@@ -145,6 +145,12 @@ def __call__(self, hidden_states: jax.Array, causal: bool = True) -> jax.Array:
145145 # 3. Conv
146146 hidden_states = self .conv (hidden_states )
147147
148+ # LTX-2 specific output expansion
149+ last_channel = hidden_states [..., - 1 :]
150+ repeats = 127 # 256 - 129
151+ last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
152+ hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
153+
148154 return hidden_states
149155
150156
@@ -290,6 +296,12 @@ def __call__(
290296 inputs = self .conv_shortcut (inputs )
291297
292298 hidden_states = hidden_states + inputs
299+ # LTX-2 specific output expansion
300+ last_channel = hidden_states [..., - 1 :]
301+ repeats = 127 # 256 - 129
302+ last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
303+ hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
304+
293305 return hidden_states
294306
295307
@@ -430,6 +442,12 @@ def __call__(self, hidden_states: jax.Array, causal: bool = True) -> jax.Array:
430442 if self .residual :
431443 hidden_states = hidden_states + residual
432444
445+ # LTX-2 specific output expansion
446+ last_channel = hidden_states [..., - 1 :]
447+ repeats = 127 # 256 - 129
448+ last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
449+ hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
450+
433451 return hidden_states
434452
435453
@@ -548,6 +566,12 @@ def __call__(
548566 for downsampler in self .downsamplers :
549567 hidden_states = downsampler (hidden_states , causal = causal )
550568
569+ # LTX-2 specific output expansion
570+ last_channel = hidden_states [..., - 1 :]
571+ repeats = 127 # 256 - 129
572+ last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
573+ hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
574+
551575 return hidden_states
552576
553577
@@ -622,6 +646,12 @@ def __call__(
622646 deterministic = deterministic
623647 )
624648
649+ # LTX-2 specific output expansion
650+ last_channel = hidden_states [..., - 1 :]
651+ repeats = 127 # 256 - 129
652+ last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
653+ hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
654+
625655 return hidden_states
626656
627657
@@ -745,6 +775,12 @@ def __call__(
745775 deterministic = deterministic
746776 )
747777
778+ # LTX-2 specific output expansion
779+ last_channel = hidden_states [..., - 1 :]
780+ repeats = 127 # 256 - 129
781+ last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
782+ hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
783+
748784 return hidden_states
749785
750786
@@ -835,7 +871,7 @@ def __init__(
835871
836872 self .conv_out = LTX2VideoCausalConv3d (
837873 in_channels = output_channel ,
838- out_channels = out_channels * 2 ,
874+ out_channels = out_channels + 1 ,
839875 kernel_size = 3 ,
840876 stride = 1 ,
841877 spatial_padding_mode = spatial_padding_mode ,
@@ -883,6 +919,12 @@ def __call__(
883919 hidden_states = self .conv_act (hidden_states )
884920 hidden_states = self .conv_out (hidden_states , causal = causal )
885921
922+ # LTX-2 specific output expansion
923+ last_channel = hidden_states [..., - 1 :]
924+ repeats = 127 # 256 - 129
925+ last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
926+ hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
927+
886928 return hidden_states
887929
888930
@@ -1071,6 +1113,12 @@ def __call__(
10711113 hidden_states = hidden_states .transpose (0 , 1 , 5 , 2 , 7 , 3 , 6 , 4 )
10721114 hidden_states = hidden_states .reshape (B , T * p_t , H * p , W * p , C_out_final )
10731115
1116+ # LTX-2 specific output expansion
1117+ last_channel = hidden_states [..., - 1 :]
1118+ repeats = 127 # 256 - 129
1119+ last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
1120+ hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
1121+
10741122 return hidden_states
10751123
10761124
0 commit comments