Understanding and implementing Transformer from scratch

8 minute read


Last Updated: 2020-04-01

Recently I’ve been working on learning by implementing the strongest sequence model at present — Transformer. This is a brief report on what I have done and my understanding.

The full code is at my GitHub Repo.

This blog is organized as follows:

  1. Transformer Overview

  2. Positional Encoding

  3. Multi-Head Attention

  4. Position-wise Feed Forward Network

  5. Conclusion

1. Transformer Overview

Let us first have have an overview of the whole architecture.

​ Figure 1: Transformer architecture [1]

First, the Inputs shown in the figure is the source text, goes through an Embedding layer, which converts word ids to continuous embeddings, then those embeddings are added by Positional Encoding, which provides positional information. Then the embeddings will pass through a Dropout layer (drop rate = 0.1 in the setting of [1]). Note that Dropout is added after every sublayer.

Then it is the Encoder part. The embeddings will get encoded by a stack of N (N = 6 in the setting of [1]) identical encoder layers. In one encoder layer, there is a Multi-Head Self-Attention layer with Dropout, after which the output is Added by the residual and be normalized by Layer Normalization. Then a Position-wise Feed Forward Network is applied, still with Dropout. After Added by the residual and Layer Normalization again, the encoding of this encoder layer is finally achieved.

Then let’s what happens to the Outputs shown in the figure, the word “Outputs” means that they are the generated outputs of previous time steps. Similarly, it passes through an Embedding layer and a Positional Encoding layer (with Dropout).

The Decoder is almost the same as the Encoder so I will not say too much about it. But be careful with two differences here: the First Multi-Head Self-Attention is Masked and another new Multi-Head Attention between Encoder and Decoder is added between the Self-Attention and Feed Forward Network layers. We will explain them in the following sections.

To conclude this section, Transformer is really a large model compared with RNNs and CNNs. It has been well know for its strong capability in many tasks according to existing research, but it also has a significant drawback,that is it contains quite a lot of layers, sublayers and of course, parameters. That’s why it is hard to apply Transformer models in the real-world.

Let’s go back. As we can see, there are two main sub-modules that appears both in/outside the encoders and decoders.

  1. Positional Encoding
  2. Multi-Head Attention
  3. Position-wise Feed Forward

2. Positional Encoding

As we know, Transformer is empowered by its strong attention mechanism, which is like connecting every two words directly, or we say, it is able to model the words jointly. Attention also helps solve the problem of gradient explosion/vanishing in RNNs as well as make the training process parallelable.

However, using attention only also brings a significant problem: how to express the position information in the sentences? As we know, RNNs are good at handling sequence tasks because they are position-sensitive. For example, the output of ‘A student at Tokyo Tech’ and ‘Tech Tokyo at student A’ will never lead to the same output in RNNs.

Thus, to convey such position information in encodings, besides a normal embedding layer, Transformer proposed a relative positional encoding method, that is:

\[PE_{(pos, 2i)}=sin(\frac{pos}{10000^{\frac{2i}{d_{model}}}})\] \[PE_{(pos, 2i+1)}=cos(\frac{pos}{10000^{\frac{2i}{d_{model}}}})\]

where pos means the position of word, 2i and 2i+1 represents the dimension.

class PositionalEncoding(torch.nn.Module):
    def __init__(self, model_dim, max_seq_len):
        super(PositionalEncoding, self).__init__()
        self.register_buffer('pe_table', self._get_pe_table(model_dim, max_seq_len))
    def forward(self, input):
        input : tensor, encodings of input sequence batch, shape (batch_size, seq_len, model_dim)
        return input + self.pd_table[:input.shape[1]].clone().detach()
    def _get_pe_table(self, model_dim, max_seq_len):
        Use logarithm form to simplify equation.
        pe_table = torch.zeros((model_dim, max_seq_len))
        pos = torch.arange(0, max_seq_len, dtype=torch.float32) # shape (max_seq_len,)
        div_term = torch.exp(-torch.log(torch.tensor(10000.)) * torch.arange(0, model_dim, 2, dtype=torch.float32) / model_dim) # shape (model_dim // 2,)
        pos = pos.unsqueeze(1)
        div_term = div_term.unsqueeze(0)
        pe_table[:, 0::2] = torch.sin(pos * div_term)
        pe_table[:, 1::2] = torch.cos(pos * div_term)
        return pe_table

By adding the sinusoid function, the position information of words in the sequences is stressed.

Of course, we can also learn a positional encoding here. According to [1], learned encoding produces ‘nearly identical results’ as the sinusoidal version. However, compared with learned positional encoding, such sinusoidal encoding method has two advantages:

  1. This encoding can naturally adapt to longer sequences which it has never met during training, while the learned encoding depends too much on the training data.

  2. It saves parameters and is faster.

3. Multi-Head Attention

Multi-head attention is the core of Transformer. Let’s see how it works.

Multi-head attention consists of a number of Scaled Dot-Product Attention, which can be expressed by the equation below:

\[Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V\]

where Q, K, V represent query, key and value respectively.

