Suivez nous sur

Attention : révolutionner l'efficacité des transformateurs

Intelligence Artificielle

Attention : révolutionner l'efficacité des transformateurs

mm
FlashAttention-3 : attention rapide et précise avec asynchronie et faible précision

À mesure que les modèles de transformateurs augmentent en taille et en complexité, ils sont confrontés à des défis importants en termes d’efficacité informatique et d’utilisation de la mémoire, en particulier lorsqu'il s'agit de longues séquences. Flash Attention est une technique d'optimisation qui promet de révolutionner la façon dont nous implémentons et faisons évoluer les mécanismes d'attention dans les modèles Transformer.

Dans ce guide complet, nous allons plonger en profondeur dans Flash Attention, en explorant ses concepts de base, ses détails de mise en œuvre et l'impact profond qu'il a sur le domaine de l'apprentissage automatique.

Le problème : l’attention coûte cher

Avant de nous plonger dans la solution, comprenons d'abord le problème que Flash Attention vise à résoudre. mécanisme d'attention, bien que puissant, s'accompagne d'un coût de calcul important, en particulier pour les longues séquences.

Attention standard : un récapitulatif rapide

Le mécanisme d'attention standard dans les modèles Transformer peut être résumé par l'équation suivante :

Attention(Q, K, V) = softmax(QK^T / √d) V

Où Q, K et V sont respectivement les matrices Requête, Clé et Valeur, et d est la dimension des vecteurs clés.

Bien que cette formulation soit élégante, sa mise en œuvre entraîne plusieurs inefficacités :

  1. Goulot d'étranglement de la mémoire: La matrice d'attention intermédiaire (QK^T) a une taille de N x N, où N est la longueur de la séquence. Pour les longues séquences, cela peut rapidement épuiser la mémoire GPU disponible.
  2. Accès à la mémoire redondante: Dans les implémentations standard, la matrice d'attention est calculée, stockée dans la mémoire à large bande passante (HBM), puis relue pour l'opération softmax. Cet accès mémoire redondant constitue un goulot d’étranglement majeur.
  3. Sous-utilisation du calcul GPULes GPU modernes ont une capacité de calcul (flops) nettement supérieure à la bande passante mémoire. L'implémentation standard de l'attention est limitée à la mémoire, ce qui laisse une grande partie du potentiel de calcul du GPU inexploité.

Illustrons cela avec un simple extrait de code Python qui montre l’implémentation standard de l’attention :

</pre>
import torch

def standard_attention(Q, K, V):
# Q, K, V shape: (batch_size, seq_len, d_model)
d_k = K.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))
attention_weights = torch.softmax(scores, dim=-1)
return torch.matmul(attention_weights, V)

Cette mise en œuvre, bien que simple, souffre des inefficacités mentionnées ci-dessus. Le scores le tenseur, qui a une forme (batch_size, seq_len, seq_len), peut devenir d'une taille prohibitive pour les longues séquences.

Entrez l'attention flash

Attention éclair, présenté par Tri Dao et ses collègues dans leur article de 2022, il s'agit d'une approche de l'attention informatique qui réduit considérablement l'utilisation de la mémoire et améliore l'efficacité des calculs. Les idées clés derrière Flash Attention sont les suivantes :

  1. Carrelage: Décomposez la grande matrice d'attention en tuiles plus petites qui s'intègrent dans la SRAM rapide sur puce.
  2. Recalcul: Au lieu de stocker l’intégralité de la matrice d’attention, recalculez-en certaines parties selon les besoins lors du passage en arrière.
  3. Implémentation IO-Aware: optimisez l'algorithme pour minimiser le mouvement des données entre les différents niveaux de la hiérarchie de la mémoire GPU.

L'algorithme d'attention flash

À la base, Flash Attention réinvente la façon dont nous calculons le mécanisme d’attention. Au lieu de calculer l’intégralité de la matrice d’attention en une seule fois, il la traite par blocs, en exploitant la hiérarchie de mémoire des GPU modernes.

Voici un aperçu de haut niveau de l’algorithme :

  1. Entrée : Matrices Q, K, V en HBM (High Bandwidth Memory) et SRAM sur puce de taille M.
  2. Les tailles de bloc sont calculées en fonction de la SRAM disponible.
  3. Initialisation de la matrice de sortie O et des vecteurs auxiliaires l et m.
  4. L'algorithme divise les matrices d'entrée en blocs pour les adapter à la SRAM.
  5. Deux boucles imbriquées traitent ces blocs :
    • La boucle externe charge les blocs K et V
    • La boucle interne charge les blocs Q et effectue des calculs
  6. Les calculs sur puce incluent la multiplication matricielle, le softmax et le calcul de sortie.
  7. Les résultats sont réécrits dans HBM après le traitement de chaque bloc.

Ce calcul par blocs permet à Flash Attention de conserver une empreinte mémoire beaucoup plus petite tout en continuant à calculer une attention exacte.

