11import copy
2+ from typing import Any , cast
23
34import pytest
5+ import torch
46import torch .nn as nn
7+ from _pytest .monkeypatch import MonkeyPatch
58
9+ from modalities .models .components .layer_norms import LayerNormConfig
610from modalities .models .gpt2 .gpt2_model import (
711 GPT2LLM ,
812 ActivationType ,
1216 LayerNormWrapperConfig ,
1317 PositionTypes ,
1418 QueryKeyValueTransformType ,
19+ is_flash_attn_v4_available ,
1520)
1621from modalities .models .model_factory import ModelFactory
1722
1823
19- def create_gpt2_configs ():
24+ def create_gpt2_configs () -> tuple [ AttentionConfig , LayerNormWrapperConfig ] :
2025 attention_config = AttentionConfig (
2126 qkv_transforms = [
2227 AttentionConfig .QueryKeyValueTransformConfig (
23- type_hint = QueryKeyValueTransformType .RotaryTransform .name ,
28+ type_hint = cast ( Any , QueryKeyValueTransformType .RotaryTransform .name ) ,
2429 config = AttentionConfig .QueryKeyValueTransformConfig .RotaryTransformConfig (
2530 n_embd = 512 , n_head = 8 , seq_length_dim = - 2 , base_freq = 10000
2631 ),
2732 )
2833 ]
2934 )
30- norm_config = LayerNormWrapperConfig (norm_type = LayerNorms .layer_norm , config = {"normalized_shape" : 512 })
35+ norm_config = LayerNormWrapperConfig (
36+ norm_type = LayerNorms .layer_norm ,
37+ config = LayerNormConfig (normalized_shape = 512 , eps = 1e-6 , elementwise_affine = True , bias = True ),
38+ )
3139 return attention_config , norm_config
3240
3341
3442@pytest .fixture
35- def gpt2_model ():
43+ def gpt2_model () -> GPT2LLM :
3644 attention_config , norm_config = create_gpt2_configs ()
3745 model = GPT2LLM (
3846 sample_key = "input_ids" ,
@@ -58,7 +66,7 @@ def gpt2_model():
5866 return model
5967
6068
61- def test_get_compiled_model_compiles_blocks (gpt2_model ) :
69+ def test_get_compiled_model_compiles_blocks (gpt2_model : GPT2LLM ) -> None :
6270 original_model = copy .deepcopy (gpt2_model )
6371 original_wte = gpt2_model .transformer .wte
6472 original_lm_head = gpt2_model .transformer .lm_head
@@ -79,7 +87,7 @@ def test_get_compiled_model_compiles_blocks(gpt2_model):
7987 assert result_model is gpt2_model , "Should return the same model instance"
8088
8189
82- def test_get_compiled_model_no_matching_blocks (gpt2_model ) :
90+ def test_get_compiled_model_no_matching_blocks (gpt2_model : GPT2LLM ) -> None :
8391 """
8492 Test that get_compiled_model raises a ValueError if no blocks match the specified types.
8593 """
@@ -88,10 +96,28 @@ def test_get_compiled_model_no_matching_blocks(gpt2_model):
8896 ModelFactory .get_compiled_model (gpt2_model , block_names = [block_name ], fullgraph = True )
8997
9098
91- def test_get_compiled_model_empty_block_names (gpt2_model ) :
99+ def test_get_compiled_model_empty_block_names (gpt2_model : GPT2LLM ) -> None :
92100 original_model_dict = dict (gpt2_model .named_modules ())
93101 result_model = ModelFactory .get_compiled_model (gpt2_model , block_names = [], fullgraph = True )
94102
95103 new_model_dict = dict (result_model .named_modules ())
96104 assert new_model_dict == original_model_dict , "Model should remain unchanged with empty block_names"
97105 assert result_model is gpt2_model , "Should return the same model instance"
106+
107+
108+ @pytest .mark .skipif (not is_flash_attn_v4_available (), reason = "FA4 not installed" )
109+ def test_get_compiled_model_disables_fullgraph_for_fa4 (monkeypatch : MonkeyPatch , gpt2_model : GPT2LLM ) -> None :
110+ recorded_fullgraph_values : list [bool ] = []
111+
112+ for block in gpt2_model .transformer .h .values ():
113+ block .attn .attention_impl = AttentionImplementation .DAO_FLASH_V4
114+
115+ def fake_compile (module : nn .Module , fullgraph : bool , options : dict [str , object ]) -> nn .Module :
116+ recorded_fullgraph_values .append (fullgraph )
117+ return module
118+
119+ monkeypatch .setattr (torch , "compile" , fake_compile )
120+
121+ ModelFactory .get_compiled_model (gpt2_model , ["GPT2Block" ], fullgraph = True )
122+
123+ assert recorded_fullgraph_values == [False ] * len (gpt2_model .transformer .h )
0 commit comments