|
1 | 1 | """ |
2 | | -Note: test_attention_types_approximate_equality can print the output of different attention implementations. |
| 2 | +Note: test_attention_types_approximate_equality can print the output of different attention implementations. |
3 | 3 | To do so, turn on verbose and run 'pytest tests/models/test_causal_self_attention.py -s' |
4 | 4 | """ |
| 5 | + |
| 6 | +import os |
| 7 | +import subprocess |
| 8 | +import sys |
| 9 | +import textwrap |
5 | 10 | from copy import deepcopy |
| 11 | +from importlib.util import find_spec |
| 12 | +from pathlib import Path |
6 | 13 |
|
7 | 14 | import pytest |
8 | 15 | import torch |
|
17 | 24 |
|
18 | 25 | torch.manual_seed(0) |
19 | 26 |
|
| 27 | +FLASH_ATTN_V4_AVAILABLE = find_spec("flash_attn.cute") is not None |
| 28 | +REPO_ROOT = Path(__file__).resolve().parents[2] |
| 29 | +SRC_ROOT = REPO_ROOT / "src" |
| 30 | + |
20 | 31 |
|
21 | 32 | def _get_random_input_seq(embedding_shape): |
22 | 33 | flash_attn_supported_dtype = torch.bfloat16 |
@@ -272,3 +283,146 @@ def test_qk_norm(n_head_q, n_head_kv, n_embd, attention_impl): |
272 | 283 |
|
273 | 284 | assert output_no_norm.shape == output_with_norm.shape == embedding_shape |
274 | 285 | assert not torch.allclose(output_no_norm, output_with_norm, atol=1e-6) |
| 286 | + |
| 287 | + |
| 288 | +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") |
| 289 | +@pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed") |
| 290 | +def test_dao_flash_v4_forward_mha_subprocess(): |
| 291 | + result = _run_fa4_subprocess( |
| 292 | + """ |
| 293 | + import torch |
| 294 | + from modalities.models.gpt2.gpt2_model import CausalSelfAttention |
| 295 | +
|
| 296 | + q = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device='cuda') |
| 297 | + k = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device='cuda') |
| 298 | + v = torch.rand(2, 4, 12, 32, dtype=torch.bfloat16, device='cuda') |
| 299 | + out = CausalSelfAttention.execute_attention(q, k, v, dropout=0.0, attention_impl='dao_flash_v4') |
| 300 | + torch.cuda.synchronize() |
| 301 | + assert tuple(out.shape) == (2, 12, 4, 32) |
| 302 | + print('ok') |
| 303 | + """ |
| 304 | + ) |
| 305 | + assert result.stdout.strip().endswith("ok") |
| 306 | + |
| 307 | + |
| 308 | +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") |
| 309 | +@pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed") |
| 310 | +def test_dao_flash_v4_forward_gqa_subprocess(): |
| 311 | + result = _run_fa4_subprocess( |
| 312 | + """ |
| 313 | + import torch |
| 314 | + from modalities.models.gpt2.gpt2_model import CausalSelfAttention |
| 315 | +
|
| 316 | + q = torch.rand(2, 8, 12, 32, dtype=torch.bfloat16, device='cuda') |
| 317 | + k = torch.rand(2, 2, 12, 32, dtype=torch.bfloat16, device='cuda') |
| 318 | + v = torch.rand(2, 2, 12, 32, dtype=torch.bfloat16, device='cuda') |
| 319 | + out = CausalSelfAttention.execute_attention(q, k, v, dropout=0.0, attention_impl='dao_flash_v4') |
| 320 | + torch.cuda.synchronize() |
| 321 | + assert tuple(out.shape) == (2, 12, 8, 32) |
| 322 | + print('ok') |
| 323 | + """ |
| 324 | + ) |
| 325 | + assert result.stdout.strip().endswith("ok") |
| 326 | + |
| 327 | + |
| 328 | +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") |
| 329 | +@pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed") |
| 330 | +def test_dao_flash_v4_qk_norm_subprocess(): |
| 331 | + result = _run_fa4_subprocess( |
| 332 | + """ |
| 333 | + import torch |
| 334 | + from modalities.models.gpt2.gpt2_model import ( |
| 335 | + AttentionConfig, |
| 336 | + CausalSelfAttention, |
| 337 | + LayerNorms, |
| 338 | + LayerNormWrapperConfig, |
| 339 | + PytorchRMSLayerNormConfig, |
| 340 | + ) |
| 341 | +
|
| 342 | + torch.manual_seed(0) |
| 343 | + attention_config_no_norm = AttentionConfig(qkv_transforms=[]) |
| 344 | + attention_config_with_norm = AttentionConfig( |
| 345 | + qkv_transforms=[], |
| 346 | + qk_norm_config=LayerNormWrapperConfig( |
| 347 | + norm_type=LayerNorms.pytorch_rms_norm, |
| 348 | + config=PytorchRMSLayerNormConfig(normalized_shape=8), |
| 349 | + ), |
| 350 | + ) |
| 351 | +
|
| 352 | + torch.manual_seed(0) |
| 353 | + layer_no_norm = CausalSelfAttention( |
| 354 | + 4, 4, 32, attention_config_no_norm, 'dao_flash_v4', False, 0.0 |
| 355 | + ).cuda().bfloat16() |
| 356 | + torch.manual_seed(0) |
| 357 | + layer_with_norm = CausalSelfAttention( |
| 358 | + 4, 4, 32, attention_config_with_norm, 'dao_flash_v4', False, 0.0 |
| 359 | + ).cuda().bfloat16() |
| 360 | + x = torch.rand((2, 9, 32), dtype=torch.bfloat16, device='cuda') |
| 361 | + out_no_norm = layer_no_norm(x) |
| 362 | + out_with_norm = layer_with_norm(x) |
| 363 | + torch.cuda.synchronize() |
| 364 | + assert out_no_norm.shape == out_with_norm.shape == (2, 9, 32) |
| 365 | + assert not torch.allclose(out_no_norm, out_with_norm, atol=1e-6) |
| 366 | + print('ok') |
| 367 | + """ |
| 368 | + ) |
| 369 | + assert result.stdout.strip().endswith("ok") |
| 370 | + |
| 371 | + |
| 372 | +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.") |
| 373 | +@pytest.mark.skipif(not FLASH_ATTN_V4_AVAILABLE, reason="FA4 not installed") |
| 374 | +def test_dao_flash_v4_backward_approximate_equality_subprocess(): |
| 375 | + result = _run_fa4_subprocess( |
| 376 | + """ |
| 377 | + import torch |
| 378 | + from modalities.models.gpt2.gpt2_model import CausalSelfAttention |
| 379 | +
|
| 380 | + query_ref = torch.rand((2, 8, 12, 64), dtype=torch.bfloat16, device='cuda', requires_grad=True) |
| 381 | + key_ref = torch.rand((2, 2, 12, 64), dtype=torch.bfloat16, device='cuda', requires_grad=True) |
| 382 | + value_ref = torch.rand((2, 2, 12, 64), dtype=torch.bfloat16, device='cuda', requires_grad=True) |
| 383 | +
|
| 384 | + query_fa4 = query_ref.detach().clone().requires_grad_(True) |
| 385 | + key_fa4 = key_ref.detach().clone().requires_grad_(True) |
| 386 | + value_fa4 = value_ref.detach().clone().requires_grad_(True) |
| 387 | +
|
| 388 | + output_ref = CausalSelfAttention.execute_attention( |
| 389 | + query_ref, key_ref, value_ref, dropout=0.0, attention_impl='pytorch_flash' |
| 390 | + ) |
| 391 | + output_fa4 = CausalSelfAttention.execute_attention( |
| 392 | + query_fa4, key_fa4, value_fa4, dropout=0.0, attention_impl='dao_flash_v4' |
| 393 | + ) |
| 394 | + torch.testing.assert_close(output_ref, output_fa4, atol=2.5e-3, rtol=0.016) |
| 395 | +
|
| 396 | + output_ref.float().sum().backward() |
| 397 | + output_fa4.float().sum().backward() |
| 398 | + torch.cuda.synchronize() |
| 399 | +
|
| 400 | + torch.testing.assert_close(query_ref.grad, query_fa4.grad, atol=5e-3, rtol=0.02) |
| 401 | + torch.testing.assert_close(key_ref.grad, key_fa4.grad, atol=5e-3, rtol=0.02) |
| 402 | + torch.testing.assert_close(value_ref.grad, value_fa4.grad, atol=5e-3, rtol=0.02) |
| 403 | + print('ok') |
| 404 | + """ |
| 405 | + ) |
| 406 | + assert result.stdout.strip().endswith("ok") |
| 407 | + |
| 408 | + |
| 409 | +def _run_fa4_subprocess(code: str) -> subprocess.CompletedProcess[str]: |
| 410 | + """Run flash attention 4 related code in a subprocess to isolate FA4's CUDA context |
| 411 | + and avoid conflicts with other tests. |
| 412 | + The code should print 'ok' if it runs successfully. |
| 413 | + The function returns the CompletedProcess object, |
| 414 | + which contains stdout and stderr for further inspection if needed. |
| 415 | + TODO: This might be an A100 / SM80-specific issue, so we can consider removing this subprocess isolation |
| 416 | + if we confirm that FA4 works well on newer architectures without it. |
| 417 | + """ |
| 418 | + env = os.environ.copy() |
| 419 | + existing_pythonpath = env.get("PYTHONPATH") |
| 420 | + env["PYTHONPATH"] = f"{SRC_ROOT}:{existing_pythonpath}" if existing_pythonpath else str(SRC_ROOT) |
| 421 | + return subprocess.run( |
| 422 | + [sys.executable, "-c", textwrap.dedent(code)], |
| 423 | + cwd=REPO_ROOT, |
| 424 | + env=env, |
| 425 | + check=True, |
| 426 | + capture_output=True, |
| 427 | + text=True, |
| 428 | + ) |
0 commit comments