Connect with us

Jamba: AI21 Labs’ 새로운 하이브리드 Transformer-Mamba 언어 모델

인공지능

Jamba: AI21 Labs’ 새로운 하이브리드 Transformer-Mamba 언어 모델

mm
Jamba AI21 style, a sleek hybrid machine with glowing circuitry, merging Transformer and Mamba components, surrounded by swirling data streams and abstract neural connections, set against a futuristic backdrop with soft, ambient lighting

언어 모델은 빠른 발전을 거쳐 왔으며, Transformer 기반 아키텍처가 자연어 처리에서 선두를 지키고 있습니다. 그러나 모델이 커질수록, 긴 컨텍스트를 처리하는 것, 메모리 효율성, 처리량 등에 대한課題가 더 두드러지게 되었습니다.

AI21 Labs는 Jamba를紹介하여 새로운 솔루션을 제공했습니다. Jamba는 상태-of-the-art 대형 언어 모델(LLM)로, Transformer와 Mamba 아키텍처의 장점을 하이브리드 프레임워크에서 결합했습니다. 이 기사에서는 Jamba의 아키텍처, 성능, 및 잠재적인 응용 프로그램에 대해 설명합니다.

Jamba 개요

Jamba는 AI21 Labs에서 개발한 하이브리드 대형 언어 모델로, Transformer 레이어와 Mamba 레이어를 결합하고, Mixture-of-Experts (MoE) 모듈을 통합했습니다. 이 아키텍처는 Jamba가 메모리 사용, 처리량, 및 성능을 균형 있게 유지할 수 있도록 해줍니다. 모델은 단일 80GB GPU에서 작동하도록 설계되어 높은 처리량과 작은 메모리 풋프린트를 제공하면서 다양한 벤치마크에서 최첨단 성능을 유지합니다.

Jamba의 아키텍처

Jamba의 아키텍처는其의 능력의 핵심입니다. 그것은 새로운 하이브리드 설계를 기반으로 하며, Transformer 레이어와 Mamba 레이어를 교대로 배치하고, MoE 모듈을 통합하여 모델의 능력을 향상시키면서 계산 요구를 크게 증가시키지 않습니다.

1. Transformer 레이어

Transformer 아키텍처는 현대의 LLM에서 표준이 된 것으로, 병렬 처리를 효율적으로 처리하고, 긴 거리 의존성을 캡처하는 능력으로 인해 널리 사용되고 있습니다. 그러나 Transformer의 성능은 높은 메모리와 계산 요구로 인해 제한될 수 있습니다. 특히 긴 컨텍스트를 처리할 때 이러한 제한이 두드러집니다. Jamba는 이러한 제한을 해결하기 위해 Mamba 레이어를 통합했습니다.

2. Mamba 레이어

Mamba는 최근에 개발된 상태 공간 모델(SSM)로, 전통적인 RNN이나 Transformer보다 더 효율적으로 긴 거리 관계를 처리할 수 있습니다. Mamba 레이어는 Transformer에서 키-값 캐시를 저장하는 데 관련된 메모리 풋프린트를 줄이는 데 특히 효과적입니다. Transformer 레이어와 Mamba 레이어를 교대로 배치함으로써, Jamba는 전체 메모리 사용을 줄이면서 높은 성능을 유지합니다. 특히 긴 컨텍스트를 처리하는 태스크에서 이러한 효과가 두드러집니다.

3. Mixture-of-Experts (MoE) 모듈

Jamba의 MoE 모듈은 모델의 능력을 확장하는 유연한 접근 방식을 소개합니다. MoE는 모델이 활성화된 매개 변수의 수를 비례하지 않게 증가시키지 않으면서도 사용 가능한 매개 변수의 수를 증가시킬 수 있습니다. Jamba에서 MoE는 일부 MLP 레이어에 적용되며, 라우터 메커니즘은 각 토큰에 대해 활성화할 전문가를 선택합니다. 이러한 선택적 활성화는 Jamba가 복잡한 태스크를 처리하면서 높은 효율성을 유지할 수 있도록 합니다.

아래 이미지는 하이브리드 Attention-Mamba 모델의 유도 헤드의 기능을示しています. 이 예에서는 주의 헤드가 감성 분석 태스크에서 “Positive” 또는 “Negative”와 같은 레이블을 예측하는 데 책임이 있습니다. 강조된 단어는 모델의 주의가 몇 샷 예제에서 특히 레이블 토큰에 강하게 집중되어 있음을 보여줍니다. 이러한 주의 메커니즘은 모델이 컨텍스트와 몇 샷 예제를 기반으로 적절한 레이블을 추론하는 Ability에 중요한 역할을 합니다.

MoE를 Attention-Mamba 하이브리드 아키텍처와 통합함으로써 제공되는 성능 개선은 표에 강조되어 있습니다. MoE를 사용하여 Jamba는 계산 비용을 비례하지 않게 증가시키지 않으면서도 능력을 증가시킵니다. 이는 특히 HellaSwag, WinoGrande, 및 Natural Questions(NQ)와 같은 다양한 벤치마크에서 두드러집니다. MoE가 있는 모델은 더 높은 정확도(예: WinoGrande에서 66.0% 대 62.5%)를 달성하면서 다양한 도메인에서 로그 확률을 개선합니다(예: C4에서 -0.534).

