Connect with us

AI 101

What is Overfitting?

mm

Updated

 on

What is Overfitting?

When you train a neural network, you have to avoid overfitting. Overfitting is an issue within machine learning and statistics where a model learns the patterns of a training dataset too well, perfectly explaining the training data set but failing to generalize its predictive power to other sets of data.

To put that another way, in the case of an overfitting model it will often show extremely high accuracy on the training dataset but low accuracy on data collected and run through the model in the future. That’s a quick definition of overfitting, but let’s go over the concept of overfitting in more detail. Let’s take a look at how overfitting occurs and how it can be avoided.

Understanding “Fit” and Overfitting

Before we delve too deeply into overfitting, it might be helpful to take a look at the concept of underfitting and “fit” generally. When we train a model we are trying to develop a framework that is capable of predicting the nature, or class, of items within a dataset, based on the features that describe those items. A model should be able to explain a pattern within a dataset and predict the classes of future data points based off of this pattern. The better the model explains the relationship between the features of the training set, the more “fit” our model is.

Blue line represents predictions by a model that is underfitting, while the green line represents a better fit model. Photo: Pep Roca via Wikimedia Commons, CC BY SA 3.0, (https://commons.wikimedia.org/wiki/File:Reg_ls_curvil%C3%ADnia.svg)

A model that poorly explains the relationship between the features of the training data and thus fails to accurately classify future data examples is underfitting the training data. If you were to graph the predicted relationship of an underfitting model against the actual intersection of the features and labels, the predictions would veer off the mark. If we had a graph with the actual values of a training set labeled, a severely underfitting model would drastically miss most of the data points. A model with a better fit might cut a path through the center of the data points, with individual data points being off of the predicted values by only a little.

Underfitting can often occur when there is insufficient data to create an accurate model, or when trying to design a linear model with non-linear data. More training data or more features will often help reduce underfitting.

So why wouldn’t we just create a model that explains every point in the training data perfectly? Surely perfect accuracy is desirable? Creating a model that has learned the patterns of the training data too well is what causes overfitting. The training data set and other, future datasets you run through the model will not be exactly the same. They will likely be very similar in many respects, but they will also differ in key ways. Therefore, designing a model that explains the training dataset perfectly means you end up with a theory about the relationship between features that doesn’t generalize well to other datasets.

Understanding Overfitting

Overfitting occurs when a model learns the details within the training dataset too well, causing the model to suffer when predictions are made on outside data. This may occur when the model not only learns the features of the dataset, it also learns random fluctuations or noise within the dataset, placing importance on these random/unimportant occurrences.

Overfitting is more likely to occur when nonlinear models are used, as they are more flexible when learning data features. Nonparametric machine learning algorithms often have various parameters and techniques that can be applied to constrain the model’s sensitivity to data and thereby reduce overfitting. As an example, decision tree models are highly sensitive to overfitting, but a technique called pruning can be used to randomly remove some of the detail that the model has learned.

If you were to graph out the predictions of the model on X and Y axes, you would have a line of prediction that zigzags back and forth, which reflects the fact that the model has tried too hard to fit all the points in the dataset into its explanation.

Controlling Overfitting

When we train a model, we ideally want the model to make no errors. When the model’s performance converges towards making correct predictions on all the data points in the training dataset, the fit is becoming better. A model with a good fit is able to explain almost all of the training dataset without overfitting.

As a model trains its performance improves over time. The model’s error rate will decrease as training time passes, but it only decreases to a certain point. The point at which the model’s performance on the test set begins to rise again is typically the point at which overfitting is occurring. In order to get the best fit for a model, we want to stop training the model at the point of lowest loss on the training set, before error starts increasing again. The optimal stopping point can be ascertained by graphing the performance of the model throughout the training time and stopping training when loss is lowest. However, one risk with this method of controlling for overfitting is that specifying the endpoint for the training based on test performance means that the test data becomes somewhat included in the training procedure, and it loses its status as purely “untouched” data.

There are a couple of different ways that one can combat overfitting. One method of reducing overfitting is to use a resampling tactic, which operates by estimating the accuracy of the model. You can also use a validation dataset in addition to the test set and plot the training accuracy against the validation set instead of the test dataset. This keeps your test dataset unseen. A popular resampling method is K-folds cross-validation. This technique enables you to divide your data into subsets that the model is trained on, and then the performance of the model on the subsets is analyzed to estimate how the model will perform on outside data.

Making use of cross-validation is one of the best ways to estimate a model’s accuracy on unseen data, and when combined with a validation dataset overfitting can often be kept to a minimum.

Blogger and programmer with specialties in Machine Learning and Deep Learning topics. Daniel hopes to help others use the power of AI for social good.