资讯专栏INFORMATION COLUMN

[论文笔记]DistilBERT, a distilled version of BERT: sma

jzzlee / 3019人阅读

摘要:引言本文是的阅读笔记。有监督损失的计算方法为其中,表示第个类别的标签表示学生模型对该类别的输出概率。在训练阶段,教师模型和学生模型设置同样的温度,此时一般将温度系数设为。在推理阶段,将温度系数设成,还原标准的函数。

引言

本文是DistilBERT1的阅读笔记。

核心思想

DistilBERT是一个更小更快的BERT模型,类似ALBERT,也是用来给BERT瘦身的。

DistilBERT应用了基于三重损失(Triplet Loss)的知识蒸馏(knowledge distillation)方法。相比BERT模型,DistilBERT的参数量压缩至原来的40%,同时带来60%的推理速度提升,并且在多个下游任务上达到BERT模型效果的97%。

并且该模型可以放到像手机?(on-device)这类设备上运行,具备的好处就是更好的隐私保护,一些隐私数据可以不用上传到服务器,直接在手机端针对这些数据就可以为人们带来个性化的服务。

模型剖析

DistilBERT的名字中Distil就是蒸馏的意思,我们先来看下什么是蒸馏。

蒸馏

蒸馏的解释是加热液体汽化,再使蒸气液化,从而除去其中的杂质。

而这里的知识蒸馏是指将已经训练好的模型包含的知识(Knowledge),蒸馏(Distill)提取到另一个模型里面去。

通常前者是一个较大的模型,后者是一个较小的模型。

从另外一个角度思考的话,我们让小模型来学习大模型。因此,我们把大模型当成老师(Teacher),小模型当成学生(Student)2

蒸馏的目标是让学生模型学习到老师模型的泛化能力,而不是学习拟合训练数据,理论上得到的结果会比单纯拟合训练数据要好。

训练损失

为了将教师模型的知识传输到学生模型,DistilBERT采用了三重损失3:有监督MLM损失、蒸馏MLM损失和词向量余弦损失,如下所示:
L = L s − m l m + L d − m l m + L c o s /mathcal{L}=/mathcal{L}^{s-mlm} + /mathcal{L}^{d-mlm}+/mathcal{L}^{cos} L=Lsmlm+Ldmlm+Lcos
有监督MLM损失 利用掩码语言模型训练得到的损失,即通过输入带有掩码的句子,得到每个掩码位置在词表空间上的概率分布,并利用交叉熵损失函数学习。有监督MLM损失的计算方法为:
L s − m l m = − ∑ i y i log ⁡ ( s i ) /mathcal{L}^{s-mlm}= -/sum_i y_i /log (s_i) Lsmlm=iyilog(si)
其中, y i y_i yi表示第 i i i个类别的标签; s i s_i si表示学生模型对该类别的输出概率。

蒸馏MLM损失 利用教师模型的概率作为指导信号,与学生模型的概率计算交叉熵损失进行学习。由于教师模型是已经训练过的预训练语言模型,其输出的概率分布相比学生模型更加准确,能够起到一定的监督训练目的。因此,在预训练语言模型的知识蒸馏中,通常将有监督MLM称作硬标签(Hard Label)训练方法,将蒸馏MLM称作软标签(Soft Label)训练方法。硬标签对应真实的MLM训练标签,而软标签是教师模型输出的概率。蒸馏MLM损失的计算方法为:
L d − m l m = − ∑ i t i log ⁡ ( s i ) /mathcal{L}^{d-mlm} = -/sum_i t_i /log(s_i) Ldmlm=itilog(si)
其中, t i t_i ti表示教师模型对第 i i i个类别的输出概率; s i s_i si表示学生模型对该类别的输出概率。对比上面两个式子可以很容易看出有监督MLM损失和蒸馏MLM损失之间的区别。需要注意的是,当计算概率 t i t_i ti s i s_i si时,DistilBERT采用了带有温度系数的Softmax函数:
P i = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) P_i = /frac{/exp(z_i/T)}{/sum_j /exp(z_j/T)} Pi=jexp(zj/T)exp(zi/T)
其中, P i P_i Pi表示带有温度的概率值, t i t_i ti s i s_i si均使用该方法计算; z i z_i zi z j z_j zj表示为激活的数值; T T T表示蒸馏里面的温度系数,用于控制输出概率的平滑程度。在训练阶段,教师模型和学生模型设置同样的温度 T T T,此时一般将温度系数设为 T = 8 T=8 T=8。在推理阶段,将温度系数设成 T = 1 T=1 T=1,还原标准的Softmax函数。

