学生模型训练方法和文本分类系统与流程
未命名
07-20
阅读:133
评论:0
1.本公开涉及一种深度学习领域,尤其涉及一种学生模型训练方法和文本分类系统。
背景技术:
2.在特定自然语言处理任务上要取得高精度预测结果,通常需要使用大量带标签的数据来训练预训练语言模型(plm)。但大量带标签的数据会使得训练成本过高。为此开发出的小样本学习技术可以使预训练语言模型在少量训练样本的条件下进行训练,从而用较低的训练成本实现较高的预测精确度。
3.然而,为了学习到海量语料中的知识,plm的参数规模十分庞大,现有的gpt-3模型的参数量高达175b。这使得plm无法应用在资源受限或延迟敏感的场景中。为此,需要一种改进的、适用于资源受限场景的深度学习语言模型。
技术实现要素:
4.本公开要解决的一个技术问题是提供一种学生模型训练方法和文本分类系统。本发明通过在知识蒸馏时引入域外教师模型,提升了学生模型的蒸馏精度。进一步地,可以根据域内模型的专家评分对域外教师模型的影响程度加以控制。还可以通过额外的伪分类概率向量来进一步缓解小样本场景下由于标签缺乏导致的过拟合。
5.根据本公开的第一个方面,提供了一种学生模型训练方法,包括:向样本添加提示信息和掩码文本占位符以得到经处理的训练样本;使用所述经处理的训练样本微调预训练语言模型plm,得到经提示微调的教师模型;使用有标签的域外训练数据微调所述plm,得到经域外数据微调的教师模型;以及使用所述经处理的训练样本训练学生模型,并且在训练过程中所述学生模型同时学习所述经提示微调的教师模型和所述经域外数据微调的教师模型输出的分类概率向量。
6.可选地,所述方法还包括:使用所述有标签的域外训练数据训练学生模型,其中,在训练过程中所述学生模型同时学习所述经提示微调的教师模型和所述经域外数据微调的教师模型输出的分类概率向量包括:基于所述经提示微调的教师模型和所述经域外数据微调的教师模型针对所述域外训练数据各自输出的预测结果之间的差异,限定训练过程中所述学生模型对所述经域外数据微调的教师模型输出的分类概率向量的学习程度。
7.可选地,使用所述经处理的训练样本训练学生模型包括:获取所述学生模型对所述掩码文本占位符的对应预测结果;以第一损失函数对所述学生模型的网络参数进行调整,所述第一损失函数根据所述掩码文本占位符的对应预测结果与标签是否相同进行损失求取。
8.可选地,在训练过程中所述学生模型同时学习所述经提示微调的教师模型和所述经域外数据微调的教师模型的分类概率向量包括:以第二损失函数对所述学生模型的网络参数进行调整,所述第二损失函数表征所述学生模型输出的分类概率向量与所述经提示微
调的教师模型输出的分类概率向量的相似性;以及以第三损失函数对所述学生模型的网络参数进行调整,所述第三损失函数表征所述学生模型输出的分类概率向量与所述经域外数据微调的教师模型输出的分类概率向量的相似性。
9.可选地,所述第三损失函数表征所述学生模型输出的分类概率向量与所述经域外数据微调的教师模型输出的分类概率向量的经调整的相似性,所述调整系数对应于基于所述经提示微调的教师模型和所述经域外数据微调的教师模型针对所述域外训练数据各自输出的预测结果之间的差异。
10.可选地,在训练过程中所述学生模型同时学习所述经提示微调的教师模型和所述经域外数据微调的教师模型输出的分类概率向量包括:在训练过程中所述学生模型不学习所述经提示微调的教师模型和所述经域外数据微调的教师模型的中间层输出。
11.可选地,所述方法还包括:在训练过程中所述学生模型学习基于标签平滑操作构造的伪概率分布。
12.可选地,在训练过程中所述学生模型学习基于标签平滑操作构造的伪概率分布包括:将伪概率分布转换成伪分类概率向量;以及以第四损失函数对所述学生模型的网络参数进行调整,所述第四损失函数表征经处理的训练样本的真实标签与所述伪分类概率向量之间的差异。
13.可选地,在训练过程中所述学生模型同时学习所述经提示微调的教师模型和所述经域外数据微调的教师模型输出的分类概率向量包括:使用第一损失函数以及表征所述经提示微调的教师模型和所述经域外数据微调的教师模型输出的分类概率向量分别与所述学生模型输出的分类概率向量相似性的损失函数的加权和作为总损失函数,训练所述学生模型。
14.根据本公开的第二个方面,提供了一种文本分类系统,包括:输入获取单元,用于获取来自用户的文本输入;分类判定单元,包括如第一方面所述的方法获取的学生模型,所述学生模型用于基于所述输入文本进行分类;以及操作单元,用于根据分类结果进行操作,所述操作包括如下至少一项:基于所述输入文本的意图分类结果进行反馈;基于所述输入文本的情感倾向分类结果进行统计;以及基于所述输入文本的属性分类进行报告。
15.根据本公开的第三个方面,提供了一种计算设备,包括:处理器;以及存储器,其上存储有可执行代码,当可执行代码被处理器执行时,使处理器执行如上述第一方面所述的方法。
16.根据本公开的第四个方面,提供了一种计算机程序产品,包括可执行代码,当所述可执行代码被电子设备的处理器执行时,使所述处理器执行如上述第一方面所述的方法。
17.根据本公开的第五个方面,提供了一种非暂时性机器可读存储介质,其上存储有可执行代码,当可执行代码被电子设备的处理器执行时,使处理器执行如上述第一方面所述的方法。
18.由此,本发明通过域外教师模型的引入大大增强了小样本场景下的监督,从而提升学生模型的蒸馏精度。可以根据域内教师模型的专家评分控制域外教师模型的影响程度;还可以通过额外的伪分类概率向量来进一步缓解小样本场景下由于标签缺乏导致的过拟合。
附图说明
19.通过结合附图对本公开示例性实施方式进行更详细的描述,本公开的上述以及其它目的、特征和优势将变得更加明显,其中,在本公开示例性实施方式中,相同的参考标号通常代表相同部件。
20.图1示出了评论情绪分析的提示和标签选择的例子。
21.图2示出了根据本发明一个实施例的学生模型训练方法的示意性流程图。
22.图3示出了软硬目标和温度因子调节软目标的例子。
23.图4示出了本发明基于两个教师模型训练学生模型的整体示意图。
24.图5示出了根据本发明一个实施例的文本分类系统的组成示意图。
25.图6示出了根据本发明一实施例可用于实现上述学生模型训练方法的计算设备的结构示意图。
具体实施方式
26.下面将参照附图更详细地描述本公开的优选实施方式。虽然附图中显示了本公开的优选实施方式,然而应该理解,可以以各种形式实现本公开而不应被这里阐述的实施方式所限制。相反,提供这些实施方式是为了使本公开更加透彻和完整,并且能够将本公开的范围完整地传达给本领域的技术人员。
27.大规模的预训练语言模型在nlp(自然语言处理)的各个领域取得了巨大的成功,人们不再从头训练语言模型,而是首先,在大量的通用语料上,通过一些无监督的代理任务得到通用的plm;随后在下游任务中,通用plm在监督数据上微调参数,即可利用通用语料中已有的语言知识实现目标分类功能。在许多实际的语言应用场景中已经广泛采用了这样的两阶段模型范式。
28.小样本学习(few-shot learning)是机器学习的一种范式,目的是在极小训练样本的情况下,仅仅对模型进行少量的微调(finetuning),得到精度较高的模型。是否拥有从少量样本中学习和概括的能力,是将人工智能和人类智能进行区分的明显分界点,因为人类可以仅通过一个或几个示例就可以轻松地建立对新事物的认知,而机器学习算法通常需要成千上万的有标签样本来保证其泛化能力。在机器视觉、自然语言处理等领域中,数据的标注是昂贵的;而在一个新的场景中,标注数据是十分稀缺的。这限制了深度学习算法的应用。小样本学习在机器学习领域具有重大意义和挑战性。在人类的快速学习能力的启发下,人们希望机器学习模型在学习了一定类别的大量数据后,对于新的类别,只需要少量的样本就能快速学习,这就是小样本学习要解决的问题。
29.针对小样本任务的特殊性质,可以将阶段二的下游任务微调重构为“完形填空问题”,即,使用pet(完形填空训练,pattern exploiting training)。
30.从bert开始,在下游任务中,基于提示微调(prompt-based finetuning)预训练语言模型已成为nlp领域的通用做法。而拥有175b参数的gpt-3模型带来了一种将lm用于下游任务的新方法:通过使用自然语言提示信息(prompt)和任务示例(demonstration)作为上下文,gpt-3只需几个样本即可处理很多任务,而不需更新底层模型中的参数。gpt-3庞大的模型规模是其成功的重要因素,而提示信息和务示例的概念也让我们对如何更好地使用语言模型有了新的认识。提示信息是插入到输入样本中的一段文本,因此可以将原始任务将
预测问题转化为mlm(掩码语言模型)问题。例如,假设我们要对影评“no reason to watch(没有理由去看)”进行情感分类,则可以在句子中附加一个提示“it was(它是)”,得到“no reason to watch.it was[mask]”。“[mask]”字符对应预训练模型mlm head(mlm头)的预测输出映射到实际的类别标签。如对于上述例子,若预测“great”的概率较高,则对应“正”类别,若预测“terrible”的概率较高,则对应“负”类别。在plm具有海量语言知识的情况下,plm会有更高的概率判断“[mask]”字符对应的是“terrible”而不是“great”。
[0031]
图1示出了评论情绪分析的提示和标签选择的例子。如图1所示,为了判定句子“wonderful movie in every aspect(各方面都很棒的电影)”的情感类别(例如,是正性的赞赏还是负性的批评),可以直接向输入文本添加提示“it is[mask].”,并确定模型要预测的[mask]可以是对应于正性标签的“good(好的)”,或是对应于负性标签的“terrible(糟糕的)”。换句话说,可以将提示模板构造为“it is+情感属性词汇”格式,并且使用verbalizer(语言表达器)从词汇表中对应于正性和负性情感的词汇中选择两个作为标签,在此例中为“good”和“terrible”。由此,原始的训练样本“wonderful movie in every aspect.”可以被改造为添加了掩码提示的经处理的训练样本:“wonderful movie in every aspect.it is[mask].”,该训练样本随后被送入plm进行训练,例如,根据plm模型预测[mask]是good还是terrible来进行损失函数的求取和基于反向传播的调整。
[0032]
在图1所示的例子中,可以人为选择正负标签,例如,可以选择图示的“good”和“terrible”,也可以选择词汇表(例如,总词表)中其他用于表示情感属性的词,例如“great”和“bad”。另外,在图1的例子中,提示词“it is(它是)”也可由人为设计。
[0033]
另外,虽然在图中示出了英文文本和提示的例子,但是也可以利用提示、掩码和标签进行针对中文的样本构建和后续分类。
[0034]
虽然现有的plm能够基于小样本学习中的提示微调,快速具备目标任务的分类能力,例如,将评论情绪分类为正性或是负性,但由此得到的经微调的plm的参数规模十分庞大,使其无法应用在资源受限或延迟敏感的场景中。
[0035]
在机器学习中,可以利用知识蒸馏(knowledge distillation,kd),将知识从一个大模型转移到一个小模型。虽然大型模型(如非常深的神经网络或许多模型的集合体)比小型模型有更高的知识容量,但这种容量可能没有被充分利用。另一方面,小型模型相比于大型模型更难以训练。知识蒸馏将知识从一个大的模型转移到一个较小的模型,而不损失其有效性。由于小模型的评估成本较低,它们可以被部署在功能较弱的硬件上(如移动设备)。
[0036]
然而,现有的知识蒸馏技术难以应用于小样本学习场景,因为稀少的标注数据会造成学生模型过拟合,并且现有的知识蒸馏方法无法进行基于提示调整目标模型的训练。
[0037]
为此,本发明提出一种基于提示调整plm的小样本知识蒸馏方案。该方案要求学生模型同时从经提示微调的plm的教师模型和经域外数据微调的教师模型两者中学习,由此通过增加域外监督数据的蒸馏途径,缓解小样本场景下学生模型的过拟合问题。进一步,本发明通过实验发现中间层表示的提取在小样本场景会对蒸馏性能产生负面影响。因此本发明摒弃了知识迁移plm的中间层表示这一本领域内的流行做法,而是在极小样本的情况下,通过增加伪分布概率的蒸馏管线来降低学生模型的过拟合问题。
[0038]
在一个实施例中,本发明可以实现为一种学生模型训练方法。图2示出了根据本发明一个实施例的基于预训练语言模型的学生模型训练方法的示意性流程图。
[0039]
在步骤s210,向样本添加提示信息和掩码文本占位符以得到经处理的训练样本。
[0040]
在本发明的学生模型训练中,需要使用一个小样本训练数据集,例如,给定一个n-way-k-shot的训练数据集x。在此,n表示模型可以输出n个分类,k表示每个分类中的样本数。因此n-way-k-shot的训练数据集x中包含n
×
k个样本,并且在小样本训练的情况下,n
×
k的值将会非常小。
[0041]
样本例如可以是n
×
k个带有感情偏好的句子,并且可以如图1所示通过添加对应提示信息的内容“itis”以及掩码文本占位符“[mask]”来构造输入样本,并且基于句子实际包含的感情偏好,为每个样本生成相应的标签(即,真实标签)。
[0042]
在步骤s220,使用所述经处理的训练样本调整预训练语言模型,得到经提示调整的教师模型。在步骤s230,使用有标签的域外训练数据微调预训练语言模型plm,得到经域外数据微调的教师模型。在此应该理解的是,在使用有标签的域外训练数据微调预训练语言模型plm时,同样可以采用构造提示并进行mlm预测的方式进行。即,域外训练数据也可以是向域外样本添加提示信息和掩码文本占位符以得到经处理的训练样本。
[0043]
在步骤s240,使用所述经处理的训练样本训练学生模型,并且在训练过程中所述学生模型对所述经提示调整的教师模型和所述经域外数据微调的教师模型输出的分类概率向量进行学习。
[0044]
在本发明中,除了构造一个n-way-k-shot的训练数据集x之外,还需要给定一个大规模的plm(教师原模型)和另一个较小的plm(学生原模型)。另外,为了训练域外教师模型,还需要一个比训练数据集x更大的域外数据集(其中),作为知识蒸馏任务的辅助数据集。与教师模型相比,学生模型可以具有相似但更少的子结构。例如,教师模型可以具有n
t+1
个transformer结构,学生模型可以具有n
s+1
个transformer结构,其中,n
s+1
<n
t+1
,并且优选地,n
s+1
<<n
t+1
。本领域技术人员应该理解的时,transformer(变换器)本身是利用自注意力机制提高模型训练速度的深度学习模型,现有的预训练语言模型中会包括多个transformer结构。训练的目标是将教师模型在小样本数据上通过提示微调得到的性能,以知识蒸馏的方式压缩到学生模型中。
[0045]
如上给定内容和训练目标可由数学符号描述。具体地,训练数据集x={(xi,yi)}(在此,yi是输入文本xi的分类标签,其中是标签集,和)。使用θ
t
参数表示经提示调整(也可称为经提示微调)的plm。模型θ
t
是从其预训练的初始化θ
t’提示调整得到。换句话说,在此可以使用θ
t’来表示教师原模型的参数。本发明的目标是获得一个由θs参数表示的小得多的plm,同时使得θs的性能可以尽可能地接近θ
t
。
[0046]
为了实现这一目标,在步骤s210中构造了小样本训练数据集x之后,需要在步骤s220,先从大规模的原始plmθ
t’得到经提示调整的plmθ
t
。在此,可以使用掩码语言模型(mlm)任务来获取经提示调整的plmθ
t
。具体地,送入原始plm的训练样本可以具有例如“wonderful movie in every aspect.it is[mask].”的形式,并且需要输出[mask]对应的标签分类。例如,在分类n=2时(此时,[mask]对应的词仅包括一个正性词汇,例如“good”,和一个负性词汇,例如“terrible”),模型会输出[mask]对应的词是“good”还是“terrible”的概率,如果模型判定是“good”的概率大于是“terrible”的概率,则分类结果为good”。
[0047]
在mlm中,对于单词表中的单词,预测目标向量为one-hot向量。one-hot向量又被称为“独热向量”,即,在包含可预测词汇的集合中,只有与分类标签(此例中为“good”)相对
应的系数为1,预测其他词汇的系数都为0。因此在使用one-hot向量来构造损失函数时,只有模型预测到了标签本身,例如模型输出为“good”时,才不会引起损失,而当模型输出了“good”之外的任何其他单词,都会引起相同的损失。
[0048]
可以根据计算得到的损失基于反向传播算法对原始plm的参数θ
t’进行调整,并在输入了n-way-k-shot的训练数据集x之后,得到经提示调整的plmθ
t
。
[0049]
可以与使用小样本训练数据集x训练教师原模型类似的方式,使用小样本训练数据集x训练学生模型。为此,使用所述经处理的训练样本训练学生模型包括:获取所述学生模型对所述掩码文本占位符的对应预测结果;以第一损失函数对所述学生模型的网络参数进行调整,所述第一损失函数根据所述掩码文本占位符的对应预测结果与被掩码词是否相同进行损失求取。换句话说,同样可以使用掩码语言模型(mlm)任务利用one-hot向量来构造第一损失函数。
[0050]
在一个实施例中,可以遵循pet方法。此时,plm可以是pet。设l(y)是类别y的标签词,是使用输入xi和plmθ
t
的情况下在掩码语言标记处预测l(y)的分数。基于θ
t
将xi分配到分类y的概率定义如下:
[0051][0052]
在此,进一步将表示为所有n个类别的概率向量。是xi对应的n维独热(one-hot)真实向量。可以直接如下所示推导出学生模型的分类损失(对应第一损失函数):
[0053][0054]
其中,ce(
·
,
·
)表示两个向量之间的交叉熵损失。
[0055]
由于本发明仅使用有标签的小样本数据集,因此面临缺乏训练数据,且监督信号相当有限的挑战。因此,在本发明中,大胆启用有标签的域外数据来作为知识蒸馏的补充。在此,域外(out-of-domain)数据是与域内(in-domain)数据相对的概念。如果plm是使用数据集a(例如,收集自报纸的文本数据)训练得到的,而数据集b(例如,收集自xx百科上的文本数据)没有参与当前plm的训练。则对于基于当前plm的任何下游任务而言,数据集a就对应于域内数据。数据集b则对应于域外数据。由于不同的数据集通常具有不同的形式和领域,因此本领域技术人员通常不会使用域外数据来辅助当前模型的微调。
[0056]
本发明的发明人则通过创造性地引入域内专业得分(domain expertise score),将域外数据与当前模型的不匹配的影响降至最低。在一个实施例中,本发明可以利用非小样本的域外数据集(其中)用于知识蒸馏。由于域外数据集并非小样本数据集,使得能够相对容易的解决域内小样本导致的过拟合问题。然而,考虑到域外数据集和训练数据集x之间的域间差异可能会导致学生模型从域外数据集获取本不应该被迁移的知识(例如,对当前域内任务无意义的知识),在一个优选实施例中,本发明可以通过域内专业评分(即,用于评估使用训练数据集x得到的经提示微调的模型给出的概率分布与使用域外数据集得到的经域外数据微调的模型给出的概率分布的相似性的指标)来对学生模型对域外数据集的学习加以限制。在一个实施例中,域内专业评分可以对应于所述经
提示微调的教师模型和所述经域外数据微调的教师模型针对域外数据集各自输出的预测结果之间的差异。此时,在训练过程中所述学生模型同时学习所述经提示微调的教师模型和所述经域外数据微调的教师模型输出的分类概率向量可以包括:基于所述经提示微调的教师模型和所述经域外数据微调的教师模型针对所述域外数据训练样本各自输出的预测结果之间的差异,限定训练过程中所述学生模型对所述经域外数据微调的教师模型输出的分类概率向量的学习程度。如下将结合针对损失函数对域内专业评分进行详述。
[0057]
在本发明中,使用经提示微调的plmθ
t
和经域外数据微调的plmθ
ot
两者作为教师模型,用于进行知识蒸馏,即,在使用小样本训练数据集x训练学生模型的过程中,所述学生模型对所述经提示调整的教师模型和所述经域外数据微调的教师模型输出的分类概率向量进行学习。换句话说,知识蒸馏可以通过学生模型对教师模型的输出的分类概率向量进行学习来实现。
[0058]
用于分类的模型最后会设置一个softmax层,其输出值对应了相应类别的概率值。在知识蒸馏时,由于已经有了一个泛化能力较强的教师模型,因此可以直接让学生模型去学习教师模型的泛化能力。一个很直白且高效的迁移泛化能力的方法就是:使用softmax层输出的类别的概率(即,分类概率向量)来作为“软目标(soft-target)”。
[0059]
常规的神经网络训练方法是定义一个损失函数,目标是使预测值尽可能接近于真实值(对应于hard-target,也可称为“硬目标”),损失函数就是使神经网络的损失值和尽可能小。这种训练过程是对真实值(ground truth)求极大似然。在知识蒸馏中,则涉及使用教师模型的类别概率作为训练学生模型的软目标的训练过程。
[0060]
图3示出了软硬目标和温度因子调节软目标的例子。假设图3对应于一个10分类模型的输出。图3左侧对应于硬目标,包括原始数据集标注的one-shot标签,除了第2类正标签为1之外,其他9个类的负标签都是0。图3中间对应于软目标,例如教师模型softmax层输出的类别概率,每个类别都分配了概率,第2类对应的正标签概率最高(接近0.6),但其他9个分类的负标签也有一定概率,例如第3类的概率接近0.2,虽然这些概率都低于正标签的概率。
[0061]
由于softmax层的输出,除了正例之外,负标签也带有教师模型归纳推理的大量信息,比如某些负标签对应的概率远远大于其他负标签(例如,图3中间所示的第3类),则代表教师模型在推理时认为该样本与该负标签有一定的相似性(例如,狗的图像应该更接近猫而不是飞机,因此在图像以最高概率被分类为狗时,该图像在猫分类的概率应该大于在飞机分类的概率),因此知识蒸馏的训练方式使得每个样本给学生模型带来的信息量大于传统的训练方式。换句话说,使用soft-target训练时,学生模型可以快速学习到教师模型的推理过程。
[0062]
为此,在训练过程中所述学生模型对所述经提示调整的教师模型和所述经域外数据微调的教师模型输出的分类概率向量进行学习包括:以第二损失函数对所述学生模型的网络参数进行调整,所述第二损失函数表征所述学生模型输出的分类概率向量与所述经提示调整的教师模型输出的分类概率向量的差异;以及以第三损失函数对所述学生模型的网络参数进行调整,所述第三损失函数表征所述学生模型输出的分类概率向量与所述经域外数据微调的教师模型输出的分类概率向量的差异(在一个实施例中,第三损失函数还应考虑如上所述的“域内专业评分”,如下将详述)。
[0063]
在此,第二损失函数和第三损失函数涉及的分类概率向量可以是各个模型的softmax层输出的分类概率向量,即,模型最后一层输出的分类概率向量(例如,logits。深度学习中,logits对应于最终的全连接层的输出,而非其本意的logit函数。通常神经网络中都是先有logits,而后通过sigmoid函数或者softmax函数得到概率p)。
[0064]
进一步地,经提示调整的教师模型所提供的“软目标”,是针对如上所述的小样本训练数据集x。具体地,可以将来自小样本训练数据集x的同一个样本xi分别送入经提示调整的教师模型和学生模型,经提示调整的教师模型和学生模型计算各自的分类概率向量,并且可以基于两个分类概率向量的交叉熵构造第二损失函数,并以降低两向量之间的交叉熵为该第二损失函数的调整方向。
[0065]
在一个实施例中,可以将针对有标注的小样本训练数据集x的知识蒸馏损失(对应于第二损失函数)定义如下:
[0066][0067]
在此,α》0是温度因子。如前所述,可以使用教师模型softmax层输出的类别概率作为soft-target,帮助学生模型快速学习到教师模型的推理过程。然而,由于softmax函数会把logits数值在各类别之间进行概率归一,并放大logits数值之间的差异,因此当softmax输出的概率分布熵相对较小时,负标签的值都很接近0,对损失函数的贡献非常小。此时,需要温度因子α来放大负标签携带的信息。在整个知识蒸馏过程中,可以先升高温度因子,然后在测试阶段恢复“低温”,这也是“蒸馏”一词的来源。
[0068]
回到图3,其中,图3中间示出了教师模型softmax层输出的类别概率,此时相当于温度因子α=1。而在蒸馏过程中,可以升高温度因子,由此提升其他负标签对应概率的取值。图3右侧示出了温度因子α升高(大于1)时的软目标分别。显然,此时正标签的概率仍然最大,但负标签的概率占比增加。
[0069]
如前所述,本发明大胆启用有标签的域外数据来作为知识蒸馏的补充,并且优选地,通过创造性地引入域内专业得分,将域外数据与当前模型的不匹配的影响降至最低。域内专业得分用于有效地衡量域外实例是否对没有人工标记的kd有用。为了确保模型的同质性,训练了一个基于域外数据集的plm,即,如上所述的基于域外数据微调的plm,其参数化表示为θ
ot
。实例(xi,yi)被同时传递给θ
ot
和θ
t
,得到各自的预测结果和分数si是如下基于两个概率向量(即,实例(xi,yi))之间的jensen-shannon散度(jsd,也可被称为js散度)计算得到的:
[0070][0071]
其中,kld(
·
||
·
)两个概率分布之间的kullback
–
leibler散度(kld,也可被称为kd散度)。基于域内专业评分,可以将作为第三损失函数的域外知识蒸馏损失定义为:
[0072]
[0073]
在此,kl散度用于衡量两个分布之间的差异,等于一个交叉熵减去一个信息熵。kl散度具有非负性和不对称性,由于kl散度的不对称性会在训练过中存在一些问题,因此在kl散度的基础上引入了js散度。js散度是对称的,取值在0到1之间。应该理解的是,输入在此使用了kl散度以及基于kl散度的js散度来度量两个概率向量之间的差异性,但在其他实施例中,也可以利用其他的度量指标。
[0074]
除了mlm head(mlm头)之外,现有技术认为中间层表示也可以为知识蒸馏提供有用线索。但是,本发明的发明人经过大量实验发现,利用中间层进行知识蒸馏在小样本场景下往往对学生模型的性能有着不利的影响。正是基于这一发现,本发明摒弃了利用教师模型中间层信息进行知识蒸馏的常见做法,此时,步骤s240可以包括:在训练过程中所述学生模型不学习所述经提示微调的教师模型和所述经域外数据微调的教师模型的中间层输出。
[0075]
在小样本场景下,需要尽可能多地挖掘模型中的信息。由于本发明不对中间层信息进行学习,因此需要构造更多的蒸馏管道来满足小样本场景下的监督需求。在一个实施例中,本发明还可以通过引入fake logits(伪logits)来提供蒸馏管道。fake logits能够用于知识蒸馏的前提在于可以将标签平滑分布看作是基于软目标进行知识蒸馏的一种特殊情况。为此,本发明的学习模型训练方法还包括:在训练过程中所述学生模型学习基于标签平滑操作构造的伪概率分布。而在训练过程中所述学生模型学习基于标签平滑操作构造的伪概率分布则可以具体包括:将伪概率分布转换成伪分类概率向量;以及以第四损失函数对所述学生模型的网络参数进行调整,所述第四损失函数表征经处理的训练样本的真实标签与所述伪分类概率向量之间的差异。
[0076]
具体地,本发明模仿教师模型的行为生成用于学生模型学习的fake logits。具体地,可以基于标签平滑操作导出伪概率分布其中:
[0077][0078]
其中,m是一个取值接近1的常数。通过设置一个较高的温度(例如,温度因子取值较大),可以通过将转化伪logit向量由此,可以将伪kd损失定义为:
[0079][0080]
在此,cel(
·
,
·
)是在两个向量之间定义的,与logits之间的交叉熵损失。
[0081]
由此,在训练过程中所述学生模型对所述经提示调整的教师模型和所述经域外数据微调的教师模型输出的分类概率向量进行学习可以包括:使用第一损失函数以及表征所述经提示调整的教师模型和所述经域外数据微调的教师模型输出的分类概率向量分别与所述学生模型输出的分类概率向量相似性的损失函数的加权和作为总损失函数,训练所述学生模型。
[0082]
在本发明的优选实施例中,表征学生模型和教师模型输出的分类概率向量相似性的损失函数可以包括如上学生模型基于训练数据集x进行的mlm任务求取的第一损失函数,基于学生模型学习经提示调整的软目标相似性的第二损失函数,基于学生模型学习经域外数据调整的软目标相似性的第三损失函数(此时,使用的是域外数据集并且需要考虑域
内专业分数)。进一步地,还可以利用fake logits的第四损失函数。
[0083]
综合上述的知识蒸馏目标进行加权求和,得到了如下的最终损失函数:
[0084][0085]
其中λ1和λ2是平衡超参数,由此得到蒸馏后的学生模型。另外,应该理解的是,虽然此处为和分配了相同的超参数,但是在其他实施例中,两者的权重可以不同。
[0086]
图4示出了本发明基于两个教师模型训练学生模型的整体示意图。
[0087]
如图所示,学生模型具有与教师模型类似的结构,但具有更少的transformer,也具有更少的transformer编码器层(图示为trm layer)。两个教师模型具有完全相同的网络结构,只是经提示调整的教师模型(图中称为“域内教师模型“)基于n-way-k-shot的训练数据集x的小样本训练调整后,其参数θ
t
相较于原教师模型的θ
t’有所微调;而经域外数据调整的教师模型(图中称为“域外教师模型“)则基于域外数据集(其中)的非小样本训练调整后,其参数θ
ot
相较于原教师模型的θ
t’有所微调。
[0088]
在对学生模型进行训练时,需要构造任务特定的mlm损失,即如上所述的第一损失函数这是基于有标注数据集x训练得到的。
[0089]
而两个教师模型对学生模型的知识蒸馏,首先可以是基于例如softmax层输出的分类概率向量相似度的知识蒸馏。对于经调整的教师模型和学生模型之间的相似度,如上的分类向量相似度的求取可以针对图4右侧上部的有标注数据训练推理得到的,即对应于如上所述的第二损失函数对于经域外数据微调的教师模型和学生模型之间的相似度,如上的分类向量相似度的求取可以针对图4右侧下部的有标注域外数据集训练推理得到的,即对应于如上所述的第三损失函数
[0090]
进一步地,可以利用fake logits来提供蒸馏管道。可以基于标签平滑操作导出伪概率分布并由此构造伪kd作为第四损失函数
[0091]
如上结合图2和图4描述了本发明基于预训练语言模型的学生模型训练方法。在如上方法获取到了学生模型后,由于学生模型的参数量更少,并且经过多管线知识蒸馏学习到了大规模预训练语言模型中所蕴涵的知识,因此适用于在实际应用场景中加以布置。
[0092]
为此,本发明还可以实现为一种文本分类系统。图5示出了根据本发明一个实施例的文本分类系统的组成示意图、
[0093]
如图所示,该系统500可以包括输入获取单元510、分类判定单元520和操作单元530。
[0094]
输入获取单元510用于获取来自用户的文本输入。此处获取的文本输入,可以是用户自行输入的文本,例如,用户发布的影评,也可以是用户输入转换而来的文本,例如用户语音输入的识别结果。
[0095]
分类判定单元520可以包括经如上所述方法获取的学生模型,所述学生模型用于基于所述文本输入进行分类。操作单元530则可用于根据分类结果进行操作。
[0096]
该文本分类系统可以在多种场景中加以应用。例如,在智能机器人交互场景中,可以从文本框中获取用户输入的内容,并实时判定用户输入内容所蕴含的用户意图,并使得操作单元后续能够根据识别出的意图给予合适的文字反馈或是其他操作。再例如,可以在
例如对针对某一艺术作品的海量评论进行读取和分类,从中给出广大用户对该作品整体评价的情感倾向,并且可以以此作为向其他用户进行推荐的依据。另外,还可以对文本本身是否是软广或是不健康信息加以分类,并且在后续操作中进行删除或是报告。
[0097]
为此,操作单元530根据分类结果所进行的操作可以包括如下至少一项:基于所述输入文本的意图分类结果进行反馈;基于所述输入文本的情感倾向分类结果进行统计;以及基于所述输入文本的属性分类进行报告。
[0098]
图6示出了根据本发明一实施例可用于实现上述基于预训练语言模型的学生模型训练方法的计算设备的结构示意图。
[0099]
参见图6,计算设备600包括存储器610和处理器620。
[0100]
处理器620可以是一个多核的处理器,也可以包含多个处理器。在一些实施例中,处理器620可以包含一个通用的主处理器以及一个或多个特殊的协处理器,例如图形处理器(gpu)、数字信号处理器(dsp)等等。在一些实施例中,处理器620可以使用定制的电路实现,例如特定用途集成电路(application specific integrated circuit,asic)或者现场可编程逻辑门阵列(field programmable gate arrays,fpga)。
[0101]
存储器610可以包括各种类型的存储单元,例如系统内存、只读存储器(rom),和永久存储装置。其中,rom可以存储处理器620或者计算机的其他模块需要的静态数据或者指令。永久存储装置可以是可读写的存储装置。永久存储装置可以是即使计算机断电后也不会失去存储的指令和数据的非易失性存储设备。在一些实施方式中,永久性存储装置采用大容量存储装置(例如磁或光盘、闪存)作为永久存储装置。另外一些实施方式中,永久性存储装置可以是可移除的存储设备(例如软盘、光驱)。系统内存可以是可读写存储设备或者易失性可读写存储设备,例如动态随机访问内存。系统内存可以存储一些或者所有处理器在运行时需要的指令和数据。此外,存储器610可以包括任意计算机可读存储媒介的组合,包括各种类型的半导体存储芯片(dram,sram,sdram,闪存,可编程只读存储器),磁盘和/或光盘也可以采用。在一些实施方式中,存储器610可以包括可读和/或写的可移除的存储设备,例如激光唱片(cd)、只读数字多功能光盘(例如dvd-rom,双层dvd-rom)、只读蓝光光盘、超密度光盘、闪存卡(例如sd卡、min sd卡、micro-sd卡等等)、磁性软盘等等。计算机可读存储媒介不包含载波和通过无线或有线传输的瞬间电子信号。
[0102]
存储器610上存储有可执行代码,当可执行代码被处理器620处理时,可以使处理器620执行上文述及的基于预训练语言模型的学生模型训练方法。
[0103]
上文中已经参考附图详细描述了根据本发明的基于预训练语言模型的学生模型训练和文本分类系统。
[0104]
本发明使用基于提示的学习改善大规模plm的小样本学习性能。为了在资源有限的环境中实现plm的在线应用部署,本发明采用知识蒸馏对大规模plm进行压缩。具体地,本发明提出了一种用于提示微调plm的小样本知识蒸馏实现,并且使得学生模型同时从经提示微调的教师模型和经域外数据微调的教师模型中学习。本发明一改现有技术从中间层学习知识的操作(被证明是对小样本学习下学生模型性能有损害),而是改用伪分布概率来提供额外的监督。
[0105]
需要说明的是,本技术所涉及的用户信息(包括但不限于用户设备信息、用户个人信息等)和数据(包括但不限于用于分析的数据、存储的数据、展示的数据等),均为经用户
授权或者经过各方充分授权的信息和数据,并且相关数据的收集、使用和处理需要遵守相关国家和地区的相关法律法规和标准,并提供有相应的操作入口,供用户选择授权或者拒绝。
[0106]
此外,根据本发明的方法还可以实现为一种计算机程序或计算机程序产品,该计算机程序或计算机程序产品包括用于执行本发明的上述方法中限定的上述各步骤的计算机程序代码指令。
[0107]
或者,本发明还可以实施为一种非暂时性机器可读存储介质(或计算机可读存储介质、或机器可读存储介质),其上存储有可执行代码(或计算机程序、或计算机指令代码),当所述可执行代码(或计算机程序、或计算机指令代码)被电子设备(或计算设备、服务器等)的处理器执行时,使所述处理器执行根据本发明的上述方法的各个步骤。
[0108]
本领域技术人员还将明白的是,结合这里的公开所描述的各种示例性逻辑块、模块、电路和算法步骤可以被实现为电子硬件、计算机软件或两者的组合。
[0109]
附图中的流程图和框图显示了根据本发明的多个实施例的系统和方法的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段或代码的一部分,所述模块、程序段或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意,在有些作为替换的实现中,方框中所标记的功能也可以以不同于附图中所标记的顺序发生。例如,两个连续的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,框图和/或流程图中的每个方框、以及框图和/或流程图中的方框的组合,可以用执行规定的功能或操作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。
[0110]
以上已经描述了本发明的各实施例,上述说明是示例性的,并非穷尽性的,并且也不限于所披露的各实施例。在不偏离所说明的各实施例的范围和精神的情况下,对于本技术领域的普通技术人员来说许多修改和变更都是显而易见的。本文中所用术语的选择,旨在最好地解释各实施例的原理、实际应用或对市场中的技术的改进,或者使本技术领域的其它普通技术人员能理解本文披露的各实施例。
技术特征:
1.一种学生模型训练方法,包括:向样本添加提示信息和掩码文本占位符以得到经处理的训练样本;使用所述经处理的训练样本微调预训练语言模型plm,得到经提示微调的教师模型;使用有标签的域外训练数据微调所述plm,得到经域外数据微调的教师模型;以及使用所述经处理的训练样本训练学生模型,并且在训练过程中所述学生模型同时学习所述经提示微调的教师模型和所述经域外数据微调的教师模型输出的分类概率向量。2.如权利要求1所述的方法,还包括:使用所述有标签的域外训练数据训练学生模型,其中,在训练过程中所述学生模型同时学习所述经提示微调的教师模型和所述经域外数据微调的教师模型输出的分类概率向量包括:基于所述经提示微调的教师模型和所述经域外数据微调的教师模型针对所述域外训练数据各自输出的预测结果之间的差异,限定训练过程中所述学生模型对所述经域外数据微调的教师模型输出的分类概率向量的学习程度。3.如权利要求1所述的方法,其中,使用所述经处理的训练样本训练学生模型包括:获取所述学生模型对所述掩码文本占位符的对应预测结果;以第一损失函数对所述学生模型的网络参数进行调整,所述第一损失函数根据所述掩码文本占位符的对应预测结果与标签是否相同进行损失求取。4.如权利要求3所述的方法,其中,在训练过程中所述学生模型同时学习所述经提示微调的教师模型和所述经域外数据微调的教师模型的分类概率向量包括:以第二损失函数对所述学生模型的网络参数进行调整,所述第二损失函数表征所述学生模型输出的分类概率向量与所述经提示微调的教师模型输出的分类概率向量的相似性;以及以第三损失函数对所述学生模型的网络参数进行调整,所述第三损失函数表征所述学生模型输出的分类概率向量与所述经域外数据微调的教师模型输出的分类概率向量的相似性。5.如权利要求4所述的方法,其中,所述第三损失函数表征所述学生模型输出的分类概率向量与所述经域外数据微调的教师模型输出的分类概率向量的经调整的相似性,所述调整系数对应于基于所述经提示微调的教师模型和所述经域外数据微调的教师模型针对所述域外训练数据各自输出的预测结果之间的差异。6.如权利要求1所述的方法,其中,在训练过程中所述学生模型同时学习所述经提示微调的教师模型和所述经域外数据微调的教师模型输出的分类概率向量包括:在训练过程中所述学生模型不学习所述经提示微调的教师模型和所述经域外数据微调的教师模型的中间层输出。7.如权利要求1所述的方法,还包括:在训练过程中所述学生模型学习基于标签平滑操作构造的伪概率分布。8.如权利要求7所述的方法,其中,在训练过程中所述学生模型学习基于标签平滑操作构造的伪概率分布包括:将伪概率分布转换成伪分类概率向量;以及以第四损失函数对所述学生模型的网络参数进行调整,所述第四损失函数表征经处理
的训练样本的真实标签与所述伪分类概率向量之间的差异。9.如权利要求3所述的方法,其中,在训练过程中所述学生模型同时学习所述经提示微调的教师模型和所述经域外数据微调的教师模型输出的分类概率向量包括:使用第一损失函数以及表征所述经提示微调的教师模型和所述经域外数据微调的教师模型输出的分类概率向量分别与所述学生模型输出的分类概率向量相似性的损失函数的加权和作为总损失函数,训练所述学生模型。10.一种文本分类系统,包括:输入获取单元,用于获取来自用户的文本输入;分类判定单元,包括如权利要求1-9中任一项所述的方法获取的学生模型,所述学生模型用于基于所述输入文本进行分类;以及操作单元,用于根据分类结果进行操作,所述操作包括如下至少一项:基于所述输入文本的意图分类结果进行反馈;基于所述输入文本的情感倾向分类结果进行统计;以及基于所述输入文本的属性分类进行报告。11.一种计算设备,包括:处理器;以及存储器,其上存储有可执行代码,当所述可执行代码被所述处理器执行时,使所述处理器执行如权利要求1-9中任一项所述的方法。12.一种计算机程序产品,包括可执行代码,当所述可执行代码被电子设备的处理器执行时,使所述处理器执行如权利要求1-9中任一项所述的方法。13.一种非暂时性机器可读存储介质,其上存储有可执行代码,当所述可执行代码被电子设备的处理器执行时,使所述处理器执行如权利要求1-9中任一项所述的方法。
技术总结
本公开涉及一种学生模型训练方法和文本分类系统。该方法包括:向样本添加提示信息和掩码文本占位符以得到经处理的训练样本;使用经处理的训练样本微调预训练语言模型PLM,得到经提示微调的教师模型;使用有标签的域外训练数据微调PLM,得到经域外数据微调的教师模型;以及使用经处理的训练样本训练学生模型,并且在训练过程中所述学生模型同时学习如上两个教师模型输出的分类概率向量。本发明通过在知识蒸馏时引入域外教师模型,提升了学生模型的蒸馏精度。进一步地,可以根据域内模型的专家评分对域外教师模型的影响程度加以控制。还可以通过额外的伪分类概率向量来进一步缓解小样本场景下由于标签缺乏导致的过拟合。解小样本场景下由于标签缺乏导致的过拟合。解小样本场景下由于标签缺乏导致的过拟合。
技术研发人员:汪诚愚 陈小庆
受保护的技术使用者:阿里巴巴(中国)有限公司
技术研发日:2023.03.07
技术公布日:2023/7/18
版权声明
本文仅代表作者观点,不代表航空之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
飞行汽车 https://www.autovtol.com/
