神经网络训练优化器及工具

前言:在上一篇ResNet实战博客中,笔者采用的都是最简单的激活函数和梯度下降训练算法,并且在训练到瓶颈期还需要手动停止训练并调整学习率。所以在本篇中我们将讨论训练优化算法和一些自动训练工具。

Basic:随机梯度下降

SGD:Stochastic gradient descent

这部分属于入门知识,可以看吴恩达老师的课程,讲得非常通俗易懂。

梯度下降

我们的模型(神经网络)本质上是一组不断调整参数的函数组合而成一个我们想要的复杂函数。而根据微积分和高维空间的知识我们知道,任何函数都存在导数(或Jacobian矩阵)并且都存在凹凸特性。而神经网络的训练就是基于这一点,通过人工设计一个目标损失函数,然后以最小化(或最大化)该函数为目标,不断计算输入数据在该损失上的梯度,并不断迭代寻找到函数最低点(上图)。我们将模型和目标损失函数表示为:

$$
\text{网络模型:}\hat{Y} = F(X)
\\
\text{损失函数:}L(\hat{Y},Y)|_{w,b}
$$

SGD训练参数过程:

$$
w := w-\alpha\cdot \frac{\partial L}{\partial w}
\\
b := b-\alpha\cdot \frac{\partial L}{\partial b}
$$

损失函数计算的并不是模型对单个输入数据(样本)的推理结果,而是一批数据,一批数据中样本的个数也就是常说的Batch Size。对于损失函数等可查看笔者先前的这篇神经网络常用函数

Momentum

SGD每次的梯度更新都是直接根据当前输入数据进行计算的。现在考虑如下图情况,蓝线为SGD根据在每个转折点计算梯度前进的结果,显然从起始点到最终的最优点(绿色)有很明显的震荡现象。

梯度下降

如果我们的训练可以减小震荡朝着目标点前进,那么就可以减少很多计算量并且可以更稳定地达到目标点,这就是优化算法要做的事情。减小高频震荡?这很容易让我们想到滤波器(笔者是电信出身,所以认为这么理解更直观,比什么指数加权平均好理解多了),Momentum就相当于是最简单一阶滤波器(效果如上红线),从数学定义上也确实如此:

$$
v= \beta v +(1-\beta)\frac{\partial L}{\partial w}
\\
w : = w - \alpha v
$$

如果上面滤波器系数为0,那么相当于直接退化回SGD。

RMSprop

这个优化算法相对Momentum的改进,在数学上正如它的名字RMS,把原本的梯度滤波内容改为了均方根:

$$
s = \beta s + (1-\beta)[\frac{\partial L}{\partial w} ]^2
\\
w:= w -\alpha \frac{\frac{\partial L}{\partial w}}{\sqrt{s+\epsilon}},\epsilon \to 0
$$

对于其效果的理解,可以把第一个一阶滤波器当作对强度滤波,比如还是上面Momentum例中的模型,显然在纵轴梯度强度比横轴大很多,那么就要削弱纵轴的梯度强度。而第二个参数更新式,通过除去梯度强度滤波器的开方值得到优化后的梯度。

笔者个人认为这个算法不是很好,因为当滤波前梯度本身就比较平滑的时候,考虑一个二维的场,那相当最后算出来的是单位长度、方向只有四个方向(±45°、±135°)的梯度向量

Adam

Adaptive moment estimation

$$
v = \beta_1 v +(1-\beta_1)\frac{\partial L}{\partial w}
\\
s = \beta_2 s + (1-\beta_2)[\frac{\partial L}{\partial w} ]^2
\\
w:= w -\alpha \frac{v}{\sqrt{s+\epsilon}},\epsilon \to 0
$$

看完了公式,首先就能感觉到又有点Momentum的味道(式1)又有点RMSprop的味道(式2),事实上确实如此。Adam就是在对梯度方向做滤波的基础上,再对强度做控制。

好用的工具们

TIP:接下来的都是以使用pytorch框架为基础

数据可视化

在上一篇ResNet博客中其实就已经简单的用到了TensorBoard用来可视化训练效果。这个工具原生是在TensorFlow框架下的,由于太好用所以现在适配到pytorch了。

首先要开一个数据写入管理对象:

1
writer = tensorboard.SummaryWriter('./log/ResNet-CIFAR10/')

入参显然就是一个路径,指定要存储的数据在硬盘上的位置。接着,就可以调用相应的add_type接口将想要的数据存储起来了:

1
2
3
4
5
6
7
8
# 数值类
writer.add_scalar("准确度", 100 * correct / total, epoch)

# 添加网络图
writer.add_graph(net,(netInputBatch,))

# 添加图像
writer.add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')

笔者此次仅举例几个常用的API,其他建议读者直接查官方文档。若训练完成或其他不需要继续记录数据的时候,可安全地关闭数据管理对象:

1
writer.close()

接着,就是一个很香的部分了——以网页方式查看存储的数据。在命令行终端上运行:

1
tensorboard --logdir=./log/ResNet-CIFAR10/ --port 8686

然后就可以在浏览器中输入对应的网址和端口了:

cmd

tensorBoard

这使得我们可以很方便地查看非本地训练时远程负责科学计算的服务器上的模型状况。

自动学习率调整器

笔者从一个专门做深度学习的朋友了解到,虽然Adam等优化算法已经非常好用了,但是模型训练达到瓶颈后,工程上大牛还是会使用最原始的SGD并手动调整学习率以达到更好的效果。

根据我们对梯度下降的场分析可以知道,学习率太大的话,模型无法下降到最低点或稳定停在最低点,而是在附近震荡。所以,我们手动调整时,最直接的逻辑就是目前采用的学习率使得网络效果不再增加时,就适量减少学习率(减少多少就看经验了)。

pytorch中提供了这样一个逻辑的自动学习率调整器,使用方法非常简单:

1
2
3
4
5
6
7
8
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min')
for epoch in range(10):
# 模型前向及方向传播(梯度求解
pass

# 不直接使用优化器,而是使用调整器
scheduler.step(val_loss)
Donate
  • Copyright: Copyright is owned by the author. For commercial reprints, please contact the author for authorization. For non-commercial reprints, please indicate the source.
  • Copyrights © 2022-2024 RY.J
  • Visitors: | Views:

请我喝杯咖啡吧~

支付宝
微信