@@ -668,6 +668,7 @@ def __init__(
668668 timestep_conditioning : bool = False ,
669669 upsample_residual : bool = False ,
670670 upscale_factor : int = 1 ,
671+ upsample_type : str = "spatiotemporal" ,
671672 spatial_padding_mode : str = "constant" ,
672673 rngs : Optional [nnx .Rngs ] = None ,
673674 mesh : Optional [jax .sharding .Mesh ] = None ,
@@ -711,9 +712,18 @@ def __init__(
711712 )
712713
713714 if spatio_temporal_scale :
715+ if upsample_type == "spatiotemporal" :
716+ stride = (2 , 2 , 2 )
717+ elif upsample_type == "temporal" :
718+ stride = (2 , 1 , 1 )
719+ elif upsample_type == "spatial" :
720+ stride = (1 , 2 , 2 )
721+ else :
722+ raise ValueError (f"Unknown upsample_type: { upsample_type } " )
723+
714724 self .upsampler = LTXVideoUpsampler3d (
715725 in_channels = out_channels * upscale_factor ,
716- stride = ( 2 , 2 , 2 ) ,
726+ stride = stride ,
717727 residual = upsample_residual ,
718728 upscale_factor = upscale_factor ,
719729 spatial_padding_mode = spatial_padding_mode ,
@@ -954,6 +964,7 @@ def __init__(
954964 timestep_conditioning : bool = False ,
955965 upsample_residual : Tuple [bool , ...] = (True , True , True ),
956966 upsample_factor : Tuple [int , ...] = (2 , 2 , 2 ),
967+ upsample_type : Tuple [str , ...] = ("spatiotemporal" , "spatiotemporal" , "spatiotemporal" ),
957968 spatial_padding_mode : str = "reflect" ,
958969 rngs : Optional [nnx .Rngs ] = None ,
959970 mesh : Optional [jax .sharding .Mesh ] = None ,
@@ -972,6 +983,7 @@ def __init__(
972983 inject_noise = tuple (reversed (inject_noise ))
973984 upsample_residual = tuple (reversed (upsample_residual ))
974985 upsample_factor = tuple (reversed (upsample_factor ))
986+ upsample_type = tuple (reversed (upsample_type ))
975987 output_channel = block_out_channels [0 ]
976988
977989 self .conv_in = LTX2VideoCausalConv3d (
@@ -1020,6 +1032,7 @@ def __init__(
10201032 timestep_conditioning = timestep_conditioning ,
10211033 upsample_residual = upsample_residual [i ],
10221034 upscale_factor = upsample_factor [i ],
1035+ upsample_type = upsample_type [i ],
10231036 spatial_padding_mode = spatial_padding_mode ,
10241037 rngs = rngs ,
10251038 mesh = mesh ,
@@ -1139,6 +1152,7 @@ def __init__(
11391152 downsample_type : Tuple [str , ...] = ("spatial" , "temporal" , "spatiotemporal" , "spatiotemporal" ),
11401153 upsample_residual : Tuple [bool , ...] = (True , True , True ),
11411154 upsample_factor : Tuple [int , ...] = (2 , 2 , 2 ),
1155+ upsample_type : Tuple [str , ...] = ("spatiotemporal" , "spatiotemporal" , "spatiotemporal" ),
11421156 timestep_conditioning : bool = False ,
11431157 patch_size : int = 4 ,
11441158 patch_size_t : int = 1 ,
@@ -1184,6 +1198,7 @@ def __init__(
11841198 spatio_temporal_scaling = decoder_spatio_temporal_scaling ,
11851199 upsample_factor = upsample_factor ,
11861200 upsample_residual = upsample_residual ,
1201+ upsample_type = upsample_type ,
11871202 patch_size = patch_size ,
11881203 patch_size_t = patch_size_t ,
11891204 resnet_norm_eps = resnet_norm_eps ,
0 commit comments