When training big neural networks, it can happen that the biggest mini-batch size you can afford is one. In such cases training can get very inefficient and even not converge due to very noisy gradients. Gradient averaging is a technique allowing to increase the effective mini-batch size arbitralily despite GPU memory constraints. The key idea is to separate gradients computation from applying them. If you do so, you can compute gradients in each iteration and apply an average of them less frequently. Let’s take a look at a code examples (full code can be found here1).

def setup_train(self, average_gradients=1, lr=1e-3):
self._loss_op = tf.losses.softmax_cross_entropy(onehot_labels=self._labels, logits=self._inference_op)

# This 'train_op' computes gradients and applies them in one step.
self._train_op = optimizer.minimize(self._loss_op)
else:
# here 'train_op' only applies gradients passed via placeholders stored


Train step

def train(self, session, input_batch, output_batch):
feed_dict = {
self._input: input_batch,
self._labels: output_batch,
self._training: True
}
loss, _ = session.run([self._loss_op, self._train_op], feed_dict=feed_dict)
else:
feed_dict[placeholder] = np.stack([g[i] for g in self._gradients], axis=0).mean(axis=0)
session.run(self._train_op, feed_dict=feed_dict)
return loss


### Experiment

Let’s see, how gradient averaging affects model performance. For that, I trained the same network with different mini-batch sizes, but the gradients were applied always after the network has seen 100 samples. Moreover, the maximum number of iterations was set such that in each case the same count of training steps was performed.

>>> python main.py --average-gradients=1 --batch-size=100 --iterations=1000
mean accuracy (10 runs): 95 +/- 0.172

>>> python main.py --average-gradients=5 --batch-size=20 --iterations=5000
mean accuracy (10 runs): 94.9 +/- 0.325

>>> python main.py --average-gradients=10 --batch-size=10 --iterations=10000
mean accuracy (10 runs): 95.1 +/- 0.36

>>> python main.py --average-gradients=20 --batch-size=5 --iterations=20000
mean accuracy (10 runs): 95.2 +/- 0.303