@@ -184,6 +184,52 @@ def get_dummy_ltx2_inputs(config, pipeline, batch_size):
184184 return (latents , audio_latents , timesteps , encoder_hidden_states , audio_encoder_hidden_states , encoder_attention_mask , audio_encoder_attention_mask )
185185
186186
187+
188+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
189+ def calculate_shift (
190+ image_seq_len ,
191+ base_seq_len : int = 256 ,
192+ max_seq_len : int = 4096 ,
193+ base_shift : float = 0.5 ,
194+ max_shift : float = 1.15 ,
195+ ):
196+ m = (max_shift - base_shift ) / (max_seq_len - base_seq_len )
197+ b = base_shift - m * base_seq_len
198+ mu = image_seq_len * m + b
199+ return mu
200+
201+ def retrieve_timesteps (
202+ scheduler ,
203+ scheduler_state ,
204+ num_inference_steps : Optional [int ] = None ,
205+ timesteps : Optional [List [int ]] = None ,
206+ sigmas : Optional [List [float ]] = None ,
207+ ** kwargs ,
208+ ):
209+ if timesteps is not None and sigmas is not None :
210+ raise ValueError ("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" )
211+
212+ if timesteps is not None :
213+ # TODO: Support custom timesteps in FlaxFlowMatchScheduler
214+ raise NotImplementedError ("Custom timesteps not yet supported in FlaxFlowMatchScheduler wrapper." )
215+ elif sigmas is not None :
216+ # Manually create state with custom sigmas
217+ # Replicates logic from diffusers but for Flax state
218+ sigmas = jnp .array (sigmas , dtype = scheduler .dtype )
219+ # Assuming scheduler.config.num_train_timesteps exists
220+ timesteps = sigmas * scheduler .config .num_train_timesteps
221+
222+ # We need to update the state with these new values
223+ scheduler_state = scheduler_state .replace (
224+ sigmas = sigmas ,
225+ timesteps = timesteps ,
226+ num_inference_steps = len (sigmas )
227+ )
228+ else :
229+ scheduler_state = scheduler .set_timesteps (scheduler_state , num_inference_steps , ** kwargs )
230+
231+ return scheduler_state
232+
187233class LTX2Pipeline :
188234 """
189235 Pipeline for LTX-2.
@@ -194,7 +240,7 @@ def __init__(
194240 scheduler : FlaxFlowMatchScheduler ,
195241 vae : LTX2VideoAutoencoderKL ,
196242 audio_vae : FlaxAutoencoderKLLTX2Audio ,
197- text_encoder : Any , # Placeholder for Gemma3
243+ text_encoder : Gemma3ForConditionalGeneration , # Using PyTorch Gemma3 encoder directly per user request
198244 tokenizer : Union [GemmaTokenizer , GemmaTokenizerFast ],
199245 connectors : LTX2AudioVideoGemmaTextEncoder ,
200246 transformer : LTX2VideoTransformer3DModel ,
@@ -717,6 +763,7 @@ def prepare_audio_latents(
717763 dtype : Optional [jnp .dtype ] = None ,
718764 generator : Optional [jax .Array ] = None ,
719765 latents : Optional [jax .Array ] = None ,
766+ num_mel_bins : Optional [int ] = None ,
720767 ) -> jax .Array :
721768 if latents is not None :
722769 # Assuming latents is JAX array or compatible
@@ -822,14 +869,32 @@ def __call__(
822869 batch_size = batch_size ,
823870 num_channels_latents = audio_channels ,
824871 audio_latent_length = audio_num_frames ,
872+ noise_scale = noise_scale ,
825873 dtype = dtype ,
826874 generator = key_audio ,
827- noise_scale = noise_scale ,
875+ latents = audio_latents ,
828876 )
829877
830878 # 5. Prepare Timesteps
831- scheduler_state = self .scheduler .set_timesteps (
832- self .scheduler .create_state (), num_inference_steps = num_inference_steps
879+ sigmas = jnp .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
880+
881+ video_sequence_length = (num_frames - 1 ) // self .vae_temporal_compression_ratio + 1
882+ video_sequence_length *= (height // self .vae_spatial_compression_ratio ) * (width // self .vae_spatial_compression_ratio )
883+
884+ mu = calculate_shift (
885+ video_sequence_length ,
886+ self .scheduler .config .get ("base_image_seq_len" , 1024 ),
887+ self .scheduler .config .get ("max_image_seq_len" , 4096 ),
888+ self .scheduler .config .get ("base_shift" , 0.95 ),
889+ self .scheduler .config .get ("max_shift" , 2.05 ),
890+ )
891+
892+ scheduler_state = retrieve_timesteps (
893+ self .scheduler ,
894+ self .scheduler .create_state (),
895+ num_inference_steps = num_inference_steps ,
896+ sigmas = sigmas ,
897+ shift = mu ,
833898 )
834899 timesteps = scheduler_state .timesteps
835900
@@ -941,10 +1006,26 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
9411006 if output_type == "latent" :
9421007 return LTX2PipelineOutput (frames = latents , audio = audio_latents )
9431008
944- # Decode Video
945- # Assuming VAE runs in half precision if configured, but here typically we interpret float32 from scheduler
946- latents = latents .astype (self .vae .dtype )
947- video = self .vae .decode (latents , return_dict = False )[0 ]
1009+ if getattr (self .vae .config , "timestep_conditioning" , False ):
1010+ noise = jax .random .normal (generator , latents .shape , dtype = latents .dtype )
1011+
1012+ if not isinstance (decode_timestep , list ):
1013+ decode_timestep = [decode_timestep ] * batch_size
1014+ if decode_noise_scale is None :
1015+ decode_noise_scale = decode_timestep
1016+ elif not isinstance (decode_noise_scale , list ):
1017+ decode_noise_scale = [decode_noise_scale ] * batch_size
1018+
1019+ timestep = jnp .array (decode_timestep , dtype = latents .dtype )
1020+ decode_noise_scale = jnp .array (decode_noise_scale , dtype = latents .dtype )[:, None , None , None , None ]
1021+
1022+ latents = (1 - decode_noise_scale ) * latents + decode_noise_scale * noise
1023+
1024+ latents = latents .astype (self .vae .dtype )
1025+ video = self .vae .decode (latents , timestep = timestep , return_dict = False )[0 ]
1026+ else :
1027+ latents = latents .astype (self .vae .dtype )
1028+ video = self .vae .decode (latents , return_dict = False )[0 ]
9481029 # Post-process video (converts to numpy/PIL)
9491030 # We need to pass numpy to postprocess_video usually, checking if it handles JAX
9501031 video_np = np .array (video )
@@ -981,37 +1062,23 @@ def transformer_forward_pass(
9811062):
9821063 transformer = nnx .merge (graphdef , state )
9831064
984- # 1. Compute Embeddings
985- temb = transformer .time_embed (timestep )
986- temb_audio = transformer .audio_time_embed (timestep )
987-
988- temb_ca_scale_shift = transformer .av_cross_attn_video_scale_shift (timestep )
989- temb_ca_audio_scale_shift = transformer .av_cross_attn_audio_scale_shift (timestep )
990- temb_ca_gate = transformer .av_cross_attn_video_a2v_gate (timestep )
991- temb_ca_audio_gate = transformer .av_cross_attn_audio_v2a_gate (timestep )
992-
1065+ # Expand timestep to batch size
1066+ timestep = jnp .expand_dims (timestep , 0 ).repeat (latents .shape [0 ])
1067+
9931068 noise_pred , noise_pred_audio = transformer (
9941069 hidden_states = latents ,
995- audio_hidden_states = audio_latents ,
9961070 encoder_hidden_states = encoder_hidden_states ,
997- audio_encoder_hidden_states = audio_encoder_hidden_states ,
998- temb = temb ,
999- temb_audio = temb_audio ,
1000- temb_ca_scale_shift = temb_ca_scale_shift ,
1001- temb_ca_audio_scale_shift = temb_ca_audio_scale_shift ,
1002- temb_ca_gate = temb_ca_gate ,
1003- temb_ca_audio_gate = temb_ca_audio_gate ,
1004- video_rotary_emb = None , # Internally computed via height/width/num_frames
1005- audio_rotary_emb = None ,
1006- ca_video_rotary_emb = None ,
1007- ca_audio_rotary_emb = None ,
1071+ timestep = timestep ,
10081072 encoder_attention_mask = encoder_attention_mask ,
1009- audio_encoder_attention_mask = audio_encoder_attention_mask ,
10101073 num_frames = num_frames ,
10111074 height = height ,
10121075 width = width ,
1076+ audio_hidden_states = audio_latents ,
1077+ audio_encoder_hidden_states = audio_encoder_hidden_states ,
1078+ audio_encoder_attention_mask = audio_encoder_attention_mask ,
10131079 fps = fps ,
10141080 audio_num_frames = audio_num_frames ,
1081+ return_dict = False ,
10151082 )
10161083
10171084 return noise_pred , noise_pred_audio
0 commit comments