Skip to content

Commit 8c75d37

Browse files
committed
final pipeline fix 1
1 parent aedbf16 commit 8c75d37

1 file changed

Lines changed: 97 additions & 30 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 97 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,52 @@ def get_dummy_ltx2_inputs(config, pipeline, batch_size):
184184
return (latents, audio_latents, timesteps, encoder_hidden_states, audio_encoder_hidden_states, encoder_attention_mask, audio_encoder_attention_mask)
185185

186186

187+
188+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
189+
def calculate_shift(
190+
image_seq_len,
191+
base_seq_len: int = 256,
192+
max_seq_len: int = 4096,
193+
base_shift: float = 0.5,
194+
max_shift: float = 1.15,
195+
):
196+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
197+
b = base_shift - m * base_seq_len
198+
mu = image_seq_len * m + b
199+
return mu
200+
201+
def retrieve_timesteps(
202+
scheduler,
203+
scheduler_state,
204+
num_inference_steps: Optional[int] = None,
205+
timesteps: Optional[List[int]] = None,
206+
sigmas: Optional[List[float]] = None,
207+
**kwargs,
208+
):
209+
if timesteps is not None and sigmas is not None:
210+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
211+
212+
if timesteps is not None:
213+
# TODO: Support custom timesteps in FlaxFlowMatchScheduler
214+
raise NotImplementedError("Custom timesteps not yet supported in FlaxFlowMatchScheduler wrapper.")
215+
elif sigmas is not None:
216+
# Manually create state with custom sigmas
217+
# Replicates logic from diffusers but for Flax state
218+
sigmas = jnp.array(sigmas, dtype=scheduler.dtype)
219+
# Assuming scheduler.config.num_train_timesteps exists
220+
timesteps = sigmas * scheduler.config.num_train_timesteps
221+
222+
# We need to update the state with these new values
223+
scheduler_state = scheduler_state.replace(
224+
sigmas=sigmas,
225+
timesteps=timesteps,
226+
num_inference_steps=len(sigmas)
227+
)
228+
else:
229+
scheduler_state = scheduler.set_timesteps(scheduler_state, num_inference_steps, **kwargs)
230+
231+
return scheduler_state
232+
187233
class LTX2Pipeline:
188234
"""
189235
Pipeline for LTX-2.
@@ -194,7 +240,7 @@ def __init__(
194240
scheduler: FlaxFlowMatchScheduler,
195241
vae: LTX2VideoAutoencoderKL,
196242
audio_vae: FlaxAutoencoderKLLTX2Audio,
197-
text_encoder: Any, # Placeholder for Gemma3
243+
text_encoder: Gemma3ForConditionalGeneration, # Using PyTorch Gemma3 encoder directly per user request
198244
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
199245
connectors: LTX2AudioVideoGemmaTextEncoder,
200246
transformer: LTX2VideoTransformer3DModel,
@@ -717,6 +763,7 @@ def prepare_audio_latents(
717763
dtype: Optional[jnp.dtype] = None,
718764
generator: Optional[jax.Array] = None,
719765
latents: Optional[jax.Array] = None,
766+
num_mel_bins: Optional[int] = None,
720767
) -> jax.Array:
721768
if latents is not None:
722769
# Assuming latents is JAX array or compatible
@@ -822,14 +869,32 @@ def __call__(
822869
batch_size=batch_size,
823870
num_channels_latents=audio_channels,
824871
audio_latent_length=audio_num_frames,
872+
noise_scale=noise_scale,
825873
dtype=dtype,
826874
generator=key_audio,
827-
noise_scale=noise_scale,
875+
latents=audio_latents,
828876
)
829877

830878
# 5. Prepare Timesteps
831-
scheduler_state = self.scheduler.set_timesteps(
832-
self.scheduler.create_state(), num_inference_steps=num_inference_steps
879+
sigmas = jnp.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
880+
881+
video_sequence_length = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
882+
video_sequence_length *= (height // self.vae_spatial_compression_ratio) * (width // self.vae_spatial_compression_ratio)
883+
884+
mu = calculate_shift(
885+
video_sequence_length,
886+
self.scheduler.config.get("base_image_seq_len", 1024),
887+
self.scheduler.config.get("max_image_seq_len", 4096),
888+
self.scheduler.config.get("base_shift", 0.95),
889+
self.scheduler.config.get("max_shift", 2.05),
890+
)
891+
892+
scheduler_state = retrieve_timesteps(
893+
self.scheduler,
894+
self.scheduler.create_state(),
895+
num_inference_steps=num_inference_steps,
896+
sigmas=sigmas,
897+
shift=mu,
833898
)
834899
timesteps = scheduler_state.timesteps
835900

@@ -941,10 +1006,26 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
9411006
if output_type == "latent":
9421007
return LTX2PipelineOutput(frames=latents, audio=audio_latents)
9431008

944-
# Decode Video
945-
# Assuming VAE runs in half precision if configured, but here typically we interpret float32 from scheduler
946-
latents = latents.astype(self.vae.dtype)
947-
video = self.vae.decode(latents, return_dict=False)[0]
1009+
if getattr(self.vae.config, "timestep_conditioning", False):
1010+
noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype)
1011+
1012+
if not isinstance(decode_timestep, list):
1013+
decode_timestep = [decode_timestep] * batch_size
1014+
if decode_noise_scale is None:
1015+
decode_noise_scale = decode_timestep
1016+
elif not isinstance(decode_noise_scale, list):
1017+
decode_noise_scale = [decode_noise_scale] * batch_size
1018+
1019+
timestep = jnp.array(decode_timestep, dtype=latents.dtype)
1020+
decode_noise_scale = jnp.array(decode_noise_scale, dtype=latents.dtype)[:, None, None, None, None]
1021+
1022+
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
1023+
1024+
latents = latents.astype(self.vae.dtype)
1025+
video = self.vae.decode(latents, timestep=timestep, return_dict=False)[0]
1026+
else:
1027+
latents = latents.astype(self.vae.dtype)
1028+
video = self.vae.decode(latents, return_dict=False)[0]
9481029
# Post-process video (converts to numpy/PIL)
9491030
# We need to pass numpy to postprocess_video usually, checking if it handles JAX
9501031
video_np = np.array(video)
@@ -981,37 +1062,23 @@ def transformer_forward_pass(
9811062
):
9821063
transformer = nnx.merge(graphdef, state)
9831064

984-
# 1. Compute Embeddings
985-
temb = transformer.time_embed(timestep)
986-
temb_audio = transformer.audio_time_embed(timestep)
987-
988-
temb_ca_scale_shift = transformer.av_cross_attn_video_scale_shift(timestep)
989-
temb_ca_audio_scale_shift = transformer.av_cross_attn_audio_scale_shift(timestep)
990-
temb_ca_gate = transformer.av_cross_attn_video_a2v_gate(timestep)
991-
temb_ca_audio_gate = transformer.av_cross_attn_audio_v2a_gate(timestep)
992-
1065+
# Expand timestep to batch size
1066+
timestep = jnp.expand_dims(timestep, 0).repeat(latents.shape[0])
1067+
9931068
noise_pred, noise_pred_audio = transformer(
9941069
hidden_states=latents,
995-
audio_hidden_states=audio_latents,
9961070
encoder_hidden_states=encoder_hidden_states,
997-
audio_encoder_hidden_states=audio_encoder_hidden_states,
998-
temb=temb,
999-
temb_audio=temb_audio,
1000-
temb_ca_scale_shift=temb_ca_scale_shift,
1001-
temb_ca_audio_scale_shift=temb_ca_audio_scale_shift,
1002-
temb_ca_gate=temb_ca_gate,
1003-
temb_ca_audio_gate=temb_ca_audio_gate,
1004-
video_rotary_emb=None, # Internally computed via height/width/num_frames
1005-
audio_rotary_emb=None,
1006-
ca_video_rotary_emb=None,
1007-
ca_audio_rotary_emb=None,
1071+
timestep=timestep,
10081072
encoder_attention_mask=encoder_attention_mask,
1009-
audio_encoder_attention_mask=audio_encoder_attention_mask,
10101073
num_frames=num_frames,
10111074
height=height,
10121075
width=width,
1076+
audio_hidden_states=audio_latents,
1077+
audio_encoder_hidden_states=audio_encoder_hidden_states,
1078+
audio_encoder_attention_mask=audio_encoder_attention_mask,
10131079
fps=fps,
10141080
audio_num_frames=audio_num_frames,
1081+
return_dict=False,
10151082
)
10161083

10171084
return noise_pred, noise_pred_audio

0 commit comments

Comments
 (0)