陈银波的知识小站

  • 首页
  • 数学
  • 算法
  • 图
  • 数据
复杂 = f (简单1, 简单2, ... , 简单n)
  1. 首页
  2. 数学
  3. 正文

神经网络梯度计算:从简单例子切入

4 8 月, 2024 1091点热度 1人点赞 0条评论

0 前言

在初次接触神经网络的时候,不知道有没有人像我一样,对梯度计算的过程很模糊。可能大体上知道是使用链式法则来做,但是平常接触的都是一元复合函数的求导,像神经网络中这种对参数矩阵求导的具体细节完全不了解。

这篇文章从多元复合函数求导(矩阵求导的重要基础)入手,通过简单的例子进行切入,展示矩阵求导的完整细节。希望这篇文章对想要了解神经网络梯度计算细节的人有所帮助。

1 多元复合函数求导

首先我们先来介绍多元复合函数求导,这是矩阵求导的重要基础。

1.1 问题切入

问:考虑三个函数:\(z = f(u, v)\),\(u = g(x, y)\) 以及 \(v = h(x, y)\),假设偏导都存在,我们的目标是求出 \(z\) 关于 \(x\) 的偏导 \(\frac{\partial z}{\partial x}\) 和 \(z\) 关于 \(y\) 的偏导 \(\frac{\partial z}{\partial y}\)。

解:我们可以先画出变量之间关系的依赖图如下:

null

上述的依赖图对于我们理清楚变量之间关系,进而理清楚链式求导中的细节非常有帮助。根据上面的依赖图和链式求导法则,有

\(\frac{\partial z}{\partial x} = \frac{\partial z}{\partial u} \frac{\partial u}{\partial x} + \frac{\partial z}{\partial v} \frac{\partial v}{\partial x}\\\) \(\frac{\partial z}{\partial y} = \frac{\partial z}{\partial u} \frac{\partial u}{\partial y} + \frac{\partial z}{\partial v} \frac{\partial v}{\partial y}\)

1.2 例子

我们举个更具体的例子:

问:假设有 \(z = u^2 + v^2\),其中 \(u = x + y\),\(v = x - y\),求 \(z\) 对 \(x\) 的偏导 \(\frac{\partial z}{\partial x}\) 和 \(z\) 对 \(y\) 的偏导 \(\frac{\partial z}{\partial y}\)。

解:z 对 u 和 v 的偏导分别为:

\(\frac{\partial z}{\partial u} = 2u\\\) \(\frac{\partial z}{\partial v} = 2v\)

\(u\) 对 \(x\) 和 \(y\) 的偏导分别为:

\(\frac{\partial u}{\partial x} = 1\\\) \(\frac{\partial u}{\partial y} = 1\)

\(v\) 对 \(x\) 和 \(y\) 的偏导分别为:

\(\frac{\partial v}{\partial x} = 1\\\) \(\frac{\partial v}{\partial y} = -1\)

则有:

\(\frac{\partial z}{\partial x} = \frac{\partial z}{\partial u} \frac{\partial u}{\partial x} + \frac{\partial z}{\partial v} \frac{\partial v}{\partial x} = 2u + 2v\\\) \(\frac{\partial z}{\partial y} = \frac{\partial z}{\partial u} \frac{\partial u}{\partial y} + \frac{\partial z}{\partial v} \frac{\partial v}{\partial y} = 2u - 2v\)

2 梯度计算

2.1 问题

在神经网络算法中,都会涉及到求参数的梯度问题,也即求损失函数 L 对 参数矩阵 W 梯度问题。我们来看这类求参数梯度问题的一个例子:

问:考虑 \(L = (y_{11} - a_{11})^2 + (y_{12} - a_{12})^2 + (y_{21} - a_{21})^2 + (y_{22} - a_{22})^2\)

其中,
\(\begin{bmatrix}
y_{11} & y_{12} \\
y_{12} & y_{22}
\end{bmatrix}
=
\begin{bmatrix}
w_{11} & w_{12} \\
w_{21} & w_{22}
\end{bmatrix}
\begin{bmatrix}
x_{11} & x_{12} \\
x_{21} & x_{22}
\end{bmatrix}
(Y= WX)\)

求 \(L\) 对 \(W\) 的梯度 \(\frac{\partial L}{\partial W}\)。

2.2 问题转化

上面的问题本质其实还是多元复合函数求导问题,是我们最开始提到的例子的一个简单的扩展。我们调整一下问题的呈现方式,问题就变成了:

考虑 \(L = (y_{11} - a_{11})^2 + (y_{12} - a_{12})^2 + (y_{21} - a_{21})^2 + (y_{22} - a_{22})^2\)

