Featured image of post 反向传播与计算图

反向传播与计算图

前向传播

前向传播

前向传播(forward propagation):按顺序(输入层\to输出层)计算和存储神经网络中每层的结果。

讨论示例: 预设:隐藏层不包括偏置项

  • 输入样本xRd\mathbf{x}\in \mathbb{R}^d
  • 隐藏层的权重参数W(1)Rh×d\mathbf{W}^{(1)} \in \mathbb{R}^{h \times d}
  • 中间变量z=W(1)x\mathbf{z}= \mathbf{W}^{(1)} \mathbf{x}
    • zRh\mathbf{z}\in \mathbb{R}^h
  • 激活函数ϕ\phi
  • 隐藏激活向量h=ϕ(z)\mathbf{h}= \phi (\mathbf{z})
    • hRh\mathbf{h}\in \mathbb{R}^h
  • 输出层的权重参数W(2)Rq×h\mathbf{W}^{(2)} \in \mathbb{R}^{q \times h}
  • 输出层变量o=W(2)h\mathbf{o}= \mathbf{W}^{(2)} \mathbf{h}
    • oRq\mathbf{o}\in \mathbb{R}^q
  • 单个数据样本的损失项L=l(o,y)L = l(\mathbf{o}, y)
    • 损失函数ll
    • 样本标签yy
  • 正则化项s=λ2(W(1)F2+W(2)F2)s = \frac{\lambda}{2} \left(\|\mathbf{W}^{(1)}\|_F^2 + \|\mathbf{W}^{(2)}\|_F^2\right)
    • L2L_2正则化
    • 超参数λ\lambda
  • 模型在给定数据样本上的正则化损失J=L+sJ = L + s
    • 即此时的目标函数(objective function)。

  • 正方形表示变量,圆圈表示操作符。
  • 左下角表示输入,右上角表示输出。
  • 箭头显示数据流的方向,主要是向右和向上。

反向传播

反向传播(backward propagation):计算神经网络参数梯度的方法。 根据微积分中的链式法则,按相反的顺序从输出层到输入层遍历网络。

  • 存储计算某些参数梯度时所需的任何中间变量(偏导数)
  • prod\text{prod}运算符:在执行必要的操作(如换位和交换输入位置)后将其参数相乘。
  • 目的:计算梯度JW(1)\frac{\partial J}{\partial \mathbf{W}^{(1)}}JW(2)\frac{\partial J}{\partial \mathbf{W}^{(2)}}
    • 先计算距离输出层更近的JW(2)\frac{\partial J}{\partial \mathbf{W}^{(2)}}
