@@ -86,6 +86,7 @@ def test_transformer_block_shapes(self):
8686 audio_cross_attention_dim = cross_dim ,
8787 activation_fn = "gelu" ,
8888 qk_norm = "rms_norm_across_heads" ,
89+ mesh = self .mesh ,
8990 )
9091
9192 # Create dummy inputs
@@ -188,7 +189,8 @@ def test_transformer_3d_model_instantiation_and_forward(self):
188189 audio_out_channels = self .audio_in_channels ,
189190 audio_num_attention_heads = 2 ,
190191 audio_attention_head_dim = 16 ,
191- audio_cross_attention_dim = 32
192+ audio_cross_attention_dim = 32 ,
193+ mesh = self .mesh ,
192194 )
193195
194196 # Inputs
@@ -283,9 +285,9 @@ def test_scan_remat_parity(self):
283285
284286 # 1. Initialize models
285287 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
286- model_scan = LTX2VideoTransformer3DModel (** args , scan_layers = True )
287- model_loop = LTX2VideoTransformer3DModel (** args , scan_layers = False )
288- model_remat = LTX2VideoTransformer3DModel (** args , scan_layers = True , remat_policy = "full" )
288+ model_scan = LTX2VideoTransformer3DModel (** args , scan_layers = True , mesh = self . mesh )
289+ model_loop = LTX2VideoTransformer3DModel (** args , scan_layers = False , mesh = self . mesh )
290+ model_remat = LTX2VideoTransformer3DModel (** args , scan_layers = True , remat_policy = "full" , mesh = self . mesh )
289291
290292 # 2. Sync weights (crucial for parity)
291293 # We can just copy params from scan to loop/remat
0 commit comments