关注我们.

人工智能

通过多标记预测增强大型语言模型

mm

发布时间

 on

具有多标记预测的大型语言模型

大型语言模型 GPT、LLaMA 等法学硕士凭借其理解和生成类人文本的卓越能力席卷了世界。然而,尽管它们的能力令人印象深刻,但训练这些模型的标准方法(称为“下一个令牌预测”)具有一些固有的局限性。

在下一个标记预测中,模型被训练为在给定前面的单词的情况下预测序列中的下一个单词。虽然这种方法已被证明是成功的,但它可能会导致模型难以应对长期依赖性和复杂的推理任务。此外,推理过程中教师强制训练制度和自回归生成过程之间的不匹配可能会导致性能不佳。

最近的一篇研究论文 格洛克等人。 (2024) Meta AI 推出了一种新颖的训练范式,称为“多标记预测”旨在解决这些限制并增强大型语言模型。在这篇博文中,我们将深入探讨这项突破性研究的核心概念、技术细节和潜在影响。

单令牌预测:传统方法

在深入研究多令牌预测的细节之前,有必要了解已经成为主流的传统方法。 大语言模型的主力 多年训练——单令牌预测,也称为下一个令牌预测。

下一个令牌预测范式

在下一个标记预测范式中,语言模型被训练来预测给定前面上下文的序列中的下一个单词。更正式地说,模型的任务是在给定先前的标记 x1、x1、…、xt 的情况下,最大化下一个标记 xt+2 的概率。这通常是通过最小化交叉熵损失来完成的:

L = -Σt log P(xt+1 | x1, x2, …, xt)

这个简单而强大的训练目标是许多成功的大​​型语言模型的基础,例如 GPT(Radford 等人,2018)、BERT(Devlin 等人,2019)及其变体。

教师强迫和自回归生成

下一个令牌预测依赖于一种称为“老师强迫”其中在训练期间为模型提供了每个未来标记的基本事实。这使得模型能够从正确的上下文和目标序列中学习,从而促进更稳定、更高效的训练。

然而,在推理或生成过程中,模型以自回归方式运行,根据先前生成的标记一次预测一个标记。训练机制(教师强制)和推理机制(自回归生成)之间的不匹配可能会导致潜在的差异和次优性能,特别是对于较长的序列或复杂的推理任务。

下一个标记预测的局限性

虽然下一个代币预测非常成功,但它也有一些固有的局限性:

  1. 短期关注:仅通过预测下一个标记,模型可能难以捕获远程依赖性以及文本的整体结构和连贯性,可能导致不一致或不连贯的生成。
  2. 局部模式锁存:下一个令牌预测模型可以锁定训练数据中的本地模式,这使得推广到分布外场景或需要更抽象推理的任务变得具有挑战性。
  3. 推理能力:对于涉及多步推理、算法思维或复杂逻辑运算的任务,下一个令牌预测可能无法提供足够的归纳偏差或表示来有效支持此类功能。
  4. 样本效率低下:由于下一个标记预测的局部性质,模型可能需要更大的训练数据集来获取必要的知识和推理技能,从而导致潜在的样本效率低下。

这些限制促使研究人员探索替代训练范例,例如多标记预测,其目的是解决其中一些缺点并解锁大型语言模型的新功能。

通过将传统的下一个令牌预测方法与新颖的多令牌预测技术进行对比,读者可以更好地理解后者的动机和潜在好处,为更深入地探索这项开创性的研究奠定了基础。

什么是多标记预测?

多标记预测背后的关键思想是训练语言模型以同时预测多个未来标记,而不仅仅是下一个标记。具体来说,在训练期间,模型的任务是使用在共享模型主干上运行的 n 个独立输出头来预测训练语料库中每个位置的下一个 n 个标记。

例如,通过 4 个令牌预测设置,模型将被训练为在给定前面的上下文的情况下立即预测接下来的 4 个令牌。这种方法鼓励模型捕获更长期的依赖关系,并更好地理解文本的整体结构和连贯性。

一个玩具示例

为了更好地理解多标记预测的概念,让我们考虑一个简单的例子。假设我们有以下句子:

“敏捷的棕色狐狸跳过了懒狗。”

在标准的下一个标记预测方法中,模型将被训练为在给定先前上下文的情况下预测下一个单词。例如,给定上下文“敏捷的棕色狐狸跳过了”,模型的任务是预测下一个单词“懒惰”。

然而,通过多标记预测,模型将被训练为一次预测多个未来单词。例如,如果我们设置 n=4,模型将被训练为同时预测接下来的 4 个单词。给定相同的上下文“敏捷的棕色狐狸跳过了”,该模型的任务是预测序列“懒惰的狗”。 (注意“dog”后面的空格表示句子的结尾)。

通过训练模型同时预测多个未来标记,可以捕获远程依赖性并更好地理解文本的整体结构和连贯性。

技术细节

作者提出了一种简单而有效的架构来实现多令牌预测。该模型由一个共享的变压器主干组成,它产生输入上下文的潜在表示,然后是 n 独立的变压器层(输出头)预测各自的未来令牌。

在训练期间,前向和后向传递经过精心编排,以最大程度地减少 GPU 内存占用。共享主干计算潜在表示,然后每个输出头顺序执行其前向和后向传递,在主干级别累积梯度。这种方法避免了同时实现所有 Logit 向量及其梯度,从而将峰值 GPU 内存使用量从 O(nV + d)O(V+d),其中 V词汇量d尺寸 的潜在表征。

