Intelligence artificielle

Jamba : Le nouveau modèle hybride Transformer-Mamba de AI21 Labs

mm
Jamba AI21 style, a sleek hybrid machine with glowing circuitry, merging Transformer and Mamba components, surrounded by swirling data streams and abstract neural connections, set against a futuristic backdrop with soft, ambient lighting

Les modèles de langage ont connu des progrès rapides, avec les architectures basées sur les Transformers à la tête de la procession dans le traitement automatique des langues. Cependant, à mesure que les modèles s’étendent, les défis liés à la gestion de longs contextes, à l’efficacité de la mémoire et au débit sont devenus plus prononcés.

AI21 Labs a introduit une nouvelle solution avec Jamba, un modèle de langage à grande échelle (LLM) d’état de l’art qui combine les forces des architectures Transformer et Mamba dans un cadre hybride. Cet article détaille l’architecture, les performances et les applications potentielles de Jamba.

Présentation de Jamba

Jamba est un modèle de langage hybride développé par AI21 Labs, exploitant une combinaison de couches Transformer et de couches Mamba, intégrées avec un module Mixture-of-Experts (MoE). Cette architecture permet à Jamba d’équilibrer l’utilisation de la mémoire, le débit et les performances, ce qui en fait un outil puissant pour une large gamme de tâches de traitement automatique des langues. Le modèle est conçu pour tenir dans une seule carte graphique de 80 Go, offrant un débit élevé et une petite empreinte mémoire tout en maintenant des performances de pointe sur divers benchmarks.

L’architecture de Jamba

L’architecture de Jamba est la clé de ses capacités. Elle repose sur une conception hybride novatrice qui intercale des couches Transformer avec des couches Mamba, en incorporant des modules MoE pour améliorer la capacité du modèle sans augmenter considérablement les exigences de calcul.

1. Couches Transformer

L’architecture Transformer est devenue la norme pour les modèles de langage modernes en raison de sa capacité à gérer le traitement parallèle de manière efficace et à capturer les dépendances à longue portée dans le texte. Cependant, ses performances sont souvent limitées par des exigences élevées en termes de mémoire et de calcul, en particulier lors du traitement de longs contextes. Jamba répond à ces limitations en intégrant des couches Mamba, que nous allons explorer ensuite.

2. Couches Mamba

Mamba est un modèle d’espace d’état (SSM) récent conçu pour gérer les relations à longue distance dans les séquences de manière plus efficace que les RNN traditionnels ou même les Transformers. Les couches Mamba sont particulièrement efficaces pour réduire l’empreinte mémoire associée au stockage des caches de clés-valeurs (KV) dans les Transformers. En intercalant des couches Mamba avec des couches Transformer, Jamba réduit l’utilisation globale de la mémoire tout en maintenant de hautes performances, en particulier dans les tâches nécessitant la gestion de longs contextes.

3. Modules Mixture-of-Experts (MoE)

Le module MoE dans Jamba introduit une approche flexible pour mettre à l’échelle la capacité du modèle. MoE permet au modèle d’augmenter le nombre de paramètres disponibles sans augmenter proportionnellement les paramètres actifs pendant l’inférence. Dans Jamba, MoE est appliqué à certaines des couches MLP, avec le mécanisme de routage sélectionnant les meilleurs experts à activer pour chaque jeton. Cette activation sélective permet à Jamba de maintenir une haute efficacité tout en gérant des tâches complexes.

L’image ci-dessous démontre le fonctionnement d’une tête d’induction dans un modèle hybride Attention-Mamba, une fonctionnalité clé de Jamba. Dans cet exemple, la tête d’attention est responsable de prédire des étiquettes telles que “Positif” ou “Négatif” en réponse à des tâches d’analyse de sentiments. Les mots mis en évidence illustrent comment l’attention du modèle est fortement concentrée sur les jetons d’étiquette à partir des exemples à quelques coups, en particulier au moment critique avant de prédire l’étiquette finale. Ce mécanisme d’attention joue un rôle crucial dans la capacité du modèle à effectuer un apprentissage en contexte, où le modèle doit déduire l’étiquette appropriée en fonction du contexte donné et des exemples à quelques coups.

Les améliorations de performances offertes par l’intégration de Mixture-of-Experts (MoE) avec l’architecture hybride Attention-Mamba sont mises en évidence dans le tableau. En utilisant MoE, Jamba augmente sa capacité sans augmenter proportionnellement les coûts de calcul. C’est particulièrement évident dans l’augmentation significative des performances sur divers benchmarks tels que HellaSwag, WinoGrande et Natural Questions (NQ). Le modèle avec MoE non seulement atteint une précision plus élevée (par exemple, 66,0 % sur WinoGrande par rapport à 62,5 % sans MoE), mais démontre également des log-probabilités améliorées sur différents domaines (par exemple, -0,534 sur C4).

