@@ -125,7 +125,7 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
125125 "audio_attention_head_dim" : 64 ,
126126 "audio_cross_attention_dim" : 2048 ,
127127 "num_layers" : 48 ,
128- "caption_channels" : 4096 ,
128+ "caption_channels" : 3840 ,
129129 "audio_caption_channels" : 2048 ,
130130 "use_prompt_embeddings" : False ,
131131 }
@@ -365,7 +365,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
365365 {
366366 "video_connector_num_layers" : 8 ,
367367 "audio_connector_num_layers" : 8 ,
368- "caption_channels" : 2048 ,
368+ "caption_channels" : 3840 ,
369369 "video_caption_channels" : 4096 ,
370370 "audio_caption_channels" : 2048 ,
371371 "video_connector_num_attention_heads" : 32 ,
@@ -1264,6 +1264,11 @@ def __call__(
12641264 timesteps : List [int ] = None ,
12651265 guidance_scale : float = 3.0 ,
12661266 guidance_rescale : float = 0.0 ,
1267+ stg_scale : float = 0.0 ,
1268+ modality_scale : float = 1.0 ,
1269+ audio_guidance_scale : Optional [float ] = None ,
1270+ audio_stg_scale : Optional [float ] = None ,
1271+ audio_modality_scale : Optional [float ] = None ,
12671272 noise_scale : float = 1.0 ,
12681273 num_videos_per_prompt : Optional [int ] = 1 ,
12691274 generator : Optional [jax .Array ] = None ,
@@ -1279,6 +1284,7 @@ def __call__(
12791284 dtype : Optional [jnp .dtype ] = None ,
12801285 output_type : str = "pil" ,
12811286 return_dict : bool = True ,
1287+ use_cross_timestep : bool = False ,
12821288 ):
12831289 # 1. Check inputs
12841290 self .check_inputs (
@@ -1499,23 +1505,24 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
14991505 audio_num_frames ,
15001506 frame_rate ,
15011507 perturbation_mask = perturbation_mask ,
1508+ use_cross_timestep = use_cross_timestep ,
15021509 )
15031510
1504- do_stg = getattr ( self . config , " stg_scale" , 0.0 ) > 0.0
1511+ do_stg = stg_scale > 0.0
15051512
15061513 if guidance_scale > 1.0 and do_stg :
15071514 noise_pred_uncond , noise_pred_text , noise_pred_perturb = jnp .split (noise_pred , 3 , axis = 0 )
15081515 noise_pred = (
15091516 noise_pred_uncond
15101517 + guidance_scale * (noise_pred_text - noise_pred_uncond )
1511- + self . config . stg_scale * (noise_pred_text - noise_pred_perturb )
1518+ + stg_scale * (noise_pred_text - noise_pred_perturb )
15121519 )
15131520 # Audio guidance
15141521 noise_pred_audio_uncond , noise_pred_audio_text , noise_pred_audio_perturb = jnp .split (noise_pred_audio , 3 , axis = 0 )
15151522 noise_pred_audio = (
15161523 noise_pred_audio_uncond
15171524 + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
1518- + self . config . stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb )
1525+ + stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb )
15191526 )
15201527 elif guidance_scale > 1.0 :
15211528 noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
@@ -1525,10 +1532,10 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
15251532 noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
15261533 elif do_stg :
15271534 noise_pred_text , noise_pred_perturb = jnp .split (noise_pred , 2 , axis = 0 )
1528- noise_pred = noise_pred_text + self . config . stg_scale * (noise_pred_text - noise_pred_perturb )
1535+ noise_pred = noise_pred_text + stg_scale * (noise_pred_text - noise_pred_perturb )
15291536
15301537 noise_pred_audio_text , noise_pred_audio_perturb = jnp .split (noise_pred_audio , 2 , axis = 0 )
1531- noise_pred_audio = noise_pred_audio_text + self . config . stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb )
1538+ noise_pred_audio = noise_pred_audio_text + stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb )
15321539
15331540 # Extract latents_step based on stacking strategy
15341541 if do_cfg and do_stg :
@@ -1693,6 +1700,8 @@ def transformer_forward_pass(
16931700 fps ,
16941701 perturbation_mask = None ,
16951702 sigma = None ,
1703+ audio_sigma = None ,
1704+ use_cross_timestep = False ,
16961705):
16971706 transformer = nnx .merge (graphdef , state )
16981707
@@ -1704,11 +1713,17 @@ def transformer_forward_pass(
17041713 else :
17051714 sigma = jnp .expand_dims (sigma , 0 ).repeat (latents .shape [0 ])
17061715
1716+ if audio_sigma is None :
1717+ audio_sigma = timestep
1718+ else :
1719+ audio_sigma = jnp .expand_dims (audio_sigma , 0 ).repeat (latents .shape [0 ])
1720+
17071721 noise_pred , noise_pred_audio = transformer (
17081722 hidden_states = latents ,
17091723 encoder_hidden_states = encoder_hidden_states ,
17101724 timestep = timestep ,
17111725 sigma = sigma ,
1726+ audio_sigma = audio_sigma ,
17121727 encoder_attention_mask = encoder_attention_mask ,
17131728 num_frames = latent_num_frames ,
17141729 height = latent_height ,
@@ -1720,6 +1735,7 @@ def transformer_forward_pass(
17201735 audio_num_frames = audio_num_frames ,
17211736 return_dict = False ,
17221737 perturbation_mask = perturbation_mask ,
1738+ use_cross_timestep = use_cross_timestep ,
17231739 )
17241740
17251741 return noise_pred , noise_pred_audio
0 commit comments