Understanding and implementing Transformer from scratch
Published:
Last Updated: 20200401
Recently I’ve been working on learning by implementing the strongest sequence model at present — Transformer. This is a brief report on what I have done and my understanding.
The full code is at my GitHub Repo.
This blog is organized as follows:

Transformer Overview

Positional Encoding

MultiHead Attention

Positionwise Feed Forward Network

Conclusion
1. Transformer Overview
Let us first have have an overview of the whole architecture.
Figure 1: Transformer architecture [1]
First, the Inputs shown in the figure is the source text, goes through an Embedding layer, which converts word ids to continuous embeddings, then those embeddings are added by Positional Encoding, which provides positional information. Then the embeddings will pass through a Dropout layer (drop rate = 0.1 in the setting of [1]). Note that Dropout is added after every sublayer.
Then it is the Encoder part. The embeddings will get encoded by a stack of N (N = 6 in the setting of [1]) identical encoder layers. In one encoder layer, there is a MultiHead SelfAttention layer with Dropout, after which the output is Added by the residual and be normalized by Layer Normalization. Then a Positionwise Feed Forward Network is applied, still with Dropout. After Added by the residual and Layer Normalization again, the encoding of this encoder layer is finally achieved.
Then let’s what happens to the Outputs shown in the figure, the word “Outputs” means that they are the generated outputs of previous time steps. Similarly, it passes through an Embedding layer and a Positional Encoding layer (with Dropout).
The Decoder is almost the same as the Encoder so I will not say too much about it. But be careful with two differences here: the First MultiHead SelfAttention is Masked and another new MultiHead Attention between Encoder and Decoder is added between the SelfAttention and Feed Forward Network layers. We will explain them in the following sections.
To conclude this section, Transformer is really a large model compared with RNNs and CNNs. It has been well know for its strong capability in many tasks according to existing research, but it also has a significant drawback,that is it contains quite a lot of layers, sublayers and of course, parameters. That’s why it is hard to apply Transformer models in the realworld.
Let’s go back. As we can see, there are two main submodules that appears both in/outside the encoders and decoders.
 Positional Encoding
 MultiHead Attention
 Positionwise Feed Forward
2. Positional Encoding
As we know, Transformer is empowered by its strong attention mechanism, which is like connecting every two words directly, or we say, it is able to model the words jointly. Attention also helps solve the problem of gradient explosion/vanishing in RNNs as well as make the training process parallelable.
However, using attention only also brings a significant problem: how to express the position information in the sentences? As we know, RNNs are good at handling sequence tasks because they are positionsensitive. For example, the output of ‘A student at Tokyo Tech’ and ‘Tech Tokyo at student A’ will never lead to the same output in RNNs.
Thus, to convey such position information in encodings, besides a normal embedding layer, Transformer proposed a relative positional encoding method, that is:
\[PE_{(pos, 2i)}=sin(\frac{pos}{10000^{\frac{2i}{d_{model}}}})\] \[PE_{(pos, 2i+1)}=cos(\frac{pos}{10000^{\frac{2i}{d_{model}}}})\]where pos means the position of word, 2i and 2i+1 represents the dimension.
class PositionalEncoding(torch.nn.Module):
def __init__(self, model_dim, max_seq_len):
super(PositionalEncoding, self).__init__()
self.register_buffer('pe_table', self._get_pe_table(model_dim, max_seq_len))
def forward(self, input):
'''
input : tensor, encodings of input sequence batch, shape (batch_size, seq_len, model_dim)
'''
return input + self.pd_table[:input.shape[1]].clone().detach()
def _get_pe_table(self, model_dim, max_seq_len):
'''
Use logarithm form to simplify equation.
'''
pe_table = torch.zeros((model_dim, max_seq_len))
pos = torch.arange(0, max_seq_len, dtype=torch.float32) # shape (max_seq_len,)
div_term = torch.exp(torch.log(torch.tensor(10000.)) * torch.arange(0, model_dim, 2, dtype=torch.float32) / model_dim) # shape (model_dim // 2,)
pos = pos.unsqueeze(1)
div_term = div_term.unsqueeze(0)
pe_table[:, 0::2] = torch.sin(pos * div_term)
pe_table[:, 1::2] = torch.cos(pos * div_term)
return pe_table
By adding the sinusoid function, the position information of words in the sequences is stressed.
Of course, we can also learn a positional encoding here. According to [1], learned encoding produces ‘nearly identical results’ as the sinusoidal version. However, compared with learned positional encoding, such sinusoidal encoding method has two advantages:

This encoding can naturally adapt to longer sequences which it has never met during training, while the learned encoding depends too much on the training data.