JW(2)=JLLW(2)+JssW(2)=1LW(2)+ 1sW(2)=LooW(2)+λW(2)=Loh+λW(2) \begin{align*} \frac{\partial J}{\partial \mathbf{W}^{(2)}} &= \frac{\partial J}{\partial L}\cdot \frac{\partial L}{\partial \mathbf{W}^{(2)}} +\frac{\partial J}{\partial s}\cdot \frac{\partial s}{\partial \mathbf{W}^{(2)}} \\ &= 1\cdot \frac{\partial L}{\partial \mathbf{W}^{(2)}} +\ 1\cdot \frac{\partial s}{\partial \mathbf{W}^{(2)}} \\ &= \frac{\partial L}{\partial \mathbf{o}} \cdot \frac{\partial \mathbf{o}}{\partial \mathbf{W}^{(2)}} + \lambda \mathbf{W}^{(2)}\\ &=\frac{\partial L}{\partial \mathbf{o}} \cdot \mathbf{h}^\top + \lambda \mathbf{W}^{(2)}\\ \end{align*} JW(1)=JLLW(1)+JssW(1)=1LW(1)+ 1sW(1)=LooW(1)+λW(1)=LoohhW(1)+λW(1)=(W(2)Lo)hW(1)+λW(1)=(W(2)Lo)hzzW(1)+λW(1)=(W(2)Lo)ϕ(z)x+λW(1) \begin{align*} \frac{\partial J}{\partial \mathbf{W}^{(1)}} &= \frac{\partial J}{\partial L}\cdot \frac{\partial L}{\partial \mathbf{W}^{(1)}} +\frac{\partial J}{\partial s}\cdot \frac{\partial s}{\partial \mathbf{W}^{(1)}} \\ &= 1\cdot \frac{\partial L}{\partial \mathbf{W}^{(1)}} +\ 1\cdot \frac{\partial s}{\partial \mathbf{W}^{(1)}} \\ &= \frac{\partial L}{\partial \mathbf{o}} \cdot \frac{\partial \mathbf{o}}{\partial \mathbf{W}^{(1)}} + \lambda \mathbf{W}^{(1)}\\ &=\frac{\partial L}{\partial \mathbf{o}} \cdot \frac{\partial \mathbf{o}}{\partial \mathbf{h}} \cdot \frac{\partial \mathbf{h}}{\partial \mathbf{W}^{(1)}} + \lambda \mathbf{W}^{(1)}\\ &=({\mathbf{W}^{(2)}}^\top \cdot \frac{\partial L}{\partial \mathbf{o}}) \cdot \frac{\partial \mathbf{h}}{\partial \mathbf{W}^{(1)}} + \lambda \mathbf{W}^{(1)}\\ &=({\mathbf{W}^{(2)}}^\top \cdot \frac{\partial L}{\partial \mathbf{o}}) \cdot \frac{\partial \mathbf{h}}{\partial \mathbf{z}} \cdot \frac{\partial \mathbf{z}}{\partial \mathbf{W}^{(1)}} + \lambda \mathbf{W}^{(1)}\\ &=({\mathbf{W}^{(2)}}^\top \cdot \frac{\partial L}{\partial \mathbf{o}}) \odot \phi'\left(\mathbf{z}\right) \cdot \mathbf{x}^\top + \lambda \mathbf{W}^{(1)}\\ \end{align*}

由于JL=1\frac{\partial J}{\partial L}=1,故Lh\frac{\partial L}{\partial \mathbf{h}}在数值上等于Jh\frac{\partial J}{\partial \mathbf{h}},下述Jh\frac{\partial J}{\partial \mathbf{h}}Jz\frac{\partial J}{\partial \mathbf{z}}按此对应上述推导。

激活函数ϕ\phi是按元素计算的,计算中间变量z\mathbf{z}的梯度JzRh\frac{\partial J}{\partial \mathbf{z}} \in \mathbb{R}^h 需要使用按元素乘法运算符,用\odot表示:

Jz=prod(Jh,hz)=Jhϕ(z) \frac{\partial J}{\partial \mathbf{z}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{h}}, \frac{\partial \mathbf{h}}{\partial \mathbf{z}}\right) = \frac{\partial J}{\partial \mathbf{h}} \odot \phi'\left(\mathbf{z}\right)

隐藏层输出的梯度JhRh\frac{\partial J}{\partial \mathbf{h}} \in \mathbb{R}^h由下式给出:

Jh=prod(Jo,oh)=W(2)Jo \frac{\partial J}{\partial \mathbf{h}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{h}}\right) = {\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}

JJ 是一个标量(损失函数的输出),o\mathbf{o}qq 维向量,按照多元微积分,Jo\frac{\partial J}{\partial \mathbf{o}}形状和 o\mathbf{o} 一致,即 JoRq\frac{\partial J}{\partial \mathbf{o}} \in \mathbb{R}^q;又W(2)Rq×h\mathbf{W}^{(2)} \in \mathbb{R}^{q \times h},故:

W(2)JoRh {{\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}} \in \mathbb{R}^h

W(2)Jo{{\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}}JoW(2){\frac{\partial J}{\partial \mathbf{o}} \cdot \mathbf{W}^{(2)}}

反向传播的推导中,更常用W(2)Jo{{\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}} 的写法,主要考虑以下:

  1. 链式法则 通过上一层的梯度左乘该层权重的转置,得到每一层的梯度
  2. 与前向传播的结构对偶
    • 前向传播:o=W(2)h\mathbf{o} = \mathbf{W}^{(2)} \mathbf{h}
    • 反向传播:Jh=W(2)Jo\frac{\partial J}{\partial \mathbf{h}} = {\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}

