Skip to content Skip to sidebar Skip to footer

How Implement Batch Norm With Swa In Tensorflow?

I am using Stochastic Weight Averaging (SWA) with Batch Normalization layers in Tensorflow 2.2. For Batch Norm I use tf.keras.layers.BatchNormalization. For SWA I use my own code t

Solution 1:

When must the forward/prediction pass be run? At the end of each mini-batch, end of each epoch, end of all training?

At the end of training. Think of it like this, SWA is performed by swapping your final weights with a running average. But all batch norm layers are still calculated based on statistics from your old weights. So we need to run a forward pass to let them catch up.

When the forward pass is run, how are the running mean & stdev values made available to the batch norm layers?

During a normal forward pass (prediction) the running mean and standard deviation will not be updated. So what we actually need to do is to train the network, but not update the weights. This is what the paper refers to when it says to run the forward pass in "training mode".

The easiest way to achieve this (that I know) is to reset the batch normalization layers and train one additional epoch with learning rate set to 0.

Is this process performed magically by the tfa.optimizers.SWA class?

I don't know. But if you are using Tensorflow Keras then I have made this Keras SWA callback that does it like in the paper including the learning rate schedules.

Post a Comment for "How Implement Batch Norm With Swa In Tensorflow?"