社区供稿 | Mixtral8x7B Pytorch 实现
0.前言
本文从代码角度来谈下 Mixtral 8x7B
混合专家Pytorch
的实现
1.论文概述

Mixtral-8x7B
引爆了MoE
的技术方向,更多针对MoE
优化的Trick
出现,回归模型本身来解析:
-
Mixtral 8x7B
采用了sMoE
模型结构,模型的细节如何?路由负载均衡如何计算?代码如何实现? -
Mixtral 8x7B
的训练流程和推理流程是怎么样的,如何提高训练和推理效率? -
Mixtral 8x7B
的模型参数是如何计算的? -
Mixtral 8x7B
性能硬刚LLaMA2-70B
和GPT-3.5
, 性能一线水准,在MBPP
代码能力超越3.5

2. Mixtral 8x7B 模型架构和计算流程
Mixtral is based on a transformer architecture [31] and uses the same modifications as described in [18], with the notable exceptions that Mixtral supports a fully dense context length of 32k tokens, and the feed forward blocks are replaced by Mixture-of-Expert layers (Section 2.1). The model architecture parameters are summarized in Table 1.
-
base
的模型结构为Transformers
的改版Mistral-7B
-
MoE
作用在Feed Forward Blocks
上

2.1 Mixtral 模型架构
In a Transformer model, the MoE layer is applied independently per token and replaces the feed-forward (FFN) sub-block of the transformer block. For Mixtral we use the same SwiGLU architecture as the expert function Ei(x) and set K = 2. This means each token is routed to two SwiGLU sub-blocks with different sets of weights. Taking this all together, the output y for an input token x is computed as:
-
以
LLaMA2
或Mistral-7B
来说其MLP
都是SwiGLU
形式 -
在
Mixtral-8x7B
中 每层的Decoder
层的MLP
都以sMoE
来替换掉

Transformers Mixtral-of-Expert
代码实现:
在Huggingface
的Transformers
框架中, Mixtral
主要有两部分组成
-
MixtralDecoderLayer
-
MixtralSparseMoeBlock
:替换掉原有的MLP层
MixtralForCausalLM(<br> (model): MixtralModel(<br> (embed_tokens): Embedding(32000, 128)<br> (layers): ModuleList(<br> (1): MixtralDecoderLayer(<br> (self_attn): MixtralAttention(<br> (q_proj): Linear(in_features=128, out_features=128, bias=False)<br> (k_proj): Linear(in_features=128, out_features=128, bias=False)<br> (v_proj): Linear(in_features=128, out_features=128, bias=False)<br> (o_proj): Linear(in_features=128, out_features=128, bias=False)<br> (rotary_emb): MixtralRotaryEmbedding()<br> )<br> (block_sparse_moe): MixtralSparseMoeBlock(<br> (gate): Linear(in_features=128, out_features=8, bias=False)<br> (experts): ModuleList(<br> (0-7): 8 x MixtralBLockSparseTop2MLP(<br> (w1): Linear(in_features=128, out_features=256, bias=False)<br> (w2): Linear(in_features=256, out_features=128, bias=False)<br> (w3): Linear(in_features=128, out_features=256, bias=False)<br> (act_fn): SiLU()<br> )<br> )<br> )<br> (input_layernorm): MixtralRMSNorm()<br> (post_attention_layernorm): MixtralRMSNorm()<br> )<br> )<br> (norm): MixtralRMSNorm()<br> )<br>
2.2 SMoE 层实现
2.2.1 单个 Expert 实现
import torch<br>from torch import nn<br>from transformers import MixtralConfig<br><br>class MixtralBLockSparseTop2MLP(nn.Module):<br> def __init__(self, config: MixtralConfig):<br> super().__init__()<br> self.ffn_dim = config.intermediate_size<br> self.hidden_dim = config.hidden_size<br><br> self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)<br> self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)<br> self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)<br><br> self.act_fn = nn.SiLU()<br><br> # Forward 是 SwiGLU<br> def forward(self, hidden_states):<br> y = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)<br> y = self.w2(y)<br> return y<br><br>x = torch.randn(1, 64, 128)<br>expert = MixtralBLockSparseTop2MLP(config)<br>print('单个专家为原LLaMA的MLP层')<br>print(expert)<br>g = expert(x)<br>print('单个专家输入:', x.shape)<br>print('单个专家输出结果:', g.shape)
结果
单个专家为原LLaMA的MLP层<br>MixtralBLockSparseTop2MLP(<br> (w1): Linear(in_features=128, out_features=256, bias=False)<br> (w2): Linear(in_features=256, out_features=128, bias=False)<br> (w3): Linear(in_features=128, out_features=256, bias=False)<br> (act_fn): SiLU()<br>)<br>单个专家输入:<br>torch.Size([1, 64, 128])<br>单个专家输出结果:<br>torch.Size([1, 64, 128])<br>
2.2.2 混合Expert实现
class MixtralSparseMoeBlock(nn.Module):<br> def __init__(self, config):<br> super().__init__()<br> self.hidden_dim = config.hidden_size<br> self.ffn_dim = config.intermediate_size<br> self.num_experts = config.num_local_experts<br> self.top_k = config.num_experts_per_tok<br><br> # gating<br> self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)<br><br> # 多个 SwiGLU MLP 层组成混合专家<br> self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) <br> for _ in range(self.num_experts)])<br><br>x = torch.randn(1, 64, 128)<br>experts = MixtralSparseMoeBlock(config)<br>print('多个专家混合专家')<br>print(experts)<br>
在以上我们实现了模型的关键结构, 但是这里的sMoE
的Forward
并没有实现
2.3 SMoE 计算流程

