Skip to content

Commit 5757481

Browse files
committed
import error fix
1 parent e1ecd5b commit 5757481

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import jax.numpy as jnp
66
from flax import nnx
77
from maxdiffusion.models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel
8-
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
8+
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
99
from maxdiffusion.models.ltx2.ltx2_utils import load_transformer_weights, load_vae_weights
1010
from maxdiffusion.models.modeling_flax_pytorch_utils import validate_flax_state_dict
1111
from flax.traverse_util import flatten_dict
@@ -84,7 +84,7 @@ def test_load_vae_weights(self):
8484
pretrained_model_name_or_path = "Lightricks/LTX-2"
8585

8686
with jax.default_device(jax.devices("cpu")[0]):
87-
model = AutoencoderKLLTX2Video(
87+
model = LTX2VideoAutoencoderKL(
8888
rngs=self.rngs,
8989
# Defaults:
9090
in_channels=3,

0 commit comments

Comments
 (0)