Pytorch梯度裁剪 nn.utils.clip_grad_norm_()
梯度裁剪原理:既然在BP过程中会产生梯度消失(就是偏导无限接近0,导致长时记忆无法更新)或梯度爆炸,那么最简单粗暴的方法就是,梯度截断Clip, 将梯度约束在某一个区间之内
pytorch中的梯度裁剪函数是nn.utils.clip_grad_norm_()
parameters:希望实施梯度裁剪的可迭代网络参数
max_norm:该组网络参数梯度的范数上限
norm_type:范数类型(一般默认为L2 范数, 即范数类型=2)torch.nn.utils.clipgrad_norm() 的使用应该在loss.backward() 之后,optimizer.step()之前.
注意这个方法只在训练的时候使用,在测试的时候验证和测试的时候不用。