基于知识蒸馏特征生成的长尾数据联邦学习方法和系统
未命名
10-09
阅读:108
评论:0
1.本发明属于联邦学习技术领域,具体涉及一种基于知识蒸馏特征生成的不平衡数据联邦学习方法和系统。
背景技术:
2.近年来,深度学习技术在人工智能领域中发挥了重要作用,它的成功很大程度上依赖于大量的训练数据。利用深度模型进行建模时,通常的做法是在服务器端收集大量的数据进行训练,模型之间的数据并不互通。但是在现实生活中,企业之间都存在数据孤岛的问题,在服务器端进行中心化建模的方式越来越困难。联邦学习的提出是为了在保证数据隐私安全及合规合法的基础上进行跨组织联合建模,提升人工智能模型的效果,目前已广泛用于人工智能研究方向。
3.联邦学习中的一个重要挑战是不同组织(客户端)之间的数据分布是不同的,即数据异构问题。fedavg是一个经典的方法,在服务器端维持了一个全局模型,客户端首先下载全局模型并且在本地数据上训练,然后服务器端聚合由客户端上传的本地模型的更新,重复上述过程直到收敛。但是fedavg不能很好的解决数据异构的问题,因为它没有考虑数据客户端之间数据分布的差异,并且单一的全局模型不能在每个客户端上都有很好的性能。因此,个性化联邦学习已经被提出,它致力于为每个客户端学习一个定制的模型并且能够从参与联邦学习中受益。
4.一部分的联邦学习方法是把全局模型分成个性化的部分和全局的部分。另外一种个性化联邦学习中常见的方法是限制通过正则的方式限制个性化模型和全局模型的距离。
5.可是,这些方法通常考虑全局分布是平衡的,然而在现实生活中数据分布往往是长尾的。很少的类(头类)包含了大量的样本,但是大部分的类(尾类)仅占据很少的样本。长尾分布加剧了在相同异构条件下每个客户端内部的不平衡程度,导致了本地模型性能的急剧下降。在联邦学习中,ratio loss通过估计全局的不平衡状况来解决不平衡。creff用联邦特征重新训练分类器来解决全局长尾分布的问题。但是,这两种方法都没有考虑生成个性化模型。
技术实现要素:
6.鉴于上述,本发明的目的是提供一种基于知识蒸馏特征生成的不平衡数据联邦学习方法和系统,通过有效解决联邦异构长尾数据分布的问题,以进一步提升个性化联邦学习下的模型性能。
7.为实现上述发明目的,实施例提供的一种基于知识蒸馏特征生成的长尾数据联邦学习方法,包括以下步骤:
8.服务器端接收上传的本地模型、每类标签分布和每类本地特征原型后,基于本地模型参数更新全局模型,基于每类标签分布和每类本地特征原型更新生成模型,更新的生成模型和全局模型下发至各客户端;
9.各客户端接收全局模型和生成模型后,基于生成模型生成每类本地特征并计算每类本地特征原型,依据每类本地特征和本地数据更新本地模型,并上传更新的本地模型、每类标签分布以及每类本地特征原型,基于全局模型初始化个性化模型,基于生成模型为少数类样本生成多特征和为多数类样本生成少特征,基于多特征、少特征和本地数据更新初始化后的个性化模型,同时基于知识蒸馏将全局模型的知识迁移到个性化模型中。
10.在一个实施例中,所述生成模型用于根据输入标签和随机噪声生成特征,基于每类标签分布和每类特征原型更新生成模型采用的损失函数l
gen
为:
11.l
gen
=λ1l
ce
+λ2l
dis
+λ3l
sc
12.其中,λ1、λ2、λ3表示权重系数,l
ce
表示分类损失,l
dis
表示样本调节损失,l
sc
表示对比损失;
13.分类损失l
ce
表示为:
[0014][0015]
其中,y表示输入的标签,f表示生成的特征,vk表示第k个客户端的本地模型中的分类器,l
ce
((vk,f),y)表示将f输入至vk的预测值与y的交叉熵损失,p(y)表示标签分布,表示将y输入至参数为的生成模型得到的特征分布,e表示期望,s
t
表示客户端集合;
[0016]
样本调节损失l
dis
表示为:
[0017][0018]
其中,z表示随机噪声向量,i和j表示向量索引,m表示向量总数,
[0019]2表示l2范数;
[0020]
对比损失l
sc
表示为:
[0021][0022]
其中,表示第k个客户端的第y类的本地特征原型,ωy表示拥有样本第y类的客户端集合,表示第y类全局特征原型,
ya
表示第
ya
类标签,fy表示依据第y类标签生成的特征,a(y)表示标签集合。
[0023]
在一个实施例中,所述基于生成模型生成每类本地特征并计算每类本地特征原型,包括:
[0024][0025]
其中,为第k个客户端的第y类的本地数据,x表示本地样本,dk表示第k个客户端的本地数据,表示第k个客户端的第y类的本地特征原型,表示第k个客户端中本地模
型包括的特征提取器,t表示轮次,r表示最终批次。
[0026]
在一个实施例中,依据每类本地特征和本地数据更新本地模型采用的损失函数为:
[0027][0028]
其中,表示随机采样的标签,表示利用生成模型依据标签生成的特征,表示将输入至的预测值与的交叉熵损失,t表示轮次,r表示批次,dk表示第k个客户端的本地数据,表示第k个本地模型,表示依据将dk输入至的预测值计算的平衡的softmax损失。
[0029]
在一个实施例中,所述基于全局模型初始化个性化模型,包括:
[0030]
所述全局模型包括全局特征提取器和全局分类器,所述个性化模型包括局部特征提取器和局部分类器,利用全局特征提取器初始化局部特征提取器。
[0031]
在一个实施例中,所述基于生成模型为少数类样本生成多特征和为多数类样本生成少特征,包括:
[0032]
针对少数类样本采样多数量标签,采用生成模型基于多数量标签生成多数量特征;
[0033]
针对多数类样本采样少数量标签,采用生成模型基于少数量标签生成少数量特征。
[0034]
在一个实施例中,基于多特征、少特征和本地数据更新初始化后的个性化模型,同时基于知识蒸馏将全局模型的知识迁移到个性化模型中,采用的损失函数l
pfldf
为:
[0035][0036]
其中,λ表示权重,l
p
表示分类损失,l
kl
表示蒸馏损失;
[0037]
分类损失l
p
表示为:
[0038][0039]
其中,表示为少数类样本采样的多标签和为多数类样本采样的少标签,表示基于生成的特征,表示第k个个性化模型包含的分类器,表示将输入至的预测值与的交叉熵损失,dk表示第k个客户端的本地数据,θk表示第k个个性化模型,l
ce
(θk,dk)表示将dk输入至θk的预测值与原始标签y的交叉熵;
[0040]
蒸馏损失l
kl
表示为:
[0041][0042]
其中,q
kl
表示kl散度,w
t
表示全局模型,t表示轮次,r表示批次,σ表示softmax函数,σ(w
t
,dk)表示将dk输入至全局模型w
t
的softmax值,表示将dk输入至个性化模型的softmax值。
[0043]
为实现上述发明目的,实施例还提供了一种基于知识蒸馏特征生成的长尾数据联邦学习系统,包括服务器端和各客户端,
[0044]
所述服务器端用于接收上传的本地模型、每类标签分布和每类本地特征原型后,
基于本地模型参数更新全局模型,基于每类标签分布和每类本地特征原型更新生成模型,更新的生成模型和全局模型下发至各客户端;
[0045]
所述各客户端用于接收全局模型和生成模型后,基于生成模型生成每类本地特征并计算每类本地特征原型,依据每类本地特征和本地数据更新本地模型,并上传更新的本地模型、每类标签分布以及每类本地特征原型,基于全局模型初始化个性化模型,基于生成模型为少数类样本生成多特征,基于多特征和本地数据更新初始化后的个性化模型,同时基于知识蒸馏将全局模型的知识迁移到个性化模型中。
[0046]
与现有技术相比,本发明具有的有益效果至少包括:
[0047]
在服务器上训练一个轻量级的生成模型。生成模型的输入为标签,输出为对应的特征。生成特征可以模拟真实特征的分布。在个性化模型的训练过程中,为每个客户端上的局部少数类生成更多特征,为其他类生成更少特征,这将有助于缓解每个客户端数据的严重不平衡。此外,全局模型是从局部模型聚合而来的,在全局头类中表现更好,而且能够获得高质量和含有丰富信息的特征。将全局模型的知识提取到个性化模型中可以帮助提高个性化模型的性能。通过生成特征来知识蒸馏来训练个性化模型。这样能够有效解决联邦异构长尾数据分布的问题,进一步提升了个性化联邦学习下的模型性能。
附图说明
[0048]
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图做简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动前提下,还可以根据这些附图获得其他附图。
[0049]
图1是本发明实施例提供的基于知识蒸馏特征生成的长尾数据联邦学习方法的流程图;
[0050]
图2是本发明实施例提供的服务器端的工作流程图;
[0051]
图3是本发明实施例提供的客户端进行本地模型更新工作流程图;
[0052]
图4是本发明实施例提供的客户端进行个性化模型更新工作流程图;
[0053]
图5是本发明实施例提供的基于知识蒸馏特征生成的长尾数据联邦学习系统的结构示意图。
具体实施方式
[0054]
为使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例对本发明进行进一步的详细说明。应当理解,此处所描述的具体实施方式仅仅用以解释本发明,并不限定本发明的保护范围。
[0055]
针对个性化联邦学习中异构数据和长尾分布的联合问题导致个性化模型性能差的技术问题,实施例提供了一种基于知识蒸馏特征生成的长尾数据联邦学习方法和系统,以提高个性化联邦学习下的模型性能。
[0056]
如图1所示,实施例提供的基于知识蒸馏特征生成的长尾数据联邦学习方法,实现该联邦学习方法的系统包括服务器端和各客户端,各客户端均与服务器端通信,但是各客户端之间不进行通信,基于该系统,联邦学习方法包括以下步骤:
[0057]
s110,服务器端接收各客户端上传的本地模型、每类标签分布和每类本地特征原型,并进行生成模型和全局模型的更新和下发。
[0058]
基于联邦学习和长尾学习方法的研究,在数据异构和数据长尾分布的情况下,模型的特征提取器受到数据分布的影响较小。因此,将全局模型分成两个部分,第一部分是全局特征提取器g,将其参数化为u。第二部分是全局分类器f,将其参数化为v。此时,样本x的特征h由h=g(x;u)生成,预测结果由f(h;v)给出。服务器端将全局模型分发给各个被选定的客户端,被选取的客户端会上传本地模型、标签分布和本地特征原型,如图2所示,服务器端执行以下步骤:s210,接收上传的本地模型、每类标签分布和每类本地特征原型后;s220,服务器端基于本地模型参数更新全局模型,基于每类标签分布和每类本地特征原型更新生成模型;s230,更新的生成模型和全局模型下发至各客户端。
[0059]
实施例中,针对全局模型,依据上传的本地模型进行全局更新,更新方式不限定。
[0060]
实施例中,生成模型用于根据输入标签y和随机噪声z生成特征f,由于生成特征需要被正确的分类,因此需要参数更新,更新过程中,依据客户端上传的每类标签分布和每类特征原型,具体更新采用的损失函数l
gen
为:
[0061]
l
gen
=λ1l
ce
+λ2l
dis
+λ3l
sc
[0062]
其中,λ1、λ2、λ3表示权重系数,l
ce
表示分类损失,l
dis
表示样本调节损失,l
sc
表示对比损失;
[0063]
分类损失l
ce
表示为:
[0064][0065]
其中,y表示输入的标签,f表示生成的特征,vk表示第k个客户端的本地模型中的分类器,l
ce
((vk,f),y)表示将f输入至vk的预测值与y的交叉熵损失,p(y)表示标签分布,表示将y输入至参数为的生成模型得到的特征分布,e表示期望,s
t
表示客户端集合;
[0066]
由于只采用分类损失会导致生成特征的多样性不足,因此,增加了样本调节损失l
dis
,表示为:
[0067][0068]
其中,z表示随机噪声向量,i和j表示向量索引,m表示向量总数,
[0069]2表示l2范数;
[0070]
由于存在生成特征可以被正确的识别但是和真实的特征不相似的情况,在这种情况下,生成的特征可能会伤害模型的性能,因此增添了对比损失l
sc
来限制生成特征和真实特征的距离,表示为:
[0071][0072]
[0073]
其中,表示第k个客户端的第y类的本地特征原型,ωy表示拥有样本第y类的客户端集合,表示第y类全局特征原型,
ya
表示第
ya
类标签,fy表示依据第y类标签生成的特征,a(y)表示标签集合。
[0074]
在对生成模型和全局模型更新后,将更新的生成模型和全局模型下发至各客户端。
[0075]
s120,各客户端接收全局模型和生成模型,依据生成模型生成特征和本地特征原型,并进行本地模型和个性化模型更新。
[0076]
实施例中,如图3所示,各客户端执行以下步骤:s310,各客户端接收服务器端下发的全局模型和生成模型;s320,基于生成模型生成每类本地特征并计算每类本地特征原型;s330,依据每类本地特征和本地数据更新本地模型,s340,上传更新的本地模型、每类标签分布以及每类本地特征原型。
[0077]
由于生成器模型携带有全局信息的知识,通过随机采样标签生成对应类别的特征和真实的数据一起更新本地模型。对于本地模型的训练,使用如下损失函数:
[0078][0079]
其中,表示随机采样的标签,表示利用生成模型依据标签生成的特征,表示将输入至的预测值与的交叉熵损失,t表示轮次,r表示批次,dk表示第k个客户端的本地数据,表示第k个本地模型,表示依据将dk输入至的预测值计算的平衡的softmax损失。
[0080]
由于希望生成模型生成的特征能够和真实样本的特征距离比较近,统计每类的特征的均值作为各个客户端的本地特征原型。通过如下方式构建本地特征原型:
[0081][0082]
其中,为第k个客户端的第y类的本地数据,x表示本地样本,dk表示第k个客户端的本地数据,表示第k个客户端的第y类的本地特征原型,表示第k个客户端中本地模型包括的特征提取器,t表示轮次,r表示最终批次,为了减少计算代价,只统计在最后一个epoch的特征。
[0083]
在更新完本地模型和本地特征原型后,上传更新的本地模型、每类标签分布以及每类本地特征原型至服务器端,以进行下一轮次联邦学习。
[0084]
实施例中,各客户端在接收全局模型和生成模型后,还进行个性化模型的更新,如图4所示,具体包括:s410,基于全局模型初始化个性化模型;s420,基于生成模型为少数类样本生成多特征和为多数类样本生成少特征;s430,基于多特征、少特征和本地数据更新初始化后的个性化模型,同时基于知识蒸馏将全局模型的知识迁移到个性化模型中。
[0085]
实施例中,个性化模型包括局部特征提取器和局部分类器,基于全局模型初始化个性化模型时,利用全局特征提取器初始化局部特征提取器。
[0086]
实施例中,根据各个客户端的样本分布,为少数类样本生成更多的特征,为多数类样本生成更少的特征,来缓解各个客户端数据分布极端不平衡的问题。具体地,基于生成模
型为少数类样本生成多特征和为多数类样本生成少特征,包括:针对少数类样本采样多数量标签,采用生成模型基于多数量标签生成多数量特征;针对多数类样本采样少数量标签,采用生成模型基于少数量标签生成少数量特征。
[0087]
在个性化模型的训练过程中生成特征和真实样本一起参与训练,首先要保证真实样本能够被正确分类,这在图片分类任务中是基本的操作,又因为生成模型能够模拟真实的特征分布,携带有其他客户端的信息,可以利用生成特征来进一步的再平衡特征,帮助缓解客户端内部数据分布不平衡的状况,另外,全局模型可以获取全局知识并且提供全局视角,还引入基于知识蒸馏将全局模型的知识迁移到个性化模型中,使用kullback-leibler散度来测量全局模型和个性化模型的差异,因此,对个性化模型参数优化采用的损失函数l
pfldf
为:
[0088]
l
pfldf
=l
p
+λl
kl
[0089]
其中,λ表示权重,l
p
表示分类损失,l
kl
表示蒸馏损失;
[0090]
分类损失l
p
表示为:
[0091][0092]
其中,表示为少数类样本采样的多标签和为多数类样本采样的少标签,表示基于生成的特征,表示第k个个性化模型包含的分类器,表示将输入至的预测值与的交叉熵损失,dk表示第k个客户端的本地数据,θk表示第k个个性化模型,l
ce
(θk,dk)表示将dk输入至θk的预测值与原始标签y的交叉熵;
[0093]
蒸馏损失l
kl
表示为:
[0094][0095]
其中,q
kl
表示kl散度,w
t
表示全局模型,t表示轮次,r表示批次,σ表示softmax函数,σ(w
t
,dk)表示将dk输入至全局模型w
t
的softmax值,表示将dk输入至个性化模型的softmax值。
[0096]
基于同样的发明构思,实施例还提供了一种基于知识蒸馏特征生成的长尾数据联邦学习系统,如图5所示,包括服务器端510和各客户端520。
[0097]
其中,服务器端510用于接收上传的本地模型、每类标签分布和每类本地特征原型后,基于本地模型参数更新全局模型,基于每类标签分布和每类本地特征原型更新生成模型,更新的生成模型和全局模型下发至各客户端;
[0098]
其中,各客户端520用于接收全局模型和生成模型后,基于生成模型生成每类本地特征并计算每类本地特征原型,依据每类本地特征和本地数据更新本地模型,并上传更新的本地模型、每类标签分布以及每类本地特征原型,基于全局模型初始化个性化模型,基于生成模型为少数类样本生成多特征,基于多特征和本地数据更新初始化后的个性化模型,同时基于知识蒸馏将全局模型的知识迁移到个性化模型中。
[0099]
需要说明的是,上述实施例提供的基于知识蒸馏特征生成的长尾数据联邦学习系统与基于知识蒸馏特征生成的长尾数据联邦学习方法实施例属于同一构思,其具体实现过程详见基于知识蒸馏特征生成的长尾数据联邦学习方法实施例,这里不再赘述。
[0100]
为验证上述基于知识蒸馏特征生成的长尾数据联邦学习的效果,实施例进行了在
数据集验证,具体地,包括:
[0101]
首先,准备长尾分类数据集,并切分训练数据集给各个客户端。
[0102]
采用cifar-10和cifar-100。cifar-10是用于识别物体的小型数据集,一共包含10个类别的rgb彩色数据图片。图片的尺寸为32x32,共50000张训练图片和10000张测试图片。cifar-100共100类,训练集50000张图片,测试集10000张图片。所有数据集的不平衡因子均为10、50和100,即全局样本量最多的类别除以全局样本量最少的类别的数值。采用带有超参数α的狄利克雷分布来控制客户端数据异构程度。α越接近0数据异构程度更严重。在本发明中,模拟了α=0.1下的数据异构程度。
[0103]
然后,搭建联邦学习框架并初始化模型。
[0104]
在cifar-10-lt和cifar-100-lt数据集采用resnet32架构作为骨干模型。所有实验均利用pytorch框架实现,在nvidia geforce rtx 3080上运行完成。共模拟设计20个客户端数据分布,每次随机选择其中的10个客户端的本地模型进行联邦聚合。真实数据的生成特征的批大小均设置为32,优化器选用sgd。全局训练轮数设为200。
[0105]
接下来,按照上述方法长尾数据联邦学习方法进行联邦学习。
[0106]
表1为本发明与其他几种联邦学习方法在cifar-10-lt和cifar-100-lt数据集上不平衡程度为100、50和10,数据异构程度为0.1的top-1精度(%)比对结果。表中加粗结果为各指标的最优结果。分析可得,本发明提供的方法得到的模型精度高。
[0107]
表1
[0108][0109]
以上所述的具体实施方式对本发明的技术方案和有益效果进行了详细说明,应理解的是以上所述仅为本发明的最优选实施例,并不用于限制本发明,凡在本发明的原则范围内所做的任何修改、补充和等同替换等,均应包含在本发明的保护范围之内。
技术特征:
1.一种基于知识蒸馏特征生成的长尾数据联邦学习方法,其特征在于,包括以下步骤:服务器端接收上传的本地模型、每类标签分布和每类本地特征原型后,基于本地模型参数更新全局模型,基于每类标签分布和每类本地特征原型更新生成模型,更新的生成模型和全局模型下发至各客户端;各客户端接收全局模型和生成模型后,基于生成模型生成每类本地特征并计算每类本地特征原型,依据每类本地特征和本地数据更新本地模型,并上传更新的本地模型、每类标签分布以及每类本地特征原型,基于全局模型初始化个性化模型,基于生成模型为少数类样本生成多特征和为多数类样本生成少特征,基于多特征、少特征和本地数据更新初始化后的个性化模型,同时基于知识蒸馏将全局模型的知识迁移到个性化模型中。2.根据权利要求1所述的基于知识蒸馏特征生成的长尾数据联邦学习方法,其特征在于,所述生成模型用于根据输入标签和随机噪声生成特征,基于每类标签分布和每类特征原型更新生成模型采用的损失函数l
gen
为:l
gen
=λ1l
ce
+λ2l
dis
+λ3l
sc
其中,λ1、λ2、λ3表示权重系数,l
ce
表示分类损失,l
dis
表示样本调节损失,l
sc
表示对比损失;分类损失l
ce
表示为:其中,y表示输入的标签,f表示生成的特征,v
k
表示第k个客户端的本地模型中的分类器,l
ce
((v
k
,f),y)表示将f输入至v
k
的预测值与y的交叉熵损失,p(y)表示标签分布,表示将y输入至参数为的生成模型得到的特征分布,e表示期望,s
t
表示客户端集合;样本调节损失l
dis
表示为:其中,z表示随机噪声向量,i和j表示向量索引,m表示向量总数,2表示l2范数;对比损失l
sc
表示为:表示为:其中,表示第k个客户端的第y类的本地特征原型,
ωy
表示拥有样本第y类的客户端集合,表示第y类全局特征原型,
ya
表示第
ya
类标签,f
y
表示依据第y类标签生成的特征,a(y)表示标签集合。3.根据权利要求1所述的基于知识蒸馏特征生成的长尾数据联邦学习方法,其特征在于,所述基于生成模型生成每类本地特征并计算每类本地特征原型,包括:
其中,为第k个客户端的第y类的本地数据,x表示本地样本,d
k
表示第k个客户端的本地数据,表示第k个客户端的第y类的本地特征原型,表示第k个客户端中本地模型包括的特征提取器,t表示轮次,r表示最终批次。4.根据权利要求1所述的基于知识蒸馏特征生成的长尾数据联邦学习方法,其特征在于,依据每类本地特征和本地数据更新本地模型采用的损失函数为:其中,表示随机采样的标签,表示利用生成模型依据标签生成的特征,表示将输入至的预测值与的交叉熵损失,t表示轮次,r表示批次,d
k
表示第k个客户端的本地数据,表示第k个本地模型,表示依据将d
k
输入至的预测值计算的平衡的softmax损失。5.根据权利要求1所述的基于知识蒸馏特征生成的长尾数据联邦学习方法,其特征在于,所述基于全局模型初始化个性化模型,包括:所述全局模型包括全局特征提取器和全局分类器,所述个性化模型包括局部特征提取器和局部分类器,利用全局特征提取器初始化局部特征提取器。6.根据权利要求1所述的基于知识蒸馏特征生成的长尾数据联邦学习方法,其特征在于,所述基于生成模型为少数类样本生成多特征和为多数类样本生成少特征,包括:针对少数类样本采样多数量标签,采用生成模型基于多数量标签生成多数量特征;针对多数类样本采样少数量标签,采用生成模型基于少数量标签生成少数量特征。7.根据权利要求1所述的基于知识蒸馏特征生成的长尾数据联邦学习方法,其特征在于,基于多特征、少特征和本地数据更新初始化后的个性化模型,同时基于知识蒸馏将全局模型的知识迁移到个性化模型中,采用的损失函数l
pfldf
为:l
pfldf
=l
p
+λl
kl
其中,λ表示权重,l
p
表示分类损失,l
kl
表示蒸馏损失;分类损失l
p
表示为:其中,表示为少数类样本采样的多标签和为多数类样本采样的少标签,表示基于生成的特征,表示第k个个性化模型包含的分类器,表示将输入至的预测值与的交叉熵损失,d
k
表示第k个客户端的本地数据,θ
k
表示第k个个性化模型,l
ce
(θ
k
,d
k
)表示将d
k
输入至θ
k
的预测值与原始标签y的交叉熵;蒸馏损失l
kl
表示为:其中,q
kl
表示kl散度,w
t
表示全局模型,t表示轮次,r表示批次,σ表示softmax函数,σ(w
t
,d
k
)表示将d
k
输入至全局模型w
t
的softmax值,表示将d
k
输入至个性化模型的softmax值。
8.一种基于知识蒸馏特征生成的长尾数据联邦学习系统,其特征在于,包括服务器端和各客户端,所述服务器端用于接收上传的本地模型、每类标签分布和每类本地特征原型后,基于本地模型参数更新全局模型,基于每类标签分布和每类本地特征原型更新生成模型,更新的生成模型和全局模型下发至各客户端;所述各客户端用于接收全局模型和生成模型后,基于生成模型生成每类本地特征并计算每类本地特征原型,依据每类本地特征和本地数据更新本地模型,并上传更新的本地模型、每类标签分布以及每类本地特征原型,基于全局模型初始化个性化模型,基于生成模型为少数类样本生成多特征,基于多特征和本地数据更新初始化后的个性化模型,同时基于知识蒸馏将全局模型的知识迁移到个性化模型中。9.根据权利要求8所述的基于知识蒸馏特征生成的长尾数据联邦学习系统,其特征在于,所述服务器端,基于每类标签分布和每类特征原型更新生成模型采用的损失函数l
gen
为:l
gen
=λ1l
ce
+λ2l
dis
+λ3l
sc
其中,λ1、λ2、λ3表示权重系数,l
ce
表示分类损失,l
dis
表示样本调节损失,l
sc
表示对比损失;分类损失l
ce
表示为:其中,y表示输入的标签,f表示生成的特征,v
k
表示第k个客户端的本地模型中分类器,l
ce
((v
k
,f),y)表示将f输入至v
k
的预测值与y的交叉熵损失,p(y)表示标签分布,表示将y输入至参数为的生成模型得到的特征分布,e表示期望,s
t
表示客户端集合;样本调节损失l
dis
表示为:其中,z表示随机噪声向量,i和j表示向量索引,m表示向量总数,|| ||2表示l2范数;对比损失l
sc
表示为:表示为:其中,表示第k个客户端的第y类的本地特征原型,ω
y
表示拥有样本第y类的客户端集合,表示第y类全局特征原型,y
a
表示第y
a
类标签,f
y
表示依据第y类标签生成的特征,a(y)表示标签集合。10.根据权利要求8所述的基于知识蒸馏特征生成的长尾数据联邦学习系统,其特征在于,所述各客户端中,依据每类本地特征和本地数据更新本地模型采用的损失函数为:
其中,表示随机采样的标签,表示利用生成模型依据标签生成的特征,表示将输入至的预测值与的交叉熵损失,t表示轮次,r表示批次,d
k
表示第k个客户端的本地数据,表示第k个本地模型,表示依据将d
k
输入至的预测值计算的平衡的softmax损失;基于多特征、少特征和本地数据更新初始化后的个性化模型,同时基于知识蒸馏将全局模型的知识迁移到个性化模型中,采用的损失函数l
pfldf
为:l
pfldf
=l
p
+λl
kl
其中,λ表示权重,l
p
表示分类损失,l
kl
表示蒸馏损失;分类损失l
p
表示为:其中,表示为少数类样本采样的多标签和为多数类样本采样的少标签,表示基于生成的特征,表示第k个个性化模型包含的分类器,表示将输入至的预测值与的交叉熵损失,d
k
表示第k个客户端的本地数据,θ
k
表示第k个个性化模型,l
ce
(θ
k
,d
k
)表示将d
k
输入至θ
k
的预测值与原始标签y的交叉熵;蒸馏损失l
kl
表示为:其中,q
kl
表示kl散度,w
t
表示全局模型,t表示轮次,r表示批次,σ表示softmax函数,σ(w
t
,d
k
)表示将d
k
输入至全局模型w
t
的softmax值,表示将d
k
输入至个性化模型的softmax值。
技术总结
本发明公开了一种基于知识蒸馏特征生成的长尾数据联邦学习方法和系统,在服务器上训练一个轻量级的生成模型。生成模型的输入为标签,输出为对应的特征。生成特征可以模拟真实特征的分布。在个性化模型的训练过程中,为每个客户端上的局部少数类生成更多特征,为其他类生成更少特征,这将有助于缓解每个客户端数据的严重不平衡。此外,全局模型是从局部模型聚合而来的,在全局头类中表现更好,而且能够获得高质量和含有丰富信息的特征。将全局模型的知识提取到个性化模型中可以帮助提高个性化模型的性能。通过生成特征来知识蒸馏来训练个性化模型。这样能够有效解决联邦异构长尾数据分布的问题,进一步提升了个性化联邦学习下的模型性能。的模型性能。的模型性能。
技术研发人员:卢杨 吕凤玲 钱品馨 黄刚 华炜 王菡子
受保护的技术使用者:厦门大学
技术研发日:2023.04.25
技术公布日:2023/10/8
版权声明
本文仅代表作者观点,不代表航空之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
飞行汽车 https://www.autovtol.com/
