Skip to content

Commit 62d994a

Browse files
committed
Enabled Quantization for LTX-2 Transformer
1 parent d651f55 commit 62d994a

13 files changed

Lines changed: 1726 additions & 1546 deletions

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 63 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -461,19 +461,20 @@ def __init__(
461461
@nnx.split_rngs(splits=num_layers)
462462
@nnx.vmap(in_axes=0, out_axes=0, axis_size=num_layers)
463463
def create_resnets(rngs):
464-
return LTX2VideoResnetBlock3d(
465-
in_channels=in_channels,
466-
out_channels=in_channels,
467-
dropout=dropout,
468-
eps=resnet_eps,
469-
non_linearity=resnet_act_fn,
470-
spatial_padding_mode=spatial_padding_mode,
471-
rngs=rngs,
472-
mesh=mesh,
473-
dtype=dtype,
474-
weights_dtype=weights_dtype,
475-
precision=precision,
476-
)
464+
return LTX2VideoResnetBlock3d(
465+
in_channels=in_channels,
466+
out_channels=in_channels,
467+
dropout=dropout,
468+
eps=resnet_eps,
469+
non_linearity=resnet_act_fn,
470+
spatial_padding_mode=spatial_padding_mode,
471+
rngs=rngs,
472+
mesh=mesh,
473+
dtype=dtype,
474+
weights_dtype=weights_dtype,
475+
precision=precision,
476+
)
477+
477478
self.resnets = create_resnets(rngs)
478479

479480
self.downsamplers = nnx.List([])
@@ -544,20 +545,19 @@ def __call__(
544545
causal: bool = True,
545546
deterministic: bool = True,
546547
) -> jax.Array:
547-
548548
subkeys = None
549549
if key is not None:
550-
subkeys = jax.random.split(key, self.num_layers)
551-
550+
subkeys = jax.random.split(key, self.num_layers)
551+
552552
def resnet_scan_fn(hidden_states, args):
553-
resnet, subkey = args
554-
hidden_states = resnet(hidden_states, temb=temb, key=subkey, causal=causal, deterministic=deterministic)
555-
return hidden_states, None
556-
553+
resnet, subkey = args
554+
hidden_states = resnet(hidden_states, temb=temb, key=subkey, causal=causal, deterministic=deterministic)
555+
return hidden_states, None
556+
557557
hidden_states, _ = nnx.scan(
558558
resnet_scan_fn,
559559
length=self.num_layers,
560-
in_axes=(nnx.Carry, 0), # Scan over 0-th dim of input tuple
560+
in_axes=(nnx.Carry, 0), # Scan over 0-th dim of input tuple
561561
out_axes=(nnx.Carry, 0),
562562
)(hidden_states, (self.resnets, subkeys))
563563

@@ -604,21 +604,22 @@ def __init__(
604604
@nnx.split_rngs(splits=num_layers)
605605
@nnx.vmap(in_axes=0, out_axes=0, axis_size=num_layers)
606606
def create_resnets(rngs):
607-
return LTX2VideoResnetBlock3d(
608-
in_channels=in_channels,
609-
out_channels=in_channels,
610-
dropout=dropout,
611-
eps=resnet_eps,
612-
non_linearity=resnet_act_fn,
613-
inject_noise=inject_noise,
614-
timestep_conditioning=timestep_conditioning,
615-
spatial_padding_mode=spatial_padding_mode,
616-
rngs=rngs,
617-
mesh=mesh,
618-
dtype=dtype,
619-
weights_dtype=weights_dtype,
620-
precision=precision,
621-
)
607+
return LTX2VideoResnetBlock3d(
608+
in_channels=in_channels,
609+
out_channels=in_channels,
610+
dropout=dropout,
611+
eps=resnet_eps,
612+
non_linearity=resnet_act_fn,
613+
inject_noise=inject_noise,
614+
timestep_conditioning=timestep_conditioning,
615+
spatial_padding_mode=spatial_padding_mode,
616+
rngs=rngs,
617+
mesh=mesh,
618+
dtype=dtype,
619+
weights_dtype=weights_dtype,
620+
precision=precision,
621+
)
622+
622623
self.resnets = create_resnets(rngs)
623624

624625
def __call__(
@@ -635,12 +636,12 @@ def __call__(
635636

636637
subkeys = None
637638
if key is not None:
638-
subkeys = jax.random.split(key, self.num_layers)
639-
639+
subkeys = jax.random.split(key, self.num_layers)
640+
640641
def resnet_scan_fn(hidden_states, args):
641-
resnet, subkey = args
642-
hidden_states = resnet(hidden_states, temb=temb, key=subkey, causal=causal, deterministic=deterministic)
643-
return hidden_states, None
642+
resnet, subkey = args
643+
hidden_states = resnet(hidden_states, temb=temb, key=subkey, causal=causal, deterministic=deterministic)
644+
return hidden_states, None
644645

645646
hidden_states, _ = nnx.scan(
646647
resnet_scan_fn,
@@ -730,21 +731,22 @@ def __init__(
730731
@nnx.split_rngs(splits=num_layers)
731732
@nnx.vmap(in_axes=0, out_axes=0, axis_size=num_layers)
732733
def create_resnets(rngs):
733-
return LTX2VideoResnetBlock3d(
734-
in_channels=out_channels,
735-
out_channels=out_channels,
736-
dropout=dropout,
737-
eps=resnet_eps,
738-
non_linearity=resnet_act_fn,
739-
inject_noise=inject_noise,
740-
timestep_conditioning=timestep_conditioning,
741-
spatial_padding_mode=spatial_padding_mode,
742-
rngs=rngs,
743-
mesh=mesh,
744-
dtype=dtype,
745-
weights_dtype=weights_dtype,
746-
precision=precision,
747-
)
734+
return LTX2VideoResnetBlock3d(
735+
in_channels=out_channels,
736+
out_channels=out_channels,
737+
dropout=dropout,
738+
eps=resnet_eps,
739+
non_linearity=resnet_act_fn,
740+
inject_noise=inject_noise,
741+
timestep_conditioning=timestep_conditioning,
742+
spatial_padding_mode=spatial_padding_mode,
743+
rngs=rngs,
744+
mesh=mesh,
745+
dtype=dtype,
746+
weights_dtype=weights_dtype,
747+
precision=precision,
748+
)
749+
748750
self.resnets = create_resnets(rngs)
749751

750752
def __call__(
@@ -770,12 +772,12 @@ def __call__(
770772

771773
subkeys = None
772774
if key is not None:
773-
subkeys = jax.random.split(key, self.num_layers)
774-
775+
subkeys = jax.random.split(key, self.num_layers)
776+
775777
def resnet_scan_fn(hidden_states, args):
776-
resnet, subkey = args
777-
hidden_states = resnet(hidden_states, temb=temb, key=subkey, causal=causal, deterministic=deterministic)
778-
return hidden_states, None
778+
resnet, subkey = args
779+
hidden_states = resnet(hidden_states, temb=temb, key=subkey, causal=causal, deterministic=deterministic)
780+
return hidden_states, None
779781

780782
hidden_states, _ = nnx.scan(
781783
resnet_scan_fn,

0 commit comments

Comments
 (0)