MOELayer DEMO及注释
MOELayer DEMO及注释
import copy
import torch
from typing import Any
from typing import Callable, Dict, Tuple
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
torch.manual_seed(42)class MixtralParallelMLPBM(Module):def __init__(self,hidden_size,ffn_hidden_size):super(MixtralParallelMLPBM,self).__init__()self.w1 = torch.nn.Linear(hidden_size,ffn_hidden_size)self.w2 = torch.nn.Linear(ffn_hidden_size,hidden_size)self.w3 = torch.nn.Linear(hidden_size,ffn_hidden_size)self.act_fn = F.siludef forward(self, hidden_states):print("hidden_states:",hidden_states.shape)current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)current_hidden_states = self.w2(current_hidden_states)return current_hidden_statesclass Experts(torch.nn.Module):def __init__(self, expert, num_local_experts=1):super(Experts, self).__init__()self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])self.num_local_experts = num_local_expertsdef forward(self, inputs):print("Experts input:",inputs.shape)chunks = inputs.chunk(self.num_local_experts, dim=1)expert_outputs = []for chunk, expert in zip(chunks, self.experts):print("chunk:",chunk.shape)chunk = torch.squeeze(chunk, dim=1).contiguous()print("chunk:",chunk.shape)out = expert(chunk)print("expert out:",out.shape)if type(out) is tuple:out, bias = outif bias is not None:out = out + biasout = torch.unsqueeze(out, dim=1)expert_outputs += [out]expert_output = torch.cat(expert_outputs, dim=1)return expert_outputdef _one_hot_to_float(x, num_classes):return F.one_hot(x, num_classes=num_classes).float()def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor):# gates has shape of S,Enum_tokens = gates.shape[0]num_experts = gates.shape[1]max_capacity = num_tokens# to(torch.int64) works around a bug in torch.onnx.export:# it should cast k to int64 when converting torch.topk but it doesn't.capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)if capacity < min_capacity:capacity = min_capacity.to(torch.int64)elif capacity > max_capacity:capacity = torch.tensor(max_capacity, dtype=torch.int64)return capacitydef top1gating(logits):"""Implements Top1Gating on logits."""# everything is in fp32 in this function# token_sel_expert_weights: [S, E], 每个token选择每个专家的概率token_sel_expert_weights = F.softmax(logits, dim=1) #16,4print(f"5.softmax:{token_sel_expert_weights.shape} token_sel_expert_weights:\n{token_sel_expert_weights}")'''tensor([[0.5426, 0.1172, 0.0655, 0.2747],[0.1293, 0.1390, 0.1795, 0.5521],[0.5180, 0.0419, 0.2816, 0.1584],[0.2191, 0.2966, 0.1691, 0.3152],[0.2212, 0.3157, 0.1812, 0.2819],[0.1572, 0.2165, 0.2931, 0.3332],[0.3198, 0.0820, 0.2499, 0.3483],[0.1738, 0.1981, 0.1453, 0.4828],[0.1618, 0.2546, 0.1643, 0.4193],[0.2306, 0.1819, 0.2694, 0.3181],[0.1739, 0.0921, 0.1228, 0.6112],[0.1355, 0.2796, 0.1024, 0.4826],[0.3720, 0.1553, 0.1946, 0.2781],[0.2496, 0.4208, 0.1395, 0.1901],[0.2637, 0.1050, 0.2761, 0.3551],[0.2899, 0.1759, 0.3855, 0.1488]] '''capacity = _capacity(token_sel_expert_weights, torch.tensor(1.1),torch.tensor(4))print("6.top1gating capacity:",capacity)# [S] 每个token对应的专家(取概率最大的)token_sel_expert_idx = torch.argmax(token_sel_expert_weights, dim=1) #[16]print("7.每个token对应的专家:",token_sel_expert_idx.shape,"data:",token_sel_expert_idx)#[3, 1, 3, 0, 1, 0, 2, 0, 3, 0, 1, 1, 1, 1, 0, 2]num_experts = int(token_sel_expert_weights.shape[1])token_sel_expert_mask = F.one_hot(token_sel_expert_idx, num_classes=num_experts)print("8.one_hot 编码:",token_sel_expert_mask.shape,"token_sel_expert_mask:\n",token_sel_expert_mask) #16,4'''tensor([[1, 0, 0, 0],[0, 0, 0, 1],[1, 0, 0, 0],[0, 0, 0, 1],[0, 1, 0, 0],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[1, 0, 0, 0],[0, 1, 0, 0],[0, 0, 0, 1],[0, 0, 1, 0]])'''# 通过topC每个专家选择至多C个token,然后和原始的mask1(每个专家可能选择超过C个token)矩阵相乘,# 丢掉超过专家容量的权重低的token,更新得到 token_sel_expert_maskexpert_sel_top_c_token_idx = torch.topk(token_sel_expert_mask, k=capacity, dim=0)[1]#5,4print(f"9:获取top{capacity}:{expert_sel_top_c_token_idx.shape} expert_sel_top_c_token_idx:\n{expert_sel_top_c_token_idx}")'''tensor([[ 0, 4, 15, 1],[ 2, 13, 0, 3],[12, 0, 1, 5],[ 1, 1, 2, 6],[ 3, 2, 3, 7]]) ''' mask=torch.zeros_like(token_sel_expert_mask).scatter_(0, expert_sel_top_c_token_idx,1)print(f"10.将上面index所在的位置填成1:{mask.shape},mask:\n{mask}")'''tensor([[1, 1, 1, 0],[1, 1, 1, 1],[1, 1, 1, 0],[1, 0, 1, 1],[0, 1, 0, 0],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0],[1, 0, 0, 0],[0, 1, 0, 0],[0, 0, 0, 0],[0, 0, 1, 0]]) ''' token_sel_expert_mask *= maskprint(f"11.生成最后的mask:{token_sel_expert_mask.shape} token_sel_expert_mask:\n{token_sel_expert_mask}")'''tensor([[1, 0, 0, 0],[0, 0, 0, 1],[1, 0, 0, 0],[0, 0, 0, 1],[0, 1, 0, 0],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0],[1, 0, 0, 0],[0, 1, 0, 0],[0, 0, 0, 0],[0, 0, 1, 0]]) '''# Normalize gate probabilitiestoken_sel_expert_mask_float = token_sel_expert_mask.float()token_sel_expert_weights = token_sel_expert_weights * token_sel_expert_mask_floatprint(f"12.用mask去取softmax后的值:{token_sel_expert_weights.shape},token_sel_expert_weights:\n{token_sel_expert_weights}")'''tensor([[0.5426, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.5521],[0.5180, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.3152],[0.0000, 0.3157, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.3332],[0.0000, 0.0000, 0.0000, 0.3483],[0.0000, 0.0000, 0.0000, 0.4828],[0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000],[0.3720, 0.0000, 0.0000, 0.0000],[0.0000, 0.4208, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.3855, 0.0000]] '''token_idx_in_expert_with_noise = torch.cumsum(token_sel_expert_mask, dim=0) - 1print(f"13.token_idx_in_expert_with_noise:{token_idx_in_expert_with_noise.shape} token_idx_in_expert_with_noise:\n{token_idx_in_expert_with_noise}")'''tensor([[ 0, -1, -1, -1],[ 0, -1, -1, 0],[ 1, -1, -1, 0],[ 1, -1, -1, 1],[ 1, 0, -1, 1],[ 1, 0, -1, 2],[ 1, 0, -1, 3],[ 1, 0, -1, 4],[ 1, 0, -1, 4],[ 1, 0, -1, 4],[ 1, 0, -1, 4],[ 1, 0, -1, 4],[ 2, 0, -1, 4],[ 2, 1, -1, 4],[ 2, 1, -1, 4],[ 2, 1, 0, 4]]) '''masked_token_idx_in_expert = token_idx_in_expert_with_noise * token_sel_expert_maskprint(f"14.masked_token_idx_in_expert:{masked_token_idx_in_expert.shape} masked_token_idx_in_expert:\n{masked_token_idx_in_expert}")'''tensor([[0, 0, 0, 0],[0, 0, 0, 0],[1, 0, 0, 0],[0, 0, 0, 1],[0, 0, 0, 0],[0, 0, 0, 2],[0, 0, 0, 3],[0, 0, 0, 4],[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0],[2, 0, 0, 0],[0, 1, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0]]) '''token_offset_for_expert = torch.sum(masked_token_idx_in_expert, dim=1)print(f"15.token_offset_for_expert:{token_offset_for_expert.shape} token_offset_for_expert:\n{token_offset_for_expert}")'''tensor([0, 0, 1, 1, 0, 2, 3, 4, 0, 0, 0, 0, 2, 1, 0, 0])'''token_locations_sc = _one_hot_to_float(token_offset_for_expert, capacity)print(f"16.token_locations_sc:{token_locations_sc.shape} token_locations_sc:\n{token_locations_sc}")'''tensor([[1., 0., 0., 0., 0.],[1., 0., 0., 0., 0.],[0., 1., 0., 0., 0.],[0., 1., 0., 0., 0.],[1., 0., 0., 0., 0.],[0., 0., 1., 0., 0.],[0., 0., 0., 1., 0.],[0., 0., 0., 0., 1.],[1., 0., 0., 0., 0.],[1., 0., 0., 0., 0.],[1., 0., 0., 0., 0.],[1., 0., 0., 0., 0.],[0., 0., 1., 0., 0.],[0., 1., 0., 0., 0.],[1., 0., 0., 0., 0.],[1., 0., 0., 0., 0.]]) '''combine_weights = torch.einsum("se,sc->sec", token_sel_expert_weights, token_locations_sc)#16,4 16,5 -> 16,4,5 #每一个token,在4个专家,哪一个容器里print(f"17.combine_weights:{combine_weights.shape} combine_weights:\n{combine_weights}")'''tensor([[[0.5426, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.5521, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.5180, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.3152, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.3157, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.3332, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.3483, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.4828]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.3720, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.4208, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.3855, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]] '''dispatch_mask = combine_weights.bool()return combine_weights, dispatch_maskclass TopKGate(Module):weight: torch.nn.Lineardef __init__(self,hidden_size, num_experts) -> None:super(TopKGate,self).__init__()self.weight = torch.nn.Linear(hidden_size, num_experts, bias=False).float()def forward(self, gate_input):input_fp32 = gate_input.float() logits = torch.nn.functional.linear(input_fp32, weight=self.weight.weight.float(), bias=None)print("4.TopKGate输入:",input_fp32.shape,"权值:",self.weight.weight.shape,"logits输出:",logits.shape)#16, 4gate_output = top1gating(logits)return gate_outputclass MOELayer(Module):def __init__(self,gate: Module,experts: Module,ep_size,num_local_experts,pipe_experts: bool = False,sequence_parallel: bool = True,pipe_experts_multi_data: int = 1,pipe_experts_multi_stream: bool = False) -> None:super().__init__()self.gate = gateself.experts = expertsself.ep_group = Noneself.ep_size = ep_sizeself.num_local_experts = num_local_expertsself.num_experts = ep_size * num_local_expertsself.exp_counts = Noneself.l_aux = Nonedef set_ep_group(self, ep_group):self.ep_group = ep_groupdef forward(self, input, **kwargs):'''一.目的:不同的expert负责不同的token二.主要步骤:1.生成特征分解矩阵,将输入token的特征拆解放在E(专家个数)C(每个专家的容器数)M(每个token的特征)的容器中矩阵每一个坐标内的值代表在某个维度上按多少比列分解特征,如果在某个维度上求和,就相当于对拆分后的特征进行加权求和2.通过矩阵乘,将输入token的特征拆解到以上矩阵(相当于用ECM的容器在装载、交换、变换特征,最后再将这个拆解后的特征加权[矩阵乘]还原到原始的维度)3.通过all2ll将分在不同ep rank的特征拉到各自己对应expert所在的rank上4.每个ep节点负责num_local_experts个expert。将上面的特征拆成num_local_experts块,分别送给不同的expert,之后合并结果5.将上面的结果通过all2all还原回之前每个RANK的排列顺序6.将分开的特征加权合并,生成(seq_len,hidden_size)的维度'''#input: 16,1,64d_model = input[0].shape[-1]reshaped_input = input[0].reshape(-1, d_model)#reshaped_input:16,64print("3.将维度转换为二维度(seq_len*batch_size,hidden_size):",reshaped_input.shape)# gatecombine_weights, dispatch_mask = self.gate(reshaped_input)print(combine_weights.shape,dispatch_mask.shape)dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) #16,4,5 16,64 -> 4 5 64 4个专家,5个容器,每个器放64个feature#将特征放在固定大小的容器里,防止了不均衡#dispatch_mask是token的分配矩阵,reshaped_input是每个token的特征,结果相当于将reshaped_input放在dispatch_mask里(加权存放)print(f"18.dispatched_input:{dispatched_input.shape},dispatched_input:\n{dispatched_input}")# dispatch all2all#ep是对expert进行拆分,每个expert承接一部分输入,all2ll之后是将前一半数据放在rank0,后一半数据放在rank1#dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)# Re-shape after all-to-all: ecm -> gecmdispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)#ep个专家组,每个组里2个专家,每个专用一个MixtralParallelMLPBM去提特征,每个MixtralParallelMLPBM的按tp并行计算,最后拼接在一起print("dispatched_input:",dispatched_input.shape)#每个expert计算一部分特征expert_output = self.experts(dispatched_input)print("expert_output:",expert_output.shape)# combine all2all#将特征还原回之前每个RANK的排列顺序,其实就相当于,通过all2all将特征当前对应的专家所在的rank上计算,计算完之后再放回去#expert_output = _AllToAll.apply(self.ep_group, expert_output)# Re-shape back: gecm -> ecmexpert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)##16,4,5 4,5,64 -> 16,64 将分开的特征加权合并,输出最终的16,64的特征combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)return combined_output.reshape(input[0].shape)'''
模型配置
TP=2
PP=2
DP=2
EP=2
num_experts=4
''''''
all tp gourps [[0, 1], [2, 3], [4, 5], [6, 7]]
all ep groups [[0, 2], [1, 3], [4, 6], [5, 7]]
all dp groups [[0, 2], [1, 3], [4, 6], [5, 7]]
all pp gourps [[0, 4], [1, 5], [2, 6], [3, 7]]
'''def main():num_experts=4num_local_experts=2seq_len=16batch_size=1hidden_size=8ffn_hidden_size=16ep_size=2gate = TopKGate(hidden_size,num_experts)moe = MOELayer(gate, Experts(MixtralParallelMLPBM(hidden_size,ffn_hidden_size),num_local_experts),ep_size,num_local_experts)input=torch.randn(seq_len, batch_size, hidden_size,dtype=torch.float32)print("1.原始的输入shape(32,1,64),因为序列并行,进入到MOE时维度为(16,1,64)")print("2.输入数列的shpae:",input.shape)output = moe([input]) print("output:",output.shape)main()