Skip to content

Commit d308079

Browse files
committed
feat(attention): added support for flash attention 4
1 parent bf0e6d3 commit d308079

4 files changed

Lines changed: 229 additions & 18 deletions

File tree

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ uv sync --extra [cpu|cu126|cu128|cu130] --extra tests --extra linting
5252
pre-commit install --install-hooks
5353
```
5454

55+
Additionally, flash attention 4 can be installed via:
56+
57+
```sh
58+
uv pip install --prerelease=allow flash-attn-4
59+
# or (if you want to install the CUDA 13 version)
60+
uv pip install --prerelease=allow flash-attn-4[cu13]
61+
```
62+
5563
### Option 2: Using pip and manual installation of dependencies
5664

5765
```sh

pyproject.toml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,20 @@ cpu = ["torch>=2.10,<2.11.0", "torchvision"]
4646
cu126 = [
4747
"torch>=2.10,<2.11.0",
4848
"torchvision",
49-
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'"
49+
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'",
50+
# "flash-attn-4; platform_system == 'Linux' and platform_machine != 'aarch64'"
5051
]
5152
cu128 = [
5253
"torch>=2.10,<2.11.0",
5354
"torchvision",
54-
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'"
55+
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'",
56+
# "flash-attn-4; platform_system == 'Linux' and platform_machine != 'aarch64'"
5557
]
5658
cu130 = [
5759
"torch>=2.10,<2.11.0",
5860
"torchvision",
59-
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'"
61+
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'",
62+
# "flash-attn-4[cu13]; platform_system == 'Linux' and platform_machine != 'aarch64'"
6063
]
6164

6265
[tool.uv]
@@ -106,8 +109,8 @@ explicit = true
106109

107110
[tool.uv.extra-build-dependencies]
108111
flash-attn = [
109-
{ requirement = "torch", match-runtime = true },
110-
{ requirement = "ninja", match-runtime = true },
112+
{ requirement = "torch", match-runtime = true },
113+
{ requirement = "ninja", match-runtime = true },
111114
]
112115

113116
[tool.black]

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,14 @@
2020
from modalities.util import parse_enum_by_name
2121

2222
try:
23-
from flash_attn import flash_attn_func
23+
from flash_attn import flash_attn_func as flash_attn_func_v2
2424
except ModuleNotFoundError:
25-
flash_attn_func = None
25+
flash_attn_func_v2 = None
26+
27+
try:
28+
from flash_attn.cute import flash_attn_func as flash_attn_func_v4
29+
except Exception:
30+
flash_attn_func_v4 = None
2631

2732
# Logger configuration
2833
logger = logging.getLogger(__name__)
@@ -249,12 +254,14 @@ class AttentionImplementation(str, Enum):
249254
Attributes:
250255
MANUAL (str): Manual attention implementation.
251256
PYTORCH_FLASH (str): PyTorch's flash attention implementation.
252-
DAO_FLASH (str): DAO's flash attention implementation.
257+
DAO_FLASH (str): DAO's FlashAttention-2 implementation.
258+
DAO_FLASH_V4 (str): DAO's FlashAttention-4 implementation.
253259
"""
254260

255261
MANUAL = "manual"
256262
PYTORCH_FLASH = "pytorch_flash"
257263
DAO_FLASH = "dao_flash"
264+
DAO_FLASH_V4 = "dao_flash_v4"
258265

259266

260267
class AttentionConfig(BaseModel):
@@ -439,6 +446,14 @@ def __init__(
439446
super().__init__()
440447
assert n_embd % n_head_q == 0, "`n_embd needs` to be divisible by `n_head_q`."
441448
assert n_head_q % n_head_kv == 0, "`n_head_q needs` to be divisible by `n_head_kv`."
449+
if attention_impl == AttentionImplementation.DAO_FLASH:
450+
if flash_attn_func_v2 is None:
451+
raise NotImplementedError("ERROR! Dao Flash Attention 2 is not installed.")
452+
if attention_impl == AttentionImplementation.DAO_FLASH_V4:
453+
if flash_attn_func_v4 is None:
454+
raise NotImplementedError("ERROR! Dao Flash Attention 4 is not installed.")
455+
if dropout > 0.0:
456+
raise NotImplementedError("ERROR! Dao Flash Attention 4 does not support attention dropout.")
442457

443458
self.n_rep = n_head_q // n_head_kv
444459
self.attention_impl = attention_impl
@@ -644,19 +659,50 @@ def execute_attention(
644659
# Due to the lack of GPUs in github actions and the requirement of those in the flash-attn library,
645660
# we have to check if the library is installed and raise an error if not.
646661
# Note, that the library is not required for the CPU-only tests.
647-
if flash_attn_func is None:
648-
raise NotImplementedError("ERROR! Dao Flash Attention is not installed.")
649-
# the next three lines are only needed for flash-attn from Daio Lab
650-
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
651-
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
652-
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
653-
y = flash_attn_func(
654-
q, k, v, dropout_p=dropout, causal=True, softmax_scale=None, window_size=(-1, -1)
655-
) # (B, T, nh_q, hd)
662+
y = cls._execute_dao_flash_v2(q, k, v, dropout)
663+
elif attention_impl == AttentionImplementation.DAO_FLASH_V4:
664+
if cls._requires_backward(q, k, v) and torch.cuda.get_device_capability(q.device)[0] < 9:
665+
y = cls._execute_dao_flash_v2(q, k, v, dropout)
666+
else:
667+
# TODO added due to upstream failure in its pack_gqa handling,
668+
# can be removed once the issue is resolved:
669+
k, v = cls.repeat_kv_heads(q, k, v)
670+
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
671+
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
672+
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
673+
y = cls._unwrap_flash_attention_output(
674+
flash_attn_func_v4(
675+
q,
676+
k,
677+
v,
678+
causal=True,
679+
softmax_scale=None,
680+
window_size=(None, None),
681+
)
682+
)
656683
else:
657684
raise NotImplementedError(f"Attention implementation {attention_impl} not supported")
658685
return y # (B, T, nh_q, hd)
659686

687+
@staticmethod
688+
def _unwrap_flash_attention_output(
689+
output: torch.Tensor | tuple[torch.Tensor, Optional[torch.Tensor]],
690+
) -> torch.Tensor:
691+
if isinstance(output, tuple):
692+
return output[0]
693+
return output
694+
695+
@staticmethod
696+
def _execute_dao_flash_v2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout: float) -> torch.Tensor:
697+
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
698+
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
699+
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
700+
return flash_attn_func_v2(q, k, v, dropout_p=dropout, causal=True, softmax_scale=None, window_size=(-1, -1))
701+
702+
@staticmethod
703+
def _requires_backward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> bool:
704+
return torch.is_grad_enabled() and any(tensor.requires_grad for tensor in (q, k, v))
705+
660706
def forward(self, x: torch.Tensor) -> torch.Tensor:
661707
"""
662708
Forward pass of the CausalSelfAttention module.

