一种多类中心动态对比学习方法及系统
未命名
10-09
阅读:84
评论:0
1.本发明涉及无监督域适应行人重识别技术领域,具体的是一种多类中心动态对比学习方法及系统。
背景技术:
2.近几年来,随着智能监控领域的不断发展,行人重识别问题逐渐受到越来越多的学者的广泛关注。行人重识别技术主要研究如何将不同摄像头下相同身份的行人图像进行关联,即给定一个查询图像,在不同监控设备的图像库检索出同一个行人。随着深度学习技术不断发展,有监督行人重识别技术已经有了突破性的进展。然而,在实际应用场景中,有监督行人重识别任务的数据标注工作往往需要花费大量的人力和财力,这使得有监督行人重识方法难以扩展到现实场景中。其次,在现实场景中,研究者们能够轻松获得大量无标记的行人数据。因此,在行人重识别问题的研究中,如何利用少量有标记的数据和大量无标记的数据来训练得到鲁邦的模型是行人重识别任务中的热点研究问题。为了解决应对这个问题,无监督域适应行人重识别范式被提出,该技术旨在利用在有标记的源域上学习的行人知识来辅助学习目标域上的行人知识从而得到一个较为鲁邦的行人重识别模型。
3.目前大部分无监督域适应行人重识别方法主要集中于基于聚类的方法。这些方法遵循两阶段训练流程:(1)采用有标记的源域数据进行监督训练得到源模型,并利用该模型初始化待训练的模型参数。(2)通过dbscan等单类中心聚类方法对无标记的目标域数据进行聚类生成其相应的伪标签,接着利用带伪标签的目标域样本进行监督训练。然而,由于域间差异的存在,采用源模型初始化的模型提取的目标域样本特征具有较差的鉴别性,从而导致聚类过程中不可避免的会生成噪声伪标签,降低模型的识别性能。其次,这些方法均采用单类中心聚类的方法来生成目标域伪标签。采用单类中心聚类方法会让属于不同身份的样本聚集到一个集群中并分配相同的伪标签,使用带噪声伪标签的样本进行训练会影响模型性能。
4.为了减轻聚类过程中生成的噪声伪标签对模型产生的负面影响,研究者们提出了许多伪标签优化的方法。然而,现有的方法大多从标签层面来减轻噪声标签对模型产生的负面影响,即选择具有可靠伪标签的目标域样本进行训练,将具有噪声伪标签的样本抛弃。例如,cpl方法提出了一种具有高精度的邻域伪标签和相对高召回率的组伪标签的联合学习框架,该框架利用互补伪标签之间是互利的这一原理来减轻聚类过程中生成的噪声伪标签的影响。ssg方法提出了一种相似性分组方法,该方法利用未标记样本的潜在相似性,从全局到局部的不同视图自动构建多个簇,从而筛选出具有可靠伪标签的样本进行监督训练。尽管这些方法在一定程度上减轻了噪声伪标签对模型产生的负面影响,但是这些方法未能充分目标域的所有样本进行训练。其次,这些方法在每次迭代优化过程中均采用静态伪标签进行监督训练,未有效利用更新后的网络学习到的知识。
技术实现要素:
5.为解决上述背景技术中提到的不足,本发明的目的在于提供一种多类中心动态对比学习方法及系统,能够从聚类算法层面来减去聚类过程中生成的噪声标签对模型的影响,然后利用当前批次网络学习到的知识来优化当前批次样本的标签,最后利用动态对比损失和动态标签进行监督训练。
6.本发明的目的可以通过以下技术方案实现:一种多类中心动态对比学习方法,方法包括以下步骤:
7.接收源域数据集,将源域数据集输入至预先建立的跨域行人重识别的源模型内,得到源模型和权重文件;
8.加载源模型和权重文件,得到平均模型和待训练的原模型,利用平均模型提取目标域样本特征,利用目标域样本特征进行多类中心聚类,得到样本伪标签和多类中心;
9.利用平均模型提取当前批次样本特征,计算当前批次样本特征与多类中心的相似概率,利用相似概率对样本伪标签进行优化,得到动态伪标签;
10.利用动态伪标签计算得出动态对比损失、分类损失和动态三元组损失,将动态对比损失、分类损失和动态三元组损失输入原模型内进行监督训练至原模型收敛,在每次迭代训练后,利用原模型参数动量更新平均模型参数,得到性能最好的平均模型性能参数。
11.优选地,所述多类中心存储于多类中心存储器内。
12.优选地,所述多类中心存储器大小为o
×
d,其中o表示所有类中心数,d表示特征维数,且o由o=m
×
k得到,其中m表示每个类的多类中心数目,k表示目标域的伪类别数。
13.优选地,所述目标域样本标记为目标域样本特征标记为所述目标域的伪类别数k通过dbscan聚类方法获取,对于特定的类别k,根据每个样本的对于类别k的预测结果进行排序,选择排在顶端的前n个样本作为类别的采样样本:每个类的采样样本集由以下公式获得:
[0014][0015]
其中表示平均模型的特征提取器,δ
·
表示分类器,表示每个类的采样样本集,且n表示每个类的采样样本个数,且n表示每个类的采样样本个数,且k表示目标域的伪类别数,r是一个决定采样比率的超参数;
[0016]
对于类均衡采样后的样本集采用kmeans聚类算法对每个类的样本集进行聚类得到属于该类的m个类中心每个类的多类中心的形式由以下公式表示:
[0017]
[0018]
其中表示第k个类的第i个类中心,表示第k个类的采样样本集,m表示每个类的多类中心数目,将计算得出的多类中心存储于多类中心存储器内。
[0019]
优选地,所述样本伪标签的计算过程如下:
[0020]
通过计算目标域样本特征与所有类别的多类中心之间的相似概率,并最大化相似概率获得相似概率相应的样本伪标签,样本伪标签的分配公式如下:
[0021][0022]
其中表示平均模型提取的目标域样本x
t
的特征,表示第k个类的第i个类中心,表示第j个类的第i个类中心,k表示目标域的类别数,exp
·
表示指数函数。
[0023]
优选地,通过以下公式迭代进行多类中心聚类,从而获得置信度较高的目标域的样本伪标签:
[0024][0025][0026][0027]
其中表示平均模型提取的目标域样本x
t
的特征,表示第k个类的第i个类中心,表示第j个类的第i个类中心,k表示目标域的类别数,exp
·
表示指数函数。
[0028]
优选地,所述动态伪标签通过最大化动态软伪标签获得,所述动态软伪标签的分配公式如下:
[0029][0030]
其中为动态软伪标签,q
t
当前训练批次的样本,表示平均模型提取的当前训练批次样本q
t
的特征。
[0031]
所述动态伪标签如下:
[0032]
[0033]
其中为动态伪标签;
[0034]
计算动态软伪标签的对称交叉熵损失如下:
[0035][0036]
其中表示第i个样本的动态软伪标签,表示原模型提取的第i个目标域样本的特征,δ
·
表示分类器,n
t
表示目标域样本个数。
[0037]
目标域的分类损失由以下公式表示:
[0038][0039]
其中,α是用于平衡各个子项的平衡参数;
[0040]
采用动态伪标签来构建三元组,并利用三元组计算三元组损失,动态三元组损失如下所示:
[0041][0042]
其中,表示欧式距离,和分别表示样本在每个小批量中最难的正样本和负样本,m=0.5表示三元组距离间隔。
[0043]
优选地,所述动态对比损失的具体形式如下所示:
[0044][0045]
其中c∈rk×d表示每个类的类中心,由每个类的多个类中心的平均值表示,k表示目标域的伪类别数,d表示特征通道数,σ
·
表示soft-max激活函数,表示第i个样本的动态软伪标签,τ表示温度系数控制着类间概率分布的柔软性,设置τ=0,05。
[0046]
优选地,所述性能最好的平均模型性能参数如下:
[0047]
l
t
θ=l
id
θ+βl
dtri
θ+μl
dcl
θ
[0048]
其中,β和μ是平衡因子,用于调节不同子项之间的平衡,θ表示原模型参数,l
id
·
表示分类损失,l
dtri
·
表示动态三元组损失,l
dcl
·
表示动态对比损失。
[0049]
在本发明的又一方面,为了达到上述目的,公开了一种多类中心动态对比学习系统,包括:
[0050]
数据输入模块:用于接收源域数据集,将源域数据集输入至预先建立的跨域行人重识别的源模型内,得到源模型和权重文件;
[0051]
聚类模块:用于加载源模型和权重文件,得到平均模型和待训练的原模型,利用平均模型提取目标域样本特征,利用目标域样本特征进行多类中心聚类,得到样本伪标签和多类中心;
[0052]
优化模块:用于利用平均模型提取当前批次样本特征,计算当前批次样本特征与多类中心的相似概率,利用相似概率对样本伪标签进行优化,得到动态伪标签;
[0053]
动态对比模块:用于利用动态伪标签计算得出动态对比损失、分类损失和动态三元组损失,将动态对比损失、分类损失和动态三元组损失输入原模型内进行监督训练至原模型收敛,在每次迭代训练后,利用原模型参数以动量的方式动态更新平均模型参数,得到性能最好的平均模型性能参数。
[0054]
本发明的有益效果:
[0055]
本发明通过计算当前批次的查询样本特征和多类中心存储器中的多类中心特征之间的相似度来优化样本伪标签,利用动态伪标签和动态对比损失进行监督训练,从而提高模型的鉴别性能和对噪声的容忍能力。
附图说明
[0056]
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图;
[0057]
图1是本发明方法流程示意图;
[0058]
图2是多类中心动态对比学习方法的模型图;
[0059]
图3是多类中动态对比学习方法训练模型的可视化热力图;
[0060]
图4是多类中心动态对比学习方法的框架图;
[0061]
图5是本发明系统结构示意图。
具体实施方式
[0062]
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。
[0063]
如图1所示,一种多类中心动态对比学习方法,方法包括以下步骤:
[0064]
接收源域数据集,将源域数据集输入至预先建立的跨域行人重识别的源模型内,得到源模型和权重文件;
[0065]
加载源模型和权重文件,得到平均模型和待训练的原模型,利用平均模型提取目标域样本特征,利用目标域样本特征进行多类中心聚类,得到样本伪标签和多类中心;
[0066]
利用平均模型提取当前批次样本特征,计算当前批次样本特征与多类中心的相似
概率,利用相似概率对样本伪标签进行优化,得到动态伪标签;
[0067]
利用动态伪标签计算得出动态对比损失、分类损失和动态三元组损失,将动态对比损失、分类损失和动态三元组损失输入原模型内进行监督训练至原模型收敛,在每次迭代训练后,利用原模型参数动量更新平均模型参数,得到性能最好的平均模型性能参数。
[0068]
需要进一步进行说明的是,在具体实施过程中,通过以下三个方面进行具体实施:
[0069]
多类中心聚类:
[0070]
参照图2,以往的udare-id方法通常采用单类中心聚类算法生成的伪标签进行训练,由于聚类的性能取决于所提取的样本特征的可鉴别性,因此采用单类中心聚类算法会让属于不同身份的样本聚集到同一个集群中并分配相同的伪标签,使用带噪声伪标签的样本进行训练会影响模型性能。为了解决这个问题,我们提出了一种多类中心聚类方法,该方法通过计算当前批次的样本特征和多个类中心的之间的相似度来生成较为可靠的伪标签。在多类中心聚类过程中,我们构建了一个多类中心存储器来存储每个类的多个类中心,该存储器的大小为o
×
d,其中o表示所有类中心数,d表示特征维数。o可由o=m
×
k得到,其中m表示每个类的多类中心数目,k表示目标域的伪类别数(使用聚类方法获得)。
[0071]
在每次epoch训练过程中,准备阶段首先被执行。在准备阶段,我们使用平均模型来提取样本特征。给定一个目标域样本表示平均模型提取的样本特征。平均模型的参数可由原模型的参数通过指数移动平均法(ema)进行更新,公式如下:
[0072][0073]
其中表示第t次迭代时平均模型的参数,θ表示原模型的参数,该值可以通过梯度反向传播进行更新。λ表示动量更新的超参数,其取值范围为0,1。
[0074]
此外,由于目标域的类别数未知,在进行多类中心聚类之前,我们采用dbscan聚类方法获取目标域的伪类别数。我们基于样本的预测结果为每个类采样样本。具体来说,对于特定的类别k,我们根据每个样本的对于类别k的预测结果进行排序,选择排在顶端的前n个样本作为该类别的采样样本,即每个类的采样样本集可由以下公式获得::
[0075][0076]
其中表示平均模型的特征提取器,δ
·
表示分类器。表示每个类的采样样本集,且n表示每个类的采样样本个数,且k表示目标域的伪类别数,r是一个决定采样比率的超参数。
[0077]
该采样方法是基于全局角度选择类别k的前n个最可靠的样本来构建该类别的多类中心,而不是基于局部的实例预测结果来决定是否采样该样本,即选择样本预测结果为k类的样本作为第k类的采样样本。由此可以看出,该采样方法是类平衡的,每个类都能实现均衡采样,从而避免了类不均衡导致的模型性能下降的问题。
[0078]
对于类均衡采样后的样本集我们采用kmeans聚类算法对每个类的样本集进行聚类得到属于该类的m个类中心每个类的多类中心的具体形式可由以下公式表示:
[0079][0080]
其中表示第k个类的第i个类中心,表示第k个类的采样样本集,m表示每个类的多类中心数目。此外,我们将所得的所有类别的多类中心存储在多类中心存储器中。
[0081]
在获得所有类别的多个类中心后,我们计算目标域样本特征与所有类别的多个类中心之间的相似概率,并最大化该相似概率获得其相应的伪标签。伪标签的分配公式如下所示:
[0082][0083]
其中表示平均模型提取的目标域样本x
t
的特征,表示第k个类的第i个类中心,表示第j个类的第i个类中心,k表示目标域的类别数,exp
·
表示指数函数。
[0084]
通过以下公式迭代进行多类中心聚类,从而获得更加可靠的多类中心和目标域样本伪标签:
[0085][0086]
其中表示平均模型提取的目标域样本x
t
的特征,表示第k个类的第i个类中心,表示第j个类的第i个类中心,k表示目标域的类别数,exp
·
表示指数函数。
[0087]
接着,我们利用所得的可靠伪标签计算交叉熵损失来监督模型训练,具体的公式如下所示:
[0088]
[0089]
其中g
θ
·
表示原模型的特征提取器,表示属于第i个样本的伪标签。
[0090]
动态伪标签:
[0091]
以往的方法通常在每个epoch的初始阶段初始化样本伪标签,然后使用该伪标签以固定的周期运行。然而,这些方法在迭代优化过程中仅采用静态伪标签进行监督训练,未有效利用更新后的网络学习到的知识。为了充分利用这些知识,本文提出一种动态伪标签方法来更新样本伪标签得到更加可靠的伪标签进行监督训练,从而提高模型的性能。具体来说,在每个epoch的初始化阶段,我们利用多类中心聚类方法从全局角度生成每个类的多个类中心和每个样本相应的伪标签。在迭代优化的过程中,我们利用当前批次的样本特征和多类中心之间的相似概率得到当前批次样本的动态软伪标签,动态软伪标签的分配公式如下所示:
[0092][0093]
其中为动态软伪标签,q
t
当前训练批次的样本,表示平均模型提取的当前训练批次样本q
t
的特征。通过最大化动态软伪标签可以得到动态伪标签。动态伪标签的具体形式如下所示:
[0094][0095]
对于所得的动态软伪标签,我们利用其计算软标签对称交叉熵损失来监督模型训练。软标签对称交叉熵损失是mmt方法提出的软标签交叉熵损失的一种更加鲁邦的变体。相比于软标签交叉熵损失,该损失能够进一步增强模型对噪声的容忍能力。相比于对称交叉熵损失,该损失采用软标签计算损失能够减轻硬量化损失的误差。动态软标签损失的具体形式如下所示:
[0096][0097]
其中表示第i个样本的动态软伪标签,表示原模型提取的第i个目标域样本的特征,δ
·
表示分类器,n
t
表示目标域样本个数。
[0098]
为了充分利用静态伪标签和动态软伪标签所包含的知识,我们将动态软标签损失和基于静态伪标签的交叉熵损失相结合来监督模型训练,从而获得更加稳健的效果。因此,目标域的分类损失可由以下公式表示:
[0099][0100]
其中,α是用于平衡各个子项的平衡参数。接着,我们采用动态伪标签来构建三元
组,并利用该三元组计算三元组损失,动态三元组损失如下所示:
[0101][0102][0103]
其中,表示欧式距离,和分别表示样本在每个小批量中最难的正样本和负样本,m=0.5表示三元组距离间隔。
[0104]
此外,为了充分利用当前网络学习到知识,我们利用当前批次的样本特征和多个类中心之间的相似概率计算出当前批次的多类中心,然后利用当前批次的多类中心通过动量移动平均方法更新多类中心存储器中的多类中心。多类中心的更新公式如下所示:
[0105][0106]
其中γ表示多类中心动量更新的超参数,该值的取值范围为0,1。表示采用当前批次样本特征计算的属于第k类的第i个类中心,该值可由以下公式得出:
[0107][0108]
其中b表示batch-size的大小,表示当前批次的样本与多类中心之间的相似概率,该值可以由以下公式计算得出:
[0109][0110]
在下一次迭代过程中,我们利用更新后的多类中心和当前批次的样本特征通过公式(7)和公式(8)来获取动态软伪标签和动态伪标签。
[0111]
动态对比学习:
[0112]
由于三元组损失是从实例级层面来监督模型训练,为了充分利用样本间的类级知识,本文提出了一种动态对比学习方法来充分利用动态伪标签从类级层面监督模型训练。
[0113]
受mmt的启发,为了避免训练过程中由于数据增强导致特征内部的相似性变化而造成的误差放大,从而影响模型性能的问题。本文提出了一种基于动态伪标签的动态对比损失,该损失将对比学习和自监督学习结合起来以非参数的方式避免了由于特征分布的扰动而造成误差放大的问题。
[0114]
与仅使用类别信息的硬伪标签对比学习方法相比,改进的软伪标签对比学习方法不仅考虑了同一批次样本特征分布的一致性,而且能够进一步提高模型的噪声容易能力。具体来说,由于聚类对比损失实际上是交叉熵损失,为了提高模型对噪声标签的容忍能力,
dcl表示使用本发明提出的多类中心动态对比学习方法训练的模型。从图3可以看出,使用本发明提出的mcc-dcl方法的实验效果最好。
[0126]
表1mcc-dcl方法在market和duke数据集上的评估结果
[0127][0128]
表2mcc-dcl方法在大规模数据集msmt17上的评估结果
[0129][0130]
表3mcc-dcl方法在无监督行人重识别任务上的评估结果
[0131][0132]
基于同一种发明构思,本发明还提供一种计算机设备,该计算机设备包括包括:一个或多个处理器,以及存储器,用于存储一个或多个计算机程序;程序包括程序指令,处理器用于执行存储器存储的程序指令。处理器可能是中央处理单元(central processing unit,cpu),还可以是其他通用处理器、数字信号处理器(digital signal processor、dsp)、专用集成电路(application specificintegrated circuit,asic)、现场可编程门阵列(field-programmable gatearray,fpga)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等,其是终端的计算核心以及控制核心,其用于实现一条或一条以上指令,具体用于加载并执行计算机存储介质内一条或一条以上指令从而实现上述方法。
[0133]
需要进一步进行说明的是,基于同一种发明构思,本发明还提供一种计算机存储介质,该存储介质上存储有计算机程序,所述计算机程序被处理器运行时执行上述方法。该存储介质可以采用一个或多个计算机可读的介质的任意组合。计算机可读介质可以是计算机可读信号介质或者计算机可读存储介质。计算机可读存储介质例如可以是但不限于电、磁、光、电、磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式计算机磁盘、硬盘、随机存取存储器(ram)、只读存储器(rom)、可擦式可编程只读存储器(eprom或闪存)、光纤、便携式紧凑磁盘只读存储器(cd-rom)、光存储器件、磁存储器件、或者上述的任意合适的组合。在本发明中,计算机可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。
[0134]
在本说明书的描述中,参考术语“一个实施例”、“示例”、“具体示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本公开的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不一定指的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任何的一个或多个实施例或示例中以合适的方式结合。
[0135]
以上显示和描述了本公开的基本原理、主要特征和本公开的优点。本行业的技术人员应该了解,本公开不受上述实施例的限制,上述实施例和说明书中描述的只是说明本公开的原理,在不脱离本公开精神和范围的前提下,本公开还会有各种变化和改进,这些变化和改进都落入要求保护的本公开范围内容。
技术特征:
1.一种多类中心动态对比学习方法,其特征在于,方法包括以下步骤:接收源域数据集,将源域数据集输入至预先建立的跨域行人重识别的源模型内,得到源模型和权重文件;加载源模型和权重文件,得到平均模型和待训练的原模型,利用平均模型提取目标域样本特征,利用目标域样本特征进行多类中心聚类,得到样本伪标签和多类中心;利用平均模型提取当前批次样本特征,计算当前批次样本特征与多类中心的相似概率,利用相似概率对样本伪标签进行优化,得到动态伪标签;利用动态伪标签计算得出动态对比损失、分类损失和动态三元组损失,将动态对比损失、分类损失和动态三元组损失输入原模型内进行监督训练至原模型收敛,在每次迭代训练后,利用原模型参数动量更新平均模型参数,得到性能最好的平均模型性能参数。2.根据权利要求1所述的一种多类中心动态对比学习方法,其特征在于,所述多类中心存储于多类中心存储器内。3.根据权利要求2所述的一种多类中心动态对比学习方法,其特征在于,所述多类中心存储器大小为o
×
d,其中o表示所有类中心数,d表示特征维数,且o由o=m
×
k得到,其中m表示每个类的多类中心数目,k表示目标域的伪类别数。4.根据权利要求1所述的一种多类中心动态对比学习方法,其特征在于,所述目标域样本标记为目标域样本特征标记为所述目标域的伪类别数通过dbscan聚类方法获取,对于特定的类别k,根据每个样本的对于类别k的预测结果进行排序,选择排在顶端的前n个样本作为类别的采样样本:每个类的采样样本集由以下公式获得:其中,表示平均模型的特征提取器,δ
·
表示分类器,表示每个类的采样样本集,且n表示每个类的采样样本个数,且n表示每个类的采样样本个数,且k表示目标域的伪类别数,r是一个决定采样比率的超参数;对于类均衡采样后的样本集采用kmeans聚类算法对每个类的样本集进行聚类得到属于该类的m个类中心每个类的多类中心的形式由以下公式表示:其中表示第k个类的第i个类中心,表示第k个类的采样样本集,m表示每个类的多类中心数目,将计算得出的多类中心存储于多类中心存储器中。5.根据权利要求1所述的一种多类中心动态对比学习方法,其特征在于,所述样本伪标
签的计算过程如下:通过计算目标域样本特征与所有类别的多类中心之间的相似概率,并最大化相似概率获得相似概率相应的样本伪标签,样本伪标签的分配公式如下:其中表示平均模型提取的目标域样本x
t
的特征,表示第k个类的第i个类中心,表示第j个类的第i个类中心,k表示目标域的类别数,exp
·
表示指数函数。6.根据权利要求5所述的一种多类中心动态对比学习方法,其特征在于,通过以下公式迭代进行多类中心聚类,从而获得置信度较高的目标域的样本伪标签:迭代进行多类中心聚类,从而获得置信度较高的目标域的样本伪标签:迭代进行多类中心聚类,从而获得置信度较高的目标域的样本伪标签:其中表示平均模型提取的目标域样本x
t
的特征,表示第k个类的第i个类中心,表示第j个类的第i个类中心,k表示目标域的类别数,exp
·
表示指数函数。7.根据权利要求1所述的一种多类中心动态对比学习方法,其特征在于,所述动态伪标签通过最大化动态软伪标签获得,所述动态软伪标签的分配公式如下:其中为动态软伪标签,q
t
当前训练批次的样本,表示平均模型提取的当前训练批次样本q
t
的特征;所述动态伪标签如下:其中为动态伪标签;计算动态软伪标签的对称交叉熵损失如下:
其中表示第i个样本的动态软伪标签,表示原模型提取的第i个目标域样本的特征,δ
·
表示分类器,n
t
表示目标域样本个数;目标域的分类损失由以下公式表示:其中,α是用于平衡各个子项的平衡参数;采用动态伪标签来构建三元组,并利用三元组计算三元组损失,动态三元组损失如下所示:其中,表示欧式距离,和分别表示样本在每个小批量中最难的正样本和负样本,m=0.5表示三元组距离间隔。8.根据权利要求1所述的一种多类中心动态对比学习方法,其特征在于,所述动态对比损失的具体形式如下所示:其中c∈r
k
×
d
表示每个类的类中心,由每个类的多个类中心的平均值表示,k表示目标域的伪类别数,d表示特征通道数,σ
·
表示soft-max激活函数,表示第i个样本的动态软伪标签,τ表示温度系数控制着类间概率分布的柔软性,设置τ=0,05。9.根据权利要求1所述的一种多类中心动态对比学习方法,其特征在于,所述性能最好的平均模型性能参数如下:l
t
θ=l
id
θ+βl
dtri
θ+μl
dcl
θ其中,β和μ是平衡因子,用于调节不同子项之间的平衡,θ表示原模型参数,l
id
·
表示分类损失,l
dtri
·
表示动态三元组损失,l
dcl
·
表示动态对比损失。
10.一种多类中心动态对比学习系统,其特征在于,包括:数据输入模块:用于接收源域数据集,将源域数据集输入至预先建立的跨域行人重识别的源模型内,得到源模型和权重文件;聚类模块:用于加载源模型和权重文件,得到平均模型和待训练的原模型,利用平均模型提取目标域样本特征,利用目标域样本特征进行多类中心聚类,得到样本伪标签和多类中心;优化模块:用于利用平均模型提取当前批次样本特征,计算当前批次样本特征与多类中心的相似概率,利用相似概率对样本伪标签进行优化,得到动态伪标签;动态对比模块:用于利用动态伪标签计算得出动态对比损失、分类损失和动态三元组损失,将动态对比损失、分类损失和动态三元组损失输入原模型内进行监督训练至原模型收敛,在每次迭代训练后,利用原模型参数以动量的方式动态更新平均模型参数,得到性能最好的平均模型性能参数。
技术总结
本发明公开了一种多类中心动态对比学习方法及系统,涉及无监督域适应行人重识别技术领域,方法包括以下步骤:接收源域数据集,将源域数据集输入至预先建立的跨域行人重识别的源模型内,得到源模型和权重文件;加载源模型和权重文件,得到平均模型和原模型,利用平均模型提取目标域样本特征,利用目标域样本特征进行多类中心聚类,得到样本伪标签和多类中心;利用平均模型提取当前批次样本特征,计算当前批次样本特征与多类中心的相似概率,利用相似概率对样本伪标签进行优化,得到动态伪标签;利用动态伪标签计算得出动态对比损失,将损失输入原模型内进行监督训练至原模型收敛,利用原模型参数动量更新平均模型参数,得到性能最好的平均模型。能最好的平均模型。能最好的平均模型。
技术研发人员:田青 杜晓欣 程耀
受保护的技术使用者:南京信息工程大学
技术研发日:2023.07.08
技术公布日:2023/10/6
版权声明
本文仅代表作者观点,不代表航空之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
飞行汽车 https://www.autovtol.com/