Les mathématiques derrière l'attention flash

La clé pour faire fonctionner Flash Attention est une astuce mathématique qui nous permet de calculer softmax par blocs. L'article présente deux formules clés :

  1. Décomposition Softmax :
    softmax(x) = exp(x - m) / Σexp(x - m)

    où m est la valeur maximale de x.

  2. Fusion Softmax :
    softmax(x ∪ y) = softmax(softmax(x) * e^(m_x - m), softmax(y) * e^(m_y - m))

    où m = max(m_x, m_y)

Ces formules permettent à Flash Attention de calculer les résultats softmax partiels pour chaque bloc, puis de les combiner correctement pour obtenir le résultat final.

Détails d'implémentation

Plongeons dans une implémentation simplifiée de Flash Attention pour illustrer ses concepts de base :

import torch

def flash_attention(Q, K, V, block_size=256):
    batch_size, seq_len, d_model = Q.shape
    
    # Initialize output and running statistics
    O = torch.zeros_like(Q)
    L = torch.zeros((batch_size, seq_len, 1))
    M = torch.full((batch_size, seq_len, 1), float('-inf'))
    
    for i in range(0, seq_len, block_size):
        Q_block = Q[:, i:i+block_size, :]
        
        for j in range(0, seq_len, block_size):
            K_block = K[:, j:j+block_size, :]
            V_block = V[:, j:j+block_size, :]
            
            # Compute attention scores for this block
            S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (d_model ** 0.5)
            
            # Update running max
            M_new = torch.maximum(M[:, i:i+block_size], S_block.max(dim=-1, keepdim=True).values)
            
            # Compute exponentials
            exp_S = torch.exp(S_block - M_new)
            exp_M_diff = torch.exp(M[:, i:i+block_size] - M_new)
            
            # Update running sum
            L_new = exp_M_diff * L[:, i:i+block_size] + exp_S.sum(dim=-1, keepdim=True)
            
            # Compute output for this block
            O[:, i:i+block_size] = (
                exp_M_diff * O[:, i:i+block_size] +
                torch.matmul(exp_S, V_block)
            ) / L_new
            
            # Update running statistics
            L[:, i:i+block_size] = L_new
            M[:, i:i+block_size] = M_new
    
    return O

Cette implémentation, bien que simplifiée, capture l'essence de Flash Attention. Il traite l'entrée en blocs, en maintenant les statistiques d'exécution (M et L) pour calculer correctement le softmax sur tous les blocs.

L'impact de l'attention flash

L'introduction de Flash Attention a eu un impact profond sur le domaine de l'apprentissage automatique, en particulier pour les grands modèles de langage et les applications à contexte long. Certains avantages clés incluent :

  1. Utilisation réduite de la mémoire: Flash Attention réduit la complexité de la mémoire de O(N^2) à O(N), où N est la longueur de la séquence. Cela permet de traiter des séquences beaucoup plus longues avec le même matériel.
  2. Vitesse améliorée: En minimisant le mouvement des données et en utilisant mieux les capacités de calcul du GPU, Flash Attention permet d'obtenir des accélérations significatives. Les auteurs signalent une formation jusqu'à 3 fois plus rapide pour GPT-2 par rapport aux implémentations standard.
  3. Calcul exact: Contrairement à d'autres techniques d'optimisation de l'attention, Flash Attention calcule l'attention exacte, et non une approximation.
  4. Évolutivité: L'empreinte mémoire réduite permet une mise à l'échelle vers des séquences beaucoup plus longues, potentiellement jusqu'à des millions de jetons.

Impact réel

L’impact de Flash Attention s’étend au-delà de la recherche universitaire. Il a été rapidement adopté dans de nombreuses bibliothèques et modèles d’apprentissage automatique populaires :

  • Transformateurs de visage étreignant: La populaire bibliothèque Transformers a intégré Flash Attention, permettant aux utilisateurs de tirer facilement parti de ses avantages.
  • GPT-4 et au-delà:Bien que cela ne soit pas confirmé, il existe des spéculations selon lesquelles des modèles de langage avancés comme GPT-4 pourraient utiliser des techniques similaires à Flash Attention pour gérer les contextes longs.
  • Modèles à contexte long: Flash Attention a permis une nouvelle génération de modèles capables de gérer des contextes extrêmement longs, tels que des modèles capables de traiter des livres entiers ou de longues vidéos.

FlashAttention : développements récents

Attention standard contre attention flash

Attention standard contre attention flash

FlashAttention-2

Forte du succès du Flash Attention original, la même équipe introduit FlashAttention-2 en 2023. Cette version mise à jour apporte plusieurs améliorations :

  1. Optimisation supplémentaire: FlashAttention-2 permet une utilisation encore meilleure du GPU, atteignant jusqu'à 70 % des FLOPS de pointe théoriques sur les GPU A100.
  2. Passe arrière amélioré: La passe arrière est optimisée pour être presque aussi rapide que la passe avant, ce qui entraîne des accélérations significatives à l'entraînement.
  3. Prise en charge de différentes variantes d'attention: FlashAttention-2 étend la prise en charge à diverses variantes d'attention, notamment l'attention aux requêtes groupées et l'attention aux requêtes multiples.

