Transformers
Introduction
The Transformer architecture, introduced by Vaswani et al. in the paper "Attention is All You Need" (2017), represents a paradigm shift in sequence processing using neural networks. Unlike traditional recurrent architectures (RNNs and LSTMs), which process sequences sequentially, Transformers employ attention mechanisms that allow parallel processing of the entire sequence, capturing long-range dependencies more effectively.
The architecture is based on the multi-head attention mechanism, which allows the model to simultaneously attend to different positions of the input sequence from multiple representation subspaces. This capability, combined with residual connections, layer normalization, and feed-forward networks, has proven to be extraordinarily effective in natural language processing tasks, computer vision, and other applications involving sequential data.
Auxiliary functions for masking
Masking constitutes an essential component in the Transformer architecture, serving two main functions. The causal mask prevents the decoder from accessing future tokens during training, preserving the autoregressive nature of sequence generation. The padding mask allows ignoring padding tokens that are added to standardize sequence lengths in a batch.
import torch
import math
from torch import nn
from torch.nn import functional as F
def create_causal_mask(size: int) -> torch.Tensor:
"""
Creates a causal mask to prevent the decoder from attending to future tokens during training.
Args:
size: Length of the sequence.
Returns:
Causal mask of shape (size, size).
"""
return torch.tril(torch.ones(size, size))
def create_padding_mask(seq: torch.Tensor, pad_token: int = 0) -> torch.Tensor:
"""
Creates a mask to ignore padding tokens in a sequence.
Args:
seq: Sequence of tokens, shape (B, seq_len).
pad_token: Padding token value.
Returns:
Padding mask of shape (B, 1, 1, seq_len).
"""
return (seq != pad_token).unsqueeze(1).unsqueeze(1)
Input embeddings and positional encoding
The transformation of discrete tokens into continuous vector representations constitutes the first step in processing with Transformers. Input embeddings map each token from the vocabulary to a dense vector of dimension \(d_{model}\), learned during training. Following the specification of the original paper, these embeddings are scaled by multiplying by \(\sqrt{d_{model}}\) to stabilize training.
Since Transformers lack the intrinsic notion of sequential order present in RNNs, it is necessary to explicitly inject positional information. Positional encoding adds deterministic vectors to the embeddings, calculated using sinusoidal functions that allow the model to distinguish relative positions:
where \(pos\) represents the position in the sequence and \(i\) the embedding dimension.
class InputEmbedding(nn.Module):
"""Embeds input tokens into vectors of dimension d_model."""
def __init__(self, d_model: int, vocab_size: int) -> None:
"""
Initializes input embedding layer.
Args:
d_model: Dimensionality of the embedding vectors.
vocab_size: Size of the vocabulary.
"""
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the embedding layer.
Args:
input_tensor: Input tensor of token indices.
Returns:
Tensor of embedded input scaled by sqrt(d_model).
"""
return self.embedding(input_tensor) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Module):
"""Adds positional encoding to input embeddings."""
def __init__(self, d_model: int, sequence_length: int, dropout_rate: float) -> None:
"""
Initializes positional encoding layer.
Args:
d_model: Dimensionality of the embedding vectors.
sequence_length: Maximum sequence length.
dropout_rate: Rate of dropout for regularization.
"""
super().__init__()
self.d_model = d_model
self.sequence_length = sequence_length
self.dropout = nn.Dropout(dropout_rate)
pe_matrix = torch.zeros(size=(self.sequence_length, self.d_model))
position = torch.arange(0, self.sequence_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe_matrix[:, 0::2] = torch.sin(position * div_term)
pe_matrix[:, 1::2] = torch.cos(position * div_term)
pe_matrix = pe_matrix.unsqueeze(0)
self.register_buffer(name="pe_matrix", tensor=pe_matrix)
def forward(self, input_embedding: torch.Tensor) -> torch.Tensor:
"""
Forward pass to add positional encoding.
Args:
input_embedding: Tensor of input embeddings.
Returns:
Tensor of embeddings with added positional encoding.
"""
x = input_embedding + (
self.pe_matrix[:, : input_embedding.shape[1], :] # type: ignore
).requires_grad_(False)
return self.dropout(x)
Layer normalization and feed-forward networks
Layer normalization stabilizes training by normalizing activations across features for each individual example. Unlike batch normalization, which normalizes across the batch, layer normalization is more suitable for sequential data of variable length. The transformation is defined as:
where \(\mu\) and \(\sigma^2\) are the mean and variance calculated over the features, \(\alpha\) and \(\beta\) are learnable parameters, and \(\epsilon\) is a small constant for numerical stability.
Feed-forward networks apply non-linear transformations independently to each position in the sequence. They consist of two linear transformations with an intermediate ReLU activation:
class LayerNormalization(nn.Module):
"""Applies layer normalization to input embeddings."""
def __init__(self, features: int, eps: float = 1e-6) -> None:
"""
Initializes layer normalization.
Args:
features: Number of features in the input.
eps: Small constant for numerical stability.
"""
super().__init__()
self.features = features
self.eps = eps
self.alpha = nn.Parameter(torch.ones(self.features))
self.bias = nn.Parameter(torch.zeros(self.features))
def forward(self, input_embedding: torch.Tensor) -> torch.Tensor:
"""
Forward pass for layer normalization.
Args:
input_embedding: Tensor of input embeddings.
Returns:
Normalized tensor.
"""
mean = torch.mean(input=input_embedding, dim=-1, keepdim=True)
var = torch.var(input=input_embedding, dim=-1, keepdim=True, unbiased=False)
return (
self.alpha * ((input_embedding - mean) / (torch.sqrt(var + self.eps)))
+ self.bias
)
class FeedForward(nn.Module):
"""Feed-forward neural network layer."""
def __init__(self, d_model: int, d_ff: int, dropout_rate: float) -> None:
"""
Initializes feed-forward network.
Args:
d_model: Dimensionality of model embeddings.
d_ff: Dimensionality of feed-forward layer.
dropout_rate: Rate of dropout for regularization.
"""
super().__init__()
self.d_model = d_model
self.d_ff = d_ff
self.ffn = nn.Sequential(
nn.Linear(in_features=self.d_model, out_features=self.d_ff),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(in_features=self.d_ff, out_features=self.d_model),
)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
Forward pass through feed-forward network.
Args:
input_tensor: Tensor of input embeddings.
Returns:
Tensor processed by feed-forward network.
"""
return self.ffn(input_tensor)
Multi-head attention mechanism
The multi-head attention mechanism constitutes the central component of the Transformer architecture. It allows the model to attend to information from different representation subspaces at different positions simultaneously. Scaled dot-product attention is calculated as:
where \(Q\), \(K\), and \(V\) represent the query, key, and value matrices, and \(d_k\) is the dimensionality of the keys. The scaling factor \(\frac{1}{\sqrt{d_k}}\) prevents dot products from growing excessively in magnitude.
Multi-head attention linearly projects queries, keys, and values \(h\) times with different learned projections, applies the attention function in parallel, and concatenates the results:
where \(\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\).
class MultiHeadAttention(nn.Module):
"""Applies multi-head attention mechanism."""
def __init__(self, d_model: int, h: int, dropout_rate: float) -> None:
"""
Initializes multi-head attention layer.
Args:
d_model: Dimensionality of model embeddings.
h: Number of attention heads.
dropout_rate: Rate of dropout for regularization.
"""
super().__init__()
if d_model % h != 0:
raise ValueError("d_model must be divisible by h")
self.d_model = d_model
self.h = h
self.dropout = nn.Dropout(dropout_rate)
self.d_k = self.d_model // self.h
self.d_v = self.d_model // self.h
self.W_K = nn.Linear(
in_features=self.d_model, out_features=self.d_model, bias=False
)
self.W_Q = nn.Linear(
in_features=self.d_model, out_features=self.d_model, bias=False
)
self.W_V = nn.Linear(
in_features=self.d_model, out_features=self.d_model, bias=False
)
self.W_OUTPUT_CONCAT = nn.Linear(
in_features=self.d_model, out_features=self.d_model, bias=False
)
@staticmethod
def attention(
k: torch.Tensor,
q: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor | None = None,
dropout: nn.Dropout | None = None,
):
"""
Computes scaled dot-product attention.
Args:
k: Key tensor.
q: Query tensor.
v: Value tensor.
mask: Optional mask tensor.
dropout: Optional dropout layer.
Returns:
Tuple of attention output and scores.
"""
matmul_q_k = q @ k.transpose(-2, -1)
d_k = k.shape[-1]
matmul_q_k_scaled = matmul_q_k / math.sqrt(d_k)
if mask is not None:
matmul_q_k_scaled.masked_fill_(mask == 0, -1e9)
attention_scores = F.softmax(input=matmul_q_k_scaled, dim=-1)
if dropout is not None:
attention_scores = dropout(attention_scores)
return (attention_scores @ v), attention_scores
def forward(
self,
k: torch.Tensor,
q: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Forward pass through multi-head attention.
Args:
k: Key tensor.
q: Query tensor.
v: Value tensor.
mask: Optional mask tensor.
Returns:
Tensor after attention and concatenation.
"""
key_prima = self.W_K(k)
query_prima = self.W_Q(q)
value_prima = self.W_V(v)
key_prima = key_prima.view(
key_prima.shape[0], key_prima.shape[1], self.h, self.d_k
).transpose(1, 2)
query_prima = query_prima.view(
query_prima.shape[0], query_prima.shape[1], self.h, self.d_k
).transpose(1, 2)
value_prima = value_prima.view(
value_prima.shape[0], value_prima.shape[1], self.h, self.d_k
).transpose(1, 2)
attention, attention_scores = MultiHeadAttention.attention(
k=key_prima,
q=query_prima,
v=value_prima,
mask=mask,
dropout=self.dropout,
)
attention = attention.transpose(1, 2)
b, seq_len, h, d_k = attention.size()
attention_concat = attention.contiguous().view(b, seq_len, h * d_k)
return self.W_OUTPUT_CONCAT(attention_concat)
Residual connections and encoder-decoder blocks
Residual connections allow gradients to flow directly through the network, facilitating the training of deep architectures. Each sublayer in the Transformer is wrapped in a residual connection followed by layer normalization:
Encoder blocks apply self-attention followed by a feed-forward network, both with residual connections. Decoder blocks add an additional cross-attention layer that attends to the encoder output, allowing each position in the decoder to attend to all positions in the input sequence.
class ResidualConnection(nn.Module):
"""Applies residual connection around a sublayer."""
def __init__(self, features: int, dropout_rate: float) -> None:
"""
Initializes residual connection layer.
Args:
features: Number of features in the input.
dropout_rate: Rate of dropout for regularization.
"""
super().__init__()
self.dropout = nn.Dropout(dropout_rate)
self.layer_norm = LayerNormalization(features=features)
def forward(self, input_tensor: torch.Tensor, sublayer: nn.Module) -> torch.Tensor:
"""
Forward pass using residual connection.
Args:
input_tensor: Input tensor to the residual layer.
sublayer: Sublayer to apply within the residual connection.
Returns:
Tensor with residual connection applied.
"""
return input_tensor + self.dropout(sublayer(self.layer_norm(input_tensor)))
class EncoderBlock(nn.Module):
"""Encoder block with attention and feed-forward layers."""
def __init__(self, d_model: int, d_ff: int, h: int, dropout_rate: float) -> None:
"""
Initializes encoder block.
Args:
d_model: Dimensionality of model embeddings.
d_ff: Dimensionality of feed-forward layer.
h: Number of attention heads.
dropout_rate: Rate of dropout for regularization.
"""
super().__init__()
self.d_model = d_model
self.d_ff = d_ff
self.h = h
self.dropout_rate = dropout_rate
self.multi_head_attention_layer = MultiHeadAttention(
d_model=self.d_model, h=self.h, dropout_rate=self.dropout_rate
)
self.residual_layer_1 = ResidualConnection(
features=d_model, dropout_rate=self.dropout_rate
)
self.feed_forward_layer = FeedForward(
d_model=self.d_model, d_ff=self.d_ff, dropout_rate=self.dropout_rate
)
self.residual_layer_2 = ResidualConnection(
features=d_model, dropout_rate=self.dropout_rate
)
def forward(
self, input_tensor: torch.Tensor, mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
Forward pass through encoder block.
Args:
input_tensor: Input tensor to the encoder block.
mask: Optional mask tensor.
Returns:
Tensor after processing by the encoder block.
"""
input_tensor = self.residual_layer_1(
input_tensor,
lambda x: self.multi_head_attention_layer(k=x, q=x, v=x, mask=mask),
)
input_tensor = self.residual_layer_2(
input_tensor, lambda x: self.feed_forward_layer(x)
)
return input_tensor
class DecoderBlock(nn.Module):
"""Decoder block with masked attention, cross-attention, and feed-forward layers."""
def __init__(self, d_model: int, d_ff: int, h: int, dropout_rate: float) -> None:
"""
Initializes decoder block.
Args:
d_model: Dimensionality of model embeddings.
d_ff: Dimensionality of feed-forward layer.
h: Number of attention heads.
dropout_rate: Rate of dropout for regularization.
"""
super().__init__()
self.d_model = d_model
self.d_ff = d_ff
self.h = h
self.dropout_rate = dropout_rate
self.masked_multi_head_attention_layer = MultiHeadAttention(
d_model=self.d_model, h=self.h, dropout_rate=self.dropout_rate
)
self.residual_layer_1 = ResidualConnection(
features=d_model, dropout_rate=self.dropout_rate
)
self.multi_head_attention_layer = MultiHeadAttention(
d_model=self.d_model, h=self.h, dropout_rate=self.dropout_rate
)
self.residual_layer_2 = ResidualConnection(
features=d_model, dropout_rate=self.dropout_rate
)
self.feed_forward_layer = FeedForward(
d_model=self.d_model, d_ff=self.d_ff, dropout_rate=self.dropout_rate
)
self.residual_layer_3 = ResidualConnection(
features=d_model, dropout_rate=self.dropout_rate
)
def forward(
self,
decoder_input: torch.Tensor,
encoder_output: torch.Tensor,
src_mask: torch.Tensor | None = None,
tgt_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Forward pass through decoder block.
Args:
decoder_input: Input tensor to the decoder block.
encoder_output: Output tensor from the encoder.
src_mask: Optional source mask tensor.
tgt_mask: Optional target mask tensor.
Returns:
Tensor after processing by the decoder block.
"""
decoder_input = self.residual_layer_1(
decoder_input,
lambda x: self.masked_multi_head_attention_layer(
k=x, q=x, v=x, mask=tgt_mask
),
)
decoder_input = self.residual_layer_2(
decoder_input,
lambda x: self.multi_head_attention_layer(
k=encoder_output, q=x, v=encoder_output, mask=src_mask
),
)
decoder_output = self.residual_layer_3(
decoder_input, lambda x: self.feed_forward_layer(x)
)
return decoder_output
Complete Transformer architecture
The complete Transformer architecture integrates all the described components into an encoder-decoder structure. The encoder processes the input sequence through multiple identical layers, each applying self-attention and feed-forward transformations. The decoder generates the output sequence autoregressively, using both masked self-attention and cross-attention over the encoder output. A final projection layer transforms the decoder representations into probabilities over the output vocabulary.
class ProjectionLayer(nn.Module):
"""Converts d_model dimensions back to vocab_size."""
def __init__(self, d_model: int, vocab_size: int) -> None:
"""
Initializes projection layer.
Args:
d_model: Dimensionality of model embeddings.
vocab_size: Size of the vocabulary.
"""
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.projection_layer = nn.Linear(in_features=d_model, out_features=vocab_size)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
Forward pass through projection layer.
Args:
input_tensor: Input tensor to the projection layer.
Returns:
Tensor with projected dimensions.
"""
return self.projection_layer(input_tensor)
class Transformer(nn.Module):
"""Transformer model with encoder and decoder blocks."""
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
src_seq_len: int,
tgt_seq_len: int,
num_encoders: int,
num_decoders: int,
d_model: int,
d_ff: int,
h: int,
dropout_rate: float,
) -> None:
"""
Initializes transformer model.
Args:
src_vocab_size: Size of source vocabulary.
tgt_vocab_size: Size of target vocabulary.
src_seq_len: Maximum source sequence length.
tgt_seq_len: Maximum target sequence length.
num_encoders: Number of encoder blocks.
num_decoders: Number of decoder blocks.
d_model: Dimensionality of model embeddings.
d_ff: Dimensionality of feed-forward layer.
h: Number of attention heads.
dropout_rate: Rate of dropout for regularization.
"""
super().__init__()
self.src_vocab_size = src_vocab_size
self.tgt_vocab_size = tgt_vocab_size
self.src_seq_len = src_seq_len
self.tgt_seq_len = tgt_seq_len
self.num_encoders = num_encoders
self.num_decoders = num_decoders
self.d_model = d_model
self.d_ff = d_ff
self.h = h
self.dropout_rate = dropout_rate
self.src_embedding = InputEmbedding(
d_model=self.d_model, vocab_size=self.src_vocab_size
)
self.tgt_embedding = InputEmbedding(
d_model=self.d_model, vocab_size=self.tgt_vocab_size
)
self.src_positional_encoding = PositionalEncoding(
d_model=self.d_model,
sequence_length=self.src_seq_len,
dropout_rate=self.dropout_rate,
)
self.tgt_positional_encoding = PositionalEncoding(
d_model=self.d_model,
sequence_length=self.tgt_seq_len,
dropout_rate=self.dropout_rate,
)
self.encoder_layers = nn.ModuleList(
[
EncoderBlock(
d_model=self.d_model,
d_ff=self.d_ff,
h=self.h,
dropout_rate=self.dropout_rate,
)
for _ in range(self.num_encoders)
]
)
self.decoder_layers = nn.ModuleList(
[
DecoderBlock(
d_model=self.d_model,
d_ff=self.d_ff,
h=self.h,
dropout_rate=self.dropout_rate,
)
for _ in range(self.num_decoders)
]
)
self.projection_layer = ProjectionLayer(
d_model=self.d_model, vocab_size=self.tgt_vocab_size
)
def encode(
self, encoder_input: torch.Tensor, src_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
Encodes source input tensor using encoder blocks.
Args:
encoder_input: Input tensor to the encoder.
src_mask: Optional source mask tensor.
Returns:
Encoded tensor.
"""
x = self.src_embedding(encoder_input)
x = self.src_positional_encoding(x)
for encoder_layer in self.encoder_layers:
x = encoder_layer(input_tensor=x, mask=src_mask)
return x
def decode(
self,
decoder_input: torch.Tensor,
encoder_output: torch.Tensor,
src_mask: torch.Tensor | None = None,
tgt_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Decodes target input tensor using decoder blocks.
Args:
decoder_input: Input tensor to the decoder.
encoder_output: Output tensor from the encoder.
src_mask: Optional source mask tensor.
tgt_mask: Optional target mask tensor.
Returns:
Decoded tensor.
"""
x = self.tgt_embedding(decoder_input)
x = self.tgt_positional_encoding(x)
for decoder_layer in self.decoder_layers:
x = decoder_layer(
decoder_input=x,
encoder_output=encoder_output,
src_mask=src_mask,
tgt_mask=tgt_mask,
)
return x
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
src_mask: torch.Tensor | None = None,
tgt_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Processes input and target sequences through the encoder
and decoder, applying optional source and target masks.
Args:
src: Input sequence tensor.
tgt: Target sequence tensor.
src_mask: Optional mask for the input sequence.
tgt_mask: Optional mask for the target sequence.
Returns:
Tensor containing the final output after projection.
"""
encoder_output = self.encode(src, src_mask)
decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
return self.projection_layer(decoder_output)