其中,

\(y_{11} = w_{11}x_{11} + w_{12}x_{21}\\\) \(y_{12} = w_{11}x_{12} + w_{12}x_{22}\\\) \(y_{21} = w_{21}x_{11} + w_{22}x_{21}\\\) \(y_{12} = w_{11}x_{12} + w_{12}x_{22}\)

求解:

\(\begin{bmatrix}
\frac{\partial L}{\partial w_{11}} & \frac{\partial L}{\partial w_{12}} \\
\frac{\partial L}{\partial w_{21}} & \frac{\partial L}{\partial w_{22}}
\end{bmatrix}\)

2.3 问题解决

前面提到了依赖图对于我们理清楚变量之间的依赖关系非常有帮助,上述函数中 \(L\)、\(y_{ij}\)、\(w_{ij}\) 有如下依赖关系:

由上述依赖关系,有

\(\frac{\partial L}{\partial w_{11}} = \frac{\partial L}{\partial y_{11}} \frac{\partial y_{11}}{\partial w_{11}} + \frac{\partial L}{\partial y_{12}} \frac{\partial y_{12}}{\partial w_{11}}\\\) \(\frac{\partial L}{\partial w_{12}} = \frac{\partial L}{\partial y_{11}} \frac{\partial y_{11}}{\partial w_{12}} + \frac{\partial L}{\partial y_{12}} \frac{\partial y_{12}}{\partial w_{12}}\) \(\frac{\partial L}{\partial w_{21}} = \frac{\partial L}{\partial y_{21}} \frac{\partial y_{21}}{\partial w_{21}} + \frac{\partial L}{\partial y_{22}} \frac{\partial y_{22}}{\partial w_{21}}\\\) \(\frac{\partial L}{\partial w_{21}} = \frac{\partial L}{\partial y_{21}} \frac{\partial y_{21}}{\partial w_{22}} + \frac{\partial L}{\partial y_{22}} \frac{\partial y_{22}}{\partial w_{22}}\)

\(L\) 对 \(y_{11}\)、 \(y_{12}\)、 \(y_{21}\)、 \(y_{22}\) 的偏导分别为

\(\frac{\partial L}{\partial y_{11}} = 2(y_{11}-a_{11})\\\) \(\frac{\partial L}{\partial y_{12}} = 2(y_{12}-a_{12})\\\) \(\frac{\partial L}{\partial y_{21}} = 2(y_{21}-a_{21})\\\) \(\frac{\partial L}{\partial y_{22}} = 2(y_{22}-a_{22})\)

\(y_{11}\) 对 \(w_{11}\) 和 \(w_{12}\) 的偏导分别为

\(\frac{\partial y_{11}}{\partial w_{11}} = x_{11}\\\) \(\frac{\partial y_{11}}{\partial w_{12}} = x_{21}\)

\(y_{12}\) 对 \(w_{11}\) 和 \(w_{12}\) 的偏导分别为

\(\frac{\partial y_{12}}{\partial w_{11}} = x_{12}\\\) \(\frac{\partial y_{12}}{\partial w_{12}} = x_{22}\)

\(y_{21}\) 对 \(w_{21}\) 和 \(w_{22}\) 的偏导分别为

\(\frac{\partial y_{21}}{\partial w_{21}} = x_{11}\\\) \(\frac{\partial y_{21}}{\partial w_{22}} = x_{21}\)

\(y_{22}\) 对 \(w_{21}\) 和 \(w_{22}\) 的偏导分别为

\(\frac{\partial y_{22}}{\partial w_{21}} = x_{12}\\\) \(\frac{\partial y_{22}}{\partial w_{22}} = x_{22}\)

因此

\(\frac{\partial L}{\partial w_{11}} = \frac{\partial L}{\partial y_{11}} \frac{\partial y_{11}}{\partial w_{11}} + \frac{\partial L}{\partial y_{12}} \frac{\partial y_{12}}{\partial w_{11}} = 2(y_{11}-a_{11})x_{11} + 2(y_{12}-a_{12})x_{12}\) \(\frac{\partial L}{\partial w_{12}} = \frac{\partial L}{\partial y_{11}} \frac{\partial y_{11}}{\partial w_{12}} + \frac{\partial L}{\partial y_{12}} \frac{\partial y_{12}}{\partial w_{12}} = 2(y_{11}-a_{11})x_{21} + 2(y_{12}-a_{12})x_{22}\) \(\frac{\partial L}{\partial w_{21}} = \frac{\partial L}{\partial y_{21}} \frac{\partial y_{21}}{\partial w_{21}} + \frac{\partial L}{\partial y_{22}} \frac{\partial y_{22}}{\partial w_{21}} = 2(y_{21}-a_{21})x_{11} + 2(y_{22}-a_{22})x_{12}\) \(\frac{\partial L}{\partial w_{21}} = \frac{\partial L}{\partial y_{21}} \frac{\partial y_{21}}{\partial w_{22}} + \frac{\partial L}{\partial y_{22}} \frac{\partial y_{22}}{\partial w_{22}} = 2(y_{21}-a_{21})x_{21} + 2(y_{22}-a_{22})x_{22}\)

