@@ -382,19 +382,36 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
382382 {
383383 "block_out_channels" : (256 , 512 , 1024 , 1024 ),
384384 "decoder_block_out_channels" : (256 , 512 , 512 , 1024 ),
385- "upsample_residual" : False ,
386- "upsample_factor" : (2 , 2 , 1 , 2 ),
385+ "layers_per_block" : (4 , 6 , 4 , 2 , 2 ),
386+ "decoder_layers_per_block" : (4 , 6 , 4 , 2 , 2 ),
387+ "spatio_temporal_scaling" : (True , True , True , True ),
388+ "decoder_spatio_temporal_scaling" : (True , True , True , True ),
389+ "decoder_inject_noise" : (False , False , False , False , False ),
390+ "downsample_type" : ("spatial" , "temporal" , "spatiotemporal" , "spatiotemporal" ),
387391 "upsample_type" : ("spatiotemporal" , "spatiotemporal" , "temporal" , "spatial" ),
392+ "upsample_residual" : (False , False , False , False ),
393+ "upsample_factor" : (2 , 2 , 1 , 2 ),
394+ "patch_size" : 4 ,
395+ "patch_size_t" : 1 ,
396+ "resnet_norm_eps" : 1e-6 ,
397+ "encoder_causal" : True ,
398+ "decoder_causal" : False ,
399+ "encoder_spatial_padding_mode" : "zeros" ,
388400 "decoder_spatial_padding_mode" : "zeros" ,
389401 }
390402 )
391- vae = LTX2VideoAutoencoderKL .from_config (
392- config .pretrained_model_name_or_path ,
393- subfolder = "vae" ,
394- rngs = rngs ,
395- mesh = mesh ,
396- ** vae_kwargs ,
397- )
403+ vae = LTX2VideoAutoencoderKL (
404+ rngs = rngs ,
405+ ** vae_kwargs ,
406+ )
407+ else :
408+ vae = LTX2VideoAutoencoderKL .from_config (
409+ config .pretrained_model_name_or_path ,
410+ subfolder = "vae" ,
411+ rngs = rngs ,
412+ mesh = mesh ,
413+ ** vae_kwargs ,
414+ )
398415 return vae
399416
400417 p_model_factory = partial (create_model , config = config )
0 commit comments