词向量余弦损失 词向量余弦损失用来对齐教师模型和学生模型的隐藏状态向量的方向,从隐藏状态维度拉近教师模型和学生模型的距离,如下:
L c o s = cos ⁡ ( h t , h s ) /mathcal{L}^{cos} = /cos(h^t,h^s) Lcos=cos(ht,hs)
其中, h t h^t ht h s h^s hs分别表示教师模型和学生模型最后一层的隐藏状态输出。

DistilBERT:一个蒸馏版本的BERT

学生模型结构 学生模型(DistilBERT)的基本结构是一个六层的BERT模型,同时去掉了标记类型嵌入和池化模块(Pooler)。线性层和层归一化层已经被高度优化且证明有效,因此作者不改动。最后一层的隐藏向量大小,作者发现减少该值并不太影响模型效果。层数能影响模型效果和推理速度,因此作者注重于此参数优化。

学生模型初始化 教师模型直接使用了原版的BERT-base模型。由于教师模型和学生模型的前六层结构基本相同,为了最大化复用教师模型中的知识,学生模型使用了教师模型的前六层进行初始化。

蒸馏 DistilBERT 是在非常大的批次上使用动态掩码利用梯度累积(每批次最多 4K 个样本)进行蒸馏的,没有下一句预测目标。

评估

GLUE : DistilBERT的参数量压缩至原来的40%,并且在多个下游任务上达到BERT模型效果的97%,甚至在WNLI任务上超过了BERT。

IMDb准确率 :BERT(93.46) DistilBERT(92.82)

推理速度(跑了所有的GLUE任务):DistilBERT(410s) BERT(668s) ELMo(895s)

参考


  1. DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter ↩︎

  2. Distilling the knowledge in a neural network ↩︎

  3. 自然语言处理——基于预训练模型的方法 ↩︎

文章版权归作者所有,未经允许请勿转载,若此文章存在违规行为,您可以联系管理员删除。

转载请注明本文地址:https://www.ucloud.cn/yun/121111.html

相关文章

  • Hinton提出泛化更优的「软决策树」:可解释DNN具体决策

    摘要:近日,针对泛化能力强大的深度神经网络无法解释其具体决策的问题,深度学习殿堂级人物等人发表论文提出软决策树。即使没有使用无标签数据,仍然有可能通过使用一种称为蒸馏法,的技术和一种执行软决策的决策树,将神经网络的泛化能力迁移到决策树上。 近日,针对泛化能力强大的深度神经网络(DNN)无法解释其具体决策的问题,深度学习殿堂级人物 Geoffrey Hinton 等人发表 arXiv 论文提出「软决...

    SillyMonkey 评论0 收藏0
  • 2018年深度学习的主要进步

    摘要:在过去几年中,深度学习改变了整个人工智能的发展。在本文中,我将介绍年深度学习的一些主要进展,与年深度学习进展版本一样,我没有办法进行详尽的审查。最后的想法与去年的情况一样,年深度学习技术的使用持续增加。 在过去几年中,深度学习改变了整个人工智能的发展。深度学习技术已经开始在医疗保健,金融,人力资源,零售,地震检测和自动驾驶汽车等领域的应用程序中出现。至于现有的成果表现也一直在稳步提高。在学术...

    sushi 评论0 收藏0
  • 文章总结:Distilling the Knowledge in a Neural Network(

    摘要:通常,应尽量反映任务的真实目标。在训练时,促使尽可能进行正确分类的同时,让其尽可能接近通过获得的。在中,每个都会贡献在某一方向的梯度,对应于的每个,。因此,在过小时将无法捕获中所有的知识。 原文地址:https://arxiv.org/abs/1503.02... Abstract: 在机器学习领域,ensemble learning是一种普遍适用的用来提升模型表现的方法, 将通过...

    happen 评论0 收藏0

发表评论

0条评论

最新活动
阅读需要支付1元查看
<