整理得

\(\begin{bmatrix}
\frac{\partial L}{\partial w_{11}} & \frac{\partial L}{\partial w_{12}} \\
\frac{\partial L}{\partial w_{21}} & \frac{\partial L}{\partial w_{22}}
\end{bmatrix}
\\=
\begin{bmatrix}
2(y_{11}-a_{11})x_{11} + 2(y_{12}-a_{12})x_{12} & 2(y_{11}-a_{11})x_{21} + 2(y_{12}-a_{12})x_{22} \\
2(y_{21}-a_{21})x_{11} + 2(y_{22}-a_{22})x_{12} & 2(y_{21}-a_{21})x_{21} + 2(y_{22}-a_{22})x_{22}
\end{bmatrix}
\\=
\begin{bmatrix}
2(y_{11}-a_{11}) & 2(y_{12}-a_{12}) \\
2(y_{21}-a_{21}) & 2(y_{22}-a_{22})
\end{bmatrix}
\begin{bmatrix}
x_{11} & x_{21} \\
x_{12} & x_{22}
\end{bmatrix}\)

至此,我们已经成功求解了 \(L\) 对 \(W\) 的梯度,如果是在神经网络中就可以根据这个梯度更新参数值了。

3 结束

事实上,在上面例子中,我们还能发现

\(\frac{\partial L}{\partial W} = \frac{\partial L}{\partial Y} \frac{\partial Y}{\partial W}\)

并且,在 \(Y = W X\) 时, 有

\(\frac{\partial Y}{\partial W} =
\begin{bmatrix}
x_{11} & x_{21} \\
x_{12} & x_{22}
\end{bmatrix}
= X^T\)

希望通过这篇文章,大家能明白矩阵求导的本质还是多元复合函数求导,以及如何通过画依赖图的方式,理清矩阵求导具体的链式细节。

感谢观看,希望对大家有启发和帮助。

标签: 暂无
最后更新:4 8 月, 2024

陈银波

邮箱:agwave@foxmail.com 知乎:https://www.zhihu.com/people/agwave github:https://github.com/agwave leetcode:https://leetcode.cn/u/agwave

点赞
< 上一篇

文章评论

razz evil exclaim smile redface biggrin eek confused idea lol mad twisted rolleyes wink cool arrow neutral cry mrgreen drooling persevering
取消回复

文章目录
  • 0 前言
  • 1 多元复合函数求导
    • 1.1 问题切入
    • 1.2 例子
  • 2 梯度计算
    • 2.1 问题
    • 2.2 问题转化
    • 2.3 问题解决
  • 3 结束
分类
  • 图
  • 工程
  • 数学
  • 数据
  • 算法
最新 热点 随机
最新 热点 随机
Change Data Capture (CDC) 技术初探 IPv6在物联网中的应用 IPv6首部的改进:简化与优化网络通信 IPv6:下一代互联网协议 联邦图学习:连接联邦学习与图神经网络的新桥梁
二次型化标准型的应用:最值求解图注意力网络(GAT):一个例子解释从输入到输出维度变化的完整过程图卷积网络(GCN):一个例子解释从输入到输出维度变化的完整过程联邦图学习:连接联邦学习与图神经网络的新桥梁IPv6首部的改进:简化与优化网络通信
Go:goroutine 与 channel 的优雅并发编程实践 IPv6首部的改进:简化与优化网络通信 联邦图学习:连接联邦学习与图神经网络的新桥梁 PageRank计算过程与直觉:从简单例子切入 神经网络梯度计算:从简单例子切入
归档
  • 2024 年 10 月
  • 2024 年 9 月
  • 2024 年 8 月
  • 2024 年 7 月
  • 2024 年 6 月
  • 2024 年 5 月

COPYRIGHT © 2024 陈银波的知识小站. ALL RIGHTS RESERVED.

Theme Kratos Made By Seaton Jiang

粤ICP备2024254302号-1

粤公网安备44030002003798号