Connect with us

人工智能

Jamba:AI21 Labs 的新型混合 Transformer-Mamba 语言模型

mm
Jamba AI21 style, a sleek hybrid machine with glowing circuitry, merging Transformer and Mamba components, surrounded by swirling data streams and abstract neural connections, set against a futuristic backdrop with soft, ambient lighting

语言模型在近年来经历了快速的发展,基于 Transformer 的架构在自然语言处理中占据了主导地位。然而,随着模型的规模增大,处理长上下文、内存效率和吞吐量的挑战变得更加明显。

AI21 Labs 推出了 Jamba,一种结合了 Transformer 和 Mamba 架构优势的混合框架的最新语言模型。该文章详细介绍了 Jamba 的架构、性能和潜在应用。

Jamba 概述

Jamba 是由 AI21 Labs 开发的混合大语言模型,利用 Transformer 层和 Mamba 层的组合,并集成了 Mixture-of-Experts (MoE) 模块。这种架构使 Jamba 能够平衡内存使用、吞吐量和性能,使其成为自然语言处理任务的强大工具。该模型设计用于适应单个 80GB GPU,提供高吞吐量和小内存占用,同时保持最先进的性能。

Jamba 的架构

Jamba 的架构是其能力的基石。它建立在一个新型的混合设计上,交替使用 Transformer 层和 Mamba 层,并集成了 MoE 模块以增强模型的能力而不显著增加计算需求。

1. Transformer 层

Transformer 架构已成为现代大语言模型的标准,因为它能够高效地处理并行处理和捕获文本中的长距离依赖。然而,其性能往往受到高内存和计算需求的限制,特别是在处理长上下文时。Jamba 通过集成 Mamba 层来解决这些限制,我们将在下面探讨。

2. Mamba 层

Mamba 是一种最近的状态空间模型 (SSM),旨在比传统的 RNN 或甚至 Transformer 更高效地处理序列中的长距离关系。Mamba 层特别适用于减少 Transformer 中存储键值 (KV) 缓存相关的内存占用。通过交替使用 Mamba 层和 Transformer 层,Jamba 减少了整体内存使用,同时保持高性能,特别是在需要处理长上下文的任务中。

3. Mixture-of-Experts (MoE) 模块

Jamba 中的 MoE 模块引入了一种灵活的方式来扩展模型能力。MoE 允许模型增加可用参数的数量,而不需要在推理过程中成比例地增加活跃参数。在 Jamba 中,MoE 应用于某些 MLP 层,路由机制选择每个标记的顶部专家以激活。这一选择性激活使 Jamba 能够在处理复杂任务时保持高效率。

以下图像演示了混合注意力-Mamba 模型中的诱导头的功能,这是 Jamba 的一个关键特征。在这个例子中,注意力头负责预测诸如“积极”或“消极”等标签,以响应情感分析任务。突出显示的单词说明了模型的注意力如何强烈地集中在少数示例中的标签令牌上,特别是在预测最终标签之前的关键时刻。这种注意力机制在模型的上下文学习能力中发挥着至关重要的作用,即模型必须根据给定的上下文和少数示例推断适当的标签。

通过将 Mixture-of-Experts (MoE) 与注意力-Mamba 混合架构集成,Jamba 提供了性能改进。通过使用 MoE,Jamba 增加了其能力,而不成比例地增加计算成本。这在各种基准测试中表现得尤为突出,例如 HellaSwag、WinoGrande 和自然问题 (NQ)。具有 MoE 的模型不仅实现了更高的准确率(例如,在 WinoGrande 上达到 66.0%,而没有 MoE 的模型仅达到 62.5%),而且在不同领域中展示了改进的对数概率(例如,在 C4 上达到 -0.534)。

关键架构特征

  • 层组合: Jamba 的架构由组合 Mamba 和 Transformer 层的块组成,按照特定的比例(例如,1:7,表示每七个 Mamba 层对应一个 Transformer 层)。此比例经过优化,以实现最佳性能和效率。
  • MoE 集成: MoE 层每隔几层应用一次,每个标记有 16 个专家可用,并激活前两位专家。这种配置使 Jamba 能够有效地扩展,同时管理内存使用和计算效率之间的权衡。
  • 归一化和稳定性: 为了确保训练过程中的稳定性,Jamba 在 Mamba 层中采用 RMSNorm,这有助于减轻大型激活脉冲可能带来的问题,例如在大规模训练中可能出现的问题。

Jamba 的性能和基准测试

Jamba 已经在广泛的基准测试中进行了严格的测试,展示了其在各个方面的竞争力。以下部分突出了 Jamba 在各种基准测试中表现出的优势,展示了其在通用 NLP 任务和长上下文场景中的优势。

1. 常见 NLP 基准测试

Jamba 已经在多个学术基准测试中进行了评估,包括:

  • HellaSwag (10 射击):一个常识推理任务,Jamba 实现了 87.1% 的性能分数,超过了许多竞争对手模型。
  • WinoGrande (5 射击):另一个推理任务,Jamba 得分 82.5%,再次展示了其处理复杂语言推理的能力。
  • ARC-Challenge (25 射击):Jamba 展示了强大的性能,得分 64.4%,反映了其处理具有挑战性的多项选择问题的能力。

在聚合基准测试中,如 MMLU (5 射击),Jamba 实现了 67.4% 的分数,表明其在多种任务中的稳健性。

2. 长上下文评估

