人工智能

使用多令牌预测增强大型语言模型

mm
Large Language Models with Multi-token Prediction
div]:bg-bg-300 [&_pre]:-mr-4 md:[&_pre]:-mr-9″>
_*]:min-w-0″>
_*]:min-w-0″>

大型语言模型 (LLMs) 像 GPT、LLaMA 和其他模型以其令人惊叹的理解和生成类似人类文本的能力而风靡全球。然而,尽管它们具有令人印象深刻的能力,但训练这些模型的标准方法,即所谓的 “下一个令牌预测”,有一些固有的局限性。

在下一个令牌预测中,模型被训练为预测序列中给定前几个单词的下一个单词。虽然这种方法已经被证明是成功的,但它可能会导致模型难以处理长距离依赖和复杂推理任务。另外,教师强制训练方案和推理过程中自回归生成之间的不匹配可能会导致次优性能。

最近的一篇研究论文由 Gloeckle 等人 (2024) 从 Meta AI 引入了一种新的训练范式,称为 “多令牌预测”,旨在解决这些局限性并增强大型语言模型。在这篇博客文章中,我们将深入探讨这一开创性研究的核心概念、技术细节和潜在影响。

div]:bg-bg-300 [&_pre]:-mr-4 md:[&_pre]:-mr-9″>
_*]:min-w-0″>

单令牌预测:传统方法

在深入探讨多令牌预测的细节之前,了解大型语言模型训练中多年的传统方法至关重要,即单令牌预测,也称为下一个令牌预测。

下一个令牌预测范式

在下一个令牌预测范式中,语言模型被训练为预测序列中给定前几个单词的下一个单词。更正式地,模型的任务是最大化下一个令牌 xt+1 的概率,给定前几个令牌 x1、x2、…、xt。这通常是通过最小化交叉熵损失来完成的:

L = -Σt log P(xt+1 | x1, x2, …, xt)

这种简单却强大的训练目标已经成为许多成功的大型语言模型的基础,例如 GPT(Radford 等,2018 年)、BERT(Devlin 等,2019 年)及其变体。

教师强制和自回归生成

下一个令牌预测依赖于一种称为 “教师强制” 的训练技术,其中模型在训练过程中为每个未来令牌提供了真实值。这使模型能够从正确的上下文和目标序列中学习,促进更稳定和高效的训练。

然而,在推理或生成过程中,模型以自回归的方式运作,根据之前生成的令牌预测一个令牌。教师强制训练方案和自回归生成过程之间的不匹配可能会导致潜在的差异和次优性能,特别是对于更长的序列或复杂的推理任务。

下一个令牌预测的局限性

虽然下一个令牌预测非常成功,但它也有一些固有的局限性:

  1. 短期关注:通过仅预测下一个令牌,模型可能难以捕捉长距离依赖和文本的整体结构和连贯性,可能导致不一致或不连贯的生成。
  2. 局部模式捕获:下一个令牌预测模型可以捕获训练数据中的局部模式,使得模型难以推广到分布之外的场景或需要更抽象推理的任务。
  3. 推理能力:对于涉及多步骤推理、算法思维或复杂逻辑运算的任务,下一个令牌预测可能无法提供足够的归纳偏差或表示来有效地支持这些能力。
  4. 样本效率:由于下一个令牌预测的局部性质,模型可能需要更大的训练数据集来获得必要的知识和推理技能,导致潜在的样本效率低下。

这些局限性激发了研究人员探索替代的训练范式,例如多令牌预测,它旨在解决其中一些缺点并解锁大型语言模型的新能力。

通过对比传统的下一个令牌预测方法和新颖的多令牌预测技术,读者可以更好地理解后者的动机和潜在的好处,为更深入地探索这项开创性的研究奠定了基础。

什么是多令牌预测?

多令牌预测的关键思想是训练语言模型同时预测多个未来令牌,而不是仅预测下一个令牌。具体来说,在训练过程中,模型的任务是预测每个位置在训练语料库中的下 n 个令牌,使用 n 个独立的输出头在共享模型主干上运行。

例如,使用 4 个令牌预测设置,模型将被训练为预测给定前几个单词的下 4 个令牌。这一方法鼓励模型捕捉更长距离的依赖关系并更好地理解文本的整体结构和连贯性。

一个玩具示例

为了更好地理解多令牌预测的概念,让我们考虑一个简单的示例。假设我们有以下句子:

“快速的棕色狐狸跳过懒惰的狗。”

在标准的下一个令牌预测方法中,模型将被训练为预测给定前几个单词的下一个单词。例如,给定上下文 “快速的棕色狐狸跳过懒惰的,”,模型将被任务预测下一个单词 “狗”。

使用多令牌预测,然而,模型将被训练为同时预测多个未来单词。例如,如果我们设置 n=4,模型将被训练为预测下 4 个单词。给定相同的上下文 “快速的棕色狐狸跳过懒惰的,”,模型将被任务预测序列 “狗 。”(注意 “狗” 后面的空格表示句子的结束)。

通过训练模型同时预测多个未来令牌,它被鼓励捕捉长距离依赖并更好地理解文本的整体结构和连贯性。

技术细节

作者提出了一个简单却有效的多令牌预测的架构。模型由一个共享的 Transformer 主干组成,生成输入上下文的潜在表示,然后是 n 个独立的 Transformer 层(输出头),预测相应的未来令牌。

