Suivez nous sur

Attention : rĂ©volutionner l'efficacitĂ© des transformateurs

Intelligence Artificielle

Attention : rĂ©volutionner l'efficacitĂ© des transformateurs

mm
FlashAttention-3 : attention rapide et prĂ©cise avec asynchronie et faible prĂ©cision

À mesure que les modĂšles de transformateurs augmentent en taille et en complexitĂ©, ils sont confrontĂ©s Ă  des dĂ©fis importants en termes d’efficacitĂ© informatique et d’utilisation de la mĂ©moire, en particulier lorsqu'il s'agit de longues sĂ©quences. Flash Attention est une technique d'optimisation qui promet de rĂ©volutionner la façon dont nous implĂ©mentons et faisons Ă©voluer les mĂ©canismes d'attention dans les modĂšles Transformer.

Dans ce guide complet, nous approfondirons Flash Attention, en explorant ses concepts fondamentaux, les dĂ©tails de sa mise en Ɠuvre et son impact profond sur le domaine de l'apprentissage automatique.

Le problĂšme : l’attention coĂ»te cher

Avant d'aborder la solution, comprenons d'abord le problÚme que Flash Attention vise à résoudre. Le mécanisme d'attention, bien que puissant, s'accompagne d'un coût de calcul important, en particulier pour les longues séquences.

Attention standard : un rĂ©capitulatif rapide

Le mĂ©canisme d'attention standard dans les modĂšles Transformer peut ĂȘtre rĂ©sumĂ© par l'Ă©quation suivante :

Attention(Q, K, V) = softmax(QK^T / √d) V

OĂč Q, K et V sont respectivement les matrices RequĂȘte, ClĂ© et Valeur, et d est la dimension des vecteurs clĂ©s.

Bien que cette formulation soit Ă©lĂ©gante, sa mise en Ɠuvre entraĂźne plusieurs inefficacitĂ©s :

  1. Goulot d'Ă©tranglement de la mĂ©moire: La matrice d'attention intermĂ©diaire (QK^T) a une taille de N x N, oĂč N est la longueur de la sĂ©quence. Pour les longues sĂ©quences, cela peut rapidement Ă©puiser la mĂ©moire GPU disponible.
  2. AccĂšs Ă  la mĂ©moire redondante: Dans les implĂ©mentations standard, la matrice d'attention est calculĂ©e, stockĂ©e dans la mĂ©moire Ă  large bande passante (HBM), puis relue pour l'opĂ©ration softmax. Cet accĂšs mĂ©moire redondant constitue un goulot d’étranglement majeur.
  3. Sous-utilisation du calcul GPU: Les GPU modernes ont une capacité de calcul (FLOPS) nettement supérieure à la bande passante mémoire. L'implémentation standard de l'attention est limitée à la mémoire, laissant une grande partie du potentiel de calcul du GPU inexploité.

Illustrons cela avec un simple extrait de code Python qui montre l'implĂ©mentation standard de l'attention :

</pre>
import torch

def standard_attention(Q, K, V):
# Q, K, V shape: (batch_size, seq_len, d_model)
d_k = K.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))
attention_weights = torch.softmax(scores, dim=-1)
return torch.matmul(attention_weights, V)

Cette mise en Ɠuvre, bien que simple, souffre des inefficacitĂ©s mentionnĂ©es ci-dessus. Le scores le tenseur, qui a une forme (batch_size, seq_len, seq_len), peut devenir d'une taille prohibitive pour les longues sĂ©quences.

Entrez l'attention flash

Attention Ă©clair, prĂ©sentĂ© par Tri Dao et ses collĂšgues dans leur article de 2022, il s'agit d'une approche de l'attention informatique qui rĂ©duit considĂ©rablement l'utilisation de la mĂ©moire et amĂ©liore l'efficacitĂ© des calculs. Les idĂ©es clĂ©s derriĂšre Flash Attention sont les suivantes :

  1. Carrelage: Décomposez la grande matrice d'attention en tuiles plus petites qui s'intÚgrent dans la SRAM rapide sur puce.
  2. Recalcul: Au lieu de stocker l’intĂ©gralitĂ© de la matrice d’attention, recalculez-en certaines parties selon les besoins lors du passage en arriĂšre.
  3. Implémentation IO-Aware: optimisez l'algorithme pour minimiser le mouvement des données entre les différents niveaux de la hiérarchie de la mémoire GPU.

