面向增量学习的图像分类方法以及相关设备与流程

未命名 08-15 阅读:109 评论:0


1.本技术涉及人工智能技术领域和数字医疗领域,尤其涉及一种面向增量学习的图像分类方法以及相关设备。


背景技术:

2.随着人工智能的发展,越来越多深度学习技术被应用在各行各业中。对于计算机视觉领域中的图像分类技术,其在实际应用中对数据隐私的保护需求使得原始数据无法共享,多个数据源之间形成“数据孤岛”。基于联邦学习的图像分类因此被提出,通过联邦学习方式训练分类模型,使得每个联邦学习的参与者能够从其他参与者的图像数据中获益,同时能够确保每个参与者的图像数据不离开本地,在保证了各方数据隐私的前提下成功解决了数据孤岛的问题。
3.例如在数字医疗领域中,利用深度学习技术进行医学影像处理是医疗辅助诊断中至关重要的一步,通过联邦学习方法训练基于深度学习的分类模型能够在保护医疗数据隐私的前提下解决数据孤岛问题,提高医学影像分类的精准性。
4.此外,图像分类在实际应用中亦需要模型具备的增量学习能力,也就是在不忘记已经学习到的知识的基础上,不断地通过新类数据学习新的知识。然而,目前大多数基于联邦学习的图像分类方法在增量学习的过程中都会出现灾难性遗忘的问题,即在学习新类别的信息后,全局模型在旧类别的表现大幅度降低。
5.因此,如何提高联邦学习架构下图像分类的增量学习能力成为亟待解决的技术问题。


技术实现要素:

6.本技术实施例的主要目的在于提出一种面向增量学习的图像分类方法、装置、电子设备及计算机可读存储介质,能够提高联邦学习架构下图像分类的增量学习能力,缓解灾难性遗忘现象。
7.为实现上述目的,本技术实施例的第一方面提出了一种面向增量学习的图像分类方法,所述方法包括:
8.获取第一本地图像样本集;
9.根据所述第一本地图像样本集确定目标分类任务类型;
10.获取所述目标分类任务类型对应的多个适配器以及预测输出模块;
11.将多个所述适配器、所述预测输出模块和预训练模型进行结合处理,得到初始图像分类模型;
12.基于所述第一本地图像样本集对所述初始图像分类模型进行训练,得到中间图像分类模型;
13.将所述中间图像分类模型的本地模型参数上传至所述服务器,以使所述服务器对所述本地模型参数进行整合处理,得到全局模型参数;
14.从所述服务器获取所述全局模型参数,并根据所述全局模型参数更新所述中间图像分类模型;
15.将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型;
16.获取待分类图像;
17.将所述待分类图像输入至所述目标图像分类模型中,以通过所述目标图像分类模型得到所述待分类图像对应的分类预测结果。
18.根据本发明一些实施例提供的面向增量学习的图像分类方法,所述将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型,包括:
19.获取第二本地图像样本集,所述第二本地图像样本集与所述第一本地图像样本集为相同分类任务类型的旧图像样本;
20.对所述第一本地图像样本集和所述第二本地图像样本集进行融合处理,得到训练样本集;
21.将所述训练样本集分别输入至所述初始图像分类模型,以通过所述初始图像分类模型得到所述训练样本集中图像样本的第一分类预测结果;
22.将所述训练样本集分别输入至更新后的所述中间图像分类模型,以通过所述中间图像分类模型得到所述训练样本集中图像样本的第二分类预测结果;
23.根据所述第一分类预测结果和所述第二分类预测结果确定损失值;
24.基于所述损失值对所述初始图像分类模型中的多个所述适配器和所述预测输出模块的模型参数进行更新,得到目标图像分类模型。
25.根据本发明一些实施例提供的面向增量学习的图像分类方法,在所述获取所述目标分类任务类型对应的多个适配器以及预测输出模块之前,所述方法还包括:
26.构建网络模块池,所述网络模块池包括多个网络模块组,所述网络模块组用于与预训练模型进行结合以得到图像分类模型;
27.其中,每个所述网络模块组分别对应一个分类任务类型,每个所述网络模块组包括多个适配器以及预测输出模块;
28.所述获取所述目标分类任务类型对应的多个适配器以及预测输出模块,包括:
29.根据所述分类任务类型与所述网络模块组的对应关系,从所述网络模块池中获取所述目标分类任务类型对应的多个适配器以及预测输出模块。
30.根据本发明一些实施例提供的面向增量学习的图像分类方法,在所述将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型之后,所述方法还包括:
31.获取所述目标图像分类模型中多个所述适配器和所述预测输出模块的第一模型参数;
32.根据所述第一模型参数更新所述网络模块池中对应的网络模块组。
33.根据本发明一些实施例提供的面向增量学习的图像分类方法,在获取所述目标图像分类模型中多个所述适配器和所述预测输出模块的第一模型参数之后,所述方法还包括:
34.将所述第一模型参数上传至所述服务器,以使所述服务器对所述第一模型参数进行整合处理,得到第二模型参数;
35.所述根据所述第一模型参数更新所述网络模块池中对应的网络模块组,包括:
36.从所述服务器获取第二模型参数,根据所述第二模型参数更新所述网络模块池中对应的网络模块组。
37.根据本发明一些实施例提供的面向增量学习的图像分类方法,所述根据所述第一本地图像样本集确定目标分类任务类型,包括:
38.获取训练好的第一图像分类模型;
39.从所述第一本地图像样本集中选择预设数量的图像样本作为测试样本集;
40.将所述测试样本集输入至所述第一图像分类模型,以通过所述第一图像分类模型得到所述测试样本集对应的分类预测结果;
41.根据所述分类预测结果,确定所述第一本地图像样本集对应的分类任务类型。
42.根据本发明一些实施例提供的面向增量学习的图像分类方法,所述全局模型参数通过以下公式得到:
[0043][0044]
其中,所述w
t
为全局模型参数,所述n为参加联邦学习的客户端数量,所述nk为第k个客户端上传的本地图像样本数量,所述为第k个客户端上传的所述本地模型参数。
[0045]
为实现上述目的,本技术实施例的第二方面提出了一种面向增量学习的图像分类装置,所述装置包括:
[0046]
第一获取模块,用于获取第一本地图像样本集;
[0047]
分类模块,用于根据所述第一本地图像样本集确定目标分类任务类型;
[0048]
第二获取模块,用于获取所述目标分类任务类型对应的多个适配器以及预测输出模块;
[0049]
模型结合模块,用于将多个所述适配器、所述预测输出模块和预训练模型进行结合处理,得到初始图像分类模型;
[0050]
模型训练模块,用于基于所述第一本地图像样本集对所述初始图像分类模型进行训练,得到中间图像分类模型;
[0051]
参数上传模块,用于将所述中间图像分类模型的本地模型参数上传至所述服务器,以使所述服务器对所述本地模型参数进行整合处理,得到全局模型参数;
[0052]
第三获取模块,用于从所述服务器获取所述全局模型参数,并根据所述全局模型参数更新所述中间图像分类模型;
[0053]
模型蒸馏模块,用于将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型;
[0054]
第四获取模块,用于获取待分类图像;
[0055]
图像分类模块,用于将所述待分类图像输入至所述目标图像分类模型中,以通过所述目标图像分类模型得到所述待分类图像对应的分类预测结果。
[0056]
为实现上述目的,本技术实施例的第三方面提出了一种电子设备,所述电子设备
包括存储器、处理器、存储在所述存储器上并可在所述处理器上运行的计算机程序,所述计算机程序被所述处理器执行时实现上述第一方面所述的方法。
[0057]
为实现上述目的,本技术实施例的第四方面提出了一种存储介质,所述存储介质为计算机可读存储介质,用于计算机可读存储,所述存储介质存储有一个或者多个计算机程序,所述一个或者多个计算机程序可被一个或者多个处理器执行,以实现上述第一方面所述的方法。
[0058]
本技术提出一种面向增量学习的图像分类方法、装置、电子设备以及计算机可读存储介质,所述方法首先确定第一本地图像样本集的分类任务类型,并根据目标分类任务类型获取对应的多个适配器以及预测输出模块,之后将多个适配器和预测输出模块和预训练模型进行结合处理,得到初始图像分类模型。基于第一本地图像样本集对初始图像分类模型进行训练,得到中间图像分类模型,之后将中间图像分类模型的本地模型参数上传至服务器,以使服务器对本地模型参数进行整合处理,得到全局模型参数,并从服务器获取全局模型参数,根据全局模型参数更新中间图像分类模型,将更新后的中间图像分类模型作为教师模型对初始图像分类模型进行蒸馏处理,得到目标图像分类模型,最后利用目标图像分类模型进行图像分类,既能在联邦学习框架下充分学习新类数据,并利用学习得到新模型进行图像分类,又能够缓解在学习新类数据时造成的灾难性遗忘,提高联邦学习架构下图像分类的增量学习能力。
[0059]
本技术实施例提出的一种面向增量学习的图像分类方法可以应用于数字医疗领域如基于人工智能技术的医疗辅助诊断系统中,该医疗辅助诊断系统包括医疗终端设备和医疗云服务器,在医疗应用场景中,首先由医疗终端设备确定第一本地医学图像样本集的目标分类任务类型,并根据目标分类任务类型获取对应的多个适配器以及预测输出模块,之后将多个适配器和预测输出模块和预训练模型进行结合处理,得到初始图像分类模型,再基于第一本地医学图像样本集对初始图像分类模型进行训练,将中间图像分类模型的本地模型参数上传至医疗辅助诊断系统中的医疗云服务器,由医疗云服务器对多个医疗终端设备上传的本地模型参数进行整合处理得到全局模型参数,之后医疗终端设备从医疗云服务器获取全局模型参数并根据全局模型参数更新中间图像分类模型,将更新后的中间图像分裂模型作为教师模型对初始图像分类模型进行蒸馏处理得到目标图像分类模型,最后医疗终端设备利用目标图像分类模型对医学图像进行分类处理,能够在保护医学数据隐私的前提下充分学习新类数据,且通过知识蒸馏方法缓解在学习新类数据时造成的灾难性遗忘,提供医疗终端设备面对图像分类时的增量学习能力,提高医学图像分类的准确性。
附图说明
[0060]
图1是本技术实施例提供的一种面向增量学习的图像分类方法的流程示意图;
[0061]
图2是图1中步骤s170的子步骤流程示意图;
[0062]
图3是图1中步骤s120的子步骤流程示意图;
[0063]
图4是本技术另一实施例提供的一种面向增量学习的图像分类方法的流程示意图;
[0064]
图5是本技术另一实施例提供的一种面向增量学习的图像分类方法的流程示意图;
[0065]
图6是图1中步骤s110的子步骤流程示意图;
[0066]
图7是本技术实施例提供的一种联邦学习框架的结构示意图;
[0067]
图8是本技术另一实施例提供的一种面向增量学习的图像分类方法的流程示意图;
[0068]
图9是本技术另一实施例提供的一种面向增量学习的图像分类装置的结构示意图;
[0069]
图10是本技术实施例提供的一种电子设备的硬件结构示意图。
具体实施方式
[0070]
为了使本技术的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本技术进行进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本技术,并不用于限定本技术。
[0071]
需要说明的是,除非另有定义,本文所使用的所有的技术和科学术语与属于本技术的技术领域的技术人员通常理解的含义相同。本文中所使用的术语只是为了描述本技术实施例的目的,不是旨在限制本技术。
[0072]
随着人工智能的发展,越来越多深度学习技术被应用在各行各业中。对于计算机视觉领域中的图像分类技术,其在实际应用中对数据隐私的保护需求使得原始数据无法共享,多个数据源之间形成“数据孤岛”。基于联邦学习的图像分类因此被提出,通过联邦学习方式训练分类模型,使得每个联邦学习的参与者能够从其他参与者的图像数据中获益,同时能够确保每个参与者的图像数据不离开本地,在保证了各方数据隐私的前提下成功解决了数据孤岛的问题。
[0073]
例如在数字医疗领域中,利用深度学习技术进行医学影像处理是医疗辅助诊断中至关重要的一步,通过联邦学习方法训练基于深度学习的分类模型能够在保护医疗数据隐私的前提下解决数据孤岛问题,提高医学影像分类的精准性。
[0074]
此外,图像分类在实际应用中亦需要模型具备的增量学习能力,也就是在不忘记已经学习到的知识的基础上,不断地通过新类数据学习新的知识。然而,目前大多数基于联邦学习的图像分类方法在增量学习的过程中都会出现灾难性遗忘的问题,即在学习新类别的信息后,全局模型在旧类别的表现大幅度降低。
[0075]
因此,如何提高联邦学习架构下图像分类的增量学习能力成为亟待解决的技术问题。
[0076]
基于此,本技术实施例提供了一种面向增量学习的图像分类方法、装置、电子设备及计算机可读存储介质,能够提高联邦学习架构下图像分类的增量学习能力,缓解灾难性遗忘现象。
[0077]
本技术实施例提供的一种面向增量学习的图像分类方法、装置、电子设备及计算机可读存储介质,具体通过如下实施例进行说明,首先描述本技术实施例中的面向增量学习的图像分类方法。
[0078]
本技术实施例可以基于人工智能技术对相关的数据进行获取和处理。其中,人工智能(artificial intelligence,ai)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及
应用系统。
[0079]
人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、机器人技术、生物识别技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
[0080]
本技术实施例提供的面向增量学习的图像分类方法可应用于终端中,也可应用于服务器端中,还可以是运行于终端或服务器端中的软件。在一些实施例中,终端可以是智能手机、平板电脑、笔记本电脑、台式计算机等;服务器端可以配置成独立的物理服务器,也可以配置成多个物理服务器构成的服务器集群或者分布式系统,还可以配置成提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、cdn以及大数据和人工智能平台等基础云计算服务的云服务器;软件可以是实现面向增量学习的图像分类方法的应用等,但并不局限于以上形式。
[0081]
本技术可用于众多通用或专用的计算机系统环境或配置中。例如:个人计算机、服务器计算机、手持设备或便携式设备、平板型设备、多处理器系统、基于微处理器的系统、置顶盒、可编程的消费电子设备、网络pc、小型计算机、大型计算机、包括以上任何系统或设备的分布式计算环境等等。本技术可以在由计算机执行的计算机可执行指令的一般上下文中描述,例如程序模块。一般地,程序模块包括执行特定任务或实现特定抽象数据类型的例程、程序、对象、组件、数据结构等等。也可以在分布式计算环境中实践本技术,在这些分布式计算环境中,由通过通信网络而被连接的远程处理设备来执行任务。在分布式计算环境中,程序模块可以位于包括存储设备在内的本地和远程计算机存储介质中。
[0082]
首先描述本技术实施例中的联邦学习框架,参见图7,图7是本技术实施例提供的一种联邦学习框架的结构示意图,如图7所示,该联邦学习框架中包括服务器以及多个客户端,服务器与多个客户端之间通信连接。
[0083]
请参见图1,图1示出了本技术实施例提供的一种面向增量学习的图像分类方法的流程示意图。如图1所示,该面向增量学习的图像分类方法应用于参与联邦学习的客户端,该方法包括但不限于步骤s100至s190:
[0084]
步骤s100,获取第一本地图像样本集;
[0085]
步骤s110,根据该第一本地图像样本集确定目标分类任务类型;
[0086]
步骤s120,获取该目标分类任务类型对应的多个适配器以及预测输出模块;
[0087]
步骤s130,将多个该适配器、该预测输出模块和预训练模型进行结合处理,得到初始图像分类模型;
[0088]
步骤s140,基于该第一本地图像样本集对该初始图像分类模型进行训练,得到中间图像分类模型;
[0089]
步骤s150,将该中间图像分类模型的本地模型参数上传至该服务器,以使该服务器对该本地模型参数进行整合处理,得到全局模型参数;
[0090]
步骤s160,从该服务器获取该全局模型参数,并根据该全局模型参数更新该中间图像分类模型;
[0091]
步骤s170,将更新后的该中间图像分类模型作为教师模型对该初始图像分类模型进行蒸馏处理,得到目标图像分类模型;
[0092]
步骤s180,获取待分类图像;
[0093]
步骤s190,将该待分类图像输入至该目标图像分类模型中,以通过该目标图像分类模型得到该待分类图像对应的分类预测结果。
[0094]
可以理解的是,参与联邦学习的客户端在学习新任务时,通过获取第一本地图像样本集,并根据第一本地图像样本集,也就是根据新类数据确定对应的目标分类任务类型。
[0095]
在一种可能的实现方式中,本技术实施例提供的面向增量学习的图像分类方法应用于数字医疗领域如基于人工智能技术的医疗辅助诊断系统中,该医疗辅助诊断系统包括多个医疗终端设备和医疗云服务器,其中,医疗终端设备和医疗云服务器通过有线或无线通信连接,且共同参与联邦学习,医疗终端设备执行该面向增量学习的图像分类方法,其中,在医学应用场景中,样本图像为医学影像,样本图像包含的对象所属类型为病灶,即机体上发生病变的部分。医学影像是指为了医疗或医学研究,以非侵入方式取得的内部组织,例如,胃部、腹部、心脏、膝盖、脑部的影像,比如,ct(computed tomography,电子计算机断层扫描)、mri(magnetic resonance imaging,磁共振成像)、us(ultrasonic,超声)、x光图像、脑电图以及光学摄影等由医学仪器生成的图像。
[0096]
在一些实施例中,参见图6,图6是图1中步骤s110的子步骤流程示意图,如图6所示,该根据该第一本地图像样本集确定目标分类任务类型,包括:
[0097]
步骤s610,获取训练好的第一图像分类模型;
[0098]
步骤s620,从该第一本地图像样本集中选择预设数量的图像样本作为测试样本集;
[0099]
步骤s630,将该测试样本集输入至该第一图像分类模型,以通过该第一图像分类模型得到该测试样本集对应的分类预测结果;
[0100]
步骤s640,根据该分类预测结果,确定该第一本地图像样本集对应的分类任务类型。
[0101]
在一个具体实施例中,第一图像分类模型为efficientnet模型。通过获取训练好的efficientnet模型,从第一本地图像样本集周昂选择预设数量的图像样本作为测试样本集,之后将测试样本集输入至efficientnet模型,已通过该模型得到该厕所样本集对应的分类预测结果,之后根据该分类预测结果确定该第一本地图像样本集对应的分类任务类型。也就是说,当有新类需要学习时,首先通过efficientnet对图像样本集进行粗分类,确定该新类所属的分类任务类型。
[0102]
在目标分类任务类型之后,获取目标分类任务类型对应的多个适配器和预测输出模块。
[0103]
在一些实施例中,在该获取目标分类任务类型对应的多个适配器以及预测输出模块之前,该方法还包括:
[0104]
构建网络模块池,该网络模块池包括多个网络模块组,该网络模块组用于与预训练模型进行结合以得到图像分类模型;
[0105]
其中,每个该网络模块组分别对应一个分类任务类型,每个该网络模块组包括多个适配器以及预测输出模块;
[0106]
可以理解的是,在获取目标分类任务类型对应的多个适配器以及预测输出模块之前,预先构建网络模块池,其中,参见图8,该网络模块池包括多个网络模块组,而网络模块
组包括多个适配器(adapter)以及预测输出模块(prediction head),同时,每个网络模块组各自都对应有分类任务类型。
[0107]
请参见图3,图3是图1中步骤s120的子步骤流程示意图,如图3所示,获取该目标分类任务类型对应的多个适配器以及预测输出模块,包括但不限于:
[0108]
步骤s310,根据该分类任务类型与该网络模块组的对应关系,从该网络模块池中获取该目标分类任务类型对应的多个适配器以及预测输出模块。
[0109]
可以理解的是,由于每个网络模块组各自对应有一个分类任务类型,因此,在确定目标分类任务类型之后,根据预设的分类任务类型和网络模块组之间的对应关系,可以从网络模块池中获取该目标分类任务类型对应的网络模块组,也就是确定目标分类任务类型对应的多个适配器以及预测输出模块。
[0110]
应能理解的是,本技术通过将分类任务划分成多个小类,而每个小类在网络模块池中都有其对应的多个适配器以及预测输出模块。
[0111]
在一个具体实施例中,该预训练模型包括多个transformer block,如图8所示,在确定目标分类任务类型之后,获取目标分类任务类型对应的多个适配器以及预测输出模块,之后同预训练好的transformer blocks相结合,得到初始图像分类模型。
[0112]
应能理解的是,通过在原始的预训练模型中的每个transformer block中加入一些参数可训练的模块,也就是加入adapter模块。因此,在针对不同的下游任务进行调整时,可以在固定预训练参数的情况下,只针对adapter模块的参数进行训练,能够在使用大规模模型的情况下,大幅度降低联邦学习的通信成本。
[0113]
通过将多个适配器、预测输出模块和预训练模型进行结合处理得到初始图像分类模型之后,基于第一本地图像样本集对应该初始图像分类模型进行训练,得到中间图像分类模型。其中,该初始图像分类模型的训练过程可以包括以下步骤:
[0114]
将该第一本地图像样本集输入至初始图像分类模型中,以通过初始图像分类模型获取第一本地图像样本集中图像样本对应的预测结果,根据该预测结果、图像真实分类以及预设的分类损失函数确定损失值,并基于该损失值对初始图像分类模型进行训练,也就是更新该初始图像分类模型的模型参数,直至该损失值满足预设阈值,得到中间图像分类模型。
[0115]
应能理解的是,在得到中间图像分类模型之后,将中间图像分类模型的本地模型参数上传至服务器,以通过服务器对本地模型参数进行整个处理,得到全局模型参数。
[0116]
需要说明的是,该中间图像分类模型的本地模型参数指的是适配器的参数,本技术的图像分类模型使用预训练模型+适配器的模型结构,在将多个适配器、预测输出模块以及预训练模型进行结合得到初始图像分类模型之后,对初始图像分类模型进行训练,也就是在固定预训练模型的模型参数的情况下,只针对适配器的参数进行训练。因此在模型参数的整合过程中,客户端只需将中间图像分类模型中适配器的参数上传至服务器,由服务器对多个客户端上传的适配器的参数进行整合处理,得到全局模型参数。
[0117]
在一些实施例中,该全局模型参数通过以下公式得到:
[0118][0119]
其中,该w
t
为全局模型参数,该n为参加联邦学习的客户端数量,该nk为第k个客户
端上传的本地图像样本数量,该为第k个客户端上传的该本地模型参数。
[0120]
举例来说,如图7所示,参与联邦学习的n个客户端上传各自的本地模型参数以及本地图像样本数量,服务器根据该上述公式进行计算,得到本次联邦学习的全局模型参数。
[0121]
在另一些实施例中,在将该中间图像分类模型的本地模型参数上传至服务器之后,该方法还包括:
[0122]
获取服务器发送的响应信息,
[0123]
根据该响应信息返回基于第一本地图像样本集对初始图像分类模型进行训练这一步骤。
[0124]
可以理解的是,服务器与本地客户端同样预设有网络模块池,而服务器在对多个客户端上传的本地模型参数进行整合得到全局模型参数之后,基于该全局模型参数更新服务器中由多个适配器、预测输出模块以及预训练模型组成的图像分类模型,并基于该图像分类模型进行测试;若测试结果满足预设条件,则将本次整合得到的模型参数视为最终的全局模型参数,从而使得客户端从服务器获取该全局模型参数,并根据该全局模型参数更新中间图像分类模型;或者,若测试结果未满足预设条件,则向参与本次联邦学习的客户端发送响应信息,以使客户端根据该响应信息继续训练初始图像分类模型,直至该测试结果满足预设条件。
[0125]
通过整合客户端上传的本地模型参数,并对基于该全局模型参数更新的图像分类模型进行测试,根据测试结果确定是否继续对初始本地图像分类模型进行训练,能够有效避免各个客户端之间数据不独立同分布,而导致模型性能下降的问题。
[0126]
在一些实施例中,参见图2,图2是图1中步骤s170的子步骤流程图,如图2所示,该将更新后的该中间图像分类模型作为教师模型对该初始图像分类模型进行蒸馏处理,得到目标图像分类模型,包括:
[0127]
步骤s210,获取第二本地图像样本集,该第二本地图像样本集与该第一本地图像样本集为相同分类任务类型的旧图像样本;
[0128]
步骤s220,对该第一本地图像样本集和该第二本地图像样本集进行融合处理,得到训练样本集;
[0129]
步骤s230,将该训练样本集分别输入至该初始图像分类模型,以通过该初始图像分类模型得到该训练样本集中图像样本的第一分类预测结果;
[0130]
步骤s240,将该训练样本集分别输入至更新后的该中间图像分类模型,以通过该中间图像分类模型得到该训练样本集中图像样本的第二分类预测结果;
[0131]
步骤s250,根据该第一分类预测结果和该第二分类预测结果确定损失值;
[0132]
步骤s260,基于该损失值对该初始图像分类模型中的多个该适配器和该预测输出模块的模型参数进行更新,得到目标图像分类模型。
[0133]
应能理解的是,将第一本地图像样本集和第二本地图像样本集进行合并得到训练样本集,也就是将新旧任务的样本数据进行融合,之后利用该训练样本集对初始图像分类模型和中间图像分类模型进行知识蒸馏,其中,中间图像分类模型作为教师模型,而初始图像分类模型作为学生模型。
[0134]
通过对进行新类数据学习后的中间图像分类模型进行知识蒸馏,以更新初始图像分类模型的模型参数,使得客户端中的本地模型既能充分学习新类数据并以此进行图像分
类,同时又能够缓解在学习新类数据时造成的灾难性遗忘。
[0135]
在一些实施例中,参见图4,图4是本技术另一实施例提供的一种面向增量学习的图像分类方法,如图4所示,在该将更新后的该中间图像分类模型作为教师模型对该初始图像分类模型进行蒸馏处理,得到目标图像分类模型之后,该方法还包括:
[0136]
步骤s410,获取该目标图像分类模型中多个该适配器和该预测输出模块的第一模型参数;
[0137]
步骤s420,根据该第一模型参数更新该网络模块池中对应的网络模块组。
[0138]
应能理解的是,在对中间图像分类模型和初始图像分类模型进行知识蒸馏,以更新初始图像分类模型的模型参数之后,获取该目标图像分类模型中多个适配器和预测输出模型的参数,以更新网络模块池中对应的网络模块组,也就是替换目标分类任务类型对应的网络模块组中适配器以及预测输出模块的参数,从而使得客户端在下次样本学习中能够从网络模块池中获取经过多次学习的网络模块组,实现客户端的联邦增量学习。
[0139]
在一些实施例中,参见图5,图5是本技术另一实施例提供的一种面向增量学习的图像分类方法,如图5所示,在获取该目标图像分类模型中多个该适配器和该预测输出模块的第一模型参数之后,该方法还包括:
[0140]
步骤s510,将该第一模型参数上传至该服务器,以使该服务器对该第一模型参数进行整合处理,得到第二模型参数;
[0141]
该根据该第一模型参数更新该网络模块池中对应的网络模块组,包括:
[0142]
步骤s520,从该服务器获取第二模型参数,根据该第二模型参数更新该网络模块池中对应的网络模块组。
[0143]
应能理解的是,对于目标图像分类模型中多个适配器和预测输出模块的第一模型参数,在根据第一模型参数更新网络模块池中对应的网络模块组之间,可以将第一模型参数上传至服务器,以通过服务器对多个客户端上传的第一模型参数进行整个处理,得到第二模型参数,从而使得客户端根据第二模型参数更新网络模块池中对应的网络模块组,能够进有效避免各个客户端之间数据不独立同分布的问题。
[0144]
还可以理解的是,由多个适配器、预测输出模块以及预训练模型组成的图像分类模型还包括嵌入块(embedding block),通过嵌入块对模型输入进行特征提取。
[0145]
下面通过一个具体实施例描述本技术实施例提供的面向增量学习的图像分类方法:
[0146]
如图7和8所示,该方法应用于参与联邦学习的客户端,该客户端与服务器通信连接。在客户端中预先构建有网络模块池,其中,网络模块池包括n个网络模块组,而每个网络模块组包括有多个适配器以及预测输出模块,且每个网络模块组分别对应有分类任务中的小类。
[0147]
客户端获取第一本地图像样本集并以此学习新知识,客户端首先利用预设的efficientnet模型对第一本地图像样本集进行粗分类,以确定第一本地图像样本集对应的目标分类任务类型,之后基于预设的网络模块组和分类任务类型之间的对应关系,确定目标分类任务类型对应的多个适配器以及预测输出模块,从而客户端将多个适配器以及预测输出模块与预训练好的transformer block相结合,得到初始图像分类模型。
[0148]
继而在初始图像分类模型上进行联邦学习得到中间图像分类模型,再将中间图像
分类模型作为教师模型对初始图像分类模型进行蒸馏处理,得到目标图像分类模型。最后基于目标图像分类模型进行分类预测。同时,获取目标图像分类模型中多个适配器以及预测输出模块的第一模型参数,根据该第一模型参数更新网络模块池中对应的网络模块组,从而实现在联邦学习框架下充分学习新类数据,利用学习得到的目标图像分类模型进行分类预测,同时能够缓解在学习新类数据时造成的灾难性遗忘,并且可以有效避免各个客户端之间数据不独立同分布,而造成模型性能下降的问题。
[0149]
下面通过具体实施例描述本技术实施例的应用场景:
[0150]
本技术实施例提供的面向增量学习的图像分类方法应用于数字医疗领域如基于人工智能技术的医疗辅助诊断系统中,该医疗辅助诊断系统包括多个医疗终端设备、医疗云服务器和数据库服务器,其中,医疗终端设备和医疗云服务器通过有线或无线通信连接,且共同参与联邦学习,医疗终端设备执行该面向增量学习的图像分类方法,数据库服务器中预先构建有网络模块池,其中,网络模块池包括多个网络模块组,而每个网络模块组包括有多个适配器以及预测输出模块,且每个网络模块组对应着分类任务中的小类。
[0151]
医疗终端设备获取第一本地医学图像集并以此学习新知识,医疗终端设备首先利用预设的efficientnet模型对第一本地医学图像集进行粗分类,以确定第一本地医学图像集对应的目标分类任务类型,之后医疗终端设备从数据库服务器中预先构建的网络模块组确定目标分类任务类型对应的多个适配器以及预测输出模块,之后医疗终端设备将多个适配器以及预测输出模块与预训练好的transformer block相结合,得到初始图像分类模型。
[0152]
医疗终端设备继而在初始图像分类模型的本地模型参数上传至医疗云服务器,由医疗云服务器对多个医疗终端设备上传的本地模型参数进行整合处理得到全局模型参数,之后医疗终端设备从医疗云服务器获取全局模型参数并根据全局模型参数更新中间图像分类模型,最后将更新后的中间图像分类模型作为教师模型对作为学生模型的初始图像分类模型进行蒸馏处理,得到目标图像分类模型,最后医疗终端设备利用目标图像分类模型对医学图像进行图像分类处理,能够在保护医学数据隐私的前提下充分学习新类数据,且通过知识蒸馏方法缓解在学习新类数据时造成的灾难性遗忘,提供医疗终端设备面对图像分类时的增量学习能力,提高医学图像分类的准确性。
[0153]
本技术提出一种面向增量学习的图像分类方法,该方法首先确定第一本地图像样本集的分类任务类型,并根据目标分类任务类型获取对应的多个适配器以及预测输出模块,之后将多个适配器和预测输出模块和预训练模型进行结合处理,得到初始图像分类模型。基于第一本地图像样本集对初始图像分类模型进行训练,得到中间图像分类模型,之后将中间图像分类模型的本地模型参数上传至服务器,以使服务器对本地模型参数进行整合处理,得到全局模型参数,并从服务器获取全局模型参数,根据全局模型参数更新中间图像分类模型,将更新后的中间图像分类模型作为教师模型对初始图像分类模型进行蒸馏处理,得到目标图像分类模型,最后利用目标图像分类模型进行图像分类,既能在联邦学习框架下充分学习新类数据,并利用学习得到新模型进行图像分类,又能够缓解在学习新类数据时造成的灾难性遗忘,提高联邦学习架构下图像分类的增量学习能力。
[0154]
请参见图9,本技术实施例还提供了一种面向增量学习的图像分类装置,所述面向增量学习的图像分类装置包括:
[0155]
第一获取模块100,用于获取第一本地图像样本集;
[0156]
分类模块110,用于根据该第一本地图像样本集确定目标分类任务类型;
[0157]
第二获取模块120,用于获取该目标分类任务类型对应的多个适配器以及预测输出模块;
[0158]
模型结合模块130,用于将多个该适配器、该预测输出模块和预训练模型进行结合处理,得到初始图像分类模型;
[0159]
模型训练模块140,用于基于该第一本地图像样本集对该初始图像分类模型进行训练,得到中间图像分类模型;
[0160]
参数上传模块150,用于将该中间图像分类模型的本地模型参数上传至该服务器,以使该服务器对该本地模型参数进行整合处理,得到全局模型参数;
[0161]
第三获取模块160,用于从该服务器获取该全局模型参数,并根据该全局模型参数更新该中间图像分类模型;
[0162]
模型蒸馏模块170,用于将更新后的该中间图像分类模型作为教师模型对该初始图像分类模型进行蒸馏处理,得到目标图像分类模型;
[0163]
第四获取模块180,用于获取待分类图像;
[0164]
图像分类模块190,用于将该待分类图像输入至该目标图像分类模型中,以通过该目标图像分类模型得到该待分类图像对应的分类预测结果。
[0165]
本技术提出一种面向增量学习的图像分类装置,该装置首先确定第一本地图像样本集的分类任务类型,并根据目标分类任务类型获取对应的多个适配器以及预测输出模块,之后将多个适配器和预测输出模块和预训练模型进行结合处理,得到初始图像分类模型。基于第一本地图像样本集对初始图像分类模型进行训练,得到中间图像分类模型,之后将中间图像分类模型的本地模型参数上传至服务器,以使服务器对本地模型参数进行整合处理,得到全局模型参数,并从服务器获取全局模型参数,根据全局模型参数更新中间图像分类模型,将更新后的中间图像分类模型作为教师模型对初始图像分类模型进行蒸馏处理,得到目标图像分类模型,最后利用目标图像分类模型进行图像分类,既能在联邦学习框架下充分学习新类数据,并利用学习得到新模型进行图像分类,又能够缓解在学习新类数据时造成的灾难性遗忘,提高联邦学习架构下图像分类的增量学习能力。
[0166]
需要说明的是,上述装置的模块之间的信息交互、执行过程等内容,由于与本技术方法实施例基于同一构思,其具体功能及带来的技术效果,具体可参见方法实施例部分,此处不再赘述。
[0167]
请参见图10,图10示出本技术实施例提供的一种电子设备的硬件结构,电子设备包括:
[0168]
处理器210,可以采用通用的cpu(central processing unit,中央处理器)、微处理器、应用专用集合成电路(application specific integrated circuit,asic)、或者一个或多个集合成电路等方式实现,用于执行相关计算机程序,以实现本技术实施例所提供的技术方案;
[0169]
存储器220,可以采用只读存储器(read only memory,rom)、静态存储设备、动态存储设备或者随机存取存储器(random access memory,ram)等形式实现。存储器220可以存储操作系统和其他应用程序,在通过软件或者固件来实现本说明书实施例所提供的技术方案时,相关的程序代码保存在存储器220中,并由处理器210来调用执行本技术实施例的
面向增量学习的图像分类方法;
[0170]
输入/输出接口230,用于实现信息输入及输出;
[0171]
通信接口240,用于实现本设备与其他设备的通信交互,可以通过有线方式(例如usb、网线等)实现通信,也可以通过无线方式(例如移动网络、wifi、蓝牙等)实现通信;和总线250,在设备的每个组件(例如处理器210、存储器220、输入/输出接口230和通信接口240)之间传输信息;
[0172]
其中处理器210、存储器220、输入/输出接口230和通信接口240通过总线250实现彼此之间在设备内部的通信连接。
[0173]
本技术实施例还提供了一种存储介质,存储介质为计算机可读存储介质,用于计算机可读存储,存储介质存储有一个或者多个计算机程序,一个或者多个计算机程序可被一个或者多个处理器执行,以实现上述面向增量学习的图像分类方法。
[0174]
存储器作为一种计算机可读存储介质,可用于存储软件程序以及计算机可执行程序。此外,存储器可以包括高速随机存取存储器,还可以包括非暂态存储器,例如至少一个磁盘存储器件、闪存器件、或其他非暂态固态存储器件。在一些实施方式中,存储器可选包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至该处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
[0175]
本技术实施例描述的实施例是为了更加清楚的说明本技术实施例的技术方案,并不构成对于本技术实施例提供的技术方案的限定,本领域技术人员可知,随着技术的演变和新应用场景的出现,本技术实施例提供的技术方案对于类似的技术问题,同样适用。
[0176]
以上所描述的装置实施例仅仅是示意性的,其中作为分离部件说明的单元可以是或者也可以不是物理上分开的,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。
[0177]
本领域普通技术人员可以理解,上文中所公开方法中的全部或某些步骤、系统、设备中的功能模块/单元可以被实施为软件、固件、硬件及其适当的组合。
[0178]
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
[0179]
本技术的说明书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等(如果存在)是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本技术的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
[0180]
应当理解,在本技术中,“至少一个(项)”是指一个或者多个,“多个”是指两个或两个以上。“和/或”,用于描述关联对象的关联关系,表示可以存在三种关系,例如,“a和/或b”可以表示:只存在a,只存在b以及同时存在a和b三种情况,其中a,b可以是单数或者复数。字符“/”一般表示前后关联对象是一种“或”的关系。“以下至少一项(个)”或其类似表达,是指这些项中的任意组合,包括单项(个)或复数项(个)的任意组合。例如,a,b或c中的至少一项(个),可以表示:a,b,c,“a和b”,“a和c”,“b和c”,或“a和b和c”,其中a,b,c可以是单个,也可
以是多个。
[0181]
在本技术所提供的几个实施例中,应该理解到,所揭露的装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,上述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集合成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
[0182]
上述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
[0183]
另外,在本技术每个实施例中的各功能单元可以集合成在一个处理单元中,也可以是每个单元单独物理存在,也可以两个或两个以上单元集合成在一个单元中。上述集合成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
[0184]
集合成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本技术的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括多指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本技术每个实施例的方法的全部或部分步骤。而前述的存储介质包括:u盘、移动硬盘、只读存储器(read-only memory,简称rom)、随机存取存储器(random access memory,简称ram)、磁碟或者光盘等各种可以存储程序的介质。
[0185]
以上参照附图说明了本技术实施例的优选实施例,并非因此局限本技术实施例的权利范围。本领域技术人员不脱离本技术实施例的范围和实质内所作的任何修改、等同替换和改进,均应在本技术实施例的权利范围之内。