2.3.1 Gating流程
以下表示为多个token
的gating
计算流程
# 阶段一<br># 计算稀疏 gating 值<br>tokens = 6<br>x = torch.randn(1, tokens, 128) # 6个token<br>hidden_states = x<br>batch_size, sequence_length, hidden_dim = hidden_states.shape<br>hidden_states = hidden_states.view(-1, hidden_dim)<br><br> # 每层都会产生router_logits, 将用于最后作 load balance loss<br>router_logits = experts.gate(hidden_states)<br>print(f'experts.gate output router logits : n {router_logits}')<br><br># 计算 TopK 的 专家 logits 和 Top2 专家的位置<br>routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)<br>print(f'softmax weight : n {routing_weights}')<br><br>routing_weights, selected_experts = torch.topk(routing_weights, <br> experts.top_k, dim=-1)<br>print(f'expert select : n {selected_experts}')<br>print(f'topk : n {routing_weights}')<br><br>routing_weights /= routing_weights.sum(dim=-1, keepdim=True)<br>print(f'topk归一化 : n {routing_weights}')<br><br>routing_weights = routing_weights.to(hidden_states.dtype)<br><br>## One Hot 编码<br>expert_mask = torch.nn.functional.one_hot(selected_experts, <br> num_classes=experts.num_experts).permute(2, 1, 0)<br>for i in range(tokens):<br> print(f'【token_{i}】n', expert_mask[:,:,i])<br>

追踪x3
的结果

2.3.2 Expert 流程
-
sMoE
中是基于专家来选择token
来计算的 -
token
先序:左图为token3
选择expert 2
,expert 3
号来计算sMoE
结果 -
expert
先序:右图为依次计算expert2
和expert3
才得出token3
的sMoE
结果

