随时间反向传播(BPTT)算法


先简单回顾一下RNN的基本公式:

st=tanh(Uxt+Wst1)s_t = \tanh (Ux_t+Ws_{t-1})

y^t=softmax(Vst)\hat y_t=softmax(Vs_t)

RNN的损失函数定义为交叉熵损失:

Et(yt,y^t)=ytlogy^tE_t(y_t,\hat y_t)=-y_t\log\hat y_t

E(y,y^)=tEt(yt,y^t)=tytlogy^tE(y,\hat y)=\sum_{t}E_t(y_t, \hat y_t)=-\sum_{t}y_t\log\hat y_t

yty_t是时刻t的样本实际值, y^_t\hat y\_t是预测值,我们通常把整个序列作为一个训练样本,所以总的误差就是每一步的误差的加和。我们的目标是计算损失函数的梯度,然后通过梯度下降方法学习出所有的参数U, V, W。比如:EW=tEtW\frac{\partial E}{\partial W}=\sum_{t}\frac{\partial E_t}{\partial W}

为了更好理解BPTT我们来推导一下公式:

前向 前向传播1:

a0=x0ua_0 = x_0 * u

b0=s1wb_0 = s_{-1} * w

z0=a0+b0+kz_0 = a_0 + b_0 + k

s0=func(z0)s_0 = func(z_0) (funcfunc 是 sig或者tanh)

前向 前向传播2:

a1=x1ua_1 = x_1 * u

b1=s0wb_1 = s_0 * w

z1=a1+b1+kz_1 = a_1 + b_1+k

s1=func(z1)s_1 = func(z_1)(funcfunc 是 sig 或者tanh)

q=s1v1q = s_1 * v_1

zt=uxt+wst1+kz_t = u*x_t + w*s_{t-1} + k

st=func(zt)s_t = func(z_t)

输出 层:

o=func(q)o = func(q)(funcfunc 是 softmax)

E=func(o)E = func(o)(funcfunc 是 x-entropy)

下面 是U的推导

E/u=E/u1+E/u0\partial E/\partial u = \partial E/\partial u_1 + \partial E/\partial u_0

E/u1=E/oo/qq/s1s1/z1z1/a1a1/u1\partial E/\partial u_1 = \partial E/\partial o * \partial o/\partial q * \partial q/\partial s_1 * \partial s_1/\partial z_1 * \partial z_1/\partial a_1 * \partial a_1/\partial u_1

E/u0=E/oo/qq/s1s1/z1z1/b1b1/s0s0/dz0z0/a0a0/u0\partial E/\partial u_0 = \partial E/\partial o * \partial o/\partial q * \partial q/\partial s_1 * \partial s_1/\partial z_1 * \partial z_1/\partial b_1 * \partial b_1/\partial s_0 * \partial s_0/dz_0 * \partial z_0/\partial a_0 * \partial a_0/\partial u_0

E/u=E/oo/qv1s1/z1((1x1)+(1w1s0/z01x0))\partial E/\partial u = \partial E/\partial o * \partial o/\partial q * v_1 * \partial s_1/\partial z_1 * ((1 * x_1) + (1 * w_1 * \partial s_0/\partial z_0 * 1 * x_0))

E/u=E/oo/qv1s1/z1(x1+w1s0/z0x0)\partial E/\partial u = \partial E/\partial o * \partial o/\partial q * v_1 * \partial s_1/\partial z_1 * (x_1 + w_1 * \partial s_0/\partial z_0 * x_0)

W参数的推导如下

E/w=E/oo/qv1s1/z1(s0+w1s0/z0s1)\partial E/\partial w = \partial E/\partial o * \partial o/\partial q * v_1 * \partial s_1/\partial z_1 * (s_0 + w_1 * \partial s_0/\partial z_0 * s_{-1})

总结

Lu=tLut=Loos1s1u1+Loos1s1s0s0u0\dfrac{\partial{L}}{\partial{u}}=\sum_t \dfrac{\partial{L}}{\partial{u_t}} = \dfrac{\partial L}{\partial o} \dfrac{\partial o}{\partial s_1} \dfrac{\partial s_1}{\partial u_1}+\dfrac{\partial L}{\partial o} \dfrac{\partial o}{\partial s_1}\dfrac{\partial s_1}{\partial s_0}\dfrac{\partial s_0}{\partial u_0}

Lw=tLwt=Loos1s1w1+Loos1s1s0s0w0\dfrac{\partial{L}}{\partial{w}}=\sum_t \dfrac{\partial{L}}{\partial{w_t}} = \dfrac{\partial L}{\partial o} \dfrac{\partial o}{\partial s_1} \dfrac{\partial s_1}{\partial w_1}+\dfrac{\partial L}{\partial o} \dfrac{\partial o}{\partial s_1}\dfrac{\partial s_1}{\partial s_0}\dfrac{\partial s_0}{\partial w_0}

xtx_t是时间t的输入

图 1.28.4.1 - Many-to-one RNN

更多了解RNN,推荐Goodfellow et al RNN chapter和Andrej Karpathy minimal character RNN实现。

results matching ""

    No results matching ""