Skip to content

Commit d9a96aa

Browse files
committed
fix
1 parent 96a341d commit d9a96aa

1 file changed

Lines changed: 11 additions & 69 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 11 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -227,16 +227,6 @@ def __init__(
227227
self.per_channel_scale2 = None
228228

229229
if timestep_conditioning:
230-
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
231-
rngs=rngs,
232-
embedding_dim=in_channels * 4,
233-
size_emb_dim=0,
234-
use_additional_conditions=False,
235-
dtype=dtype,
236-
weights_dtype=weights_dtype
237-
))
238-
else:
239-
self.time_embedder = None
240230
self.scale_shift_table = nnx.Param(
241231
jax.random.normal(rngs.params(), (4, in_channels)) / (in_channels ** 0.5)
242232
)
@@ -665,25 +655,14 @@ def __init__(
665655
):
666656
out_channels = out_channels or in_channels
667657

668-
if timestep_conditioning:
669-
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
670-
rngs=rngs,
671-
embedding_dim=in_channels * 4,
672-
size_emb_dim=0,
673-
use_additional_conditions=False,
674-
dtype=dtype,
675-
weights_dtype=weights_dtype
676-
))
677-
else:
678-
self.time_embedder = None
679-
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
680-
rngs=rngs,
681-
embedding_dim=in_channels * 4,
682-
size_emb_dim=0,
683-
use_additional_conditions=False,
684-
dtype=dtype,
685-
weights_dtype=weights_dtype
686-
))
658+
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
659+
rngs=rngs,
660+
embedding_dim=in_channels * 4,
661+
size_emb_dim=0,
662+
use_additional_conditions=False,
663+
dtype=dtype,
664+
weights_dtype=weights_dtype
665+
))
687666

688667
if in_channels != out_channels:
689668
self.conv_in = nnx.data(LTX2VideoResnetBlock3d(
@@ -911,11 +890,6 @@ def __call__(
911890
hidden_states = self.conv_act(hidden_states)
912891
hidden_states = self.conv_out(hidden_states, causal=causal)
913892

914-
# LTX-2 specific output expansion
915-
last_channel = hidden_states[..., -1:]
916-
repeats = 127 # 256 - 129
917-
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
918-
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
919893

920894

921895
return hidden_states
@@ -1033,6 +1007,7 @@ def __init__(
10331007
self.scale_shift_table = None
10341008
self.timestep_scale_multiplier = None
10351009
if timestep_conditioning:
1010+
self.timestep_scale_multiplier = nnx.Param(jnp.array(1000.0, dtype=jnp.float32))
10361011
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
10371012
rngs=rngs,
10381013
embedding_dim=in_channels * 4,
@@ -1041,20 +1016,9 @@ def __init__(
10411016
dtype=dtype,
10421017
weights_dtype=weights_dtype
10431018
))
1044-
else:
1019+
else:
1020+
self.timestep_scale_multiplier = None
10451021
self.time_embedder = None
1046-
self.timestep_scale_multiplier = nnx.Param(jnp.array(1000.0, dtype=jnp.float32))
1047-
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
1048-
rngs=rngs,
1049-
embedding_dim=output_channel * 2,
1050-
size_emb_dim=0,
1051-
use_additional_conditions=False,
1052-
dtype=dtype,
1053-
weights_dtype=weights_dtype
1054-
))
1055-
self.scale_shift_table = nnx.Param(
1056-
jax.random.normal(rngs.params(), (2, output_channel)) / (output_channel ** 0.5)
1057-
)
10581022

10591023
@nnx.jit(static_argnames=("causal", "deterministic"))
10601024
def __call__(
@@ -1084,31 +1048,9 @@ def __call__(
10841048

10851049
hidden_states = self.norm_out(hidden_states)
10861050

1087-
if self.time_embedder is not None:
1088-
temb = self.time_embedder(timestep=temb.flatten(), hidden_dtype=hidden_states.dtype)
1089-
1090-
B = hidden_states.shape[0]
1091-
C = hidden_states.shape[-1]
1092-
1093-
temb = temb.reshape(B, 2, C)
1094-
1095-
# Add table
1096-
params = self.scale_shift_table.value[None, :, :] + temb
1097-
1098-
shift = params[:, 0, :]
1099-
scale = params[:, 1, :]
1100-
1101-
# Broadcast
1102-
hidden_states = hidden_states * (1 + scale[:, None, None, None, :]) + shift[:, None, None, None, :]
1103-
11041051
hidden_states = self.conv_act(hidden_states)
11051052
hidden_states = self.conv_out(hidden_states, causal=causal)
11061053

1107-
# LTX-2 specific output expansion
1108-
last_channel = hidden_states[..., -1:]
1109-
repeats = 127 # 256 - 129
1110-
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
1111-
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
11121054

11131055
# Unpatchify
11141056
B, T, H, W, C = hidden_states.shape

0 commit comments

Comments
 (0)