0 前言
在初次接触神经网络的时候,不知道有没有人像我一样,对梯度计算的过程很模糊。可能大体上知道是使用链式法则来做,但是平常接触的都是一元复合函数的求导,像神经网络中这种对参数矩阵求导的具体细节完全不了解。
这篇文章从多元复合函数求导(矩阵求导的重要基础)入手,通过简单的例子进行切入,展示矩阵求导的完整细节。希望这篇文章对想要了解神经网络梯度计算细节的人有所帮助。
1 多元复合函数求导
首先我们先来介绍多元复合函数求导,这是矩阵求导的重要基础。
1.1 问题切入
问:考虑三个函数:\(z = f(u, v)\),\(u = g(x, y)\) 以及 \(v = h(x, y)\),假设偏导都存在,我们的目标是求出 \(z\) 关于 \(x\) 的偏导 \(\frac{\partial z}{\partial x}\) 和 \(z\) 关于 \(y\) 的偏导 \(\frac{\partial z}{\partial y}\)。
解:我们可以先画出变量之间关系的依赖图如下:
上述的依赖图对于我们理清楚变量之间关系,进而理清楚链式求导中的细节非常有帮助。根据上面的依赖图和链式求导法则,有
\(\frac{\partial z}{\partial x} = \frac{\partial z}{\partial u} \frac{\partial u}{\partial x} + \frac{\partial z}{\partial v} \frac{\partial v}{\partial x}\\\) \(\frac{\partial z}{\partial y} = \frac{\partial z}{\partial u} \frac{\partial u}{\partial y} + \frac{\partial z}{\partial v} \frac{\partial v}{\partial y}\)1.2 例子
我们举个更具体的例子:
问:假设有 \(z = u^2 + v^2\),其中 \(u = x + y\),\(v = x - y\),求 \(z\) 对 \(x\) 的偏导 \(\frac{\partial z}{\partial x}\) 和 \(z\) 对 \(y\) 的偏导 \(\frac{\partial z}{\partial y}\)。
解:z 对 u 和 v 的偏导分别为:
\(\frac{\partial z}{\partial u} = 2u\\\) \(\frac{\partial z}{\partial v} = 2v\)\(u\) 对 \(x\) 和 \(y\) 的偏导分别为:
\(\frac{\partial u}{\partial x} = 1\\\) \(\frac{\partial u}{\partial y} = 1\)\(v\) 对 \(x\) 和 \(y\) 的偏导分别为:
\(\frac{\partial v}{\partial x} = 1\\\) \(\frac{\partial v}{\partial y} = -1\)则有:
\(\frac{\partial z}{\partial x} = \frac{\partial z}{\partial u} \frac{\partial u}{\partial x} + \frac{\partial z}{\partial v} \frac{\partial v}{\partial x} = 2u + 2v\\\) \(\frac{\partial z}{\partial y} = \frac{\partial z}{\partial u} \frac{\partial u}{\partial y} + \frac{\partial z}{\partial v} \frac{\partial v}{\partial y} = 2u - 2v\)2 梯度计算
2.1 问题
在神经网络算法中,都会涉及到求参数的梯度问题,也即求损失函数 L 对 参数矩阵 W 梯度问题。我们来看这类求参数梯度问题的一个例子:
问:考虑 \(L = (y_{11} - a_{11})^2 + (y_{12} - a_{12})^2 + (y_{21} - a_{21})^2 + (y_{22} - a_{22})^2\)
其中,
\(\begin{bmatrix}
y_{11} & y_{12} \\
y_{12} & y_{22}
\end{bmatrix}
=
\begin{bmatrix}
w_{11} & w_{12} \\
w_{21} & w_{22}
\end{bmatrix}
\begin{bmatrix}
x_{11} & x_{12} \\
x_{21} & x_{22}
\end{bmatrix}
(Y= WX)\)
求 \(L\) 对 \(W\) 的梯度 \(\frac{\partial L}{\partial W}\)。
2.2 问题转化
上面的问题本质其实还是多元复合函数求导问题,是我们最开始提到的例子的一个简单的扩展。我们调整一下问题的呈现方式,问题就变成了:
考虑 \(L = (y_{11} - a_{11})^2 + (y_{12} - a_{12})^2 + (y_{21} - a_{21})^2 + (y_{22} - a_{22})^2\)
其中,
\(y_{11} = w_{11}x_{11} + w_{12}x_{21}\\\) \(y_{12} = w_{11}x_{12} + w_{12}x_{22}\\\) \(y_{21} = w_{21}x_{11} + w_{22}x_{21}\\\) \(y_{12} = w_{11}x_{12} + w_{12}x_{22}\)求解:
\(\begin{bmatrix}\frac{\partial L}{\partial w_{11}} & \frac{\partial L}{\partial w_{12}} \\
\frac{\partial L}{\partial w_{21}} & \frac{\partial L}{\partial w_{22}}
\end{bmatrix}\)
2.3 问题解决
前面提到了依赖图对于我们理清楚变量之间的依赖关系非常有帮助,上述函数中 \(L\)、\(y_{ij}\)、\(w_{ij}\) 有如下依赖关系:
由上述依赖关系,有
\(\frac{\partial L}{\partial w_{11}} = \frac{\partial L}{\partial y_{11}} \frac{\partial y_{11}}{\partial w_{11}} + \frac{\partial L}{\partial y_{12}} \frac{\partial y_{12}}{\partial w_{11}}\\\) \(\frac{\partial L}{\partial w_{12}} = \frac{\partial L}{\partial y_{11}} \frac{\partial y_{11}}{\partial w_{12}} + \frac{\partial L}{\partial y_{12}} \frac{\partial y_{12}}{\partial w_{12}}\) \(\frac{\partial L}{\partial w_{21}} = \frac{\partial L}{\partial y_{21}} \frac{\partial y_{21}}{\partial w_{21}} + \frac{\partial L}{\partial y_{22}} \frac{\partial y_{22}}{\partial w_{21}}\\\) \(\frac{\partial L}{\partial w_{21}} = \frac{\partial L}{\partial y_{21}} \frac{\partial y_{21}}{\partial w_{22}} + \frac{\partial L}{\partial y_{22}} \frac{\partial y_{22}}{\partial w_{22}}\)\(L\) 对 \(y_{11}\)、 \(y_{12}\)、 \(y_{21}\)、 \(y_{22}\) 的偏导分别为
\(\frac{\partial L}{\partial y_{11}} = 2(y_{11}-a_{11})\\\) \(\frac{\partial L}{\partial y_{12}} = 2(y_{12}-a_{12})\\\) \(\frac{\partial L}{\partial y_{21}} = 2(y_{21}-a_{21})\\\) \(\frac{\partial L}{\partial y_{22}} = 2(y_{22}-a_{22})\)\(y_{11}\) 对 \(w_{11}\) 和 \(w_{12}\) 的偏导分别为
\(\frac{\partial y_{11}}{\partial w_{11}} = x_{11}\\\) \(\frac{\partial y_{11}}{\partial w_{12}} = x_{21}\)\(y_{12}\) 对 \(w_{11}\) 和 \(w_{12}\) 的偏导分别为
\(\frac{\partial y_{12}}{\partial w_{11}} = x_{12}\\\) \(\frac{\partial y_{12}}{\partial w_{12}} = x_{22}\)\(y_{21}\) 对 \(w_{21}\) 和 \(w_{22}\) 的偏导分别为
\(\frac{\partial y_{21}}{\partial w_{21}} = x_{11}\\\) \(\frac{\partial y_{21}}{\partial w_{22}} = x_{21}\)\(y_{22}\) 对 \(w_{21}\) 和 \(w_{22}\) 的偏导分别为
\(\frac{\partial y_{22}}{\partial w_{21}} = x_{12}\\\) \(\frac{\partial y_{22}}{\partial w_{22}} = x_{22}\)因此
\(\frac{\partial L}{\partial w_{11}} = \frac{\partial L}{\partial y_{11}} \frac{\partial y_{11}}{\partial w_{11}} + \frac{\partial L}{\partial y_{12}} \frac{\partial y_{12}}{\partial w_{11}} = 2(y_{11}-a_{11})x_{11} + 2(y_{12}-a_{12})x_{12}\) \(\frac{\partial L}{\partial w_{12}} = \frac{\partial L}{\partial y_{11}} \frac{\partial y_{11}}{\partial w_{12}} + \frac{\partial L}{\partial y_{12}} \frac{\partial y_{12}}{\partial w_{12}} = 2(y_{11}-a_{11})x_{21} + 2(y_{12}-a_{12})x_{22}\) \(\frac{\partial L}{\partial w_{21}} = \frac{\partial L}{\partial y_{21}} \frac{\partial y_{21}}{\partial w_{21}} + \frac{\partial L}{\partial y_{22}} \frac{\partial y_{22}}{\partial w_{21}} = 2(y_{21}-a_{21})x_{11} + 2(y_{22}-a_{22})x_{12}\) \(\frac{\partial L}{\partial w_{21}} = \frac{\partial L}{\partial y_{21}} \frac{\partial y_{21}}{\partial w_{22}} + \frac{\partial L}{\partial y_{22}} \frac{\partial y_{22}}{\partial w_{22}} = 2(y_{21}-a_{21})x_{21} + 2(y_{22}-a_{22})x_{22}\)整理得
\(\begin{bmatrix}\frac{\partial L}{\partial w_{11}} & \frac{\partial L}{\partial w_{12}} \\
\frac{\partial L}{\partial w_{21}} & \frac{\partial L}{\partial w_{22}}
\end{bmatrix}
\\=
\begin{bmatrix}
2(y_{11}-a_{11})x_{11} + 2(y_{12}-a_{12})x_{12} & 2(y_{11}-a_{11})x_{21} + 2(y_{12}-a_{12})x_{22} \\
2(y_{21}-a_{21})x_{11} + 2(y_{22}-a_{22})x_{12} & 2(y_{21}-a_{21})x_{21} + 2(y_{22}-a_{22})x_{22}
\end{bmatrix}
\\=
\begin{bmatrix}
2(y_{11}-a_{11}) & 2(y_{12}-a_{12}) \\
2(y_{21}-a_{21}) & 2(y_{22}-a_{22})
\end{bmatrix}
\begin{bmatrix}
x_{11} & x_{21} \\
x_{12} & x_{22}
\end{bmatrix}\)
至此,我们已经成功求解了 \(L\) 对 \(W\) 的梯度,如果是在神经网络中就可以根据这个梯度更新参数值了。
3 结束
事实上,在上面例子中,我们还能发现
\(\frac{\partial L}{\partial W} = \frac{\partial L}{\partial Y} \frac{\partial Y}{\partial W}\)并且,在 \(Y = W X\) 时, 有
\(\frac{\partial Y}{\partial W} =\begin{bmatrix}
x_{11} & x_{21} \\
x_{12} & x_{22}
\end{bmatrix}
= X^T\)
希望通过这篇文章,大家能明白矩阵求导的本质还是多元复合函数求导,以及如何通过画依赖图的方式,理清矩阵求导具体的链式细节。
感谢观看,希望对大家有启发和帮助。
文章评论