Искусственный интеллект
Внимание: революционный рост эффективности трансформаторов

Эта реализация, хотя и проста, страдает от неэффективности, упомянутой выше. scores тензор, имеющий форму (batch_size, seq_len, seq_len), может стать непомерно большим для длинных последовательностей.
Введите мгновенное внимание
Вспышка внимания, представлено Три Дао и коллегами в своей статье 2022 года — это подход к вычислительному вниманию, который значительно снижает использование памяти и повышает эффективность вычислений. Ключевые идеи Flash Attention:
- Черепица: разбить большую матрицу внимания на более мелкие плитки, которые помещаются в быстродействующую встроенную SRAM.
- Перерасчет: вместо сохранения всей матрицы внимания пересчитывайте ее части по мере необходимости во время обратного прохода.
- Реализация с учетом ввода-вывода: Оптимизировать алгоритм, чтобы минимизировать перемещение данных между различными уровнями иерархии памяти графического процессора.
Алгоритм мгновенного внимания
По своей сути Flash Attention переосмысливает то, как мы вычисляем механизм внимания. Вместо того, чтобы вычислять всю матрицу внимания сразу, он обрабатывает ее блоками, используя иерархию памяти современных графических процессоров.
Вот краткий обзор алгоритма:
- Входные данные: матрицы Q, K, V в HBM (память с высокой пропускной способностью) и встроенная SRAM размера M.
- Размеры блоков рассчитываются на основе доступной SRAM.
- Инициализация выходной матрицы O и вспомогательных векторов l и m.
- Алгоритм делит входные матрицы на блоки для размещения в SRAM.
- Эти блоки обрабатываются двумя вложенными циклами:
- Внешний цикл загружает блоки K и V
- Внутренний цикл загружает блоки Q и выполняет вычисления.
- Встроенные вычисления включают умножение матриц, softmax и расчет выходных данных.
- Результаты записываются обратно в HBM после обработки каждого блока.
Эти блочные вычисления позволяют Flash Attention использовать гораздо меньший объем памяти, сохраняя при этом точное внимание.
Математика, лежащая в основе повышения внимания
Ключом к работе Flash Attention является математический трюк, который позволяет нам вычислять softmax блочным способом. В статье представлены две ключевые формулы:
- Софтмакс-разложение:
softmax(x) = exp(x - m) / Σexp(x - m)где m — максимальное значение x.
- Слияние Softmax:
softmax(x ∪ y) = softmax(softmax(x) * e^(m_x - m), softmax(y) * e^(m_y - m))где m = max(m_x, m_y)
Эти формулы позволяют Flash Attention вычислить частичные результаты softmax для каждого блока, а затем правильно объединить их для получения окончательного результата.
Детали реализации
Давайте рассмотрим упрощенную реализацию Flash Attention, чтобы проиллюстрировать ее основные концепции:
import torch
def flash_attention(Q, K, V, block_size=256):
batch_size, seq_len, d_model = Q.shape
# Initialize output and running statistics
O = torch.zeros_like(Q)
L = torch.zeros((batch_size, seq_len, 1))
M = torch.full((batch_size, seq_len, 1), float('-inf'))
for i in range(0, seq_len, block_size):
Q_block = Q[:, i:i+block_size, :]
for j in range(0, seq_len, block_size):
K_block = K[:, j:j+block_size, :]
V_block = V[:, j:j+block_size, :]
# Compute attention scores for this block
S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (d_model ** 0.5)
# Update running max
M_new = torch.maximum(M[:, i:i+block_size], S_block.max(dim=-1, keepdim=True).values)
# Compute exponentials
exp_S = torch.exp(S_block - M_new)
exp_M_diff = torch.exp(M[:, i:i+block_size] - M_new)
# Update running sum
L_new = exp_M_diff * L[:, i:i+block_size] + exp_S.sum(dim=-1, keepdim=True)
# Compute output for this block
O[:, i:i+block_size] = (
exp_M_diff * O[:, i:i+block_size] +
torch.matmul(exp_S, V_block)
) / L_new
# Update running statistics
L[:, i:i+block_size] = L_new
M[:, i:i+block_size] = M_new
return O
Эта реализация, хотя и упрощена, отражает суть Flash Attention. Он обрабатывает входные данные блоками, сохраняя текущую статистику (M и L) для правильного вычисления softmax для всех блоков.
Влияние мгновенного внимания
Внедрение Flash Attention оказало глубокое влияние на область машинного обучения, особенно для больших языковых моделей и приложений с длинным контекстом. Некоторые ключевые преимущества включают в себя:
- Уменьшенное использование памяти: Flash Attention снижает сложность памяти с O(N^2) до O(N), где N — длина последовательности. Это позволяет обрабатывать гораздо более длинные последовательности на одном и том же оборудовании.
- Повышенная скорость: минимизируя перемещение данных и более эффективно используя вычислительные возможности графического процессора, Flash Attention обеспечивает значительное ускорение. Авторы сообщают, что обучение GPT-3 происходит в 2 раза быстрее по сравнению со стандартными реализациями.
- Точный расчет: В отличие от некоторых других методов оптимизации внимания, Flash Attention рассчитывает точное внимание, а не приближение.
- Масштабируемость: уменьшенный объем памяти позволяет масштабировать более длинные последовательности, потенциально до миллионов токенов.
Воздействие на реальный мир
Влияние мгновенного внимания выходит за рамки академических исследований. Он был быстро принят во многих популярных библиотеках и моделях машинного обучения:
- Трансформеры с обнимающимися лицами: популярная библиотека Transformers интегрировала Flash Attention, что позволяет пользователям легко использовать ее преимущества.
- GPT-4 И дальше: Хотя это и не подтверждено, есть предположение, что продвинутые языковые модели, такие как GPT-4, могут использовать методы, похожие на Flash Attention, для обработки длинных контекстов.
- Длинноконтекстные модели: Flash Attention позволил создать новое поколение моделей, способных обрабатывать чрезвычайно длинные контексты, например модели, которые могут обрабатывать целые книги или длинные видеоролики.
FlashAttention: последние события
FlashAttention-2
Опираясь на успех оригинального Flash Attention, та же команда представил FlashAttention-2 в 2023 году. Эта обновленная версия содержит несколько улучшений:
- Дальнейшая оптимизация: FlashAttention-2 обеспечивает еще лучшее использование графического процессора, достигая до 70 % от теоретического пикового значения FLOPS на графических процессорах A100.
- Улучшенный пас назад: пас назад оптимизирован так, чтобы быть почти таким же быстрым, как и пас вперед, что приводит к значительному ускорению тренировки.
- Поддержка различных вариантов внимания: FlashAttention-2 расширяет поддержку различных вариантов внимания, включая внимание сгруппированных запросов и внимание с несколькими запросами.
FlashAttention-3
FlashAttention-2024 выпущен в 3 году. представляет собой последнее достижение в этой области исследований. В нем представлено несколько новых методов для дальнейшего улучшения производительности:
- Асинхронные вычисления: использование асинхронной природы новых инструкций графического процессора для перекрытия различных вычислений.
- Поддержка FP8: использование вычислений FP8 с низкой точностью для еще более быстрой обработки.
- Некогерентная обработка: метод уменьшения ошибки квантования при использовании форматов низкой точности.
Вот упрощенный пример того, как FlashAttention-3 может использовать асинхронные вычисления:
import torch
from torch.cuda.amp import autocast
def flash_attention_3(Q, K, V, block_size=256):
with autocast(dtype=torch.float8): # Using FP8 for computation
# ... (similar to previous implementation)
# Asynchronous computation example
with torch.cuda.stream(torch.cuda.Stream()):
# Compute GEMM asynchronously
S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (d_model ** 0.5)
# Meanwhile, on the default stream:
# Prepare for softmax computation
# Synchronize streams
torch.cuda.synchronize()
# Continue with softmax and output computation
# ...
return O
Этот фрагмент кода иллюстрирует, как FlashAttention-3 может использовать асинхронные вычисления и точность FP8. Обратите внимание, что это упрощенный пример, и фактическая реализация будет гораздо более сложной и зависит от оборудования.
Использование Flash Attention в ваших проектах
Если вы заинтересованы в использовании Flash Attention в своих собственных проектах, у вас есть несколько вариантов:
- Используйте существующие библиотеки: Многие популярные библиотеки, такие как Hugging Face Transformers, теперь включают реализации Flash Attention. Простого обновления до последней версии и включения соответствующих флагов может быть достаточно.
- Индивидуальная реализация: Для большего контроля или специализированных случаев использования вы можете реализовать Flash Attention самостоятельно. Библиотека xformers предоставляет хорошую эталонную реализацию.
- Аппаратная оптимизация: Если вы работаете со специфическим оборудованием (например, графическими процессорами NVIDIA H100), вам может потребоваться использовать специфичные для оборудования функции для достижения максимальной производительности.
Вот пример того, как можно использовать Flash Attention с библиотекой Hugging Face Transformers:
from transformers import AutoModel, AutoConfig
# Enable Flash Attention
config = AutoConfig.from_pretrained("bert-base-uncased")
config.use_flash_attention = True
# Load model with Flash Attention
model = AutoModel.from_pretrained("bert-base-uncased", config=config)
# Use the model as usual
# ...Проблемы и будущие направления
Несмотря на то, что мгновенное внимание добилось значительных успехов в повышении эффективности механизмов внимания, все еще существуют проблемы и области для будущих исследований:
- Специфика оборудования: Текущие реализации часто оптимизированы для конкретных архитектур графических процессоров. Обобщение этих оптимизаций на различном оборудовании остается сложной задачей.
- Интеграция с другими методами: Сочетание Flash Attention с другими методами оптимизации, такими как сокращение, квантование и сжатие модели, является активной областью исследований.
- Распространение на другие домены: Хотя Flash Attention продемонстрировал большой успех в НЛП, расширение его преимуществ на другие области, такие как компьютерное зрение и мультимодальные модели, требует постоянных усилий.
- Теоретическое понимание: Углубление нашего теоретического понимания того, почему Flash Attention работает так хорошо, может привести к еще более мощной оптимизации.
Заключение
Умно используя иерархию памяти графического процессора и применяя математические приемы, Flash Attention достигает существенных улучшений как в скорости, так и в использовании памяти, не жертвуя при этом точностью.
Как мы выяснили в этой статье, влияние технологии Flash Attention выходит далеко за рамки простого метода оптимизации. Она позволила разработать более мощные и эффективные модели.