代码实现结果为:
## 最终结果<br>final_hidden_states = torch.zeros(<br> (batch_size * sequence_length, hidden_dim), <br> dtype=hidden_states.dtype, device=hidden_states.device<br>)<br>print(f'final moe result shape for each token: {final_hidden_states.shape}')<br><br># 每个专家收集需要计算token<br>for expert_idx in range(experts.num_experts):<br><br> print(f'--------expert {expert_idx} ---------')<br><br> expert_layer = experts.experts[expert_idx]<br> print(expert_mask[expert_idx])<br> idx, top_x = torch.where(expert_mask[expert_idx])<br> print(f'专家 {expert_idx} 计算的样本编号:',top_x.tolist()) # select x_idx for expert top1<br> print(f'专家 {expert_idx} top1:0, top2:1 ',idx.tolist()) # 0 is top1 ,1 is top2<br> print(f'有 {len(top_x)} / {x.shape[1]} token 选到专家 {expert_idx}')<br> <br> top_x_list = top_x.tolist()<br> idx_list = idx.tolist()<br><br> current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)<br><br> # expert_0(x) * routing_weights<br> current_hidden_states = expert_layer(current_state) <br> * routing_weights[top_x_list, idx_list, None]<br><br> # 将计算的单个专家结果填入到结果表里<br> final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))<br><br> print(current_state.shape) <br> print(routing_weights[top_x_list, idx_list, None].shape)<br> print(current_hidden_states.shape)<br> print(final_hidden_states.shape)<br>
输出结果为:

2.4 Router Load Balence 计算
路由负载均衡的实现来自Switch Transformers
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. See Switch Transformer for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced.
2.4.1 Switch Transformers Load Balance Loss
该算法为sMoE
简化版load balance
, 去除了原版 balance loss 估计
fi
:在一个batch
中第i
专家分配到token
的数量概率
Pi
:在一个batch
中T
个tokens
,各个专家选到tokens
的概率和
2.4.2 手撕Mixtral Load Balance Loss 计算流程
可以想象下layer norm
只是在当前层里对所有tokens
做,而负载均衡处理范围更广,对所有层的tokens
,在每个expert
的纵向计算出单专家负载值,求和便得到整个网络的负载均衡 loss