Caractéristiques architecturales clés

  • Composition des couches : L’architecture de Jamba se compose de blocs qui combinent les couches Mamba et les couches Transformer dans un rapport spécifique (par exemple, 1:7, ce qui signifie une couche Transformer pour sept couches Mamba). Ce rapport est réglé pour des performances et une efficacité optimales.
  • Intégration de MoE : Les couches MoE sont appliquées toutes les quelques couches, avec 16 experts disponibles et les deux meilleurs experts activés par jeton. Cette configuration permet à Jamba de mettre à l’échelle de manière efficace tout en gérant les compromis entre l’utilisation de la mémoire et l’efficacité de calcul.
  • Normalisation et stabilité : Pour assurer la stabilité pendant l’entraînement, Jamba intègre RMSNorm dans les couches Mamba, ce qui aide à atténuer les problèmes tels que les pics d’activation importants qui peuvent survenir à grande échelle.

Performances et évaluation de Jamba

Jamba a été rigoureusement testé sur une large gamme de benchmarks, démontrant des performances compétitives dans l’ensemble. Les sections suivantes mettent en évidence certains des principaux benchmarks sur lesquels Jamba a excellé, mettant en valeur ses forces dans les tâches de traitement automatique des langues générales et les scénarios de long contexte.

1. Benchmarks de traitement automatique des langues courants

Jamba a été évalué sur plusieurs benchmarks universitaires, notamment :

  • HellaSwag (10 coups) : Une tâche de raisonnement basé sur le bon sens où Jamba a obtenu un score de performance de 87,1 %, surpassant de nombreux modèles concurrents.
  • WinoGrande (5 coups) : Une autre tâche de raisonnement où Jamba a obtenu un score de 82,5 %, mettant à nouveau en évidence sa capacité à gérer des raisonnements linguistiques complexes.
  • ARC-Challenge (25 coups) : Jamba a démontré de fortes performances avec un score de 64,4 %, reflétant sa capacité à gérer des questions à choix multiples difficiles.

Dans les benchmarks agrégés comme MMLU (5 coups), Jamba a obtenu un score de 67,4 %, indiquant sa robustesse sur diverses tâches.

2. Évaluations de long contexte

L’une des fonctionnalités de Jamba est sa capacité à gérer des contextes extrêmement longs. Le modèle prend en charge une longueur de contexte allant jusqu’à 256 K jetons, la plus longue parmi les modèles disponibles publiquement. Cette capacité a été testée à l’aide du benchmark Needle-in-a-Haystack, où Jamba a montré une précision de récupération exceptionnelle sur différentes longueurs de contexte, allant jusqu’à 256 K jetons.

3. Débit et efficacité

L’architecture hybride de Jamba améliore considérablement le débit, en particulier avec de longues séquences.

Dans les tests comparant le débit (jetons par seconde) entre différents modèles, Jamba a constamment surpassé ses pairs, en particulier dans les scénarios impliquant de grandes tailles de lots et de longs contextes. Par exemple, avec un contexte de 128 K jetons, Jamba a atteint 3 fois le débit de Mixtral, un modèle comparable.

Utilisation de Jamba : Python

Pour les développeurs et les chercheurs impatients d’expérimenter avec Jamba, AI21 Labs a mis le modèle à disposition sur des plateformes comme Hugging Face, le rendant accessible pour une large gamme d’applications. Le code suivant montre comment charger et générer du texte à l’aide de Jamba :


<p>from transformers import AutoModelForCausalLM, AutoTokenizer</p>

<p>model = AutoModelForCausalLM.from_pretrained(&quot;ai21labs/Jamba-v0.1&quot;)
tokenizer = AutoTokenizer.from_pretrained(&quot;ai21labs/Jamba-v0.1&quot;)</p>