For easy understanding, we can imagine we are using query to look for which value we should look at more by matching with their keys.

Knowing about this, we can first implement the “single-head” attention:

class ScaledDotProductAttention(torch.nn.Module):
    def __init__(self, model_dim, key_dim, dropout=0.1):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.key_dim = key_dim
    def forward(self, query, key, value, mask=None):
        query : tensor, shape (batch_size, seq_len_1, query_dim)
        key, value : tensor, shape (batch_size, seq_len_2, query_dim)
        mask : tensor, shape (batch_size, 1, seq_len_2)
        key = key.transpose(-1, -2)
        attn_score = torch.matmul(query / (self.key_dim ** 0.5), key)
        # Mask the unneeded/unseenable values in the key/value matrices.
        if mask is not None:
            mask = mask.unsqueeze(1).eq(0)
            attn_score = attn_score.masked_fill(mask=mask, value=1e-9) # -infinity value
            attn_score = F.softmax(attn_score, dim=-1).masked_fill(mask=mask, value=0.0)
            attn_score = F.softmax(attn_score, dim=-1)
        attn_score = self.dropout(attn_score)
        out = torch.matmul(attn_score, value)
        return out, attn_score

You may be confused that how do Q, K, V appear? Actually they are just different linear transformation of the inputs, for example, the output after the Positional Encoding layer. The so-called Multi-Head is just repeatedly do the same transformation and Scaled Dot-Product Attention multiple times.

Let’s find out how it works in the multi-head attention:

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, model_dim, key_dim, value_dim, n_head, dropout=0.1):
        n_head : default value 8 as set in [1]
        key_dim, value_dim : default value model_dim/n_head = 512 / 8 = 64 as set in [1]
        super(MultiHeadAttention, self).__init()
        self.linear_K = nn.Linear(in_features=model_dim, out_features=key_dim * n_head)
        self.linear_Q = nn.Linear(in_features=model_dim, out_features=key_dim * n_head)
        self.linear_V = nn.Linear(in_features=model_dim, out_features=value_dim * n_head)
        self.linear_out = nn.Linear(in_features=valud_dim*n_head, out_features=model_dim)
        self.attn = ScaledDotProductAttention(model_dim, key_dim, dropout)
        self.dropout = nn.Dropout(p=dropout)
        self.n_head = n_head
        self.key_dim = key_dim
        self.value_dim = value_dim
    def forward(self, query, key, value, mask=None):
        query : tensor, shape (batch_size, seq_len_1, model_dim)
        key, value : tensor, shape (batch_size, seq_len_2, model_dim)
        mask : tensor, shape (batch_size, 1, seq_len_2)
        batch_size, seq_len_1, _ = query.shape
        n_head = self.n_head
        key_dim = self.key_dim
        value_dim = self.value_dim
        key = self.linear_K(key).view(batch_size, n_head, seq_len_1, key_dim)
        query = self.linear_Q(query).view(batch_size, n_head, seq_len_1, key_dim)
        value = self.linear_V(value).view(batch_size, n_head, seq_len_1, value_dim)
        # out shape here: (batch_size, n_head, seq_len_1, value_dim)
        out, attn_score = self.attn(query, key, value, mask=mask)
        out = out.view(batch_size, seq_len_1, n_head*value_dim)
        out = self.dropout(self.linear_out(out)) # Do not forget Dropout!
        return out, attn_score

4. Position-wise Feed Forward Network

This is the easiest part for understanding in Transformer as long as you have basic understanding deep learning. It is actually a two-layer full-connected network.

\[FFN(x)=max(0, xW_1+b_1)W_2+b_2\]
class PositionwiseFeedForward(torch.nn.Module):
    def __init__(self, model_dim, hidden_dim, dropout=0.1):
        model_dim : default 512 as set in [1]
        hidden_dim : default 2048 as set in [1]
        super(PositionwiseFeedForward, self).__init__()
        self.linear_1 = nn.Linear(in_features=model_dim, out_features=hidden_dim)
        self.linear_2 = nn.Linear(in_features=hidden_dim, out_features=model_dim)
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, input):
        Input : tensor, shape (batch_size, seq_len, model_dim)
        out = F.relu(self.linear_1(input))
        out = self.dropout(self.linear_2(out))
        return out

5. Conclusion

Now we have looked through the whole architecture and design of core modules in Transformer. In the future, I would like to extend the topic more by introducing its application in Speech[2] and the well-known pretrained model BERT [3].

Thanks so much for you time. If you want to learn more about my detailed implementation, please refer to my GitHub repository for the full code: Repo.

6. References

[1] Vaswani, A., et al. “Attention Is All You Need. arXiv 2017.” arXiv preprint arXiv:1706.03762.

[2] L. Dong, S. Xu and B. Xu, “Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition,” 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Calgary, AB, 2018, pp. 5884-5888.

[3] Devlin, Jacob, et al. “Bert: Pre-training of deep bidirectional transformers for language understanding.” arXiv preprint arXiv:1810.04805 (2018).