L'algorithme d'attention flash

À la base, Flash Attention rĂ©invente la façon dont nous calculons le mĂ©canisme d’attention. Au lieu de calculer l’intĂ©gralitĂ© de la matrice d’attention en une seule fois, il la traite par blocs, en exploitant la hiĂ©rarchie de mĂ©moire des GPU modernes.

Voici un aperçu gĂ©nĂ©ral de l'algorithme :

  1. Entrée : Matrices Q, K, V en HBM (High Bandwidth Memory) et SRAM sur puce de taille M.
  2. Les tailles de bloc sont calculées en fonction de la SRAM disponible.
  3. Initialisation de la matrice de sortie O et des vecteurs auxiliaires l et m.
  4. L'algorithme divise les matrices d'entrée en blocs pour les adapter à la SRAM.
  5. Deux boucles imbriquĂ©es traitent ces blocs :
    • La boucle externe charge les blocs K et V
    • La boucle interne charge les blocs Q et effectue des calculs
  6. Les calculs sur puce incluent la multiplication matricielle, le softmax et le calcul de sortie.
  7. Les résultats sont réécrits dans HBM aprÚs le traitement de chaque bloc.

Ce calcul par blocs permet à Flash Attention de conserver une empreinte mémoire beaucoup plus petite tout en continuant à calculer une attention exacte.

Les mathématiques derriÚre l'attention flash

La clĂ© pour faire fonctionner Flash Attention est une astuce mathĂ©matique qui nous permet de calculer softmax par blocs. L'article prĂ©sente deux formules clĂ©s :

  1. DĂ©composition Softmax :
    softmax(x) = exp(x - m) / ÎŁexp(x - m)

    oĂč m est la valeur maximale de x.

  2. Fusion Softmax :
    softmax(x âˆȘ y) = softmax(softmax(x) * e^(m_x - m), softmax(y) * e^(m_y - m))

    oĂč m = max(m_x, m_y)

Ces formules permettent à Flash Attention de calculer les résultats softmax partiels pour chaque bloc, puis de les combiner correctement pour obtenir le résultat final.

Détails d'implémentation

Plongeons dans une implĂ©mentation simplifiĂ©e de Flash Attention pour illustrer ses concepts fondamentaux :

import torch

def flash_attention(Q, K, V, block_size=256):
    batch_size, seq_len, d_model = Q.shape
    
    # Initialize output and running statistics
    O = torch.zeros_like(Q)
    L = torch.zeros((batch_size, seq_len, 1))
    M = torch.full((batch_size, seq_len, 1), float('-inf'))
    
    for i in range(0, seq_len, block_size):
        Q_block = Q[:, i:i+block_size, :]
        
        for j in range(0, seq_len, block_size):
            K_block = K[:, j:j+block_size, :]
            V_block = V[:, j:j+block_size, :]
            
            # Compute attention scores for this block
            S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (d_model ** 0.5)
            
            # Update running max
            M_new = torch.maximum(M[:, i:i+block_size], S_block.max(dim=-1, keepdim=True).values)
            
            # Compute exponentials
            exp_S = torch.exp(S_block - M_new)
            exp_M_diff = torch.exp(M[:, i:i+block_size] - M_new)
            
            # Update running sum
            L_new = exp_M_diff * L[:, i:i+block_size] + exp_S.sum(dim=-1, keepdim=True)
            
            # Compute output for this block
            O[:, i:i+block_size] = (
                exp_M_diff * O[:, i:i+block_size] +
                torch.matmul(exp_S, V_block)
            ) / L_new
            
            # Update running statistics
            L[:, i:i+block_size] = L_new
            M[:, i:i+block_size] = M_new
    
    return O

Cette implémentation, bien que simplifiée, capture l'essence de Flash Attention. Il traite l'entrée en blocs, en maintenant les statistiques d'exécution (M et L) pour calculer correctement le softmax sur tous les blocs.

L'impact de l'attention flash

