Skip to content

Commit 8641526

Browse files
committed
VAE compression ratio change
1 parent eaee6d5 commit 8641526

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,15 @@ def __init__(
235235
self.vae_temporal_compression_ratio = getattr(self.vae, "temporal_compression_ratio", 8)
236236

237237
# Audio VAE compression ratios
238-
self.audio_vae_mel_compression_ratio = getattr(self.audio_vae, "mel_compression_ratio", 4)
239-
self.audio_vae_temporal_compression_ratio = getattr(self.audio_vae, "temporal_compression_ratio", 4)
238+
if hasattr(self.audio_vae, "config") and hasattr(self.audio_vae.config, "patch_size"):
239+
self.audio_vae_mel_compression_ratio = self.audio_vae.config.patch_size
240+
else:
241+
self.audio_vae_mel_compression_ratio = getattr(self.audio_vae, "mel_compression_ratio", 1)
242+
243+
if hasattr(self.audio_vae, "config") and hasattr(self.audio_vae.config, "patch_size_t"):
244+
self.audio_vae_temporal_compression_ratio = self.audio_vae.config.patch_size_t
245+
else:
246+
self.audio_vae_temporal_compression_ratio = getattr(self.audio_vae, "temporal_compression_ratio", 1)
240247

241248
# Transformer patch sizes
242249
self.transformer_spatial_patch_size = getattr(self.transformer.config, "patch_size", 1) if getattr(self, "transformer", None) is not None else 1

0 commit comments

Comments
 (0)