一种联邦学习方法
未命名
08-02
阅读:139
评论:0
1.本发明涉及计算机技术领域,具体来说,涉及联邦学习领域,更具体来说,涉及一种联邦学习方法。
背景技术:
2.随着《个人信息保护法》的发布,企业在使用、管理、存储隐私数据方面的成本不断增加,使得数据变得更加难以流通和共享,各领域“数据孤岛”现象愈发突出。而联邦学习技术的出现,为解决“数据孤岛”问题提供了新的思路。联邦学习是一种分布式训练ai模型的一种技术方案,它保证隐私数据保留在本地,参与训练模型的各方通过自身的隐私数据建立联合模型。由于隐私数据并不直接参与到共享中,仅通过模型的形式共享数据的特征,数据的所有权、使用权均不受影响,缓冲了隐私数据保护和使用的矛盾,打通“数据孤岛”,为数据驱动型产业在高数据监管力度下带来新的解决方案。但与此同时,联邦学习仍然面临着众多挑战,其中之一便是难以衡量各节点对于联合模型的贡献。
3.在实际应用中,各节点的数据数量和数据质量往往存在较大差异,公平的联邦学习系统需要综合考虑各节点的数据数量和数据质量,以分配公平的贡献度和奖励,来维持联邦学习系统的稳定性。但是,各节点的数据并不透明,仅以模型的形式进行公开,难以获取各节点数据数量和数据质量情况,例如,传统联邦学习系统贡献度计算通常直接计算各节点模型与联合模型相似度,虽然模型能够体现数据特征,但模型参数存在可解释性弱的问题,且微小的数据扰动都会引起模型参数的巨大变动。因此,直接衡量各节点模型与联合模型相似度的传统方案虽然为联邦学习系统的贡献度提供了一种量化方案,但其可解释性较弱且联邦学习效果较差。
4.因此,亟需一种在各节点数据不透明的情况下公平地衡量各节点的贡献度,以基于各节点的贡献度更好地进行联邦学习的方法。
技术实现要素:
5.因此,本发明的目的在于克服上述现有技术的缺陷,提供一种联邦学习方法。
6.本发明的目的是通过以下技术方案实现的:
7.根据本发明的第一方面,提供一种联邦学习方法,所述方法包括由中心节点将初始化的联合模型分发给多个客户端作为初始的客户端模型,并由中心节点和客户端配合完成多轮联邦训练,得到最终的联合模型,其中,每轮联邦训练包括:由中心节点获取每个客户端上传的当前轮训练后的客户端模型,其中,当前轮训练后的客户端模型是利用客户端的本地训练集以最新获得的客户端模型为基础训练得到的;由中心节点基于各客户端的贡献度对多个客户端当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型,并将该联合模型分发给多个客户端作为各客户端下一轮训练的基础,其中,各客户端的贡献度基于各客户端模型分别对预设的仿真样本的分类准确率确定。
8.在本发明的一些实施例中,所述每个客户端的贡献度根据该客户端的数据标签分
布情况和该客户端模型对各类别下仿真样本的分类准确率确定,其中,所述数据标签分布情况指示对应客户端的各类别下的非仿真样本的占比。
9.在本发明的一些实施例中,所述每个客户端的贡献度是该客户端的数据标签分布情况中每个类别下的非仿真样本的占比与该客户端模型对该类别的仿真样本的分类准确率的乘积之和。
10.在本发明的一些实施例中,由中心节点按照以下方式获得当前轮更新后的联合模型:每轮联邦训练前,获取最新更新的每个客户端的贡献度,比较各客户端中最新更新的每个客户端的贡献度和预设阈值的大小,剔除贡献度小于预设阈值所对应的客户端,得到当前轮参与训练的多个客户端;对当前轮参与训练的多个客户端上传的当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型。
11.在本发明的一些实施例中,所述仿真样本按照以下方式生成:获取基于生成对抗方式训练得到的经训练的生成模型;利用经训练的生成模型针对每种类别对应生成多个仿真样本,其中,每种类别下的仿真样本的数据标签向量与为该类别预设的标签向量的距离小于预设阈值,数据标签向量是基于将仿真样本输入当前轮更新后的联合模型中得到的。
12.在本发明的一些实施例中,基于所述生成对抗方式进行一轮或者多轮迭代对抗训练,每轮对抗训练包括:获取对抗生成网络,其包括生成模型和判别模型;获取第一训练集训练判别模型,得到当轮训练的判别模型,所述第一训练集包括多个第一样本和每个第一样本对应的指示其是非仿真样本的置信度标签,单个第一样本为仿真样本或者非仿真样本,该置信度标签基于将第一样本输入当轮更新后的联合模型得到的输出结果确定;将生成的仿真样本输入当轮训练的判别模型,利用判别模型对生成的仿真样本的判别损失更新生成模型的参数。
13.在本发明的一些实施例中,所述第一训练集中每个第一样本的生成方式包括:获取数据集中的非仿真样本或生成模型基于随机数生成的仿真样本;将非仿真样本或仿真样本输入当轮更新后的联合模型中,得到数据标签向量,计算数据标签向量与为多种类别中每种类别预设的标签向量的距离,得到多个距离,从多个距离中选择与数据标签向量最小的距离;获取预设的超参数,根据预设的超参数以及计算的最小距离确定第一样本的置信度标签。
14.在本发明的一些实施例中,第一样本的置信度标签按照以下方式确定:
[0015][0016]
其中,h表示置信度标签,α为预设的超参数,d表示计算的非仿真样本或仿真样本输入当前轮更新后的联合模型中得到的数据标签向量与为各类别预设的标签向量的最小距离,β表示预设的超参数,γ表示预设的超参数。
[0017]
在本发明的一些实施例中,各客户端的贡献度是每隔预设的轮次周期性更新的,或者在所有的联邦训练的轮次中指定的联邦训练的轮次被间隔更新的。
[0018]
根据本发明的第二方面,提供一种图像分类方法,所述方法包括:获取根据本发明的第一方面任一项所述的联邦学习方法得到最终的联合模型,其中,联邦学习方法中,各个客户端的本地训练集中的样本数据为图像数据,标签为对应样本数据对应的类别;基于最
终的联合模型对输入的图像数据进行图像分类。
[0019]
在本发明的一些实施例中,样本数据对应的类别为飞机、汽车、鸟、猫、鹿、狗、蛙、马、船或卡车。
[0020]
根据本发明的第三方面,提供一种用户分类方法,所述方法包括:获取根据本发明的第一方面任一项所述的联邦学习方法得到最终的联合模型,其中,联邦学习方法中,各个客户端的本地训练集中的样本数据为用户特征数据,标签为对应样本数据对应的用户类别;基于最终的联合模型对输入的用户特征数据进行用户分类。
[0021]
在本发明的一些实施例中,所述样本数据对应的用户类别为优质用户、良好用户或普通用户。
[0022]
根据本发明的第四方面,提供一种电子设备,包括:一个或多个处理器;以及存储器,其中存储器用于存储可执行指令;所述一个或多个处理器被配置为经由执行所述可执行指令以实现本发明的第一方面、第二方面、第三方面中任一项所述方法的步骤。
[0023]
与现有技术相比,本发明的优点在于:
[0024]
本发明利用中心节点预设的仿真样本为联邦学习中各客户端的贡献度评估提供数据支撑,通过各客户端模型分别对预设的仿真样本的分类准确率确定各客户端的贡献度,为联邦学习提供公平合理、可解释的贡献度计算;在聚合联合模型时,该贡献度的计算使中心节点更好地评估每个客户端的重要性,从而更好地利用每个客户端的贡献度对多个客户端当前轮训练后的客户端模型进行聚合,提高联邦学习的效率以及联合模型的泛化能力。
附图说明
[0025]
以下参照附图对本发明实施例作进一步说明,其中:
[0026]
图1为根据本发明一个实施例的联邦学习方法的流程示意图;
[0027]
图2为根据本发明一个实施例的联邦学习系统架构示意图;
[0028]
图3为根据本发明一个实施例的联合模型和客户端模型的结构原理示意图;
[0029]
图4为根据本发明一个实施例的生成模型的结构原理示意图;
[0030]
图5为根据本发明一个实施例的判别模型的结构原理示意图;
[0031]
图6为根据本发明一个实施例的一种联邦学习方法整体流程示意图。
具体实施方式
[0032]
为了使本发明的目的,技术方案及优点更加清楚明白,以下结合附图通过具体实施例对本发明进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本发明,并不用于限定本发明。
[0033]
如在背景技术部分提到的,现有进行联邦学习方法中各节点(即客户端)的数据并不透明,仅以模型的形式进行公开,难以获取各节点数据数量和数据质量情况,导致基于现有计算得到的贡献度进行联邦学习的效果较差。
[0034]
为了解决上述问题,首先,发明人考虑到直接基于模型的形式量化各个客户端的方式导致贡献度衡量效果差,联邦学习效率低。因此,发明人以数据的形式来衡量各个客户端的贡献度,综合考虑各个客户端的数据数量和数据质量,以公平地衡量各个客户端的贡
献度,以维持联邦学习的稳定性;其次,发明人还考虑到各个客户端的数据也并不透明,而中心节点通常缺少样本数据,因此,发明人利用中心节点预设仿真样本,计算各客户端的客户端模型对各类别下的仿真样本的分类准确率,基于该分类准确率来确定各客户端的贡献度,为联邦学习提供公平合理、可解释性强的贡献度计算方式。最后,中心节点在获取每个客户端上传的当前轮训练后的客户端模型后,基于本发明的贡献度的计算可以帮助中心节点更好地评估每个客户端的重要性,从而根据每个客户端的贡献度对各客户端模型进行聚合使得到的联合模型具有更强的泛化能力。
[0035]
基于上述分析,本发明提供一种联邦学习方法,所述方法包括由中心节点将初始化的联合模型分发给多个客户端作为初始的客户端模型,并由中心节点和客户端配合完成多轮联邦训练,得到最终的联合模型,其中,每轮联邦训练包括步骤s1和s2,参见图1,在步骤s1中,由中心节点获取每个客户端上传的当前轮训练后的客户端模型,其中,当前轮训练后的客户端模型是利用客户端的本地训练集以最新获得的客户端模型为基础训练得到的,本发明通过将本地训练集训练的客户端模型上传到中心节点,保证了联邦学习中的数据隐私,通过上传训练后的各客户端模型实现数据共享;在步骤s2中,由中心节点基于各客户端的贡献度对多个客户端当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型,并将该联合模型分发给多个客户端作为各客户端下一轮训练的基础,其中,各客户端的贡献度基于各客户端模型分别对预设的仿真样本的分类准确率确定。本发明利用中心节点预设的仿真样本为联邦学习中各节点客户端的贡献度评估提供数据支撑,通过各客户端模型分别对该仿真样本的分类准确率确定各客户端的贡献度,为联邦学习提供公平合理、可解释的贡献度计算,在聚合联合模型时,该贡献度的计算使中心节点更好地评估每个客户端的重要性,从而更好地利用每个客户端的贡献度对多个客户端当前轮训练后的客户端模型进行聚合,提高联邦学习的效率以及联合模型的泛化能力。
[0036]
根据本发明的一个实施例,参见图2,图2为联邦学习系统架构示意图,联邦学习系统架构中包括中心节点c和参与联邦学习的多个客户端,中心节点c中包括联合模型m0,参与联邦学习的多个客户端分别为客户端e1、客户端e2
……
客户端ek,每个客户端包括一个本地数据库和一个客户端模型,即客户端e1包括本地数据库d1和客户端模型m1、客户端e2包括本地数据库d2和客户端模型m2、客户端ek包括本地数据库dk和客户端模型mk。基于该联邦学习系统架构执行本发明的联邦学习方法时,首先,在进行多轮联邦训练前,通过中心节点初始化联合模型m0以及各个客户端的贡献度,其次,将初始化的联合模型分发给多个客户端作为初始的客户端模型,最后,由中心节点和客户端配合完成多轮联邦训练,得到最终的联合模型。其中,每轮联邦训练过程中,中心节点获取每个客户端上传的当前轮训练后的客户端模型,其中,当前轮训练后的客户端模型是利用客户端的本地训练集(即本地数据库)以最新获得的客户端模型为基础训练得到的;由中心节点基于各客户端的贡献度对多个客户端当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型,并将当前轮更新后的联合模型分发给多个客户端作为各客户端下一轮训练的基础,其中,各客户端的贡献度基于各客户端模型分别对预设的仿真样本的分类准确率确定。
[0037]
根据本发明的一个实施例,中心节点c的联合模型和参与联邦学习的多个客户端的客户端模型可根据联邦学习任务不同选择对应的模型结构。示意性的,以图像分类任务为例,参见图3,图3为联合模型以及客户端模型的结构原理示意图,具体结构依次为卷积层
1、激活函数层relu、批标准化层、卷积层2、激活函数层relu、批标准化层、池化层、卷积层3、批标准化层、激活函数层relu、卷积层4、激活函数层relu、批标准化层、池化层、卷积层5、激活函数层relu、批标准化层、卷积层6、激活函数层relu、批标准化层、池化层、全连接层1、全连接层2和激活函数层softmax。将一张32*32像素的rgb图像(彩色图像),每个rgb图像分为3个通道(r通道、g通道、b通道),将该图像作为输入通过图3的模型结构进行处理。其中,每个卷积层均以3*3的卷积核和步长1对其输入数据进行处理,即该图像通过卷积层1、激活函数层relu、批标准化层处理,输出维度为3*32*32的图像数据,再依次通过卷积层2、激活函数层relu、批标准化层、池化层处理输出维度为32*32*32的图像数据,再依次通过卷积层3、批标准化层、激活函数层relu处理输出维度为64*32*32的图像数据,再依次通过卷积层4、激活函数层relu、批标准化层、池化层处理输出维度为64*64*64的图像数据,再依次通过卷积层5、激活函数层relu、批标准化层处理输出维度为128*128*128的图像数据,再依次通过卷积层6、激活函数层relu、批标准化层、池化层处理输出维度为128*128*128的图像数据,最后,通过全连接层1、全连接层2、激活函数层softmax处理得到维度为10的输出数据,输出数据表示输入图像属于10个类别中每个类别的概率值。
[0038]
根据本发明的一个实施例,每轮联邦训练过程中包括步骤s1和s2,为了更好地理解本发明,下面结合具体的实施例针对每一个步骤分别进行详细说明。
[0039]
在步骤s1中,由中心节点获取每个客户端上传的当前轮训练后的客户端模型,其中,当前轮训练后的客户端模型是利用客户端的本地训练集以最新获得的客户端模型为基础训练得到的。
[0040]
根据本发明的一个实施例,联邦学习方法中,当联合模型用于图像分类任务时,各个客户端的本地训练集中的样本数据为图像数据,标签为对应样本数据对应的类别;基于最终的联合模型对输入的图像数据进行图像分类。其中,样本数据对应的类别为飞机、汽车、鸟、猫、鹿、狗、蛙、马、船或卡车。应当理解,此处仅为示意,当联邦学习任务不同时,其选择的训练集也不同,本发明对此并不限定,例如进行用户分类任务时,各个客户端的本地训练集中的样本数据可以为用户特征数据,标签为对应样本数据对应的用户类别;基于最终的联合模型对输入的用户特征数据进行用户分类。其中,所述样本数据对应的用户类别为优质用户、良好用户或普通用户。
[0041]
在步骤s2中,由中心节点基于各客户端的贡献度对多个客户端当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型,并将该联合模型分发给多个客户端作为各客户端下一轮训练的基础,其中,各客户端的贡献度基于各客户端模型分别对预设的仿真样本的分类准确率确定。本发明通过对仿真样本的分类准确率来确定各客户端的贡献度,为联邦学习提供公平合理、可解释的贡献度确定方式,在聚合联合模型时,使中心节点更好地评估每个客户端的重要性,可构建具有更强泛化能力的联合模型。
[0042]
根据本发明的一个实施例,由中心节点获得当前轮更新后的联合模型的方式:每轮联邦训练前,获取最新更新的每个客户端的贡献度,比较各客户端中最新更新的每个客户端的贡献度和预设阈值的大小,剔除贡献度小于预设阈值所对应的客户端,得到当前轮参与训练的多个客户端;对当前轮参与训练的多个客户端上传的当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型。本发明通过将贡献度低于预设阈值的客户端剔除,筛选低数据质量客户端和恶意客户端,保证联邦学习系统的稳定性和可持续性。中心
节点获取各个客户端模型的参数,使用fedavg聚合方法,对所有客户端模型的参数进行求和平均,得到更新后的联合模型,聚合公式如下:
[0043][0044]
其中,表示第i轮联邦训练得到更新后的联合模型的参数,并根据第i轮联合模型的参数得到第i轮更新后的联合模型符号的下标i表示联邦训练轮次,上标g表示其为联合模型,n表示参与联邦训练的客户端数量,表示第i轮训练后的第k个客户端模型的参数。应当理解,由中心节点获得当前轮更新后的联合模型也可以为其他方式,本发明对此并不限定,例如,基于各个客户端的贡献度,采用加权求和的方式来聚合各个客户端上传的当前轮训练后的客户端模型,具体计算如下:
[0045][0046]
其中,v0表示第1个客户端模型的贡献度,表示第i轮训练后的第1个客户端模型的参数,v1表示第2个客户端模型的贡献度,表示第i轮训练后的第2个客户端模型的参数,v
n-1
表示第n个客户端模型的贡献度,表示训练后的第n个客户端模型的参数,其中,该计算方式中各个客户端的贡献度为归一化处理后得到的,即v0+v1+
…
+v
n-1
=1。
[0047]
根据本发明的一个实施例,各客户端的贡献度是每隔预设的轮次周期性更新的,或者在所有的联邦训练的轮次中指定的联邦训练的轮次被间隔更新的。例如每隔40轮周期性更新客户端的贡献度,则第1轮到40轮之间的每轮的各个客户端的贡献度为第1轮获取的各客户端的贡献度,第40轮更新得到各个客户端的贡献度,第41轮到第80轮间的每轮的各客户端的贡献度采用第40轮最新更新的每个客户端的贡献度,第80轮再次更新得到各个客户端的贡献度,后续轮次则采用第80轮最新更新的贡献度,以此每隔预设的40轮次周期性更新;又或是在联邦训练的200轮次中指定第60轮和第140轮联邦训练更新得到各客户端的贡献度,则在第61轮前采用第1轮获取的各客户端的贡献度,第61轮及61轮到140轮间每轮的各客户端的贡献度采用第60轮最新更新的每个客户端的贡献度,并在第140轮次再次更新各客户端的贡献度,第141轮及141轮到200轮间各客户端的贡献度采用第140轮最新更新的每个客户端的贡献度。其中,第一轮中各客户端的贡献度由中心节点进行初始化得到,执行第一轮联邦训练是基于初始化的各客户端的贡献度对多个客户端当前轮训练后的客户端模型进行聚合。本发明在联邦训练过程中,设计间隔的贡献度计算轮次,降低计算成本并提高联邦学习效率,同时,避免恶意客户端对贡献度计算过程的针对性攻击,提高了联邦学习稳定性。
[0048]
与此同时,在非独立同分布数据场景中,基于各个客户端模型来量化客户端的贡献度对非独立同分布数据的贡献度衡量效果较差,且可解释性弱,另外,各客户端的客户端模型存在偏向性,即对应客户端存在数据质量好,但各类别数据分布不均匀,导致各客户端训练出的客户端模型对部分类别的数据分类准确率高而对其他类别的数据分类准确率低的问题。因此,本发明计算贡献度时不仅结合了对数据的分类准确率,还考虑了客户端的不同数据类别的分布情况,即数据标签分布情况。根据本发明的一个实施例,在中心节点中确
定各客户端的贡献度的方式为:每个客户端的贡献度根据该客户端的数据标签分布情况和该客户端模型对各类别下仿真样本的分类准确率确定,其中,所述数据标签分布情况指示对应客户端的各类别下的非仿真样本的占比。每个客户端各类别下的非仿真样本为客户端本地已有的数据,不是仿真生成的数据,且对应客户端的各类别下的非仿真样本的占比为客户端提供的数据分布情况,各类别下的非仿真样本的占比可以是客户端提供的每种标签对应类别下的数据与该客户端数据总量之比。如第j个客户端共有500条数据,其中,100条数据标签是狗,100条是猫,300条是飞机,那么对应飞机、汽车、鸟、猫、鹿、狗、蛙、马、船或卡车10类数据类别,可知第j个客户端的数据标签分布情况总体为[0.6,0,0,0.2,0,0.2,0,0,0,0]。应当理解,此处仅为示意,可根据具体情况的需要进行设置,例如,每个客户端各类别下的非仿真样本的占比也可以是客户端提供的每种标签对应类别下采用的训练数据与该客户端采用的本地训练集的总量之比,即客户端数据总量为1000时,采用300条作为本地训练集,本地训练集有150条数据为狗,150条为马,可知第j个客户端的数据标签分布情况总体为[0,0,0,0,0,0.5,0,0.5,0,0];又例如,每个客户端各类别下的非仿真样本的占比也可以是客户端提供的每种标签对应类别下采样的数据与该客户端采样的数据总量之比,本发明对此不作任何限制。本发明在计算各客户端模型在各数据类别下的分类准确率后,与各客户端的数据标签分布情况相结合,共同得出各客户端的贡献度,缓解了贡献度计算过程中异构数据造成的偏见,对数据质量好但各类别数据分布不均匀的客户端的贡献度估计更准确,提高对各客户端的贡献度的计算准确性。
[0049]
根据本发明的一个实施例,在计算客户端的贡献度时,首先需针对各类别生成对应的仿真样本,再确定客户端模型对各类别下仿真样本的分类准确率,根据分类准确率和数据标签分布情况确定贡献度。其中,仿真样本生成方式包括:获取基于生成对抗方式训练得到的经训练的生成模型;利用经训练的生成模型针对每种类别对应生成多个仿真样本,其中,每种类别下的仿真样本的数据标签向量与为该类别预设的标签向量的距离小于预设阈值,数据标签向量是基于将仿真样本输入当前轮更新后的联合模型中得到的。其中,数据标签向量和预设的标签向量均为one-hot编码,将仿真样本作为当前轮更新后的联合模型的输入,得到数据标签向量,数据标签向量与对应类别预设的标签向量距离越小表明该仿真样本的数据质量越好。因此,本发明计算数据标签向量与各类别预设的标签向量的最小距离,同时,将最小距离与预设阈值对比,该最小距离小于预设阈值,即表示该仿真样本质量高,可用于测试客户端模型的分类准确率,为联邦学习中各客户端的贡献度评估提供数据支撑,提高联邦学习的效率。
[0050]
根据本发明的一个实施例,客户端模型对各类别下仿真样本的分类准确率计算方式包括:将每种类别下的多个仿真样本输入到每个客户端上传的当前轮训练后的客户端模型中,得到每个客户端模型的输出结果,根据每个客户端模型的输出结果计算每个类别下每个客户端模型的准确率。示意性的,假设数据类别有10类,分别为飞机、汽车、鸟、猫、鹿、狗、蛙、马、船和卡车,针对每种类别生成200条仿真样本,将鸟类的200条仿真样本输入各客户端模型中,若一个客户端模型输出的200个结果中有20个分类错误,则得到其对鸟类的分类准确率为0.9。
[0051]
根据本发明的一个实施例,根据分类准确率和数据标签分布情况确定贡献度的方式为:计算每个客户端的贡献度是该客户端的数据标签分布情况中每个类别下的非仿真样
本的占比与该客户端模型对该类别的仿真样本的分类准确率的乘积之和,即每个客户端的贡献度计算方式如下:
[0052][0053]
其中,vj表示第j个客户端的贡献度,j表示第j个客户端,表示第j个客户端模型的对第f种类别的分类准确率,f表示第f种类别,表示第j个客户端中第f种类别的数据标签分布情况。
[0054]
根据本发明的一个实施例,仿真样本生成前,基于所述生成对抗方式进行一轮或者多轮迭代对抗训练,本发明中设置为3000轮后停止对抗训练,得到经训练的生成模型,通过经训练的生成模型生成仿真样本。其中,每轮对抗训练包括:如下步骤a1、a2和a3:
[0055]
在步骤a1中,获取对抗生成网络,其包括生成模型和判别模型。
[0056]
示意性的,参见图4,图4为生成模型的结构原理示意图,具体结构依次包括反卷积层1、批标准化层、激活函数层relu、反卷积层2、批标准化层、激活函数层relu、反卷积层3、批标准化层、激活函数层relu、反卷积层4和激活函数层tanh。其中,反卷积层1以4*4的卷积核和步长1对其输入数据进行处理,反卷积层2、反卷积层3均以4*4的卷积核和步长2对其输入数据进行处理,输入数据为100个随机数,依次通过反卷积层1、批标准化层、激活函数层relu处理输出维度为512*4*4的数据,再依次通过反卷积层2、批标准化层、激活函数层relu处理输出维度为256*8*8的数据,再依次通过反卷积层3、批标准化层、激活函数层relu处理输出维度为128*16*16的数据,再依次通过反卷积层4和激活函数层tanh处理输出维度为3*32*32的数据。
[0057]
示意性的,参见图5,图5为判别模型的结构原理示意图,具体结构依次包括卷积层1、激活函数层leaky relu、卷积层2、批标准化层、激活函数层leaky relu、卷积层3、批标准化层、激活函数层leaky relu、卷积层4、批标准化层、激活函数层leaky relu和激活函数层sigmoid。其中,卷积层1、卷积层2、卷积层3均以4*4的卷积核和步长2对其输入数据进行处理,卷积层4以4*4的卷积核和步长1对其输入数据进行处理,将数据输入判别模型中,依次通过卷积层1、激活函数层leaky relu处理输出维度为3*32*32的数据,再依次通过卷积层2、批标准化层、激活函数层leaky relu处理输出维度为64*16*16的数据,再依次通过卷积层3、批标准化层、激活函数层leaky relu处理输出维度为128*8*8的数据,再依次通过卷积层4、批标准化层、激活函数层leaky relu处理输出维度为256*4*4的数据,最后,通过激活函数层sigmoid处理输出一个判别模型对输入数据的判别结果。
[0058]
在步骤a2中,获取第一训练集训练判别模型,得到当轮训练的判别模型,所述第一训练集包括多个第一样本和每个第一样本对应的指示其是非仿真样本的置信度标签,单个第一样本为仿真样本或者非仿真样本,该置信度标签基于将第一样本输入当轮更新后的联合模型得到的输出结果确定。
[0059]
由于隐私原因中心节点无法采集大量的样本数据来进行对抗训练,通常采用小批量的训练集来训练,且训练时采用的非仿真样本也有质量较差的,采用的仿真样本也有伪造的质量比较好,需先为这些样本赋予一个合理的置信度标签,提高训练效率以及生成模型的准确度,而联合模型就是基于各个客户端的本地已经存在的样本数据训练聚合而来
的,多轮联邦训练后的联合模型包含各个客户端的本地样本数据的信息,本发明通过当前轮更新后的联合模型先对第一样本进行一个处理评估,输出第一样本的评估结果,基于该评估结果以及预设的超参数计算得到第一样本的置信度标签,以此方式得到第一训练集后再进行对抗训练。因此,根据本发明的一个实施例,所述第一训练集中每个第一样本的生成方式包括:获取数据集中的非仿真样本或生成模型基于随机数生成的仿真样本;将非仿真样本或仿真样本输入当轮更新后的联合模型中,得到数据标签向量,计算数据标签向量与为多种类别中每种类别预设的标签向量的距离,得到多个距离,从多个距离中选择与数据标签向量最小的距离;获取预设的超参数,根据预设的超参数以及计算的最小距离确定第一样本的置信度标签。示意性的,若联邦学习任务为图像分类,则采用的第一样本为数据集中的非仿真图像样本或生成模型基于随机数生成的仿真图像样本,该处数据集可采用cifar10,cifar10数据集包括若干个非仿真图像样本,每个非仿真图像样本都是一张32*32像素的rgb图像。其中,第一样本的置信度标签的计算方式如下:
[0060][0061]
其中,h表示置信度标签,α为预设的超参数,d表示计算的非仿真样本或仿真样本输入当前轮更新后的联合模型中得到的数据标签向量与为各类别预设的标签向量的最小距离,β表示预设的超参数,γ表示预设的超参数,最终,第一样本的形式为[非仿真样本,置信度标签=1-αd]或者为[仿真样本,置信度标签=β-γd],且α,β,γ均为人工设定的超参数,本发明中α=1,β=1,γ=0.5。其中,最小距离d可采用欧氏距离,范数等计算方式,例如以范数方式计算,计算方式如下:
[0062]
d=min
c∈ce
{||l
j-c||},
[0063]
其中,d表示最小距离,min表示求最小,c表示与数据标签向量最小的距离对应的类别预设的标签向量,ce表示各类别预设的标签向量,‖
·
‖表示范数,lj表示非仿真样本或仿真样本对应的数据标签向量。其中,数据标签向量和预设的标签向量均为one-hot编码,数据标签向量获得方式为:将第一样本(即非仿真样本或仿真样本)作为当前轮更新后的联合模型的输入,得到第一样本的数据标签向量lj。本发明置信度标签的计算方式基于当前轮更新后的联合模型和超参数得到,即通过当前轮更新后的联合模型评估并结合超参数计算第一样本对应的置信度标签,不管采用的第一样本是非仿真样本或生成模型生成的仿真样本,均可以为对应第一样本设置合理的置信度标签来指示该第一样本是非仿真样本的概率,提高了小批量的训练集下生成模型的精度和效果,基于该生成模型生成的仿真样本可以更准确地测试各个客户端模型的分类准确率,提高各个客户端的贡献度计算的准确性。
[0064]
根据本发明的一个实施例,获取上述实施例得到的第一训练集,在利用该第一训练集训练判别模型时,将第一样本输入判别模型得到第一判别结果,基于第一样本的第一判别结果和置信度标签计算得到的第一判别损失来更新判别模型的参数,以得到当轮训练后的判别模型。第一判别结果即为第一样本是非仿真样本的置信度,采用均方误差损失函数根据判别模型输出的置信度和第一样本的置信度标签的差异得到第一判别损失。
[0065]
在步骤a3中,将生成的仿真样本输入当轮训练的判别模型,利用判别模型对生成的仿真样本的判别损失更新生成模型的参数。
[0066]
根据本发明的一个实施例,在步骤a3中更新生成模型的参数的方式为:将生成的仿真样本输入当轮训练的判别模型,利用判别模型对生成的仿真样本的第二判别损失取反后求梯度并反向传播更新生成模型的参数。
[0067]
根据本发明的另一个实施例,在步骤a3中更新生成模型的参数的方式也可以是:通过构建第二训练集,第二训练集包括多个第二样本和每个第二样本对应指示其是非仿真样本的置信度标签,每个第二样本为生成模型生成的仿真样本,每个第二样本对应的置信度标签的值设置为指示其是非仿真样本的最大置信度,其中,置信度∈[0,1],1用于表示非仿真样本,0用于表示仿真样本,即第二样本的结构形式为[非仿真样本,置信度标签=1]。将第二训练集中的第二样本输入当轮训练的判别模型得到第二判别结果,第二判别结果为当轮训练的判别模型基于输入的仿真样本输出的指示其是非仿真样本的置信度,基于第二样本的第二判别结果和对应置信度标签计算得到的第二判别损失来反向更新生成模型的参数,即通过该第二判别损失奖励生成模型。本发明中的对抗训练中要尽可能让判别模型判别正确以及使生成模型要尽可能让判别模型对其生成的数据判别错误,以确保经训练的生成模型生成质量更好的仿真样本数据。
[0068]
示意性的,参见图6,图6为基于上述实施例给出的一种联邦学习方法整体流程示意图,包括步骤b1、b2、b3、b4、b5、b6、b7、b8和b9。开始执行时,步骤b1:由中心节点初始化联合模型,并初始化参与联邦训练的多个客户端的贡献度,联邦训练轮次设置为g,每轮联邦训练中设置客户端模型利用本地训练集训练轮次为l;步骤b2:参与联邦训练的客户端下载联合模型,使用本地数据训练l轮得到各自客户端当前联邦训练轮次训练后的客户端模型,并将当前轮训练后的客户端模型上传至中心节点;步骤b3:中心节点基于各客户端模型的贡献度聚合上传的客户端模型,得到第i轮联邦训练的联合模型;若最新计算更新的某个客户端的贡献度低于预设阈值,则该客户端被判定为低数据质量客户端或者恶意客户端,则在本轮及后续轮次中,该客户端的客户端模型不会参与到联邦训练聚合过程中;步骤b4:中心节点判断第i轮联邦训练是否为贡献度计算轮次(即该轮次是否更新各客户端的贡献度),如果是则继续执行步骤5,否则执行步骤9;贡献度计算轮次判定属于超参数,在联邦训练启动时设定,本发明设置一次贡献度计算轮次。步骤b5:中心节点获取各客户端的数据标签分布情况;步骤b6:中心节点训练生成模型和判别模型,得到经训练的生成模型;步骤b7:中心节点利用经训练的生成模型针对每种类别生成多个仿真样本,基于仿真样本测试各客户端的客户端模型的分类准确率;步骤b8:计算基于各客户端模型在各类别下的分类准确率,结合对应客户端的数据标签分布情况计算该客户端的贡献度;步骤b9:中心节点基于将贡献度小于预设阈值的客户端的贡献度设置为0,且该客户端不参与后续联邦训练,其中,贡献度低于预设阈值的客户端判定为低数据质量节点或恶意节点,不让其参与后续联邦学习的训练过程,保证联邦学习的稳定性;最后,重复步骤b2~b9,直到达到预设的联邦训练轮次后训练完成,得到参与联邦训练的各个客户端的贡献度和最终的联合模型。
[0069]
根据本发明的一个实施例,提供一种图像分类方法,所述方法包括:获取根据上述实施例所述的联邦学习方法得到最终的联合模型,其中,联邦学习方法中,各个客户端的本地训练集中的样本数据为图像数据,标签为对应样本数据对应的类别;基于最终的联合模型对输入的图像数据进行图像分类。其中,样本数据对应的类别为飞机、汽车、鸟、猫、鹿、狗、蛙、马、船或卡车。
[0070]
根据本发明的一个实施例,提供一种用户分类方法,所述方法包括:获取根据上述实施例所述的联邦学习方法得到最终的联合模型,其中,联邦学习方法中,各个客户端的本地训练集中的样本数据为用户特征数据,标签为对应样本数据对应的用户类别;基于最终的联合模型对输入的用户特征数据进行用户分类。其中,所述样本数据对应的用户类别为优质用户、良好用户或普通用户。
[0071]
为了验证本发明的有益效果,发明人进行了如下实验:
[0072]
实验过程中采用的硬件为:amd ryzen 7 4800h处理器;4g内存;windows 10操作系统;软件:pytorch1.8深度学习框架;cuda10.1运算平台,基于该硬件和软件环境针对客户端可能存在的不同数据分布情况进行如下三种实验:
[0073]
1、对不同数据量客户端的贡献度评估实验
[0074]
实验过程:设置60个参与训练客户端,为第一组参与训练的客户端a1分配5000条数据,为第二组参与训练的客户端a2分配500条数据,为第三组参与训练的客户端a3分配50条数据,采用本发明方法计算每组中各个客户端的贡献度并进行统计,得到不同数据量客户端的贡献度评估结果,如下表1所示:
[0075]
表1:不同数据量客户端的贡献度评估结果
[0076][0077]
基于上表1中的实验结果,可知第一组客户端的贡献度平均值最高,第三组客户端的贡献度平均值最低,即本发明能够区分来自不同数据量a1,a2,a3集合的客户端,且数据量越大的客户端集合的贡献度平均值越高。
[0078]
2、对不同数据质量的客户端的贡献度评估
[0079]
实验过程:设置60个参与训练的客户端,为第一组参与训练的客户端b1分配1000条数据,为第二组参与训练的客户端b2分配1000条添加30%随机噪声的数据,为第三组参与训练的客户端b3分配1000条添加50%随机噪声的数据,采用本发明方法计算每组中各个客户端的贡献度并进行统计,得到不同数据质量的客户端的贡献度评估结果,如下表2所示:
[0080]
表2:不同数据质量的客户端的贡献度评估结果
[0081][0082]
基于上表2中的实验结果,可知第一组客户端的贡献度平均值最高,第三组客户端的贡献度平均值最低,即本发明能够区分来自不同数据质量b1,b2,b3集合的客户端,且数据质量越高的客户端集合的贡献度平均值越高。
[0083]
3、对不同数据标签分布情况的客户端的贡献度评估
[0084]
实验过程:设置60个参与训练的客户端,为第一组参与训练的客户端c1分配1000
条独立同分布数据(即1000条数据中包括所有类别对应的数据),为第二组参与训练的客户端c2分配1000条仅有50%数据标签的数据(即1000条数据中仅包括所有类别的一半类别对应的数据,其余类别对应的数据为0条),为第三组参与训练的客户端c3分配1000条仅有20%数据标签的数据(即1000条数据中仅包括所有类别中的四分之一类别对应的数据,其余类别对应的数据为0条),采用本发明方法计算每组中各个客户端的贡献度并进行统计,得到不同数据标签分布情况的客户端的贡献度评估结果,如下表3所示:
[0085]
表3:不同数据标签分布情况的客户端的贡献度评估结果
[0086][0087]
基于上表3中的实验结果,可知本发明对于拥有相似数据数量和数据质量的三组客户端贡献度评估平均值相似,因此对于非独立同分布数据的客户端的贡献度评估较好。且由于非独立同分布程度最高的第三组的客户端每个标签的数据量越大,因此贡献度评估平均值最高;由于非独立同分布程度最低的第二组的客户端拥有更广泛的数据类型,因此贡献度评估平均值次高。
[0088]
最后,基于上述三个方面的实验及实验结果可以知道,本发明采用仿真样本测试各客户端的分类准确率,并结合客户端的分类准确率和客户端的数据标签分布情况计算该客户端的贡献度,缓解了贡献度计算过程中异构数据造成的偏见,且对联邦学习中各客户端的贡献度的估计更加准确,基于准确性高的贡献度对多个客户端训练后的客户端模型进行聚合,得到联合模型,极大地提高了联邦学习效率。
[0089]
需要说明的是,虽然上文按照特定顺序描述了各个步骤,但是并不意味着必须按照上述特定顺序来执行各个步骤,实际上,这些步骤中的一些可以并发执行,甚至改变顺序,只要能够实现所需要的功能即可。
[0090]
本发明可以是系统、方法和/或计算机程序产品。计算机程序产品可以包括计算机可读存储介质,其上载有用于使处理器实现本发明的各个方面的计算机可读程序指令。
[0091]
计算机可读存储介质可以是保持和存储由指令执行设备使用的指令的有形设备。计算机可读存储介质例如可以包括但不限于电存储设备、磁存储设备、光存储设备、电磁存储设备、半导体存储设备或者上述的任意合适的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:便携式计算机盘、硬盘、随机存取存储器(ram)、只读存储器(rom)、可擦式可编程只读存储器(eprom或闪存)、静态随机存取存储器(sram)、便携式压缩盘只读存储器(cd-rom)、数字多功能盘(dvd)、记忆棒、软盘、机械编码设备、例如其上存储有指令的打孔卡或凹槽内凸起结构、以及上述的任意合适的组合。
[0092]
以上已经描述了本发明的各实施例,上述说明是示例性的,并非穷尽性的,并且也不限于所披露的各实施例。在不偏离所说明的各实施例的范围和精神的情况下,对于本技术领域的普通技术人员来说许多修改和变更都是显而易见的。本文中所用术语的选择,旨在最好地解释各实施例的原理、实际应用或对市场中的技术改进,或者使本技术领域的其它普通技术人员能理解本文披露的各实施例。
技术特征:
1.一种联邦学习方法,所述方法包括由中心节点将初始化的联合模型分发给多个客户端作为初始的客户端模型,并由中心节点和客户端配合完成多轮联邦训练,得到最终的联合模型,其中,每轮联邦训练包括:由中心节点获取每个客户端上传的当前轮训练后的客户端模型,其中,当前轮训练后的客户端模型是利用客户端的本地训练集以最新获得的客户端模型为基础训练得到的;由中心节点基于各客户端的贡献度对多个客户端当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型,并将该联合模型分发给多个客户端作为各客户端下一轮训练的基础,其中,各客户端的贡献度基于各客户端模型分别对预设的仿真样本的分类准确率确定。2.根据权利要求1所述的方法,其特征在于,所述每个客户端的贡献度根据该客户端的数据标签分布情况和该客户端模型对各类别下仿真样本的分类准确率确定,其中,所述数据标签分布情况指示对应客户端的各类别下的非仿真样本的占比。3.根据权利要求2所述的方法,其特征在于,所述每个客户端的贡献度是该客户端的数据标签分布情况中每个类别下的非仿真样本的占比与该客户端模型对该类别的仿真样本的分类准确率的乘积之和。4.根据权利要求2所述的方法,其特征在于,由中心节点按照以下方式获得当前轮更新后的联合模型:每轮联邦训练前,获取最新更新的每个客户端的贡献度,比较各客户端中最新更新的每个客户端的贡献度和预设阈值的大小,剔除贡献度小于预设阈值所对应的客户端,得到当前轮参与训练的多个客户端;对当前轮参与训练的多个客户端上传的当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型。5.根据权利要求1所述的方法,其特征在于,所述仿真样本按照以下方式生成:获取基于生成对抗方式训练得到的经训练的生成模型;利用经训练的生成模型针对每种类别对应生成多个仿真样本,其中,每种类别下的仿真样本的数据标签向量与为该类别预设的标签向量的距离小于预设阈值,数据标签向量是基于将仿真样本输入当前轮更新后的联合模型中得到的。6.根据权利要求5所述的方法,其特征在于,基于所述生成对抗方式进行一轮或者多轮迭代对抗训练,每轮对抗训练包括:获取对抗生成网络,其包括生成模型和判别模型;获取第一训练集训练判别模型,得到当轮训练的判别模型,所述第一训练集包括多个第一样本和每个第一样本对应的指示其是非仿真样本的置信度标签,单个第一样本为仿真样本或者非仿真样本,该置信度标签基于将第一样本输入当轮更新后的联合模型得到的输出结果确定;将生成的仿真样本输入当轮训练的判别模型,利用判别模型对生成的仿真样本的判别损失更新生成模型的参数。7.根据权利要求6所述的方法,其特征在于,所述第一训练集中每个第一样本的生成方式包括:获取数据集中的非仿真样本或生成模型基于随机数生成的仿真样本;
将非仿真样本或仿真样本输入当轮更新后的联合模型中,得到数据标签向量,计算数据标签向量与为多种类别中每种类别预设的标签向量的距离,得到多个距离,从多个距离中选择与数据标签向量最小的距离;获取预设的超参数,根据预设的超参数以及计算的最小距离确定第一样本的置信度标签。8.根据权利要求6所述的方法,其特征在于,第一样本的置信度标签按照以下方式确定:其中,h表示置信度标签,α为预设的超参数,d表示计算的非仿真样本或仿真样本输入当前轮更新后的联合模型中得到的数据标签向量与为各类别预设的标签向量的最小距离,β表示预设的超参数,γ表示预设的超参数。9.根据权利要求1-8任一项所述的方法,其特征在于,各客户端的贡献度是每隔预设的轮次周期性更新的,或者在所有的联邦训练的轮次中指定的联邦训练的轮次被间隔更新的。10.一种图像分类方法,其特征在于,所述方法包括:获取根据权利要求1-9任一项所述的联邦学习方法得到最终的联合模型,其中,联邦学习方法中,各个客户端的本地训练集中的样本数据为图像数据,标签为对应样本数据对应的类别;基于最终的联合模型对输入的图像数据进行图像分类。11.根据权利要求10所述的方法,其特征在于,样本数据对应的类别为飞机、汽车、鸟、猫、鹿、狗、蛙、马、船或卡车。12.一种用户分类方法,其特征在于,所述方法包括:获取根据权利要求1-9任一项所述的联邦学习方法得到最终的联合模型,其中,联邦学习方法中,各个客户端的本地训练集中的样本数据为用户特征数据,标签为对应样本数据对应的用户类别;基于最终的联合模型对输入的用户特征数据进行用户分类。13.根据权利要求12所述的方法,其特征在于,所述样本数据对应的用户类别为优质用户、良好用户或普通用户。14.一种计算机可读存储介质,其特征在于,其上存储有计算机程序,所述计算机程序可被处理器执行以实现权利要求1至13中任一项所述方法的步骤。15.一种电子设备,其特征在于,包括:一个或多个处理器;以及存储器,其中存储器用于存储可执行指令;所述一个或多个处理器被配置为经由执行所述可执行指令以实现权利要求1至13中任一项所述方法的步骤。
技术总结
本发明实施例提供了一种联邦学习方法,所述方法包括由中心节点将初始化的联合模型分发给多个客户端作为初始的客户端模型,并由中心节点和客户端配合完成多轮联邦训练,得到最终的联合模型,其中,每轮联邦训练包括:由中心节点获取每个客户端上传的当前轮训练后的客户端模型,其中,当前轮训练后的客户端模型是利用客户端的本地训练集以最新获得的客户端模型为基础训练得到的;由中心节点基于各客户端的贡献度对多个客户端当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型,并将该联合模型分发给多个客户端作为各客户端下一轮训练的基础,其中,各客户端的贡献度基于各客户端模型分别对预设的仿真样本的分类准确率确定。分类准确率确定。分类准确率确定。
技术研发人员:史红周 余孙婕 曾辉
受保护的技术使用者:中国科学院计算技术研究所
技术研发日:2023.04.24
技术公布日:2023/8/1
版权声明
本文仅代表作者观点,不代表航空之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
飞行汽车 https://www.autovtol.com/