在训练过程中,前向和后向传递被仔细协调以最小化 GPU 内存占用。共享主干计算潜在表示,然后每个输出头顺序地执行其前向和后向传递,在主干级别累积梯度。这一方法避免了同时实例化所有 logit 向量及其梯度,从而将峰值 GPU 内存使用量从 O(nV + d) 减少到 O(V + d),其中 V 是词汇表大小,d 是潜在表示的维度。

内存高效实现

训练多令牌预测器的一个挑战是减少它们的 GPU 内存使用。由于词汇表大小 (V) 通常远大于潜在表示的维度 (d),logit 向量成为 GPU 内存使用的瓶颈。

为了解决这个挑战,作者提出了一个内存高效的实现,仔细调整了前向和后向操作的顺序。与其同时实例化所有 logit 和它们的梯度,不如顺序地为每个独立输出头计算前向和后向传递,在主干级别累积梯度。

这种方法避免了同时存储所有 logit 向量及其梯度,从而将峰值 GPU 内存使用量从 O(nV + d) 减少到 O(V + d),其中 n 是预测的未来令牌数量。

多令牌预测的优势

研究论文提出了使用多令牌预测训练大型语言模型的几个令人信服的优势:

  1. 改进的样本效率:通过鼓励模型同时预测多个未来令牌,多令牌预测推动模型朝着更好的样本效率发展。作者在代码理解和生成任务上展示了显著的性能改进,模型参数最多可达 13B,平均解决了 15% 更多的问题。
  2. 更快的推理:多令牌预测训练的额外输出头可以用于自我推测解码,一种推测解码的变体,允许并行令牌预测。这导致了在广泛的批次大小范围内最多 3 倍的推理速度加快,即使对于大型模型也是如此。
  3. 促进长距离依赖:多令牌预测鼓励模型捕捉数据中的更长距离依赖和模式,这对于需要理解和推理更大上下文的任务特别有益。
  4. 算法推理:作者在合成任务上进行了实验,展示了多令牌预测模型在开发归纳头和算法推理能力方面的优势,特别是对于较小的模型大小。
  5. 连贯性和一致性:通过训练模型同时预测多个未来令牌,多令牌预测鼓励连贯和一致的表示的发展。这对于需要生成更长、更连贯的文本的任务特别有益,例如讲故事、创作或生成说明手册。
  6. 改进的泛化:作者的实验在合成任务上表明,多令牌预测模型表现出更好的泛化能力,特别是在分布之外的场景中。这可能是由于模型能够捕捉更长距离的模式和依赖,从而更好地推广到未见场景。

示例和直觉

为了提供更多关于为什么多令牌预测效果如此好的直觉,让我们考虑几个示例:

  1. 代码生成:在代码生成的背景下,预测多个令牌可以帮助模型理解和生成更复杂的代码结构。例如,当生成函数定义时,预测下一个令牌可能无法为模型提供足够的上下文来正确生成整个函数签名。然而,通过预测多个令牌,模型可以更好地捕捉函数名、参数和返回类型之间的依赖,从而生成更准确和连贯的代码。
  2. 自然语言推理:考虑一个语言模型需要回答一个需要多步骤推理或处理多个信息片段的问题。通过预测多个令牌,模型可以更好地捕捉推理过程的不同组件之间的依赖,从而生成更连贯和准确的响应。
  3. 长文本生成:在生成长文本(如故事、文章或报告)时,使用下一个令牌预测训练的语言模型可能难以维持连贯性和一致性。多令牌预测鼓励模型开发捕捉文本整体结构和流程的表示,从而可能生成更连贯和一致的长文本。

局限性和未来方向

虽然论文中呈现的结果令人印象深刻,但仍有一些局限性和待探索的问题:

  1. 最佳令牌数量:论文探索了不同的 n 值(要预测的未来令牌数量),发现 n=4 对许多任务有效。然而,最佳的 n 值可能取决于特定的任务、数据集和模型大小。开发确定最佳 n 的原则方法可能会带来进一步的性能改进。
  2. 词汇表大小和标记化:作者指出,多令牌预测模型的最佳词汇表大小和标记化策略可能与下一个令牌预测模型不同。探索这一方面可能会带来更好的压缩序列长度和计算效率之间的权衡。
  3. 辅助预测损失:作者建议他们的工作可能会激发人们对开发大型语言模型的新型辅助预测损失的兴趣,超越标准的下一个令牌预测。研究替代损失和它们与多令牌预测的组合是一个令人兴奋的研究方向。
  4. 理论理解:虽然论文提供了一些直觉和实证证据关于多令牌预测的有效性,但对这一方法为什么和如何有效的更深入的理论理解将会很有价值。

结论

Gloeckle 等人 (2024) 的研究论文 “通过多令牌预测实现更好、更快的大型语言模型” 引入了一种新型的训练范式,它有潜力显著提高大型语言模型的性能和能力。通过训练模型同时预测多个未来令牌,多令牌预测鼓励模型捕捉长距离依赖、算法推理能力和更好的样本效率。

作者提出的技术实现简单却有效,使得这种方法可以应用于大规模语言模型训练。此外,利用自我推测解码进行更快推理的能力是一个显著的实际优势。

虽然仍有一些待探索的问题,但这项研究代表了大型语言模型领域的一个令人兴奋的步骤。随着对更强大和更高效语言模型的需求不断增长,多令牌预测可能成为下一代这些强大 AI 系统的关键组成部分。

我过去五年一直沉浸在令人着迷的机器学习和深度学习世界中。我的热情和专业知识使我能够为超过50个不同的软件工程项目做出贡献,特别注重人工智能/机器学习。我的持续好奇心也使我对自然语言处理产生了兴趣,这是一个我渴望进一步探索的领域。