Switch Transformer
Google 的 MoE 开创性工作
观察 Token 如何被路由到不同的专家,以及稀疏激活如何工作。
💡 Mixtral 8x7B: 47B 总参数,每次只激活 13B
一句话定义:MoE (Mixture of Experts) 是一种稀疏激活架构,每次推理只激活部分”专家”网络,在保持大容量的同时大幅降低计算量。
| Dense 模型 | MoE 模型 |
|---|---|
| 每个问题找全科医生 | 根据症状找专科医生 |
| 医生什么都要会 | 每个医生专精一个领域 |
| 一个人处理所有 | 只找相关的 2-3 个专家会诊 |
| 效率低 | 效率高,专业性强 |
MoE 让模型像医院一样:有很多专家,但每次只请相关的几位会诊。
| 模型 | 总参数 | 激活参数 | 推理成本 |
|---|---|---|---|
| GPT-3 | 175B | 175B | 100% |
| Mixtral 8x7B | 47B | 13B | 28% |
MoE 的魔力: 10x 参数,相似计算量!
| 优势 | 说明 |
|---|---|
| 容量大 | 更多参数 = 更多知识 |
| 计算省 | 稀疏激活,只用部分参数 |
| 专业化 | 不同专家学习不同能力 |
| 可扩展 | 加专家不增加推理成本 |
┌────────────────────────────────────────────────────────┐│ MoE Layer │├────────────────────────────────────────────────────────┤│ ││ 输入 x ││ │ ││ ├──→ Router (门控网络) ││ │ │ ││ │ ↓ ││ │ Gate Scores: [0.1, 0.3, 0.5, 0.05, ...] ││ │ │ ││ │ ↓ Top-K (K=2) ││ │ Selected: Expert 2, Expert 3 ││ │ ││ ├──→ Expert 2 ──→ y2 × 0.38 ││ │ ││ └──→ Expert 3 ──→ y3 × 0.62 ││ ││ 输出: y = 0.38 × y2 + 0.62 × y3 │└────────────────────────────────────────────────────────┘防止所有 token 都选同一个专家:
import torchimport torch.nn as nnimport torch.nn.functional as F
class MoELayer(nn.Module): """Mixture of Experts Layer"""
def __init__( self, hidden_size: int, num_experts: int = 8, top_k: int = 2, ffn_hidden_size: int = None ): super().__init__() self.num_experts = num_experts self.top_k = top_k
if ffn_hidden_size is None: ffn_hidden_size = hidden_size * 4
# 门控网络 self.gate = nn.Linear(hidden_size, num_experts, bias=False)
# 专家网络 (每个都是 FFN) self.experts = nn.ModuleList([ nn.Sequential( nn.Linear(hidden_size, ffn_hidden_size), nn.GELU(), nn.Linear(ffn_hidden_size, hidden_size) ) for _ in range(num_experts) ])
def forward(self, x): """ x: [batch_size, seq_len, hidden_size] """ batch_size, seq_len, hidden_size = x.shape x_flat = x.view(-1, hidden_size) # [B*S, H]
# 1. 计算门控分数 gate_logits = self.gate(x_flat) # [B*S, num_experts]
# 2. Top-K 选择 top_k_logits, top_k_indices = torch.topk( gate_logits, self.top_k, dim=-1 ) top_k_weights = F.softmax(top_k_logits, dim=-1) # [B*S, K]
# 3. 计算专家输出 output = torch.zeros_like(x_flat)
for i, expert in enumerate(self.experts): # 找到选择了这个专家的 token mask = (top_k_indices == i).any(dim=-1) # [B*S] if mask.any(): expert_input = x_flat[mask] expert_output = expert(expert_input)
# 获取权重 idx_in_topk = (top_k_indices[mask] == i).float() weights = (top_k_weights[mask] * idx_in_topk).sum(dim=-1, keepdim=True)
output[mask] += weights * expert_output
return output.view(batch_size, seq_len, hidden_size)
def compute_aux_loss(self, gate_logits): """负载均衡损失""" # 每个专家被选中的比例 probs = F.softmax(gate_logits, dim=-1) top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1).indices
# f_i: 专家 i 被选中的频率 expert_mask = F.one_hot(top_k_indices, self.num_experts).sum(dim=1) f = expert_mask.float().mean(dim=0)
# P_i: 专家 i 的平均概率 P = probs.mean(dim=0)
# 负载均衡损失 aux_loss = self.num_experts * (f * P).sum() return aux_lossfrom transformers import AutoModelForCausalLM, AutoTokenizer
# 加载 Mixtral 8x7Bmodel = AutoModelForCausalLM.from_pretrained( "mistralai/Mixtral-8x7B-v0.1", torch_dtype=torch.float16, device_map="auto")tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
# 推理inputs = tokenizer("Hello, how are you?", return_tensors="pt")outputs = model.generate(**inputs, max_new_tokens=100)print(tokenizer.decode(outputs[0]))
# 查看模型结构print(model)# 会看到每个 Transformer 层的 FFN 被替换为 MoE# MixtralSparseMoeBlock(# (gate): Linear(in=4096, out=8)# (experts): ModuleList(# (0-7): 8 x MixtralBlockSparseTop2MLP(...)# )# )# Megablocks 提供高效 MoE 实现# pip install megablocks
import megablocksfrom megablocks.layers import moefrom megablocks.layers.arguments import Arguments
# 配置args = Arguments( hidden_size=4096, ffn_hidden_size=14336, num_experts=8, top_k=2, moe_capacity_factor=1.25, # 容量因子)
# 创建高效 MoE 层moe_layer = moe.MoE(args)
# 前向传播 (使用优化的 kernel)x = torch.randn(2, 1024, 4096)output, aux_loss = moe_layer(x)| 模型 | 专家数 | Top-K | 特点 |
|---|---|---|---|
| Switch Transformer | 128 | 1 | 首个大规模 MoE |
| Mixtral 8x7B | 8 | 2 | 开源最强 MoE |
| DeepSeek-MoE | 64 | 6 | 细粒度专家 |
| Qwen-MoE | 60 | 4 | 阿里开源 |
| 维度 | Dense | MoE |
|---|---|---|
| 训练稳定性 | 好 | 需要负载均衡 |
| 推理效率 | 参数=计算 | 参数 > 计算 |
| 显存 | 参数量决定 | 全部专家要加载 |
| 微调 | 简单 | 需要专门方法 |
| 知识存储 | 分布式 | 可能专家化 |