Skip to content

Commit 2724462

Browse files
committed
vae latents std and mean fix
1 parent 045dfac commit 2724462

4 files changed

Lines changed: 15 additions & 19 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,8 +1197,8 @@ def __init__(
11971197
)
11981198

11991199
self.scaling_factor = scaling_factor
1200-
self.latents_mean = tuple([0.0] * latent_channels)
1201-
self.latents_std = tuple([1.0] * latent_channels)
1200+
self.latents_mean = nnx.Param(jnp.zeros((latent_channels,), dtype=dtype))
1201+
self.latents_std = nnx.Param(jnp.ones((latent_channels,), dtype=dtype))
12021202
self.encoder_causal = encoder_causal
12031203
self.decoder_causal = decoder_causal
12041204

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2_audio.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -707,8 +707,8 @@ def __init__(
707707
is_causal=is_causal
708708
)
709709

710-
self.latents_mean = tuple([0.0] * base_channels)
711-
self.latents_std = tuple([1.0] * base_channels)
710+
self.latents_mean = nnx.Param(jnp.zeros((base_channels,), dtype=dtype))
711+
self.latents_std = nnx.Param(jnp.ones((base_channels,), dtype=dtype))
712712

713713
def _normalize_latents(self, h: jnp.ndarray) -> jnp.ndarray:
714714
if self.double_z:
@@ -721,7 +721,7 @@ def _normalize_latents(self, h: jnp.ndarray) -> jnp.ndarray:
721721

722722
# Normalize means ONLY
723723
means_patched = self.patchifier.patchify(means)
724-
means_normalized = (means_patched - jnp.array(self.latents_mean, dtype=means_patched.dtype)) / jnp.array(self.latents_std, dtype=means_patched.dtype)
724+
means_normalized = (means_patched - self.latents_mean.value.astype(means_patched.dtype)) / self.latents_std.value.astype(means_patched.dtype)
725725
means_normalized = self.patchifier.unpatchify(means_normalized, channels, freq)
726726

727727
if logvars is not None:
@@ -734,7 +734,7 @@ def _denormalize_latents(self, z: jnp.ndarray) -> Tuple[jnp.ndarray, Tuple[int,
734734

735735
# Denormalize latents (which are just means)
736736
patched_z = self.patchifier.patchify(z)
737-
denorm_patched_z = (patched_z * jnp.array(self.latents_std, dtype=patched_z.dtype)) + jnp.array(self.latents_mean, dtype=patched_z.dtype)
737+
denorm_patched_z = (patched_z * self.latents_std.value.astype(patched_z.dtype)) + self.latents_mean.value.astype(patched_z.dtype)
738738
z = self.patchifier.unpatchify(denorm_patched_z, channels, freq)
739739

740740
target_frames = time * LATENT_DOWNSAMPLE_FACTOR

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,7 @@ def load_vae_weights(
271271
random_flax_state_dict[string_tuple] = flattened_eval[key]
272272

273273
for pt_key, tensor in tensors.items():
274-
# Diffusers saves static tensors for these, but they are defined as static tuples in Flax.
275-
if pt_key in ["latents_mean", "latents_std"]:
276-
continue
274+
# latents_mean and latents_std are nnx.Params and will be loaded correctly.
277275
renamed_pt_key = rename_key(pt_key)
278276
renamed_pt_key = renamed_pt_key.replace("nin_shortcut", "conv_shortcut")
279277

@@ -529,8 +527,6 @@ def load_audio_vae_weights(
529527
random_flax_state_dict[string_tuple] = flattened_eval[key]
530528

531529
for pt_key, tensor in tensors.items():
532-
if pt_key in ["latents_mean", "latents_std"]:
533-
continue
534530
key = rename_for_ltx2_audio_vae(pt_key)
535531

536532
should_transpose = False

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -994,8 +994,8 @@ def prepare_latents(
994994
) -> jax.Array:
995995
if latents is not None:
996996
if latents.ndim == 5:
997-
latents_mean = jnp.array(self.vae.latents_mean)
998-
latents_std = jnp.array(self.vae.latents_std)
997+
latents_mean = self.vae.latents_mean.value
998+
latents_std = self.vae.latents_std.value
999999
scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, "scaling_factor") else 1.0
10001000

10011001
latents = self._normalize_latents(latents, latents_mean, latents_std, scaling_factor)
@@ -1045,8 +1045,8 @@ def prepare_audio_latents(
10451045
if latents.ndim != 3:
10461046
raise ValueError("Unexpected audio latents shape")
10471047

1048-
latents_mean = jnp.array(self.audio_vae.latents_mean)
1049-
latents_std = jnp.array(self.audio_vae.latents_std)
1048+
latents_mean = self.audio_vae.latents_mean.value
1049+
latents_std = self.audio_vae.latents_std.value
10501050

10511051
latents = self._normalize_audio_latents(latents, latents_mean, latents_std)
10521052
latents = self._create_noised_state(latents, noise_scale, generator)
@@ -1294,8 +1294,8 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12941294
)
12951295
latents = self._denormalize_latents(
12961296
latents,
1297-
jnp.array(self.vae.latents_mean),
1298-
jnp.array(self.vae.latents_std),
1297+
self.vae.latents_mean.value,
1298+
self.vae.latents_std.value,
12991299
self.vae.config.scaling_factor
13001300
)
13011301

@@ -1305,8 +1305,8 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13051305
# Denormalize and Unpack Audio (Order important: Denorm THEN Unpack)
13061306
audio_latents = self._denormalize_audio_latents(
13071307
audio_latents_jax,
1308-
jnp.array(self.audio_vae.latents_mean),
1309-
jnp.array(self.audio_vae.latents_std)
1308+
self.audio_vae.latents_mean.value,
1309+
self.audio_vae.latents_std.value
13101310
)
13111311

13121312
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64

0 commit comments

Comments
 (0)