Transfer Learning

August 15, 2022

Photo by Stephen Dawson on Unsplash

Transfer Learning

Long gone are the days in which data practitioners trained machine learning models from scratch themselves. Unless you have a very specific use case, you’re better off leveraging the pre-trained models that are made available in public repositories such as Tensorflow Hub and Hugging Face. The process by which we customize a pre-trained model for a given task is known as transfer learning.

There are two techniques under the umbrella of transfer learning:

  • Feature extraction: Uses the output (prior to the sigmoid function) from a pre-trained model as the features to a new classifier. When training the new model, you keep the weights of the pre-trained model fixed.
  • Fine-tuning: Unfreezes the weights in the layers of the base model and jointly trains them along with the ones in the newly-added classifier layers.

In this article, we will walkthrough an example of feature extraction. If you’d like to learn more about fine-tuning, you can checkout this article on the subject.


We’ll be using the code taken from the TF Hub for TF2: Retraining an image classifier tutorial since I couldn’t make it any clearer myself.

I recommend using Google Collab to train and run the model since it already has all the dependencies installed. You can view what GPU your runtime has access to as follows:

! nvidia-smi

We begin by importing the following libraries:

import tensorflow as tf  
import tensorflow_hub as hub
import matplotlib.pylab as plt  
import numpy as np

We will use the MobileNetV2 architecture trained on the ImageNet dataset as the base model.

model_handle = "[]("

In TensorFlow Hub, you can download classification and feature vector models. The feature vector models are specifically designed for transfer learning. As the name implies, their output is a feature vector (i.e. not piped through a sigmoid function). Thus, you don’t need to remove the output layer of the model when doing transfer learning. You can simply add additional trainable layers.


The model was trained using images of size 224 by 224 pixels. We set a sensible default value for the batch size.

IMAGE_SIZE = (224, 224)  

The Keras library provides a utility function for retrieving files from the GCP.

data_dir = tf.keras.utils.get_file(  

Using the images we just downloaded, we construct the training and validation datasets. We perform data augmentation by transforming the images slightly (i.e. rotating, translating, flipping).

def build_dataset(subset):  
  return tf.keras.preprocessing.image_dataset_from_directory(  
      # Seed needs to provided when using validation_split and shuffle = True.  
      # A fixed seed is used so that the validation set is stable across runs.  
train_ds = build_dataset("training")  
class_names = tuple(train_ds.class_names)  
train_size = train_ds.cardinality().numpy()  
train_ds = train_ds.unbatch().batch(BATCH_SIZE)  
train_ds = train_ds.repeat()
normalization_layer = tf.keras.layers.Rescaling(1. / 255)  
preprocessing_model = tf.keras.Sequential([normalization_layer])  
do_data_augmentation = False #@param {type:"boolean"}  
if do_data_augmentation:  
      tf.keras.layers.RandomTranslation(0, 0.2))  
      tf.keras.layers.RandomTranslation(0.2, 0))  
  # Like the old tf.keras.preprocessing.image.ImageDataGenerator(),  
  # image sizes are fixed when reading, and then a random zoom is applied.  
  # If all training inputs are larger than image_size, one could also use  
  # RandomCrop with a batch size of 1 and rebatch later.  
      tf.keras.layers.RandomZoom(0.2, 0.2))  
train_ds = images, labels:  
                        (preprocessing_model(images), labels))
val_ds = build_dataset("validation")  
valid_size = val_ds.cardinality().numpy()  
val_ds = val_ds.unbatch().batch(BATCH_SIZE)  
val_ds = images, labels:  
                    (normalization_layer(images), labels))

We freeze the weights of the MobileNetV2 model.

do_fine_tuning = False

We add a Dropout layer (to prevent overfitting) and a Dense layer (for classification) to the output of our model.

model = tf.keras.Sequential([  
    # Explicitly define the input shape so the model can be properly  
    # loaded by the TFLiteConverter  
    tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),  
    hub.KerasLayer(model_handle, trainable=do_fine_tuning),  

As we can see, only 6,405 out of 2,264,389 parameters are trainable. This will speed up training significantly.

Model: "sequential_1"  
 Layer (type)                Output Shape              Param #     
 keras_layer (KerasLayer)    (None, 1280)              2257984     
 dropout (Dropout)           (None, 1280)              0           
 dense (Dense)               (None, 5)                 6405        
Total params: 2,264,389  
Trainable params: 6,405  
Non-trainable params: 2,257,984  

We compile the model using a learning rate of 0.005 and categorical crossentropy for the loss function.

  optimizer=tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9),   
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),  

We define the number of steps per epoch.

steps_per_epoch = train_size // BATCH_SIZE  
validation_steps = valid_size // BATCH_SIZE

Finally, we train the model.

hist =  
    epochs=5, steps_per_epoch=steps_per_epoch,  

We plot the training and validation loss over time.

plt.ylabel("Loss (training and validation)")  
plt.xlabel("Training Steps")  
plt.ylabel("Accuracy (training and validation)")  
plt.xlabel("Training Steps")  

As we can see, it didn’t take long for the accuracy of the model to start hovering around 90%.

We can use the model to infer the class of a specific sample in the dataset.

x, y = next(iter(val_ds))  
image = x[0, :, :, :]  
true_index = np.argmax(y[0])  
prediction_scores = model.predict(np.expand_dims(image, axis=0))  
predicted_index = np.argmax(prediction_scores)  
print("True label: " + class_names[true_index])  
print("Predicted label: " + class_names[predicted_index])
``` True label: sunflowers Predicted label: sunflowers ```

As we can see, it accurately classified the image.

Profile picture

Written by Cory Maklin Genius is making complex ideas simple, not making simple ideas complex - Albert Einstein You should follow them on Twitter