Seoul AI

Seoul AI is the largest international Artificial Intelligence community in Seoul. Group of professionals from different fields meets in the heart of Seoul to discuss AI or implement various AI applications.

Kaggle - How to use pretrained Keras models?

Martin Kersner

23 January 2018

Training deep learning models from scratch is never a good way of aproaching new problem. Moreover, if you want to obtain reasonable baseline results fast. One of the solutions is to use pretrained models, which are available for every major deep learning framework (e.g. Tensorflow, Keras) and finetune them for your particular task or train separate model on outputs from pretrained model.

Kaggle offers kernels for competing in challeges without need to install anything on your computer. However, preinstalled packages does not contain models with pretrained weights. beluga came up with a simple solution how to include pretrained models to any Kaggle kernel. Five Keras pretrained models were uploaded to Kaggle as a dataset. Following table shows accuracy on ImageNet challenge for each available pretrained model.

Model Top-1 Accuracy Top-5 Accuracy
Xception 0.790 0.945
VGG16 0.715 0.901
VGG19 0.727 0.910
ResNet50 0.759 0.929
InceptionV3 0.788 0.944

If you decide to use Keras pretrained model in Kaggle, first, you have to include Keras Pretrained Models “dataset” using Add Data Source button in Input Files field at the top of Kaggle kernel page and then use necessary boiler plate code. Keras searches for models in ~/.keras/models directory, which isn’t present after startup of kernel, so we have to create it.

# mkdir -p ~/.keras/models
cache_dir = Path.home() / '.keras'
if not cache_dir.exists():
models_dir = cache_dir / 'models'
if not models_dir.exists():

Next, we have to copy pretrained models to ~/.keras/models. If you want to use for example Xception, run code below.

!cp ../input/keras-pretrained-models/xception* ~/.keras/models/

And that’s all! Now, we can import xception, build model and predict.

from keras.applications import xception
xception_model = xception.Xception(weights='imagenet',
#xception_model.predict(X_train, batch_size=32, verbose=1)