MOE介绍

moe

1. 普通 MoE 模型

"""
此模块实现了普通的 Mixture of Experts (MoE) 模型。
包含门控网络和多个专家网络,输入数据会被所有专家处理。
"""
import torch
import torch.nn as nn

class Expert(nn.Module):
    """
    专家网络模块,简单的全连接神经网络。
    """
    def __init__(self, input_size, hidden_size, output_size):
        """
        初始化专家网络。

        参数:
        input_size (int): 输入特征的维度。
        hidden_size (int): 隐藏层的维度。
        output_size (int): 输出特征的维度。
        """
        super(Expert, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        """
        前向传播方法。

        参数:
        x (torch.Tensor): 输入数据。

        返回:
        torch.Tensor: 专家网络的输出。
        """
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

class GatingNetwork(nn.Module):
    """
    门控网络模块,用于为每个专家分配权重。
    """
    def __init__(self, input_size, num_experts):
        """
        初始化门控网络。

        参数:
        input_size (int): 输入特征的维度。
        num_experts (int): 专家网络的数量。
        """
        super(GatingNetwork, self).__init__()
        self.fc = nn.Linear(input_size, num_experts)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        """
        前向传播方法。

        参数:
        x (torch.Tensor): 输入数据。

        返回:
        torch.Tensor: 每个专家的权重。
        """
        x = self.fc(x)
        x = self.softmax(x)
        return x

class MoE(nn.Module):
    """
    普通 MoE 模型模块,包含多个专家网络和一个门控网络。
    """
    def __init__(self, input_size, hidden_size, output_size, num_experts):
        """
        初始化普通 MoE 模型。

        参数:
        input_size (int): 输入特征的维度。
        hidden_size (int): 专家网络隐藏层的维度。
        output_size (int): 输出特征的维度。
        num_experts (int): 专家网络的数量。
        """
        super(MoE, self).__init__()
        self.experts = nn.ModuleList([Expert(input_size, hidden_size, output_size) for _ in range(num_experts)])
        self.gating_network = GatingNetwork(input_size, num_experts)

    def forward(self, x):
        """
        前向传播方法。

        参数:
        x (torch.Tensor): 输入数据。

        返回:
        torch.Tensor: MoE 模型的最终输出。
        """
        # 计算门控权重
        gates = self.gating_network(x)
        expert_outputs = []
        for expert in self.experts:
            expert_outputs.append(expert(x))
        expert_outputs = torch.stack(expert_outputs, dim=1)
        # 加权求和
        output = torch.sum(gates.unsqueeze(-1) * expert_outputs, dim=1)
        return output

# 使用示例
input_size = 10
hidden_size = 20
output_size = 5
num_experts = 3
model = MoE(input_size, hidden_size, output_size, num_experts)
input_data = torch.randn(1, input_size)
output = model(input_data)
print(output)

2. 普通 MoE(Mixture of Experts)和 Sparse MoE 的区别

普通 MoE
普通 MoE (混合专家模型)由多个专家网络(Expert Networks)和一个门控网络(Gating Network)组成。输入数据会同时传递给所有专家网络,然后门控网络为每个专家网络分配一个权重,最后将所有专家网络的输出按照权重进行加权求和得到最终输出。由于所有专家网络都会处理输入数据,计算开销较大。

Sparse MoE
稀疏 MoE 是普通 MoE 的改进版本。在稀疏 MoE 中,输入数据不会同时传递给所有专家网络,而是根据门控网络的输出,只选择一部分(通常是 top-k 个)专家网络进行处理,从而减少计算开销。这种稀疏性使得模型在保持高性能的同时,能够显著降低计算成本。

class SparseMoE(nn.Module):
    """
    稀疏 MoE 模型模块,包含多个专家网络和一个门控网络。
    仅选择部分专家处理输入数据,对 top-k 专家权重重新归一化。
    """
    def __init__(self, input_size, hidden_size, output_size, num_experts, top_k=2):
        """
        初始化稀疏 MoE 模型。

        参数:
        input_size (int): 输入特征的维度。
        hidden_size (int): 专家网络隐藏层的维度。
        output_size (int): 输出特征的维度。
        num_experts (int): 专家网络的数量。
        top_k (int): 选择的专家数量。
        """
        super(SparseMoE, self).__init__()
        self.experts = nn.ModuleList([Expert(input_size, hidden_size, output_size) for _ in range(num_experts)])
        self.gating_network = GatingNetwork(input_size, num_experts)
        self.top_k = top_k

    def forward(self, x):
        """
        前向传播方法。

        参数:
        x (torch.Tensor): 输入数据。

        返回:
        torch.Tensor: 稀疏 MoE 模型的最终输出。
        """
        # 计算门控权重
        gates = self.gating_network(x)
        # 选择 top-k 个专家
        top_k_values, top_k_indices = torch.topk(gates, self.top_k, dim=1)
        # 对 top-k 权重进行重新归一化
        top_k_values = top_k_values / top_k_values.sum(dim=1, keepdim=True)

        output = torch.zeros(x.size(0), self.experts[0].fc2.out_features).to(x.device)
        for i in range(x.size(0)):
            for j in range(self.top_k):
                expert_index = top_k_indices[i][j]
                expert_weight = top_k_values[i][j]
                expert_output = self.experts[expert_index](x[i].unsqueeze(0))
                output[i] += expert_weight * expert_output.squeeze(0)
        return output

# 使用示例
input_size = 10
hidden_size = 20
output_size = 5
num_experts = 3
top_k = 2
model = SparseMoE(input_size, hidden_size, output_size, num_experts, top_k)
input_data = torch.randn(1, input_size)
output = model(input_data)
print(output)

以上的输入都是 batch_size, input_size的shape
实际输入可能是 batch_size, seq_len, input_size的shape,需要对每个seq_len进行处理,所以需要对输入数据进行reshape,将seq_len维度展开,然后进行处理,最后再将输出reshape回来。
比如

x = x.view(-1, x.size(-1))  # 展平成二维 (batch_size * seq_len, input_size)

然后就可以使用moe和sparse_moe了,最后再将输出reshape回来。

out = out.view(batch_size, seq_len, -1)  # 恢复成三维 (batch_size, seq_len, output_size)

3. sparse_moe_vectorized的实现

当前这种使用显式嵌套循环的实现方式在效率上存在一定的局限性,下面从计算复杂度、内存开销和实际性能等方面进行分析,并给出优化建议。

  1. 计算复杂度
    在 forward 方法里,有两层嵌套循环:外层循环遍历输入样本(for i in range(x.size(0))),内层循环遍历 top-k 个专家(for j in range(self.top_k))。假设输入样本数量为 N,top-k 为 K,那么时间复杂度为 O ( N ∗ K ) O(N * K) O(NK)。每次循环内部都要调用专家网络进行前向传播,这会带来额外的计算开销。当 N 和 K 数值较大时,计算量会显著增加,导致运行时间变长。

  2. 内存开销
    主要的内存开销在于存储门控权重 gates、top_k_values、top_k_indices 以及中间的专家输出 expert_output。随着输入样本数量和专家数量的增多,这些中间变量占用的内存会相应增加。此外,Python 的循环在处理大规模数据时,由于解释器的开销,也会额外消耗一定的内存。

  3. 实际性能
    小批量数据:对于小规模的输入数据(即 N 和 K 较小时),这种实现方式的性能损失并不明显,代码的可读性和简洁性优势得以体现。
    大批量数据:当处理大规模数据时,Python 循环的效率远低于 PyTorch 的向量化操作。因为 PyTorch 的底层是用 C++ 实现的,向量化操作可以充分利用 GPU 的并行计算能力,而 Python 循环则无法有效利用这一点,从而导致性能大幅下降。

向量化后的计算代码如下,参考minimind项目代码。
num_shared_experts = 0的时候没有共享expert, 共享expert是对所有token处理


class SparseMoE2(nn.Module):
    """
    稀疏 MoE 模型模块,包含多个专家网络和一个门控网络。
    仅选择部分专家处理输入数据,对 top-k 专家权重重新归一化。
    """
    def __init__(self, input_size, hidden_size, output_size, num_experts, top_k=2, num_shared_experts=0):
        """
        初始化稀疏 MoE 模型。

        参数:
        input_size (int): 输入特征的维度。
        hidden_size (int): 专家网络隐藏层的维度。
        output_size (int): 输出特征的维度。
        num_experts (int): 专家网络的数量。
        top_k (int): 选择的专家数量。
        """
        super(SparseMoE2, self).__init__()
        self.experts = nn.ModuleList([Expert(input_size, hidden_size, output_size) for _ in range(num_experts)])
        self.gating_network = GatingNetwork(input_size, num_experts)
        self.top_k = top_k

        self.num_shared_experts = num_shared_experts
        if num_shared_experts > 0:
            self.shared_experts = nn.ModuleList([Expert(input_size, hidden_size, output_size) for _ in range(num_shared_experts)])  
        else:
            self.shared_experts = None

    def forward(self, x):
        """
        前向传播方法。

        参数:
        x (torch.Tensor): 输入数据。(token_number, h)

        返回:
        torch.Tensor: 稀疏 MoE 模型的最终输出。
        """
        # 计算门控权重
        gates = self.gating_network(x)
        # 选择 top-k 个专家
        top_k_values, top_k_indices = torch.topk(gates, self.top_k, dim=1) # 都是(token_number, topk)
        # 对 top-k 权重进行重新归一化
        top_k_values = top_k_values / top_k_values.sum(dim=1, keepdim=True)

        flat_topk_idx = top_k_indices.view(-1) # (token_number * topk)
        if self.training:
            x = x.repeat_interleave(self.top_k, dim=0) # (token_number * topk , h)
            # 初始化 y 为正确的形状 (token_number * topk , output_size)
            y = torch.empty(x.size(0), self.experts[0].fc2.out_features, dtype=torch.float16, device=x.device)
            for i, expert in enumerate(self.experts):
                y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype)  # 确保类型一致
            y = y.view(-1, self.top_k, self.experts[0].fc2.out_features)
            output = (y * top_k_values.unsqueeze(-1)).sum(dim=1) # (token_number, topk, h) -> (token_number, h)
        else:
            output = self.moe_infer(x) # 训练时使用的推理方法

        # 共享专家的处理
        if self.shared_experts is not None:
            for expert in self.shared_experts:
                output = output + expert(x)
        return output
    
    @torch.no_grad()
    def moe_infer(self, x):
        """
        前向传播方法。

        参数:
        x (torch.Tensor): 输入数据。(token_number, h)

        返回:
        torch.Tensor: 稀疏 MoE 模型的最终输出。
        """
        # 计算门控权重
        gates = self.gating_network(x)
        # 选择 top-k 个专家
        top_k_values, top_k_indices = torch.topk(gates, self.top_k, dim=1) # 都是(token_number, topk)

        # 对 top-k 权重进行重新归一化
        top_k_values = top_k_values / top_k_values.sum(dim=1, keepdim=True)

        top_k_values = top_k_values.view(-1)                                        
        top_k_indices = top_k_indices.view(-1)                                 # (token_number * topk) 记录的是专家的序号
        idxs = top_k_indices.argsort()                                         # (token_number * topk) idxs 专家0-topk 对应的位置
        tokens_per_expert = top_k_indices.bincount().cpu().numpy().cumsum(0)   # (num_experts) 记录每个专家对应的 token 的数量
        token_idxs = idxs // self.top_k                                        # (token_number * topk) idxs (专家0-topk 对应的位置) -> token_idxs(专家0-topk 对应的token), 因为每个token有 topk个专家
        

        expert_cache = torch.zeros(x.size(0), self.experts[0].fc2.out_features, dtype=torch.float32, device=x.device)
        for i, end_idx in enumerate(tokens_per_expert):
            start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
            if start_idx == end_idx:
                continue
            expert = self.experts[i]
            exp_token_idx = token_idxs[start_idx:end_idx] # 找到每个专家对应的token, 是一个list存储的是专家对应的token的idx
            expert_tokens = x[exp_token_idx]              # 根据token的idx找到对应的token
            expert_out = expert(expert_tokens)            # 计算结果

            # 获取当前专家对应的权重,并扩展维度以匹配 expert_out 的形状
            expert_weights = top_k_values[idxs[start_idx:end_idx]].unsqueeze(-1)
            expert_out.mul_(expert_weights) # 就地乘法
            expert_cache.index_add_(0, exp_token_idx, expert_out)
            #expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, expert_out.shape[-1]), expert_out)
        return expert_cache
