moe
1. 普通 MoE 模型
"""
此模块实现了普通的 Mixture of Experts (MoE) 模型。
包含门控网络和多个专家网络,输入数据会被所有专家处理。
"""
import torch
import torch.nn as nn
class Expert(nn.Module):
"""
专家网络模块,简单的全连接神经网络。
"""
def __init__(self, input_size, hidden_size, output_size):
"""
初始化专家网络。
参数:
input_size (int): 输入特征的维度。
hidden_size (int): 隐藏层的维度。
output_size (int): 输出特征的维度。
"""
super(Expert, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
"""
前向传播方法。
参数:
x (torch.Tensor): 输入数据。
返回:
torch.Tensor: 专家网络的输出。
"""
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
class GatingNetwork(nn.Module):
"""
门控网络模块,用于为每个专家分配权重。
"""
def __init__(self, input_size, num_experts):
"""
初始化门控网络。
参数:
input_size (int): 输入特征的维度。
num_experts (int): 专家网络的数量。
"""
super(GatingNetwork, self).__init__()
self.fc = nn.Linear(input_size, num_experts)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
"""
前向传播方法。
参数:
x (torch.Tensor): 输入数据。
返回:
torch.Tensor: 每个专家的权重。
"""
x = self.fc(x)
x = self.softmax(x)
return x
class MoE(nn.Module):
"""
普通 MoE 模型模块,包含多个专家网络和一个门控网络。
"""
def __init__(self, input_size, hidden_size, output_size, num_experts):
"""
初始化普通 MoE 模型。
参数:
input_size (int): 输入特征的维度。
hidden_size (int): 专家网络隐藏层的维度。
output_size (int): 输出特征的维度。
num_experts (int): 专家网络的数量。
"""
super(MoE, self).__init__()
self.experts = nn.ModuleList([Expert(input_size, hidden_size, output_size) for _ in range(num_experts)])
self.gating_network = GatingNetwork(input_size, num_experts)
def forward(self, x):
"""
前向传播方法。
参数:
x (torch.Tensor): 输入数据。
返回:
torch.Tensor: MoE 模型的最终输出。
"""
# 计算门控权重
gates = self.gating_network(x)
expert_outputs = []
for expert in self.experts:
expert_outputs.append(expert(x))
expert_outputs = torch.stack(expert_outputs, dim=1)
# 加权求和
output = torch.sum(gates.unsqueeze(-1) * expert_outputs, dim=1)
return output
# 使用示例
input_size = 10
hidden_size = 20
output_size = 5
num_experts = 3
model = MoE(input_size, hidden_size, output_size, num_experts)
input_data = torch.randn(1, input_size)
output = model(input_data)
print(output)
2. 普通 MoE(Mixture of Experts)和 Sparse MoE 的区别
普通 MoE
普通 MoE (混合专家模型)由多个专家网络(Expert Networks)和一个门控网络(Gating Network)组成。输入数据会同时传递给所有专家网络,然后门控网络为每个专家网络分配一个权重,最后将所有专家网络的输出按照权重进行加权求和得到最终输出。由于所有专家网络都会处理输入数据,计算开销较大。
Sparse MoE
稀疏 MoE 是普通 MoE 的改进版本。在稀疏 MoE 中,输入数据不会同时传递给所有专家网络,而是根据门控网络的输出,只选择一部分(通常是 top-k 个)专家网络进行处理,从而减少计算开销。这种稀疏性使得模型在保持高性能的同时,能够显著降低计算成本。
class SparseMoE(nn.Module):
"""
稀疏 MoE 模型模块,包含多个专家网络和一个门控网络。
仅选择部分专家处理输入数据,对 top-k 专家权重重新归一化。
"""
def __init__(self, input_size, hidden_size, output_size, num_experts, top_k=2):
"""
初始化稀疏 MoE 模型。
参数:
input_size (int): 输入特征的维度。
hidden_size (int): 专家网络隐藏层的维度。
output_size (int): 输出特征的维度。
num_experts (int): 专家网络的数量。
top_k (int): 选择的专家数量。
"""
super(SparseMoE, self).__init__()
self.experts = nn.ModuleList([Expert(input_size, hidden_size, output_size) for _ in range(num_experts)])
self.gating_network = GatingNetwork(input_size, num_experts)
self.top_k = top_k
def forward(self, x):
"""
前向传播方法。
参数:
x (torch.Tensor): 输入数据。
返回:
torch.Tensor: 稀疏 MoE 模型的最终输出。
"""
# 计算门控权重
gates = self.gating_network(x)
# 选择 top-k 个专家
top_k_values, top_k_indices = torch.topk(gates, self.top_k, dim=1)
# 对 top-k 权重进行重新归一化
top_k_values = top_k_values / top_k_values.sum(dim=1, keepdim=True)
output = torch.zeros(x.size(0), self.experts[0].fc2.out_features).to(x.device)
for i in range(x.size(0)):
for j in range(self.top_k):
expert_index = top_k_indices[i][j]
expert_weight = top_k_values[i][j]
expert_output = self.experts[expert_index](x[i].unsqueeze(0))
output[i] += expert_weight * expert_output.squeeze(0)
return output
# 使用示例
input_size = 10
hidden_size = 20
output_size = 5
num_experts = 3
top_k = 2
model = SparseMoE(input_size, hidden_size, output_size, num_experts, top_k)
input_data = torch.randn(1, input_size)
output = model(input_data)
print(output)
以上的输入都是 batch_size, input_size的shape
实际输入可能是 batch_size, seq_len, input_size的shape,需要对每个seq_len进行处理,所以需要对输入数据进行reshape,将seq_len维度展开,然后进行处理,最后再将输出reshape回来。
比如
x = x.view(-1, x.size(-1)) # 展平成二维 (batch_size * seq_len, input_size)
然后就可以使用moe和sparse_moe了,最后再将输出reshape回来。
out = out.view(batch_size, seq_len, -1) # 恢复成三维 (batch_size, seq_len, output_size)
3. sparse_moe_vectorized的实现
当前这种使用显式嵌套循环的实现方式在效率上存在一定的局限性,下面从计算复杂度、内存开销和实际性能等方面进行分析,并给出优化建议。
-
计算复杂度
在 forward 方法里,有两层嵌套循环:外层循环遍历输入样本(for i in range(x.size(0))),内层循环遍历 top-k 个专家(for j in range(self.top_k))。假设输入样本数量为 N,top-k 为 K,那么时间复杂度为 O ( N ∗ K ) O(N * K) O(N∗K)。每次循环内部都要调用专家网络进行前向传播,这会带来额外的计算开销。当 N 和 K 数值较大时,计算量会显著增加,导致运行时间变长。 -
内存开销
主要的内存开销在于存储门控权重 gates、top_k_values、top_k_indices 以及中间的专家输出 expert_output。随着输入样本数量和专家数量的增多,这些中间变量占用的内存会相应增加。此外,Python 的循环在处理大规模数据时,由于解释器的开销,也会额外消耗一定的内存。 -
实际性能
小批量数据:对于小规模的输入数据(即 N 和 K 较小时),这种实现方式的性能损失并不明显,代码的可读性和简洁性优势得以体现。
大批量数据:当处理大规模数据时,Python 循环的效率远低于 PyTorch 的向量化操作。因为 PyTorch 的底层是用 C++ 实现的,向量化操作可以充分利用 GPU 的并行计算能力,而 Python 循环则无法有效利用这一点,从而导致性能大幅下降。
向量化后的计算代码如下,参考minimind项目代码。
num_shared_experts = 0的时候没有共享expert, 共享expert是对所有token处理
class SparseMoE2(nn.Module):
"""
稀疏 MoE 模型模块,包含多个专家网络和一个门控网络。
仅选择部分专家处理输入数据,对 top-k 专家权重重新归一化。
"""
def __init__(self, input_size, hidden_size, output_size, num_experts, top_k=2, num_shared_experts=0):
"""
初始化稀疏 MoE 模型。
参数:
input_size (int): 输入特征的维度。
hidden_size (int): 专家网络隐藏层的维度。
output_size (int): 输出特征的维度。
num_experts (int): 专家网络的数量。
top_k (int): 选择的专家数量。
"""
super(SparseMoE2, self).__init__()
self.experts = nn.ModuleList([Expert(input_size, hidden_size, output_size) for _ in range(num_experts)])
self.gating_network = GatingNetwork(input_size, num_experts)
self.top_k = top_k
self.num_shared_experts = num_shared_experts
if num_shared_experts > 0:
self.shared_experts = nn.ModuleList([Expert(input_size, hidden_size, output_size) for _ in range(num_shared_experts)])
else:
self.shared_experts = None
def forward(self, x):
"""
前向传播方法。
参数:
x (torch.Tensor): 输入数据。(token_number, h)
返回:
torch.Tensor: 稀疏 MoE 模型的最终输出。
"""
# 计算门控权重
gates = self.gating_network(x)
# 选择 top-k 个专家
top_k_values, top_k_indices = torch.topk(gates, self.top_k, dim=1) # 都是(token_number, topk)
# 对 top-k 权重进行重新归一化
top_k_values = top_k_values / top_k_values.sum(dim=1, keepdim=True)
flat_topk_idx = top_k_indices.view(-1) # (token_number * topk)
if self.training:
x = x.repeat_interleave(self.top_k, dim=0) # (token_number * topk , h)
# 初始化 y 为正确的形状 (token_number * topk , output_size)
y = torch.empty(x.size(0), self.experts[0].fc2.out_features, dtype=torch.float16, device=x.device)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = y.view(-1, self.top_k, self.experts[0].fc2.out_features)
output = (y * top_k_values.unsqueeze(-1)).sum(dim=1) # (token_number, topk, h) -> (token_number, h)
else:
output = self.moe_infer(x) # 训练时使用的推理方法
# 共享专家的处理
if self.shared_experts is not None:
for expert in self.shared_experts:
output = output + expert(x)
return output
@torch.no_grad()
def moe_infer(self, x):
"""
前向传播方法。
参数:
x (torch.Tensor): 输入数据。(token_number, h)
返回:
torch.Tensor: 稀疏 MoE 模型的最终输出。
"""
# 计算门控权重
gates = self.gating_network(x)
# 选择 top-k 个专家
top_k_values, top_k_indices = torch.topk(gates, self.top_k, dim=1) # 都是(token_number, topk)
# 对 top-k 权重进行重新归一化
top_k_values = top_k_values / top_k_values.sum(dim=1, keepdim=True)
top_k_values = top_k_values.view(-1)
top_k_indices = top_k_indices.view(-1) # (token_number * topk) 记录的是专家的序号
idxs = top_k_indices.argsort() # (token_number * topk) idxs 专家0-topk 对应的位置
tokens_per_expert = top_k_indices.bincount().cpu().numpy().cumsum(0) # (num_experts) 记录每个专家对应的 token 的数量
token_idxs = idxs // self.top_k # (token_number * topk) idxs (专家0-topk 对应的位置) -> token_idxs(专家0-topk 对应的token), 因为每个token有 topk个专家
expert_cache = torch.zeros(x.size(0), self.experts[0].fc2.out_features, dtype=torch.float32, device=x.device)
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx] # 找到每个专家对应的token, 是一个list存储的是专家对应的token的idx
expert_tokens = x[exp_token_idx] # 根据token的idx找到对应的token
expert_out = expert(expert_tokens) # 计算结果
# 获取当前专家对应的权重,并扩展维度以匹配 expert_out 的形状
expert_weights = top_k_values[idxs[start_idx:end_idx]].unsqueeze(-1)
expert_out.mul_(expert_weights) # 就地乘法
expert_cache.index_add_(0, exp_token_idx, expert_out)
#expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, expert_out.shape[-1]), expert_out)
return expert_cache
if __name__ == "__main__":
torch.manual_seed(42)
np.random.seed(42)
input_size = 10
hidden_size = 20
output_size = 5
num_experts = 3
top_k = 2
# 实例化第一个模型
torch.manual_seed(42) # 确保 model 的权重是固定的
model = SparseMoE(input_size, hidden_size, output_size, num_experts, top_k)
# 实例化第二个模型
torch.manual_seed(42) # 再次设置种子,以使得 model2 内部模块的初始化顺序与 model 一致
model2 = SparseMoE2(input_size, hidden_size, output_size, num_experts, top_k)
# 将 model 的权重加载到 model2
# 注意:这要求 model 和 model2 的网络结构完全一致,包括参数名
model2.load_state_dict(model.state_dict())
# 确保 model2 处于训练模式以执行我们分析的路径
model.train() # 虽然 SparseMoE 没有显式的 training 属性影响其 forward,但保持一致
model2.train() # 这会使得 SparseMoE2.training 为 True
input_data = torch.randn(4, input_size) # 使用 batch_size > 1 测试更充分
output = model(input_data)
output2 = model2(input_data)
model2.eval() # 将 model2 设置为评估模式以执行推理
output2 = model2(input_data) # 重新计算输出以确保它是正确的
print("Output from SparseMoE:")
print(output)
print("\nOutput from SparseMoE2 (with synced weights and training mode):")
print(output2)
# 检查是否几乎相等 (由于浮点数精度,可能不是完全相等)
if torch.allclose(output, output2, atol=1e-3):
print("\nOutputs are close enough (numerically equivalent).")
else:
print("\nOutputs are different.")
print("Difference:", torch.abs(output - output2).max())
写两个不同的实现主要是为了 性能优化和适应不同阶段的需求 :
-
训练阶段的需求 (由 forward 的训练分支处理):
- 正确性优先: 训练的核心是梯度下降。MoE 模型中的梯度需要通过门控网络和被选中的专家流回。 repeat_interleave 和后续的计算方式确保了计算图能够正确地跟踪每个 token 到其所有选中专家的路径,从而计算出正确的梯度。这种方式可能涉及更多的数据复制和中间张量,但保证了训练的数学正确性。
- 辅助损失: 用于平衡每个expert的使用频率。这部分代码本blog暂时没有给出,有机会再写
-
推理阶段的需求 (由 moe_infer 处理):
- 效率优先: 推理阶段不需要计算梯度,主要目标是最大化前向传播的速度和最小化内存占用。
- moe_infer 中的排序、切片和 scatter_add_(index_add_) 是一种针对 MoE 推理优化的常见模式。它避免了 repeat_interleave 带来的数据冗余,并且 scatter_add_ 可以高效地将稀疏的专家输出累加到最终结果中,这通常比训练分支中的重塑和加权求和更高效,尤其是在处理大量 token 和专家时。


3万+

被折叠的 条评论
为什么被折叠?



