Skip to content

Commit 807930e

Browse files
committed
fix(attention): Workaround for flash attention 4 not playing nice with torch compile currently.
1 parent 71bc205 commit 807930e

3 files changed

Lines changed: 79 additions & 28 deletions

File tree

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import abstractmethod
44
from enum import Enum
55
from importlib import import_module
6-
from typing import Annotated, Callable, Optional, overload
6+
from typing import Annotated, Callable, Optional, cast, overload
77

88
import torch
99
import torch.nn as nn
@@ -73,6 +73,31 @@ def _raise_flash_attn_v4_unavailable() -> None:
7373
raise NotImplementedError(error_message)
7474

7575

76+
@torch.compiler.disable
77+
def _execute_dao_flash_v4_eager(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
78+
flash_attn_v4 = get_flash_attn_func_v4()
79+
if flash_attn_v4 is None:
80+
_raise_flash_attn_v4_unavailable()
81+
82+
output = flash_attn_v4(
83+
q,
84+
k,
85+
v,
86+
causal=True,
87+
softmax_scale=None,
88+
window_size=(None, None),
89+
)
90+
return _unwrap_flash_attention_output(cast(torch.Tensor | tuple[torch.Tensor, Optional[torch.Tensor]], output))
91+
92+
93+
def _unwrap_flash_attention_output(
94+
output: torch.Tensor | tuple[torch.Tensor, Optional[torch.Tensor]],
95+
) -> torch.Tensor:
96+
if isinstance(output, tuple):
97+
return output[0]
98+
return output
99+
100+
76101
class LayerNorms(LookupEnum):
77102
"""
78103
Enum lookup class for LayerNorms.
@@ -698,7 +723,6 @@ def execute_attention(
698723
# Note, that the library is not required for the CPU-only tests.
699724
y = cls._execute_dao_flash_v2(q, k, v, dropout)
700725
elif attention_impl == AttentionImplementation.DAO_FLASH_V4:
701-
flash_attn_v4 = get_flash_attn_func_v4()
702726
if cls._requires_backward(q, k, v) and torch.cuda.get_device_capability(q.device)[0] < 9:
703727
y = cls._execute_dao_flash_v2(q, k, v, dropout)
704728
else:
@@ -708,28 +732,11 @@ def execute_attention(
708732
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
709733
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
710734
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
711-
y = cls._unwrap_flash_attention_output(
712-
flash_attn_v4(
713-
q,
714-
k,
715-
v,
716-
causal=True,
717-
softmax_scale=None,
718-
window_size=(None, None),
719-
)
720-
)
735+
y = _execute_dao_flash_v4_eager(q, k, v)
721736
else:
722737
raise NotImplementedError(f"Attention implementation {attention_impl} not supported")
723738
return y # (B, T, nh_q, hd)
724739

725-
@staticmethod
726-
def _unwrap_flash_attention_output(
727-
output: torch.Tensor | tuple[torch.Tensor, Optional[torch.Tensor]],
728-
) -> torch.Tensor:
729-
if isinstance(output, tuple):
730-
return output[0]
731-
return output
732-
733740
@staticmethod
734741
def _execute_dao_flash_v2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout: float) -> torch.Tensor:
735742
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)

src/modalities/models/model_factory.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
GPT2LLM,
3636
AttentionConfig,
3737
AttentionImplementation,
38+
GPT2Block,
3839
LayerNormWrapperConfig,
3940
PositionTypes,
4041
SwiGLU,
@@ -61,6 +62,14 @@
6162
class ModelFactory:
6263
"""Model factory class to create models."""
6364

65+
@staticmethod
66+
def _requires_graph_break_friendly_compile(module: nn.Module) -> bool:
67+
if isinstance(module, GPT2Block):
68+
return module.attn.attention_impl == AttentionImplementation.DAO_FLASH_V4
69+
70+
attention_impl = getattr(module, "attention_impl", None)
71+
return attention_impl == AttentionImplementation.DAO_FLASH_V4
72+
6473
@staticmethod
6574
def _is_model_on_meta_device(model: nn.Module) -> bool:
6675
"""
@@ -402,7 +411,16 @@ def get_parent_module_and_child_name(child_module: nn.Module, model: nn.Module)
402411
for _, module in model.named_modules():
403412
if isinstance(module, block_types):
404413
options = {"trace.enabled": True} if debug else {}
405-
compiled_module = torch.compile(module, fullgraph=fullgraph, options=options)
414+
compiled_fullgraph = fullgraph
415+
if compiled_fullgraph and ModelFactory._requires_graph_break_friendly_compile(module):
416+
compiled_fullgraph = False
417+
logger.warning(
418+
"Disabling `fullgraph=True` for `%s` because FlashAttention-4 currently graph-breaks under "
419+
"torch.compile when tracing into flash_attn.cute internals.",
420+
module.__class__.__name__,
421+
)
422+
423+
compiled_module = torch.compile(module, fullgraph=compiled_fullgraph, options=options)
406424
parent_module, child_name = get_parent_module_and_child_name(child_module=module, model=model)
407425
parent_module.register_module(name=child_name, module=compiled_module)
408426
return model

tests/test_torch_compile.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import copy
2+
from typing import Any, cast
23

34
import pytest
5+
import torch
46
import torch.nn as nn
7+
from _pytest.monkeypatch import MonkeyPatch
58

9+
from modalities.models.components.layer_norms import LayerNormConfig
610
from modalities.models.gpt2.gpt2_model import (
711
GPT2LLM,
812
ActivationType,
@@ -12,27 +16,31 @@
1216
LayerNormWrapperConfig,
1317
PositionTypes,
1418
QueryKeyValueTransformType,
19+
is_flash_attn_v4_available,
1520
)
1621
from modalities.models.model_factory import ModelFactory
1722

1823

19-
def create_gpt2_configs():
24+
def create_gpt2_configs() -> tuple[AttentionConfig, LayerNormWrapperConfig]:
2025
attention_config = AttentionConfig(
2126
qkv_transforms=[
2227
AttentionConfig.QueryKeyValueTransformConfig(
23-
type_hint=QueryKeyValueTransformType.RotaryTransform.name,
28+
type_hint=cast(Any, QueryKeyValueTransformType.RotaryTransform.name),
2429
config=AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig(
2530
n_embd=512, n_head=8, seq_length_dim=-2, base_freq=10000
2631
),
2732
)
2833
]
2934
)
30-
norm_config = LayerNormWrapperConfig(norm_type=LayerNorms.layer_norm, config={"normalized_shape": 512})
35+
norm_config = LayerNormWrapperConfig(
36+
norm_type=LayerNorms.layer_norm,
37+
config=LayerNormConfig(normalized_shape=512, eps=1e-6, elementwise_affine=True, bias=True),
38+
)
3139
return attention_config, norm_config
3240

3341

3442
@pytest.fixture
35-
def gpt2_model():
43+
def gpt2_model() -> GPT2LLM:
3644
attention_config, norm_config = create_gpt2_configs()
3745
model = GPT2LLM(
3846
sample_key="input_ids",
@@ -58,7 +66,7 @@ def gpt2_model():
5866
return model
5967

6068

61-
def test_get_compiled_model_compiles_blocks(gpt2_model):
69+
def test_get_compiled_model_compiles_blocks(gpt2_model: GPT2LLM) -> None:
6270
original_model = copy.deepcopy(gpt2_model)
6371
original_wte = gpt2_model.transformer.wte
6472
original_lm_head = gpt2_model.transformer.lm_head
@@ -79,7 +87,7 @@ def test_get_compiled_model_compiles_blocks(gpt2_model):
7987
assert result_model is gpt2_model, "Should return the same model instance"
8088

8189

82-
def test_get_compiled_model_no_matching_blocks(gpt2_model):
90+
def test_get_compiled_model_no_matching_blocks(gpt2_model: GPT2LLM) -> None:
8391
"""
8492
Test that get_compiled_model raises a ValueError if no blocks match the specified types.
8593
"""
@@ -88,10 +96,28 @@ def test_get_compiled_model_no_matching_blocks(gpt2_model):
8896
ModelFactory.get_compiled_model(gpt2_model, block_names=[block_name], fullgraph=True)
8997

9098

91-
def test_get_compiled_model_empty_block_names(gpt2_model):
99+
def test_get_compiled_model_empty_block_names(gpt2_model: GPT2LLM) -> None:
92100
original_model_dict = dict(gpt2_model.named_modules())
93101
result_model = ModelFactory.get_compiled_model(gpt2_model, block_names=[], fullgraph=True)
94102

95103
new_model_dict = dict(result_model.named_modules())
96104
assert new_model_dict == original_model_dict, "Model should remain unchanged with empty block_names"
97105
assert result_model is gpt2_model, "Should return the same model instance"
106+
107+
108+
@pytest.mark.skipif(not is_flash_attn_v4_available(), reason="FA4 not installed")
109+
def test_get_compiled_model_disables_fullgraph_for_fa4(monkeypatch: MonkeyPatch, gpt2_model: GPT2LLM) -> None:
110+
recorded_fullgraph_values: list[bool] = []
111+
112+
for block in gpt2_model.transformer.h.values():
113+
block.attn.attention_impl = AttentionImplementation.DAO_FLASH_V4
114+
115+
def fake_compile(module: nn.Module, fullgraph: bool, options: dict[str, object]) -> nn.Module:
116+
recorded_fullgraph_values.append(fullgraph)
117+
return module
118+
119+
monkeypatch.setattr(torch, "compile", fake_compile)
120+
121+
ModelFactory.get_compiled_model(gpt2_model, ["GPT2Block"], fullgraph=True)
122+
123+
assert recorded_fullgraph_values == [False] * len(gpt2_model.transformer.h)

0 commit comments

Comments
 (0)