if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)

    input_size = 10
    hidden_size = 20
    output_size = 5
    num_experts = 3
    top_k = 2

    # 实例化第一个模型
    torch.manual_seed(42) # 确保 model 的权重是固定的
    model = SparseMoE(input_size, hidden_size, output_size, num_experts, top_k)

    # 实例化第二个模型
    torch.manual_seed(42) # 再次设置种子,以使得 model2 内部模块的初始化顺序与 model 一致
    model2 = SparseMoE2(input_size, hidden_size, output_size, num_experts, top_k)
    # 将 model 的权重加载到 model2
    # 注意:这要求 model 和 model2 的网络结构完全一致,包括参数名
    model2.load_state_dict(model.state_dict())

    # 确保 model2 处于训练模式以执行我们分析的路径
    model.train() # 虽然 SparseMoE 没有显式的 training 属性影响其 forward,但保持一致
    model2.train() # 这会使得 SparseMoE2.training 为 True

    input_data = torch.randn(4, input_size) # 使用 batch_size > 1 测试更充分
    
    output = model(input_data)
    output2 = model2(input_data)
    model2.eval() # 将 model2 设置为评估模式以执行推理
    output2 = model2(input_data) # 重新计算输出以确保它是正确的

    print("Output from SparseMoE:")
    print(output)
    print("\nOutput from SparseMoE2 (with synced weights and training mode):")
    print(output2)

    # 检查是否几乎相等 (由于浮点数精度,可能不是完全相等)
    if torch.allclose(output, output2, atol=1e-3):
        print("\nOutputs are close enough (numerically equivalent).")
    else:
        print("\nOutputs are different.")
        print("Difference:", torch.abs(output - output2).max())

写两个不同的实现主要是为了 性能优化和适应不同阶段的需求 :

  1. 训练阶段的需求 (由 forward 的训练分支处理):

    • 正确性优先: 训练的核心是梯度下降。MoE 模型中的梯度需要通过门控网络和被选中的专家流回。 repeat_interleave 和后续的计算方式确保了计算图能够正确地跟踪每个 token 到其所有选中专家的路径,从而计算出正确的梯度。这种方式可能涉及更多的数据复制和中间张量,但保证了训练的数学正确性。
    • 辅助损失: 用于平衡每个expert的使用频率。这部分代码本blog暂时没有给出,有机会再写
  2. 推理阶段的需求 (由 moe_infer 处理):

    • 效率优先: 推理阶段不需要计算梯度,主要目标是最大化前向传播的速度和最小化内存占用。
    • moe_infer 中的排序、切片和 scatter_add_(index_add_) 是一种针对 MoE 推理优化的常见模式。它避免了 repeat_interleave 带来的数据冗余,并且 scatter_add_ 可以高效地将稀疏的专家输出累加到最终结果中,这通常比训练分支中的重塑和加权求和更高效,尤其是在处理大量 token 和专家时。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值