算法面经手撕系列(3)--手撕LayerNormlization
LayerNormlization
在许多的语言模型如Bert里,虽然都是说做的LayerNormlization,但计算均值和方差只会沿着channel维度做,并不是沿着seq_L和channel维度一起做,参考:BERT用的LayerNorm可能不是你认为的那个Layer Norm
LayerNormlization计算流程:
- init里初始化C_in大小的scale和shift向量
- 沿Channel维度计算均值和方差
- 归一化
代码
LayerNorm(InstanceNorm)实现如下:
class LayerNormalization(nn.Module):def __init__(self,hidden_dim,eps=1e-6):super(LayerNormalization, self).__init__()self.eps=epsself.gamma=nn.Parameter(torch.ones(hidden_dim))self.beta=nn.Parameter(torch.zeros(hidden_dim))def forward(self,x):B,seq_L,C=x.shapemean=x.mean(dim=-1,keepdim=True)std=x.std(dim=-1,keepdim=True)out=(x-mean)/(std+self.eps)out=out*self.gamma+self.betareturn out
if __name__=="__main__":tensor_input=torch.rand(5,10,8)model=LayerNormalization(8)res=model(tensor_input)print(res)