MULTI-HEADED ATTENTION IN TRANSFORMERS FROM SCRATCH
- MLV Prasad
- Oct 23, 2023
- 3 min read
Updated: Nov 14, 2023
The Multi-headed Attention Mechanism Used in the Latest LLM Models

Coding from Scratch in PyTorch.
"ATTENTION" IS ALL YOU NEED - LLM
import math
from typing import Optional, List
import torch
from torch
import nn
from labml import trackerPrepare for multi-head attention
This module does a linear transformation and splits the vector into a given number of heads for multi- head attention. This is used to transform key, query, and value vectors.
class PrepareForMultiHeadAttention(nn.Module):
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
super().__init__()
Linear layer for linear transform
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)Linear layer for linear transformZ
self.linear = nn.LinZear(d_model, heads * d_k, bias=bias)Number of heads
self.heads = headsNumber of dimensions in vectors in each head
self.d_k = d_k def forward(self, x: torch.Tensor):Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model] .
We apply the linear transformation to the last dimension and split that into the heads.
head_shape = x.shape[:-1]Linear transform
x = self.linear(x)Split the last dimension into heads
x = x.view(*head_shape, self.heads, self.d_k)Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, heads, d_model]
return xMulti-Head Attention Module
This computes scaled multi-headed attention for given query , key and value vectors.
Attention(Q,K,V) = softmax (Q K⊤/ sqrt ( Dk ) ) * V
In simple terms, it finds keys that matches the query and gets the values of those keys. It uses dot-product of query and key as the indicator of how matching they are.
Before taking the softmax the dot-products are scaled by sqrt ( Dk ).
This is done to avoid large dot-product values causing softmax to give very small gradients
when ( Dk ). is large. Softmax is calculated along the axis of of the sequence (or time).
class MultiHeadAttention(nn.Module):heads is the number of heads.
d_model is the number of features in the query, key and value vectors.
def __init__(self, heads: int, d_model: int,
dropout_prob: float = 0.1,
bias: bool = True): super().__init__()Number of features per head
self.d_k = d_model // headsNumber of heads
self.heads = headsThese transform the query, key and value vectors for multi-headed attention.
self.query = PrepareForMultiHeadAttention(d_model, heads,
self.d_k, bias=bias)
self.key = PrepareForMultiHeadAttention(d_model, heads,
self.d_k, bias=bias)
self.value = PrepareForMultiHeadAttention(d_model, heads,
self.d_k, bias=True)Softmax for attention along the time dimension of key
self.softmax = nn.Softmax(dim=1)Output layer
self.output = nn.Linear(d_model, d_model)Dropout
self.dropout = nn.Dropout(dropout_prob)Scaling factor before the softmax
self.scale = 1 / math.sqrt(self.d_k)We store attentions so that it can be used for logging, or other computations if needed
self.attn = NoneCalculate scores between queries and keys This method can be overridden for other variations like relative attention.
def get_scores(self, query: torch.Tensor, key: torch.Tensor):Calculate QK⊤ or Sijbh = ∑d QibhdKjbhd
return torch.einsum('ibhd,jbhd->ijbh', query, key)mask has shape [seq_len_q, seq_len_k, batch_size] , where first dimension is the
query dimension. If the query dimension is equal to 1 it will be broadcasted.
def prepare_mask(self, mask: torch.Tensor,
query_shape: List[int],
key_shape: List[int]): assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
assert mask.shape[1] == key_shape[0]
assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]Same mask applied to all heads.
mask = mask.unsqueeze(-1)resulting mask has shape [seq_len_q, seq_len_k, batch_size, heads]
return maskquery, key and value are the tensors that store collection of query, key and value vectors.
They have shape [seq_len, batch_size, d_model]. mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether
for batch b, query at position i has access to key-value at position j .
def forward(self, *,query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None):query , key and value have shape [seq_len, batch_size, d_model]
seq_len, batch_size, _ = query.shape164165
if mask is not None:
mask = self.prepare_mask(mask, query.shape, key.shape)Prepare query, key and value for attention computation. These will then have shape
[seq_len, b atch_size, heads, d_k] .
query = self.query(query)
key = self.key(key)
value = self.value(value)Compute attention scores Q K⊤.
This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .
scores = self.get_scores(query, key)Scale scores Q K⊤/ sqrt ( Dk )
scores *= self.scaleApply mask
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))softmax attention along the key sequence dimension
softmax (Q K⊤/ sqrt ( Dk ) * V )
attn = self.softmax(scores)Save attentions if debugging
tracker.debug('attn', attn)Apply dropout
attn = self.dropout(attn)Multiply by values softmax (Q K⊤/ sqrt ( Dk ) * V )
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)Save attentions for any other calculations
self.attn = attn.detach()Concatenate multiple heads
x = x.reshape(seq_len, batch_size, -1)Output layer
return self.output(x)


Comments