Attention is all you need
Attention 101
Let's start to take a look at what is attention, and how to calculate attention.
What is attention
Basically attention is a technique that allows a model to assign importance scores to different element of its input. This allows the model to focus on most relevant information for the task.
How it works
Attention works with three key components: Query, Key and Value
Query: The current focus of the model
Keys: A set of elements in the input sequences
Values: Each key has a value, containing the actual information associated with the key.
The output of attention is computed as a weighted sum of the values, where the weight is a function of query and corresponding key:
attention = softmax(Query * Key) * Value
Basically the weight is the softmax of dot product between Query and Key. Dot product calculate the similarity between Query and Key. Softmax transforms the similarity score into a probability between 0 and 1.
Walk through a toy example
Let's walk through a toy example about how to calculate the basic attention with text as input (Fig. 1):
Let W be the input sequence of length 6 (N=6): "Let's explore and discover the world"
Get the word embedding and put them into embedding matrix X ∈ RNxd. d is the embedding size and we set it as 16 in this example (d=16). Each row corresponds to an embedding of one token (word) in the original sequence W.
Transform each word embedding with weight matrices Q∈ Rdxd, K∈ Rdxd, V ∈ Rdxd to get queries(X*Q ∈ Rnxd), keys(X*K ∈ Rnxd) and values(X*V ∈ Rnxd).
Compute pairwise similarities between queries and keys and normalized with softmax as the weight: Softmax(X*Q * (X*K)T)
Compute the basic attention as the weighted sum of values: Softmax(X*Q * (X*K)T) * (X*V)
Figure 1 A toy example to calculate basic attention
Positional encoding & Scaled attention
There are a few improvement that the Attention is all you need work has contributed to the attention mechanism. We will talk about positional encoding and scaled attention in this section. Both of them are highlighted as red in Fig. 2.
Positional encoding
Unlike RNN and CNN that takes order of the sequence into consideration, the transformer architecture doesn't have the positional information by nature. However, position information are usually important to the tasks. To make sure the information is available for transformer, the authors proposed the "positional encoding" idea:
Using a positional encoding or learn a position embedding, which has the same dimension (d) as the word embedding.
Sum the positional encoding and the word embedding together to get the final embedding in X.
Scaled attention
The authors added a scaling factor 1/sqrt(d) to the dot-product attention to avoid vanishing gradient when the dot product grows in magnitude with the large d.
Figure 2 Improvement on basic attention: positional encoding and scaled attention.
Multihead attention
Why multi-head attention
A single attention head may only be able to focus on one aspect of the input data. Multi-head can independently learn to attend different aspects of the input.
e.g. In language, one head can focus on sub-verb relationship and another can focus on sentiment of phrases.
Multi-head also takes the idea of ensemble-like behavior and combining the head leads to more robust and versatile representations.
Computationally efficiency (can be seen from the toy example below): The mutlihead methods metrices (Fig. 3) are the same size as the basic attention method (Fig. 1).
Walk through a toy example with 4 heads
Let W be the input sequence of length 6 (N=6): "Let's explore and discover the world"
Sum the word embedding and positional encoding to get embedding matrix X ∈ RNxd. d is the embedding size and we set it as 16 in this example (d=16). Each row corresponds to an embedding of one token (word) in the original sequence W.
Transform each word embedding with weight matrices Q∈ Rdx(d/h), K∈ Rdx(d/h), V ∈ Rdx(d/h) to get queries(X*Q ∈ Rnx(d/h)), keys(X*K ∈ Rnx(d/h)) and values(X*V ∈ Rnx(d/h)), where d=16 and h=4.
For each head: compute pairwise similarities between queries and keys and normalized with softmax as the weight: Softmax(X*Q * (X*K)T)
For each head: compute the basic attention as the weighted sum of values: Softmax(X*Q * (X*K)T) * (X*V)
Concatenate the output from each head to get the final output.
Figure 3 A toy example of multihead attention with positional encoding and scaled attention.
Encoder/Decoder: Attention is not all you need
Encoder
The encoder is composed of N (N=6) identical layers
Basic block for encoder
The basic building block for encoder/decoder are composed of
A multi-head self-attention layer followed by layer normalization
Feed-forward network (FFN) followed by layer normalization
To add non-linearity to attention.
A residual connection around the multi-head attention and FFN:
Add a shortcut path that directly adds the input of the layer to its output to avoid gradient vanishing issue in deeper networks and improve training efficiency.
Decoder
The decoder is composed of N (N=6) identical layers
Basic blocks for decoder
The basic building block for encoder/decoder are composed of
A masked multi-head self-attention layer followed by layer normalization
Mask: To prevent cheating: avoid attending to subsequent positions.
A cross-attention layer to attend between output of encoder and the masked multi-head self-attention.
Feed-forward network (FFN) follower by layer normalization
A residual connection around the multi-head attentions and FFN.
Figure 4 The transformer model architecture