内存高效的实现

训练多标记预测器的挑战之一是降低 GPU 内存利用率。自从 词汇量(V) 通常比 尺寸 潜在表征的 (四),logit向量成为GPU内存使用瓶颈。

为了应对这一挑战,作者提出了一种内存高效的实现,可以仔细调整前向和后向操作的顺序。该实现不是同时实现所有逻辑及其梯度,而是按顺序计算每个独立输出头的前向和后向传递,在主干级别累积梯度。

这种方法避免了同时将所有 Logit 向量及其梯度存储在内存中,从而降低了 GPU 内存利用率峰值 O(nV + d)O(V+d),其中 n 是预测的未来代币数量。

多标记预测的优点

该研究论文提出了使用多标记预测来训练大型语言模型的几个引人注目的优势:

  1. 提高样品效率:通过鼓励模型同时预测多个未来令牌,多令牌预测可推动模型提高样本效率。作者展示了代码理解和生成任务性能的显着改进,高达 13B 参数的模型平均解决了大约 15% 的问题。
  2. 更快的推理:通过多令牌预测训练的附加输出头可用于自推测解码,这是允许并行令牌预测的推测解码的变体。这使得各种批量大小的推理时间加快了 3 倍,即使对于大型模型也是如此。
  3. 促进远程依赖:多标记预测鼓励模型捕获数据中的较长范围的依赖性和模式,这对于需要在更大的上下文中理解和推理的任务特别有益。
  4. 算法推理:作者提出了关于合成任务的实验,证明了多标记预测模型在开发归纳头和算法推理能力方面的优越性,特别是对于较小的模型大小。
  5. 连贯性和一致性:通过训练模型同时预测多个未来令牌,多令牌预测鼓励连贯一致的表示的发展。这对于需要生成更长、更连贯的文本的任务特别有益,例如讲故事、创意写作或生成教学手册。
  6. 改进泛化:作者对合成任务的实验表明,多标记预测模型表现出更好的泛化能力,特别是在分布外设置中。这可能是由于该模型能够捕获更长期的模式和依赖性,这可以帮助它更有效地推断出未见过的场景。

例子和直觉

为了更直观地了解多标记预测为何如此有效,让我们考虑几个示例:

  1. 代码生成:在代码生成的背景下,同时预测多个标记可以帮助模型理解并生成更复杂的代码结构。例如,在生成函数定义时,仅预测下一个标记可能无法为模型提供足够的上下文来正确生成整个函数签名。然而,通过一次预测多个标记,模型可以更好地捕获函数名称、参数和返回类型之间的依赖关系,从而生成更准确、更连贯的代码。
  2. 自然语言推理:考虑一个场景,其中语言模型的任务是回答需要对多个步骤或信息进行推理的问题。通过同时预测多个标记,该模型可以更好地捕获推理过程的不同组成部分之间的依赖关系,从而产生更加连贯和准确的响应。
  3. 长文本生成:在生成长文本(例如故事、文章或报告)时,对于使用下一个标记预测训练的语言模型来说,在较长时间内保持连贯性和一致性可能具有挑战性。多标记预测鼓励模型开发捕获文本整体结构和流程的表示,这可能会导致更连贯和一致的长格式生成。

限制和未来方向

虽然论文中提出的结果令人印象深刻,但仍存在一些局限性和悬而未决的问题,值得进一步研究:

  1. 最佳代币数量:论文探讨了不同的 n 值(要预测的未来标记的数量),发现 n=4 对于许多任务来说效果很好。然而,n 的最佳值可能取决于特定任务、数据集和模型大小。开发确定最佳 n 的原则性方法可以进一步提高性能。
  2. 词汇量大小和标记化:作者指出,多标记预测模型的最佳词汇量和标记化策略可能与下一个标记预测模型所使用的不同。探索这个方面可能会导致压缩序列长度和计算效率之间更好的权衡。
  3. 辅助预测损失:作者认为,他们的工作可能会激发人们对为大型语言模型开发新颖的辅助预测损失的兴趣,超越标准的下一个标记预测。研究替代辅助损失及其与多令牌预测的组合是一个令人兴奋的研究方向。
  4. 理论理解:虽然本文为多令牌预测的有效性提供了一些直觉和经验证据,但对这种方法为何以及如何如此有效的更深入的理论理解将是有价值的。

结论

Gloeckle 等人的研究论文“Better & Faster Large Language Models via Multi-token Prediction”。引入了一种新颖的训练范式,该范式有可能显着提高大型语言模型的性能和能力。通过训练模型同时预测多个未来令牌,多令牌预测鼓励开发远程依赖性、算法推理能力和更好的样本效率。

作者提出的技术实现非常优雅且计算效率高,使得将该方法应用于大规模语言模型训练是可行的。此外,利用自推测解码来实现更快推理的能力是一个显着的实际优势。

尽管仍然存在悬而未决的问题和需要进一步探索的领域,但这项研究代表了大型语言模型领域向前迈出了令人兴奋的一步。随着对功能更强大、更高效的语言模型的需求不断增长,多标记预测可能成为下一代强大人工智能系统的关键组成部分。

在过去的五年里,我一直沉浸在机器学习和深度学习的迷人世界中。 我的热情和专业知识使我为 50 多个不同的软件工程项目做出了贡献,特别关注人工智能/机器学习。 我持续的好奇心也吸引了我对自然语言处理的兴趣,这是我渴望进一步探索的领域。