How to Choose the Best Source Model for Transfer Learning
This post assumes that you have a basic familiarity with Transfer Learning for Deep Learning models. If you are unfamiliar with Transfer Learning or would like to brush up on it, I strongly recommend reading A Comprehensive Hands-on Guide to Transfer Learning with Real-World Applications in Deep Learning.
The wonderful thing about Transfer Learning is that you don’t have to be the one who trains the source model. In fact, when using Transfer Learning you rarely train the source model yourself. Instead, you usually find a pre-trained model online with good performance and then continue training it on your dataset of interest. Thankfully, pre-trained models are readily available online, and in large numbers. ModelDepot alone has over 50,000 freely accessible pre-trained models with search functionality to help you find a source model.
Since there are so many pre-trained models to choose from, the key question is: Which source model will yield the best performance after Transfer Learning? This is the subject of research by myself and Lior Rokach, which was recently published in IEEE Access under the title “Source Model Selection for Deep Learning in the Time Series Domain”. The rest of this post summarizes our method of determining the optimal source model for Transfer Learning.
It is important to note that not every pre-trained model can be used for Transfer Learning for a specific target dataset. At the very least, you have to be able to input the target samples to the source network.
In our view, Neural Networks are encoders of their input data. In each layer, the data is compressed more and more until what remains is the network’s prediction. We propose selecting the source model by the quality of the target dataset encodings it yields at a given layer.
Our method for selecting the optimal source model can be broken down into four steps:
- Truncate all of the source networks at the desired layer.
- Input the target data into each of the networks to get the “encodings”.
- Calculate how well the encodings cluster the target data using the Mean Silhouette Coefficient. The Silhouette Coefficient is a number between 1 and -1, where 1 indicates an optimal clustering and -1 indicates a poor clustering. For a more elaborate discussion of the Silhouette Coefficient see our paper (link at the end of the post). Other resources include Wikipedia and the original paper Silhouettes: A graphical aid to the interpretation and validation of cluster analysis by Peter Rousseeuw.
- Select the source model that yields the highest Silhouette Coefficient as the source model for Transfer Learning.
Our work is based on the experiments run by Fawaz, Forestier, Weber, Idoumghar, and Muller (2018). They propose a method of source model selection called Inter Dataset Similarity(IDS). IDS uses the similarity between representative sequences from each class in the source datasets and representative sequences from each class in the target dataset to select the optimal source model.
We evaluate our method using the UCR Time Series Classification Archive, although it could be applied in other domains as well.
In our paper, we show that we achieve comparable results to IDS but without access to the source datasets. This is a critical difference because many times a pre-trained model is available but the data it was trained on is not.
To delve in further check out our paper and code!
I would like to thank Rochelle Meiseles for editing this post.