Skip to content

Commit 4a3bf4f

Browse files
committed
test vae load
1 parent a25e818 commit 4a3bf4f

3 files changed

Lines changed: 19 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,10 @@ def load_vae_weights(
315315
pt_tuple_key = tuple(pt_list)
316316

317317
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
318+
319+
if flax_key in [("latents_mean",), ("latents_std",)]:
320+
flax_key = flax_key + ("value",)
321+
318322
flax_key = _tuple_str_to_int(flax_key)
319323

320324
flax_key_str = [str(x) for x in flax_key]
@@ -545,6 +549,9 @@ def load_audio_vae_weights(
545549
flax_key_parts.append(part)
546550

547551
flax_key = tuple(flax_key_parts)
552+
553+
if flax_key in [("latents_mean",), ("latents_std",)]:
554+
flax_key = flax_key + ("value",)
548555

549556
if "up_stages" in flax_key:
550557
try:

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,9 +1144,6 @@ def __call__(
11441144
)
11451145
audio_num_frames = round(duration_s * audio_latents_per_second)
11461146

1147-
# Pad audio sequence length to cleanly divide block sizes for Pallas flash attention on TPUs
1148-
audio_num_frames = ((audio_num_frames + 127) // 128) * 128
1149-
11501147
audio_latents = self.prepare_audio_latents(
11511148
batch_size=batch_size,
11521149
num_channels_latents=audio_channels,

test_vae_load.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import jax
2+
import sys
3+
import maxdiffusion.pyconfig as pyconfig
4+
from maxdiffusion.pipelines.ltx2.ltx2_pipeline import LTX2Pipeline
5+
6+
argv = ["", "src/maxdiffusion/configs/ltx2_video.yml"]
7+
pyconfig.initialize(argv)
8+
9+
pipeline = LTX2Pipeline.from_pretrained(pyconfig.config, vae_only=True)
10+
print("latents_mean:", pipeline.vae.latents_mean.value[:10])
11+
print("latents_std:", pipeline.vae.latents_std.value[:10])
12+
print("audio_latents_mean:", pipeline.audio_vae.latents_mean.value[:10])

0 commit comments

Comments
 (0)