技术特征:
1.一种面向增量学习的图像分类方法,其特征在于,所述方法应用于参与联邦学习的客户端,所述客户端与服务器通信连接,其特征在于,所述方法包括:获取第一本地图像样本集;根据所述第一本地图像样本集确定目标分类任务类型;获取所述目标分类任务类型对应的多个适配器以及预测输出模块;将多个所述适配器、所述预测输出模块和预训练模型进行结合处理,得到初始图像分类模型;基于所述第一本地图像样本集对所述初始图像分类模型进行训练,得到中间图像分类模型;将所述中间图像分类模型的本地模型参数上传至所述服务器,以使所述服务器对所述本地模型参数进行整合处理,得到全局模型参数;从所述服务器获取所述全局模型参数,并根据所述全局模型参数更新所述中间图像分类模型;将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型;获取待分类图像;将所述待分类图像输入至所述目标图像分类模型中,以通过所述目标图像分类模型得到所述待分类图像对应的分类预测结果。2.根据权利要求1所述的方法,其特征在于,所述将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型,包括:获取第二本地图像样本集,所述第二本地图像样本集与所述第一本地图像样本集为相同分类任务类型的旧图像样本;对所述第一本地图像样本集和所述第二本地图像样本集进行融合处理,得到训练样本集;将所述训练样本集分别输入至所述初始图像分类模型,以通过所述初始图像分类模型得到所述训练样本集中图像样本的第一分类预测结果;将所述训练样本集分别输入至更新后的所述中间图像分类模型,以通过所述中间图像分类模型得到所述训练样本集中图像样本的第二分类预测结果;根据所述第一分类预测结果和所述第二分类预测结果确定损失值;基于所述损失值对所述初始图像分类模型中的多个所述适配器和所述预测输出模块的模型参数进行更新,得到目标图像分类模型。3.根据权利要求1所述的方法,其特征在于,在所述获取所述目标分类任务类型对应的多个适配器以及预测输出模块之前,所述方法还包括:构建网络模块池,所述网络模块池包括多个网络模块组,所述网络模块组用于与预训练模型进行结合以得到图像分类模型;其中,每个所述网络模块组分别对应一个分类任务类型,每个所述网络模块组包括多个适配器以及预测输出模块;所述获取所述目标分类任务类型对应的多个适配器以及预测输出模块,包括:根据所述分类任务类型与所述网络模块组的对应关系,从所述网络模块池中获取所述
目标分类任务类型对应的多个适配器以及预测输出模块。4.根据权利要求3所述的方法,其特征在于,在所述将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型之后,所述方法还包括:获取所述目标图像分类模型中多个所述适配器和所述预测输出模块的第一模型参数;根据所述第一模型参数更新所述网络模块池中对应的网络模块组。5.根据权利要求4所述的方法,其特征在于,在获取所述目标图像分类模型中多个所述适配器和所述预测输出模块的第一模型参数之后,所述方法还包括:将所述第一模型参数上传至所述服务器,以使所述服务器对所述第一模型参数进行整合处理,得到第二模型参数;所述根据所述第一模型参数更新所述网络模块池中对应的网络模块组,包括:从所述服务器获取第二模型参数,根据所述第二模型参数更新所述网络模块池中对应的网络模块组。6.根据权利要求1所述的方法,其特征在于,所述根据所述第一本地图像样本集确定目标分类任务类型,包括:获取训练好的第一图像分类模型;从所述第一本地图像样本集中选择预设数量的图像样本作为测试样本集;将所述测试样本集输入至所述第一图像分类模型,以通过所述第一图像分类模型得到所述测试样本集对应的分类预测结果;根据所述分类预测结果,确定所述第一本地图像样本集对应的分类任务类型。7.根据权利要求1所述的方法,其特征在于,所述全局模型参数通过以下公式得到:其中,所述w
t
为全局模型参数,所述n为参加联邦学习的客户端数量,所述n
k
为第k个客户端上传的本地图像样本数量,所述为第k个客户端上传的所述本地模型参数。8.一种面向增量学习的图像分类装置,应用于参与联邦学习的客户端,其特征在于,所述装置包括:第一获取模块,用于获取第一本地图像样本集;分类模块,用于根据所述第一本地图像样本集确定目标分类任务类型;第二获取模块,用于获取所述目标分类任务类型对应的多个适配器以及预测输出模块;模型结合模块,用于将多个所述适配器、所述预测输出模块和预训练模型进行结合处理,得到初始图像分类模型;模型训练模块,用于基于所述第一本地图像样本集对所述初始图像分类模型进行训练,得到中间图像分类模型;参数上传模块,用于将所述中间图像分类模型的本地模型参数上传至所述服务器,以使所述服务器对所述本地模型参数进行整合处理,得到全局模型参数;第三获取模块,用于从所述服务器获取所述全局模型参数,并根据所述全局模型参数
更新所述中间图像分类模型;模型蒸馏模块,用于将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型;第四获取模块,用于获取待分类图像;图像分类模块,用于将所述待分类图像输入至所述目标图像分类模型中,以通过所述目标图像分类模型得到所述待分类图像对应的分类预测结果。9.一种电子设备,其特征在于,包括:至少一个处理器;以及,与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有计算机程序,所述计算机程序被所述至少一个处理器执行,以使所述至少一个处理器能够执行如权利要求1至7中任一项所述的面向增量学习的图像分类方法。10.一种计算机可读存储介质,存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至7中任一项所述的面向增量学习的图像分类方法。

技术总结
本申请涉及人工智能领域以及数字医疗领域,提供了一种面向增量学习的图像分类方法以及相关设备,该方法通过确定第一本地图像样本集的分类任务类型,并根据目标分类任务类型获取对应的多个适配器以及预测输出模块,之后将多个适配器、预测输出模块和预训练模型进行结合以得到初始图像分类模型,之后基于初始图像分类模型进行联邦学习得到中间图像分类模型,并将中间图像分类模型作为教师模型对初始图像分类模型进行蒸馏处理,得到目标图像分类模型,最后利用目标图像分类模型进行分类预测,既能在联邦学习框架下充分学习新类数据,又能够缓解在学习新类数据时造成的灾难性遗忘,提高联邦学习架构下图像分类的增量学习能力。高联邦学习架构下图像分类的增量学习能力。高联邦学习架构下图像分类的增量学习能力。


技术研发人员:瞿晓阳 王健宗 王亮
受保护的技术使用者:平安科技(深圳)有限公司
技术研发日:2023.05.31
技术公布日:2023/8/14
版权声明

本文仅代表作者观点,不代表航空之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)

飞行汽车 https://www.autovtol.com/

分享:

扫一扫在手机阅读、分享本文

相关推荐