图像分类模型可信训练及图像分类方法、装置、设备与流程

未命名 08-12 阅读:135 评论:0


1.本技术涉及图像分类技术领域,特别涉及图像分类模型可信训练及图像分类方法、装置、设备。


背景技术:

2.在深度学习领域,知识可以理解为输入与输出之间隐含的映射关系。而知识蒸馏就是从大模型(或者模型集合)向小模型迁移知识的过程,其中大模型称为老师模型,而小模型称为学生模型。
3.在图像多分类任务中知识蒸馏常用做法是将样本同时输入到(已经训练好的)老师模型和学生模型,让学生模型学会老师模型中的知识。但是,存在学生模型进行图像分类时,精度不高的问题。


技术实现要素:

4.有鉴于此,本技术的目的在于提供图像分类模型可信训练及图像分类方法、装置、设备,能够提升学生模型的图像分类精度。其具体方案如下:
5.第一方面,本技术公开了一种图像分类模型可信训练方法,包括:
6.将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型;
7.若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失;
8.基于所述几何结构关系损失更新所述学生模型的参数;
9.当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。
10.可选的,还包括:
11.确定任意两个图像训练样本的老师模型特征向量对应的第一向量距离以及学生模型特征向量对应的第二向量距离;
12.若所述任意两个图像训练样本为同类样本,且所述第一向量距离小于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信;
13.若所述任意两个图像训练样本为不同类样本,且所述第一向量距离大于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信。
14.可选的,所述几何结构关系损失包括距离关系损失,相应的,所述利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,包
括:
15.利用公式loss2=l
eq
+l
neq
计算距离关系损失;
16.其中,loss2表示距离关系损失,l
eq
=mseloss(d
si,j
,d
ti,j
),l
neq
=mseloss(d
si,j
,d
ti,j
),并且,l
eq
表示同类样本之间的距离损失,l
neq
表示不同类样本之间的距离损失,i、j分别表示第i个图像训练样本、第j个图像训练样本,d
ti,j
表示第一向量距离,d
si,j
表示第二向量距离,mseloss为平均平方误差损失函数。
17.可选的,所述几何结构关系损失包括角度关系损失,相应的,所述利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,包括:
18.确定与所述任意两个图像训练样本不属于相同分类的图像训练样本作为锚点图像训练样本;
19.利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量以及所述锚点图像训练样本对应的老师模型特征向量计算角度关系损失。
20.可选的,所述利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,包括:
21.利用公式loss3=l’eq
+l’neq
计算角度关系损失;
22.其中,loss3表示角度关系损失,l’eq
=mseloss(a
si,j,k
,a
ti,j,k
),l’neq
=mseloss(a
si,j,k
,a
ti,j,k
),a
ti,j,k
=cos(embd
ti-embd
tk
,embd
tj-embd
tk
),a
si,j,k
=cos(embd
si-embd
tk
,embd
sj-embd
tk
),并且,l’eq
表示同类的第i个图像训练样本与j个图像训练样本之间的角度损失,l’neq
表示不同类的第i个图像训练样本与j个图像训练样本之间的角度损失,k表示锚点样本,a
si,j,k
表示学生模型对应的角度关系,a
ti,j,k
表示老师模型对应的角度关系,embd
t
表示老师模型输出的特征向量,embds表示学生模型输出的特征向量,mseloss为平均平方误差损失函数。
23.可选的,所述基于所述几何结构关系损失更新所述学生模型的参数,包括:
24.基于特征向量损失和所述几何结构关系损失计算综合训练损失;其中,特征向量损失的计算公式为:loss1=mseloss(embds,embd
t
);
25.基于所述综合训练损失更新所述学生模型的参数。
26.第二方面,本技术公开了一种图像分类方法,包括:
27.获取待分类图像;
28.将所述待分类图像输入目标图像分类模型,得到图像分类结果;
29.其中,所述目标图像分类模型为基于前述的图像分类模型可信训练方法训练得到。
30.第三方面,本技术公开了一种图像分类模型可信训练装置,包括:
31.特征向量获取模块,用于将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型;
32.关系损失计算模块,用于若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失;
33.模型参数更新模块,用于基于所述几何结构关系损失更新所述学生模型的参数;
34.分类模型确定模块,用于当参数更新后的所述学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。
35.第四方面,本技术公开了一种电子设备,包括存储器和处理器,其中:
36.所述存储器,用于保存计算机程序;
37.所述处理器,用于执行所述计算机程序,以实现前述的图像分类模型可信训练方法,和/或如前述的图像分类方法。
38.第五方面,本技术公开了一种计算机可读存储介质,用于保存计算机程序,其中,所述计算机程序被处理器执行时实现前述的图像分类模型可信训练方法,和/或如前述的图像分类方法。
39.可见,本技术将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量,其中,所述老师模型为训练后的图像分类模型,若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,之后基于所述几何结构关系损失更新所述学生模型的参数,当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。也即,本技术将图像训练样本同时输入训练得到的老师模型以及学生模型,在任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信时,利用这两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,基于损失更新学生模型,以使学生模型拟合老师模型学习出来的几何结构关系,这样,给知识蒸馏添加限制,避免了在学生模型特征向量之间的几何结构关系可信时,学生模型仍拟合老师模型学习出来的几何结构关系,能够提升学生模型的图像分类精度。
附图说明
40.为了更清楚地说明本技术实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本技术的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。
41.图1为本技术公开的一种图像分类模型可信训练方法流程图;
42.图2为本技术公开的一种具体的模型输入输出示意图;
43.图3为本技术公开的一种具体的图像分类方法示意图;
44.图4为本技术公开的一种图像分类模型可信训练装置结构示意图;
45.图5为本技术公开的一种电子设备结构图。
具体实施方式
46.下面将结合本技术实施例中的附图,对本技术实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本技术一部分实施例,而不是全部的实施例。基于本技术中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本技术保护的范围。
47.目前,在图像多分类任务中知识蒸馏常用做法是将样本同时输入到(已经训练好的)老师模型和学生模型,让学生模型学会老师模型中的知识。但是,存在学生模型进行图像分类时,精度不高的问题。为此,本技术提供了一种图像分类模型可信训练方案,能够提升学生模型的图像分类精度。
48.参见图1所示,本技术实施例公开了一种图像分类模型可信训练方法,包括:
49.步骤s11:将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型。
50.在具体的实施方式中,可以先利用图像训练样本集训练一个用于图像分类的老师模型,具体为多分类,训练后老师模型的参数不再更新,图像训练样本集包括图像训练样本以及图像训练样本对应的标签信息,损失函数为交叉熵,老师模型的参数量大于学生模型。然后训练利用该图像训练样本集训练学生模型。老师模型和学生模型均包括两个模块,在前的模块是特征向量提取模块,在后模块是特征向量分类模块。知识蒸馏仅仅针对特征向量提取模块。本技术实施例中,老师模型特征向量为老师模型中特征向量提取模块输出的特征向量,学生模型特征向量为学生模型中特征向量提取模块输出的特征向量,特征向量分类模块输入的是特征向量,输出的是分类结果。老师和学生模型可以使用相同的特征向量分类模块。
51.步骤s12:若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失。
52.在具体的实施方式中,可以确定任意两个图像训练样本的老师模型特征向量对应的第一向量距离以及学生模型特征向量对应的第二向量距离;若所述任意两个图像训练样本为同类样本,且所述第一向量距离小于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信;若所述任意两个图像训练样本为不同类样本,且所述第一向量距离大于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信。
53.进一步的,在一种实施方式中,由于老师模型特征向量约等于学生模型特征向量,并且因为老师模型经过长时间的训练,老师模型特征向量要比学生模型向量更精准,所以在计算第二向量距离时,可以将所述任意两个图像训练样本中的任一样本对应的学生模型特征向量替换为老师模型特征向量,与另一样本对应的学生模型特征向量计算出向量距离,得到第二向量距离。
54.并且,在一种实施方式中,几何结构关系损失包括距离关系损失,可以利用公式loss2=l
eq
+l
neq
计算距离关系损失;其中,loss2表示距离关系损失,l
eq
=mseloss(d
si,j
,d
ti,j
),l
neq
=mseloss(d
si,j
,d
ti,j
),并且,l
eq
表示同类样本之间的距离损失,l
neq
表示不同类样本之间的距离损失,i、j分别表示第i个图像训练样本、第j个图像训练样本,d
ti,j
表示第一向量距离,d
si,j
表示第二向量距离,mseloss为平均平方误差损失函数。
55.另外,在一种实施方式中,所述几何结构关系损失包括角度关系损失,确定与所述任意两个图像训练样本不属于相同分类的图像训练样本作为锚点图像训练样本;利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量以及所述锚点图像
训练样本对应的老师模型特征向量计算角度关系损失。需要指出的是,如果锚点样本对应的特征向量和这两个特征向量很接近(属于相同的分类),那么锚点细微的改变都会造成这两个特征向量之间的角度发生巨大的变化;所以本技术实施例限制锚点必须和这两个特征向量属于不同的分类。
56.进一步的,在具体的实施方式中,可以利用公式loss3=l’eq
+l’neq
计算角度关系损失;其中,loss3表示角度关系损失,l’eq
=mseloss(a
si,j,k
,a
ti,j,k
),l’neq
=mseloss(a
si,j,k
,a
ti,j,k
),a
ti,j,k
=cos(embd
ti-embd
tk
,embd
tj-embd
tk
),a
si,j,k
=cos(embd
si-embd
tk
,embd
sj-embd
tk
),并且,l’eq
表示同类的第i个图像训练样本与j个图像训练样本之间的角度损失,l’neq
表示不同类的第i个图像训练样本与j个图像训练样本之间的角度损失,k表示锚点样本,a
si,j,k
表示学生模型对应的角度关系,a
ti,j,k
表示老师模型对应的角度关系,embd
t
表示老师模型输出的特征向量,embds表示学生模型输出的特征向量,mseloss为平均平方误差损失函数。
57.可以理解的是,本技术实施例中,若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系可信,则不参与损失计算。例如,参见图2所示,图2为本技术实施例公开的一种具体的模型输入输出示意图。f
t
表示老师模型,fs表示学生模型,embd
t
表示老师模型输出的特征向量,embds表示学生模型输出的特征向量,1、2、3分别表示3个样本,1和2属于相同分类,与3为不同分类。1和2样本,学生模型输出的特征向量之间的距离比老师模型更接近,学生模型输出的关系是可信的,所以学生模型不再需要模拟老师模型所输出的特征向量之间的距离和角度关系。
58.需要指出的是,同类样本的特征向量,距离越近越好;反之,属于不同分类的特征向量,它们之间的距离越远越好。如果相同分类的样本在老师模型中输出的特征向量之间的距离,比在学生模型中的远,说明使用学生模型中的特征向量,比使用老师模型中的能得到更精准的分类结果,即学生模型输出的特征向量之间的关系是可信的,反之是不可信的。如果学生模型输出的关系是可信的,再让学生模型拟合老师模型学习出来的关系,只会降低学生模型的效果。本技术实施例如果相同分类的样本在老师模型中输出的特征向量之间的距离比学生模型中的更远,或者不同分类的样本在老师模型中输出的特征向量之间的距离比学生模型中的更近的时候,使用学生模型的特征向量能得到比老师模型更精准的分类结果。此时给学生模型施加限制,不再让学生模型拟合老师模型学习出来的关系。这样,学生模型在学习老师模型特征向量之间的距离和角度关系时,受到条件的限制。当学生模型中输出的特征向量之间的关系,比老师模型输出的能让最后的分类模块表现更好时,学生模型就不再去模拟老师模型,以提升学生模型的效果。
59.步骤s13:基于所述几何结构关系损失更新所述学生模型的参数。
60.在具体的实施方式中,可以基于特征向量损失和所述几何结构关系损失计算综合训练损失;其中,特征向量损失的计算公式为:loss1=mseloss(embds,embd
t
);embd
t
表示老师模型输出的特征向量,embds表示学生模型输出的特征向量,mseloss为平均平方误差损失函数。基于所述综合训练损失更新所述学生模型的参数。
61.在一种实施方式中,可以利用特征向量损失以及特征向量损失对应的超参数、距离关系损失以及距离关系损失对应的超参数、角度关系损失以及角度关系损失对应的超参数计算综合训练损失,具体公式如下:
62.loss=αloss1+βloss2+γloss3;其中,loss表示综合训练损失,α、β、γ为超参数。
63.步骤s14:当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。
64.可以理解的是,在训练过程中,本技术实施例是选取batch(批量样本)同时输入老师模型和学生模型,利用该批量样本对应的特征向量计算综合训练损失,基于综合训练损失更新学生模型的参数,完成一次迭代,之后重复迭代,直到参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。其中,可以基于综合训练损失更新学生模型的特征向量提取模块的参数,学生模型可以直接使用老师模型的特征向量分类模块。
65.下面以手写数字识别的数据集为例,采用本技术中的方案进行训练和测试:
66.步骤1、训练一个大的老师模型去学习一个多分类任务,选择交叉熵作为损失函数。
67.步骤2、开始训练一个小的学生模型去模拟老师模型中已经学习到的知识,具体做法如下:
68.步骤2.1、固定老师模型的参数,从数据集中取样一些样本,同时输入给老师模型和学生模型,并得到特征向量。
69.其中,embd
t
=f
t
(x),embds=fs(x);f
t
表示老师模型的特向向量提取模块,fs表示学生模型的特征向量提取模块,embd
t
表示老师模型输出的特征向量,embds表示学生模型输出的特征向量,x表示输入的图像训练样本。
70.步骤2.2、设计损失函数loss1,让学生模型的特征向量尽可能接近老师模型的特征向量。loss1=mseloss(embds,embd
t
)。
71.步骤2.3、设计损失函数loss2,在学生模型输出的关系不可信的时候,让学生模型学习老师模型特征向量之间的距离关系。具体方案如下:
72.d
ti,j
=distance(embd
ti
,embd
tj
);
73.d
si,j
=distance(embd
si
,embd
sj
);
74.其中,i和j是表示这些样本中的任意两条,i≠j,embd
ti
是第i条样本输入给老师模型所输出的特征向量,同理embd
sj
是第j条样本输入给学生模型所输出的特征向量,这两条样本在老师模型和学生模型输出的特征向量之间的距离是d
ti,j
、d
si,j
,distance是距离函数,可以选择余弦距离或者欧式距离等。由于embd
ti
≈embd
si
并且embd
ti
已经经过了老师模型长时间的训练,要比embd
si
更精准,所以在公式中选择用embd
ti
来代替embd
si
,即:d
si,j
=distance(embd
ti
,embd
sj
);于是得到损失函数loss2:
75.l
eq
=mseloss(d
si,j
,d
ti,j
),clsi=clsj且d
si,j
》d
ti,j

