Flower Classification using CNN

Flower Classification using CNN
Photo by dongkyu lim on Dribbble

We all come across numerous flowers on a daily basis. But we don’t even know their names at times. We all wonder “I wish my computer/mobile could classify this” when we come across a beautiful looking flower. That is the motive behind this article, to classify flower images.

The main objective of this article is to use Convolutional Neural Networks (CNN) to classify flower images into 10 categories

DATASET

Kaggle Dataset — https://www.kaggle.com/olgabelitskaya/flower-color-images

The 10 classes in the dataset are:

  1. Phlox
  2. Rose
  3. Calendula
  4. Iris
  5. Leucanthemum maximum (Shasta daisy)
  6. Campanula (Bellflower)
  7. Viola
  8. Rudbeckia laciniata (Goldquelle)
  9. Peony
  10. Aquilegia

IMPORTS

I will be using Tensorflow to implement the CNN, Matplotlib to plot graphs and display images, Seaborn to display the heatmap

MODEL

The model consists of 2 Conv2D layers of 128 neurons each along with MaxPooling layers and followed by 2 Dense layers

I have used LeakyReLU here. ReLU might also provide good results here.

The Loss is Categorical Crossentropy and Optimizer is Adam

Model Architecture:

Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_4 (Conv2D)            (None, 126, 126, 128)     3584      
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 126, 126, 128)     0         
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 63, 63, 128)       0         
_________________________________________________________________
dropout_6 (Dropout)          (None, 63, 63, 128)       0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 61, 61, 128)       147584    
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 61, 61, 128)       0         
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 30, 30, 128)       0         
_________________________________________________________________
dropout_7 (Dropout)          (None, 30, 30, 128)       0         
_________________________________________________________________
global_max_pooling2d_2 (Glob (None, 128)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 512)               66048     
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dropout_8 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_5 (Dense)              (None, 10)                5130      
_________________________________________________________________
activation_2 (Activation)    (None, 10)                0         
=================================================================
Total params: 222,346
Trainable params: 222,346
Non-trainable params: 0

Callbacks

I have defined 2 callbacks

  • ModelCheckpoint — To save the best model during training
  • ReduceLROnPlateau — Reduce the learning rate accordingly during training

Train

The model is being training with a Batch Size = 32 and for 75 Epochs

As we don’t have a large amount of data, we use Image Augmentation to synthesize more data and train our model with it

RESULT

The model reached a validation accuracy of 80.95238% which is quite decent. And we can see that the model did not overfit a lot. So it’s quite a good model

Model Accuracy
Model Loss

PREDICTIONS

Let’s look at some predictions made by our model on randomly chosen images

Prediction

 

Notebook Link: Here

Credit: Rasswanth Shankar

Also Read: Mobile Price Range Classifier

Related Posts