@@ -91,52 +91,30 @@ def setUp(self):
9191
9292 try :
9393 from maxdiffusion .pipelines .ltx2 .ltx2_pipeline import MaxTextGemma3FeatureExtractor
94- from MaxText import common_types
95- # Partial mock for config if needed
96- class MockConfig :
97- vocab_size = 32000
98- emb_dim = 16
99- num_layers = 5
100- num_heads = 2
101- head_dim = 8
102- mlp_dim = 32
103- dtype = jnp .float32
104- weights_dtype = jnp .float32
105- weight_dtype = jnp .float32
106- use_iota_embed = False
107- num_decoder_layers = 5 # Match num_layers
108- normalization_layer_epsilon = 1e-6
109- scan_layers = False
110- param_scan_axis = 1
111- max_prefill_predict_length = 512
112- per_device_batch_size = 1
113- max_target_length = 1024
114- rope_min_timescale = 1
115- rope_max_timescale = 10000
116- rope_type = "interleaved"
117- rope_embedding_dims = 16
118- rope_use_scale = False
119- model_name = "gemma3-4b" # Use a valid model name
120- base_emb_dim = 16
121- base_num_query_heads = 2
122- head_dim = 8
123- num_query_heads = 2
124- num_kv_heads = 1
125- dropout_rate = 0.0
126- float32_qk_product = False
127- float32_logits = False
128- sliding_window_size = 128
129- attn_logits_soft_cap = 50.0
130- use_post_attn_norm = True
131- attention = "dot_product" # attention_kernel
132- quantization = "" # for configure_kv_quant
133- quantize_kvcache = False
134- decoder_block = common_types .DecoderBlockType .GEMMA3
135- use_chunked_prefill = False
136- attention_type = "dot_product"
94+ from MaxText .configs import types
95+ from MaxText import pyconfig
96+
97+ # Use real MaxText Config to avoid missing attributes
98+ raw_config = types .MaxTextConfig (
99+ vocab_size = 32000 ,
100+ model_name = "gemma3-4b" ,
101+ base_emb_dim = 16 ,
102+ base_num_query_heads = 2 ,
103+ base_num_kv_heads = 1 ,
104+ head_dim = 8 ,
105+ num_decoder_layers = 5 ,
106+ max_prefill_predict_length = 512 ,
107+ max_target_length = 1024 ,
108+ per_device_batch_size = 1 ,
109+ dtype = jnp .float32 ,
110+ weight_dtype = jnp .float32 ,
111+ rope_embedding_dims = 16 ,
112+ # Add other overrides as needed for the test to be lightweight
113+ )
114+ config = pyconfig .HyperParameters (raw_config )
137115
138116 self .text_encoder = MaxTextGemma3FeatureExtractor (
139- config = MockConfig () ,
117+ config = config ,
140118 mesh = self .mesh ,
141119 quant = None ,
142120 rngs = self .rng
0 commit comments