Skip to content

Commit d8c64f2

Browse files
committed
vae config hardcoded
1 parent 9c3ce8b commit d8c64f2

1 file changed

Lines changed: 26 additions & 9 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -382,19 +382,36 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
382382
{
383383
"block_out_channels": (256, 512, 1024, 1024),
384384
"decoder_block_out_channels": (256, 512, 512, 1024),
385-
"upsample_residual": False,
386-
"upsample_factor": (2, 2, 1, 2),
385+
"layers_per_block": (4, 6, 4, 2, 2),
386+
"decoder_layers_per_block": (4, 6, 4, 2, 2),
387+
"spatio_temporal_scaling": (True, True, True, True),
388+
"decoder_spatio_temporal_scaling": (True, True, True, True),
389+
"decoder_inject_noise": (False, False, False, False, False),
390+
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
387391
"upsample_type": ("spatiotemporal", "spatiotemporal", "temporal", "spatial"),
392+
"upsample_residual": (False, False, False, False),
393+
"upsample_factor": (2, 2, 1, 2),
394+
"patch_size": 4,
395+
"patch_size_t": 1,
396+
"resnet_norm_eps": 1e-6,
397+
"encoder_causal": True,
398+
"decoder_causal": False,
399+
"encoder_spatial_padding_mode": "zeros",
388400
"decoder_spatial_padding_mode": "zeros",
389401
}
390402
)
391-
vae = LTX2VideoAutoencoderKL.from_config(
392-
config.pretrained_model_name_or_path,
393-
subfolder="vae",
394-
rngs=rngs,
395-
mesh=mesh,
396-
**vae_kwargs,
397-
)
403+
vae = LTX2VideoAutoencoderKL(
404+
rngs=rngs,
405+
**vae_kwargs,
406+
)
407+
else:
408+
vae = LTX2VideoAutoencoderKL.from_config(
409+
config.pretrained_model_name_or_path,
410+
subfolder="vae",
411+
rngs=rngs,
412+
mesh=mesh,
413+
**vae_kwargs,
414+
)
398415
return vae
399416

400417
p_model_factory = partial(create_model, config=config)

0 commit comments

Comments
 (0)