2.4.3 手撕Mixtral Load Balance
import torch<br><br>num_experts = 8<br>batch = 10<br>seq_length = 6<br>top_k = 2<br><br>print(f'sMoE num_experts:{num_experts} top_k:{top_k} batch:{batch} seq_length:{seq_length}')<br><br>router_logits_1 = torch.randn(batch, seq_length, num_experts).view(-1,num_experts) # layer 1<br>router_logits_2 = torch.randn(batch, seq_length, num_experts).view(-1,num_experts) # layer 2<br>router_logits = [router_logits_1, router_logits_2] <br><br>concatenated_gate_logits = torch.cat(router_logits, dim = 0)<br>print('单层gating的路由logits:', router_logits_1.shape) <br>print('两层gating的路由logits:', concatenated_gate_logits.shape)<br><br>print('根据logits top-k 计算热独编码')<br>routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)<br>_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)<br>expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)<br>print(expert_mask.shape)<br><br>tokens_sum_expert = torch.sum(expert_mask.float(), dim=0)<br>tokens_per_expert = torch.mean(expert_mask.float(), dim=0)<br>print(f'top1 每个专家平均处理的token :', tokens_sum_expert[0])<br>print(f'top2 每个专家平均处理的token fi:', tokens_per_expert[1])<br>print(f'top1与top2水平合计', tokens_per_expert.sum(dim=1))<br><br># Compute the average probability of routing to these experts<br>router_prob_per_expert = torch.mean(routing_weights, dim=0)<br>print('router_prob_per_expert Pi: ' , router_prob_per_expert)<br><br>print( '每个专家的负载:', tokens_per_expert * router_prob_per_expert.unsqueeze(0))<br>overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))<br>print('final loss:', overall_loss)<br><br>
计算结果
sMoE num_experts:8 top_k:2 batch:10 seq_length:6<br>单层gating的路由logits:<br>torch.Size([60, 8])<br>两层gating的路由logits:<br>torch.Size([120, 8])<br>根据logits top-k 计算热独编码<br>torch.Size([120, 2, 8])<br>top1 每个专家平均处理的token : tensor([10., 14., 19., 17., 14., 9., 17., 20.])<br>top2 每个专家平均处理的token fi: tensor([0.1667, 0.1333, 0.1833, 0.0833, 0.1167, 0.1500, 0.0667, 0.1000])<br>top1与top2水平合计 tensor([1., 1.])<br>router_prob_per_expert Pi: tensor([0.1236, 0.1184, 0.1351, 0.1168, 0.1311, 0.1147, 0.1156, 0.1447])<br>每个专家的负载:tensor([[0.0103, 0.0138, 0.0214, 0.0165, 0.0153, 0.0086, 0.0164, 0.0241],<br> [0.0206, 0.0158, 0.0248, 0.0097, 0.0153, 0.0172, 0.0077, 0.0145]])<br>final loss: tensor(0.2520)<br>
这里的gating logits
是跨batch
跨层的,作用在每个token
上
-
Mixtral 8x7B 参数量计算
3.1 原论文描述
这里的
13B
是指单个token
涉及的模型参数量,实际推理时每个token
都有不同的expert
,那么实际运行还是跑47B
参数的, 使用了sMoE 并不会减少显存占用。3.2 模型参数量计算
忽略
GQA
计算dim = 4096<br>n_layers = 32<br>head_dim = 128<br>hidden_dim = 14336<br>n_heads = 32<br>n_kv_heads = 8# ignore GQA<br>vocab_size = 32000<br>num_experts = 8<br>top_k_experts = 2<br><br># attention mlp layernorm<br>llama_num = n_layers * (dim * dim * 4 + hidden_dim * dim * 3 + 2 * dim ) <br> + 2 * vocab_size * dim <br>print('llama:', llama_num)<br><br># attention 【mlp*8】 layernorm<br>moe_num = n_layers * (dim * dim * 4 + hidden_dim * dim * 3 * 8 + 2 * dim ) <br> + 2 * vocab_size * dim <br>print('moe:', moe_num)<br><br># attention 【mlp*2】 layernorm<br># ToP2-inference<br>moe_num = n_layers * (dim * dim * 4 + hidden_dim * dim * 3 * 2 + 2 * dim ) <br> + 2 * vocab_size * dim <br>print('moe top-2:', moe_num)<br>
结果
llama: 8047034368<br>moe: 47507046400<br>moe top-2: 13684178944<br>
-
MoE 扩展
4.1 MegaBlocks
MoE layers can be run efficiently on single GPUs with high performance specialized kernels. For example, Megablocks
MegaBlocks
实现稀疏的MoE
计算题外话:
XFormers
也实现了类似思想的算子,batch
里的attention
通过Mask
实现多序列稀疏计算。4.2 GShard
Mixtral
论文里在load balance
里提了一下GShard
, 是首篇将MoE
引入到Transformers
的工作This formulation is similar to the GShard architecture [21], with the exceptions that we replace all FFN sub-blocks by MoE layers while GShard replaces every other block, and that GShard uses a more elaborate gating strategy for the second expert assigned to each token.
GShard
在不同GPU
上分配不同的专家,其他参数都共享,数据派发到专家,专家结果汇总都由All-to-All
算子实现DeepSpeed-MoE源码对
All-to-All
的实现如下class _AllToAll(torch.autograd.Function):<br><br> @staticmethod<br> def forward(<br> ctx: Any,<br> # TODO: replace with DS process group<br> group: torch.distributed.ProcessGroup,<br> input: Tensor) -> Tensor:# type: ignore<br> ctx.group = group<br> input = input.contiguous()<br> output = torch.empty_like(input)<br> dist.all_to_all_single(output, input, group=group)<br> return output<br><br> @staticmethod<br> def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:<br> return (None, _AllToAll.apply(ctx.group, *grad_output))<br> <br>class MOELayer(Base):<br> # ...<br> def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:<br> # ...<br> dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)<br><br> # Re-shape after all-to-all: ecm -> gecm<br> dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)<br><br> expert_output = self.experts(dispatched_input)<br><br><br> expert_output = _AllToAll.apply(self.ep_group, expert_output)<br><br> #...<br>
4.3 DeepSpeed-MoE
-
更加工程化的实现可以看
DeepSpeed-MoE
的开源方案 -
MoE
层使用Expert-Paralallelism
做并行AlltoAll
实现如上 -
非
MoE
层使用TP+DP
4.4 LLaMA-MoE
Mixtral 8x7B
训不动?试试将LLaMA
原MLP
改造成LLaMA-MoE
LLaMA-MoE 上关键代码是用
LinearGLUExperts
代替原本LLaMA
里的SwiGLU
层class LinearGLUExperts(nn.Module):<br> # ...<br> def __init__(...):<br> # ... <br> # 每个专家都创建SwiGLU MLP层<br> for i in range(num_experts):<br> # this matrix will be transposed when performing linear forwarding<br> this_expert_weight_gate = nn.Parameter(<br> torch.empty((size_experts[i], in_features), **factory_kwargs)<br> )<br> # this matrix will be transposed when performing linear forwarding<br> this_expert_weight_up = nn.Parameter(<br> torch.empty((size_experts[i], in_features), **factory_kwargs)<br> )<br> # this matrix will be transposed when performing linear forwarding<br> this_expert_weight_down = nn.Parameter(<br> torch.empty((out_features, size_experts[i]), **factory_kwargs)<br> )<br> self.weight_gate.append(this_expert_weight_gate)<br> self.weight_up.append(this_expert_weight_up)<br> self.weight_down.append(this_expert_weight_down)<br> # ...<br>
-
更加工程化的实现可以看
-
Mixtral 8x7B 总结 & 进一步阅读
-
Mixtral 8x7B
实现并不复杂,其中load-balance loss
是expert-wise
维度计算的 -
当前发布的模型还是围绕模型结构展开的, 期待
mistral.AI
上线创新的对齐方案 -
涉及到多机多卡的
sMoE
分布式训练非常需要工程技巧, 不同的模型架构和集群可以有多种DPTPEP..
组合方案, - 在·Mixtral·中对于实验反直觉论点 专家的知识是作用在 token 级别,而不是domain级别,对 MoE 感兴趣的话可以进一步开盒分析
Reference
- Mixture of Experts Explained
- 方佳瑞:MoE训练论文解读之Megablocks:打破动态路由限制
- 方佳瑞:MoE训练系统之JANUS:参数服务器助力MoE训练
- 方佳瑞:MoE训练论文解读之Tutel: 动态切换并行策略实现动态路由
- 西门宇少:对MoE大模型的训练和推理做分布式加速——DeepSpeed-MoE论文速读
- 吃果冻不吐果冻皮:大模型分布式训练并行技术(八)-MOE并行
- 孟繁续:Mixtral-8x7B 模型挖坑
- Mixtral-of-experts
- Mistral-7B
- Gshard
- Switch Transformers
- sMoE
- Transformers-Mixtral-of-Experts
- DeepSpeed-MoE
- Megablocks
- LLaMA-MoE
本文由 Hugging Face 中文社区内容共建项目提供,稿件由社区成员投稿,经授权发布于 Hugging Face 公众号。文章内容不代表官方立场,文中介绍的产品和服务等均不构成投资建议。了解更多请关注公众号:如果你有与开源 AI、Hugging Face 相关的技术和实践分享内容,以及最新的开源 AI 项目发布,希望通过我们分享给更多 AI 从业者和开发者们,请通过下面的链接投稿与我们取得联系:
-
本文分享自微信公众号 - Hugging Face(gh_504339124f0f)。
如有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。