tests/models/test_causal_self_attention.py

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
"""
2-
Note: test_attention_types_approximate_equality can print the output of different attention implementations.
2+
Note: test_attention_types_approximate_equality can print the output of different attention implementations.
33
To do so, turn on verbose and run 'pytest tests/models/test_causal_self_attention.py -s'
44
"""
5+
6+
import os
7+
import subprocess
8+
import sys
9+
import textwrap
510
from copy import deepcopy
11+
from importlib.util import find_spec
12+
from pathlib import Path
613

714
import pytest
815
import torch
@@ -17,6 +24,10 @@
1724

1825
torch.manual_seed(0)
1926

27+
FLASH_ATTN_V4_AVAILABLE = find_spec("flash_attn.cute") is not None
28+
REPO_ROOT = Path(__file__).resolve().parents[2]
29+
SRC_ROOT = REPO_ROOT / "src"
30+
2031

2132
def _get_random_input_seq(embedding_shape):
2233
flash_attn_supported_dtype = torch.bfloat16
@@ -272,3 +283,146 @@ def test_qk_norm(n_head_q, n_head_kv, n_embd, attention_impl):
272283

273284
assert output_no_norm.shape == output_with_norm.shape == embedding_shape
274285
assert not torch.allclose(output_no_norm, output_with_norm, atol=1e-6)
286+
287+
288+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.")
289+
@pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed")
290+
def test_dao_flash_v4_forward_mha_subprocess():
291+
result = _run_fa4_subprocess(
292+
"""
293+
import torch
294+
from modalities.models.gpt2.gpt2_model import CausalSelfAttention
295+
296+
q = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device='cuda')
297+
k = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device='cuda')
298+
v = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device='cuda')
299+
out = CausalSelfAttention.execute_attention(q, k, v, dropout=0.0, attention_impl='dao_flash_v4')
300+
torch.cuda.synchronize()
301+
assert tuple(out.shape) == (2, 12, 4, 32)
302+
print('ok')
303+
"""
304+
)
305+
assert result.stdout.strip().endswith("ok")
306+
307+
308+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.")
309+
@pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed")
310+
def test_dao_flash_v4_forward_gqa_subprocess():
311+
result = _run_fa4_subprocess(
312+
"""
313+
import torch
314+
from modalities.models.gpt2.gpt2_model import CausalSelfAttention
315+
316+
q = torch.rand(2, 8, 12, 32, dtype=torch.bfloat16, device='cuda')
317+
k = torch.rand(2, 2, 12, 32, dtype=torch.bfloat16, device='cuda')
318+
v = torch.rand(2, 2, 12, 32, dtype=torch.bfloat16, device='cuda')
319+
out = CausalSelfAttention.execute_attention(q, k, v, dropout=0.0, attention_impl='dao_flash_v4')
320+
torch.cuda.synchronize()
321+
assert tuple(out.shape) == (2, 12, 8, 32)
322+
print('ok')
323+
"""
324+
)
325+
assert result.stdout.strip().endswith("ok")
326+
327+
328+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.")
329+
@pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed")
330+
def test_dao_flash_v4_qk_norm_subprocess():
331+
result = _run_fa4_subprocess(
332+
"""
333+
import torch
334+
from modalities.models.gpt2.gpt2_model import (
335+
AttentionConfig,
336+
CausalSelfAttention,
337+
LayerNorms,
338+
LayerNormWrapperConfig,
339+
PytorchRMSLayerNormConfig,
340+
)
341+
342+
torch.manual_seed(0)
343+
attention_config_no_norm = AttentionConfig(qkv_transforms=[])
344+
attention_config_with_norm = AttentionConfig(
345+
qkv_transforms=[],
346+
qk_norm_config=LayerNormWrapperConfig(
347+
norm_type=LayerNorms.pytorch_rms_norm,
348+
config=PytorchRMSLayerNormConfig(normalized_shape=8),
349+
),
350+
)
351+
352+
torch.manual_seed(0)
353+
layer_no_norm = CausalSelfAttention(
354+
4, 4, 32, attention_config_no_norm, 'dao_flash_v4', False, 0.0
355+
).cuda().bfloat16()
356+
torch.manual_seed(0)
357+
layer_with_norm = CausalSelfAttention(
358+
4, 4, 32, attention_config_with_norm, 'dao_flash_v4', False, 0.0
359+
).cuda().bfloat16()
360+
x = torch.rand((2, 9, 32), dtype=torch.bfloat16, device='cuda')
361+
out_no_norm = layer_no_norm(x)
362+
out_with_norm = layer_with_norm(x)
363+
torch.cuda.synchronize()
364+
assert out_no_norm.shape == out_with_norm.shape == (2, 9, 32)
365+
assert not torch.allclose(out_no_norm, out_with_norm, atol=1e-6)
366+
print('ok')
367+
"""
368+
)
369+
assert result.stdout.strip().endswith("ok")
370+
371+
372+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.")
373+
@pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed")
374+
def test_dao_flash_v4_backward_approximate_equality_subprocess():
375+
result = _run_fa4_subprocess(
376+
"""
377+
import torch
378+
from modalities.models.gpt2.gpt2_model import CausalSelfAttention
379+
380+
query_ref = torch.rand((2, 8, 12, 64), dtype=torch.bfloat16, device='cuda', requires_grad=True)
381+
key_ref = torch.rand((2, 2, 12, 64), dtype=torch.bfloat16, device='cuda', requires_grad=True)
382+
value_ref = torch.rand((2, 2, 12, 64), dtype=torch.bfloat16, device='cuda', requires_grad=True)
383+
384+
query_fa4 = query_ref.detach().clone().requires_grad_(True)
385+
key_fa4 = key_ref.detach().clone().requires_grad_(True)
386+
value_fa4 = value_ref.detach().clone().requires_grad_(True)
387+
388+
output_ref = CausalSelfAttention.execute_attention(
389+
query_ref, key_ref, value_ref, dropout=0.0, attention_impl='pytorch_flash'
390+
)
391+
output_fa4 = CausalSelfAttention.execute_attention(
392+
query_fa4, key_fa4, value_fa4, dropout=0.0, attention_impl='dao_flash_v4'
393+
)
394+
torch.testing.assert_close(output_ref, output_fa4, atol=2.5e-3, rtol=0.016)
395+
396+
output_ref.float().sum().backward()
397+
output_fa4.float().sum().backward()
398+
torch.cuda.synchronize()
399+
400+
torch.testing.assert_close(query_ref.grad, query_fa4.grad, atol=5e-3, rtol=0.02)
401+
torch.testing.assert_close(key_ref.grad, key_fa4.grad, atol=5e-3, rtol=0.02)
402+
torch.testing.assert_close(value_ref.grad, value_fa4.grad, atol=5e-3, rtol=0.02)
403+
print('ok')
404+
"""
405+
)
406+
assert result.stdout.strip().endswith("ok")
407+
408+
409+
def _run_fa4_subprocess(code: str) -> subprocess.CompletedProcess[str]:
410+
"""Run flash attention 4 related code in a subprocess to isolate FA4's CUDA context
411+
and avoid conflicts with other tests.
412+
The code should print 'ok' if it runs successfully.
413+
The function returns the CompletedProcess object,
414+
which contains stdout and stderr for further inspection if needed.
415+
TODO: This might be an A100 / SM80-specific issue, so we can consider removing this subprocess isolation
416+
if we confirm that FA4 works well on newer architectures without it.
417+
"""
418+
env = os.environ.copy()
419+
existing_pythonpath = env.get("PYTHONPATH")
420+
env["PYTHONPATH"] = f"{SRC_ROOT}:{existing_pythonpath}" if existing_pythonpath else str(SRC_ROOT)
421+
return subprocess.run(
422+
[sys.executable, "-c", textwrap.dedent(code)],
423+
cwd=REPO_ROOT,
424+
env=env,
425+
check=True,
426+
capture_output=True,
427+
text=True,
428+
)

0 commit comments

Comments
 (0)