Flash Attention: Ako transformery zvládajú milióny tokenov

Flash Attention je algoritmická optimalizácia, ktorá zmenila spôsob výpočtu pozornostného mechanizmu v transformeroch — a práve vďaka nej dnes modely ako Claude, GPT-4 či Gemini zvládajú kontextové okná s státisícmi až miliónmi tokenov bez toho, aby im pri tom praskol GPU.


1. Problém: klasická attention je pamäťový moč

Každý transformer v jadre počíta takzvanú scaled dot-product attention. Pre vstup s $N$ tokenmi vzniká matica veľkosti $N \times N$ — každý token sa porovnáva s každým iným. Pri $N = 1,000$ tokenoch je to milión hodnôt. Pri $N = 128,000$ tokenoch je to už 16 miliárd hodnôt len pre jednu vrstvu, jeden attention head.

Tá matica sa musí niekam uložiť. GPU má niekoľko typov pamäte:

  • HBM (High Bandwidth Memory) — veľká (desiatky GB), ale pomalá prenos dát z/do jadier
  • SRAM (on-chip) — bleskovo rýchla, ale tiiiny: niekoľko desiatok MB na celý chip

Klasická implementácia attention opätovne presúva obrovské matice cez pomalý HBM. Výsledok: výpočtový čas narastá kvadraticky ($O(N^2)$) a pamäťové nároky tiež. Dlhé kontexty sú tak nielen pomalé — pri klasickej attention jednoducho nezmestia do GPU.


2. Riešenie: IO-aware algoritmus

Flash Attention, ktorý v roku 2022 publikoval Tri Dao s kolektívom zo Stanfordu, prišiel s odpoveďou: nerátajte celú attention maticu naraz. Namiesto toho rozdeľte vstupné matice $Q$, $K$, $V$ (query, key, value) na malé bloky, ktoré sa zmestia do rýchlej SRAM, a spracujte ich postupne.

Kľúčové triky:

  • Tiling (dlaždičkové spracovanie) — matice $Q$, $K$, $V$ sa porežú na bloky; každý blok sa načíta do SRAM, spracuje a výsledok sa zapíše späť do HBM. Veľká matica $N \times N$ sa v pamäti nikdy ako celok neobjaví.
  • Online softmax — softmax vyžaduje globálne maximum riadku. Flash Attention ho počíta inkrementálne pri spracovaní blokov pomocou numericky stabilného running-max triku (Milakov & Gimelshein, 2018).
  • Rekomputácia pri spätnom prechode — namiesto ukladania celej attention matice pre backpropagation algoritmus maticu znova rýchlo prepočíta z uložených vstupov. Tým šetrí pamäť za cenu relatívne lacného výpočtu.

Výsledok: algoritmus je IO-komplexne subkubický — nemusí prenášať $O(N^2)$ dát cez pomalý HBM, čo je v praxi bottleneck.


3. Evolúcia: FA1 → FA2 → FA3

Verzia Rok Kľúčová inovácia Zrýchlenie vs. štandard
Flash Attention 1 2022 IO-aware tiling + online softmax 3–4× rýchlejší
Flash Attention 2 2023 Lepšie rozdelenie práce medzi warp-y, menej zbytočného syncovania, FP16/BF16 5–9× rýchlejší
Flash Attention 3 2024 Asynchrónny pipeline pre H100 Tensor Cores, FP8 podpora, overlapping GEMM + softmax 1,5–2× oproti FA2

FA3 je navrhnutý špeciálne pre architektúru NVIDIA Hopper (H100/H200) a exploituje asynchrónne inštrukcie WGMMA a TMA, ktoré staršie GPU nemajú.


4. Kde to nájdete v praxi

Flash Attention nie je produkt — je to referenčná implementácia (open-source, licencia BSD) plus integrácia do väčšiny dnes relevantných knižníc:

  • PyTorch 2.0+F.scaled_dot_product_attention() automaticky volá Flash Attention, ak je dostupný a vstup vyhovuje
  • Hugging Face Transformers — väčšina architektúr ho zapne cez attn_implementation="flash_attention_2"
  • vLLM, TensorRT-LLM, SGLang — inference frameworky, kde FA2/FA3 je default pre H100 deployment

Modely, ktoré Flash Attention aktívne využívajú:

  • Meta Llama 2, 3, 4 (Scout, Maverick)
  • Mistral 7B a celá rodina
  • Anthropic Claude (interné tréningové pipeline)
  • Qwen 2/3 séria
  • Prakticky každý moderný open-source model tréningovaný od 2023

Bez Flash Attention by kontextové okno 128k–1M tokenov nebolo ekonomicky realistické. Tréning na dlhých dokumentoch by stál niekoľkonásobne viac.


5. Limity a čo ďalej

Flash Attention nie je všeliek:

  • Závislý na GPU architektúre — FA3 je len pre Hopper (H100/H200). Na starších kartách (A100, RTX 30xx) beží FA2. Na CPU alebo Apple Silicon je podpora iná cesta (Metal, xFormers).
  • Causally masked attention ≠ full attention — implementácia pre decoder-only modely (kauzálna maska) je jednoduchšia; bidirektívna attention (BERT-style) má špecifické varianty.
  • Skupinová attention (GQA/MQA) — Flash Attention podporuje grouped-query a multi-query attention, ale ich správna implementácia pridáva vrstvu zložitosti.
  • Alternatívy — pre extrémne dlhé kontexty (>1M tokenov) sa skúmajú lineárne attention approximácie (Hyena, RetNet, Mamba) a ringová distribúcia attention cez viaceré GPU.

Výskum pokračuje v smere sparse attention (pozornosť len na podmnožinu tokenov), hardware co-design (architektúry GPU navrhnuté s ohľadom na attention), a FP8 tréning kde FA3 už dnes ukazuje sľubné výsledky.


6. Ako ho zapnúť

V praxi Flash Attention nevoláš ručne — knižnice ho zapnú za teba, ak je dostupný:

# PyTorch 2.0+: automaticky použije Flash Attention, ak vstup vyhovuje
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

# Hugging Face Transformers: explicitné zapnutie
model = AutoModelForCausalLM.from_pretrained(
    name, attn_implementation="flash_attention_2", torch_dtype="bfloat16",
)

Numericky dáva (takmer) rovnaký výsledok ako klasická attention — líši sa len spôsob výpočtu, nie matematika.

7. Súvislosti

  • Attention mechanizmus: Flash Attention je IO-efektívny spôsob, ako ho spočítať.
  • Kontextové okno: práve vďaka nemu sú dlhé okná ekonomicky reálne.
  • Inferencia: v produkcii znižuje pamäť (KV cache) aj latenciu.
  • Deep Learning: príklad, ako hardvérovo uvedomelý algoritmus pohne celé odvetvie.

Zhrnutie: Flash Attention je jeden z tých algoritmických skokov, ktoré zmenili celé odvetvie bez toho, aby sa o ňom veľa hovorilo — nepridáva nové schopnosti modelom, ale spraví to, čo existuje, dostatočne rýchle a lacné na to, aby sa dalo reálne použiť; bez neho by dnešné dlhé kontexty zostali len na papieri.