76.l
neq
=mseloss(d
si,j
,d
ti,j
),clsi≠clsj且d
si,j
《d
ti,j

77.loss2=l
eq
+l
neq

78.其中,cls表示样本的分类。相同分类的样本,如果学生模型输出的特征向量之间的距离比老师模型的更远,就让学生模型学习老师模型输出的特征向量之间的距离关系,l
eq
=就是对应的损失函数。不同分类的样本,如果学生模型输出的特征向量之间的距离比老师模型的更近,就让学生模型学习老师模型输出的特征向量之间的距离关系,l
neq
就是对应的损失函数。
79.步骤2.4、设计损失函数loss3,在学生模型输出的关系不可信的时候,让学生模型学习老师模型输出的特征向量之间的角度关系。要想计算两个特征向量之间的角度关系,还需要选择一个锚点embd
tk
。需要注意的是,如果锚点和这两个特征向量很接近,属于相同的分类,,那么锚点细微的改变都会造成这两个特征向量之间的角度发生巨大的变化;所以限制锚点必须和这两个特征向量属于不同的分类。
80.a
ti,j,k
=cos(embd
ti-embd
tk
,embd
tj-embd
tk
),a
si,j,k
=cos(embd
si-embd
tk

81.embd
sj-embd
tk
),i≠j≠k且clsk≠clsi且clsk≠clsj;
82.参考距离关系的损失函数,可以得到角度关系的损失函数:
83.l’eq
=mseloss(a
si,j,k
,a
ti,j,k
),clsi=clsj且d
si,j
》d
ti,j

