Training Sequence Models with Attention
Here are a few practical tips for training sequencetosequence models with attention. About fifty percent of the time, they should work every time.
If you have experience training other types of deep neural networks, pretty much all of it applies here. The usual optimization techniques like annealing learning rates and gradient clipping as well as the regularization techniques such as dropout and weight decay are useful. This article focuses on tips that you might not know about, even with experience training other models.
Learning to Condition
The first thing we want to know is if the model is even working. Sometimes, it’s not so obvious. With sequencetosequence models, we typically optimize the conditional probability of the output given the input,
\[ p(Y \mid X) \; = \; \prod_{u=1}^U \; p(y_u \mid y_{\lt u}, X). \]
Here $X = [x_1, \ldots, x_T]$ is the input sequence and $Y = [y_1, \ldots, y_U]$ is the output sequence.
One failure mode for sequencetosequence models is they never learn to condition on the input $X$. In effect they optimize
\[ p(Y) \; = \; \prod_{u=1}^U \; p(y_u \mid y_{\lt u}). \]
This is just a language model over the output sequences. Reasonable learning can actually happen in this case even if the model never learns to condition on $X$. That’s one reason it can sometimes be hard to tell if the model is truly working.
Visualize Attention: This brings us to our first tip. A great way to tell if the model has learned to condition on the input is to visualize the attention. Usually it’s pretty clear if the attention looks reasonable.
I recommend setting up your model so that it’s easy to extract the attention vectors as soon as possible. Make a Jupyter notebook or some other simple method to load examples and visualize the attention.
The Inference Gap
Sequencetosequence models are trained with teacher forcing. Instead of using the predicted output as the input at the next step, the ground truth output is used. Without teacher forcing these models are much slower to converge, if they do so at all.
Teacher forcing causes a mismatch between training the model and using it for inference. During training we always know the previous ground truth but not at inference. Because of this, it’s not uncommon to see a large gap between error rates on a heldout set evaluated with teacher forcing versus true inference.
Scheduled Sampling: A helpful technique to bridge the gap between training and inference is scheduled sampling.^{1} The idea is simple – select the previous predicted output instead of the ground truth output with probability $p$. The probability should be tuned for the problem. The typical range for $p$ is between 10% and 40%.
Tune with Inference Rates: There can be a big gap between the teacher forced error rates and error rates when properly inferring the output. Also, the correlation between the two metrics may not be perfect. Because of this, I recommend performing model selection and hyperparameter tuning based on the inferred output error rates. If you save the model which performs best on a development set during training, use the inference error rate as a performance measure.
This tip is perhaps more important on smaller datasets when there is likely more variance in the two metrics. However, in these cases it can make a big difference. For example on the phoneme recognition task above we see a 13% relative improvement by taking the model with the best inferred error rate instead of the best teacher forced error rate. This can be a key difference if you’re trying to reproduce a baseline.
Efficiency
One downside to using these models is that they can be quite slow. The attention computation scales as the product of the input and output sequence lengths, e.g. $O(TU)$. If the input sequence doubles in length and the output sequence doubles length the amount of computation quadruples.
Bucket by Length: When optimizing a model with a minibatch size greater than 1, make sure to bucket the examples by length. For each batch, we’d like the inputs to all be the same length and the outputs to all be the same length. This won’t usually be possible, but we can at least attempt to minimize the largest length mismatch in any given batch.
One heuristic that works pretty well is to make buckets based on the input lengths. For example, all the inputs with lengths 1 to 5 go in the first bucket. Inputs with lengths 6 to 10 go in the second bucket and so on. Then sort the examples in each bucket by the output length followed by the input length.
Naturally, the larger the training set the more likely you are to have minibatches with inputs and outputs that are mostly the same length.
Striding and Subsampling: When the input and output sequences are long these models can grind to a halt. With long input sequences, a good practice is to reduce encoded sequence length by subsampling. This is common in speech recognition, for example, where the input can have thousands of timesteps.^{3} You won’t see it as much in wordbased machine translation since the input sequences aren’t as long. However, with character based models subsampling is more common.^{4}
Often subsampling the input doesn’t reduce the accuracy of the model. Even with a minor hit to accuracy though, the speedup in training time can be worth it. When the RNN and attention computations are the bottleneck (which they usually are), subsampling the input by a factor of 4 can make training the model 4 times faster.
That’s All
As you can see, getting these models to work well requires the right basket of tools. These tips are by no means comprehensive, my aim here is more for precision over recall. Even so, they certainly won’t generalize to every problem. But, as a few first ideas to try when training and improving a baseline sequencetosequence model, I strongly recommend all of them.
If you have other practical tips that you think are critical, or comments on any of the above, I’d love to know more. Please leave a comment below or send me a note.
Footnotes

See Bengio et al., 2015 ↩

Here’s the code for more details or if you want to reproduce this experiment. ↩

See Chan et al., 2015 for an example of this in speech recognition. ↩

See Zie et al., 2016 for an example of a characterbased model for language correction which subsamples in the encoder. ↩