Intelligence Artificielle
Attention : révolutionner l'efficacité des transformateurs

Cette mise en Ćuvre, bien que simple, souffre des inefficacitĂ©s mentionnĂ©es ci-dessus. Le scores
le tenseur, qui a une forme (batch_size, seq_len, seq_len), peut devenir d'une taille prohibitive pour les longues séquences.
Entrez l'attention flash
Attention éclair, présenté par Tri Dao et ses collÚgues dans leur article de 2022, il s'agit d'une approche de l'attention informatique qui réduit considérablement l'utilisation de la mémoire et améliore l'efficacité des calculs. Les idées clés derriÚre Flash Attention sont les suivantes :
- Carrelage: Décomposez la grande matrice d'attention en tuiles plus petites qui s'intÚgrent dans la SRAM rapide sur puce.
- Recalcul: Au lieu de stocker lâintĂ©gralitĂ© de la matrice dâattention, recalculez-en certaines parties selon les besoins lors du passage en arriĂšre.
- Implémentation IO-Aware: optimisez l'algorithme pour minimiser le mouvement des données entre les différents niveaux de la hiérarchie de la mémoire GPU.
L'algorithme d'attention flash
Ă la base, Flash Attention rĂ©invente la façon dont nous calculons le mĂ©canisme dâattention. Au lieu de calculer lâintĂ©gralitĂ© de la matrice dâattention en une seule fois, il la traite par blocs, en exploitant la hiĂ©rarchie de mĂ©moire des GPU modernes.
Voici un aperçu général de l'algorithme :
- Entrée : Matrices Q, K, V en HBM (High Bandwidth Memory) et SRAM sur puce de taille M.
- Les tailles de bloc sont calculées en fonction de la SRAM disponible.
- Initialisation de la matrice de sortie O et des vecteurs auxiliaires l et m.
- L'algorithme divise les matrices d'entrée en blocs pour les adapter à la SRAM.
- Deux boucles imbriquées traitent ces blocs :
- La boucle externe charge les blocs K et V
- La boucle interne charge les blocs Q et effectue des calculs
- Les calculs sur puce incluent la multiplication matricielle, le softmax et le calcul de sortie.
- Les résultats sont réécrits dans HBM aprÚs le traitement de chaque bloc.
Ce calcul par blocs permet à Flash Attention de conserver une empreinte mémoire beaucoup plus petite tout en continuant à calculer une attention exacte.
Les mathématiques derriÚre l'attention flash
La clé pour faire fonctionner Flash Attention est une astuce mathématique qui nous permet de calculer softmax par blocs. L'article présente deux formules clés :
- Décomposition Softmax :
softmax(x) = exp(x - m) / ÎŁexp(x - m)
oĂč m est la valeur maximale de x.
- Fusion Softmax :
softmax(x âȘ y) = softmax(softmax(x) * e^(m_x - m), softmax(y) * e^(m_y - m))
oĂč m = max(m_x, m_y)
Ces formules permettent à Flash Attention de calculer les résultats softmax partiels pour chaque bloc, puis de les combiner correctement pour obtenir le résultat final.
Détails d'implémentation
Plongeons dans une implémentation simplifiée de Flash Attention pour illustrer ses concepts fondamentaux :
import torch def flash_attention(Q, K, V, block_size=256): batch_size, seq_len, d_model = Q.shape # Initialize output and running statistics O = torch.zeros_like(Q) L = torch.zeros((batch_size, seq_len, 1)) M = torch.full((batch_size, seq_len, 1), float('-inf')) for i in range(0, seq_len, block_size): Q_block = Q[:, i:i+block_size, :] for j in range(0, seq_len, block_size): K_block = K[:, j:j+block_size, :] V_block = V[:, j:j+block_size, :] # Compute attention scores for this block S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (d_model ** 0.5) # Update running max M_new = torch.maximum(M[:, i:i+block_size], S_block.max(dim=-1, keepdim=True).values) # Compute exponentials exp_S = torch.exp(S_block - M_new) exp_M_diff = torch.exp(M[:, i:i+block_size] - M_new) # Update running sum L_new = exp_M_diff * L[:, i:i+block_size] + exp_S.sum(dim=-1, keepdim=True) # Compute output for this block O[:, i:i+block_size] = ( exp_M_diff * O[:, i:i+block_size] + torch.matmul(exp_S, V_block) ) / L_new # Update running statistics L[:, i:i+block_size] = L_new M[:, i:i+block_size] = M_new return O
Cette implémentation, bien que simplifiée, capture l'essence de Flash Attention. Il traite l'entrée en blocs, en maintenant les statistiques d'exécution (M et L) pour calculer correctement le softmax sur tous les blocs.
L'impact de l'attention flash
L'introduction de Flash Attention a eu un impact profond sur le domaine de l'apprentissage automatique, en particulier pour les grands modÚles de langage et les applications à contexte long. Certains avantages clés incluent :
- Utilisation rĂ©duite de la mĂ©moire: Flash Attention rĂ©duit la complexitĂ© de la mĂ©moire de O(N^2) Ă O(N), oĂč N est la longueur de la sĂ©quence. Cela permet de traiter des sĂ©quences beaucoup plus longues avec le mĂȘme matĂ©riel.
- Vitesse améliorée: En minimisant le mouvement des données et en utilisant mieux les capacités de calcul du GPU, Flash Attention permet d'obtenir des accélérations significatives. Les auteurs signalent une formation jusqu'à 3 fois plus rapide pour GPT-2 par rapport aux implémentations standard.
- Calcul exact: Contrairement Ă d'autres techniques d'optimisation de l'attention, Flash Attention calcule l'attention exacte, et non une approximation.
- ĂvolutivitĂ©: L'empreinte mĂ©moire rĂ©duite permet une mise Ă l'Ă©chelle vers des sĂ©quences beaucoup plus longues, potentiellement jusqu'Ă des millions de jetons.
Impact réel
Lâimpact de Flash Attention sâĂ©tend au-delĂ de la recherche universitaire. Il a Ă©tĂ© rapidement adoptĂ© dans de nombreuses bibliothĂšques et modĂšles dâapprentissage automatique populaires :
- Transformateurs de visage étreignant: La populaire bibliothÚque Transformers a intégré Flash Attention, permettant aux utilisateurs de tirer facilement parti de ses avantages.
- GPT-4 et au-delà : Bien que cela ne soit pas confirmé, il existe des spéculations selon lesquelles des modÚles de langage avancés tels que GPT-4 pourraient utiliser des techniques similaires à Flash Attention pour gérer des contextes longs.
- ModĂšles Ă contexte long: Flash Attention a permis une nouvelle gĂ©nĂ©ration de modĂšles capables de gĂ©rer des contextes extrĂȘmement longs, tels que des modĂšles capables de traiter des livres entiers ou de longues vidĂ©os.
FlashAttention : développements récents
FlashAttention-2
Forte du succĂšs du Flash Attention original, la mĂȘme Ă©quipe introduit FlashAttention-2 en 2023. Cette version mise Ă jour apporte plusieurs amĂ©liorations :
- Optimisation supplémentaire: FlashAttention-2 permet une utilisation encore meilleure du GPU, atteignant jusqu'à 70 % des FLOPS de pointe théoriques sur les GPU A100.
- Passe arriĂšre amĂ©liorĂ©: La passe arriĂšre est optimisĂ©e pour ĂȘtre presque aussi rapide que la passe avant, ce qui entraĂźne des accĂ©lĂ©rations significatives Ă l'entraĂźnement.
- Prise en charge de diffĂ©rentes variantes d'attention: FlashAttention-2 Ă©tend la prise en charge Ă diverses variantes d'attention, notamment l'attention aux requĂȘtes groupĂ©es et l'attention aux requĂȘtes multiples.
FlashAttention-3
Sorti en 2024, FlashAttention-3 représente la derniÚre avancée dans cette ligne de recherche. Il introduit plusieurs nouvelles techniques pour améliorer encore les performances :
- Calcul asynchrone: Tirer parti de la nature asynchrone des nouvelles instructions GPU pour chevaucher différents calculs.
- Prise en charge du 8e PC: Utilisation du calcul FP8 de faible précision pour un traitement encore plus rapide.
- Traitement incohérent: Une technique pour réduire l'erreur de quantification lors de l'utilisation de formats de faible précision.
Voici un exemple simplifié de la maniÚre dont FlashAttention-3 peut exploiter le calcul asynchrone :
import torch from torch.cuda.amp import autocast def flash_attention_3(Q, K, V, block_size=256): with autocast(dtype=torch.float8): # Using FP8 for computation # ... (similar to previous implementation) # Asynchronous computation example with torch.cuda.stream(torch.cuda.Stream()): # Compute GEMM asynchronously S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (d_model ** 0.5) # Meanwhile, on the default stream: # Prepare for softmax computation # Synchronize streams torch.cuda.synchronize() # Continue with softmax and output computation # ... return O
Cet extrait de code illustre comment FlashAttention-3 peut exploiter le calcul asynchrone et la prĂ©cision FP8. Notez qu'il s'agit d'un exemple simplifiĂ© et que la mise en Ćuvre rĂ©elle serait beaucoup plus complexe et spĂ©cifique au matĂ©riel.
Implémenter Flash Attention dans vos projets
Si vous souhaitez tirer parti de Flash Attention dans vos propres projets, vous disposez de plusieurs options :
- Utiliser les bibliothÚques existantes: De nombreuses bibliothÚques populaires comme Hugging Face Transformers incluent désormais des implémentations Flash Attention. Une simple mise à jour vers la derniÚre version et l'activation des indicateurs appropriés peuvent suffire.
- ImplĂ©mentation personnalisĂ©e: Pour plus de contrĂŽle ou des cas d'utilisation spĂ©cialisĂ©s, vous souhaiterez peut-ĂȘtre implĂ©menter Flash Attention vous-mĂȘme. La bibliothĂšque xformers fournit une bonne implĂ©mentation de rĂ©fĂ©rence.
- Optimisations spĂ©cifiques au matĂ©riel: Si vous travaillez avec du matĂ©riel spĂ©cifique (par exemple, les GPU NVIDIA H100), vous souhaiterez peut-ĂȘtre tirer parti des fonctionnalitĂ©s spĂ©cifiques au matĂ©riel pour des performances maximales.
Voici un exemple de la façon dont vous pouvez utiliser Flash Attention avec la bibliothÚque Hugging Face Transformers :
from transformers import AutoModel, AutoConfig # Enable Flash Attention config = AutoConfig.from_pretrained("bert-base-uncased") config.use_flash_attention = True # Load model with Flash Attention model = AutoModel.from_pretrained("bert-base-uncased", config=config) # Use the model as usual # ...
Défis et orientations futures
Bien que Flash Attention ait fait des progrĂšs significatifs dans lâamĂ©lioration de lâefficacitĂ© des mĂ©canismes dâattention, il reste encore des dĂ©fis et des domaines de recherche future :
- Spécificité matérielle: Les implémentations actuelles sont souvent optimisées pour des architectures GPU spécifiques. Généraliser ces optimisations sur différents matériels reste un défi.
- Intégration avec d'autres techniques: La combinaison de Flash Attention avec d'autres techniques d'optimisation telles que l'élagage, la quantification et la compression de modÚle est un domaine de recherche actif.
- Extension à d'autres domaines: Bien que Flash Attention ait connu un grand succÚs en PNL, étendre ses avantages à d'autres domaines comme la vision par ordinateur et les modÚles multimodaux est un effort continu.
- Compréhension théorique: Approfondir notre compréhension théorique des raisons pour lesquelles Flash Attention fonctionne si bien pourrait conduire à des optimisations encore plus puissantes.
Conclusion
En exploitant intelligemment les hiérarchies de mémoire GPU et en employant des astuces mathématiques, Flash Attention permet d'améliorer considérablement la vitesse et l'utilisation de la mémoire sans sacrifier la précision.
Comme nous l'avons exploré dans cet article, l'impact de Flash Attention s'étend bien au-delà d'une simple technique d'optimisation. Elle a permis le développement de modÚles plus puissants et plus efficaces.