Skip to content

Commit ff823af

Browse files
committed
upsample_type added
1 parent d8c64f2 commit ff823af

1 file changed

Lines changed: 16 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,7 @@ def __init__(
668668
timestep_conditioning: bool = False,
669669
upsample_residual: bool = False,
670670
upscale_factor: int = 1,
671+
upsample_type: str = "spatiotemporal",
671672
spatial_padding_mode: str = "constant",
672673
rngs: Optional[nnx.Rngs] = None,
673674
mesh: Optional[jax.sharding.Mesh] = None,
@@ -711,9 +712,18 @@ def __init__(
711712
)
712713

713714
if spatio_temporal_scale:
715+
if upsample_type == "spatiotemporal":
716+
stride = (2, 2, 2)
717+
elif upsample_type == "temporal":
718+
stride = (2, 1, 1)
719+
elif upsample_type == "spatial":
720+
stride = (1, 2, 2)
721+
else:
722+
raise ValueError(f"Unknown upsample_type: {upsample_type}")
723+
714724
self.upsampler = LTXVideoUpsampler3d(
715725
in_channels=out_channels * upscale_factor,
716-
stride=(2, 2, 2),
726+
stride=stride,
717727
residual=upsample_residual,
718728
upscale_factor=upscale_factor,
719729
spatial_padding_mode=spatial_padding_mode,
@@ -954,6 +964,7 @@ def __init__(
954964
timestep_conditioning: bool = False,
955965
upsample_residual: Tuple[bool, ...] = (True, True, True),
956966
upsample_factor: Tuple[int, ...] = (2, 2, 2),
967+
upsample_type: Tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
957968
spatial_padding_mode: str = "reflect",
958969
rngs: Optional[nnx.Rngs] = None,
959970
mesh: Optional[jax.sharding.Mesh] = None,
@@ -972,6 +983,7 @@ def __init__(
972983
inject_noise = tuple(reversed(inject_noise))
973984
upsample_residual = tuple(reversed(upsample_residual))
974985
upsample_factor = tuple(reversed(upsample_factor))
986+
upsample_type = tuple(reversed(upsample_type))
975987
output_channel = block_out_channels[0]
976988

977989
self.conv_in = LTX2VideoCausalConv3d(
@@ -1020,6 +1032,7 @@ def __init__(
10201032
timestep_conditioning=timestep_conditioning,
10211033
upsample_residual=upsample_residual[i],
10221034
upscale_factor=upsample_factor[i],
1035+
upsample_type=upsample_type[i],
10231036
spatial_padding_mode=spatial_padding_mode,
10241037
rngs=rngs,
10251038
mesh=mesh,
@@ -1139,6 +1152,7 @@ def __init__(
11391152
downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
11401153
upsample_residual: Tuple[bool, ...] = (True, True, True),
11411154
upsample_factor: Tuple[int, ...] = (2, 2, 2),
1155+
upsample_type: Tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
11421156
timestep_conditioning: bool = False,
11431157
patch_size: int = 4,
11441158
patch_size_t: int = 1,
@@ -1184,6 +1198,7 @@ def __init__(
11841198
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
11851199
upsample_factor=upsample_factor,
11861200
upsample_residual=upsample_residual,
1201+
upsample_type=upsample_type,
11871202
patch_size=patch_size,
11881203
patch_size_t=patch_size_t,
11891204
resnet_norm_eps=resnet_norm_eps,

0 commit comments

Comments
 (0)