@@ -29,19 +29,29 @@ def test_load_transformer_weights(self):
2929 pretrained_model_name_or_path = "Lightricks/LTX-2"
3030
3131 with jax .default_device (jax .devices ("cpu" )[0 ]):
32- model = LTX2VideoTransformer3DModel (
33- rngs = self .rngs ,
34- # Explicitly setting key params to version 2.0 to be safe
35- in_channels = 128 ,
36- out_channels = 128 ,
37- patch_size = 1 ,
38- patch_size_t = 1 ,
39- num_attention_heads = 32 ,
40- attention_head_dim = 128 ,
41- cross_attention_dim = 4096 ,
42- num_layers = 48 ,
43- scan_layers = True
44- )
32+ self .config = LTX2VideoConfig ()
33+ self .config .audio_attention_head_dim = 128 # Match Checkpoint
34+
35+ self .transformer = LTX2VideoTransformer3DModel (
36+ in_channels = self .config .in_channels ,
37+ out_channels = self .config .out_channels ,
38+ patch_size = self .config .patch_size ,
39+ patch_size_t = self .config .patch_size_t ,
40+ num_attention_heads = self .config .num_attention_heads ,
41+ attention_head_dim = self .config .attention_head_dim ,
42+ cross_attention_dim = self .config .cross_attention_dim ,
43+ audio_in_channels = self .config .audio_in_channels ,
44+ audio_out_channels = self .config .audio_out_channels ,
45+ audio_patch_size = self .config .audio_patch_size ,
46+ audio_patch_size_t = self .config .audio_patch_size_t ,
47+ audio_num_attention_heads = self .config .audio_num_attention_heads ,
48+ audio_attention_head_dim = 128 , # Match Config/Checkpoint
49+ audio_cross_attention_dim = self .config .audio_cross_attention_dim ,
50+ num_layers = self .config .num_layers ,
51+ scan_layers = True ,
52+ param_dtype = jnp .bfloat16 ,
53+ rngs = nnx .Rngs (0 ),
54+ )
4555
4656 # Get abstract state (shapes only)
4757 # We need the PyTree structure of parameters
0 commit comments