私達ず接続

フラッシュアテンション倉圧噚の効率を革新

Artificial Intelligence

フラッシュアテンション倉圧噚の効率を革新

mm
FlashAttention-3: 非同期か぀䜎粟床の高速か぀正確なアテンション

トランスフォヌマヌモデルが倧きくなり耇雑になるに぀れお、 蚈算効率ずメモリ䜿甚量の面で倧きな課題特に長いシヌケンスを扱う堎合には、Flash Attention は Transformer モデルでアテンション メカニズムを実装および拡匵する方法に革呜をもたらす最適化手法です。

この包括的なガむドでは、Flash Attention に぀いお深く掘り䞋げ、その䞭栞ずなる抂念、実装の詳现、機械孊習の分野に䞎えおいる倧きな圱響を探りたす。

問題: 泚目は高䟡である

解決策を詳しく怜蚎する前に、たずFlash Attentionが解決しようずしおいる問題を理解したしょう。 泚意メカニズム匷力ではありたすが、特に長いシヌケンスの堎合はかなりの蚈算コストがかかりたす。

暙準の泚意: 簡単な芁玄

Transformer モデルの暙準的な泚意メカニズムは、次の匏で芁玄できたす。

Attention(Q, K, V) = softmax(QK^T / √d) V

ここで、Q、K、V はそれぞれク゚リ、キヌ、倀の行列であり、d はキヌ ベクトルの次元です。

この定匏化は簡朔ですが、実装によっおいく぀かの非効率性が生たれたす。

  1. メモリボトルネック: 䞭間アテンション マトリックス (QK^T) のサむズは N x N です。ここで、N はシヌケンスの長さです。シヌケンスが長い堎合、䜿甚可胜な GPU メモリがすぐに䜿い果たされる可胜性がありたす。
  2. 冗長メモリアクセス: 暙準的な実装では、アテンション マトリックスが蚈算され、高垯域幅メモリ (HBM) に保存され、その埌、゜フトマックス挔算のために読み戻されたす。この冗長なメモリ アクセスが倧きなボトルネックずなりたす。
  3. GPUコンピュヌティングの掻甚䞍足: 最新の GPU は、メモリ垯域幅よりもはるかに高い蚈算胜力 (FLOPS) を備えおいたす。暙準的なアテンション実装はメモリに瞛られおいるため、GPU の蚈算胜力の倚くは未掻甚のたたになっおいたす。

暙準的なアテンションの実装を瀺す簡単な Python コヌド スニペットでこれを説明したしょう。

</pre>
import torch

def standard_attention(Q, K, V):
# Q, K, V shape: (batch_size, seq_len, d_model)
d_k = K.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))
attention_weights = torch.softmax(scores, dim=-1)
return torch.matmul(attention_weights, V)

この実装は単玔ではあるが、前述の非効率性に悩たされおいる。 scores 圢状が (batch_size, seq_len, seq_len) であるテン゜ルは、長いシヌケンスでは法倖に倧きくなる可胜性がありたす。

フラッシュアテンションを入力

フラッシュアテンション、 トリ・ダオず同僚によっお玹介された 2022幎の論文で発衚されたFlash Attentionは、メモリ䜿甚量を倧幅に削枛し、蚈算効率を向䞊させるコンピュヌティングアテンションのアプロヌチです。Flash Attentionの背埌にある䞻芁なアむデアは次のずおりです。

  1. タむル: 倧きなアテンション マトリックスを、高速オンチップ SRAM に収たる小さなタむルに分割したす。
  2. 再蚈算: 泚意行列党䜓を保存する代わりに、埌方パス䞭に必芁に応じおその䞀郚を再蚈算したす。
  3. IO 察応実装: GPU メモリ階局の異なるレベル間でのデヌタ移動を最小限に抑えるようにアルゎリズムを最適化したす。

フラッシュ アテンション アルゎリズム

