Batch Normalization Tensorflow Keras Example
Machine learning is such an active field of research that you’ll often see white papers referenced in the documentation of libraries. In the proceeding article we’ll cover batch normalization which was characterized by Loffe and Szegedy. If you’re the kind of person who likes to get their information directly from the source, checkout the proceeding link to their white paper.
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
_Training Deep Neural Networks is complicated by the fact that the distribution of each layer’s inputs changes during…_arxiv.org
Batch normalization is used to stabilize and perhaps accelerate the learning process. It does so by applying a transformation that maintains the mean activation close to 0 and the activation standard deviation close to 1.
At a high level, backpropagation modifies the weights in order to lower the value of cost function. However, before we can understand the reasoning behind batch normalization, it’s critical that we grasp the actual mathematics underlying backpropagation.
To make the problem simpler, we will assume we have a neural network consisting of two layers, each with a single neuron.
We can express output of each neuron using the following formulas:
- L = the layer in the neural network
- w = the weight with which the outgoing edges from the neuron are multiplied
- a = output from the neuron in the previous layer (the value of the incoming edge)
- σ = the activation function
- b = the output from the bias neuron (the value of the incoming edge)
A typical example of a cost function is mean squared error. For an individual sample, we subtract the actual value (i.e. y) from the predicted value and square the result to account for instances when the predicted value is greater or lower than the actual value.
As we mentioned previously, we modify the weights in order to minimize the cost function. If we plotted the cost in relation to an individual weight, the cost would be at its lowest at the bottom of the parabola.
We can get the partial derivative of the cost function with respect to the weight by making use of the chain rule in Calculus.
The partial derivative of each of the terms can be expressed as follows.
Notice how we use the derivative of the activation function.
If we use a sigmoid function for our activation function, then, if z (the output of the neuron prior to the activate function) is very large or very small, the derivative will be approximately 0. In consequence, when we go to compute the gradient and update the weights, the change will be so infinitesimally small that the model won’t improve. The latter is known as the vanishing gradient problem.
In normalizing the output of the neuron before it enters the activation function, we can ensure it remains close to 0 where the derivative highest.
Random processes in nature tend to follow a bell shape curve known as a normal distribution.
The mean is the sum of all the data points divided by the total number of points. Increasing the mean shifts the center of the bell shape curve to the right and decreasing the mean shifts the center of the bell shape curve to the left. On the other hand, the standard deviation (square root of the variance) describes how far the samples differ from the mean. Increasing the standard deviation widens the curve.
In order to normalize the data, we subtract the mean and divide by the standard deviation.
No matter the data we’re working with, after normalizing it, the mean will be equal to 0 and the standard deviation will be equal to 1.
Note: This is the same as saying it ensures the variance is equal to 1 since the standard deviation is equal to the square root of the variance.
Suppose we built a neural network with the goal of classifying grayscale images. The intensity of every pixel in a grayscale image varies from 0 to 255. Prior to entering the neural network, every image will be transformed into a 1 dimensional array. Then, every pixel enters one neuron from the input layer. If the output of each neuron is passed to a sigmoid function, then every value other than 0 (i.e. 1 to 255) will be reduced to a number close to 1. Therefore, it’s common to normalize the pixel values of each image before training. Batch normalization, on the other hand, is used to apply normalization to the output of the hidden layers.
Let’s take a look at how we can go about implementing batch normalization in Python.
import matplotlib.pyplot as plt import matplotlib.image as mpimg plt.style.use('dark_background')
from keras.models import Sequential from keras.preprocessing.image import ImageDataGenerator from keras.layers import BatchNormalization from keras.layers import Conv2D, MaxPooling2D, Dense, Flatten from keras.datasets import cifar10 from keras.utils import normalize, to_categorical
The cifar10 dataset is comprised of 60,000 32×32 pixel images divided into 10 classes. The classes and their standard associated integer values are listed below.
- 0: airplane
- 1: automobile
- 2: bird
- 3: cat
- 4: deer
- 5: dog
- 6: frog
- 7: horse
- 8: ship
- 9: truck
Prior to training our model, we normalize the inputs for the same reasons listed above and encode the labels.
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = normalize(X_train, axis=1) X_test = normalize(X_test, axis=1) y_train = to_categorical(y_train) y_test = to_categorical(y_test)
In order to improve our model’s ability to generalize, we will randomly shift, flip and zoom in/out of the image.
train_datagen = ImageDataGenerator( shear_range = 0.2, zoom_range = 0.2, horizontal_flip = True )
train_generator = train_datagen.flow( X_train, y_train, batch_size = 32 )
We set the number of steps using the following equation but we could have used any arbitrary value.
steps = int(X_train.shape / 64)
We define a function to build models with and without the use of batch normalization as well as the activation function of our choice.
def build_model(batch_normalization, activation): model = Sequential() model.add(Conv2D(32, 3, activation = activation, padding = 'same', input_shape = (32, 32, 3))) if batch_normalization: model.add(BatchNormalization()) model.add(Conv2D(32, 3, activation = activation, padding = 'same', kernel_initializer = 'he_uniform')) if batch_normalization: model.add(BatchNormalization()) model.add(MaxPooling2D()) model.add(Conv2D(64, 3, activation = activation, padding = 'same', kernel_initializer = 'he_uniform')) if batch_normalization: model.add(BatchNormalization()) model.add(Conv2D(64, 3, activation = activation, padding = 'same', kernel_initializer = 'he_uniform')) if batch_normalization: model.add(BatchNormalization()) model.add(MaxPooling2D()) model.add(Flatten()) model.add(Dense(128, activation = activation, kernel_initializer = 'he_uniform')) model.add(Dense(10, activation = 'softmax'))
To highlight the benefits of using batch normalization, we’re going to train and compare the performance of a model with and one without the use of batch normalization.
sig_model = build_model(batch_normalization = False, activation = 'sigmoid')
We use rmsprop as our optimizer and categorical crossentropy as our loss function since we’re trying to predict classes.
sig_model.compile( optimizer = 'rmsprop', loss = 'categorical_crossentropy', metrics = ['accuracy'] )
Next, we train our model.
sig_history = sig_model.fit_generator( train_generator, steps_per_epoch = steps, epochs = 10, validation_data = (X_test, y_test) )
We can plot the training and validation accuracy and loss at each epoch by using the history variable returned by the fit function.
loss = sig_history.history['loss'] val_loss = sig_history.history['val_loss'] epochs = range(1, len(loss) + 1) plt.plot(epochs, loss, 'y', label='Training loss') plt.plot(epochs, val_loss, 'r', label='Validation loss') plt.title('Training and validation loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.show()
Next, we take the same steps as before only this time we apply batch normalization.
sig_norm_model = build_model(batch_normalization = True, activation = 'sigmoid')
sig_norm_model.compile( optimizer = 'rmsprop', loss = 'categorical_crossentropy', metrics = ['accuracy'] )
sig_norm_history = sig_norm_model.fit_generator( train_generator, steps_per_epoch = steps, epochs = 10, validation_data = (X_test, y_test) )
As you can see, the training loss and training accuracy plots are much smoother and achieve significantly better results than the model without batch normalization.
loss = sig_norm_history.history['loss'] val_loss = sig_norm_history.history['val_loss'] epochs = range(1, len(loss) + 1) plt.plot(epochs, loss, 'y', label='Training loss') plt.plot(epochs, val_loss, 'r', label='Validation loss') plt.title('Training and validation loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.show()
The vanishing gradient problem refers to how the gradient decreases exponentially as we propagate down to the initial layers. In consequence, the weights and biases of the initial layers won’t be updated effectively. Given that these initial layers are often crucial to recognizing the core elements of the input data, it can lead to poor accuracy.
The simplest solution is to use another activation function, such as ReLU. Otherwise, we can use batch normalization to mitigate the issue by normalizing the input such that it remains in the goldilocks zone.