Skip to content

Commit 1f2e44b

Browse files
committed
audio frames padded to nearest multiple of 128
1 parent c5c9587 commit 1f2e44b

2 files changed

Lines changed: 4 additions & 0 deletions

File tree

src/maxdiffusion/maxdiffusion_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def get_dummy_ltx2_inputs(config, pipeline, batch_size):
297297
pipeline.audio_sampling_rate / pipeline.audio_hop_length / float(pipeline.audio_vae_temporal_compression_ratio)
298298
)
299299
audio_num_frames = round(duration_s * audio_latents_per_second)
300+
audio_num_frames = ((audio_num_frames + 127) // 128) * 128
300301

301302
hidden_states = pipeline.prepare_latents(
302303
batch_size,

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,9 @@ def __call__(
11451145
)
11461146
audio_num_frames = round(duration_s * audio_latents_per_second)
11471147

1148+
# Pad audio sequence length to cleanly divide block sizes for Pallas flash attention on TPUs
1149+
audio_num_frames = ((audio_num_frames + 127) // 128) * 128
1150+
11481151
audio_latents = self.prepare_audio_latents(
11491152
batch_size=batch_size,
11501153
num_channels_latents=audio_channels,

0 commit comments

Comments
 (0)