Skip to content

Commit 9550040

Browse files
committed
debug
1 parent 420c444 commit 9550040

1 file changed

Lines changed: 8 additions & 31 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,6 @@ def __call__(self, hidden_states: jax.Array, causal: bool = True) -> jax.Array:
145145
# 3. Conv
146146
hidden_states = self.conv(hidden_states)
147147

148-
# LTX-2 specific output expansion
149-
last_channel = hidden_states[..., -1:]
150-
repeats = 127 # 256 - 129
151-
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
152-
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
153148

154149
return hidden_states
155150

@@ -296,11 +291,6 @@ def __call__(
296291
inputs = self.conv_shortcut(inputs)
297292

298293
hidden_states = hidden_states + inputs
299-
# LTX-2 specific output expansion
300-
last_channel = hidden_states[..., -1:]
301-
repeats = 127 # 256 - 129
302-
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
303-
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
304294

305295
return hidden_states
306296

@@ -442,11 +432,6 @@ def __call__(self, hidden_states: jax.Array, causal: bool = True) -> jax.Array:
442432
if self.residual:
443433
hidden_states = hidden_states + residual
444434

445-
# LTX-2 specific output expansion
446-
last_channel = hidden_states[..., -1:]
447-
repeats = 127 # 256 - 129
448-
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
449-
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
450435

451436
return hidden_states
452437

@@ -641,11 +626,6 @@ def __call__(
641626
deterministic=deterministic
642627
)
643628

644-
# LTX-2 specific output expansion
645-
last_channel = hidden_states[..., -1:]
646-
repeats = 127 # 256 - 129
647-
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
648-
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
649629

650630
return hidden_states
651631

@@ -770,11 +750,6 @@ def __call__(
770750
deterministic=deterministic
771751
)
772752

773-
# LTX-2 specific output expansion
774-
last_channel = hidden_states[..., -1:]
775-
repeats = 127 # 256 - 129
776-
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
777-
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
778753

779754
return hidden_states
780755

@@ -914,12 +889,13 @@ def __call__(
914889
hidden_states = self.conv_act(hidden_states)
915890
hidden_states = self.conv_out(hidden_states, causal=causal)
916891

917-
# LTX-2 specific output expansion
892+
# LTX-2 specific output expansion
918893
last_channel = hidden_states[..., -1:]
919894
repeats = 127 # 256 - 129
920895
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
921896
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
922897

898+
923899
return hidden_states
924900

925901

@@ -1095,6 +1071,12 @@ def __call__(
10951071

10961072
hidden_states = self.conv_act(hidden_states)
10971073
hidden_states = self.conv_out(hidden_states, causal=causal)
1074+
1075+
# LTX-2 specific output expansion
1076+
last_channel = hidden_states[..., -1:]
1077+
repeats = 127 # 256 - 129
1078+
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
1079+
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
10981080

10991081
# Unpatchify
11001082
B, T, H, W, C = hidden_states.shape
@@ -1108,11 +1090,6 @@ def __call__(
11081090
hidden_states = hidden_states.transpose(0, 1, 5, 2, 7, 3, 6, 4)
11091091
hidden_states = hidden_states.reshape(B, T * p_t, H * p, W * p, C_out_final)
11101092

1111-
# LTX-2 specific output expansion
1112-
last_channel = hidden_states[..., -1:]
1113-
repeats = 127 # 256 - 129
1114-
last_channel_repeated = jnp.repeat(last_channel, repeats, axis=-1)
1115-
hidden_states = jnp.concatenate([hidden_states, last_channel_repeated], axis=-1)
11161093

11171094
return hidden_states
11181095

0 commit comments

Comments
 (0)