File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -31,8 +31,8 @@ guidance_scale: 3.0
3131guidance_rescale : 0.7
3232audio_guidance_scale : 7.0
3333audio_guidance_rescale : 0.7
34- stg_scale : 0 .0
35- audio_stg_scale : 0 .0
34+ stg_scale : 1 .0
35+ audio_stg_scale : 1 .0
3636modality_scale : 1.0
3737audio_modality_scale : 1.0
3838use_cross_timestep : true
Original file line number Diff line number Diff line change @@ -388,7 +388,7 @@ def __call__(
388388 if encoder_hidden_states is not None :
389389 encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , axis_names )
390390 if audio_encoder_hidden_states is not None :
391- audio_encoder_hidden_states = jax .lax .with_sharding_constraint (audio_encoder_hidden_states , axis_names )
391+ audio_encoder_hidden_states = jax .lax .with_sharding_constraint (audio_encoder_hidden_states , axis_names_audio )
392392
393393 # 1. Video and Audio Self-Attention
394394 norm_hidden_states = self .norm1 (hidden_states )
You can’t perform that action at this time.
0 commit comments