对梯度爆炸和梯度消失的一些理解
众所周知,RNN可能有梯度爆炸和梯度消失的情况,主要是由于在反向传播时loss函数对远端w权重值求导时会出现一个Ws高阶连乘以及hn对hn-1的导数,这个值是在[0,1]间的,因此连乘下来数值会越来越小,梯度的值接近零,于是远端的信息无法传递过来;如果Ws初始化的值太小也会导致梯度消失,太大则会梯度爆炸,也就是说相同的权重矩阵反复连乘。这里的梯度消失并不是参数完全不更新,而是更新被近距离的信息主导。
LSTM有效的解决了这个问题,因为加了一个乘法门记忆单元,能通过参数来控制反向传播是按照与RNN类似的路径,还是直接通过记忆单元这条路,LSTM 通过记忆这条路径上的梯度拯救了总体的远距离梯度消失问题。如果通过记忆单元来求梯度,可以通过控制与h有关的几个参数来控制Cn对Cn-1(C就是记忆单元)的导数,这个值是遗忘门的输出的乘积和其他部分的和,因此可以通过控制遗忘门的输出值接近1保证这条路的梯度值不会太小,这样就避免了梯度消失;远距离的信息在参与更新时,在Cell上是正常梯度,在另一条与RNN类似的路径上梯度,消失+ 消失梯度 = 正常梯度。
但LSTM还有可能梯度爆炸,但相比RNN概率小了很多,梯度爆炸可以通过梯度修剪来解决:只要设定阈值,当提出梯度超过此阈值,就进行截取。
Cell这个结构类似于ResNet
参考:https://www.zhihu.com/question/34878706