L'introduction de Flash Attention a eu un impact profond sur le domaine de l'apprentissage automatique, en particulier pour les grands modĂšles de langage et les applications Ă  contexte long. Certains avantages clĂ©s incluent :

  1. Utilisation rĂ©duite de la mĂ©moire: Flash Attention rĂ©duit la complexitĂ© de la mĂ©moire de O(N^2) Ă  O(N), oĂč N est la longueur de la sĂ©quence. Cela permet de traiter des sĂ©quences beaucoup plus longues avec le mĂȘme matĂ©riel.
  2. Vitesse améliorée: En minimisant le mouvement des données et en utilisant mieux les capacités de calcul du GPU, Flash Attention permet d'obtenir des accélérations significatives. Les auteurs signalent une formation jusqu'à 3 fois plus rapide pour GPT-2 par rapport aux implémentations standard.
  3. Calcul exact: Contrairement Ă  d'autres techniques d'optimisation de l'attention, Flash Attention calcule l'attention exacte, et non une approximation.
  4. ÉvolutivitĂ©: L'empreinte mĂ©moire rĂ©duite permet une mise Ă  l'Ă©chelle vers des sĂ©quences beaucoup plus longues, potentiellement jusqu'Ă  des millions de jetons.

Impact réel

L’impact de Flash Attention s’étend au-delĂ  de la recherche universitaire. Il a Ă©tĂ© rapidement adoptĂ© dans de nombreuses bibliothĂšques et modĂšles d’apprentissage automatique populaires :

  • Transformateurs de visage Ă©treignant: La populaire bibliothĂšque Transformers a intĂ©grĂ© Flash Attention, permettant aux utilisateurs de tirer facilement parti de ses avantages.
  • GPT-4 et au-delĂ : Bien que cela ne soit pas confirmĂ©, il existe des spĂ©culations selon lesquelles des modĂšles de langage avancĂ©s tels que GPT-4 pourraient utiliser des techniques similaires Ă  Flash Attention pour gĂ©rer des contextes longs.
  • ModĂšles Ă  contexte long: Flash Attention a permis une nouvelle gĂ©nĂ©ration de modĂšles capables de gĂ©rer des contextes extrĂȘmement longs, tels que des modĂšles capables de traiter des livres entiers ou de longues vidĂ©os.

FlashAttention : dĂ©veloppements rĂ©cents

Attention standard contre attention flash

Attention standard contre attention flash

FlashAttention-2

Forte du succĂšs du Flash Attention original, la mĂȘme Ă©quipe introduit FlashAttention-2 en 2023. Cette version mise Ă  jour apporte plusieurs amĂ©liorations :

  1. Optimisation supplémentaire: FlashAttention-2 permet une utilisation encore meilleure du GPU, atteignant jusqu'à 70 % des FLOPS de pointe théoriques sur les GPU A100.
  2. Passe arriĂšre amĂ©liorĂ©: La passe arriĂšre est optimisĂ©e pour ĂȘtre presque aussi rapide que la passe avant, ce qui entraĂźne des accĂ©lĂ©rations significatives Ă  l'entraĂźnement.
  3. Prise en charge de diffĂ©rentes variantes d'attention: FlashAttention-2 Ă©tend la prise en charge Ă  diverses variantes d'attention, notamment l'attention aux requĂȘtes groupĂ©es et l'attention aux requĂȘtes multiples.

FlashAttention-3

Sorti en 2024, FlashAttention-3 reprĂ©sente la derniĂšre avancĂ©e dans cette ligne de recherche. Il introduit plusieurs nouvelles techniques pour amĂ©liorer encore les performances :

  1. Calcul asynchrone: Tirer parti de la nature asynchrone des nouvelles instructions GPU pour chevaucher différents calculs.
  2. Prise en charge du 8e PC: Utilisation du calcul FP8 de faible précision pour un traitement encore plus rapide.
  3. Traitement incohérent: Une technique pour réduire l'erreur de quantification lors de l'utilisation de formats de faible précision.

Voici un exemple simplifiĂ© de la maniĂšre dont FlashAttention-3 peut exploiter le calcul asynchrone :

import torch
from torch.cuda.amp import autocast