Flash Attention は、本質的に、アテンション メカニズムの蚈算方法を再考したす。アテンション マトリックス党䜓を䞀床に蚈算するのではなく、最新の GPU のメモリ階局を掻甚しおブロック単䜍で凊理したす。

アルゎリズムの抂芁は次のずおりです。

  1. 入力: HBM (高垯域幅メモリ) 内の行列 Q、K、V およびサむズ M のオンチップ SRAM。
  2. ブロック サむズは䜿甚可胜な SRAM に基づいお蚈算されたす。
  3. 出力行列 O ず補助ベクトル l および m の初期化。
  4. アルゎリズムは、入力マトリックスを SRAM に収たるようにブロックに分割したす。
  5. 2 ぀のネストされたルヌプがこれらのブロックを凊理したす。
    • 倖偎のルヌプはKブロックずVブロックをロヌドしたす
    • 内偎のルヌプはQブロックをロヌドし、蚈算を実行したす。
  6. オンチップ蚈算には、行列乗算、゜フトマックス、出力蚈算が含たれたす。
  7. 各ブロックの凊理埌、結果は HBM に曞き戻されたす。

このブロック単䜍の蚈算により、Flash Attention は正確な泚意を蚈算しながら、メモリ フットプリントを倧幅に小さく抑えるこずができたす。

フラッシュアテンションの背埌にある数孊

Flash Attention を機胜させるための鍵は、ブロック単䜍で゜フトマックスを蚈算できる数孊的なトリックです。この論文では、2 ぀の重芁な公匏が玹介されおいたす。

  1. ゜フトマックス分解:
    softmax(x) = exp(x - m) / Σexp(x - m)

    ここで、m は x の最倧倀です。

  2. ゜フトマックス合䜵:
    softmax(x ∪ y) = softmax(softmax(x) * e^(m_x - m), softmax(y) * e^(m_y - m))

    ここで、m = max(m_x, m_y)

これらの匏により、Flash Attention は各ブロックの郚分的な゜フトマックス結果を蚈算し、それらを正しく組み合わせお最終結果を埗るこずができたす。

実装の詳现

Flash Attention の簡略化された実装を詳しく芋お、その䞭栞ずなる抂念を説明したしょう。

import torch

def flash_attention(Q, K, V, block_size=256):
    batch_size, seq_len, d_model = Q.shape
    
    # Initialize output and running statistics
    O = torch.zeros_like(Q)
    L = torch.zeros((batch_size, seq_len, 1))
    M = torch.full((batch_size, seq_len, 1), float('-inf'))
    
    for i in range(0, seq_len, block_size):
        Q_block = Q[:, i:i+block_size, :]
        
        for j in range(0, seq_len, block_size):
            K_block = K[:, j:j+block_size, :]
            V_block = V[:, j:j+block_size, :]
            
            # Compute attention scores for this block
            S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (d_model ** 0.5)
            
            # Update running max
            M_new = torch.maximum(M[:, i:i+block_size], S_block.max(dim=-1, keepdim=True).values)
            
            # Compute exponentials
            exp_S = torch.exp(S_block - M_new)
            exp_M_diff = torch.exp(M[:, i:i+block_size] - M_new)
            
            # Update running sum
            L_new = exp_M_diff * L[:, i:i+block_size] + exp_S.sum(dim=-1, keepdim=True)
            
            # Compute output for this block
            O[:, i:i+block_size] = (
                exp_M_diff * O[:, i:i+block_size] +
                torch.matmul(exp_S, V_block)
            ) / L_new
            
            # Update running statistics
            L[:, i:i+block_size] = L_new
            M[:, i:i+block_size] = M_new
    
    return O

この実装は簡略化されおいたすが、Flash Attention の本質を捉えおいたす。ブロック内の入力を凊理し、実行䞭の統蚈 (M ず L) を維持しお、すべおのブロックにわたっお゜フトマックスを正しく蚈算したす。

フラッシュアテンションの圱響

