Mixture of Experts
Introduction
The Mixture of Experts (MoE) architecture constitutes a deep learning paradigm that allows efficient model scaling through submodel specialization. Originally introduced by Jacobs et al. (1991) and subsequently popularized in the context of deep neural networks, this architecture is based on the "divide and conquer" principle: instead of training a single monolithic model for all tasks, a set of specialized models (experts) is trained along with a routing mechanism (gating network) that determines which experts should process each input.
The computational efficiency of MoE lies in its conditional activation capability: although the model can contain a large number of parameters distributed among multiple experts, only a subset of these is activated for each specific input. This property allows building models with massive expressive capabilities while maintaining manageable computational costs during inference. The architecture has proven particularly effective in large-scale language models, where different experts can specialize in different linguistic domains, styles, or types of knowledge.
Individual expert architecture
Each expert in an MoE architecture constitutes an independent neural network designed to process a specific subset of the input space. In its simplest form, an expert can be implemented as a feed-forward network with hidden layers that transform the input into an output representation. The specialization of each expert emerges naturally during training, where the routing mechanism learns to direct different types of inputs to different experts.
Mathematically, each expert \(E_i\) can be represented as a parameterized function:
where \(\theta_i\) represents the specific parameters of expert \(i\), and \(d_{in}\) and \(d_{out}\) are the input and output dimensionalities respectively.
import torch
from torch import nn
from torch.nn import functional as F
class ExpertModel(nn.Module):
"""
Individual expert model for MoE
"""
def __init__(self, input_dim: int, output_dim: int, hidden_dim: int) -> None:
"""
Initializes an expert model with a simple feed-forward network.
Args:
input_dim: Dimensionality of the input data.
output_dim: Dimensionality of the output data.
hidden_dim: Dimensionality of the hidden layer.
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim
self.model = nn.Sequential(
nn.Linear(in_features=self.input_dim, out_features=self.hidden_dim),
nn.ReLU(),
nn.Linear(in_features=self.hidden_dim, out_features=self.output_dim),
)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the expert model.
Args:
input_tensor: Input tensor to the model.
Returns:
The model's output tensor.
"""
return self.model(input_tensor)
Routing mechanism via gating network
The routing mechanism (gating network) constitutes the central component that determines how inputs are distributed among available experts. This network learns to assign weights to each expert based on input characteristics, producing a probability distribution over experts through a softmax function. The gating network can be interpreted as a soft classifier that determines which experts are most relevant for processing each specific input.
The gating function \(G(x; \phi)\) produces a vector of normalized weights:
where \(h(x)\) represents intermediate transformations applied to the input, \(W_g\) and \(b_g\) are learnable gating parameters, and \(\phi\) denotes the complete set of routing network parameters. The resulting weights satisfy \(\sum_{i=1}^{N} g_i(x) = 1\), where \(N\) is the total number of experts.
class Gating(nn.Module):
"""
Gating mechanism to select experts.
"""
def __init__(
self, input_dim: int, num_experts: int, dropout_rate: float = 0.2
) -> None:
"""
Initializes a gating network for expert selection.
Args:
input_dim: Dimensionality of the input data.
num_experts: Number of experts to select from.
dropout_rate: Rate of dropout for regularization.
"""
super().__init__()
self.input_dim = input_dim
self.num_experts = num_experts
self.dropout_rate = dropout_rate
self.model = nn.Sequential(
nn.Linear(in_features=self.input_dim, out_features=128),
nn.Dropout(self.dropout_rate),
nn.LeakyReLU(),
nn.Linear(in_features=128, out_features=256),
nn.LeakyReLU(),
nn.Dropout(self.dropout_rate),
nn.Linear(in_features=256, out_features=128),
nn.LeakyReLU(),
nn.Dropout(self.dropout_rate),
nn.Linear(in_features=128, out_features=num_experts),
)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the gating network.
Args:
input_tensor: Input tensor to the network.
Returns:
Softmax probabilities for expert selection.
"""
return F.softmax(self.model(input_tensor), dim=-1)
Complete Mixture of Experts architecture
The complete MoE architecture integrates individual experts with the gating mechanism to produce a final output through a weighted combination of expert predictions. For an input \(x\), the MoE system output is calculated as:
where \(g_i(x)\) represents the weight assigned to expert \(i\) by the gating network, and \(E_i(x)\) is the output of expert \(i\). This formulation allows the model to automatically learn which experts are most relevant for different regions of the input space, facilitating specialization and improving model capacity without proportionally increasing computational cost.
During training, both experts and the gating network are jointly optimized through standard backpropagation. The gradient flows through all experts weighted by their respective gating weights, allowing the system to learn both expert specialization and optimal routing in an end-to-end manner.
class MoE(nn.Module):
"""
Mixture of Experts
"""
def __init__(
self,
trained_experts: list[ExpertModel],
input_dim: int,
dropout_rate: float = 0.2,
) -> None:
"""
Initializes a mixture of experts with gating.
Args:
trained_experts: List of trained expert models.
input_dim: Dimensionality of the input data.
dropout_rate: Rate of dropout in the gating network.
"""
super().__init__()
self.experts = nn.ModuleList(trained_experts)
self.num_experts = len(trained_experts)
self.input_dim = input_dim
self.dropout_rate = dropout_rate
self.gating_layer = Gating(
input_dim=self.input_dim,
num_experts=self.num_experts,
dropout_rate=self.dropout_rate,
)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the mixture of experts.
Args:
input_tensor: Input tensor to the model.
Returns:
Weighted sum of expert outputs.
"""
expert_weights = self.gating_layer(input_tensor)
_expert_outputs: list[torch.Tensor] = []
for expert in self.experts:
_expert_outputs.append(expert(input_tensor))
expert_outputs = torch.stack(_expert_outputs, dim=-1)
expert_weights = expert_weights.unsqueeze(1)
return torch.sum(expert_outputs * expert_weights, dim=-1)
Usage example and verification
The MoE implementation allows direct integration into existing deep learning pipelines. The following example demonstrates basic model instantiation and usage, including verification that gating weights are correctly normalized.
if __name__ == "__main__":
input_dim = 10
output_dim = 5
num_experts = 3
batch_size = 32
hidden_dim = 128
experts = [
ExpertModel(input_dim=input_dim, output_dim=output_dim, hidden_dim=hidden_dim)
for _ in range(num_experts)
]
moe = MoE(experts, input_dim)
x = torch.randn(batch_size, input_dim)
output = moe(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Expected output shape: ({batch_size}, {output_dim})")
gating_weights = moe.gating_layer(x)
print(f"Gating weights shape: {gating_weights.shape}")
print(f"Gating weights sum per sample: {gating_weights.sum(dim=1)}")