Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/meta-llama/llama/llms.txt

Use this file to discover all available pages before exploring further.

Overview

The Attention class implements multi-head attention with support for Grouped-Query Attention (GQA), rotary position embeddings (RoPE), and key-value caching for efficient autoregressive generation.

Definition

class Attention(nn.Module):
    """Multi-head attention module."""
    def __init__(self, args: ModelArgs)
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ) -> torch.Tensor

Initialization

Parameters

args
ModelArgs
Model configuration parameters. See ModelArgs for details.

Attributes

The __init__ method creates the following attributes:
n_kv_heads
int
Number of key and value heads. Defaults to args.n_heads if args.n_kv_heads is None.
n_local_heads
int
Number of local query heads after model parallelism partitioning. Computed as args.n_heads // model_parallel_size.
n_local_kv_heads
int
Number of local key and value heads after model parallelism partitioning. Computed as n_kv_heads // model_parallel_size.
n_rep
int
Number of repetitions for local heads in Grouped-Query Attention. Computed as n_local_heads // n_local_kv_heads.
head_dim
int
Dimension size of each attention head. Computed as args.dim // args.n_heads.
wq
ColumnParallelLinear
Linear transformation for queries. Projects from args.dim to args.n_heads * head_dim without bias.
wk
ColumnParallelLinear
Linear transformation for keys. Projects from args.dim to n_kv_heads * head_dim without bias.
wv
ColumnParallelLinear
Linear transformation for values. Projects from args.dim to n_kv_heads * head_dim without bias.
wo
RowParallelLinear
Linear transformation for output. Projects from args.n_heads * head_dim to args.dim without bias.
cache_k
torch.Tensor
Cached keys for attention with shape (max_batch_size, max_seq_len, n_local_kv_heads, head_dim). Pre-allocated on CUDA.
cache_v
torch.Tensor
Cached values for attention with shape (max_batch_size, max_seq_len, n_local_kv_heads, head_dim). Pre-allocated on CUDA.

Forward Pass

Parameters

x
torch.Tensor
Input tensor with shape (batch_size, seq_len, dim).
start_pos
int
Starting position for caching. Used to index into the KV cache for autoregressive generation.
freqs_cis
torch.Tensor
Precomputed frequency tensor for rotary position embeddings (complex exponentials).
mask
Optional[torch.Tensor]
Attention mask tensor. When provided, added to attention scores before softmax to prevent attending to certain positions.

Returns

output
torch.Tensor
Output tensor after attention with shape (batch_size, seq_len, dim).

Grouped-Query Attention (GQA)

The module implements GQA by allowing different numbers of query heads (n_heads) and key-value heads (n_kv_heads):
  • When n_kv_heads == n_heads: Standard multi-head attention
  • When n_kv_heads < n_heads: Grouped-Query Attention (more efficient)
Key-value heads are repeated using the repeat_kv function (model.py:164-173) to match the number of query heads:
keys = repeat_kv(keys, self.n_rep)  # Expand KV heads to match query heads
values = repeat_kv(values, self.n_rep)

Attention Mechanism

The forward pass implements scaled dot-product attention with caching:
  1. Project inputs: Apply linear transformations wq, wk, wv (model.py:274)
  2. Reshape: Split into multiple heads (model.py:276-278)
  3. Apply RoPE: Apply rotary position embeddings to queries and keys (model.py:280)
  4. Update cache: Store current keys and values in cache (model.py:285-286)
  5. Retrieve from cache: Get all keys and values up to current position (model.py:288-289)
  6. Repeat KV: Expand key-value heads for GQA (model.py:292-293)
  7. Compute scores: Calculate attention scores with scaling (model.py:298)
  8. Apply mask: Add attention mask if provided (model.py:299-300)
  9. Softmax: Normalize scores (model.py:301)
  10. Apply attention: Multiply scores with values (model.py:302)
  11. Output projection: Apply wo transformation (model.py:304)

Usage in TransformerBlock

The Attention module is instantiated in TransformerBlock and applied with pre-normalization:
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.attention = Attention(args)  # Initialize attention
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        # ...

    def forward(self, x, start_pos, freqs_cis, mask):
        # Pre-normalization + residual connection
        h = x + self.attention(
            self.attention_norm(x), start_pos, freqs_cis, mask
        )
        # ...
See model.py:375 and model.py:406-408 for the complete implementation.