Skip to content

Commit d651f55

Browse files
committed
[VAE] Refactor Video VAE to use JAX idiomatic scan/vmap
Signed-off-by: James Huang <syhuang1201@gmail.com>
1 parent 95fcbd5 commit d651f55

1 file changed

Lines changed: 123 additions & 93 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 123 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -456,25 +456,25 @@ def __init__(
456456
precision: jax.lax.Precision = None,
457457
):
458458
out_channels = out_channels or in_channels
459-
460-
self.resnets = nnx.List(
461-
[
462-
LTX2VideoResnetBlock3d(
463-
in_channels=in_channels,
464-
out_channels=in_channels,
465-
dropout=dropout,
466-
eps=resnet_eps,
467-
non_linearity=resnet_act_fn,
468-
spatial_padding_mode=spatial_padding_mode,
469-
rngs=rngs,
470-
mesh=mesh,
471-
dtype=dtype,
472-
weights_dtype=weights_dtype,
473-
precision=precision,
474-
)
475-
for _ in range(num_layers)
476-
]
477-
)
459+
self.num_layers = num_layers
460+
461+
@nnx.split_rngs(splits=num_layers)
462+
@nnx.vmap(in_axes=0, out_axes=0, axis_size=num_layers)
463+
def create_resnets(rngs):
464+
return LTX2VideoResnetBlock3d(
465+
in_channels=in_channels,
466+
out_channels=in_channels,
467+
dropout=dropout,
468+
eps=resnet_eps,
469+
non_linearity=resnet_act_fn,
470+
spatial_padding_mode=spatial_padding_mode,
471+
rngs=rngs,
472+
mesh=mesh,
473+
dtype=dtype,
474+
weights_dtype=weights_dtype,
475+
precision=precision,
476+
)
477+
self.resnets = create_resnets(rngs)
478478

479479
self.downsamplers = nnx.List([])
480480
if spatio_temporal_scale:
@@ -544,11 +544,22 @@ def __call__(
544544
causal: bool = True,
545545
deterministic: bool = True,
546546
) -> jax.Array:
547-
for resnet in self.resnets:
548-
subkey = None
549-
if key is not None:
550-
key, subkey = jax.random.split(key)
551-
hidden_states = resnet(hidden_states, temb=temb, key=subkey, causal=causal, deterministic=deterministic)
547+
548+
subkeys = None
549+
if key is not None:
550+
subkeys = jax.random.split(key, self.num_layers)
551+
552+
def resnet_scan_fn(hidden_states, args):
553+
resnet, subkey = args
554+
hidden_states = resnet(hidden_states, temb=temb, key=subkey, causal=causal, deterministic=deterministic)
555+
return hidden_states, None
556+
557+
hidden_states, _ = nnx.scan(
558+
resnet_scan_fn,
559+
length=self.num_layers,
560+
in_axes=(nnx.Carry, 0), # Scan over 0-th dim of input tuple
561+
out_axes=(nnx.Carry, 0),
562+
)(hidden_states, (self.resnets, subkeys))
552563

