基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法
未命名
10-18
阅读:119
评论:0
1.本发明涉及人工智能技术领域,尤其涉及一种基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法。
背景技术:
2.深度神经网络压缩和加速是针对资源受限环境下的深度神经网络模型进行优化的一系列技术和方法。这些技术的目标是减小模型大小、加快推理速度,并降低计算资源的消耗,从而使得模型能够高效地部署在移动设备、嵌入式系统和边缘设备等场景中。其中,权重剪枝是一种深度神经网络压缩的方法,它通过去除模型中冗余的、贡献较小的连接或参数来减小模型的大小。剪枝可以基于权重敏感性或梯度敏感性等标准进行,在训练前、训练中或训练后进行。
3.粗粒度剪枝和细粒度剪枝是深度神经网络压缩中常用的两种剪枝方法。粗粒度剪枝是一种相对较粗的剪枝方法,它通常在网络的层级或模块级别进行参数修剪。具体来说,该方法会选择整个层或模块中的一部分参数进行修剪,而不是单独剪枝每个参数。这样的剪枝方式使得整个层或模块中的一部分连接被去除,从而减小了网络的规模。由于粗粒度剪枝选择的是整个层或模块的一部分参数,因此它可能会导致一些信息损失,同时也可能降低模型的精度。但它的优点在于其相对简单和快速的操作,适用于硬件加速和高效推理;细粒度剪枝是一种更细致的剪枝方法,它会针对每个参数进行选择性修剪。具体来说,该方法会根据参数的敏感度或重要性,选择性地修剪网络中的一些参数,而保留其他参数。
4.细粒度剪枝相对于粗粒度剪枝来说,更加精细和精确,因为它可以更好地保留重要的连接和参数,减少信息损失,同时还有可能更好地保持模型的精度。然而,使用粗粒度剪枝容易导致模型精度下降,特别是在需要高压缩率的情况下。即使对于压缩效果较好的细粒度剪枝,当修剪参数的比例过大时,模型的精度仍然可能会降低到不足的水平。
5.知识蒸馏是另一种用于深度神经网络压缩的方法,通常通过将教师网络的输出(logits)作为模型内部隐藏的“暗知识”传递给较小的学生网络,从而使学生网络能够逼近教师网络的性能。相比于直接使用one-hot标签进行训练,知识蒸馏策略在提高小型网络准确率方面表现更优。通过知识蒸馏,学生网络可以获得比直接使用one-hot标签训练更多的信息,从而在相对较小的网络结构下,实现更好的准确率。这使得知识蒸馏成为一种有力的神经网络压缩方法,特别适用于资源受限的环境,如移动设备和嵌入式系统等。
6.通常情况下,过参数化的原始模型具有较强的学习和表达能力。然而,在进行剪枝操作后,网络规模变小或受到一定的约束,自身的学习能力可能无法获得复杂的表征,因此即使通过微调也难以完全恢复到原始模型的精度。传统的迭代剪枝方法中,修剪后的模型会经过微调来恢复模型的精度。然而,对于过度细粒度剪枝模型或粗粒度剪枝模型,微调可能会面临一定的困难,难以完全恢复其精度。
技术实现要素:
7.为解决现有技术存在的局限和缺陷,本发明提供一种基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法,包括:
8.步骤s1、根据得到的数据集训练长短期记忆模型,获得具有预设的泛化能力的原始模型,保存所述原始模型;
9.步骤s2、设置剪枝参数,所述剪枝参数包括权重剪枝方法、稀疏度的初始值、稀疏度的期望值;
10.步骤s3、根据所述权重剪枝方法评估连接或权重块的重要性,排序后根据所述稀疏度确定修剪比例,根据所述修剪比例将对应的参数置零,同时禁止已经置零的参数进行更新,得到剪枝模型;
11.步骤s4、使用知识蒸馏方法对所述剪枝模型进行训练,将所述原始模型作为教师,将所述剪枝模型作为学生,通过在损失函数中加入蒸馏损失,使得学生模型拟合教师模型的logits输出,迭代训练预设的次数之后,得到精度恢复的模型;
12.步骤s5、评估所述精度恢复的模型的精度,调整所述稀疏度,根据预设的精度损失范围增减所述稀疏度,返回步骤s3继续剪枝,直至达到所述稀疏度的期望值或满足预设的终止条件。
13.可选的,还包括:
14.获取在预设的任务上进行预设的微调的bert模型;
15.将所述bert模型作为教师,将所述剪枝模型作为学生,使用知识蒸馏方法对所述剪枝模型进行训练。
16.可选的,还包括:
17.使用均方误差损失直接比较logits输出的结果差异,以计算所述蒸馏损失,所述蒸馏损失的表达式如下:
[0018][0019]
其中,z
t
为所述教师模型的logits输出,zs为所述学生模型的logits输出,n是预测类别的个数。
[0020]
可选的,还包括:
[0021]
使用输出概率分布与真实标签之间的交叉熵损失作为目标函数的一部分,最终的损失函数的表达式如下:
[0022][0023]
其中,ys是所述学生模型预测的输出概率分布,由logits经过softmax函数得到;当样本来自原始的标记数据集时,t为标记的真值标签;当样本来自数据增强生成的数据集时,将bert模型的预测结果作为真值标签;α为权重超参数。
[0024]
本发明具有下述有益效果:
[0025]
本发明在知识蒸馏的实现中,除了引入蒸馏损失用于学生模型拟合教师模型的logits输出,还使用输出概率分布与真实标签的交叉熵损失,以确保学生模型的输出与样本的真实标签相互匹配。这两部分损失共同构成目标函数,帮助学生模型从教师模型的“暗知识”中进行学习,优化模型的输出概率分布,从而提高剪枝模型的准确率。另外,本发明将知识蒸馏应用于lstm模型的剪枝过程中,通过合理传递知识,使得剪枝后的模型具备更强的表征能力。
附图说明
[0026]
图1为本发明实施例一提供的基于原始模型的压缩方法的流程图。
[0027]
图2为本发明实施例一提供的基于bert模型的压缩方法的流程图。
[0028]
图3为本发明实施例一提供的知识蒸馏架构示意图。
[0029]
图4为本发明实施例一提供的针对单句文本分类任务的bilstm网络示意图。
[0030]
图5为本发明实施例一提供的针对句子对匹配任务的bilstm网络示意图。
[0031]
图6为本发明实施例一提供的在句子对匹配任务微调bert-base的流程图。
[0032]
图7为本发明实施例一提供的压缩模型与原始模型在单句分类任务上功耗和推理时间上的比较示意图。
具体实施方式
[0033]
为使本领域的技术人员更好地理解本发明的技术方案,下面结合附图对本发明提供的基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法进行详细描述。
[0034]
实施例一
[0035]
本实施例旨在提升长短期记忆(long short-term memory,lstm)剪枝后模型的精度,为此引入了知识蒸馏的方法。通常情况下,过参数化的原始模型具有较强的学习和表达能力。然而,在进行剪枝操作后,网络规模变小或受到一定的约束,自身的学习能力可能无法获得复杂的表征,因此即使通过微调也难以完全恢复到原始模型的精度。知识蒸馏是解决这一问题的有效方法,它从大模型(例如过参数化的原始lstm网络或泛化能力更好的bert网络)中提取监督信息,并将这些知识指导剪枝后模型的训练。这样可以提升剪枝模型的表征能力,使其更好地恢复精度。具体做法是将lstm剪枝流程中的微调过程替换为知识蒸馏,通过这种方式将过参数化网络中的知识传递到剪枝后的网络中。教师网络可以是剪枝前的过度参数化的原始lstm网络,也可以是泛化能力更好的bert网络。通过知识蒸馏,剪枝后的模型能够从教师网络中获取更丰富的信息,从而提高其精度,并达到更好的压缩效果。本实施例的创新之处在于将知识蒸馏应用于lstm的剪枝过程中,通过合理传递知识,使得剪枝后的模型能够具备更强的表征能力。这为在资源受限环境下部署高性能lstm模型提供了新的解决方案。
[0036]
本实施例提供一种基于知识蒸馏的剪枝模型恢复策略,用于在剪枝过程中更好地恢复模型精度,达到更好的压缩效果。本策略旨在解决剪枝过程中剪去大量参数导致泛化能力差的问题,通过知识蒸馏方法,将教师模型的知识引入到剪枝后的学生模型中,从而提高剪枝模型的准确率。
[0037]
传统的微调方法在剪枝模型上进行训练时,通常使用原始训练数据集的one-hot
标签,这种方式在剪枝后的小模型上可能无法充分利用教师模型的知识,导致恢复精度较低。而本实施例提出的基于知识蒸馏的剪枝模型恢复策略,通过教师模型输出的“暗知识”作为监督信号,避免了one-hot标签训练,使剪枝模型能够更好地从泛化能力更强的教师模型中学习,从而显著提高剪枝模型的准确率。
[0038]
本实施例在剪枝过程中引入知识蒸馏的步骤,使剪枝后的模型能够获得来自教师模型的重要信息,有效地弥补了剪枝过程中可能造成的信息损失。通过细致地传递教师模型的知识,剪枝模型的表征能力得到提升,从而在资源受限的情况下,实现更好的压缩效果。通过选择不同来源的教师模型,将策略分为基于原始模型的精度恢复策略与基于bert模型的精度恢复策略。图1展示了基于原始模型的精度恢复策略与迭代剪枝方法结合后的压缩方法。整体压缩方法与迭代剪枝类似,主要区别在于采用知识蒸馏方法在再训练步骤中恢复剪枝模型精度,而不是传统的微调方法。具体步骤如下:
[0039]
1、训练原始lstm模型:在给定数据集上训练lstm,获得一个具有强大泛化能力的原始模型,保存之。
[0040]
2、设置剪枝参数:包括权重剪枝方法、稀疏度(剪枝率)初值和期望的最终稀疏度。
[0041]
3、剪枝:根据修剪方法评估连接或权重块的重要性,排序后根据稀疏度确定修剪比例,将不重要的参数置零并禁止其更新,得到剪枝模型。
[0042]
4、知识蒸馏:使用知识蒸馏方法再训练剪枝模型。将第一步得到的原始模型作为教师,剪枝模型作为学生,通过在损失函数中加入蒸馏损失,使得学生模型能够拟合教师模型的logits输出。迭代训练若干次,得到精度恢复的模型。
[0043]
5、判断是否继续剪枝:评估模型精度并调整稀疏度。根据预期的精度损失范围,增减稀疏度并返回第3步继续剪枝,直至达到预期的最终稀疏度或满足终止条件。
[0044]
在基于原始模型的压缩方法中,教师模型和剪枝模型都是由同一原始模型生成的,因此它们具有高度相似性和兼容性,使得剪枝模型较容易模仿教师模型的logits输出。
[0045]
图2显示了基于bert模型的精度恢复策略与迭代剪枝方法结合后的压缩方法。与基于原始模型的压缩方法类似,区别在于教师不再是原始模型,而是在特定任务上微调的bert模型。微调的bert模型具有更优秀的泛化能力和学习能力,学生模型能够从中学习更丰富的知识,从而更好地恢复被剪枝模型的精度。为了促进bert网络到lstm网络的有效知识转移,本专利采用数据增强手段生成未标记数据集,并利用教师模型为这些样本提供预测标签。在训练过程中,学生模型不仅从原始数据集学习,还从教师模型处理的增强样本的输出中学习。涵盖计算机视觉和自然语言处理任务中的常用数据增强手段,确保有效知识的传递。
[0046]
本实施例采用的知识蒸馏架构如图3所示,包括学生模型和教师模型。学生模型是被剪枝后的较小、简单的lstm稀疏模型,而教师模型可以是剪枝前的原始lstm模型,也可以是经过微调得到的大型、复杂的bert模型。图3显示了本实施例使用的知识蒸馏架构,学生模型和教师模型之间的蒸馏损失是实现剪枝模型恢复精度的关键部分。通过在损失函数中加入蒸馏损失,学生模型被迫拟合教师模型的输出,从而能够更好地恢复精度。蒸馏损失的设计和计算能够帮助学生模型从教师模型的“暗知识”中学习,使得剪枝模型能够接近教师模型的性能,从而提高剪枝模型的准确率。在知识蒸馏的实现细节中,本实施例对剪枝lstm模型的损失函数进行修改,增加一个蒸馏损失项,用于惩罚学生模型输出与教师模型产生
的软标签之间的差异。具体而言,本实施例采用均方误差损失来直接比较logits结果的差异,以实现蒸馏损失的计算,蒸馏损失如下所示:
[0047][0048]
其中,z
t
为教师模型的logits输出,zs为学生模型的logits输出,n是预测类别的个数。
[0049]
除了让学生模型拟合教师模型的logits外,本实施例还需要确保学生模型的输出与样本对应的真实标签匹配。这部分的损失与传统的模型训练时使用的损失函数是一样的。具体而言,本实施例采用的是输出概率分布与真实标签之间的交叉熵损失作为目标函数的一部分,最终的损失函数如下:
[0050][0051]
其中,ys是学生的预测的输出概率分布,由logits经过softmax函数得到。当样本来自原始的标记数据集进行蒸馏时,t为标记的真值标签。当样本来自数据增强生成的数据集时,可以把bert模型预测得到的结果当成真值标签。α为权重超参数。综合起来,本实施例在知识蒸馏的实现中,除了引入蒸馏损失用于让学生模型拟合教师模型的logits,还采用输出概率分布与真实标签的交叉熵损失,以确保学生模型的输出与样本的真实标签相匹配。这两部分损失共同构成目标函数,帮助学生模型从教师模型的“暗知识”中学习,并优化模型的输出概率分布,从而提高剪枝模型的准确率。
[0052]
目前主流的剪枝方法通常采用微调来恢复剪枝后模型的精度。然而,对于更适合推理加速的粗粒度剪枝,微调往往无法有效地恢复模型精度。另一方面,知识蒸馏不仅可以作为一种压缩方法,还可以视为一种小网络训练方法。目前对于lstm权重剪枝与知识蒸馏训练相结合的研究还较为缺乏。为了实现lstm模型的压缩,本实施例引入了基于知识蒸馏的剪枝模型精度恢复策略,旨在更好地恢复模型精度。
[0053]
在此基础上,本实施例提出了两种压缩方法:基于原始模型的压缩方法和基于bert模型的压缩方法。通过将基于知识蒸馏的剪枝模型精度恢复策略、粗粒度剪枝和量化操作相结合,实现了压缩模型在不损失模型精度的情况下,获得相较于原始稠密模型更高的推理速度和能效。
[0054]
在实验过程中,本实施例综合了基于知识蒸馏的剪枝模型精度恢复策略、粗粒度剪枝和量化操作,通过将这些技术相互结合,使得压缩后的模型在保持高精度的同时,显著提高了推理速度和能效。这些创新性的技术综合应用,为模型压缩和加速提供了一种有效的解决方案。因此,本实施例在lstm模型压缩领域具有重要的应用价值,并具备专利申请的条件。
[0055]
本实施例选择在单句文本分类和句子对语义匹配任务上进行评估。在单句文本分类任务中,模型需要将单句输入文本分成不同的类别,这种任务常常被应用在垃圾邮件识别、新闻主题分类等场景中。而在句子对语义匹配任务中,模型需要判断两个文本之间的语
义是否相同,这种任务多用于信息检索、问答系统、对话系统等领域。通过在这两种任务上的评估,本实施例能够全面地验证知识蒸馏对剪枝模型的精度恢复策略的有效性和适用性。同时,这些任务在自然语言处理领域中具有重要的应用价值,因此对于剪枝模型在实际应用中的性能提升具有一定的指导意义。
[0056]
首先是剪枝模型精度的评价指标,在分类任务中,标签可以分为正例和负例。输入样本后,分类算法的判断结果也只有两种,预测为正例,预测为负例。将样本对应的真实标签和预测结果结合起来,会出现下表1四种情况:
[0057]
表1模型预测的四种情况
[0058][0059]
其中,tp和tn两种情况为判断正确的情况,fp和fn为判断错误的情况。统计tp、tn、fp、fn的数量,可以得到下面模型精度的三个评估指标:
[0060]
正确率(accuracy,acc),指的是分类正确的个数占总样本个数的比例。正确率是针对所有样本的统计量,衡量的是分类器整体的性能。acc的计算如下所示:
[0061][0062]
精确率(precision),指的是分类正确的正例个数占分类器判定为正例个数的比例。精确率是针对部分样本的统计量,侧重对分类器判定为正例的数据的统计。精确率的计算如下所示:
[0063][0064]
召回率(recall),指分类正确的正例个数占真正的正例个数的比例。召回率也是对部分样本的统计量,侧重对真实的正例样本的统计。召回率的计算如下所示:
[0065][0066]
除了正确率,还可以使用f1分数衡量模型精度,f1是精确率和召回率的调和平均数,既可以兼顾精确率又可以兼顾召回率。f1分数越高说明精确率和召回率达到了一个很高的平衡点。f1分数的计算如下所示:
[0067][0068]
稀疏度是0值在剪枝后的权重矩阵中所占的比例。稀疏度相同情况下,剪枝模型的精度越高,压缩方法的压缩效果越好。
[0069]
由于要评估压缩前后模型的性能和能效,还需要测量以下指标:
[0070]
速度:可以通过比较速度来比较性能。执行同一个预测任务时,可以测量模型的推理时间来比较速度,推理时间越短,速度越快,模型性能越好。
[0071]
功耗:执行预测任务时的功耗大小,单位为瓦特(watt,w)。本章实验使用的硬件平台是nvidia的gpu,本实施例使用nvidia-smi记录模型推理时的gpu功耗。
[0072]
能效比:能效指的是完成一次预测任务所需的能量,由硬件功耗乘以推理时间得到。执行同一个预测任务时,可以计算速度和功耗的比值得到能效比。能效比越高,说明模型的计算效率越高。
[0073]
对于单句文本分类任务,选择glue(the general language understanding evaluation)的sst-2(the stanford sentiment treebank)数据集。对于句子对匹配任务,本实施例选择glue的qqp(the quora question pairs)数据集。glue是纽约大学、华盛顿大学等机构创建的一个多任务的自然语言理解基准和分析平台。
[0074]
sst-2语料来自斯坦福大学收集整理的电影评论。其中训练集8551个,开发集1043个,测试集1063个。样本由电影评论中的句子和它们情感的人类注释构成。样本评论分别被标注为正面情感(样本标签对应1),和负面情感(样本标签对应为0)。sst-2对应的是一个单句文本二分类任务,即输入一段单句文本,判断文本在情感上是好评还是差评。
[0075]
qqp属于相似性和释义任务,语料来自与社区问答网站quora中的问题对。其中训练集363870个,开发集40431个,测试集390965个。每个样例是两句话,中间用“[tab]”隔开。语义等效的句子对的正样本(标签为1);语义不等效的句子对为负样本(标签为0)。qqp是正负样本不均衡的,其中负样本占63%,正样本是37%。qqp对应的任务是句子对文本匹配任务,即输入一对文本,判断文本对在语义上是否等效。
[0076]
sst-2任务使用acc作为评估指标,qqp使用acc和f1作为评估指标。
[0077]
bilstm在长文本处理中能够更好的提取文本的上下文特征,是一种经典的lstm网络,本实施例使用bilstm并训练针对单句文本分类和句子对匹配的模型。针对单句文本分类任务的bilstm网络如图4所示。网络由嵌入层、bilstm层、最大池化层、全连接层、softmax函数组成。具体的推理过程如下,对于一个t个单词的句子{w
t
}
t=1,2...,t
,首先嵌入层会处理得到词向量序列{x
t
}
t=1,2,...,t
。bilstm层处理词向量序列得到一组特征向量{h
t
}
t=1,2,...,t
。h
t
是由正向传递的隐藏状态和反向传递的拼接得到的向量。计算过程如下:
[0078][0079][0080][0081]
为了降低特征向量的维度,本实施例选择使用最大池化层提取该组特征向量每行的最大值,组合成句子的表示向量v。表示向量v馈送到全连接层后得到logits,logits经过softmax函数计算得到预测类别的概率分布。概率分布中最大的概率值对应类别就是模型对该句文本的预测类别。针对句子对匹配任务的bilstm网络结构如图5所示。具体的计算过程是,首先将句子对中的句子a输入网络a中,得到句子a对应的表示向量va,然后将句子对中的句子b输入网络b中,得到句子b对应的表示向量vb,然后在两个句子的表示向量之间应用下式获取匹配向量m:
[0082]
m=f(va,vb)=[va,vb,vaevb,|v
a-vb|]
[0083]
匹配向量m被馈送到全连接层后得到logits,接着通过softmax函数计算得到预测
类别的概率分布。在这里,类别只有两类,分别是句子对语义匹配和句子语义不匹配。需要注意的是,图中的网络a和网络b事实上是同一个网络,也就是其中嵌入层和bilstm层的权重参数都是相同的。在具体的实现细节上,本实施例通过先后输入句子a和句子b来获取它们各自的表示向量,然后将这些表示向量输入到全连接层中进行进一步处理,得到最后的预测结果。通过这样的设计,可以在保持网络参数共享的情况下,实现对句子a和句子b的有效表示和匹配,从而完成句子对语义匹配任务。这种共享参数的设计有助于减少模型的存储空间和计算量,并且能够更好地利用有限的资源进行高效部署。为了提升模型的精度,嵌入层一般采用预训练的word2vec或glove模型,本实施例使用的是在谷歌新闻上训练得到300维word2vec模型。其次,一般情况下,序列需要填充到一个固定的长度输入到模型中,对于sst-2任务,本实施例将固定长度值设置为200,对于qqp任务,该值设置为100。
[0084]
bert已被证明在序列处理任务中非常有效,并在许多基准测试中取得了最先进的结果,其泛化能力和学习能力远强于lstm。bilstm是一个双向模型,通过在前向和后向两个方向处理序列来生成上下文表示。而bert在生成表征时也考虑了一个词的上下文,因此也可以视为一个双向模型。这两者在某种程度上的相似性有助于本实施例从bert向lstm进行知识迁移。
[0085]
不同于lstm在特定任务上从头开始进行端到端的训练,bert通常在特定任务上进行微调。本实施例选择了bert-base作为教师模型,并对其进行微调来获取教师模型。bert有几个变体,具有不同的规模和能力,而本实施例选择的是bert-base版本。bert-base由12个transformer层组成,每个transformer层有768个隐藏单元和12个自注意头,总共有1.1亿个参数。由于bert-base已经在大规模的预训练任务上进行了训练,它具备了强大的泛化能力和学习能力,可以作为教师模型为剪枝后的lstm模型提供有益的指导和知识转移。
[0086]
本实施例采用的预训练bert-base模型来自tensorflow hub,该预训练bert-base模型并不包含用于文本分类任务的分类层。因此,在使用bert-base模型时,需要在其之上添加一个特定任务的分类层。由于需要进行简单的二元分类,因此选择直接创建一个有2个输出单元的全连接层,并将其添加到网络中。在使用bert-base模型进行文本分类时,输入序列需要按照特定格式进行处理,其中需要在序列开头添加特殊的“[cls]”标记。然后,将“[cls]”对应位置的表征馈送到全连接层,得到logits。最后,logits经过softmax函数计算,可以得到预测类别的概率分布。这样的处理方式使得bert-base模型可以适用于特定的文本分类任务,并通过简单的分类层获得相应的分类结果。
[0087]
针对句子对匹配任务的bert-base微调过程如图6所示。可以观察到,句子对匹配任务的bert和单句文本分类任务的bert在结构上是相同的。不同之处在于在创建输入序列的阶段。由于输入是句子对,所以需要在两个句子之间额外添加一个“[sep]”标记,用于分隔句子。这样的处理方式使得bert-base模型能够适用于句子对匹配任务,并保持与单句文本分类任务相同的网络结构。在微调过程中,通过输入句子对并对相应任务进行训练,bert-base模型可以根据不同的任务要求学习到适用于句子对匹配任务的语义表示。这样,bert-base模型就能够在句子对匹配任务中发挥作用,同时保持了与单句文本分类任务相同的网络结构,从而提高了模型的泛化能力和适用性。
[0088]
为了验证本实施例压缩方法在不同粒度的权重剪枝下的表现,我们在细粒度剪枝和粗粒度剪枝两种剪枝粒度上进行了实验。同时,为了更好地将bert模型内部的知识迁移
到剪枝后的lstm模型中,我们使用了数据增强方法来生成未标记的数据。在具体的实现细节上,本实施例采用了随机遮掩方法,即以概率pmask随机选择句子中的词,被选择的词会被遮掩并替换为“[mask]”标识。我们设置pmask为0.15。最后,在对隐藏层大小为1600的bilstm模型进行粗粒度剪枝的过程中,将块大小设置为8*8可以在性能和精度之间取得一个平衡。
[0089]
表2实验环境
[0090][0091]
表2是一些实验环境信息。实验实现使用tensorflow完成。tensorflow是一个由谷歌开发的用于数值计算和机器学习的开源软件库。它于2015年发布,此后成为使用最广泛和最受欢迎的机器学习框架之一。除了tensorflow外,实验还使用了包括numpy、gensim、matplotlib、anaconda等工具。本实施例对bilstm模型应用本实施例的压缩方法与一些研究中提到的迭代剪枝微调方法进行比较。本实施例压缩方法分为两种:基于原始模型的压缩方法,将原始lstm模型作为知识蒸馏的教师网络;基于bert的压缩方法,将bert-base模型作为知识蒸馏的教师网络。为了验证压缩方法在不同粒度、不同稀疏度下的表现,实验中的剪枝粒度包括细粒度剪枝和粗粒度剪枝,实验中的稀疏度包括30%、60%、90%三个不同等级的稀疏度。
[0092]
表3剪枝粒度设置为细粒度,压缩bilstm在sst-2和qqp上的实验结果
[0093][0094]
a基准模型指的是未经压缩的bilstm模型
[0095]
b kd-lstm指的是基于原始模型的压缩方法(不包括量化)
[0096]
c kd-bert指的是基于bert模型的压缩方法(不包括量化)
[0097]d↓
指的是压缩模型相对于基准模型的精度变化
[0098]
表3显示了剪枝粒度设置为细粒度,压缩bilstm模型在sst-2和qqp上的实验结果。在sst-2任务上,相较于剪枝微调得到的模型,本实施例的压缩模型在正确率上高出0.4到1.36个百分点;在qqp任务上,相较于剪枝微调得到的模型,本实施例的压缩模型在正确率上高出0.69到1.05个百分点,在f1分数上高出0.49到0.82。
[0099]
表4剪枝粒度设置为粗粒度,压缩bilstm在sst-2和qqp上的实验结果
[0100][0101]
表4显示了剪枝粒度设置为粗粒度,压缩bilstm模型在sst-2和qqp上的实验结果。在sst-2任务上,相较于剪枝微调得到的模型,本实施例的压缩模型在正确率上高出0.32到0.92个百分点;在qqp任务上,相较于剪枝微调得到的模型,本实施例的压缩模型在正确率
上高出0.32到0.81个百分点,在f1分数上高出0.4到0.97。
[0102]
观察实验结果,在低稀疏度下,基于bert的压缩方法的模型精度恢复效果要优于基于原始模型的压缩方法;而在高稀疏度的情况下,两者的精度恢复效果相差不大,甚至基于原始模型的压缩方法的精度恢复效果要略优一些。在低稀疏度下,剪枝模型保留了足够的参数来拟合bert网络的输出并学习bert网络内部的知识。而在高稀疏度下,剩余参数较少,剪枝模型无法轻易学习bert网络内部的知识。此时,具有更相似结构的原始lstm模型中的知识更容易迁移到剪枝模型中,从而提升了模型的精度恢复效果。综合来看,基于知识蒸馏的压缩策略能够更好地帮助剪枝模型学习并拟合教师模型的知识,尤其在低稀疏度下,其优势更为明显。在实际应用中,我们可以根据压缩率和模型精度的需求来选择合适的剪枝粒度和压缩方法,以达到更好的压缩效果和性能恢复。
[0103]
表5量化操作对压缩模型(设置为粗粒度剪枝)精度的影响
[0104][0105]
为了进一步压缩和加速剪枝模型,本实施例在实验中对稀疏度为60%的粗粒度剪枝模型应用了16位量化操作。根据表5的结果可以看到,量化操作对模型精度的影响较小。在此情况下,压缩模型相较于基准模型,仍然能够保持较高的精度,没有明显的精度损失。这表明本实施例提出的压缩方法不仅在剪枝过程中恢复了模型的精度,而且在后续的量化操作中也能够保持模型的精度稳定。量化操作在减小模型存储和加速推理过程中起到了重要的作用,同时保持了模型性能的高水平。
[0106]
本实施例还对压缩前后模型的速度和能效进行了评估。如图7所示,对同一句文本进行单句分类时,比较了压缩模型(压缩流程包括粗粒度剪枝、基于知识蒸馏的精度恢复策略和量化)和原始模型的推理时间和功耗。图中的稀疏度0表示未经压缩的原始bilstm模型。实验结果显示随着稀疏度的上升,压缩模型进行单句分类所需的推理时间显著减少,同时推理时的功耗也有所下降。具体而言,当将粗粒度剪枝稀疏度设置为60%并进行16位量化时,压缩后的bilstm模型精度尚未受到明显影响。由此可见,本实施例提出的压缩方法在不损失模型精度的前提下,显著提高了模型的推理速度和能效。压缩模型相较于原始模型获得了约2.3倍的加速比和约2.8倍的能效比提升。这意味着在资源受限的情况下,本实施例的压缩模型能够在保持高精度的同时,提供更高效的推理能力,适用于诸如移动设备、嵌入式系统和边缘设备等场景的部署。
[0107]
本实施例在知识蒸馏的实现中,除了引入蒸馏损失用于学生模型拟合教师模型的logits输出,还使用输出概率分布与真实标签的交叉熵损失,以确保学生模型的输出与样本的真实标签相互匹配。这两部分损失共同构成目标函数,帮助学生模型从教师模型的“暗知识”中进行学习,优化模型的输出概率分布,从而提高剪枝模型的准确率。另外,本实施例
将知识蒸馏应用于lstm模型的剪枝过程中,通过合理传递知识,使得剪枝后的模型具备更强的表征能力。
[0108]
可以理解的是,以上实施方式仅仅是为了说明本发明的原理而采用的示例性实施方式,然而本发明并不局限于此。对于本领域内的普通技术人员而言,在不脱离本发明的精神和实质的情况下,可以做出各种变型和改进,这些变型和改进也视为本发明的保护范围。
技术特征:
1.一种基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法,其特征在于,包括:步骤s1、根据得到的数据集训练长短期记忆模型,获得具有预设的泛化能力的原始模型,保存所述原始模型;步骤s2、设置剪枝参数,所述剪枝参数包括权重剪枝方法、稀疏度的初始值、稀疏度的期望值;步骤s3、根据所述权重剪枝方法评估连接或权重块的重要性,排序后根据所述稀疏度确定修剪比例,根据所述修剪比例将对应的参数置零,同时禁止已经置零的参数进行更新,得到剪枝模型;步骤s4、使用知识蒸馏方法对所述剪枝模型进行训练,将所述原始模型作为教师,将所述剪枝模型作为学生,通过在损失函数中加入蒸馏损失,使得学生模型拟合教师模型的logits输出,迭代训练预设的次数之后,得到精度恢复的模型;步骤s5、评估所述精度恢复的模型的精度,调整所述稀疏度,根据预设的精度损失范围增减所述稀疏度,返回步骤s3继续剪枝,直至达到所述稀疏度的期望值或满足预设的终止条件。2.根据权利要求1所述的基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法,其特征在于,还包括:获取在预设的任务上进行预设的微调的bert模型;将所述bert模型作为教师,将所述剪枝模型作为学生,使用知识蒸馏方法对所述剪枝模型进行训练。3.根据权利要求2所述的基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法,其特征在于,还包括:使用均方误差损失直接比较logits输出的结果差异,以计算所述蒸馏损失,所述蒸馏损失的表达式如下:其中,z
t
为所述教师模型的logits输出,z
s
为所述学生模型的logits输出,n是预测类别的个数。4.根据权利要求3所述的基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法,其特征在于,还包括:使用输出概率分布与真实标签之间的交叉熵损失作为目标函数的一部分,最终的损失函数的表达式如下:其中,y
s
是所述学生模型预测的输出概率分布,由logits经过softmax函数得到;当样本来自原始的标记数据集时,t为标记的真值标签;当样本来自数据增强生成的数据集时,将bert模型的预测结果作为真值标签;α为权重超参数。
技术总结
本发明公开了一种基于知识蒸馏恢复策略剪枝的长短期记忆压缩方法,在知识蒸馏的实现中,除了引入蒸馏损失用于学生模型拟合教师模型的logits输出,还使用输出概率分布与真实标签的交叉熵损失,以确保学生模型的输出与样本的真实标签相互匹配。这两部分损失共同构成目标函数,帮助学生模型从教师模型的“暗知识”中进行学习,优化模型的输出概率分布,从而提高剪枝模型的准确率。另外,本发明将知识蒸馏应用于LSTM模型的剪枝过程中,通过合理传递知识,使得剪枝后的模型具备更强的表征能力。使得剪枝后的模型具备更强的表征能力。使得剪枝后的模型具备更强的表征能力。
技术研发人员:王思野 李元东 赵中原 梁步顺 徐文波 赖锦林 麦吉
受保护的技术使用者:北京邮电大学
技术研发日:2023.08.16
技术公布日:2023/10/11
版权声明
本文仅代表作者观点,不代表航空之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
飞行汽车 https://www.autovtol.com/
上一篇:一种刺梨薏仁米酸奶及其制备方法 下一篇:一种涤纶护肤面料及其生产方法与流程
