一种基于类注意力传输的模型压缩方法
未命名
08-07
阅读:118
评论:0
1.本发明属于机器视觉技术领域,具体涉及一种基于类注意力传输的模型压缩方法。
背景技术:
2.神经网络模型的表现受到模型参数容量的影响,参数多的模型往往能获得更好的表现。但大模型具有对硬件要求高、运行速度慢等缺点,这使得它很难部署到小型设备及移动设备上。知识蒸馏技术旨在提高小模型的表现,它的实现方式是:在训练小模型的过程中额外让小模型模仿并学习参数容量较大的预训练模型的中间输出。其中,训练的小模型被称为学生模型,预训练的大模型被称为教师模型,由大模型提供的中间输出被称为软标签。现存的知识蒸馏方法可以根据其传输的软标签的生产方式分为三类:传输预测分数、传输模型的特征图、以及传输模型的注意力图。
3.在现存的方法中,基于传输预测分数和特征图的知识蒸馏方法具有较好的效果。然而,由于神经网络模型的预测分数值以及其中间输出的特征图仍然不能被很好的解释其含义,基于传输预测分数和特征图的知识蒸馏的方法的可解释性差。相较之下,基于传输注意力的方法具有较好的可解释性。然而,与基于传输预测分数和特征图的方法相比,现存的基于传输注意力的知识蒸馏方法的效果较差。
技术实现要素:
4.为了解决相关技术中存在的上述问题,本发明提供了一种基于类注意力传输的模型压缩方法。本发明要解决的技术问题通过以下技术方案实现:
5.本发明提供一种基于类注意力传输的模型压缩方法,包括:
6.获取待识别图像;
7.采用训练好的学生模型进行所述待识别图像中的目标物体的识别,得到所述目标物体的类别;
8.所述训练好的学生模型是采用待训练的学生模型根据训练样本输出的预测分数、所述训练样本对应的真实标签、所述待训练的学生模型自身输出的第一类注意力图,以及预训练的教师模型输出的第二类注意力图,对所述待训练的学生模型进行训练得到的;所述待训练的学生模型和所述教师模型中连接在全局池化层之后的全连接层被替换为连接在所述全局池化层之前的一维卷积核。
9.在一些实施例中,所述学生模型和所述教师模型均为用于进行目标分类的模型。
10.在一些实施例中,所述待训练的学生模型和所述教师模型均包括:卷积层、一维卷积核和全局池化层;其中,所述卷积层用于对输入的图像进行特征提取,得到所述输入的图像的特征图;所述一维卷积核用于根据所述特征图生成类注意力图;所述全局池化层用于将所述类注意力图转换为预测分数。
11.在一些实施例中,在所述采用训练好的学生模型进行所述待识别图像中的目标物
体的识别,得到所述目标物体的类别之前,所述方法还包括:
12.获取多个带有真实标签的样本图像;所述真实标签表征样本图像中的物体的真实类别;
13.获取第t次待训练的学生模型,以及所述预训练的教师模型;其中,当t为1时,所述待训练的学生模型为初始待训练的学生模型;
14.将至少一个样本图像输入所述第t次待训练的学生模型和所述预训练的教师模型,得到所述第t次待训练的学生模型输出的预测分数、所述第t次待训练的学生模型输出的所述第一类注意力图、所述预训练的教师模型输出的所述第二类注意力图;t为大于或等于1的整数;
15.根据所述第t次待训练的学生模型输出的预测分数、所述第t次待训练的学生模型输出的所述第一类注意力图、所述预训练的教师模型输出的所述第二类注意力图,以及所述至少一个样本图像的真实标签,确定第t次的损失值;
16.根据所述第t次的损失值进行反向传播并更新所述第t次待训练的学生模型的参数,得到第t+1次待训练的学生模型,如此迭代,直至满足预设条件时停止训练,得到训练好的待转换学生模型;
17.将所述训练好的待转换学生模型中连接在所述全局池化层之前的一维卷积核,替换为连接在全局池化层之后的全连接层,得到所述训练好的学生模型。
18.在一些实施例中,所述根据所述第t次待训练的学生模型输出的预测分数、所述第t次待训练的学生模型输出的所述第一类注意力图、所述预训练的教师模型输出的所述第二类注意力图,以及所述至少一个样本图像的真实标签,确定第t次的损失值,包括:
19.根据所述第t次待训练的学生模型输出的所述第一类注意力图、所述预训练的教师模型输出的所述第二类注意力图,确定第t次的知识蒸馏损失值;
20.根据所述第t次待训练的学生模型输出预测分数和所述至少一个样本图像的真实标签,确定第t次的交叉熵损失值;
21.根据所述第t次的知识蒸馏损失值和所述第t次的交叉熵损失值,确定所述第t次的损失值。
22.在一些实施例中,所述根据所述第t次待训练的学生模型输出的所述第一类注意力图、所述预训练的教师模型输出的所述第二类注意力图,确定第t次的知识蒸馏损失值,包括:
23.对所述第t次待训练的学生模型输出的所述第一类注意力图进行池化操作,得到第一池化特征;
24.对所述预训练的教师模型输出的所述第二类注意力图进行池化操作,得到第二池化特征;
25.对所述第一池化特征和所述第二池化特征分别进行l2正则化处理,得到第一处理特征和第二处理特征;
26.将所述第一处理特征和所述第二处理特征之间的均方差,作为所述第t次的知识蒸馏损失值。
27.在一些实施例中,所述根据所述第t次的知识蒸馏损失值和所述第t次的交叉熵损失值,确定所述第t次的损失值,包括:
28.对所述第t次的知识蒸馏损失值和所述第t次的交叉熵损失值进行求和,得到求和值;
29.将所述求和值作为所述第t次的损失值。
30.在一些实施例中,所述获取第t次待训练的学生模型,包括:
31.当t为1时,进行学生模型的初始化,得到所述初始待训练的学生模型。
32.本发明具有如下有益技术效果:
33.通过上述方法,使得训练得到的学生模型能够捕捉到图中更多的与目标类别有关的辨识区域,在具有高度的可解释性的同时提高了模型的分类准确度。
34.以下将结合附图及实施例对本发明做进一步详细说明。
附图说明
35.图1为本发明实施例提供的一种基于类注意力传输的模型压缩方法的一个流程图;
36.图2为本发明实施例提供的示例性的本发明提出的模型的结构(converted structure)以及该模型结构从普通模型(normoal cnn)的转化过程示意图;
37.图3为本发明实施例提供的示例性的确定知识蒸馏损失值的原理示意图;
38.图4为本发明实施例的训练好的学生模型得到的类注意力图与普通训练模型的类注意力图的对比示意图;
39.图5为本发明实施例提供的示例性的多个其他方法和本发明提出的方法的训练消耗及效果对比示意图。
具体实施方式
40.下面结合具体实施例对本发明做进一步详细的描述,但本发明的实施方式不限于此。
41.在本发明的描述中,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多个该特征。在本发明的描述中,“多个”的含义是两个或两个以上,除非另有明确具体的限定。
42.在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本发明的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不必须针对的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任何的一个或多个实施例或示例中以合适的方式结合。此外,本领域的技术人员可以将本说明书中描述的不同实施例或示例进行接合和组合。
43.尽管在此结合各实施例对本发明进行了描述,然而,在实施所要求保护的本发明过程中,本领域技术人员通过查看所述附图、公开内容、以及所附权利要求书,可理解并实现所述公开实施例的其他变化。在权利要求中,“包括”(comprising)一词不排除其他组成部分或步骤,“一”或“一个”不排除多个的情况。单个处理器或其他单元可以实现权利要求中列举的若干项功能。相互不同的从属权利要求中记载了某些措施,但这并不表示这些措
施不能组合起来产生良好的效果。
44.图1是本发明实施例提供的基于类注意力传输的模型压缩方法的一个流程图,如图1所示,所述方法包括以下步骤:
45.s101、获取待识别图像。
46.s102、采用训练好的学生模型进行待识别图像中的目标物体的识别,得到目标物体的类别;训练好的学生模型是采用待训练的学生模型根据训练样本输出的预测分数、训练样本对应的真实标签、待训练的学生模型自身输出的第一类注意力图,以及预训练的教师模型输出的第二类注意力图,对待训练的学生模型进行训练得到的;待训练的学生模型和教师模型中连接在全局池化层之后的全连接层被替换为连接在全局池化层之前的一维卷积核。
47.这里,学生模型和教师模型均为用于进行目标分类的模型。
48.示例性的,待训练的学生模型和预训练的教师模型均包括:卷积层、一维卷积核和全局池化层;其中,卷积层用于对输入的图像进行特征提取,得到输入的图像的特征图;一维卷积核用于根据特征图生成类注意力图;全局池化层用于将类注意力图转换为预测分数。
49.示例性的,训练好的学生模型包括:卷积层、全局池化层和全连接层。
50.如图2所示,在图像分类任务中,主流模型通常使用卷积神经网络(convolutional neural network,cnn)先来提取特征,由cnn产生的多个特征图再被全局池化(global average pooling,gap)成预测分数后,会被输入到一个简单的全连接(full connectivity,fc)层来执行分类任务。而本发明在训练学生模型时,通过将全连接层转化为一维卷积核(1
×
1conv)后,就可以在cnn的前向传播过程中得到类注意力图(class activation mapping,cam)(例如,图2中的k个类注意力图)。
51.在一些实施例中,在s102之前,还包括以下步骤:
52.s201、获取多个带有真实标签的样本图像;真实标签表征样本图像中的物体的真实类别。
53.s202、获取第t次待训练的学生模型,以及预训练的教师模型;其中,当t为1时,待训练的学生模型为初始待训练的学生模型。
54.这里,当t为1时,进行学生模型的初始化,得到初始待训练的学生模型。
55.s203、将至少一个样本图像输入第t次待训练的学生模型和预训练的教师模型,得到第t次待训练的学生模型输出的预测分数、第t次待训练的学生模型输出的第一类注意力图、预训练的教师模型输出的第二类注意力图;t为大于或等于1的整数。
56.具体的,每次训练时,可以随机选择一个或多个样本图像,并将选择的样本图像输入这一次待训练的学生模型,得到这一次待训练的学生模型中间输出的类注意力图(称为第一类注意力图),以及这一次待训练的学生模型输出的预测分数;以及将选择的样本图像输入预训练的教师模型,得到预训练的教师模型,得到预训练的教师模型输出的类注意力图(称为第二类注意力图)。
57.s204、根据第t次待训练的学生模型输出的预测分数、第t次待训练的学生模型输出的第一类注意力图、预训练的教师模型输出的第二类注意力图,以及至少一个样本图像的真实标签,确定第t次的损失值。
58.这里,可以根据第t次待训练的学生模型输出的第一类注意力图、预训练的教师模型输出的第二类注意力图,确定第t次的知识蒸馏损失值;根据第t次待训练的学生模型输出预测分数和至少一个样本图像的真实标签,确定第t次的交叉熵损失值;对第t次的知识蒸馏损失值和第t次的交叉熵损失值进行求和,得到求和值;将求和值作为第t次的损失值。
59.具体的,确定第t次的知识蒸馏损失值的原理如下:对第t次待训练的学生模型输出的第一类注意力图进行池化操作,得到第一池化特征;对预训练的教师模型输出的第二类注意力图进行池化操作,得到第二池化特征;对第一池化特征和第二池化特征分别进行l2正则化处理,得到第一处理特征和第二处理特征;将第一处理特征和第二处理特征之间的均方差,作为第t次的知识蒸馏损失值。
60.示例性的,图3为确定知识蒸馏损失值的原理图。如图3所示,相同的样本图像分别经过预训练的教师模型和待训练的学生模型后,得到预训练的教师模型中间输出的类注意力图transferred cams,以及待训练的学生模型中间输出的类注意力图cams,对transferred cams和cams分别进行池化(pooling)处理后,得到预训练的教师模型对应的尺寸为w*h*k的池化特征(以下称为第一池化特征)和待训练的学生模型对应的尺寸为w*h*k的池化特征(以下称为第二池化特征),对第一池化特征和第二池化特征分别进行正则化(normalization)出后,分别得到预训练的教师模型对应的经过正则化的特征(以下称为第一处理特征)和待训练的学生模型对应的经过正则化的特征(以下称为第二处理特征),计算第一处理特征和第二处理特征之间的均方差,将该方差作为知识蒸馏损失值(cat loss)。
61.s205、根据第t次的损失值进行反向传播并更新第t次待训练的学生模型的参数,得到第t+1次待训练的学生模型,如此迭代,直至满足预设条件时停止训练,得到训练好的待转换学生模型。
62.这里,当训练次数达到预设次数,或者某一次的损失值达到预设阈值时,确定满足预设条件,此时可以停止训练,并将最后一次训练得到的学生模型作为训练好的待转换学生模型。
63.s206、将训练好的待转换学生模型中连接在全局池化层之前的一维卷积核,替换为连接在全局池化层之后的全连接层,得到训练好的学生模型。
64.这里,由于得到的训练好的待转换学生模型的结构如图2中的converted structure所示,因而,通过将训练好的待转换学生模型中连接在全局池化层之前的一维卷积核,替换为连接在全局池化层之后的全连接层,可以得到训练好的学生模型,并且,这一转变不会改变模型输出的预测分数。
65.通过上述方法,本发明使得训练得到的学生模型能够捕捉到图中更多的与目标类别有关的辨识区域,在具有高度的可解释性的同时提高了模型的分类准确度。
66.以下通过实验数据,对本发明实施例所能达到的技术效果进行进一步说明。
67.图4为本方法(cat-kd)的训练好的学生模型得到的类注意力图与普通训练模型的类注意力图的对比示意图。图4中的第一行的三个图为普通训练模型的类注意力图的可视化,图4中的第二行的三个图为本方法的训练好的学生模型的类注意力图的可视化。如图4所示,本发明预的训练的学生模型能够捕捉到图中更多的与目标类别有关的辨识区域。
68.以下表1为本方法(cat-kd)与其他知识蒸馏方法(kd、dkd、crd、ofd、fitnet、rkd、
review kd、at)在cifar100数据集的效果比较。显然,本方法的效果更优。表2为对多个其他方法训练的模型和本方法训练的模型进行迁移学习所得到的结果比较。显然,本方法训练的模型有更好的泛化性。
[0069][0070][0071]
表1
[0072][0073]
表2
[0074]
图5为多个其他方法和本方法的训练消耗及效果对比示意图。显然,本方法是最优的。
[0075]
以上内容是结合具体的优选实施方式对本发明所作的进一步详细说明,不能认定本发明的具体实施只局限于这些说明。对于本发明所属技术领域的普通技术人员来说,在不脱离本发明构思的前提下,还可以做出若干简单推演或替换,都应当视为属于本发明的
保护范围。
技术特征:
1.一种基于类注意力传输的模型压缩方法,其特征在于,包括:获取待识别图像;采用训练好的学生模型进行所述待识别图像中的目标物体的识别,得到所述目标物体的类别;所述训练好的学生模型是采用待训练的学生模型根据训练样本输出的预测分数、所述训练样本对应的真实标签、所述待训练的学生模型自身输出的第一类注意力图,以及预训练的教师模型输出的第二类注意力图,对所述待训练的学生模型进行训练得到的;所述待训练的学生模型和所述教师模型中连接在全局池化层之后的全连接层被替换为连接在所述全局池化层之前的一维卷积核。2.根据权利要求1所述的基于类注意力传输的模型压缩方法,其特征在于,所述学生模型和所述教师模型均为用于进行目标分类的模型。3.根据权利要求1所述的基于类注意力传输的模型压缩方法,其特征在于,所述待训练的学生模型和所述教师模型均包括:卷积层、一维卷积核和全局池化层;其中,所述卷积层用于对输入的图像进行特征提取,得到所述输入的图像的特征图;所述一维卷积核用于根据所述特征图生成类注意力图;所述全局池化层用于将所述类注意力图转换为预测分数。4.根据权利要求1所述的基于类注意力传输的模型压缩方法,其特征在于,在所述采用训练好的学生模型进行所述待识别图像中的目标物体的识别,得到所述目标物体的类别之前,所述方法还包括:获取多个带有真实标签的样本图像;所述真实标签表征样本图像中的物体的真实类别;获取第t次待训练的学生模型,以及所述预训练的教师模型;其中,当t为1时,所述待训练的学生模型为初始待训练的学生模型;将至少一个样本图像输入所述第t次待训练的学生模型和所述预训练的教师模型,得到所述第t次待训练的学生模型输出的预测分数、所述第t次待训练的学生模型输出的所述第一类注意力图、所述预训练的教师模型输出的所述第二类注意力图;t为大于或等于1的整数;根据所述第t次待训练的学生模型输出的预测分数、所述第t次待训练的学生模型输出的所述第一类注意力图、所述预训练的教师模型输出的所述第二类注意力图,以及所述至少一个样本图像的真实标签,确定第t次的损失值;根据所述第t次的损失值进行反向传播并更新所述第t次待训练的学生模型的参数,得到第t+1次待训练的学生模型,如此迭代,直至满足预设条件时停止训练,得到训练好的待转换学生模型;将所述训练好的待转换学生模型中连接在所述全局池化层之前的一维卷积核,替换为连接在全局池化层之后的全连接层,得到所述训练好的学生模型。5.根据权利要求4所述的基于类注意力传输的模型压缩方法,其特征在于,所述根据所述第t次待训练的学生模型输出的预测分数、所述第t次待训练的学生模型输出的所述第一类注意力图、所述预训练的教师模型输出的所述第二类注意力图,以及所述至少一个样本图像的真实标签,确定第t次的损失值,包括:根据所述第t次待训练的学生模型输出的所述第一类注意力图、所述预训练的教师模
型输出的所述第二类注意力图,确定第t次的知识蒸馏损失值;根据所述第t次待训练的学生模型输出预测分数和所述至少一个样本图像的真实标签,确定第t次的交叉熵损失值;根据所述第t次的知识蒸馏损失值和所述第t次的交叉熵损失值,确定所述第t次的损失值。6.根据权利要求5所述的基于类注意力传输的模型压缩方法,其特征在于,所述根据所述第t次待训练的学生模型输出的所述第一类注意力图、所述预训练的教师模型输出的所述第二类注意力图,确定第t次的知识蒸馏损失值,包括:对所述第t次待训练的学生模型输出的所述第一类注意力图进行池化操作,得到第一池化特征;对所述预训练的教师模型输出的所述第二类注意力图进行池化操作,得到第二池化特征;对所述第一池化特征和所述第二池化特征分别进行l2正则化处理,得到第一处理特征和第二处理特征;将所述第一处理特征和所述第二处理特征之间的均方差,作为所述第t次的知识蒸馏损失值。7.根据权利要求5所述的基于类注意力传输的模型压缩方法,其特征在于,所述根据所述第t次的知识蒸馏损失值和所述第t次的交叉熵损失值,确定所述第t次的损失值,包括:对所述第t次的知识蒸馏损失值和所述第t次的交叉熵损失值进行求和,得到求和值;将所述求和值作为所述第t次的损失值。8.根据权利要求5所述的基于类注意力传输的模型压缩方法,其特征在于,所述获取第t次待训练的学生模型,包括:当t为1时,进行学生模型的初始化,得到所述初始待训练的学生模型。
技术总结
本发明公开了一种基于类注意力传输的模型压缩方法,包括:获取待识别图像;采用训练好的学生模型进行所述待识别图像中的目标物体的识别,得到所述目标物体的类别;所述训练好的学生模型是采用待训练的学生模型根据训练样本输出的预测分数、所述训练样本对应的真实标签、所述待训练的学生模型自身输出的第一类注意力图,以及预训练的教师模型输出的第二类注意力图,对所述待训练的学生模型进行训练得到的;所述待训练的学生模型和所述教师模型中连接在全局池化层之后的全连接层被替换为连接在所述全局池化层之前的一维卷积核。本发明与其他模型压缩方法相比具有更高的可解释性,并且,本发明能够提高模型的分类准确度。本发明能够提高模型的分类准确度。本发明能够提高模型的分类准确度。
技术研发人员:李晖 郭子尧 闫浩楠 赵兴文
受保护的技术使用者:西安电子科技大学
技术研发日:2023.03.29
技术公布日:2023/8/5
版权声明
本文仅代表作者观点,不代表航空之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
飞行汽车 https://www.autovtol.com/