Jamba 的一个突出特征是其处理极长上下文的能力。该模型支持最长 256K 个标记的上下文长度,在公开可用模型中排名第一。这种能力通过针对海量数据的基准测试进行了评估,Jamba 在不同上下文长度(包括 256K 个标记)下展示了卓越的检索准确率。

3. 吞吐量和效率

Jamba 的混合架构显著提高了吞吐量,特别是在处理长序列时。

在比较不同模型吞吐量(每秒标记数)的测试中,Jamba 在大批量和长上下文场景中始终优于其同行。例如,在 128K 标记的上下文中,Jamba 的吞吐量是 Mixtral(一个可比模型)的三倍。

使用 Jamba:Python

对于渴望尝试 Jamba 的开发人员和研究人员,AI21 Labs 已经在 Hugging Face 等平台上提供了该模型,使其可用于广泛的应用。以下代码片段演示了如何加载和使用 Jamba 生成文本:


from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

input_ids = tokenizer("最近的超级碗 LVIII,", return_tensors='pt').to(model.device)["input_ids"]

outputs = model.generate(input_ids, max_new_tokens=216)

print(tokenizer.batch_decode(outputs))

该简单脚本加载 Jamba 模型和分词器,根据给定的输入提示生成文本,并打印生成的输出。

微调 Jamba

Jamba 被设计为一个基础模型,这意味着它可以被微调以适应特定的任务或应用。微调允许用户将模型适应特定领域,提高模型在特定任务上的性能。以下示例展示了如何使用 PEFT 库微调 Jamba:

import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments

tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16)

lora_config = LoraConfig(r=8,
target_modules=[
"embed_tokens","x_proj", "in_proj", "out_proj", # mamba
"gate_proj", "up_proj", "down_proj", # mlp
"q_proj", "k_proj", "v_proj"
# attention],
task_type="CAUSAL_LM", bias="none")

dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = SFTConfig(output_dir="./results",
num_train_epochs=2,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10, learning_rate=1e-5, dataset_text_field="quote")
trainer = SFTTrainer(model=model, tokenizer=tokenizer, args=training_args,
peft_config=lora_config, train_dataset=dataset,
)
trainer.train()

该代码片段微调 Jamba 以适应英语名言数据集,调整模型参数以更好地适应特定任务的文本生成。

部署和集成

AI21 Labs 通过各种平台和部署选项使 Jamba 家族广泛可用:

  1. 云平台
    • 可在主要云提供商(包括 Google Cloud Vertex AI、Microsoft Azure 和 NVIDIA NIM)上使用。
    • 即将在 Amazon Bedrock、Databricks Marketplace 和 Snowflake Cortex 上推出。
  2. AI 开发框架:
    • 与流行框架(如 LangChain 和 LlamaIndex(即将推出))集成。
  3. AI21 Studio:
    • 通过 AI21 自有的开发平台直接访问。
  4. Hugging Face
    • 可下载和实验的模型。
  5. 本地部署:
    • 针对具有特定安全或合规需求的组织的私有、现场部署选项。
  6. 定制解决方案:
    • AI21 为企业客户提供定制模型定制和微调服务。

面向开发人员的功能

Jamba 模型具有多个内置功能,使其对开发人员特别有吸引力:

  1. 函数调用:轻松将外部工具和 API 集成到您的 AI 工作流中。
  2. 结构化 JSON 输出:直接从自然语言输入生成干净、可解析的数据结构。
  3. 文档对象消化:高效处理和理解复杂文档结构。
  4. RAG 优化:内置功能以增强检索增强生成管道。

这些功能,结合模型的长上下文窗口和高效处理,使 Jamba 成为广泛开发场景中的多功能工具。

伦理考虑和负责任的 AI

虽然 Jamba 的能力令人印象深刻,但以负责任的 AI 心态来对待其使用至关重要。AI21 Labs 强调了几个重要点:

  1. 基础模型性质:Jamba 1.5 模型是预训练的基础模型,没有特定的对齐或指令微调。
  2. 缺乏内置保障措施:模型没有内置的审查机制。
  3. 谨慎部署:在生产环境或与最终用户一起使用 Jamba 之前,应实施额外的适应和保障措施。
  4. 数据隐私:在使用基于云的部署时,请注意数据处理和合规要求。
  5. 偏见意识:像所有大型语言模型一样,Jamba 可能会反映其训练数据中的偏见。用户应意识到这一点,并实施适当的缓解措施。

通过牢记这些因素,开发人员和组织可以负责任、合乎道德地利用 Jamba 的能力。

AI 开发的新篇章?

AI21 Labs 推出的 Jamba 家族标志着大型语言模型演进的重要里程碑。通过结合 Transformer 和状态空间模型的优势,集成 Mixture-of-Experts 技术,并突破上下文长度和处理速度的界限,Jamba 为各个行业的 AI 应用开启了新的可能性。

随着 AI 社区继续探索和在此创新架构基础上进行建设,我们可以期待在模型效率、长上下文理解和实际 AI 部署方面看到进一步的进步。Jamba 家族不仅代表了一组新模型,也代表了对大规模 AI 系统设计和实施方法的潜在转变。

我过去五年一直沉浸在令人着迷的机器学习和深度学习世界中。我的热情和专业知识使我能够为超过50个不同的软件工程项目做出贡献,特别注重人工智能/机器学习。我的持续好奇心也使我对自然语言处理产生了兴趣,这是一个我渴望进一步探索的领域。