基于动态自适应知识蒸馏的联邦学习模型聚合方法

未命名 09-07 阅读:305 评论:0


1.本发明涉及隐私保护和数据安全技术领域,特别涉及一种基于动态自适应知识蒸馏的联邦学习模型聚合方法。


背景技术:

2.传统的集中式学习要求在手机等本地设备上收集的所有数据都要集中存储在数据中心或云服务器上。这一要求不仅引起了对隐私风险和数据泄露的担忧,而且在数据量巨大时,对服务器的存储和计算能力提出了很高的要求。
3.联邦学习是目前在隐私约束下最广泛采用的机器学习模型协作训练框架,旨在训练一个全局模型,可以在分布在不同设备上的数据上进行训练,同时保护数据隐私。但是联邦学习中每个客户端上的训练数据在很大程度上依赖于特定本地设备的使用情况,因此,客户端的数据分布可能彼此完全不同。这种现象被称为非独立同分布(non-iid),它可能会导致严重的模型发散,导致精度降低,模型收敛缓慢甚至无法收敛。也就是说,由于局部数据分布的异质性,具有相同初始参数的局部模型会收敛到不同的模型。在联邦学习过程中,通过平均上传的局部模型得到的共享全局模型与理想模型(本地设备上的数据为iid时得到的模型)之间的差异持续增加,收敛速度减慢,使学习性能恶化。
4.虽然目前已经有一些研究提出可以在本地模型训练时使用知识蒸馏技术约束本地模型向全局模型学习来解决这一问题,但仍然存在许多问题。比如固定的知识蒸馏比例不能自主适应训练过程中的多变性,或是需要额外的辅助数据集来帮助判断合适的知识蒸馏比例,在现实应用中仍然存在诸多困难。因此,如何更好地利用知识蒸馏提升模型的准确率依然是目前亟需解决的技术难题。


技术实现要素:

5.本发明的目的在于克服现有技术的不足,提供基于动态自适应知识蒸馏的联邦学习模型聚合方法,用以解决联邦学习场景下,固定知识蒸馏比例不能适应各客户端数据分布及模型训练进度不一致的实际情况,进而导致知识蒸馏效果下降,也就无法训练得到高准确度的全局模型的技术问题。
6.为了实现上述目的,本发明采用的技术方案为:基于动态自适应知识蒸馏的联邦学习模型聚合方法,包括如下步骤:
7.步骤1:服务器初始化全局模型并将其发送至参与本轮训练的客户端;
8.步骤2:客户端接收到全局模型后,确定本轮知识蒸馏中对收到的全局模型学习的比例,自适应调整学习本地数据集和全局模型的比例,并动态调整教师模型的输出,使其处于最适合学习的分布状态,训练生成本地模型,并上传给服务器;
9.步骤3:对接收到的本地模型进行聚合生成新的全局模型从而完成本轮训练过程。
10.步骤1中:服务器根据训练任务选择待训练的模型作为本轮全局模型m;然后选择参与本轮训练的客户端c1,c2,...,cn(1≤n≤n),将全局模型下发给参与训练的客户端;其
中客户端c1,c2,...,cn为n个独立的客户端,客户端各自拥有独立的数据d1,d2,...,dn。
11.步骤3中:在接收到全部客户端上传的全部本地模型后,采用联邦平均算法对本地模型进行聚合形成新的全局模型。
12.计算每个客户端数据集大小占总数据集大小的比例,按照比例对对应本地模型参数进行加权形成全局模型的参数进而得到新的全局模型。
13.步骤2中包括:
14.步骤2.1:客户端保存全局模型作为知识蒸馏的教师模型,计算当前客户端对于教师模型的知识蒸馏比例;
15.步骤2.2:客户端计算教师模型输出分布平缓程度;
16.步骤2.3:客户端把全局模型作为本地模型的初始模型,利用本地训练数据集训练本地模型;并且把本地模型作为知识蒸馏中的学生模型,使用步骤2.1、2.2中计算出的知识蒸馏比例和输出分布平缓程度约束教师模型的知识蒸馏过程;
17.步骤2.4:客户端将本地模型上传给服务器。
18.步骤2.1中:
19.利用本地数据集确定本轮知识蒸馏中对收到的全局模型学习的比例。客户端ci(1≤i≤n)接收服务器下发的模型m,作为本地模型mi的教师模型,参与本地模型mi优化训练过程。客户端利用本地训练数据集di=(x,y),测试教师模型在本地训练数据集di上的准确度ai。已知数据样本x和对应标签y,输入教师模型,模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x)。p(x)中预测概率最高的类别即为模型预测结果,如与对应样本标签y一致,则预测成功,否则预测失败。统计所有样本的预测结果,得到预测准确度并且考虑随着训练过程进行教师模型的变化,加入时间因素,根据当前通信回合轮数t和最终通信回合数t,计算得出本轮教师模型在客户端ci的知识蒸馏比例α
i,kd
=ai*t/t。
20.步骤2.2中:
21.客户端利用本地数据集di计算教师模型输出分布平缓程度λ。对于输入x和目标y,教师模型产生logit向量z
teacher
(x),计算教师模型输出logit向量平缓程度其中z
teacher,i
(x)代表单个类别的教师模型logit向量,k是多类别分类任务的类别数目,mean(z
teacher
(x))是各类别预测概率的平均值,
22.步骤2.3中:
23.客户端ci将全局模型m,作为本地模型mi的起点,利用本地数据集di,对于输入x和目标y,本地模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x);
24.使用交叉熵函数计算真实概率y和预测概率分布p(x)之间的差异;
25.对于输入x和目标y,教师模型产生logit向量z
teacher
(x);根据计算出的教师模型输出logit
向量平缓程度λ,缩放logit向量对应的softmax分布k为多类别分类任务的类别数目,便于学生模型对教师模型的知识的学习;针对本地模型和教师模型在同一数据样本上不同的预测概率输出,使用kullback-leibler(kl)散度表示教师模型输出概率分布p
teacher
(x)和学生模型输出概率p(x)分布之间的差异,约束本地模型向教师模型学习;
26.并且,根据计算出的知识蒸馏比例α
i,kd
,动态调整交叉熵损失h(p,y)和kl散度损失d
kl
(p
teacher
||p)的比例,总损失函数l=(1-α
i,kd
)h(p,y)+α
i,kddkl
(p
teacher
||p);最后使用总损失反向传播计算梯度更新本地模型。
27.通过本发明所构思的以上技术方案与现有技术相比,本发明的优点在于:
28.1、本发明基于联邦学习,服务器只对客户端上交的模型进行聚合操作即可得到全局模型,客户端数据不会泄露给第三方,能够对客户端身份隐私做到很好的保护,不用担心客户端数据泄露的情况,因此本发明具有很高的隐私保护安全性。
29.2、本发明在确保用户信息数据和隐私不被泄露的情况下实现了对客户端模型的聚合并灵活调控。使用知识蒸馏技术,客户端利用上一轮的全局模型对个性化本地模型训练过程加以约束,很好地改善non-iid场景下模型权重发散导致聚合后模型性能灾难性下降问题。
30.3、每个客户端利用全局模型对本地训练数据集的预测准确度,可以自主计算出合适于当前全局模型的知识蒸馏比例,进行选择性学习。相比传统方案灵活性更高,具有很高的实用性。
31.4、客户端可以根据自己的教师模型输出分布的平缓程度自适应调整教师模型的输出分布,更有利于教师模型与学生模型之间知识的传输,也具有较高的灵活性。
附图说明
32.下面对本发明说明书各幅附图表达的内容及图中的标记作简要说明:
33.图1为本发明联邦学习通信过程的流程示意图;
34.图2为本发明客户端本地训练方法的流程示意图;
35.图3为本发明服务器聚合模型过程的流程示意图。
具体实施方式
36.下面对照附图,通过对最优实施例的描述,对本发明的具体实施方式作进一步详细的说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。此外,下面所描述的本发明各个实施方式中所涉及到的技术特征只要彼此之间未构成冲突就可以相互组合。
37.本发明提供了隐私保护和数据安全技术领域的一种基于动态自适应知识蒸馏的联邦学习模型聚合方法,其目的在于为客户端提供一种基于知识蒸馏的本地训练方法,客户端本地训练阶段,可以自适应地选择对全局模型进行学习,保证本地数据集学习效果的同时维护本地模型权重不过于发散,便于之后服务器直接聚合本地模型生成全局模型。从而实现non-iid场景下,既保护客户端数据隐私信息不被泄露,又可以基于联邦学习协调多
方训练得到的全局模型能够表现出更好的预测准确度,为用户提供服务。
38.本发明的整体思路在于,服务器初始化全局模型发送给客户端;客户端利用本地数据集确定本轮知识蒸馏中对收到的全局模型学习的比例,自适应调整学习本地数据集和全局模型的比例,并动态调整教师模型的输出,使其处于最适合学习的分布状态,灵活进行学习,训练生成本地模型,并上传给服务器;服务器对收到的本地模型进行加权平均,聚合生成新的全局模型,并进入下一次训练过程。最终训练得到一个适用于各个客户端的通用的全局模型。
39.本技术实施例的方案为:一种基于动态自适应知识蒸馏的联邦学习模型聚合方法,包括:
40.服务器初始化全局模型,将全局模型下发给客户端;
41.客户端接收全局模型,使用本地训练数据集训练生成本地模型,并且客户端将训练完成的本地模型上传给服务器;
42.服务器收集客户端发送的n个本地模型,聚合得到新的全局模型参数。
43.优选地,所述服务器初始化全局模型,将全局模型下发给客户端,包括:
44.服务器从初始模型提供者处下载模型或者随机生成初始全局模型m,并将全局模型m统一下发给本轮参与训练的n个客户端c1,c2,...cn,通信模型本身。
45.优选地,所述客户端接收全局模型,使用本地训练数据集训练生成本地模型,并且客户端将训练完成的本地模型上传给服务器,包括:
46.(a)客户端ci(1≤i≤n)接收服务器下发的模型m,作为本地模型mi的教师模型,参与本地模型mi优化训练过程。客户端利用本地训练数据集di=(x,y),测试教师模型在本地训练数据集di上的准确度ai。已知数据样本x和对应标签y,输入教师模型,模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x)。p(x)中预测概率最高的类别即为模型预测结果,如与对应样本标签y一致,则预测成功,否则预测失败。统计所有样本的预测结果,得到预测准确度并且考虑随着训练过程进行教师模型的变化,加入时间因素,根据当前通信回合轮数t和最终通信回合数t,计算得出本轮教师模型在客户端ci的知识蒸馏比例α
i,kd
=ai*t/t。
47.(b)利用本地训练数据集di=(x,y),客户端计算教师模型输出logit向量平缓程度λ。对于输入x和目标y,教师模型产生logit向量z
teacher
(x),计算得到其中z
teacher,i
(x)代表单个类别的教师模型logit向量,k是多类别分类任务的类别数目,mean(z
teacher
(x))是各类别预测概率的平均值,
48.(c)客户端ci将全局模型m,作为本地模型mi的起点,利用本地数据集di,对于输入x和目标y,本地模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x)。使用交叉熵函数计算真实概率y和预测概率分布p(x)之间的差异。
49.同样的,对于输入x和目标y,教师模型产生logit向量z
teacher
(x)。根据(b)中计算出的教师模型输出logit向量平缓程度λ,缩放logit向量对应的softmax分布
其中k为多类别分类任务的类别数目,便于教师模型到学生模型的知识传输。针对本地模型和教师模型在同一数据样本上不同的预测概率输出,使用kl散度表示教师模型输出概率分布p
teacher
(x)和学生模型输出概率p(x)分布之间的差异,约束本地模型向教师模型学习。
50.并且,根据(a)中计算出的知识蒸馏比例α
i,kd
,动态调整交叉熵损失h(p,y)和kl散度损失d
kl
(p
teacher
||p)的比例,总损失函数l=(1-α
i,kd
)h(p,y)+α
i,kddkl
(p
teacher
||p)。最后使用总损失反向传播计算梯度更新本地模型。
51.(d)客户端将训练得到的本地模型上传给服务器,通信模型本身。
52.优选地,所述服务器收集客户端发送的n个本地模型,聚合得到新的全局模型参数,包括:
53.服务器收齐所有客户端传输的局部模型,使用联邦平均算法对n个本地模型进行聚合,得到全局模型的m,并进入下一轮的模型训练。
54.下面将具体介绍各个部分:
55.如图1所示,为本发明实施例提供的一种方法流程示意图,主要包括联邦学习中服务器与客户端通信流程。其中各序号流程分别为:
56.①
:服务器初始化全局模型,下发给客户端;
57.②
:客户端训练本地模型;
58.③
:客户端上传本地模型;
59.④
:服务器聚合本地模型生成新一轮全局模型。
60.系统包括一个可信任的服务器,和n个独立的客户端c1,c2,...,cn,客户端各自拥有独立的数据d1,d2,...,dn。当训练开始,服务器根据训练任务选择合适的模型,下载预训练模型或随机初始化模型参数,作为本轮全局模型m。然后选择参与本轮训练的客户端c1,c2,...,cn(1≤n≤n),将全局模型下发给参与训练的客户端。客户端将收到的全局模型作为本地模型m1,m2,..,.mn的起点,利用本地数据集d1,d2,...,dn进行训练得到本地模型,并上传给服务器,本地训练流程如图2所示。然后服务器收集本地模型完成后,按照本地数据集的比例加权平均收集的本地模型参数,得到新的全局模型其中d为所有数据集规模之和,di是每个客户端i的本地数据集大小,mi是客户端i的本地模型,n为参与本轮训练的客户端数目。服务器聚合流程如图3所示。然后进行下一轮的训练。多轮训练完成后,得到一个最终的通用的全局模型m。
61.如图2所示,为客户端本地训练方法的流程示意图,用于客户端本地训练阶段。具体步骤包括:
62.第一步:客户端接收服务器下发的本轮全局模型m,保存为教师模型。利用本地训练数据集di=(x,y),样本x输入教师模型得到logit向量z(x)=[z1(x),z2(x),...,zk(x)],然后通过softmax函数输出预测的概率p(x)=[p1(x),p2(x),...,pk(x)]。p(x)中预测概率最高的类别即为模型预测结果,如与对应样本标签y一致,则预测成功,否则预测失败。统计所有样本的预测结果,得到预测准确度因为全局模型训练效果在训练
不同阶段表现不一致,为更好衡量适合于当前全局模型的知识蒸馏比例,根据当前通信轮数t、总通信轮数t和第一步得到的预测准确度ai,计算得出本轮teacher在客户端ci的知识蒸馏比例α
i,kd
=ai*t/t。
[0063]
第二步:对于本地训练数据集di=(x,y),输入x和目标y,教师模型产生logit向量z
teacher
(x),计算教师模型输出logit向量的平缓程度其中z
teacher,i
(x)代表单个类别的教师模型logit向量,k是多类别分类任务的类别数目,mean(z
teacher
(x))是各类别预测概率的平均值,
[0064]
第三步:客户端将全局模型,作为本地模型mi的起点,对于输入x和目标y,本地模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x)。同样的,输入x和目标y,教师模型产生logit向量z
teacher
(x)。根据第二步得到的教师模型输出的logit向量平缓程度λ,缩放z
teacher
(x)对应的softmax分布k是多类别分类任务的类别数目,便于学生模型对教师模型学习。
[0065]
第四步:使用交叉熵函数计算真实概率y和预测概率分布p(x)之间的差异,交叉熵损失为其中k是多类别分类任务的类别数目,p(xi)是本地模型对类别i的预测概率。
[0066]
第五步:使用kl散度表示教师模型输出概率分布p
teacher
(x)和学生模型(本地模型)输出概率p(x)分布之间的差异,函数如下:
[0067][0068]
其中,k是多类别分类任务的类别数目,p(xi)是本地模型对类别i的预测概率,p
teacher
(xi)是教师模型对类别i的预测概率。
[0069]
第六步:根据第一步中得到的知识蒸馏比例α
i,kd
,调整第四步计算出的交叉熵损失h(p,y)和第五步计算出的kl散度损失d
kl
(p
teacher
||p)的比例,总损失函数l=(1-α
i,kd
)h(p,y)+α
i,kddkl
(p
teacher
||p)。最后使用总损失反向传播计算梯度更新本地模型,然后返回第三步进入下一轮本地训练。本地模型训练完成后,客户端将本地模型发送给服务器。
[0070]
如图3所示,为服务器聚合阶段的流程示意图。具体步骤包括:
[0071]
第一步:服务器收集各个客户端上传的本地模型。
[0072]
第二步:根据联邦平均算法,计算每个客户端数据集大小占总数据集大小的比例,按照比例对对应本地模型参数进行加权,平均得到新的全局模型。并进入下一轮全局模型的训练。其中客户端1的样本量a1,训练完成后为本地模型m1,则其对应的参数就是a1*m1/(a1+a2+a3...+an),对应的全局模型m的参数m=(a1*m1+a2*m2...+an*mn)/(a1+a2...+an)。
[0073]
显然本发明具体实现并不受上述方式的限制,只要采用了本发明的方法构思和技术方案进行的各种非实质性的改进,均在本发明的保护范围之内。