키 아키텍처 특징

  • 레이어 구성: Jamba의 아키텍처는 Mamba와 Transformer 레이어를 특정 비율(예: 1:7, 즉 7개의 Mamba 레이어당 1개의 Transformer 레이어)로 결합한 블록으로 구성됩니다. 이 비율은 최적의 성능과 효율성을 위해 조정됩니다.
  • MoE 통합: MoE 레이어는 몇 개의 레이어마다 적용되며, 16개의 전문가가 사용 가능하며, 각 토큰당 상위 2개의 전문가를 활성화합니다. 이 구성은 Jamba가 효과적으로 확장하면서 메모리 사용과 계산 효율성 간의 트레이드오프를 관리할 수 있도록 합니다.
  • 정규화 및 안정성: 훈련 중에 안정성을 보장하기 위해 Jamba는 Mamba 레이어에서 RMSNorm을 통합합니다. 이는 큰 활성화 스파이크와 같은 문제를 완화하는 데 도움이 됩니다.

Jamba의 성능 및 벤치마크

Jamba는 다양한 벤치마크에서 철저하게 테스트되어 전반적으로 경쟁력 있는 성능을 보여주었습니다. 다음 섹션에서는 Jamba가 우수한 성능을 보여준 몇 가지 주요 벤치마크를 강조합니다.

1. 일반 NLP 벤치마크

Jamba는 여러 학술 벤치마크에서 평가되었습니다.

  • HellaSwag (10샷): 공통 감성 추론 태스크에서 Jamba는 87.1%의 성능 점수를 달성하여 많은 경쟁 모델을 앞섰습니다.
  • WinoGrande (5샷): 또 다른 추론 태스크에서 Jamba는 82.5%의 점수를 달성하여 복잡한 언어적 추론을 처리하는 Ability를 보여주었습니다.
  • ARC-Challenge (25샷): Jamba는 64.4%의 점수를 달성하여 어려운 다중 선택 질문을 관리하는 Ability를 보여주었습니다.

집계 벤치마크인 MMLU(5샷)에서 Jamba는 67.4%의 점수를 달성하여 다양한 태스크에서 강건함을 보여주었습니다.

2. 긴 컨텍스트 평가

Jamba의 주요 특징 중 하나는极めて 긴 컨텍스트를 처리하는 Ability입니다. 모델은 최대 256K 토큰의 컨텍스트 길이를 지원하며, 이는 공개적으로 사용 가능한 모델 중 가장 긴 것입니다. 이러한 Ability는 Needle-in-a-Haystack 벤치마크에서 테스트되었으며, Jamba는 다양한 컨텍스트 길이에서, 최대 256K 토큰에 이르기까지, 예외적인 검색 정확도를 보여주었습니다.

3. 처리량 및 효율성

Jamba의 하이브리드 아키텍처는 특히 긴 시퀀스에서 처리량을 크게 개선합니다.

다른 모델과 비교하여 처리량(초당 토큰 수)을 테스트한 결과, Jamba는 일관되게 동등한 모델을 앞섰습니다. 특히 대형 배치 크기와 긴 컨텍스트가 포함된 시나리오에서 이러한 효과가 두드러졌습니다. 예를 들어, 128K 토큰의 컨텍스트에서 Jamba는 Mixtral과 같은 모델의 3배의 처리량을 달성했습니다.

Python을 사용하여 Jamba 사용

개발자와 연구자들이 Jamba를 실험하기를 원하는 경우, AI21 Labs는 Hugging Face와 같은 플랫폼에서 모델을 제공하여 다양한 응용 프로그램에 쉽게 접근할 수 있도록 했습니다. 다음 코드 조각은 Jamba를 로드하고 텍스트를 생성하는 방법을示しています:


from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

input_ids = tokenizer("최근 Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]

outputs = model.generate(input_ids, max_new_tokens=216)

print(tokenizer.batch_decode(outputs))

이 간단한 스크립트는 Jamba 모델과 토크나이저를 로드하고, 주어진 입력 프롬프트에 따라 텍스트를 생성하며, 생성된 출력을 인쇄합니다.

Jamba 미세 조정

Jamba는 기본 모델로 설계되어 특정 태스크 또는 응용 프로그램을 위해 미세 조정이 가능합니다. 미세 조정은 사용자가 모델을 특정 도메인에 적응시키고, 전문 태스크에서 성능을 개선하는 것을 가능하게 합니다. 다음 예는 PEFT 라이브러리를 사용하여 Jamba를 미세 조정하는 방법을示합니다:

import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments

tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16)

lora_config = LoraConfig(r=8,
target_modules=[
"embed_tokens","x_proj","in_proj","out_proj", # mamba
"gate_proj","up_proj","down_proj", # mlp
"q_proj","k_proj","v_proj"
# attention],
task_type="CAUSAL_LM", bias="none")

dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = SFTConfig(output_dir="./results",
num_train_epochs=2,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10, learning_rate=1e-5, dataset_text_field="quote")
trainer = SFTTrainer(model=model, tokenizer=tokenizer, args=training_args,
peft_config=lora_config, train_dataset=dataset,
)
trainer.train()

지난 5년 동안私は Machine Learning과 Deep Learning의 매력적인 세계에 몰두해 왔습니다.私の情熱と専門知識は、AI/ML에 중점을 둔 50개 이상의 다양한 소프트웨어 엔지니어링 프로젝트에 기여했습니다.私の継続的な 호기심은 또한 자연어 처리 분야로私の 관심을 끌었고, 더 깊이 탐구하고 싶은 분야입니다.