当前位置: 首页 > news >正文

MOELayer DEMO及注释

MOELayer DEMO及注释

import copy
import torch
from typing import Any
from typing import Callable, Dict, Tuple
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
torch.manual_seed(42)class MixtralParallelMLPBM(Module):def __init__(self,hidden_size,ffn_hidden_size):super(MixtralParallelMLPBM,self).__init__()self.w1 = torch.nn.Linear(hidden_size,ffn_hidden_size)self.w2 = torch.nn.Linear(ffn_hidden_size,hidden_size)self.w3 = torch.nn.Linear(hidden_size,ffn_hidden_size)self.act_fn = F.siludef forward(self, hidden_states):print("hidden_states:",hidden_states.shape)current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)current_hidden_states = self.w2(current_hidden_states)return current_hidden_statesclass Experts(torch.nn.Module):def __init__(self, expert, num_local_experts=1):super(Experts, self).__init__()self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])self.num_local_experts = num_local_expertsdef forward(self, inputs):print("Experts input:",inputs.shape)chunks = inputs.chunk(self.num_local_experts, dim=1)expert_outputs = []for chunk, expert in zip(chunks, self.experts):print("chunk:",chunk.shape)chunk = torch.squeeze(chunk, dim=1).contiguous()print("chunk:",chunk.shape)out = expert(chunk)print("expert out:",out.shape)if type(out) is tuple:out, bias = outif bias is not None:out = out + biasout = torch.unsqueeze(out, dim=1)expert_outputs += [out]expert_output = torch.cat(expert_outputs, dim=1)return expert_outputdef _one_hot_to_float(x, num_classes):return F.one_hot(x, num_classes=num_classes).float()def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor):# gates has shape of S,Enum_tokens = gates.shape[0]num_experts = gates.shape[1]max_capacity = num_tokens# to(torch.int64) works around a bug in torch.onnx.export:# it should cast k to int64 when converting torch.topk but it doesn't.capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)if capacity < min_capacity:capacity = min_capacity.to(torch.int64)elif capacity > max_capacity:capacity = torch.tensor(max_capacity, dtype=torch.int64)return capacitydef top1gating(logits):"""Implements Top1Gating on logits."""# everything is in fp32 in this function# token_sel_expert_weights: [S, E], 每个token选择每个专家的概率token_sel_expert_weights = F.softmax(logits, dim=1) #16,4print(f"5.softmax:{token_sel_expert_weights.shape} token_sel_expert_weights:\n{token_sel_expert_weights}")'''tensor([[0.5426, 0.1172, 0.0655, 0.2747],[0.1293, 0.1390, 0.1795, 0.5521],[0.5180, 0.0419, 0.2816, 0.1584],[0.2191, 0.2966, 0.1691, 0.3152],[0.2212, 0.3157, 0.1812, 0.2819],[0.1572, 0.2165, 0.2931, 0.3332],[0.3198, 0.0820, 0.2499, 0.3483],[0.1738, 0.1981, 0.1453, 0.4828],[0.1618, 0.2546, 0.1643, 0.4193],[0.2306, 0.1819, 0.2694, 0.3181],[0.1739, 0.0921, 0.1228, 0.6112],[0.1355, 0.2796, 0.1024, 0.4826],[0.3720, 0.1553, 0.1946, 0.2781],[0.2496, 0.4208, 0.1395, 0.1901],[0.2637, 0.1050, 0.2761, 0.3551],[0.2899, 0.1759, 0.3855, 0.1488]]    '''capacity = _capacity(token_sel_expert_weights, torch.tensor(1.1),torch.tensor(4))print("6.top1gating capacity:",capacity)# [S] 每个token对应的专家(取概率最大的)token_sel_expert_idx = torch.argmax(token_sel_expert_weights, dim=1) #[16]print("7.每个token对应的专家:",token_sel_expert_idx.shape,"data:",token_sel_expert_idx)#[3, 1, 3, 0, 1, 0, 2, 0, 3, 0, 1, 1, 1, 1, 0, 2]num_experts = int(token_sel_expert_weights.shape[1])token_sel_expert_mask = F.one_hot(token_sel_expert_idx, num_classes=num_experts)print("8.one_hot 编码:",token_sel_expert_mask.shape,"token_sel_expert_mask:\n",token_sel_expert_mask) #16,4'''tensor([[1, 0, 0, 0],[0, 0, 0, 1],[1, 0, 0, 0],[0, 0, 0, 1],[0, 1, 0, 0],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[1, 0, 0, 0],[0, 1, 0, 0],[0, 0, 0, 1],[0, 0, 1, 0]])'''# 通过topC每个专家选择至多C个token,然后和原始的mask1(每个专家可能选择超过C个token)矩阵相乘,# 丢掉超过专家容量的权重低的token,更新得到 token_sel_expert_maskexpert_sel_top_c_token_idx = torch.topk(token_sel_expert_mask, k=capacity, dim=0)[1]#5,4print(f"9:获取top{capacity}:{expert_sel_top_c_token_idx.shape} expert_sel_top_c_token_idx:\n{expert_sel_top_c_token_idx}")'''tensor([[ 0,  4, 15,  1],[ 2, 13,  0,  3],[12,  0,  1,  5],[ 1,  1,  2,  6],[ 3,  2,  3,  7]]) '''    mask=torch.zeros_like(token_sel_expert_mask).scatter_(0, expert_sel_top_c_token_idx,1)print(f"10.将上面index所在的位置填成1:{mask.shape},mask:\n{mask}")'''tensor([[1, 1, 1, 0],[1, 1, 1, 1],[1, 1, 1, 0],[1, 0, 1, 1],[0, 1, 0, 0],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0],[1, 0, 0, 0],[0, 1, 0, 0],[0, 0, 0, 0],[0, 0, 1, 0]])    '''    token_sel_expert_mask *= maskprint(f"11.生成最后的mask:{token_sel_expert_mask.shape} token_sel_expert_mask:\n{token_sel_expert_mask}")'''tensor([[1, 0, 0, 0],[0, 0, 0, 1],[1, 0, 0, 0],[0, 0, 0, 1],[0, 1, 0, 0],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 1],[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0],[1, 0, 0, 0],[0, 1, 0, 0],[0, 0, 0, 0],[0, 0, 1, 0]])    '''# Normalize gate probabilitiestoken_sel_expert_mask_float = token_sel_expert_mask.float()token_sel_expert_weights = token_sel_expert_weights * token_sel_expert_mask_floatprint(f"12.用mask去取softmax后的值:{token_sel_expert_weights.shape},token_sel_expert_weights:\n{token_sel_expert_weights}")'''tensor([[0.5426, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.5521],[0.5180, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.3152],[0.0000, 0.3157, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.3332],[0.0000, 0.0000, 0.0000, 0.3483],[0.0000, 0.0000, 0.0000, 0.4828],[0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000],[0.3720, 0.0000, 0.0000, 0.0000],[0.0000, 0.4208, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.3855, 0.0000]]    '''token_idx_in_expert_with_noise = torch.cumsum(token_sel_expert_mask, dim=0) - 1print(f"13.token_idx_in_expert_with_noise:{token_idx_in_expert_with_noise.shape} token_idx_in_expert_with_noise:\n{token_idx_in_expert_with_noise}")'''tensor([[ 0, -1, -1, -1],[ 0, -1, -1,  0],[ 1, -1, -1,  0],[ 1, -1, -1,  1],[ 1,  0, -1,  1],[ 1,  0, -1,  2],[ 1,  0, -1,  3],[ 1,  0, -1,  4],[ 1,  0, -1,  4],[ 1,  0, -1,  4],[ 1,  0, -1,  4],[ 1,  0, -1,  4],[ 2,  0, -1,  4],[ 2,  1, -1,  4],[ 2,  1, -1,  4],[ 2,  1,  0,  4]])    '''masked_token_idx_in_expert = token_idx_in_expert_with_noise * token_sel_expert_maskprint(f"14.masked_token_idx_in_expert:{masked_token_idx_in_expert.shape} masked_token_idx_in_expert:\n{masked_token_idx_in_expert}")'''tensor([[0, 0, 0, 0],[0, 0, 0, 0],[1, 0, 0, 0],[0, 0, 0, 1],[0, 0, 0, 0],[0, 0, 0, 2],[0, 0, 0, 3],[0, 0, 0, 4],[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0],[2, 0, 0, 0],[0, 1, 0, 0],[0, 0, 0, 0],[0, 0, 0, 0]])    '''token_offset_for_expert = torch.sum(masked_token_idx_in_expert, dim=1)print(f"15.token_offset_for_expert:{token_offset_for_expert.shape} token_offset_for_expert:\n{token_offset_for_expert}")'''tensor([0, 0, 1, 1, 0, 2, 3, 4, 0, 0, 0, 0, 2, 1, 0, 0])'''token_locations_sc = _one_hot_to_float(token_offset_for_expert, capacity)print(f"16.token_locations_sc:{token_locations_sc.shape} token_locations_sc:\n{token_locations_sc}")'''tensor([[1., 0., 0., 0., 0.],[1., 0., 0., 0., 0.],[0., 1., 0., 0., 0.],[0., 1., 0., 0., 0.],[1., 0., 0., 0., 0.],[0., 0., 1., 0., 0.],[0., 0., 0., 1., 0.],[0., 0., 0., 0., 1.],[1., 0., 0., 0., 0.],[1., 0., 0., 0., 0.],[1., 0., 0., 0., 0.],[1., 0., 0., 0., 0.],[0., 0., 1., 0., 0.],[0., 1., 0., 0., 0.],[1., 0., 0., 0., 0.],[1., 0., 0., 0., 0.]])    '''combine_weights = torch.einsum("se,sc->sec", token_sel_expert_weights, token_locations_sc)#16,4 16,5 -> 16,4,5 #每一个token,在4个专家,哪一个容器里print(f"17.combine_weights:{combine_weights.shape} combine_weights:\n{combine_weights}")'''tensor([[[0.5426, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.5521, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.5180, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.3152, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.3157, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.3332, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.3483, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.4828]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.3720, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.4208, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[0.3855, 0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]    '''dispatch_mask = combine_weights.bool()return combine_weights, dispatch_maskclass TopKGate(Module):weight: torch.nn.Lineardef __init__(self,hidden_size, num_experts) -> None:super(TopKGate,self).__init__()self.weight = torch.nn.Linear(hidden_size, num_experts, bias=False).float()def forward(self, gate_input):input_fp32 = gate_input.float()        logits = torch.nn.functional.linear(input_fp32, weight=self.weight.weight.float(), bias=None)print("4.TopKGate输入:",input_fp32.shape,"权值:",self.weight.weight.shape,"logits输出:",logits.shape)#16, 4gate_output = top1gating(logits)return gate_outputclass MOELayer(Module):def __init__(self,gate: Module,experts: Module,ep_size,num_local_experts,pipe_experts: bool = False,sequence_parallel: bool = True,pipe_experts_multi_data: int = 1,pipe_experts_multi_stream: bool = False) -> None:super().__init__()self.gate = gateself.experts = expertsself.ep_group = Noneself.ep_size = ep_sizeself.num_local_experts = num_local_expertsself.num_experts = ep_size * num_local_expertsself.exp_counts = Noneself.l_aux = Nonedef set_ep_group(self, ep_group):self.ep_group = ep_groupdef forward(self, input, **kwargs):'''一.目的:不同的expert负责不同的token二.主要步骤:1.生成特征分解矩阵,将输入token的特征拆解放在E(专家个数)C(每个专家的容器数)M(每个token的特征)的容器中矩阵每一个坐标内的值代表在某个维度上按多少比列分解特征,如果在某个维度上求和,就相当于对拆分后的特征进行加权求和2.通过矩阵乘,将输入token的特征拆解到以上矩阵(相当于用ECM的容器在装载、交换、变换特征,最后再将这个拆解后的特征加权[矩阵乘]还原到原始的维度)3.通过all2ll将分在不同ep rank的特征拉到各自己对应expert所在的rank上4.每个ep节点负责num_local_experts个expert。将上面的特征拆成num_local_experts块,分别送给不同的expert,之后合并结果5.将上面的结果通过all2all还原回之前每个RANK的排列顺序6.将分开的特征加权合并,生成(seq_len,hidden_size)的维度'''#input: 16,1,64d_model = input[0].shape[-1]reshaped_input = input[0].reshape(-1, d_model)#reshaped_input:16,64print("3.将维度转换为二维度(seq_len*batch_size,hidden_size):",reshaped_input.shape)# gatecombine_weights, dispatch_mask = self.gate(reshaped_input)print(combine_weights.shape,dispatch_mask.shape)dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) #16,4,5 16,64 -> 4 5 64  4个专家,5个容器,每个器放64个feature#将特征放在固定大小的容器里,防止了不均衡#dispatch_mask是token的分配矩阵,reshaped_input是每个token的特征,结果相当于将reshaped_input放在dispatch_mask里(加权存放)print(f"18.dispatched_input:{dispatched_input.shape},dispatched_input:\n{dispatched_input}")# dispatch all2all#ep是对expert进行拆分,每个expert承接一部分输入,all2ll之后是将前一半数据放在rank0,后一半数据放在rank1#dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)# Re-shape after all-to-all: ecm -> gecmdispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)#ep个专家组,每个组里2个专家,每个专用一个MixtralParallelMLPBM去提特征,每个MixtralParallelMLPBM的按tp并行计算,最后拼接在一起print("dispatched_input:",dispatched_input.shape)#每个expert计算一部分特征expert_output = self.experts(dispatched_input)print("expert_output:",expert_output.shape)# combine all2all#将特征还原回之前每个RANK的排列顺序,其实就相当于,通过all2all将特征当前对应的专家所在的rank上计算,计算完之后再放回去#expert_output = _AllToAll.apply(self.ep_group, expert_output)# Re-shape back: gecm -> ecmexpert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)##16,4,5  4,5,64  -> 16,64 将分开的特征加权合并,输出最终的16,64的特征combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)return combined_output.reshape(input[0].shape)'''
模型配置
TP=2
PP=2
DP=2
EP=2
num_experts=4
''''''
all tp gourps [[0, 1], [2, 3], [4, 5], [6, 7]]
all ep groups [[0, 2], [1, 3], [4, 6], [5, 7]]
all dp groups [[0, 2], [1, 3], [4, 6], [5, 7]]
all pp gourps [[0, 4], [1, 5], [2, 6], [3, 7]]
'''def main():num_experts=4num_local_experts=2seq_len=16batch_size=1hidden_size=8ffn_hidden_size=16ep_size=2gate = TopKGate(hidden_size,num_experts)moe = MOELayer(gate, Experts(MixtralParallelMLPBM(hidden_size,ffn_hidden_size),num_local_experts),ep_size,num_local_experts)input=torch.randn(seq_len, batch_size, hidden_size,dtype=torch.float32)print("1.原始的输入shape(32,1,64),因为序列并行,进入到MOE时维度为(16,1,64)")print("2.输入数列的shpae:",input.shape)output = moe([input])    print("output:",output.shape)main()

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 你想活出怎样的人生?我只活一次,所以想做自己
  • LLM的训练与推断
  • 字节测开面筋大总结!!!!
  • Flutter 中自定义DNS解析的实现
  • 移动式气象站:便携科技的天气守望者
  • 制作excel模板,用于管理后台批量导入船舶数据
  • 优选算法之位运算
  • React基础知识 精简全面 推荐
  • AI绘画3分钟解决英文恐惧症,comfyui汉化插件
  • 安装python插件命令集合
  • 分布式文件存储行业解决方案和技术选型分析
  • 【MySQL进阶之路 | 高级篇】显式事务和隐式事务
  • electron 网页TodoList应用打包win桌面软件数据持久化
  • 00-从零开始安装Oracle19c之数据库安装规划
  • 这款ERP云进销存系统,直接封神
  • CAP 一致性协议及应用解析
  • Centos6.8 使用rpm安装mysql5.7
  • ComponentOne 2017 V2版本正式发布
  • CSS相对定位
  • es6--symbol
  • ES6之路之模块详解
  • HTTP中的ETag在移动客户端的应用
  • js正则,这点儿就够用了
  • linux安装openssl、swoole等扩展的具体步骤
  • Quartz实现数据同步 | 从0开始构建SpringCloud微服务(3)
  • SpiderData 2019年2月25日 DApp数据排行榜
  • Tornado学习笔记(1)
  • 初识 webpack
  • 聊聊directory traversal attack
  • 前端每日实战:61# 视频演示如何用纯 CSS 创作一只咖啡壶
  • 巧用 TypeScript (一)
  • 线性表及其算法(java实现)
  • ​​​【收录 Hello 算法】10.4 哈希优化策略
  • # 透过事物看本质的能力怎么培养?
  • #《AI中文版》V3 第 1 章 概述
  • #define、const、typedef的差别
  • #if 1...#endif
  • %3cscript放入php,跟bWAPP学WEB安全(PHP代码)--XSS跨站脚本攻击
  • (LLM) 很笨
  • (二)c52学习之旅-简单了解单片机
  • (二十四)Flask之flask-session组件
  • (附程序)AD采集中的10种经典软件滤波程序优缺点分析
  • (论文阅读笔记)Network planning with deep reinforcement learning
  • (每日一问)基础知识:堆与栈的区别
  • (自用)仿写程序
  • ../depcomp: line 571: exec: g++: not found
  • ./include/caffe/util/cudnn.hpp: In function ‘const char* cudnnGetErrorString(cudnnStatus_t)’: ./incl
  • .bat批处理(二):%0 %1——给批处理脚本传递参数
  • .NET面试题解析(11)-SQL语言基础及数据库基本原理
  • /proc/interrupts 和 /proc/stat 查看中断的情况
  • /proc/vmstat 详解
  • @param注解什么意思_9000字,通俗易懂的讲解下Java注解
  • @RequestBody与@ModelAttribute
  • [ASP.NET MVC]如何定制Numeric属性/字段验证消息
  • [CC-FNCS]Chef and Churu