技术特征:
1.基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:包括如下步骤:步骤1:服务器初始化全局模型并将其发送至参与本轮训练的客户端;步骤2:客户端接收到全局模型后,确定本轮知识蒸馏中对收到的全局模型学习的比例,自适应调整学习本地数据集和全局模型的比例,并动态调整教师模型的输出,使其处于最适合学习的分布状态,训练生成本地模型,并上传给服务器;步骤3:对接收到的本地模型进行聚合生成新的全局模型从而完成本轮训练过程。2.如权利要求1所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:步骤1中:服务器根据训练任务选择待训练的模型作为本轮全局模型m;然后选择参与本轮训练的客户端c1,c2,...,c
n
(1≤n≤n),将全局模型下发给参与训练的客户端;其中客户端c1,c2,...,c
n
为n个独立的客户端,客户端各自拥有独立的数据d1,d2,...,d
n
。3.如权利要求1所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:步骤3中:在接收到全部客户端上传的全部本地模型后,采用联邦平均算法对本地模型进行聚合形成新的全局模型。4.如权利要求3所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:计算每个客户端数据集大小占总数据集大小的比例,按照比例对对应本地模型参数进行加权形成全局模型的参数进而得到新的全局模型。5.如权利要求1-4任一所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:步骤2中包括:步骤2.1:客户端保存全局模型作为知识蒸馏的教师模型,计算当前客户端对于教师模型的知识蒸馏比例;步骤2.2:客户端计算教师模型输出分布平缓程度;步骤2.3:客户端把全局模型作为本地模型的初始模型,利用本地训练数据集训练本地模型;并且把本地模型作为知识蒸馏中的学生模型,使用步骤2.1、2.2中计算出的知识蒸馏比例和输出分布平缓程度约束教师模型的知识蒸馏过程;步骤2.4:客户端将本地模型上传给服务器。6.如权利要求5所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:步骤2.1中:利用本地数据集确定本轮知识蒸馏中对收到的全局模型学习的比例;客户端c
i
(1≤i≤n)接收服务器下发的模型m,作为本地模型m
i
的教师模型,参与本地模型m
i
优化训练过程;客户端利用本地训练数据集d
i
=(x,y),测试教师模型在本地训练数据集d
i
上的准确度a
i
;已知数据样本x和对应标签y,输入教师模型,模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x);p(x)中预测概率最高的类别即为模型预测结果,如与对应样本标签y一致,则预测成功,否则预测失败;统计所有样本的预测结果,得到预测准确度
并且考虑随着训练过程进行教师模型的变化,加入时间因素,根据当前通信回合轮数t和最终通信回合数t,计算得出本轮教师模型在客户端c
i
的知识蒸馏比例α
i,kd
=a
i
*t/t。7.如权利要求5所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:步骤2.2中:客户端利用本地数据集d
i
计算教师模型输出分布平缓程度λ;对于输入x和目标y,教师模型产生logit向量z
teacher
(x),计算教师模型输出logit向量平缓程度其中z
teacher,i
(x)代表单个类别的教师模型logit向量,k是多类别分类任务的类别数目,mean(z
teacher
(x))是各类别预测概率的平均值,8.如权利要求5所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:步骤2.3中:客户端c
i
将全局模型m,作为本地模型m
i
的起点,利用本地数据集d
i
,对于输入x和目标y,本地模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x);使用交叉熵函数计算真实概率y和预测概率分布p(x)之间的差异;对于输入x和目标y,教师模型产生logit向量z
teacher
(x);根据计算出的教师模型输出logit向量平缓程度λ,缩放logit向量对应的softmax分布k为多类别分类任务的类别数目,便于学生模型对教师模型的知识的学习;针对本地模型和教师模型在同一数据样本上不同的预测概率输出,使用kullback-leibler(kl)散度表示教师模型输出概率分布p
teacher
(x)和学生模型输出概率p(x)分布之间的差异,约束本地模型向教师模型学习;并且,根据计算出的知识蒸馏比例α
i,kd
,动态调整交叉熵损失h(p,y)和kl散度损失d
kl
(p
teacher
||p)的比例,总损失函数l=(1-α
i,kd
)h(p,y)+α
i,kd
d
kl
(p
teacher
||p);最后使用总损失反向传播计算梯度更新本地模型。

技术总结
本发明提供了一种基于动态自适应知识蒸馏的联邦学习模型聚合方法,可以有效缓解数据异质性带来的精度下降问题。方法包括全局模型初始化、本地模型训练、聚合生成全局模型三个阶段。本发明在本地模型训练阶段使用知识蒸馏技术促进客户端学习全局模型,动态调整知识蒸馏比例使客户端可以根据各自情况自适应学习全局模型,并且动态调整教师模型输出分布使客户端更有效地利用知识蒸馏中教师模型的知识,使得聚合后服务器能够有效生成性能更优的全局模型,同时保证不泄露聚合过程中局部模型和全局模型的额外隐私。本发明能够在保证用户隐私安全的前提下,协同多方训练生成更优的全局模型。模型。模型。


技术研发人员:吕军 马晓静 赵瑞欣 付佳韵 陈付龙 苌婉婷
受保护的技术使用者:安徽师范大学
技术研发日:2023.06.09
技术公布日:2023/9/6
版权声明

本文仅代表作者观点,不代表航空之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)

飞行汽车 https://www.autovtol.com/

分享:

扫一扫在手机阅读、分享本文

相关推荐