File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 55import jax .numpy as jnp
66from flax import nnx
77from maxdiffusion .models .ltx2 .transformer_ltx2 import LTX2VideoTransformer3DModel
8- from maxdiffusion .models .ltx2 .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
8+ from maxdiffusion .models .ltx2 .autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
99from maxdiffusion .models .ltx2 .ltx2_utils import load_transformer_weights , load_vae_weights
1010from maxdiffusion .models .modeling_flax_pytorch_utils import validate_flax_state_dict
1111from flax .traverse_util import flatten_dict
@@ -84,7 +84,7 @@ def test_load_vae_weights(self):
8484 pretrained_model_name_or_path = "Lightricks/LTX-2"
8585
8686 with jax .default_device (jax .devices ("cpu" )[0 ]):
87- model = AutoencoderKLLTX2Video (
87+ model = LTX2VideoAutoencoderKL (
8888 rngs = self .rngs ,
8989 # Defaults:
9090 in_channels = 3 ,
You can’t perform that action at this time.
0 commit comments