Set up the best parameters for Deep Learning RNN with Grid Search


Most of the time I’ve spent on Kaggle contests have been hyperparameter optimization. It’s a major part of deep learning and fine tuning if you hope to win or even getting into the top 10% for any contest or model building.

Neural networks are quite difficult to configure and there are tons of parameters that need to be set. The worst part, neural networks are very slow to train which means you could test one or two parameters a day because you are waiting for your training set to run.

I have setup a Colabs where you can look at the code I used for a Kaggle contest to determine the best parameters for my training set using grid search capability from the scikit-learn python machine learning library to tune the hyperparameters of Keras deep learning models.

Keras models can be used in scikit-learn by wrapping them with the KerasClassifier or KerasRegressor class.

To use these wrappers you must define a function that creates and returns your Keras sequential model, then pass this function to the build_fn argument when constructing the KerasClassifier class.

See define model in Colabs

How to Use Grid Search in scikit-learn

Grid search is a model hyperparameter optimization technique provided in the GridSearchCV class.

ccuracy is the score that is optimized, but other scores can be specified in the score argument of the GridSearchCV constructor.

By default, the grid search will only use one thread. By setting the n_jobs argument in the GridSearchCV constructor to -1, the process will use all cores on your machine. Depending on your Keras backend, this may interfere with the main neural network training process.

The GridSearchCV process will then construct and evaluate one model for each combination of parameters. Cross validation is used to evaluate each individual model and the default of 3-fold cross validation is used, although this can be overridden by specifying the cv argument to the GridSearchCV constructor.

In future posts I’ll discuss more on how to tune the algorithms.