@@ -145,11 +145,6 @@ def __call__(self, hidden_states: jax.Array, causal: bool = True) -> jax.Array:
145145 # 3. Conv
146146 hidden_states = self .conv (hidden_states )
147147
148- # LTX-2 specific output expansion
149- last_channel = hidden_states [..., - 1 :]
150- repeats = 127 # 256 - 129
151- last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
152- hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
153148
154149 return hidden_states
155150
@@ -296,11 +291,6 @@ def __call__(
296291 inputs = self .conv_shortcut (inputs )
297292
298293 hidden_states = hidden_states + inputs
299- # LTX-2 specific output expansion
300- last_channel = hidden_states [..., - 1 :]
301- repeats = 127 # 256 - 129
302- last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
303- hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
304294
305295 return hidden_states
306296
@@ -442,11 +432,6 @@ def __call__(self, hidden_states: jax.Array, causal: bool = True) -> jax.Array:
442432 if self .residual :
443433 hidden_states = hidden_states + residual
444434
445- # LTX-2 specific output expansion
446- last_channel = hidden_states [..., - 1 :]
447- repeats = 127 # 256 - 129
448- last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
449- hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
450435
451436 return hidden_states
452437
@@ -641,11 +626,6 @@ def __call__(
641626 deterministic = deterministic
642627 )
643628
644- # LTX-2 specific output expansion
645- last_channel = hidden_states [..., - 1 :]
646- repeats = 127 # 256 - 129
647- last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
648- hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
649629
650630 return hidden_states
651631
@@ -770,11 +750,6 @@ def __call__(
770750 deterministic = deterministic
771751 )
772752
773- # LTX-2 specific output expansion
774- last_channel = hidden_states [..., - 1 :]
775- repeats = 127 # 256 - 129
776- last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
777- hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
778753
779754 return hidden_states
780755
@@ -914,12 +889,13 @@ def __call__(
914889 hidden_states = self .conv_act (hidden_states )
915890 hidden_states = self .conv_out (hidden_states , causal = causal )
916891
917- # LTX-2 specific output expansion
892+ # LTX-2 specific output expansion
918893 last_channel = hidden_states [..., - 1 :]
919894 repeats = 127 # 256 - 129
920895 last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
921896 hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
922897
898+
923899 return hidden_states
924900
925901
@@ -1095,6 +1071,12 @@ def __call__(
10951071
10961072 hidden_states = self .conv_act (hidden_states )
10971073 hidden_states = self .conv_out (hidden_states , causal = causal )
1074+
1075+ # LTX-2 specific output expansion
1076+ last_channel = hidden_states [..., - 1 :]
1077+ repeats = 127 # 256 - 129
1078+ last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
1079+ hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
10981080
10991081 # Unpatchify
11001082 B , T , H , W , C = hidden_states .shape
@@ -1108,11 +1090,6 @@ def __call__(
11081090 hidden_states = hidden_states .transpose (0 , 1 , 5 , 2 , 7 , 3 , 6 , 4 )
11091091 hidden_states = hidden_states .reshape (B , T * p_t , H * p , W * p , C_out_final )
11101092
1111- # LTX-2 specific output expansion
1112- last_channel = hidden_states [..., - 1 :]
1113- repeats = 127 # 256 - 129
1114- last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
1115- hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
11161093
11171094 return hidden_states
11181095
0 commit comments