It saves parameters and is faster.
3. MultiHead Attention
Multihead attention is the core of Transformer. Let’s see how it works.
Multihead attention consists of a number of Scaled DotProduct Attention, which can be expressed by the equation below:
\[Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V\]where Q, K, V represent query, key and value respectively.
For easy understanding, we can imagine we are using query to look for which value we should look at more by matching with their keys.
Knowing about this, we can first implement the “singlehead” attention:
class ScaledDotProductAttention(torch.nn.Module):
def __init__(self, model_dim, key_dim, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.dropout = nn.Dropout(p=dropout)
self.key_dim = key_dim
def forward(self, query, key, value, mask=None):
'''
query : tensor, shape (batch_size, seq_len_1, query_dim)
key, value : tensor, shape (batch_size, seq_len_2, query_dim)
mask : tensor, shape (batch_size, 1, seq_len_2)
'''
key = key.transpose(1, 2)
attn_score = torch.matmul(query / (self.key_dim ** 0.5), key)
# Mask the unneeded/unseenable values in the key/value matrices.
if mask is not None:
mask = mask.unsqueeze(1).eq(0)
attn_score = attn_score.masked_fill(mask=mask, value=1e9) # infinity value
attn_score = F.softmax(attn_score, dim=1).masked_fill(mask=mask, value=0.0)
else:
attn_score = F.softmax(attn_score, dim=1)
attn_score = self.dropout(attn_score)
out = torch.matmul(attn_score, value)
return out, attn_score
You may be confused that how do Q, K, V appear? Actually they are just different linear transformation of the inputs, for example, the output after the Positional Encoding layer. The socalled MultiHead is just repeatedly do the same transformation and Scaled DotProduct Attention multiple times.
Let’s find out how it works in the multihead attention:
class MultiHeadAttention(torch.nn.Module):
def __init__(self, model_dim, key_dim, value_dim, n_head, dropout=0.1):
'''
n_head : default value 8 as set in [1]
key_dim, value_dim : default value model_dim/n_head = 512 / 8 = 64 as set in [1]
'''
super(MultiHeadAttention, self).__init()
self.linear_K = nn.Linear(in_features=model_dim, out_features=key_dim * n_head)
self.linear_Q = nn.Linear(in_features=model_dim, out_features=key_dim * n_head)
self.linear_V = nn.Linear(in_features=model_dim, out_features=value_dim * n_head)
self.linear_out = nn.Linear(in_features=valud_dim*n_head, out_features=model_dim)
self.attn = ScaledDotProductAttention(model_dim, key_dim, dropout)
self.dropout = nn.Dropout(p=dropout)
self.n_head = n_head
self.key_dim = key_dim
self.value_dim = value_dim
def forward(self, query, key, value, mask=None):
'''
query : tensor, shape (batch_size, seq_len_1, model_dim)
key, value : tensor, shape (batch_size, seq_len_2, model_dim)
mask : tensor, shape (batch_size, 1, seq_len_2)
'''
batch_size, seq_len_1, _ = query.shape
n_head = self.n_head
key_dim = self.key_dim
value_dim = self.value_dim
key = self.linear_K(key).view(batch_size, n_head, seq_len_1, key_dim)
query = self.linear_Q(query).view(batch_size, n_head, seq_len_1, key_dim)
value = self.linear_V(value).view(batch_size, n_head, seq_len_1, value_dim)
# out shape here: (batch_size, n_head, seq_len_1, value_dim)
out, attn_score = self.attn(query, key, value, mask=mask)
out = out.view(batch_size, seq_len_1, n_head*value_dim)
out = self.dropout(self.linear_out(out)) # Do not forget Dropout!
return out, attn_score
4. Positionwise Feed Forward Network
This is the easiest part for understanding in Transformer as long as you have basic understanding deep learning. It is actually a twolayer fullconnected network.
\[FFN(x)=max(0, xW_1+b_1)W_2+b_2\]class PositionwiseFeedForward(torch.nn.Module):
def __init__(self, model_dim, hidden_dim, dropout=0.1):
'''
model_dim : default 512 as set in [1]
hidden_dim : default 2048 as set in [1]
'''
super(PositionwiseFeedForward, self).__init__()
self.linear_1 = nn.Linear(in_features=model_dim, out_features=hidden_dim)
self.linear_2 = nn.Linear(in_features=hidden_dim, out_features=model_dim)
self.dropout = nn.Dropout(p=dropout)
def forward(self, input):
'''
Input : tensor, shape (batch_size, seq_len, model_dim)
'''
out = F.relu(self.linear_1(input))
out = self.dropout(self.linear_2(out))
return out
5. Conclusion
Now we have looked through the whole architecture and design of core modules in Transformer. In the future, I would like to extend the topic more by introducing its application in Speech[2] and the wellknown pretrained model BERT [3].
Thanks so much for you time. If you want to learn more about my detailed implementation, please refer to my GitHub repository for the full code: Repo.
6. References
[1] Vaswani, A., et al. “Attention Is All You Need. arXiv 2017.” arXiv preprint arXiv:1706.03762.
[2] L. Dong, S. Xu and B. Xu, “SpeechTransformer: A NoRecurrence SequencetoSequence Model for Speech Recognition,” 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Calgary, AB, 2018, pp. 58845888.
[3] Devlin, Jacob, et al. “Bert: Pretraining of deep bidirectional transformers for language understanding.” arXiv preprint arXiv:1810.04805 (2018).