Flash Attention の導入は、特に倧芏暡な蚀語モデルや長いコンテキストのアプリケヌションにおいお、機械孊習の分野に倧きな圱響を䞎えたした。䞻な利点は次のずおりです。

  1. メモリ䜿甚量の削枛: Flash Attention は、メモリの耇雑さを O(N^2) から O(N) に削枛したす。ここで、N はシヌケンスの長さです。これにより、同じハヌドりェアではるかに長いシヌケンスを凊理できるようになりたす。
  2. 速床の向䞊デヌタの移動を最小限に抑え、GPU の蚈算胜力をより有効に掻甚するこずで、Flash Attention は倧幅な高速化を実珟したす。著者らは、暙準実装ず比范しお GPT-3 のトレヌニングが最倧 2 倍高速であるず報告しおいたす。
  3. 正確な蚈算: 他の泚意最適化手法ずは異なり、Flash Attention は近䌌倀ではなく正確な泚意を蚈算したす。
  4. 拡匵性: メモリフットプリントが削枛されたため、最倧数癟䞇のトヌクンたで、はるかに長いシヌケンスにスケヌリングできるようになりたす。

実䞖界ぞの圱響

Flash Attention の圱響は孊術研究だけにずどたりたせん。倚くの人気の機械孊習ラむブラリやモデルに急速に採甚されおいたす。

  • フェむストランスフォヌマヌを抱き締める: 人気の Transformers ラむブラリには Flash Attention が統合されおおり、ナヌザヌはその利点を簡単に掻甚できたす。
  • GPT-4 以降: 確認はされおいたせんが、GPT-4 のような高床な蚀語モデルは、長いコンテキストを凊理するために Flash Attention に䌌た手法を䜿甚しおいる可胜性があるずいう掚枬がありたす。
  • ロングコンテキストモデル: Flash Attention により、曞籍党䜓や長いビデオを凊理できるモデルなど、非垞に長いコンテキストを凊理できる新䞖代のモデルが可胜になりたした。

フラッシュ泚目: 最近の動向

暙準アテンションずフラッシュアテンション

暙準アテンションずフラッシュアテンション

フラッシュアテンション-2

オリゞナルのFlash Attentionの成功を基に、同じチヌムが 2幎にFlashAttention-2023を導入この曎新バヌゞョンでは、いく぀かの改善が加えられおいたす。

  1. さらなる最適化FlashAttention-2 はさらに優れた GPU 䜿甚率を実珟し、A70 GPU の理論䞊のピヌク FLOPS の最倧 100% に達したす。
  2. 改良されたバックワヌドパス: 埌方パスは前方パスずほが同じ速床になるように最適化されおおり、トレヌニングの速床が倧幅に向䞊したす。
  3. さたざたな泚目バリ゚ヌションのサポヌト: FlashAttention-2 は、グルヌプ化されたク゚リ アテンションやマルチク゚リ アテンションなど、さたざたなアテンション バリアントのサポヌトを拡匵したす。

フラッシュアテンション-3

2024幎にリリヌスされたFlashAttention-3 この研究分野における最新の進歩を衚しおいたす。パフォヌマンスをさらに向䞊させるためのいく぀かの新しい手法が導入されおいたす。

  1. 非同期蚈算: 新しい GPU 呜什の非同期性を掻甚しお、さたざたな蚈算を重ね合わせたす。
  2. FP8 サポヌト: 䜎粟床FP8挔算を利甚し、さらに高速な凊理を実珟したす。
  3. 䞀貫性のない凊理: 䜎粟床フォヌマットを䜿甚するずきに量子化誀差を枛らす手法。

以䞋は、FlashAttention-3 が非同期蚈算を掻甚する方法の簡略化された䟋です。

import torch
from torch.cuda.amp import autocast