FlashAttention-3

Sorti en 2024, FlashAttention-3 représente la dernière avancée dans cette ligne de recherche. Il introduit plusieurs nouvelles techniques pour améliorer encore les performances :

  1. Calcul asynchrone: Tirer parti de la nature asynchrone des nouvelles instructions GPU pour chevaucher différents calculs.
  2. Prise en charge du 8e PC: Utilisation du calcul FP8 de faible précision pour un traitement encore plus rapide.
  3. Traitement incohérent: Une technique pour réduire l'erreur de quantification lors de l'utilisation de formats de faible précision.

Voici un exemple simplifié de la manière dont FlashAttention-3 pourrait exploiter le calcul asynchrone :

import torch
from torch.cuda.amp import autocast

def flash_attention_3(Q, K, V, block_size=256):
    with autocast(dtype=torch.float8):  # Using FP8 for computation
        # ... (similar to previous implementation)
        
        # Asynchronous computation example
        with torch.cuda.stream(torch.cuda.Stream()):
            # Compute GEMM asynchronously
            S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (d_model ** 0.5)
        
        # Meanwhile, on the default stream:
        # Prepare for softmax computation
        
        # Synchronize streams
        torch.cuda.synchronize()
        
        # Continue with softmax and output computation
        # ...

    return O

Cet extrait de code illustre comment FlashAttention-3 peut exploiter le calcul asynchrone et la précision FP8. Notez qu'il s'agit d'un exemple simplifié et que la mise en œuvre réelle serait beaucoup plus complexe et spécifique au matériel.

Implémenter Flash Attention dans vos projets

Si vous souhaitez exploiter Flash Attention dans vos propres projets, plusieurs options s'offrent à vous :

  1. Utiliser les bibliothèques existantes: De nombreuses bibliothèques populaires comme Hugging Face Transformers incluent désormais des implémentations Flash Attention. Une simple mise à jour vers la dernière version et l'activation des indicateurs appropriés peuvent suffire.
  2. Implémentation personnalisée: Pour plus de contrôle ou des cas d'utilisation spécialisés, vous souhaiterez peut-être implémenter Flash Attention vous-même. La bibliothèque xformers fournit une bonne implémentation de référence.
  3. Optimisations spécifiques au matériel:Si vous travaillez avec du matériel spécifique (par exemple, les GPU NVIDIA H100), vous souhaiterez peut-être tirer parti des fonctionnalités spécifiques au matériel pour des performances maximales.

Voici un exemple de la façon dont vous pourriez utiliser Flash Attention avec la bibliothèque Hugging Face Transformers :

from transformers import AutoModel, AutoConfig

# Enable Flash Attention
config = AutoConfig.from_pretrained("bert-base-uncased")
config.use_flash_attention = True

# Load model with Flash Attention
model = AutoModel.from_pretrained("bert-base-uncased", config=config)

# Use the model as usual
# ...

Défis et orientations futures

Bien que Flash Attention ait fait des progrès significatifs dans l’amélioration de l’efficacité des mécanismes d’attention, il reste encore des défis et des domaines de recherche future :

  1. Spécificité matérielle: Les implémentations actuelles sont souvent optimisées pour des architectures GPU spécifiques. Généraliser ces optimisations sur différents matériels reste un défi.
  2. Intégration avec d'autres techniques: La combinaison de Flash Attention avec d'autres techniques d'optimisation telles que l'élagage, la quantification et la compression de modèle est un domaine de recherche actif.
  3. Extension à d'autres domaines: Bien que Flash Attention ait connu un grand succès en PNL, étendre ses avantages à d'autres domaines comme la vision par ordinateur et les modèles multimodaux est un effort continu.
  4. Compréhension théorique: Approfondir notre compréhension théorique des raisons pour lesquelles Flash Attention fonctionne si bien pourrait conduire à des optimisations encore plus puissantes.

Conclusion

 En exploitant intelligemment les hiérarchies de mémoire GPU et en employant des astuces mathématiques, Flash Attention permet d'améliorer considérablement la vitesse et l'utilisation de la mémoire sans sacrifier la précision.

Comme nous l'avons vu dans cet article, l'impact de Flash Attention va bien au-delà d'une simple technique d'optimisation. Il a permis le développement de modèles plus puissants et plus efficaces.

J'ai passé les cinq dernières années à m'immerger dans le monde fascinant du Machine Learning et du Deep Learning. Ma passion et mon expertise m'ont amené à contribuer à plus de 50 projets de génie logiciel divers, avec un accent particulier sur l'IA/ML. Ma curiosité continue m'a également attiré vers le traitement automatique du langage naturel, un domaine que j'ai hâte d'explorer davantage.