Skip to content

Commit 81f0710

Browse files
committed
mapping last 2 weights
1 parent ac9e51c commit 81f0710

1 file changed

Lines changed: 49 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ def __call__(self, hidden_states: jax.Array, causal: bool = True) -> jax.Array:
145145
# 3. Conv
146146
hidden_states = self.conv(hidden_states)
147147

148+
# LTX-2 specific output expansion
149+
last_channel = hidden_states[..., -1:]
150+
repeats = 127 # 256 - 129
151+
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
152+
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
153+
148154
return hidden_states
149155

150156

@@ -290,6 +296,12 @@ def __call__(
290296
inputs = self.conv_shortcut(inputs)
291297

292298
hidden_states = hidden_states + inputs
299+
# LTX-2 specific output expansion
300+
last_channel = hidden_states[..., -1:]
301+
repeats = 127 # 256 - 129
302+
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
303+
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
304+
293305
return hidden_states
294306

295307

@@ -430,6 +442,12 @@ def __call__(self, hidden_states: jax.Array, causal: bool = True) -> jax.Array:
430442
if self.residual:
431443
hidden_states = hidden_states + residual
432444

445+
# LTX-2 specific output expansion
446+
last_channel = hidden_states[..., -1:]
447+
repeats = 127 # 256 - 129
448+
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
449+
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
450+
433451
return hidden_states
434452

435453

@@ -548,6 +566,12 @@ def __call__(
548566
for downsampler in self.downsamplers:
549567
hidden_states = downsampler(hidden_states, causal=causal)
550568

569+
# LTX-2 specific output expansion
570+
last_channel = hidden_states[..., -1:]
571+
repeats = 127 # 256 - 129
572+
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
573+
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
574+
551575
return hidden_states
552576

553577

@@ -622,6 +646,12 @@ def __call__(
622646
deterministic=deterministic
623647
)
624648

649+
# LTX-2 specific output expansion
650+
last_channel = hidden_states[..., -1:]
651+
repeats = 127 # 256 - 129
652+
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
653+
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
654+
625655
return hidden_states
626656

627657

@@ -745,6 +775,12 @@ def __call__(
745775
deterministic=deterministic
746776
)
747777

778+
# LTX-2 specific output expansion
779+
last_channel = hidden_states[..., -1:]
780+
repeats = 127 # 256 - 129
781+
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
782+
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
783+
748784
return hidden_states
749785

750786

@@ -835,7 +871,7 @@ def __init__(
835871

836872
self.conv_out = LTX2VideoCausalConv3d(
837873
in_channels=output_channel,
838-
out_channels=out_channels * 2,
874+
out_channels=out_channels + 1,
839875
kernel_size=3,
840876
stride=1,
841877
spatial_padding_mode=spatial_padding_mode,
@@ -883,6 +919,12 @@ def __call__(
883919
hidden_states = self.conv_act(hidden_states)
884920
hidden_states = self.conv_out(hidden_states, causal=causal)
885921

922+
# LTX-2 specific output expansion
923+
last_channel = hidden_states[..., -1:]
924+
repeats = 127 # 256 - 129
925+
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
926+
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
927+
886928
return hidden_states
887929

888930

@@ -1071,6 +1113,12 @@ def __call__(
10711113
hidden_states = hidden_states.transpose(0, 1, 5, 2, 7, 3, 6, 4)
10721114
hidden_states = hidden_states.reshape(B, T * p_t, H * p, W * p, C_out_final)
10731115

1116+
# LTX-2 specific output expansion
1117+
last_channel = hidden_states[..., -1:]
1118+
repeats = 127 # 256 - 129
1119+
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
1120+
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
1121+
10741122
return hidden_states
10751123

10761124

0 commit comments

Comments
 (0)