前向传播 前向传播 (forward propagation):按顺序(输入层→ \to → 输出层)计算和存储神经网络中每层的结果。
讨论示例:
预设:隐藏层不包括偏置项
输入样本x ∈ R d \mathbf{x}\in \mathbb{R}^d x ∈ R d 隐藏层的权重参数W ( 1 ) ∈ R h × d \mathbf{W}^{(1)} \in \mathbb{R}^{h \times d} W ( 1 ) ∈ R h × d 中间变量z = W ( 1 ) x \mathbf{z}= \mathbf{W}^{(1)} \mathbf{x} z = W ( 1 ) x z ∈ R h \mathbf{z}\in \mathbb{R}^h z ∈ R h 激活函数ϕ \phi ϕ 隐藏激活向量h = ϕ ( z ) \mathbf{h}= \phi (\mathbf{z}) h = ϕ ( z ) h ∈ R h \mathbf{h}\in \mathbb{R}^h h ∈ R h 输出层的权重参数W ( 2 ) ∈ R q × h \mathbf{W}^{(2)} \in \mathbb{R}^{q \times h} W ( 2 ) ∈ R q × h 输出层变量o = W ( 2 ) h \mathbf{o}= \mathbf{W}^{(2)} \mathbf{h} o = W ( 2 ) h o ∈ R q \mathbf{o}\in \mathbb{R}^q o ∈ R q 单个数据样本的损失项L = l ( o , y ) L = l(\mathbf{o}, y) L = l ( o , y ) 正则化项s = λ 2 ( ∥ W ( 1 ) ∥ F 2 + ∥ W ( 2 ) ∥ F 2 ) s = \frac{\lambda}{2} \left(\|\mathbf{W}^{(1)}\|_F^2 + \|\mathbf{W}^{(2)}\|_F^2\right) s = 2 λ ( ∥ W ( 1 ) ∥ F 2 + ∥ W ( 2 ) ∥ F 2 ) 模型在给定数据样本上的正则化损失J = L + s J = L + s J = L + s 即此时的目标函数 (objective function)。
计算图 正方形表示变量,圆圈表示操作符。 左下角表示输入,右上角表示输出。 箭头显示数据流的方向,主要是向右和向上。 反向传播 反向传播 (backward propagation):计算神经网络参数梯度的方法。
根据微积分中的链式法则 ,按相反的顺序从输出层到输入层遍历网络。
存储计算某些参数梯度时所需的任何中间变量(偏导数) prod \text{prod} prod 运算符:在执行必要的操作(如换位和交换输入位置)后将其参数相乘。目的:计算梯度∂ J ∂ W ( 1 ) \frac{\partial J}{\partial \mathbf{W}^{(1)}} ∂ W ( 1 ) ∂ J 和∂ J ∂ W ( 2 ) \frac{\partial J}{\partial \mathbf{W}^{(2)}} ∂ W ( 2 ) ∂ J 先计算距离输出层更近的∂ J ∂ W ( 2 ) \frac{\partial J}{\partial \mathbf{W}^{(2)}} ∂ W ( 2 ) ∂ J ∂ J ∂ W ( 2 ) = ∂ J ∂ L ⋅ ∂ L ∂ W ( 2 ) + ∂ J ∂ s ⋅ ∂ s ∂ W ( 2 ) = 1 ⋅ ∂ L ∂ W ( 2 ) + 1 ⋅ ∂ s ∂ W ( 2 ) = ∂ L ∂ o ⋅ ∂ o ∂ W ( 2 ) + λ W ( 2 ) = ∂ L ∂ o ⋅ h ⊤ + λ W ( 2 )
\begin{align*}
\frac{\partial J}{\partial \mathbf{W}^{(2)}}
&= \frac{\partial J}{\partial L}\cdot \frac{\partial L}{\partial \mathbf{W}^{(2)}}
+\frac{\partial J}{\partial s}\cdot \frac{\partial s}{\partial \mathbf{W}^{(2)}} \\
&= 1\cdot \frac{\partial L}{\partial \mathbf{W}^{(2)}}
+\ 1\cdot \frac{\partial s}{\partial \mathbf{W}^{(2)}} \\
&= \frac{\partial L}{\partial \mathbf{o}} \cdot \frac{\partial \mathbf{o}}{\partial \mathbf{W}^{(2)}}
+ \lambda \mathbf{W}^{(2)}\\
&=\frac{\partial L}{\partial \mathbf{o}}
\cdot \mathbf{h}^\top
+ \lambda \mathbf{W}^{(2)}\\
\end{align*}
∂ W ( 2 ) ∂ J = ∂ L ∂ J ⋅ ∂ W ( 2 ) ∂ L + ∂ s ∂ J ⋅ ∂ W ( 2 ) ∂ s = 1 ⋅ ∂ W ( 2 ) ∂ L + 1 ⋅ ∂ W ( 2 ) ∂ s = ∂ o ∂ L ⋅ ∂ W ( 2 ) ∂ o + λ W ( 2 ) = ∂ o ∂ L ⋅ h ⊤ + λ W ( 2 ) ∂ J ∂ W ( 1 ) = ∂ J ∂ L ⋅ ∂ L ∂ W ( 1 ) + ∂ J ∂ s ⋅ ∂ s ∂ W ( 1 ) = 1 ⋅ ∂ L ∂ W ( 1 ) + 1 ⋅ ∂ s ∂ W ( 1 ) = ∂ L ∂ o ⋅ ∂ o ∂ W ( 1 ) + λ W ( 1 ) = ∂ L ∂ o ⋅ ∂ o ∂ h ⋅ ∂ h ∂ W ( 1 ) + λ W ( 1 ) = ( W ( 2 ) ⊤ ⋅ ∂ L ∂ o ) ⋅ ∂ h ∂ W ( 1 ) + λ W ( 1 ) = ( W ( 2 ) ⊤ ⋅ ∂ L ∂ o ) ⋅ ∂ h ∂ z ⋅ ∂ z ∂ W ( 1 ) + λ W ( 1 ) = ( W ( 2 ) ⊤ ⋅ ∂ L ∂ o ) ⊙ ϕ ′ ( z ) ⋅ x ⊤ + λ W ( 1 )
\begin{align*}
\frac{\partial J}{\partial \mathbf{W}^{(1)}}
&= \frac{\partial J}{\partial L}\cdot \frac{\partial L}{\partial \mathbf{W}^{(1)}}
+\frac{\partial J}{\partial s}\cdot \frac{\partial s}{\partial \mathbf{W}^{(1)}} \\
&= 1\cdot \frac{\partial L}{\partial \mathbf{W}^{(1)}}
+\ 1\cdot \frac{\partial s}{\partial \mathbf{W}^{(1)}} \\
&= \frac{\partial L}{\partial \mathbf{o}} \cdot \frac{\partial \mathbf{o}}{\partial \mathbf{W}^{(1)}}
+ \lambda \mathbf{W}^{(1)}\\
&=\frac{\partial L}{\partial \mathbf{o}}
\cdot \frac{\partial \mathbf{o}}{\partial \mathbf{h}}
\cdot \frac{\partial \mathbf{h}}{\partial \mathbf{W}^{(1)}}
+ \lambda \mathbf{W}^{(1)}\\
&=({\mathbf{W}^{(2)}}^\top \cdot \frac{\partial L}{\partial \mathbf{o}})
\cdot \frac{\partial \mathbf{h}}{\partial \mathbf{W}^{(1)}}
+ \lambda \mathbf{W}^{(1)}\\
&=({\mathbf{W}^{(2)}}^\top \cdot \frac{\partial L}{\partial \mathbf{o}})
\cdot \frac{\partial \mathbf{h}}{\partial \mathbf{z}}
\cdot \frac{\partial \mathbf{z}}{\partial \mathbf{W}^{(1)}}
+ \lambda \mathbf{W}^{(1)}\\
&=({\mathbf{W}^{(2)}}^\top \cdot \frac{\partial L}{\partial \mathbf{o}})
\odot \phi'\left(\mathbf{z}\right)
\cdot \mathbf{x}^\top
+ \lambda \mathbf{W}^{(1)}\\
\end{align*}
∂ W ( 1 ) ∂ J = ∂ L ∂ J ⋅ ∂ W ( 1 ) ∂ L + ∂ s ∂ J ⋅ ∂ W ( 1 ) ∂ s = 1 ⋅ ∂ W ( 1 ) ∂ L + 1 ⋅ ∂ W ( 1 ) ∂ s = ∂ o ∂ L ⋅ ∂ W ( 1 ) ∂ o + λ W ( 1 ) = ∂ o ∂ L ⋅ ∂ h ∂ o ⋅ ∂ W ( 1 ) ∂ h + λ W ( 1 ) = ( W ( 2 ) ⊤ ⋅ ∂ o ∂ L ) ⋅ ∂ W ( 1 ) ∂ h + λ W ( 1 ) = ( W ( 2 ) ⊤ ⋅ ∂ o ∂ L ) ⋅ ∂ z ∂ h ⋅ ∂ W ( 1 ) ∂ z + λ W ( 1 ) = ( W ( 2 ) ⊤ ⋅ ∂ o ∂ L ) ⊙ ϕ ′ ( z ) ⋅ x ⊤ + λ W ( 1 ) 由于∂ J ∂ L = 1 \frac{\partial J}{\partial L}=1 ∂ L ∂ J = 1 ,故∂ L ∂ h \frac{\partial L}{\partial \mathbf{h}} ∂ h ∂ L 在数值上等于∂ J ∂ h \frac{\partial J}{\partial \mathbf{h}} ∂ h ∂ J ,下述∂ J ∂ h \frac{\partial J}{\partial \mathbf{h}} ∂ h ∂ J 、∂ J ∂ z \frac{\partial J}{\partial \mathbf{z}} ∂ z ∂ J 按此对应上述推导。
激活函数ϕ \phi ϕ 是按元素计算的,计算中间变量z \mathbf{z} z 的梯度∂ J ∂ z ∈ R h \frac{\partial J}{\partial \mathbf{z}} \in \mathbb{R}^h ∂ z ∂ J ∈ R h
需要使用按元素乘法运算符,用⊙ \odot ⊙ 表示:
∂ J ∂ z = prod ( ∂ J ∂ h , ∂ h ∂ z ) = ∂ J ∂ h ⊙ ϕ ′ ( z )
\frac{\partial J}{\partial \mathbf{z}}
= \text{prod}\left(\frac{\partial J}{\partial \mathbf{h}}, \frac{\partial \mathbf{h}}{\partial \mathbf{z}}\right)
= \frac{\partial J}{\partial \mathbf{h}} \odot \phi'\left(\mathbf{z}\right)
∂ z ∂ J = prod ( ∂ h ∂ J , ∂ z ∂ h ) = ∂ h ∂ J ⊙ ϕ ′ ( z ) 隐藏层输出的梯度∂ J ∂ h ∈ R h \frac{\partial J}{\partial \mathbf{h}} \in \mathbb{R}^h ∂ h ∂ J ∈ R h 由下式给出:
∂ J ∂ h = prod ( ∂ J ∂ o , ∂ o ∂ h ) = W ( 2 ) ⊤ ∂ J ∂ o
\frac{\partial J}{\partial \mathbf{h}}
= \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{h}}\right)
= {\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}
∂ h ∂ J = prod ( ∂ o ∂ J , ∂ h ∂ o ) = W ( 2 ) ⊤ ∂ o ∂ J J J J 是一个标量(损失函数的输出),o \mathbf{o} o 是 q q q 维向量,按照多元微积分,∂ J ∂ o \frac{\partial J}{\partial \mathbf{o}} ∂ o ∂ J 形状和 o \mathbf{o} o 一致,即 ∂ J ∂ o ∈ R q \frac{\partial J}{\partial \mathbf{o}} \in \mathbb{R}^q ∂ o ∂ J ∈ R q ;又W ( 2 ) ∈ R q × h \mathbf{W}^{(2)} \in \mathbb{R}^{q \times h} W ( 2 ) ∈ R q × h ,故:
W ( 2 ) ⊤ ∂ J ∂ o ∈ R h
{{\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}} \in \mathbb{R}^h
W ( 2 ) ⊤ ∂ o ∂ J ∈ R h W ( 2 ) ⊤ ∂ J ∂ o {{\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}} W ( 2 ) ⊤ ∂ o ∂ J 与∂ J ∂ o ⋅ W ( 2 ) {\frac{\partial J}{\partial \mathbf{o}} \cdot \mathbf{W}^{(2)}} ∂ o ∂ J ⋅ W ( 2 )
在反向传播的推导 中,更常用W ( 2 ) ⊤ ∂ J ∂ o {{\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}} W ( 2 ) ⊤ ∂ o ∂ J 的写法,主要考虑以下:
链式法则
通过上一层的梯度左乘该层权重的转置 ,得到每一层的梯度 与前向传播的结构对偶前向传播:o = W ( 2 ) h \mathbf{o} = \mathbf{W}^{(2)} \mathbf{h} o = W ( 2 ) h 反向传播:∂ J ∂ h = W ( 2 ) ⊤ ∂ J ∂ o \frac{\partial J}{\partial \mathbf{h}} = {\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}} ∂ h ∂ J = W ( 2 ) ⊤ ∂ o ∂ J 对于W ( 2 ) ∈ R q × h \mathbf{W}^{(2)} \in \mathbb{R}^{q \times h} W ( 2 ) ∈ R q × h ,为了在计算∂ J ∂ h \frac{\partial J}{\partial \mathbf{h}} ∂ h ∂ J 时形状匹配:
∂ J ∂ h = W ( 2 ) ⊤ ∂ J ∂ o \frac{\partial J}{\partial \mathbf{h}} = {{\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}} ∂ h ∂ J = W ( 2 ) ⊤ ∂ o ∂ J :应将∂ J ∂ o \frac{\partial J}{\partial \mathbf{o}} ∂ o ∂ J 视作是 q × 1 q \times 1 q × 1 的列向量(R q × 1 \mathbb{R}^{q \times 1} R q × 1 ),得到∂ J ∂ o ⋅ W ( 2 ) ∈ R h × 1 \frac{\partial J}{\partial \mathbf{o}} \cdot \mathbf{W}^{(2)} \in \mathbb{R}^{h \times 1} ∂ o ∂ J ⋅ W ( 2 ) ∈ R h × 1 ∂ J ∂ h = ∂ J ∂ o ⋅ W ( 2 ) \frac{\partial J}{\partial \mathbf{h}} = {\frac{\partial J}{\partial \mathbf{o}} \cdot \mathbf{W}^{(2)}} ∂ h ∂ J = ∂ o ∂ J ⋅ W ( 2 ) :应将∂ J ∂ o \frac{\partial J}{\partial \mathbf{o}} ∂ o ∂ J 视作是 1 × q 1 \times q 1 × q 的行向量(R 1 × q \mathbb{R}^{1 \times q} R 1 × q ),得到∂ J ∂ o ⋅ W ( 2 ) ∈ R 1 × h \frac{\partial J}{\partial \mathbf{o}} \cdot \mathbf{W}^{(2)} \in \mathbb{R}^{1 \times h} ∂ o ∂ J ⋅ W ( 2 ) ∈ R 1 × h 区分行/列向量 ,不是改变变量的本质维度,是为了在矩阵乘法等操作时形状能严格对齐。
但在具体的编程实现 时,W ( 2 ) ⊤ ∂ J ∂ o {{\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}} W ( 2 ) ⊤ ∂ o ∂ J 与∂ J ∂ o ⋅ W ( 2 ) {\frac{\partial J}{\partial \mathbf{o}} \cdot \mathbf{W}^{(2)}} ∂ o ∂ J ⋅ W ( 2 ) 很可能一致。
对于表述∂ J ∂ o ∈ R q \frac{\partial J}{\partial \mathbf{o}} \in \mathbb{R}^q ∂ o ∂ J ∈ R q ,只表示“q q q 维实向量”,不区分是行向量还是列向量 ,也不强调是1维还是2维(抽象意义上是1维向量);而例如∂ J ∂ o ∈ R q × 1 \frac{\partial J}{\partial \mathbf{o}} \in \mathbb{R}^{q \times 1} ∂ o ∂ J ∈ R q × 1 则明确表示“q q q 行 1 1 1 列的矩阵”,也就是列向量,在数学和编程实现中是2维的。
在实际编程(如 NumPy、PyTorch)中,
R q \mathbb{R}^q R q 通常对应 shape 为 (q,)
的一维数组(向量)R q × 1 \mathbb{R}^{q \times 1} R q × 1 对应 shape 为 (q, 1)
的二维数组(列向量)R 1 × q \mathbb{R}^{1 \times q} R 1 × q 对应 shape 为 (1, q)
的二维数组(行向量)∂ J ∂ o \frac{\partial J}{\partial \mathbf{o}} ∂ o ∂ J 的 shape 取决于你的实现方式和数据批量:
单样本(无batch):
o \mathbf{o} o 是 q q q 维向量,通常 shape 为 (q,)
(一维数组),也可能是 (q, 1)
(二维列向量)此时,∂ J ∂ o \frac{\partial J}{\partial \mathbf{o}} ∂ o ∂ J 的 shape 通常为 (q,)
或 (q, 1)
批量(batch):
o \mathbf{o} o 是 shape (batch_size, q)
,即每一行 对应一个样本的输出此时,∂ J ∂ o \frac{\partial J}{\partial \mathbf{o}} ∂ o ∂ J 的 shape 为 (batch_size, q)
多数情况我们会使用batch,借助PyTorch 的 @
运算(也即matmul
),自动对第一个维度(batch 维)进行广播和批量矩阵乘法。此时,不论是通过W ( 2 ) ⊤ ∂ J ∂ o {{\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}} W ( 2 ) ⊤ ∂ o ∂ J 还是通过∂ J ∂ o ⋅ W ( 2 ) {\frac{\partial J}{\partial \mathbf{o}} \cdot \mathbf{W}^{(2)}} ∂ o ∂ J ⋅ W ( 2 ) ,均使用dJ_dh = dJ_do @ W2
语句得到∂ J ∂ h \frac{\partial J}{\partial \mathbf{h}} ∂ h ∂ J 。
内存开销 反向传播重复利用前向传播中存储的中间值,以避免重复计算。影响之一是需要保留中间值,直到反向传播完成。这也是训练比单纯的预测需要更多的内存(显存)的原因之一。
此外,这些中间值的大小与网络层的数量和批量的大小大致成正比,因此,使用更大的批量来训练更深层次的网络更容易导致内存不足 (out of memory)。