1. 引言
近年来,深度神经网络在计算机视觉领域中的图像分类、目标检测、语义分割等任务上取得了显著成绩。然而,这些网络的结构和超参数往往需要人工选择,耗费了大量的人力。神经架构搜索(NAS)通过在预定义的搜索空间内自动搜索最佳网络架构,不仅有效减少了人工干预,还能够实现比人工设计网络更优异的性能。早期的NAS方法,如基于强化学习(RL) [1]和进化算法(EA) [2]的搜索方法,尽管能够找到高性能架构,但计算成本高,通常需要成千上万个GPU小时。高效神经架构搜索(ENAS) [3]采用在子架构之间共享权重的策略,将搜索时间降低到几个GPU天,极大地提高了NAS的实用性。之后,可微架构搜索(DARTS) [4]通过将离散搜索空间松弛为连续空间,并采用双层优化策略,实现了高效的端到端训练。为提升分类任务中的性能,基于梯度的可微架构搜索(GDAS) [5]方法通过可微分的Gumbel-Softmax来进行离散架构选择。稀疏梯度架构搜索(SGAS) [6]通过逐步修剪冗余操作来减少搜索空间。Shapley-NAS [7]通过评估候选操作的贡献来优化网络结构。基于可微分神经架构搜索(ADARTS) [8]提出了一种基于通道注意力的可微神经架构部分通道连接算法。基于爆炸引力场算法的神经架构搜索(EGFA-NAS) [9]通过爆炸引力场算法提升了全局搜索能力和搜索效率。
上述架构搜索方法中仍存在一些问题,如操作选择不合理导致整个网络性能下降[10]。注意力引导的架构搜索(AGNAS) [11]能够根据注意力机制对重要操作赋予高的权重,避免了操作被错误选择的问题。本文对AGNAS进行改进,将原有的通道注意力替换为瓶颈注意力(BAM)模块[12],以在通道和空间通道上选择性地强调重要特征,并利用权重来量化每个操作对网络的贡献。此外,本文在操作选择中加入文献[13]中的ShuffleUnit操作,结合深度可分离卷积(DWConv)和通道特征组之间交换通道(Channel shuffle)等运算,增强通道组之间的信息交换能力,进而提升模型的特征表达能力。最后,将改进后的AGNAS应用于图像分类任务中,提升图像分类的正确率。
2. 改进的AGNAS架构介绍
改进的AGNAS整体网络架构如图1所示,由多个单元(Cell)堆叠而成。Cell又分为Reduction Cell和Normal Cell两种类型。前者通过下采样操作减少特征图的空间分辨率并增加通道数,以此来压缩信息和减少计算开销;后者用于提取特征,并保持特征图的空间尺寸不变。每个Cell是一个有向无环图(DAG),由4个中间节点(图1中用0、1、2、3示意),及任意两个节点间9条边(图1中用4条边示意)组成。每个节点表示一组特征图,而每条边则表示从输入节点到输出节点的操作。每个Cell的输入为前两个Cell的输出特征,特别地,第一个Cell的输入为固定Stem层(由一个3 × 3卷积和批量归一化层构成)从原始输入图像中提取出特征。每个中间节点通过聚合来自其所有前节点的信息流来生成新的特征表示。输出节点是将中间节点的输出特征按通道维度进行拼接得到。
与原AGNAS相比,本文的改进之处包括两个方面:(1) 每个操作后,采用BAM模块替换原通道注意力模块,以在通道和空间两个维度上关注重要特征;(2) 在原来的8种操作(3 × 3和5 × 5可分离卷积sep_conv、3 × 3和5 × 5扩张可分离卷积dil_conv、3 × 3最大池化max_pool、3 × 3平均池化avg_pool、跳跃连接skip_connect、以及空操作none)基础上,引入ShuffleUnit操作,增加操作选择的多样性。
接下来,先介绍改进的AGNAS的架构搜索原理,然后详细介绍BAM模块和ShuffleUnit操作。
Figure 1. Architecture of the improved AGNAS
图1. 改进的AGNAS的网络架构
2.1. 架构搜索原理
对于用于分类任务的整体架构,其搜索过程中采用的损失函数为,
(1)
其中,N为样本数量,G为类别数量,
为第n个样本的真实标签,
为第n个样本被预测为类别g的概率值。
对于每个Cell,其架构搜索过程如下。首先,采用所有M = 9种候选操作对输入特征进行运算,并将生成的特征在通道维度上拼接,得到中间特征图
,其中,C表示特征通道数,H和W分别表示特征图的高和宽。
然后,对中间特征图F,采用BAM模块关注通道和空间重要特征,得到注意力权重
。进一步,定义任意的第m (
)种候选操作
的重要性
为,
. (2)
最后,按照候选操作重要性的最大值选择每条边上的最佳操作
,
. (3)
重复同样操作,直到选出其余各个节点最佳操作为止。将所有节点最佳操作进行组合,即可得到该Cell的最优网络架构。
2.2. BAM模块
BAM模块的结构如图1中红色虚线框内所示,包括通道和空间注意力两条路径。通道注意力路径用于选择性地增强或抑制特征图中不同通道的响应,而空间注意力用于选择性地增强或抑制特征图中不同空间位置的响应。对特征图
,分别计算相应的通道注意力
和空间注意力
。
在通道注意力路径上,先将特征图F通过全局平均池化(GAP),生成一个包含全局信息的通道向量
。然后,采用多层感知器(MLP) (用全连接运算实现)和批量归一化层(BN)来获取通道注意力
,计算公式为,
(4)
其中,
和
为两个全连接层的权重,
和
为两个全连接层的偏置,r为通道上的降维比例。
在空间注意力路径上,先采用1 × 1卷积来压缩特征图F的通道数,即将F降维为
,然后利用两个3 × 3膨胀卷积(DConv)捕捉上下文信息。最后,用1 × 1卷积(Conv)将特征维度变为
。整个计算过程可表示为,
(5)
其中,上标1 × 1和3 × 3分别表示卷积核的大小。
在获取
和
后,分别将两者的维度复制扩充至
,并采用Sigmoid函数将其映射至0到1范围内,得到合并后的注意映射
为,
(6)
其中,
表示Sigmoid函数,
表示逐元素求和。
进一步地,将
与F相乘,得到加权后的输出特征图
为,
(7)
其中,
表示逐元素相乘。
又将
和F进行逐元素求和得到
,并对
进行通道维度上同操作数的逐元素求和,得到输出节点的特征图
,计算公式为,
. (8)
2.3. ShuffleUnit操作
ShuffleUnit是ShuffleNet V2的一个组成单元,其结构如图2所示。首先,用两个分支对特征进行卷积运算。第一个分支主要的运算依次为1x1 Conv、步长为2的3 × 3 DWConv和1 × 1 Conv;第二个分支主要的运算依次为步长为2的3 × 3 DWConv和1 × 1 Conv。此外,两个分支中,1 × 1 Conv后面需要BN和ReLU激活运算,步长为2的3 × 3 DWConv后面需要BN运算。然后,将两个分支得到的特征图在通道维度上拼接。最后,将拼接的特征图划分为多个通道组,并利用Channel shuffle实现通道组之间的信息交换。
显然,DWConv和1 × 1 Conv的组合能有效提高计算效率,步长为2的DWConv能实现特征在空间上的降采样,两条分支可视为特征的重复使用。因此,ShuffleUnit是一个能高效完成具有更多特征通道和更大网络容量的信息提取单元,而且具有通道组之间信息交换能力。
Figure 2. ShuffleUnit operation
图2. ShuffleUnit操作
3. 实验结果
本文采用的数据集为CIFAR-10,包含60,000幅32 × 32的彩色图像。总共有10个类别目标(飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车),每类包含6000幅图像。由于本文的架构搜索需要验证集参与,因此,将数据集划分为训练、验证和测试集,比例为2.5:2.5:1。
实验软件环境为Ubuntu20.04系统,深度学习框架为pytorch1.12.1,硬件配置CPU为Intel® Xeon(R) W-2235 CPU @ 3.80 GHz × 12,内存64 G,GPU为NVIDA GeForce RTX 3090,显存24G。实验使用SGD优化器来更新网络权重,初始学习率为0.025,动量为0.9,权重衰减为3 × 10−4。
3.1. 架构搜索结果
为了在一个大的空间中搜索,先构建一个包含8个Cell (2个Reduction Cell和6个Normal Cell)的超网。其中,2个Reduction Cell分别放置在第2个(超网的1/3长度处)和第5个(超网的2/3长度处)位置处。采用的候选操作选择包括:3 × 3和5 × 5 sep_conv、3 × 3和5 × 5 dil_conv、shuffle_unit、max_pool_3 × 3、avg_pool_3 × 3、skip_connect、以及none这9种操作。搜索过程中,计算Cell中各节点每个候选操作对应的注意力权重,并选择具有最高注意力权重的操作。将所有被选中操作进行组合,即可得到该Cell的最优网络架构。每一轮搜索中,8个Cell的架构都会更新。
搜索结束后,为了平衡计算效率和性能,仅选择了其中5个Cell进行堆叠。其中,2个Reduction Cell被全部保留,6个Normal Cel中随机选出3个。选出的5个Cell的结构如图3(a)~(e)所示。其中图3(b)和(d)是Reduction Cell (Cell_2和Cell_5),其余为Normal Cell (Cell_1、Cell_3和Cell_6)。图3(a)中Cell_1采用了5 × 5 dil_conv和多个skip_connect,而且在节点1到节点3和两个输入节点分别到节点2的过程中均选择了shuffle_unit操作。Cell_2采用3 × 3和5 × 5的sep_conv、avg_pool_3 × 3和多个skip_connect,且在输入节点(前一个Cell的输出)到节点0的过程中选择了shuffle_unit操作。Cell_3采用3 × 3和5 × 5的sep_conv和3 × 3的dil_conv,并通过skip_connect将不同层的特征图融合,且也在输入节点(前两个Cell的输出)到节点1的过程中选择了shuffle_unit操作。Cell_5采用5 × 5 sep_conv、avg_pool_3 × 3和max_pool_3 × 3,且在输入节点(前两个Cell的输出)到节点3的过程中选择了shuffle_unit操作。Cell_6主要使用3 × 3和5 × 5的dil_conv、max_pool_3 × 3和3 × 3 sep_conv,并通过多个skip_connect融合特征。在这些选中的操作中,shuffle_unit操作被选择到的概率达到了15%,因此,shuffle_unit操作的引入不仅增加了操作的多样性,而且在架构用于分类的性能提高方面发挥着重要作用。
Figure 3. Detailed architecture of 5 cells searched on the CIFAR-10 dataset (Light blue represents input features; light purple represents output features; light yellow represents intermediate nodes)
图3. CIFAR-10数据集上搜索到的5个Cell的详细架构(淡蓝色表示输入特征;淡紫色表示输出特征;黄色表示中间节点)
受文献[11]的启发,将选出的5个Cell进行堆叠,形成一个新的包含20个Cell的网络,以用于最终的分类任务。堆叠过程中,将2个Reduction Cell (Cell_2和Cell_5)分别放置在网络的第7个和第14个位置;其余3个Normal Cell (Cell_1、Cell_3和Cell_6)各堆叠6次,结果如图4所示。采用CIFAR-10的训练集和验证集对该网络的参数进行训练,后进行分类测试,得到的混淆矩阵如表1所示。显然,每类目标的正确率均高于95%,10类目标的平均分类正确率为97.54%。
Figure 4. Classification network
图4. 分类网络
3.2. 消融实验
以原始的AGNAS模型作为基线,分别以及同时加入BAM模块和ShuffleUnit操作,得到测试集上的平均分类错误率如表2所示。基线模型得到的平均测试错误率为2.53%。当引入BAM模块后,测试错误率降低了0.03%;当引入ShuffleUnit操作时,测试错误率降低了0.04%;当同时引入BAM模块和ShuffleUnit操作时,测试错误率降低了0.07%。结果表明,引入BAM模块或ShuffleUnit操作均能够降低测试错误率;当同时引入两者时,分类性能进一步得到提升。
Table 1. Confusion matrix
表1. 混淆矩阵
Class |
airplane |
automobile |
bird |
cat |
deer |
dog |
frog |
horse |
ship |
truck |
Accuracy (%) |
airplane |
975 |
0 |
3 |
2 |
2 |
0 |
0 |
0 |
15 |
3 |
97.5 |
automobile |
0 |
987 |
0 |
0 |
1 |
0 |
0 |
0 |
1 |
11 |
98.7 |
bird |
2 |
0 |
972 |
6 |
4 |
5 |
10 |
0 |
1 |
0 |
97.2 |
cat |
1 |
1 |
2 |
947 |
5 |
35 |
6 |
1 |
2 |
0 |
94.7 |
deer |
0 |
0 |
3 |
8 |
980 |
4 |
2 |
3 |
0 |
0 |
98.0 |
dog |
1 |
1 |
4 |
31 |
6 |
954 |
0 |
3 |
0 |
0 |
95.4 |
frog |
0 |
0 |
3 |
1 |
1 |
1 |
994 |
0 |
0 |
0 |
99.4 |
horse |
0 |
0 |
1 |
4 |
2 |
6 |
0 |
987 |
0 |
0 |
98.7 |
ship |
8 |
7 |
2 |
0 |
0 |
0 |
1 |
0 |
982 |
0 |
98.2 |
truck |
1 |
17 |
0 |
0 |
0 |
0 |
1 |
0 |
5 |
976 |
97.6 |
Total |
|
|
|
|
|
|
|
|
|
|
97.54 |
Table 2. Results of the ablation experiment
表2. 消融实验结果
Baseline |
BAM |
ShuffleUnit |
Test error (%)↓ |
√ |
|
|
2.53 |
√ |
|
√ |
2.49 |
√ |
√ |
|
2.50 |
√ |
√ |
√ |
2.46 |
3.3. 与其他方法的对比
将本文方法与其他方法进行对比,各方法的测试错误率、模型参数量、搜索时间和搜索算法如表3所示。在基于非梯度搜索算法的架构搜索方法(NASNet-A [14]、AmoebaNet-B [2]、PNAS [15]和ENAS [3])中,NASNet-A和AmoebaNet-B两种方法的平均测试错误率分别达到了2.65%和2.55%,但它们的搜索时间却非常长,分别需要1800和3150 GPU天。这些方法依赖于计算密集型的RL学习和EA搜索算法,资源消耗巨大。而基于梯度(Gradient)算法的架构搜索方法(DARTS [4]、GDAS [5]、PC-DARTS [16]、SGAS [6]、DARTS + PT [10]和AGNAS [11]),虽在平均测试错误率上略高,但它们的参数量少、搜索时间短,优势更为明显。其中,PC-DARTS仅需0.1 GPU天。
本文改进模型获取的平均测试错误率达到了2.46%,错误率的波动范围为±0.001%,参数量为3.8 M,搜索时间仅为0.39 GPU天。与原始的AGNAS相比,不仅平均测试错误率降低,而且分类性能稳定。此外,在参数量上和计算效率上与原始的AGNAS相当。与其他梯度方法相比,平均测试错误率得到明显降低,分类性能更为稳定,而耗时在同一数量级上。因此,综合考虑分类错误率、性能稳定性、效率和资源消耗,本文方法比其他方法更具有优势。
Table 3. Comparison with other methods
表3. 与其他方法的比较
Methods |
Test error (%) |
Parameters (M) |
Search cost (GPU-days) |
Search algorithm |
NASNet-A [14] |
2.65 |
3.3 |
1800 |
RL |
AmoebaNet-B [2] |
2.55 ± 0.05 |
2.8 |
3150 |
EA |
PNAS [15] |
3.41 ± 0.09 |
3.2 |
225 |
SMBO |
ENAS [3] |
2.89 |
4.6 |
0.5 |
RL |
DARTS (1st order) [4] |
3.00 ± 0.14 |
3.3 |
1.5 |
Gradient |
DARTS (2st order) [4] |
2.76 ± 0.09 |
3.3 |
4 |
Gradient |
GDAS [5] |
2.93 |
3.4 |
0.21 |
Gradient |
PC-DARTS [16] |
2.57 ± 0.07 |
3.6 |
0.1 |
Gradient |
SGAS [6] |
2.66 ± 0.24 |
3.7 |
0.25 |
Gradient |
DARTS+PT [10] |
2.61 ± 0.08 |
3.0 |
0.8 |
Gradient |
AGNAS [11] |
2.53 ± 0.03 |
3.6 |
0.4 |
Gradient |
Ours |
2.46 ± 0.001 |
3.8 |
0.39 |
Gradient |
4. 结束语
本文对原始AGNAS进行改进,每个操作后引入BAM注意力机制,并增加一个新的ShuffleUnit操作。首先,分析了网络的架构搜索原理,并详细阐述了BAM注意力机制和ShuffleUnit操作的工作原理。然后,在CIFAR-10数据集上开展了图像分类任务实验,获取了适用于该数据集的架构。消融实验和对比实验结果表明,本文提出的改进模型比原AGNAS模型具有更低的分类错误率和更稳定的分类性能。此外,综合考虑测试错误率、参数量和搜索时间多方面因素,本文方法比其他方法优势更为明显。
基金项目
在此特别感谢江西省自然科学基金对本文的支持:20224BAB202002。