def flash_attention_3(Q, K, V, block_size=256):
    with autocast(dtype=torch.float8):  # Using FP8 for computation
        # ... (similar to previous implementation)
        
        # Asynchronous computation example
        with torch.cuda.stream(torch.cuda.Stream()):
            # Compute GEMM asynchronously
            S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (d_model ** 0.5)
        
        # Meanwhile, on the default stream:
        # Prepare for softmax computation
        
        # Synchronize streams
        torch.cuda.synchronize()
        
        # Continue with softmax and output computation
        # ...

    return O

このコヌド スニペットは、FlashAttention-3 が非同期蚈算ず FP8 粟床を掻甚する方法を瀺しおいたす。これは単玔化された䟋であり、実際の実装ははるかに耇雑でハヌドりェア固有になるこずに泚意しおください。

プロゞェクトにフラッシュアテンションを実装する

独自のプロゞェクトで Flash Attention を掻甚するこずに関心がある堎合は、いく぀かのオプションがありたす。

  1. 既存のラむブラリを䜿甚する: Hugging Face Transformers などの倚くの人気ラむブラリには珟圚、Flash Attention の実装が含たれおいたす。最新バヌゞョンに曎新し、適切なフラグを有効にするだけで十分な堎合がありたす。
  2. カスタム実装: より高床な制埡や特殊なナヌスケヌスの堎合は、Flash Attention を自分で実装するこずをお勧めしたす。xformers ラむブラリは、優れたリファレンス実装を提䟛したす。
  3. ハヌドりェア固有の最適化: 特定のハヌドりェア (NVIDIA H100 GPU など) を䜿甚しおいる堎合は、パフォヌマンスを最倧限に高めるためにハヌドりェア固有の機胜を掻甚するこずをお勧めしたす。

Hugging Face Transformers ラむブラリで Flash Attention を䜿甚する方法の䟋を次に瀺したす。

from transformers import AutoModel, AutoConfig

# Enable Flash Attention
config = AutoConfig.from_pretrained("bert-base-uncased")
config.use_flash_attention = True

# Load model with Flash Attention
model = AutoModel.from_pretrained("bert-base-uncased", config=config)

# Use the model as usual
# ...

課題ず今埌の方向性

フラッシュ アテンションは泚意メカニズムの効率性の向䞊に倧きな進歩を遂げたしたが、ただ課題ず今埌の研究領域が残っおいたす。

  1. ハヌドりェアの特異性珟圚の実装は、倚くの堎合、特定の GPU アヌキテクチャ向けに最適化されおいたす。これらの最適化をさたざたなハヌドりェアに䞀般化するこずは、䟝然ずしお課題ずなっおいたす。
  2. 他の技術ずの統合: Flash Attention を、プルヌニング、量子化、モデル圧瞮などの他の最適化手法ず組み合わせるこずは、掻発に研究されおいる分野です。
  3. 他のドメむンぞの拡匵: Flash Attention は NLP で倧きな成功を収めおいたすが、その利点をコンピュヌタヌ ビゞョンやマルチモヌダル モデルなどの他の領域に拡匵するための取り組みが珟圚も続いおいたす。
  4. 理論的理解: Flash Attention がなぜこれほどうたく機胜するのかに぀いおの理論的理解を深めるこずで、さらに匷力な最適化を実珟できる可胜性がありたす。

たずめ

 Flash Attention は、GPU メモリ階局を巧みに掻甚し、数孊的なトリックを採甚するこずで、粟床を犠牲にするこずなく、速床ずメモリ䜿甚量の䞡方を倧幅に改善したす。

この蚘事で説明したように、Flash Attention の圱響は単なる最適化手法をはるかに超えおいたす。これにより、より匷力で効率的なモデルの開発が可胜になりたした。

私は過去 50 幎間、機械孊習ず深局孊習の魅力的な䞖界に没頭しおきたした。 私の情熱ず専門知識により、特に AI/ML に重点を眮いた XNUMX を超える倚様な゜フトりェア ゚ンゞニアリング プロゞェクトに貢献しおきたした。 私の継続的な奜奇心は、私がさらに探求したいず思っおいる分野である自然蚀語凊理にも匕き寄せられたした。