84.l’neq
=mseloss(a
si,j,k
,a
ti,j,k
),clsi≠clsj且d
si,j
《d
ti,j

85.loss3=l’eq
+l’neq

86.步骤2.5、使用不同的超参数将三个损失函数融合在一起,作为最后的损失函数,就可以实现学生模型的训练。在使用学生模型做最后分类的时候,可以直接使用老师模型的分类模块。
87.loss=αloss1+βloss2+γloss3。
88.这样,借助老师和学生模型输出的特征向量之间的距离关系,来评估学生模型是否需要去学习老师模型输出特征向量之间的距离和角度关系。在计算两个特征向量之间的角度关系时,选择与它们不同分类的特征向量作为锚点,防止因为锚点太过接近这两个特征向量而造成角度计算出现很大的误差。通过消融实验,在手写数字识别的数据集上,通过本技术的方案,将知识从大模型迁移到小模型,使得小模型取得了更高的精度。
89.可见,本技术实施例将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量,其中,所述老师模型为训练后的图像分类模型,若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,之后基于所述几何结构关系损失更新所述学生模型的参数,当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。也即,本技术实施例将图像训练样本同时输入训练得到的老师模型以及学生模型,在任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信时,利用这两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,基于损失更新学生模型,以使学生模型拟合老师模型学习出来的几何结构关系,这样,给知识蒸馏添加限制,避免了在学生模型特征向量之间的几何结构关系可信时,学生模型仍拟合老师模型学习出来的几何结构关系,能够提升学生模型的图像分类精度。
90.参见图3所示,图3为本技术实施例公开的一种图像分类方法,包括:
91.步骤s21:获取待分类图像;
92.步骤s22:将所述待分类图像输入目标图像分类模型,得到图像分类结果;
93.其中,所述目标图像分类模型为基于前述实施例所述的图像分类模型可信训练方法训练得到。
94.可见,本技术实施例获取待分类图像,将所述待分类图像输入基于前述实施例所述的图像分类模型可信训练方法训练得到的目标图像分类模型,得到图像分类结果,目标
图像分类模型在训练过程中,在任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信时,利用这两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,基于损失更新学生模型,以使学生模型拟合老师模型学习出来的几何结构关系,这样,给知识蒸馏添加限制,避免了在学生模型特征向量之间的几何结构关系可信时,学生模型仍拟合老师模型学习出来的几何结构关系,能够提升学生模型的图像分类精度。
95.参见图4所示,本技术实施例公开了一种图像分类模型可信训练装置,包括:
96.特征向量获取模块11,用于将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型;
97.关系损失计算模块12,用于若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失;
98.模型参数更新模块13,用于基于所述几何结构关系损失更新所述学生模型的参数;
99.分类模型确定模块14,用于当参数更新后的所述学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。
100.可见,本技术实施例将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量,其中,所述老师模型为训练后的图像分类模型,若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,之后基于所述几何结构关系损失更新所述学生模型的参数,当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。也即,本技术实施例将图像训练样本同时输入训练得到的老师模型以及学生模型,在任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信时,利用这两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,基于损失更新学生模型,以使学生模型拟合老师模型学习出来的几何结构关系,这样,给知识蒸馏添加限制,避免了在学生模型特征向量之间的几何结构关系可信时,学生模型仍拟合老师模型学习出来的几何结构关系,能够提升学生模型的图像分类精度。
101.进一步的,所述装置还包括可信判断模块,用于:
102.确定任意两个图像训练样本的老师模型特征向量对应的第一向量距离以及学生模型特征向量对应的第二向量距离;
103.若所述任意两个图像训练样本为同类样本,且所述第一向量距离小于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信;
104.若所述任意两个图像训练样本为不同类样本,且所述第一向量距离大于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信。
105.在一种具体的实施方式中,所述几何结构关系损失包括距离关系损失,相应的,关
系损失计算模块12,具体用于:
106.利用公式loss2=l
eq
+l
neq
计算距离关系损失;
107.其中,loss2表示距离关系损失,l
eq
=mseloss(d
si,j
,d
ti,j
),l
neq
=mseloss(d
si,j
,d
ti,j
),并且,l
eq
表示同类样本之间的距离损失,l
neq
表示不同类样本之间的距离损失,i、j分别表示第i个图像训练样本、第j个图像训练样本,d
ti,j
表示第一向量距离,d
si,j
表示第二向量距离,mseloss为平均平方误差损失函数。
108.在一种实施方式中,所述几何结构关系损失包括角度关系损失,相应的,关系损失计算模块12,具体用于:确定与所述任意两个图像训练样本不属于相同分类的图像训练样本作为锚点图像训练样本;利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量以及所述锚点图像训练样本对应的老师模型特征向量计算角度关系损失。
109.进一步的,关系损失计算模块12,具体用于:
110.利用公式loss3=l’eq
+l’neq
计算角度关系损失;
111.其中,loss3表示角度关系损失,l’eq
=mseloss(a
si,j,k
,a
ti,j,k
),l’neq
=mseloss(a
si,j,k
,a
ti,j,k
),a
ti,j,k
=cos(embd
ti-embd
tk
,embd
tj-embd
tk
),a
si,j,k
=cos(embd
si-embd
tk
,embd
sj-embd
tk
),并且,l’eq
表示同类的第i个图像训练样本与j个图像训练样本之间的角度损失,l’neq
表示不同类的第i个图像训练样本与j个图像训练样本之间的角度损失,k表示锚点样本,a
si,j,k
表示学生模型对应的角度关系,a
ti,j,k
表示老师模型对应的角度关系,embd
t
表示老师模型输出的特征向量,embds表示学生模型输出的特征向量,mseloss为平均平方误差损失函数。
112.所述装置还包括综合训练损失计算模块,用于:
113.基于特征向量损失和所述几何结构关系损失计算综合训练损失;其中,特征向量损失的计算公式为:loss1=mseloss(embds,embd
t
);
114.相应的,模型参数更新模块13,用于基于所述综合训练损失更新所述学生模型的参数。
115.参见图5所示,本技术实施例公开了一种电子设备20,包括处理器21和存储器22;其中,所述存储器22,用于保存计算机程序;所述处理器21,用于执行所述计算机程序,前述实施例公开的图像分类模型可信训练方法,和/或图像分类方法。
116.关于上述图像分类模型可信训练方法,和/或图像分类方法的具体过程可以参考前述实施例中公开的相应内容,在此不再进行赘述。
117.并且,所述存储器22作为资源存储的载体,可以是只读存储器、随机存储器、磁盘或者光盘等,存储方式可以是短暂存储或者永久存储。
118.另外,所述电子设备20还包括电源23、通信接口24、输入输出接口25和通信总线26;其中,所述电源23用于为所述电子设备20上的各硬件设备提供工作电压;所述通信接口24能够为所述电子设备20创建与外界设备之间的数据传输通道,其所遵循的通信协议是能够适用于本技术技术方案的任意通信协议,在此不对其进行具体限定;所述输入输出接口25,用于获取外界输入数据或向外界输出数据,其具体的接口类型可以根据具体应用需要进行选取,在此不进行具体限定。
119.进一步的,本技术实施例还公开了一种计算机可读存储介质,用于保存计算机程序,其中,所述计算机程序被处理器执行时实现前述实施例公开的图像分类模型可信训练
方法,和/或图像分类方法。
120.关于上述图像分类模型可信训练方法,和/或图像分类方法的具体过程可以参考前述实施例中公开的相应内容,在此不再进行赘述。
121.本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其它实施例的不同之处,各个实施例之间相同或相似部分互相参见即可。对于实施例公开的装置而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。
122.结合本文中所公开的实施例描述的方法或算法的步骤可以直接用硬件、处理器执行的软件模块,或者二者的结合来实施。软件模块可以置于随机存储器(ram)、内存、只读存储器(rom)、电可编程rom、电可擦除可编程rom、寄存器、硬盘、可移动磁盘、cd-rom、或技术领域内所公知的任意其它形式的存储介质中。
123.以上对本技术所提供的图像分类模型可信训练及图像分类方法、装置、设备进行了详细介绍,本文中应用了具体个例对本技术的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本技术的方法及其核心思想;同时,对于本领域的一般技术人员,依据本技术的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本技术的限制。