<p>input_ids = tokenizer(&quot;Dans le Super Bowl LVIII récent,&quot;, return_tensors=&#039;pt&#039;).to(model.device)[&quot;input_ids&quot;]</p>

<p>outputs = model.generate(input_ids, max_new_tokens=216)</p>

print(tokenizer.batch_decode(outputs))

Ce script simple charge le modèle Jamba et le tokenizer, génère du texte en fonction d’une invite de saisie donnée et imprime la sortie générée.

Affiner Jamba

Jamba est conçu comme un modèle de base, ce qui signifie qu’il peut être affiné pour des tâches ou des applications spécifiques. L’affinement permet aux utilisateurs d’adapter le modèle à des domaines de niche, améliorant ainsi les performances sur des tâches spécialisées. L’exemple suivant montre comment affiner Jamba à l’aide de la bibliothèque PEFT :

import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments

<p>tokenizer = AutoTokenizer.from_pretrained(&quot;ai21labs/Jamba-v0.1&quot;)
model = AutoModelForCausalLM.from_pretrained(
&quot;ai21labs/Jamba-v0.1&quot;, device_map=&#039;auto&#039;, torch_dtype=torch.bfloat16)</p>

<p>lora_config = LoraConfig(r=8,
target_modules=[
&quot;embed_tokens&quot;,&quot;x_proj&quot;, &quot;in_proj&quot;, &quot;out_proj&quot;, # mamba
&quot;gate_proj&quot;, &quot;up_proj&quot;, &quot;down_proj&quot;, # mlp
&quot;q_proj&quot;, &quot;k_proj&quot;, &quot;v_proj&quot;
# attention],
task_type=&quot;CAUSAL_LM&quot;, bias=&quot;none&quot;)</p>

<p>dataset = load_dataset(&quot;Abirate/english_quotes&quot;, split=&quot;train&quot;)
training_args = SFTConfig(output_dir=&quot;./results&quot;,
num_train_epochs=2,
per_device_train_batch_size=4,
logging_dir=&#039;./logs&#039;,
logging_steps=10, learning_rate=1e-5, dataset_text_field=&quot;quote&quot;)
trainer = SFTTrainer(model=model, tokenizer=tokenizer, args=training_args,
peft_config=lora_config, train_dataset=dataset,
)
trainer.train()

Ce code affine Jamba sur un jeu de données de citations anglaises, ajustant les paramètres du modèle pour mieux s’adapter à la tâche spécifique de génération de texte dans un domaine spécialisé.

Déploiement et intégration

AI21 Labs a rendu la famille Jamba largement accessible via diverses plateformes et options de déploiement :

  1. Plateformes cloud :
    • Disponible sur les principaux fournisseurs de cloud, y compris Google Cloud Vertex AI, Microsoft Azure et NVIDIA NIM.
    • Bientôt disponible sur Amazon Bedrock, Databricks Marketplace et Snowflake Cortex.
  2. Cadres de développement d’IA :
    • Intégration avec des cadres populaires comme LangChain et LlamaIndex (à venir).
  3. AI21 Studio :
    • Accès direct via la plateforme de développement d’AI21.
  4. Hugging Face :
    • Modèles disponibles pour téléchargement et expérimentation.
  5. Déploiement sur site :
    • Options pour un déploiement privé sur site pour les organisations ayant des besoins spécifiques en matière de sécurité ou de conformité.
  6. Solutions personnalisées :
    • AI21 propose des services de personnalisation et d’affinement de modèle pour les clients d’entreprise.

Fonctionnalités conviviales pour les développeurs

Les modèles Jamba disposent de plusieurs fonctionnalités intégrées qui les rendent particulièrement attrayants pour les développeurs :

  1. Appel de fonction : Intégrez facilement des outils et des API externes dans vos flux de travail d’IA.
  2. Sortie JSON structurée : Générez des structures de données propres et analysables directement à partir de saisies de langage naturel.
  3. Digestion d’objets de document : Traitez et comprenez efficacement des structures de document complexes.
  4. Optimisations RAG : Fonctionnalités intégrées pour améliorer les pipelines de génération augmentée de récupération.

Ces fonctionnalités, combinées avec la fenêtre de contexte longue du modèle et le traitement efficace, font de Jamba un outil polyvalent pour une large gamme de scénarios de développement.

Considérations éthiques et IA responsable

Bien que les capacités de Jamba soient impressionnantes, il est crucial d’aborder son utilisation avec une mentalité d’IA responsable. AI21 Labs met l’accent sur plusieurs points importants :

  1. Nature du modèle de base : Les modèles Jamba 1.5 sont des modèles de base préentraînés sans alignement ou réglage d’instruction spécifique.
  2. Manque de mécanismes de modération intégrés : Les modèles ne disposent pas de mécanismes de modération inhérents.
  3. Déploiement prudent : Des adaptations et des mécanismes de sécurité supplémentaires doivent être mis en œuvre avant d’utiliser Jamba dans des environnements de production ou avec des utilisateurs finals.
  4. Confidentialité des données : Lors de l’utilisation de déploiements basés sur le cloud, soyez attentif à la gestion et aux exigences de conformité des données.
  5. Conscience des biais : Comme tous les grands modèles de langage, Jamba peut refléter les biais présents dans ses données d’entraînement. Les utilisateurs doivent en être conscients et mettre en œuvre des mesures d’atténuation appropriées.

En tenant compte de ces facteurs, les développeurs et les organisations peuvent exploiter les capacités de Jamba de manière responsable et éthique.

Un nouveau chapitre dans le développement de l’IA ?

L’introduction de la famille Jamba par AI21 Labs marque un jalon important dans l’évolution des grands modèles de langage. En combinant les forces des Transformers et des modèles d’espace d’état, en intégrant des techniques de Mixture-of-Experts et en repoussant les limites de la longueur du contexte et de la vitesse de traitement, Jamba ouvre de nouvelles perspectives pour les applications d’IA à travers les industries.

À mesure que la communauté de l’IA continue d’explorer et de construire sur cette architecture innovante, nous pouvons nous attendre à voir de nouvelles avancées en termes d’efficacité du modèle, de compréhension de long contexte et de déploiement pratique de l’IA. La famille Jamba représente non seulement un nouvel ensemble de modèles, mais également un potentiel changement de cap dans la conception et la mise en œuvre de grands systèmes d’IA.

J'ai passé les cinq dernières années à plonger dans le monde fascinant de l'apprentissage automatique et du deep learning. Ma passion et mon expertise m'ont conduit à contribuer à plus de 50 projets de génie logiciel divers, avec un focus particulier sur l'IA/ML. Ma curiosité continue m'a également attiré vers le traitement automatique des langues, un domaine que je suis impatient d'explorer plus en profondeur.