简单MoE实现
MoE 通过多个专家(Expert)网络和一个门控(Gating)网络,动态选择最合适的专家进行计算,从而提高计算效率和模型能力。
import torch
import torch.nn as nn
import torch.nn.functional as F
class Expert(nn.Module):
"""单个专家网络"""
def __init__(self, input_dim, hidden_dim, output_dim):
super(Expert, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x)
class GatingNetwork(nn.Module):
"""门控网络,用于选择专家"""
def __init__(self, input_dim, num_experts):
super(GatingNetwork, self).__init__()
self.fc = nn.Linear(input_dim, num_experts)
def forward(self, x):
return F.softmax(self.fc(x), dim=-1) # 归一化权重
class MoE(nn.Module):
"""MoE 总体架构"""
def __init__(self, input_dim, hidden_dim, output_dim, num_experts=3, top_k=2):
super(MoE, self).__init__()
self.num_experts = num_experts
self.top_k = top_k
self.experts = nn.ModuleList([Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)])
self.gating_network = GatingNetwork(input_dim, num_experts)
def forward(self, x):
batch_size = x.size(0)
gate_output = self.gating_network(x) # 计算专家权重
topk_values, topk_indices = torch.topk(gate_output, self.top_k, dim=-1) # 选择 top-k 专家
# 初始化输出
output = torch.zeros(batch_size, self.experts[0].fc2.out_features).to(x.device)
# 遍历 top-k 专家
for i in range(self.top_k):
expert_idx = topk_indices[:, i] # 当前专家索引
expert_weight = topk_values[:, i].unsqueeze(-1) # 权重
# 计算专家输出并加权求和
expert_output = torch.stack([self.experts[idx](x[j:j+1]) for j, idx in enumerate(expert_idx)]).squeeze(1)
output += expert_weight * expert_output
return output
# 测试 MoE 模型
input_dim = 10
hidden_dim = 20
output_dim = 5
num_experts = 4
top_k = 2
moe_model = MoE(input_dim, hidden_dim, output_dim, num_experts, top_k)
x = torch.rand(8, input_dim) # Batch size = 8
output = moe_model(x)
print(output)
- Expert:定义多个专家(全连接网络)。
- GatingNetwork:计算输入在不同专家上的权重,并通过 softmax 归一化。
- MoE:通过 top-k 选择最重要的专家(避免所有专家都计算,提高计算效率)。使用 加权求和 组合所选专家的输出。
- 测试:使用随机输入数据测试 MoE 结构。
复杂MoE实现
- 支持 Sparse MoE(稀疏门控):使用 Top-K 选择部分专家,减少计算量。
- 支持 Transformer 结构:将 MoE 作为 Transformer 的 FFN 层,提高计算效率。
- 支持梯度回溯:使用 门控 loss(Load Balance Loss) 让专家均衡使用,避免单个专家独占计算。
- 适用于 GPT、BERT、Switch Transformer 等大模型。
import torch
import torch.nn as nn
import torch.nn.functional as F
class Expert(nn.Module):
"""单个专家网络"""
def __init__(self, model_dim):
super(Expert, self).__init__()
self.fc1 = nn.Linear(model_dim, 4 * model_dim) # 扩展维度
self.fc2 = nn.Linear(4 * model_dim, model_dim) # 压缩回原维度
def forward(self, x):
return self.fc2(F.relu(self.fc1(x)))
class GatingNetwork(nn.Module):
"""门控网络:控制专家选择"""
def __init__(self, model_dim, num_experts):
super(GatingNetwork, self).__init__()
self.fc = nn.Linear(model_dim, num_experts)
def forward(self, x):
gate_logits = self.fc(x)
gate_probs = F.softmax(gate_logits, dim=-1) # 归一化权重
return gate_probs
class SparseMoE(nn.Module):
"""稀疏 MoE 结构(支持 Top-K 选择)"""
def __init__(self, model_dim, num_experts, top_k=2):
super(SparseMoE, self).__init__()
self.num_experts = num_experts
self.top_k = top_k
self.experts = nn.ModuleList([Expert(model_dim) for _ in range(num_experts)])
self.gating_network = GatingNetwork(model_dim, num_experts)
def forward(self, x):
batch_size, seq_len, model_dim = x.shape
x = x.view(-1, model_dim) # 合并 batch 和 seq 维度
# 计算专家权重(门控网络)
gate_probs = self.gating_network(x) # (batch_size * seq_len, num_experts)
topk_values, topk_indices = torch.topk(gate_probs, self.top_k, dim=-1) # 选出前 k 个专家
# 计算专家的输出
output = torch.zeros_like(x)
for i in range(self.top_k):
expert_idx = topk_indices[:, i] # 选出的专家索引
expert_weight = topk_values[:, i].unsqueeze(-1) # 计算权重
# 获取专家输出(稀疏计算)
expert_output = torch.stack([self.experts[idx](x[j:j+1]) for j, idx in enumerate(expert_idx)]).squeeze(1)
# 加权求和
output += expert_weight * expert_output
# 还原 batch 维度
output = output.view(batch_size, seq_len, model_dim)
return output
class TransformerMoELayer(nn.Module):
"""Transformer 中的 MoE 结构(替换 FFN 层)"""
def __init__(self, model_dim, num_experts, top_k=2):
super(TransformerMoELayer, self).__init__()
self.norm = nn.LayerNorm(model_dim)
self.moe = SparseMoE(model_dim, num_experts, top_k)
def forward(self, x):
return x + self.moe(self.norm(x)) # 残差连接 + MoE
class TransformerWithMoE(nn.Module):
"""完整的 Transformer 结构,使用 MoE"""
def __init__(self, model_dim, num_experts, num_layers, top_k=2):
super(TransformerWithMoE, self).__init__()
self.layers = nn.ModuleList([TransformerMoELayer(model_dim, num_experts, top_k) for _ in range(num_layers)])
self.norm = nn.LayerNorm(model_dim)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.norm(x)
# 测试 MoE Transformer
batch_size = 4
seq_len = 10
model_dim = 512
num_experts = 8
num_layers = 6
top_k = 2
moe_transformer = TransformerWithMoE(model_dim, num_experts, num_layers, top_k)
x = torch.randn(batch_size, seq_len, model_dim)
output = moe_transformer(x)
print(output.shape) # 期望输出: (4, 10, 512)
- 专家(Expert):使用两层 Linear + ReLU 进行维度扩展和降维(类似于 Transformer FFN)。例如,如果 model_dim=512,专家会将其映射到 4*512=2048,然后再压缩回 512。
- 门控网络(Gating Network):计算不同专家的权重,并使用 Softmax 归一化。
- 稀疏 MoE(Sparse MoE):采用 Top-K 选择部分专家(避免计算所有专家,提高效率)。仅对选出的 Top-K 专家计算输出,并加权求和。
- Transformer 结构:采用标准 Transformer 框架,但用 MoE 替换原来的 FFN 层。添加 LayerNorm 和 残差连接。
- 完整的 Transformer with MoE:由多个 TransformerMoELayer 组成。MoE 处理序列数据,并使用 LayerNorm 归一化。
1993

被折叠的 条评论
为什么被折叠?