技术特征:
1.一种图像分类模型可信训练方法,其特征在于,包括:将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型;若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失;基于所述几何结构关系损失更新所述学生模型的参数;当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。2.根据权利要求1所述的图像分类模型可信训练方法,其特征在于,还包括:确定任意两个图像训练样本的老师模型特征向量对应的第一向量距离以及学生模型特征向量对应的第二向量距离;若所述任意两个图像训练样本为同类样本,且所述第一向量距离小于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信;若所述任意两个图像训练样本为不同类样本,且所述第一向量距离大于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信。3.根据权利要求2所述的图像分类模型可信训练方法,其特征在于,所述几何结构关系损失包括距离关系损失,相应的,所述利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,包括:利用公式loss2=l
eq
+l
neq
计算距离关系损失;其中,loss2表示距离关系损失,l
eq
=mseloss(d
si,j
,d
ti,j
),l
neq
=mseloss(d
si,j
,d
ti,j
),并且,l
eq
表示同类样本之间的距离损失,l
neq
表示不同类样本之间的距离损失,i、j分别表示第i个图像训练样本、第j个图像训练样本,d
ti,j
表示第一向量距离,d
si,j
表示第二向量距离,mseloss为平均平方误差损失函数。4.根据权利要求3所述的图像分类模型可信训练方法,其特征在于,所述几何结构关系损失包括角度关系损失,相应的,所述利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,包括:确定与所述任意两个图像训练样本不属于相同分类的图像训练样本作为锚点图像训练样本;利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量以及所述锚点图像训练样本对应的老师模型特征向量计算角度关系损失。5.根据权利要求4所述的图像分类模型可信训练方法,其特征在于,所述利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,包括:利用公式loss3=l’eq
+l’neq
计算角度关系损失;其中,loss3表示角度关系损失,l’eq
=mseloss(a
si,j,k
,a
ti,j,k
),l’neq
=mseloss(a
si,j,k

a
ti,j,k
),a
ti,j,k
=cos(embd
ti-embd
tk
,embd
tj-embd
tk
),a
si,j,k
=cos(embd
si-embd
tk
,embd
sj-embd
tk
),并且,l’eq
表示同类的第i个图像训练样本与j个图像训练样本之间的角度损失,l’neq
表示不同类的第i个图像训练样本与j个图像训练样本之间的角度损失,k表示锚点样本,a
si,j,k
表示学生模型对应的角度关系,a
ti,j,k
表示老师模型对应的角度关系,embd
t
表示老师模型输出的特征向量,embd
s
表示学生模型输出的特征向量,mseloss为平均平方误差损失函数。6.根据权利要求5所述的图像分类模型可信训练方法,其特征在于,所述基于所述几何结构关系损失更新所述学生模型的参数,包括:基于特征向量损失和所述几何结构关系损失计算综合训练损失;其中,特征向量损失的计算公式为:loss1=mseloss(embd
s
,embd
t
);基于所述综合训练损失更新所述学生模型的参数。7.一种图像分类方法,其特征在于,包括:获取待分类图像;将所述待分类图像输入目标图像分类模型,得到图像分类结果;其中,所述目标图像分类模型为基于权利要求1至6任一项所述的图像分类模型可信训练方法训练得到。8.一种图像分类模型可信训练装置,其特征在于,包括:特征向量获取模块,用于将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型;关系损失计算模块,用于若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失;模型参数更新模块,用于基于所述几何结构关系损失更新所述学生模型的参数;分类模型确定模块,用于当参数更新后的所述学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。9.一种电子设备,其特征在于,包括存储器和处理器,其中:所述存储器,用于保存计算机程序;所述处理器,用于执行所述计算机程序,以实现如权利要求1至6任一项所述的图像分类模型可信训练方法,和/或如权利要求7所述的图像分类方法。10.一种计算机可读存储介质,其特征在于,用于保存计算机程序,其中,所述计算机程序被处理器执行时实现如权利要求1至6任一项所述的图像分类模型可信训练方法,和/或如权利要求7所述的图像分类方法。

技术总结
本申请公开了图像分类模型可信训练及图像分类方法、装置、设备,包括:将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型;若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失;基于所述几何结构关系损失更新所述学生模型的参数;当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。这样,能够提升学生模型的图像分类精度。模型的图像分类精度。模型的图像分类精度。


技术研发人员:刘伟华 严宇 左勇
受保护的技术使用者:智慧眼科技股份有限公司
技术研发日:2023.05.30
技术公布日:2023/8/9
版权声明

本文仅代表作者观点,不代表航空之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)

飞行汽车 https://www.autovtol.com/

分享:

扫一扫在手机阅读、分享本文

相关推荐