Skip to content

Commit 668390c

Browse files
committed
new fix
1 parent d9a96aa commit 668390c

2 files changed

Lines changed: 80 additions & 9 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -655,14 +655,16 @@ def __init__(
655655
):
656656
out_channels = out_channels or in_channels
657657

658-
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
659-
rngs=rngs,
660-
embedding_dim=in_channels * 4,
661-
size_emb_dim=0,
662-
use_additional_conditions=False,
663-
dtype=dtype,
664-
weights_dtype=weights_dtype
665-
))
658+
self.time_embedder = None
659+
if timestep_conditioning:
660+
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
661+
rngs=rngs,
662+
embedding_dim=in_channels * 4,
663+
size_emb_dim=0,
664+
use_additional_conditions=False,
665+
dtype=dtype,
666+
weights_dtype=weights_dtype
667+
))
666668

667669
if in_channels != out_channels:
668670
self.conv_in = nnx.data(LTX2VideoResnetBlock3d(
@@ -1068,6 +1070,60 @@ def __call__(
10681070
return hidden_states
10691071

10701072

1073+
1074+
class LTX2DiagonalGaussianDistribution(nnx.Module):
1075+
def __init__(self, parameters: jax.Array, deterministic: bool = False):
1076+
self.parameters = parameters
1077+
# Split into mean and logvar
1078+
# LTX-2 specific: 128 channels for mean, 1 channel for logvar
1079+
self.mean, self.logvar = jnp.split(parameters, [128], axis=-1)
1080+
self.logvar = jnp.clip(self.logvar, -30.0, 20.0)
1081+
self.deterministic = deterministic
1082+
self.std = jnp.exp(0.5 * self.logvar)
1083+
self.var = jnp.exp(self.logvar)
1084+
if self.deterministic:
1085+
self.var = self.std = jnp.zeros_like(
1086+
self.mean, dtype=self.parameters.dtype
1087+
)
1088+
1089+
def sample(self, key: jax.Array) -> jax.Array:
1090+
# make sure sample is on the same device as the parameters and has same dtype
1091+
sample = jax.random.normal(key, self.mean.shape, dtype=self.parameters.dtype)
1092+
x = self.mean + self.std * sample
1093+
return x
1094+
1095+
def kl(self, other: "LTX2DiagonalGaussianDistribution" = None) -> jax.Array:
1096+
if self.deterministic:
1097+
return jnp.array([0.0])
1098+
else:
1099+
if other is None:
1100+
return 0.5 * jnp.sum(
1101+
jnp.power(self.mean, 2) + self.var - 1.0 - self.logvar,
1102+
axis=[1, 2, 3],
1103+
)
1104+
else:
1105+
return 0.5 * jnp.sum(
1106+
jnp.power(self.mean - other.mean, 2) / other.var
1107+
+ self.var / other.var
1108+
- 1.0
1109+
- self.logvar
1110+
+ other.logvar,
1111+
axis=[1, 2, 3],
1112+
)
1113+
1114+
def nll(self, sample: jax.Array, dims: Tuple[int, ...] = (1, 2, 3)) -> jax.Array:
1115+
if self.deterministic:
1116+
return jnp.array([0.0])
1117+
logtwopi = jnp.log(2.0 * jnp.pi)
1118+
return 0.5 * jnp.sum(
1119+
logtwopi + self.logvar + jnp.power(sample - self.mean, 2) / self.var,
1120+
axis=dims,
1121+
)
1122+
1123+
def mode(self) -> jax.Array:
1124+
return self.mean
1125+
1126+
10711127
class LTX2VideoAutoencoderKL(nnx.Module, ConfigMixin):
10721128
_supports_gradient_checkpointing = True
10731129
config_name = "config.json"
@@ -1510,7 +1566,8 @@ def encode(
15101566
else:
15111567
moments = self._encode(sample, key=key, causal=causal)
15121568

1513-
posterior = FlaxDiagonalGaussianDistribution(moments)
1569+
1570+
posterior = LTX2DiagonalGaussianDistribution(moments)
15141571

15151572
if not return_dict:
15161573
return (posterior,)

src/maxdiffusion/tests/ltx2_vae_parity_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,19 @@ def fix_keys(d):
130130
print(f"Latents Mean: {latents.mean():.6f}")
131131
print(f"Latents Std: {latents.std():.6f}")
132132

133+
# Assertions
134+
# 1. Check Output Shape
135+
assert jax_recon.shape == jax_input.shape, f"Output shape mismatch: {jax_recon.shape} vs {jax_input.shape}"
136+
137+
# 2. Check Latents Shape (Mean should be 128 channels)
138+
assert latents.shape[-1] == 128, f"Latents channel mismatch: {latents.shape[-1]} vs 128"
139+
140+
# 3. Check Encoder Output Channels (should be 129 before split)
141+
# We can check parameters of the distribution indirectly via moments if accessible,
142+
# but here we checked latents (mode) which is derived from mean (128).
143+
# The fact that it ran without error implies the split worked.
144+
145+
print("\nTest Passed!")
146+
133147
if __name__ == "__main__":
134148
test_ltx2_vae_parity()

0 commit comments

Comments
 (0)