August 23, 2022
Learn how data augmentation can reduce image classification overfitting & how VisionERA can help craft unique solutions for you.
There is always some valuable information that can be extracted from any data type. Images containing objects can be used to teach machines to identify the image's label or category correctly. Image classification is one of the popular techniques in computer vision, which aims to enable computers to extract meaningful information from digital media like images or videos. This is essential in specialized applications requiring image recognition, such as
Image recognition or Image classification is the basis for all the above applications. Hence, an image classification model with higher accuracy is desired for the success of these applications in real-world scenarios.
Image classification models are deep learning models with specified neural networks that can analyze and extract the important features of an image. Based on these attributes, the algorithm enables the model to classify the image into a specific category. In general, if we train the model with a sufficient variety and number of images, the model will be better at identifying the correct category.
However, such a diverse and huge amount of relevant data is hard to find. It would be expensive and time-consuming, even if companies strive to acquire such data. So, it is evident that these models frequently face challenges due to the scarcity of larger datasets and a lack of adequately labeled data. Additionally, variations in intra-class, scale, illumination, and presence of noise or clutter in the images affect the accuracy as the model cannot correctly label the images. This is likely to occur when the deep learning model is fed with real images that can have variations in brightness, contrast, zoom, or rotation. One way to ensure that a model can process these differences in the same object is to train the model with an expanded dataset with augmented images. The existing dataset can be augmented using Data Augmentation techniques. These augmented images can add much-needed variety to the existing dataset, and the model generalization can improve.
Each machine learning model's primary goal is to generalize well. In this context, generalization refers to a Machine Learning (ML) model's ability to provide a suitable output by adapting the given set of unknown inputs. It means that after training on the dataset, it is expected the model should produce reliable and accurate results. However, when we evaluate the model performance, there are two major scenarios commonly encountered while training Machine Learning models: Underfitting and Overfitting. The occurrence of any of these can degrade the performance of machine learning models.
In a simple sense, underfitting implies that the trained model makes a few correct and many incorrect predictions. On the other hand, overfitting occurs when the trained model fails to make accurate predictions, i.e., the training accuracy is relatively high, but the validation accuracy is poor. A higher training accuracy indicates that the training error is very small, while a poor validation accuracy means the validation error is very large. Both of these should ideally not be present in models, although they are often challenging to remove. Underfitting is less common in ML models than overfitting, but it should not be neglected. Before delving into the details of overfitting, let us first understand some key terms used for evaluating model performance.
Following are four important concepts essential to understanding the machine learning model performance.
So, when a model appears to be overfitting, it means that the model 'bias' is very low and 'variance' is high, whereas vice-versa in the case of underfitting.
The "goodness of fit" is an optimum fit for the model that indicates the model performs well on unknown data. This can be achieved when variance and bias are kept as low as possible. In reality, there is always a slight trade-off between bias and variance that needs to be considered to make the model acceptable. But then, how do we know if a model is overfitting? Let us understand this in the next section.
Only a trained model used on test data can be evaluated for overfitting. The model needs to be trained on a split dataset (generally an 80:20 split) with distinct training and testing sets. Next, a plot (often called a learning curve plot) of the model performance can be plotted at each epoch which shows both curves for model performance on the training and validation/test sets for each step of model learning. Now, if the model performance on the training set was exponentially better than on the test set, it is clearly overfitting the training data. It can be often seen in the case of learning curve plots that the model performance on the training dataset continues to improve, i.e., loss continues to reduce or accuracy continues to increase, whereas, for validation/test set, it seems to improve only up to a certain point and then begins to degrade. The training should be stopped whenever such a pattern is observed in order to avoid model overfitting. After understanding overfitting, let us explore some techniques to reduce the model overfitting.
Almost all image classification models exhibit a tendency to overfit training data. If a classification model appears to be overfitting, here’s what can be done to achieve the goodness of fit. These are a few strategies that can help reduce overfitting and improve the model generalization.
All the above strategies can be used to address the issue of overfitting in classification models. In the next section, let us explore how “Data Augmentation” reduces overfitting.
To understand how Data Augmentation reduces overfitting, we will use the filtered version of the Kaggle dataset 'cats_and_dogs' (original dataset provided by Microsoft) to build an image classifier. There are 2000 training and 1000 testing images for two labels, ‘cats’ and ‘dogs’. Here are some sample images from the dataset.
Since this is a relatively small-sized image dataset, the trained model can experience overfitting. We will build the image classifier using CNN with the following code -
Training the above model for 100 epochs, we get a training accuracy of 100%, while the validation accuracy is 67.6%. After plotting the curves for these two accuracies, as shown in the below figure, we can clearly see that the model is overfitting as the validation accuracy is not increasing after the first few while the validation loss does not reduce; instead, it increases.
Next, we will augment the dataset using selective transformations such as width_shift_range, height_shift_range, rotation, zoom, flipping and rotation. Now we can retrain the model with these augmented images. Here are some sample augmented images -
The training accuracy reduces to 69.4% while the validation accuracy increases to 71%. Plotting the training and validation accuracies for the model trained on augmented images, we get -
The model does a better job at training with augmented images as both training and validation accuracies overlap well. Also, the training and validation losses keep on reducing (except for one epoch). This indicates that Data Augmentation has helped to reduce overfitting for this image classification model.
In this article, we saw an overview of overfitting and how Data Augmentation can reduce overfitting in an image classification model.
About us: VisionERA is an Intelligent Document Processing (IDP) platform capable of handling various types of documents because of Data Augmentation for Image Classification. It has the capacity to extract and validate data for bulk volumes with minimal intervention. Also, the platform can be molded as per requirements for any industry and use case because of its custom DIY workflow feature. It is a scalable and flexible platform providing end-to-end document automation for any organization.
Looking for a document processing solution that uses the enhanced capabilities of image classification using deep learning? Setup a demo today by clicking the CTA below or simply send us a query through the contact us page!