Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/meta-llama/llama/llms.txt

Use this file to discover all available pages before exploring further.

Overview

The FeedForward class implements a SwiGLU (Swish-Gated Linear Unit) feedforward network, which is a key component of each Transformer layer. It uses a gating mechanism with the SiLU (Swish) activation function for improved performance.

Definition

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    )
    def forward(self, x: torch.Tensor) -> torch.Tensor

Initialization

Parameters

dim
int
Input dimension. This is the model’s hidden dimension from ModelArgs.dim.
hidden_dim
int
Base hidden dimension of the feedforward layer. The actual hidden dimension is computed as int(2 * hidden_dim / 3) and then adjusted.
multiple_of
int
Value to ensure the hidden dimension is a multiple of this number. Rounds up the computed hidden dimension to the nearest multiple for computational efficiency on modern hardware.
ffn_dim_multiplier
Optional[float]
Optional custom multiplier for the hidden dimension. When provided, scales the computed hidden dimension by this factor before rounding to multiple_of.

Hidden Dimension Computation

The actual hidden dimension is computed using the following logic (model.py:331-335):
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
    hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
This ensures:
  1. Base scaling by 2/3
  2. Optional custom scaling via ffn_dim_multiplier
  3. Rounding up to nearest multiple of multiple_of

Attributes

w1
ColumnParallelLinear
Linear transformation for the first gate layer. Projects from dim to hidden_dim without bias.
w2
RowParallelLinear
Linear transformation for the output layer. Projects from hidden_dim back to dim without bias.
w3
ColumnParallelLinear
Linear transformation for the second gate layer. Projects from dim to hidden_dim without bias.

Forward Pass

Parameters

x
torch.Tensor
Input tensor with shape (batch_size, seq_len, dim).

Returns

output
torch.Tensor
Output tensor with shape (batch_size, seq_len, dim).

Implementation

The forward pass implements the SwiGLU activation function (model.py:348):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
This computes:
  1. Gate branch: F.silu(self.w1(x)) - Apply SiLU activation to first projection
  2. Linear branch: self.w3(x) - Second projection without activation
  3. Gating: Element-wise multiplication of the two branches
  4. Output projection: self.w2(...) - Project back to model dimension

SwiGLU Activation

SwiGLU combines two concepts:
  • GLU (Gated Linear Unit): Uses a gating mechanism with element-wise multiplication
  • SiLU/Swish: Smooth activation function x * sigmoid(x)
The formula is: SwiGLU(x) = (Swish(W1·x) ⊙ W3·x) · W2 Where represents element-wise multiplication.

Usage in TransformerBlock

The FeedForward module is instantiated in TransformerBlock with specific dimension calculations:
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,  # Base hidden dim is 4x model dim
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
        # ...

    def forward(self, x, start_pos, freqs_cis, mask):
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        # Pre-normalization + residual connection
        out = h + self.feed_forward(self.ffn_norm(h))
        return out
See model.py:376-381 for initialization and model.py:409 for the forward pass with pre-normalization and residual connection.

Performance Considerations

  • The multiple_of parameter (default 256) ensures hidden dimensions are multiples of large powers of 2, improving GPU/TPU efficiency
  • Model parallelism is handled via ColumnParallelLinear and RowParallelLinear from FairScale
  • The 2/3 scaling factor and 4x base expansion result in an effective hidden dimension of approximately 8/3 * dim before rounding