Flower Classification using CNN

Photo by dongkyu lim on Dribbble

We all come across numerous flowers on a daily basis. But we don’t even know their names at times. We all wonder “I wish my computer/mobile could classify this” when we come across a beautiful looking flower. That is the motive behind this article, to classify flower images.

The main objective of this article is to use Convolutional Neural Networks (CNN) to classify flower images into 10 categories

DATASET

Kaggle Dataset — https://www.kaggle.com/olgabelitskaya/flower-color-images

The 10 classes in the dataset are:

  1. Phlox
  2. Rose
  3. Calendula
  4. Iris
  5. Leucanthemum maximum (Shasta daisy)
  6. Campanula (Bellflower)
  7. Viola
  8. Rudbeckia laciniata (Goldquelle)
  9. Peony
  10. Aquilegia

IMPORTS

I will be using Tensorflow to implement the CNN, Matplotlib to plot graphs and display images, Seaborn to display the heatmap

https://gist.github.com/IamRash-7/c40e0e16806b260c67890398a4672846

MODEL

The model consists of 2 Conv2D layers of 128 neurons each along with MaxPooling layers and followed by 2 Dense layers

I have used LeakyReLU here. ReLU might also provide good results here.

The Loss is Categorical Crossentropy and Optimizer is Adam

https://gist.github.com/IamRash-7/8cbc78f6466ec8cdb060f1a163b4696a

Model Architecture:

Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_4 (Conv2D) (None, 126, 126, 128) 3584
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 126, 126, 128) 0
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 63, 63, 128) 0
_________________________________________________________________
dropout_6 (Dropout) (None, 63, 63, 128) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 61, 61, 128) 147584
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 61, 61, 128) 0
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 30, 30, 128) 0
_________________________________________________________________
dropout_7 (Dropout) (None, 30, 30, 128) 0
_________________________________________________________________
global_max_pooling2d_2 (Glob (None, 128) 0
_________________________________________________________________
dense_4 (Dense) (None, 512) 66048
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 512) 0
_________________________________________________________________
dropout_8 (Dropout) (None, 512) 0
_________________________________________________________________
dense_5 (Dense) (None, 10) 5130
_________________________________________________________________
activation_2 (Activation) (None, 10) 0
=================================================================
Total params: 222,346
Trainable params: 222,346
Non-trainable params: 0

Callbacks

I have defined 2 callbacks

  • ModelCheckpoint — To save the best model during training
  • ReduceLROnPlateau — Reduce the learning rate accordingly during training

https://gist.github.com/IamRash-7/4e224dd4ca64561d03050435d22e9d35

Train

The model is being training with a Batch Size = 32 and for 75 Epochs

https://gist.github.com/IamRash-7/1cf390ee64211b8feca61a45b87e52a8

As we don’t have a large amount of data, we use Image Augmentation to synthesize more data and train our model with it

https://gist.github.com/IamRash-7/2a4df5020c4ee93f2f9b9156be33f8da

RESULT

The model reached a validation accuracy of 80.95238% which is quite decent. And we can see that the model did not overfit a lot. So it’s quite a good model

PREDICTIONS

Let’s look at some prediction made by our model on randomly chosen images

https://gist.github.com/IamRash-7/1719ea0c049268d420ce6a2caa40ade4

Notebook Link: Here

Credit: Rasswanth Shankar