Skip to content

Commit 9b4a16c

Browse files
committed
transformer made to work on latent height, width and frames
1 parent 885b441 commit 9b4a16c

1 file changed

Lines changed: 14 additions & 10 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,10 @@ def __call__(
11401140
latents=latents,
11411141
)
11421142

1143+
latent_height = height // self.vae_spatial_compression_ratio
1144+
latent_width = width // self.vae_spatial_compression_ratio
1145+
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
1146+
11431147
# 4. Prepare Audio Latents
11441148
audio_channels = (
11451149
self.audio_vae.config.latent_channels
@@ -1256,9 +1260,9 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12561260
prompt_attention_mask_jax,
12571261
guidance_scale > 1.0,
12581262
guidance_scale,
1259-
num_frames,
1260-
height,
1261-
width,
1263+
latent_num_frames,
1264+
latent_height,
1265+
latent_width,
12621266
audio_num_frames,
12631267
frame_rate,
12641268
)
@@ -1358,7 +1362,7 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13581362

13591363
return LTX2PipelineOutput(frames=video, audio=audio)
13601364

1361-
@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale", "num_frames", "height", "width", "audio_num_frames", "fps"))
1365+
@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale", "latent_num_frames", "latent_height", "latent_width", "audio_num_frames", "fps"))
13621366
def transformer_forward_pass(
13631367
graphdef,
13641368
state,
@@ -1371,9 +1375,9 @@ def transformer_forward_pass(
13711375
audio_encoder_attention_mask,
13721376
do_classifier_free_guidance,
13731377
guidance_scale,
1374-
num_frames,
1375-
height,
1376-
width,
1378+
latent_num_frames,
1379+
latent_height,
1380+
latent_width,
13771381
audio_num_frames,
13781382
fps,
13791383
):
@@ -1387,9 +1391,9 @@ def transformer_forward_pass(
13871391
encoder_hidden_states=encoder_hidden_states,
13881392
timestep=timestep,
13891393
encoder_attention_mask=encoder_attention_mask,
1390-
num_frames=num_frames,
1391-
height=height,
1392-
width=width,
1394+
num_frames=latent_num_frames,
1395+
height=latent_height,
1396+
width=latent_width,
13931397
audio_hidden_states=audio_latents,
13941398
audio_encoder_hidden_states=audio_encoder_hidden_states,
13951399
audio_encoder_attention_mask=audio_encoder_attention_mask,

0 commit comments

Comments
 (0)