Skip to content

Commit 96a341d

Browse files
committed
fix
1 parent 9550040 commit 96a341d

1 file changed

Lines changed: 40 additions & 8 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,16 @@ def __init__(
227227
self.per_channel_scale2 = None
228228

229229
if timestep_conditioning:
230+
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
231+
rngs=rngs,
232+
embedding_dim=in_channels * 4,
233+
size_emb_dim=0,
234+
use_additional_conditions=False,
235+
dtype=dtype,
236+
weights_dtype=weights_dtype
237+
))
238+
else:
239+
self.time_embedder = None
230240
self.scale_shift_table = nnx.Param(
231241
jax.random.normal(rngs.params(), (4, in_channels)) / (in_channels ** 0.5)
232242
)
@@ -573,14 +583,16 @@ def __init__(
573583
precision: jax.lax.Precision = None,
574584
):
575585
if timestep_conditioning:
576-
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
577-
rngs=rngs,
578-
embedding_dim=in_channels * 4,
579-
size_emb_dim=0,
580-
use_additional_conditions=False,
581-
dtype=dtype,
582-
weights_dtype=weights_dtype
583-
))
586+
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
587+
rngs=rngs,
588+
embedding_dim=in_channels * 4,
589+
size_emb_dim=0,
590+
use_additional_conditions=False,
591+
dtype=dtype,
592+
weights_dtype=weights_dtype
593+
))
594+
else:
595+
self.time_embedder = None
584596

585597
self.resnets = nnx.List([
586598
LTX2VideoResnetBlock3d(
@@ -654,6 +666,16 @@ def __init__(
654666
out_channels = out_channels or in_channels
655667

656668
if timestep_conditioning:
669+
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
670+
rngs=rngs,
671+
embedding_dim=in_channels * 4,
672+
size_emb_dim=0,
673+
use_additional_conditions=False,
674+
dtype=dtype,
675+
weights_dtype=weights_dtype
676+
))
677+
else:
678+
self.time_embedder = None
657679
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
658680
rngs=rngs,
659681
embedding_dim=in_channels * 4,
@@ -1011,6 +1033,16 @@ def __init__(
10111033
self.scale_shift_table = None
10121034
self.timestep_scale_multiplier = None
10131035
if timestep_conditioning:
1036+
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
1037+
rngs=rngs,
1038+
embedding_dim=in_channels * 4,
1039+
size_emb_dim=0,
1040+
use_additional_conditions=False,
1041+
dtype=dtype,
1042+
weights_dtype=weights_dtype
1043+
))
1044+
else:
1045+
self.time_embedder = None
10141046
self.timestep_scale_multiplier = nnx.Param(jnp.array(1000.0, dtype=jnp.float32))
10151047
self.time_embedder = nnx.data(NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
10161048
rngs=rngs,

0 commit comments

Comments
 (0)