In machine learning, cross-validation is often performed to select the best hyper-parameter for a model. Once the hyper-parameters are selected, the model is retrained on both train and validation sets before being evaluated with the test set. A general workflow for cross-validation looks something like this:
While cross-validation are often used for hyper-parameter tuning, it is also good to do cross-validation when we are not trying to do hyper-parameter tuning. For simplicity, we leave out hyper-parameter tuning here but cross-validation still needs to be performed if we are comparing performance across different models.
Cross-validation is done by resampling the dataset, often without replacement. This article focuses on resampling techniques for cross-validation and is different from other over/under-sampling methods, often used to target class imbalance problems. Resampling techniques involves repeatedly drawing samples from a training set and refitting a model of interest on each sample in order to obtain more information about the fitted model.
So, why is it helpful to do cross-validation? In prediction and forecasting, cross-validation helps in estimation of test error rates and assists in our model selection. The motivation of doing cross-validation is to estimate test error and do model selection in situations where we do not have large test dataset. The test errors are estimated with the use of a validation sets to obtain the validation errors.
There are 3 broad methods of resampling for cross-validation:
1. Validation set approach (random split)
2. Leave-One-Out Cross-Validation (LOOCV)
3. K-Fold (most common!)
In general, a dataset is split into train, validation and test sets. The resampling methods define how the validation set is obtained. By doing resampling, we are able to generate multiple validation errors from the same model to estimate the test errors. These multiple validation errors are then averaged out to give an overall cross-validation score. In regression problems, the validation error could be Mean Squared Error (MSE) obtained from validation set for example.
A good estimator with low bias and variance gives us higher confidence to select the best model based on lowest mean cross-validation score.
Method 1: Validation Set Approach (Random split)
The validation set approach can be broken down into the following steps:
1. Randomly divide the available data into training and validation set
2. Fit the model using the training set
3. Evaluate the prediction on the validation set
4. Error in validation set approximates error in test set
This can be repeated multiple times to obtain multiple validation errors and taking the average to get average cross-validation score.
Advantages:
1. Simple and easy to implement
2. Computationally efficient
Disadvantages:
1. Validation MSE can be highly varied (high variance) due to the randomness in constructing training and validation datasets
2. Could result in worse fit as only a single subset of observations are used to fit the model (training data), which will over-estimate the test error
Method 2: LOOCV
One specific type of cross-validation approach is Leave-One-Out Cross-Validation (LOOCV). For a dataset of size n, n samples can be obtained by leaving 1 data point out each time. In iteration i, use the i-th data point as validation set and the remaining set of size n-1 is used to fit the model and obtain validation error. The LOOCV MSE error for the model is the average of the n validation errors.
LOOCV is designed to overcome the randomness in dataset in the first approach. The data utilised for training is also maximised. However, LOOCV is computationally intensive as the model has to be fitted n times. Each sample is also using almost the same data and the trained models would be highly correlated. There would be high variance in test error estimation. This method is great for situations where we have very small datasets as other methods would further reduce the size of training set.
Method 3: K-fold Cross-Validation
This method is trade-off between the validation set approach and LOOCV. In K-fold approach, we randomly divide the data into K different parts and K samples can be used to for the train and validation set. In iteration i, use i-th part as validation and the remaining K-1 parts as training data. Then fit the model with training data and find out the validation error.
What K to choose? K=5 and K=10 is commonly used. It has been empirically shown that they yield test error estimates that suffer neither from excessively high bias nor from very high variance. When K=n, where n is the total number of samples, it is essentially LOOCV. Validation estimators obtained from K-fold balances the trade-off between less bias/high variance from LOOCV and less variance/high bias from validation set.
Similarly, the CV error for K-fold is simply the average of all the n validation errors.
Side note: K-fold cross-validation is used in sklearn GridSearchCV() and RandomisedSearchCV(), which are used to identify the best hyper-parameters of a model.
In regression problems, cross-validation measures MSE. In classification problems, it can be done on error rate or other performance measures for classification.
The methods above discuss resampling without replacement. In situations where we have small data size, to mimic the process of obtaining new data, we could do bootstrapping to obtain multiple bootstrap samples (distinct datasets) by repeatedly sampling n observations from the original dataset with replacement. Estimates are obtained for each bootstrap sample to get multiple bootstrap estimates. The standard error or confidence interval of the bootstrap estimates can be calculated, which will approximate the true standard error and confidence interval.
Bootstrap is a flexible and powerful statistical tool for quantifying the uncertainty associated with a given estimator or statistical learning method without drawing new data.
Key Learnings:
1) The purpose of validation set is to obtain a good estimate for test errors using validation errors, which is used to perform model selection.
2) K-fold is the most common cross-validation method and it is a in-between option between the validation set approach and LOOCV, which balances the trade-off between high bias and high variance.