def flash_attention_3(Q, K, V, block_size=256):
    with autocast(dtype=torch.float8):  # Using FP8 for computation
        # ... (similar to previous implementation)
        
        # Asynchronous computation example
        with torch.cuda.stream(torch.cuda.Stream()):
            # Compute GEMM asynchronously
            S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (d_model ** 0.5)
        
        # Meanwhile, on the default stream:
        # Prepare for softmax computation
        
        # Synchronize streams
        torch.cuda.synchronize()
        
        # Continue with softmax and output computation
        # ...

    return O

Cet extrait de code illustre comment FlashAttention-3 peut exploiter le calcul asynchrone et la prĂ©cision FP8. Notez qu'il s'agit d'un exemple simplifiĂ© et que la mise en Ɠuvre rĂ©elle serait beaucoup plus complexe et spĂ©cifique au matĂ©riel.

Implémenter Flash Attention dans vos projets

Si vous souhaitez tirer parti de Flash Attention dans vos propres projets, vous disposez de plusieurs options :

  1. Utiliser les bibliothÚques existantes: De nombreuses bibliothÚques populaires comme Hugging Face Transformers incluent désormais des implémentations Flash Attention. Une simple mise à jour vers la derniÚre version et l'activation des indicateurs appropriés peuvent suffire.
  2. ImplĂ©mentation personnalisĂ©e: Pour plus de contrĂŽle ou des cas d'utilisation spĂ©cialisĂ©s, vous souhaiterez peut-ĂȘtre implĂ©menter Flash Attention vous-mĂȘme. La bibliothĂšque xformers fournit une bonne implĂ©mentation de rĂ©fĂ©rence.
  3. Optimisations spĂ©cifiques au matĂ©riel: Si vous travaillez avec du matĂ©riel spĂ©cifique (par exemple, les GPU NVIDIA H100), vous souhaiterez peut-ĂȘtre tirer parti des fonctionnalitĂ©s spĂ©cifiques au matĂ©riel pour des performances maximales.

Voici un exemple de la façon dont vous pouvez utiliser Flash Attention avec la bibliothĂšque Hugging Face Transformers :

from transformers import AutoModel, AutoConfig

# Enable Flash Attention
config = AutoConfig.from_pretrained("bert-base-uncased")
config.use_flash_attention = True

# Load model with Flash Attention
model = AutoModel.from_pretrained("bert-base-uncased", config=config)

# Use the model as usual
# ...

Défis et orientations futures

Bien que Flash Attention ait fait des progrĂšs significatifs dans l’amĂ©lioration de l’efficacitĂ© des mĂ©canismes d’attention, il reste encore des dĂ©fis et des domaines de recherche future :

  1. Spécificité matérielle: Les implémentations actuelles sont souvent optimisées pour des architectures GPU spécifiques. Généraliser ces optimisations sur différents matériels reste un défi.
  2. Intégration avec d'autres techniques: La combinaison de Flash Attention avec d'autres techniques d'optimisation telles que l'élagage, la quantification et la compression de modÚle est un domaine de recherche actif.
  3. Extension à d'autres domaines: Bien que Flash Attention ait connu un grand succÚs en PNL, étendre ses avantages à d'autres domaines comme la vision par ordinateur et les modÚles multimodaux est un effort continu.
  4. Compréhension théorique: Approfondir notre compréhension théorique des raisons pour lesquelles Flash Attention fonctionne si bien pourrait conduire à des optimisations encore plus puissantes.

Conclusion

 En exploitant intelligemment les hiĂ©rarchies de mĂ©moire GPU et en employant des astuces mathĂ©matiques, Flash Attention permet d'amĂ©liorer considĂ©rablement la vitesse et l'utilisation de la mĂ©moire sans sacrifier la prĂ©cision.

Comme nous l'avons exploré dans cet article, l'impact de Flash Attention s'étend bien au-delà d'une simple technique d'optimisation. Elle a permis le développement de modÚles plus puissants et plus efficaces.

J'ai passé les cinq derniÚres années à m'immerger dans le monde fascinant du Machine Learning et du Deep Learning. Ma passion et mon expertise m'ont amené à contribuer à plus de 50 projets de génie logiciel divers, avec un accent particulier sur l'IA/ML. Ma curiosité continue m'a également attiré vers le traitement automatique du langage naturel, un domaine que j'ai hùte d'explorer davantage.