1. 引言
在过去的几十年中,深度神经网络推动了计算机视觉领域的蓬勃发展,在比如图像分类[1]-[3]、目标检测[4] [5]、和分割[6] [7]方面获得了巨大的性能提升。但是,由于强大的网络性能常常依赖于庞大的模型容量,它们严重依赖于算力和存储资源,在某些特定环境比如移动设备上无法完成部署。为了解决这一问题,许多方法被提出用来压缩模型的大小。知识蒸馏(Knowledge Distillation)正是其中的一种方法。具体来说,知识蒸馏框架主要包括一个大型模型(教师)和一个小型模型(学生),通过将知识从教师转移到学生的方法,在不增加额外成本的前提下提高了小型模型(学生)的网络性能。
主流的知识蒸馏方法主要分为软目标蒸馏和特征蒸馏。软目标蒸馏仅仅在软目标层面,通过最小化教师和学生之间的KL散度(Kullback-Leibler Divergence) [8] [9]来转移知识。为了更好地利用教师的知识,最近的研究更加专注于教师模型的特征层,通过匹配教师和学生之间的特征分布来进行蒸馏。这种方法被称为特征蒸馏[10]。以前基于特征的蒸馏方法通常会让学生尽可能模仿老师的输出,因为老师的特征具有更强的表现力。然而,本文认为没有必要直接模仿老师来提高学生特征的表征能力。用于蒸馏的特征一般是通过深度网络得到的高阶语义信息。特征像素已经在一定程度上包含了相邻像素的信息。
面对这种情况,本文重点提出了一种掩码生成式解耦特征蒸馏算法(MDKD),这是一种简单高效的基于特征的蒸馏方法。具体来说,首先,本文方法屏蔽了学生特征的随机像素,然后通过一个简单的模块使用屏蔽特征生成教师的完整特征。由于每次迭代都使用随机像素,因此在整个训练过程中将使用所有像素,这意味着特征将更加鲁棒,并且其表示能力将得到提高。其次,本文还提出了一种解耦空间金字塔池化知识蒸馏(DSPP),该方法中应用了空间金字塔池化[11]架构来自动捕获知识,它可以有效地捕获特征图不同尺度的信息知识。然后,基于特征图中较低激活区域在KD中发挥更重要作用的观察结果,即较低激活区域包含更多信息知识线索,设计了一个解耦模块来分析学生和教师网络之间的区域级语义损失。通过利用空间金字塔池化和解耦的区域级损失分配,可以通过更复杂的监督有效优化学生网络。大量实验数据证明,在主流基准测试中我们的方法在同构及异构网络知识蒸馏配置中优于现有的蒸馏技术。
2. 相关知识
知识蒸馏(Knowledge Distillation, KD)的概念最早由Hinton等人提出。涉及一个大型的教师模型和一个轻量级的学生模型。该框架的目标是将教师模型中的知识提炼并转移至学生模型中。具体操作上,通过最小化教师和学生模型预测之间的差异,迫使学生模型模仿教师的输出。知识蒸馏旨在通过从较大教师网络中提取的暗知识,提升较小学生网络的性能。
根据现有的知识,主要把知识蒸馏方法分为三类:基于软目标的蒸馏[8] [9] [12]-[15]、基于中间特征的蒸馏[11] [18]-[23]和基于关系的蒸馏[24]。软目标蒸馏方法侧重于模型输出的Logit,而特征蒸馏方法则侧重于模型内部特征的提取和转移,现有的基于关系的知识蒸馏方法则对不同层与数据样本之间以及不同样本之间的关系进行了探究。这些方法各有其特点和应用场景,为模型压缩和知识蒸馏提供了多样的技术选择。
软目标蒸馏方法通过输出的Logit实现知识的提取。先前的关于软目标蒸馏的研究主要集中在开发有效优化方法上,例如,DKD [8]通过用一个常数值替换与教师置信度负相关的系数,从而将损失从整体软目标蒸馏中解耦出来,提高了对预测良好样本的蒸馏效果。DML [13]通过同时训练两个小型学生网络,利用分类损失函数和模仿对方网络的损失函数,实现学生网络的类别后验对齐。TAKD [12]则引入了一个名为“教师助理”的中型网络,旨在缩小教师和学生之间的差距。MLKD [14]通过实例、批次和类级别的多级预测对齐,优化了知识转移过程。CTKD [15]通过动态调整学习阶段的任务难度,实现了渐进式学习。CTKD以易到难的课程形式,逐步提升蒸馏损失的温度,从而增加学习任务的难度。此外,还有一些研究[16] [17]侧重于对经典知识蒸馏方法进行解释。基于软目标的知识蒸馏算法简单,可以与其他知识蒸馏方法结合,但单独使用的话,模型的性能提升有限。除此之外,这种知识蒸馏方法依赖于softmax函数,只能用于分类相关的任务,并与类别的数量有关。最后,这些方法不能应用于无标签任务。
为了进一步促进知识蒸馏,提出了特征蒸馏这一新的研究方向,该方法对中间特征而不是软目标输出进行蒸馏。具体而言,FitNets [10]扩展了Hinton提出的知识蒸馏(KD) [8]方法,通过利用教师模型特征提取器的中间层输出作为提示,结合KD对更深且更窄的学生模型进行知识传递。AT [21]通过提取复杂模型生成的注意力图来指导简单模型,使其生成的注意力图与复杂模型相似。与FitNet不同,AT采用了多个层来进行知识转移,来捕获多层次信息,从而更好地提高学生的性能。MGD [23]通过随机屏蔽学生模型特征的像素并重构教师模型的完整特征来增强学生模型的表征能力。CRD [20]通过对比性学习来训练学生模型,使其能在教师的数据表述中捕捉到更多信息的新目标。其他研究[25]则通过提取输入的相关性来传递教师的知识。此外,还有一些研究通过同时对中间特种和软目标输出一同蒸馏,SAKD [26]在整个蒸馏期间的每次训练迭代中自适应地确定每个样本的教师网络中的中间特征和软目标层的蒸馏点。CAT-KD [27]引入了类别注意力转移(CAT)和类激活映射(CAM)转移,通过转移它们来增强学生模型的性能。研究者们提出了很多有效的基于特征知识的蒸馏方法,但是如何选取教师模型的暗示依然是一个值得探究的问题。比较常用的选取策略是跨层选取,例如12层模型蒸馏到4层模型,选取策略为每三层的输出选取一层作为暗示。其他有效的选取策略有待进一步挖掘。
与之前只关注单个样本的输出结果不同,RKD [24]将输出样本之间的结构关系迁移给学生。通过将同批次的两个样本之间的距离关系以及三个样本之间的角度关系作为知识传递给学生,让学生学习教师模型的结构化信息。
3. 方法
3.1. 知识蒸馏
对于包含C类的样本,其分类概率可以表示为
,其中
表示第
类的概率,C为该数据集类别数。P中的每个元素都可以通过softmax函数获得:
(1)
其中
表示第i类的logit值,T表示用于温度缩放的超参数。在知识蒸馏中,T通常大于1.0,这有助于帮助学生更好的学习教师网络。知识蒸馏的过程是通过温度变化控制教师网络输出,最小化教师模型和学生模型输出之间的KL散度来实现蒸馏。
(2)
对于不同的任务,模型的架构差异很大。此外,大多数蒸馏方法都是为特定任务而设计的。然而,基于特征的蒸馏可以应用于分类和密集预测。特征蒸馏的基本方法可以表述为:
(3)
其中
和
分别表示教师和学生的特征,
是学生的特征与教师的特征的对齐的适应层。C,H,W表示特征图的形状。
在本节中,主要讨论了基于特征的知识蒸馏的机制,整体结构如图1所示。以往的特征蒸馏方法可以指导学生直接模仿老师的特征,然而,本文提出了掩码生成蒸馏,指导学生生成老师的特征而不是模仿他。此外,本文提出了一种新的解耦金字塔池化蒸馏方法定义知识,以及利用解耦特征图来优化知识蒸馏(KD)训练过程,从而更充分地利用中间层知识。
Figure 1. General structure diagram
图1. 总体结构图
3.2. 掩码特征生成
对于CNN模型来说,深层次的特征具有更大的感受野和更好的原始输入图像表示。也就是说,特征像素已经一定程度上包含了相邻像素的信息。因此,使用部分像素来恢复完整的特征图的方式是可以的,本文的方法旨在通过学生的屏蔽特征生成教师的特征,这可以帮助学生获得更好的表示。图2为本文提出的掩码特征生成结构图。
Figure 2. Mask feature generation structure diagram
图2. 掩码特征生成结构图
本文分别将教师和学生的第l个特征图表示为
和
。首先,设置第l个随机掩码来覆盖学生的第l个特征,其表示为:
(4)
其中
是(0, 1)中的随机数,i,j分别是特征图的水平和垂直坐标。λ是一个超参数,表示掩码的比例。第l个特征图将会被第l个随机掩码所覆盖。
然后我们使用相应的掩码来覆盖学生的特征图,并尝试用左侧像素生成教师的特征图,具体公式如下:
(5)
(6)
其中,
表示为包括两个卷积层Wl1,Wl2和一个激活层ReLU的投影层。在本文中,适应层被设置为1 × 1的卷积层,Wl1,Wl2采用3 × 3的卷积层。
根据以上方法,本文的掩码特征生成蒸馏的损失可以被表示为:
(7)
其中L为蒸馏层的数目,C,H,W表示特征图的形状大小。S和T分别表示学生和教师的特征。
3.3. 解耦金字塔池知识蒸馏
本文提出的DSPP架构如图3所示,学生网络和教师网络通过解耦的空间金字塔池相互交互。从文献[8]中能够得知,logit实际上是类别的概率分布,其对于学生网络过于抽象而使得学生网络无法获得全面的信息性知识。此外,由于模型大小的不同,找到适当匹配的教师和学生模型的提示层是非常困难的。由于以上原因,本文选择使用最后一个提示层来解决logit高度抽象问题以及匹配提示层的复杂问题。在最后一层提示层中,本文使用了转换的操作对齐了学生和教师的最后一层提示层。受到文献[28]的启发,本文引入了一种新的方法来定义提示层中的知识,应用空间金字塔池来捕获不同尺度的提示层中的知识,从而解决了最后一层提示层可能存在的知识过于集中而不全面的问题。此外,本文还提出了一个解耦模块来提高了空间金字塔中较低激活区域的重要性。
Figure 3. Decoupled pyramid pool knowledge distillation structure diagram
图3. 解耦金字塔池知识蒸馏结构图
3.3.1. 空间金字塔池
空间金字塔池化技术最初在文献[11]中被引入至视觉识别领域,该创新策略显著地解除了卷积神经网络对于固定输入尺寸的依赖。鉴于教师模型与学生模型在架构设计上存在的差异,本文采用了空间金字塔池化方法,以应对两者在最后一个特征层(提示层)上形状不匹配的问题。进一步地,空间金字塔池化通过其分层结构,为最后一个特征层引入了多尺度的感受野,这一特性赋予了本文模型从该特征层中同时捕获全局与局部知识的能力。具体而言,空间金字塔池化的实施过程可形式化描述如下:
(8)
其中,L表示输入提示层,W(·)表示计算L的长度函数,k为金字塔层的层数,k共有n层。函数Pooling(·)中总包含两个参数,输入特征图以及池化核的大小。
众所周知,学生模型的logit来自全连接层,与教师模型是高度相似,即它们在一个数据集上的预测是相同的。然而,对于图像分类任务,全连接层缺少了输入图像的二维或三维空间信息。如前所述,很难找到教师和学生模型的提示层的适当匹配,并且这可能会降低KD的可解释性,这是本文选择了最后一个提示层的动机。此外,全连接层直接从最后一个提示层计算得出,该层在理论上最接近所有提示层中的logit。
3.3.2. 解耦模块
为了更多的关注较低的激活区域,本文提出了一个解耦模块来处理空间金字塔池化的扁平化特征。在解耦模块中,根据特征中的每个元素的值将扁平化的特征解耦为两个组件。如图4所示,学生特征Vs通过双向箭头与教师特征Vt进行元素匹配。红色箭头指向Vt中的n个最大元素,其另一端指向Vs中相应的位置。相反,蓝色箭头指向Vt中的最后一个尾部(N − n)个元素,其中N表示Vs或Vt的长度,SPP的损失可以计算为:
Figure 4. Decoupling module structure diagram
图4. 解耦模块结构图
(9)
其中top(·)表示Vt中top-(·)元素的索引,tail(·)表示Vt中tail-(·)元素的索引,函数L2(·)表示L2范数距离。θ和μ是控制解耦权重的超参数。为了提高较低激活区域的重要性,本文让μ大于θ。本文方法更加关注较低激活区域的原因是,具有大量参数的强大教师模型可能具有更复杂的机制来查找反应较低激活区域的输入的更多细节。较低的激活区域有助于提高学生模型的准确性和泛化性。
综上所述,本文将DSPP应用到知识蒸馏任务中去,并将其与本文提出的掩码特征生成蒸馏相结合,其公式如下:
(10)
其中
代表目标与仅来自学生模型预测之间的广泛使用的交叉熵损失,α,β分别用于平衡
和
的权重。
通过将损失的三个部分整合到一起,本文的方法,不仅通过学生的屏蔽特征生成教师的特征,帮助学生获得更好的表示,也应用解耦的语义损失分配来提高在KD中发挥更重要作用的较低激活区域的权重,旨在减轻学生网络的训练难度。通过该方法,可以让学生更加学习更加丰富的教师知识,对于学生网络的性能提升起着重要的作用。
4. 实验
4.1. 数据集与设置
在本文的实验中,我们对图像分类的性能进行了评估。数据集方面,本文选择了两个广泛研究的数据集:1) CIFAR-100 [29],这是一个著名的图像分类数据集,包含100个类别的32 × 32像素的图片,其中50000张图片作为训练集,10000张作为验证集。2) Tiny-ImageNet [30],这是一个大规模的分类数据集,是图像分类领域最重要的基准数据集之一,包含200个类别的100000张图像(每个类别500张),缩小为64 × 64彩色图像。每个类别有500张训练图像、50张验证图像和50张测试图像。
设置方面,本文的实验着重在知识蒸馏上,具体包括了两种不同的设置:1) 同构架构,即教师模型与学生模型采用相同的模型架构,仅仅是模型层数不相同,例如ResNet56和ResNet20。2) 异构架构,即教师模型的模型架构是与学生模型完全不相同的,例如ResNet 32 × 4和ShuffleNetV2。本文的实验包括了多种神经网络架构,如ResNet [1],ShuffleNet [31],vgg [32],WRN [33],MobileNet [34]。
实验配置方面,在CIFAR-100数据集实验中,本文将batch大小设置为64,基础学习率为0.05。在ImageNet数据集实验中,本文将batch大小设置为128,基础学习率为0.01。本文使用1块Nvidia RTX 3090作为训练显卡。
4.2. 实验结果
在本文的实验中,我们评估了该方法的性能,同时与目前主流知识蒸馏方法进行了比较,包括了主流的软目标蒸馏方法以及特征蒸馏方法。表1给出了基于五种同构网络模型组合和八种KD方法的CIFAR-100上的Top-1测试精度,与我们提出的MD知识蒸馏进行了比较。其他方法的部分结果引用自文献[28]。根据表1,表明本文方法在KD的参与下始终比最先进的蒸馏方法获得更高的精度。七种异构网络模型组合的结果如表2所示。显然,当时教师-学生模型组合从同构切换到异构时,在多个中间层上构建的方法往往比提取最后基层或logit的方法表现更差。一些方法甚至可能在学生网络的训练过程中起到相反的负面作用。例如,AT和FitNet的表现甚至比普通学生还要差。这就如之前章节所述,可能是由于提示层的不匹配而造成得这种现象。
具体来说,在本研究中,我们对CIFAR-100和Tiny-ImageNet数据集进行了深入的实验分析,探讨了同构与异构架构下的知识蒸馏效果。在CIFAR-100数据集中,同构架构实验中,教师模型ResNet110的Top-1准确率为74.31%,而采用相同架构但层数减少的学生模型ResNet32的Top-1准确率为71.14%。通过与现有的ReviewKD (特征蒸馏)方法的比较,我们发现,ReviewKD方法将准确率提升至73.89%,而本研究提出的方法能将学生模型的准确率提升至74.13%。在异构架构实验中,教师模型ResNet 32 × 4的Top-1准确率为79.42%,学生模型ShuffleNetV2的原始准确率为71.82%。相较于ReviewKD (特征蒸馏)和DKD (Logit蒸馏)方法,其分别将准确率提升至77.78%和77.07%,本研究方法能显著提升至78.13%。
为了评估本文方法的泛化性能,本文同样在Tiny-ImageNet上对三种经典的师生架构进行了一系列的实验,如表3所示。结果表明本文方法优于其他方法,包括CRD和SAKD的组合,这进一步证明了本文方法的有效性。Tiny-ImageNet中的图像比CIFAR-100中的图像大两倍,特征图同样大两倍,可以提供更多的信息。因此,本文方法在Tiny-ImageNet上的性能优于CIFAR-100。
Table 1. Homogeneous architecture CIFAR-100 results
表1. 同构架构CIFAR-100结果
方法 |
教师网络 |
ResNet56 |
ResNet110 |
ResNet32×4 |
WRN-40-2 |
VGG13 |
72.34 |
74.31 |
79.42 |
75.61 |
75.61 |
学生网络 |
ResNet20 |
ResNet32 |
ResNet8×4 |
WRN-16-2 |
VGG8 |
69.06 |
71.14 |
72.50 |
73.26 |
70.36 |
软目标 |
KD |
70.66 |
73.08 |
74.92 |
73.54 |
72.98 |
DML |
69.52 |
72.03 |
73.58 |
72.68 |
71.79 |
TAKD |
70.83 |
73.37 |
75.06 |
74.33 |
73.23 |
特征 |
FitNet |
69.21 |
71.06 |
73.50 |
72.24 |
71.02 |
RKD |
69.61 |
71.82 |
73.35 |
72.22 |
71.48 |
CRD |
71.16 |
73.48 |
75.51 |
74.14 |
73.94 |
OFD |
70.98 |
73.23 |
75.48 |
74.33 |
73.95 |
ReviewKD |
71.89 |
73.89 |
75.63 |
75.09 |
74.84 |
Ours |
71.35 |
74.13 |
76.03 |
76.87 |
74.96 |
Table 2. Heterogeneous architecture CIFAR-100 results
表2. 异构架构CIFAR-100结果
方法 |
教师网络 |
ResNet 32 × 4 |
ResNet 32 × 4 |
ResNet 32 × 4 |
WRN-40-2 |
WRN-40-2 |
79.42 |
79.42 |
79.42 |
75.61 |
75.61 |
学生网络 |
ShuffleNet-V2 |
WRN-16-2 |
WRN-40-2 |
ResNet 8 × 4 |
MobileNet-V2 |
71.82 |
73.26 |
75.61 |
72.5 |
64.6 |
软目标 |
KD |
74.45 |
74.9 |
77.7 |
73.97 |
68.36 |
CTKD |
75.37 |
74.57 |
77.66 |
74.61 |
68.34 |
DKD |
77.07 |
75.7 |
78.46 |
75.56 |
69.28 |
特征 |
FitNet |
73.54 |
74.7 |
77.69 |
74.61 |
68.64 |
AT |
72.73 |
73.91 |
77.43 |
74.11 |
60.78 |
RKD |
73.21 |
74.86 |
77.82 |
75.26 |
69.27 |
CRD |
75.65 |
75.65 |
78.15 |
75.24 |
70.28 |
OFD |
76.82 |
76.17 |
79.25 |
74.36 |
69.92 |
ReviewKD |
77.78 |
76.11 |
78.96 |
74.34 |
71.28 |
SimKD |
78.39 |
77.17 |
79.29 |
75.29 |
70.1 |
Ours |
78.13 |
76.37 |
78.77 |
76.83 |
70.88 |
对于Tiny-ImageNet数据集,同构架构中教师模型ResNet34的Top-1准确率为73.31%,学生模型ResNet18的准确率原为69.75%,而本研究方法提升至71.93%,相比KD方法的70.66%表现出显著优势。在异构架构中,使用ResNet50作为教师模型,MobileNetV2作为学生模型,其原始Top-1准确率为68.87%,通过本研究方法提升至72.64%,而DKD方法的效果为72.05%。
Table 3. Tiny-ImageNet results
表3. Tiny-ImageNet结果
|
|
Top-1 |
Top-5 |
Top-1 |
Top-5 |
方法 |
教师网络 |
ResNet34 |
ResNet50 |
73.31 |
91.42 |
76.16 |
92.86 |
学生网络 |
ResNet18 |
MobileNet-V2 |
69.75 |
89.07 |
68.87 |
88.76 |
软目标 |
KD |
70.66 |
89.88 |
68.58 |
88.98 |
DML |
70.82 |
90.02 |
71.35 |
90.31 |
TAKD |
70.78 |
90.16 |
70.82 |
90.01 |
DKD |
71.7 |
90.41 |
72.05 |
91.05 |
特征 |
AT |
70.69 |
90.01 |
69.56 |
89.33 |
OFD |
70.81 |
89.98 |
71.25 |
90.34 |
CRD |
71.17 |
90.13 |
71.37 |
90.41 |
ReviewKD |
71.61 |
90.51 |
72.56 |
91 |
Ours |
71.93 |
90.38 |
72.64 |
90.82 |
4.3. 消融实验
该小节研究了本文方法中每个组成部分的贡献,包括掩码特征生成蒸馏(MFD)和解耦金字塔池知识蒸馏(DSPP),如表4所示。实验在CIFAR-100上进行,以ResNet 32 × 4,ResNet 8 × 4和ShuffleNet-V2分别作为同构和异构的学生网络,以Top-1准确率作为评价指标。当采用所有结构时,该方法的表现超过了所有其他蒸馏方法,证明了我们方法的每个部分都是不可或缺的。
Table 4. Ablation experiment
表4. 消融实验
MFD |
DSPP |
ResNet 8 × 4 |
ShuffleNet-V2 |
√ |
|
75.85 |
77.91 |
√ |
√ |
76.03 |
78.13 |
5. 结论
本文提出了一种新型知识蒸馏方法,主要包括掩码特征生成蒸馏(MFD)和解耦空间金字塔池化知识蒸馏(DSPP)。这些方法通过改进学生模型与教师模型的特征对齐方式,实现了知识蒸馏的性能提升。具体来说,MFD创新性地引入了特征生成策略,与传统的模仿教师特征的方式不同,学生模型通过屏蔽随机像素并利用简单的生成模块直接生成教师的完整特征。由于训练过程中屏蔽像素的随机性,模型得以利用全部像素,从而提升了特征的鲁棒性与表示能力。另一方面,DSPP通过解耦的空间金字塔池化操作,减少了对中间层的依赖,并有效捕获了多尺度知识。针对特征图低激活区域包含更多知识线索的观察结果,本文设计了一个解耦模块,用于分析教师和学生网络之间的区域级语义损失。结合空间金字塔池化与解耦区域损失分配的策略,DSPP实现了对学生网络的高效优化。在CIFAR-100和Tiny-ImageNet数据集上的实验结果表明,无论是同构还是异构网络配置,本文方法均显著优于现有蒸馏技术,证明了其在图像分类任务中的优越性。这种结合生成式掩码与解耦池化的蒸馏策略,为知识蒸馏领域提供了一种全新的解决思路。