对于W(2)Rq×h\mathbf{W}^{(2)} \in \mathbb{R}^{q \times h},为了在计算Jh\frac{\partial J}{\partial \mathbf{h}}时形状匹配:

  • Jh=W(2)Jo\frac{\partial J}{\partial \mathbf{h}} = {{\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}}:应将Jo\frac{\partial J}{\partial \mathbf{o}} 视作是 q×1q \times 1 的列向量(Rq×1\mathbb{R}^{q \times 1}),得到JoW(2)Rh×1\frac{\partial J}{\partial \mathbf{o}} \cdot \mathbf{W}^{(2)} \in \mathbb{R}^{h \times 1}
  • Jh=JoW(2)\frac{\partial J}{\partial \mathbf{h}} = {\frac{\partial J}{\partial \mathbf{o}} \cdot \mathbf{W}^{(2)}}:应将Jo\frac{\partial J}{\partial \mathbf{o}} 视作是 1×q1 \times q 的行向量(R1×q\mathbb{R}^{1 \times q}),得到JoW(2)R1×h\frac{\partial J}{\partial \mathbf{o}} \cdot \mathbf{W}^{(2)} \in \mathbb{R}^{1 \times h}

区分行/列向量,不是改变变量的本质维度,是为了在矩阵乘法等操作时形状能严格对齐。

但在具体的编程实现时,W(2)Jo{{\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}}JoW(2){\frac{\partial J}{\partial \mathbf{o}} \cdot \mathbf{W}^{(2)}}很可能一致。

对于表述JoRq\frac{\partial J}{\partial \mathbf{o}} \in \mathbb{R}^q,只表示“qq 维实向量”,不区分是行向量还是列向量,也不强调是1维还是2维(抽象意义上是1维向量);而例如JoRq×1\frac{\partial J}{\partial \mathbf{o}} \in \mathbb{R}^{q \times 1}则明确表示“qq11 列的矩阵”,也就是列向量,在数学和编程实现中是2维的。

在实际编程(如 NumPy、PyTorch)中,

  • Rq\mathbb{R}^q 通常对应 shape 为 (q,) 的一维数组(向量)
  • Rq×1\mathbb{R}^{q \times 1} 对应 shape 为 (q, 1) 的二维数组(列向量)
  • R1×q\mathbb{R}^{1 \times q} 对应 shape 为 (1, q) 的二维数组(行向量)

Jo\frac{\partial J}{\partial \mathbf{o}} 的 shape 取决于你的实现方式和数据批量:

  • 单样本(无batch):

    • o\mathbf{o}qq 维向量,通常 shape 为 (q,)(一维数组),也可能是 (q, 1)(二维列向量)
    • 此时,Jo\frac{\partial J}{\partial \mathbf{o}} 的 shape 通常为 (q,)(q, 1)
  • 批量(batch):

    • o\mathbf{o} 是 shape (batch_size, q),即每一行对应一个样本的输出
    • 此时,Jo\frac{\partial J}{\partial \mathbf{o}} 的 shape 为 (batch_size, q)

多数情况我们会使用batch,借助PyTorch 的 @ 运算(也即matmul),自动对第一个维度(batch 维)进行广播和批量矩阵乘法。此时,不论是通过W(2)Jo{{\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}}还是通过JoW(2){\frac{\partial J}{\partial \mathbf{o}} \cdot \mathbf{W}^{(2)}},均使用dJ_dh = dJ_do @ W2语句得到Jh\frac{\partial J}{\partial \mathbf{h}}

内存开销

反向传播重复利用前向传播中存储的中间值,以避免重复计算。影响之一是需要保留中间值,直到反向传播完成。这也是训练比单纯的预测需要更多的内存(显存)的原因之一。 此外,这些中间值的大小与网络层的数量和批量的大小大致成正比,因此,使用更大的批量来训练更深层次的网络更容易导致内存不足(out of memory)。

使用 Hugo 构建
主题 StackJimmy 设计