Skip to content

Commit 330b0ce

Browse files
committed
text encoder used from hf
1 parent e009f3d commit 330b0ce

3 files changed

Lines changed: 499 additions & 169 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from maxdiffusion.models.embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings, NNXPixArtAlphaTextProjection
2525
from maxdiffusion.models.gradient_checkpoint import GradientCheckpointType
2626
from maxdiffusion.common_types import BlockSizes
27+
from maxdiffusion.configuration_utils import ConfigMixin, register_to_config
28+
from maxdiffusion.configuration_utils import ConfigMixin, register_to_config
2729

2830

2931
class LTX2AdaLayerNormSingle(nnx.Module):
@@ -483,8 +485,9 @@ def __call__(
483485
return hidden_states, audio_hidden_states
484486

485487

486-
class LTX2VideoTransformer3DModel(nnx.Module):
488+
class LTX2VideoTransformer3DModel(nnx.Module, ConfigMixin):
487489

490+
@register_to_config
488491
def __init__(
489492
self,
490493
rngs: nnx.Rngs,

0 commit comments

Comments
 (0)