553564
for downsampler in self.downsamplers:
554565
hidden_states = downsampler(hidden_states, causal=causal)
@@ -588,26 +599,27 @@ def __init__(
588599
else:
589600
self.time_embedder = None
590601

591-
self.resnets = nnx.List(
592-
[
593-
LTX2VideoResnetBlock3d(
594-
in_channels=in_channels,
595-
out_channels=in_channels,
596-
dropout=dropout,
597-
eps=resnet_eps,
598-
non_linearity=resnet_act_fn,
599-
inject_noise=inject_noise,
600-
timestep_conditioning=timestep_conditioning,
601-
spatial_padding_mode=spatial_padding_mode,
602-
rngs=rngs,
603-
mesh=mesh,
604-
dtype=dtype,
605-
weights_dtype=weights_dtype,
606-
precision=precision,
607-
)
608-
for _ in range(num_layers)
609-
]
610-
)
602+
self.num_layers = num_layers
603+
604+
@nnx.split_rngs(splits=num_layers)
605+
@nnx.vmap(in_axes=0, out_axes=0, axis_size=num_layers)
606+
def create_resnets(rngs):
607+
return LTX2VideoResnetBlock3d(
608+
in_channels=in_channels,
609+
out_channels=in_channels,
610+
dropout=dropout,
611+
eps=resnet_eps,
612+
non_linearity=resnet_act_fn,
613+
inject_noise=inject_noise,
614+
timestep_conditioning=timestep_conditioning,
615+
spatial_padding_mode=spatial_padding_mode,
616+
rngs=rngs,
617+
mesh=mesh,
618+
dtype=dtype,
619+
weights_dtype=weights_dtype,
620+
precision=precision,
621+
)
622+
self.resnets = create_resnets(rngs)
611623

612624
def __call__(
613625
self,
@@ -621,12 +633,21 @@ def __call__(
621633
temb = self.time_embedder(timestep=temb.flatten(), hidden_dtype=hidden_states.dtype)
622634
temb = temb.reshape(temb.shape[0], 1, 1, 1, -1)
623635

624-
for resnet in self.resnets:
625-
subkey = None
626-
if key is not None:
627-
key, subkey = jax.random.split(key)
628-
629-
hidden_states = resnet(hidden_states, temb=temb, key=subkey, causal=causal, deterministic=deterministic)
636+
subkeys = None
637+
if key is not None:
638+
subkeys = jax.random.split(key, self.num_layers)
639+
640+
def resnet_scan_fn(hidden_states, args):
641+
resnet, subkey = args
642+
hidden_states = resnet(hidden_states, temb=temb, key=subkey, causal=causal, deterministic=deterministic)
643+
return hidden_states, None
644+
645+
hidden_states, _ = nnx.scan(
646+
resnet_scan_fn,
647+
length=self.num_layers,
648+
in_axes=(nnx.Carry, 0),
649+
out_axes=(nnx.Carry, 0),
650+
)(hidden_states, (self.resnets, subkeys))
630651

631652
return hidden_states
632653

@@ -688,43 +709,43 @@ def __init__(
688709
)
689710
)
690711

691-
self.upsamplers = nnx.List([])
692712
if spatio_temporal_scale:
693-
self.upsamplers.append(
694-
LTXVideoUpsampler3d(
695-
in_channels=out_channels * upscale_factor,
696-
stride=(2, 2, 2),
697-
residual=upsample_residual,
698-
upscale_factor=upscale_factor,
699-
spatial_padding_mode=spatial_padding_mode,
700-
rngs=rngs,
701-
mesh=mesh,
702-
dtype=dtype,
703-
weights_dtype=weights_dtype,
704-
precision=precision,
705-
)
713+
self.upsampler = LTXVideoUpsampler3d(
714+
in_channels=out_channels * upscale_factor,
715+
stride=(2, 2, 2),
716+
residual=upsample_residual,
717+
upscale_factor=upscale_factor,
718+
spatial_padding_mode=spatial_padding_mode,
719+
rngs=rngs,
720+
mesh=mesh,
721+
dtype=dtype,
722+
weights_dtype=weights_dtype,
723+
precision=precision,
706724
)
707-
708-
self.resnets = nnx.List(
709-
[
710-
LTX2VideoResnetBlock3d(
711-
in_channels=out_channels,
712-
out_channels=out_channels,
713-
dropout=dropout,
714-
eps=resnet_eps,
715-
non_linearity=resnet_act_fn,
716-
inject_noise=inject_noise,
717-
timestep_conditioning=timestep_conditioning,
718-
spatial_padding_mode=spatial_padding_mode,
719-
rngs=rngs,
720-
mesh=mesh,
721-
dtype=dtype,
722-
weights_dtype=weights_dtype,
723-
precision=precision,
724-
)
725-
for _ in range(num_layers)
726-
]
727-
)
725+
else:
726+
self.upsampler = None
727+
728+
self.num_layers = num_layers
729+
730+
@nnx.split_rngs(splits=num_layers)
731+
@nnx.vmap(in_axes=0, out_axes=0, axis_size=num_layers)
732+
def create_resnets(rngs):
733+
return LTX2VideoResnetBlock3d(
734+
in_channels=out_channels,
735+
out_channels=out_channels,
736+
dropout=dropout,
737+
eps=resnet_eps,
738+
non_linearity=resnet_act_fn,
739+
inject_noise=inject_noise,
740+
timestep_conditioning=timestep_conditioning,
741+
spatial_padding_mode=spatial_padding_mode,
742+
rngs=rngs,
743+
mesh=mesh,
744+
dtype=dtype,
745+
weights_dtype=weights_dtype,
746+
precision=precision,
747+
)
748+
self.resnets = create_resnets(rngs)
728749

729750
def __call__(
730751
self,
@@ -744,15 +765,24 @@ def __call__(
744765
temb = self.time_embedder(timestep=temb.flatten(), hidden_dtype=hidden_states.dtype)
745766
temb = temb.reshape(temb.shape[0], 1, 1, 1, -1)
746767

747-
for upsampler in self.upsamplers:
748-
hidden_states = upsampler(hidden_states, causal=causal)
749-
750-
for resnet in self.resnets:
751-
subkey = None
752-
if key is not None:
753-
key, subkey = jax.random.split(key)
768+
if self.upsampler is not None:
769+
hidden_states = self.upsampler(hidden_states, causal=causal)
754770

755-
hidden_states = resnet(hidden_states, temb=temb, key=subkey, causal=causal, deterministic=deterministic)
771+
subkeys = None
772+
if key is not None:
773+
subkeys = jax.random.split(key, self.num_layers)
774+
775+
def resnet_scan_fn(hidden_states, args):
776+
resnet, subkey = args
777+
hidden_states = resnet(hidden_states, temb=temb, key=subkey, causal=causal, deterministic=deterministic)
778+
return hidden_states, None
779+
780+
hidden_states, _ = nnx.scan(
781+
resnet_scan_fn,
782+
length=self.num_layers,
783+
in_axes=(nnx.Carry, 0),
784+
out_axes=(nnx.Carry, 0),
785+
)(hidden_states, (self.resnets, subkeys))
756786

757787
return hidden_states
758788

0 commit comments

Comments
 (0)