@@ -227,16 +227,6 @@ def __init__(
227227 self .per_channel_scale2 = None
228228
229229 if timestep_conditioning :
230- self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
231- rngs = rngs ,
232- embedding_dim = in_channels * 4 ,
233- size_emb_dim = 0 ,
234- use_additional_conditions = False ,
235- dtype = dtype ,
236- weights_dtype = weights_dtype
237- ))
238- else :
239- self .time_embedder = None
240230 self .scale_shift_table = nnx .Param (
241231 jax .random .normal (rngs .params (), (4 , in_channels )) / (in_channels ** 0.5 )
242232 )
@@ -665,25 +655,14 @@ def __init__(
665655 ):
666656 out_channels = out_channels or in_channels
667657
668- if timestep_conditioning :
669- self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
670- rngs = rngs ,
671- embedding_dim = in_channels * 4 ,
672- size_emb_dim = 0 ,
673- use_additional_conditions = False ,
674- dtype = dtype ,
675- weights_dtype = weights_dtype
676- ))
677- else :
678- self .time_embedder = None
679- self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
680- rngs = rngs ,
681- embedding_dim = in_channels * 4 ,
682- size_emb_dim = 0 ,
683- use_additional_conditions = False ,
684- dtype = dtype ,
685- weights_dtype = weights_dtype
686- ))
658+ self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
659+ rngs = rngs ,
660+ embedding_dim = in_channels * 4 ,
661+ size_emb_dim = 0 ,
662+ use_additional_conditions = False ,
663+ dtype = dtype ,
664+ weights_dtype = weights_dtype
665+ ))
687666
688667 if in_channels != out_channels :
689668 self .conv_in = nnx .data (LTX2VideoResnetBlock3d (
@@ -911,11 +890,6 @@ def __call__(
911890 hidden_states = self .conv_act (hidden_states )
912891 hidden_states = self .conv_out (hidden_states , causal = causal )
913892
914- # LTX-2 specific output expansion
915- last_channel = hidden_states [..., - 1 :]
916- repeats = 127 # 256 - 129
917- last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
918- hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
919893
920894
921895 return hidden_states
@@ -1033,6 +1007,7 @@ def __init__(
10331007 self .scale_shift_table = None
10341008 self .timestep_scale_multiplier = None
10351009 if timestep_conditioning :
1010+ self .timestep_scale_multiplier = nnx .Param (jnp .array (1000.0 , dtype = jnp .float32 ))
10361011 self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
10371012 rngs = rngs ,
10381013 embedding_dim = in_channels * 4 ,
@@ -1041,20 +1016,9 @@ def __init__(
10411016 dtype = dtype ,
10421017 weights_dtype = weights_dtype
10431018 ))
1044- else :
1019+ else :
1020+ self .timestep_scale_multiplier = None
10451021 self .time_embedder = None
1046- self .timestep_scale_multiplier = nnx .Param (jnp .array (1000.0 , dtype = jnp .float32 ))
1047- self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
1048- rngs = rngs ,
1049- embedding_dim = output_channel * 2 ,
1050- size_emb_dim = 0 ,
1051- use_additional_conditions = False ,
1052- dtype = dtype ,
1053- weights_dtype = weights_dtype
1054- ))
1055- self .scale_shift_table = nnx .Param (
1056- jax .random .normal (rngs .params (), (2 , output_channel )) / (output_channel ** 0.5 )
1057- )
10581022
10591023 @nnx .jit (static_argnames = ("causal" , "deterministic" ))
10601024 def __call__ (
@@ -1084,31 +1048,9 @@ def __call__(
10841048
10851049 hidden_states = self .norm_out (hidden_states )
10861050
1087- if self .time_embedder is not None :
1088- temb = self .time_embedder (timestep = temb .flatten (), hidden_dtype = hidden_states .dtype )
1089-
1090- B = hidden_states .shape [0 ]
1091- C = hidden_states .shape [- 1 ]
1092-
1093- temb = temb .reshape (B , 2 , C )
1094-
1095- # Add table
1096- params = self .scale_shift_table .value [None , :, :] + temb
1097-
1098- shift = params [:, 0 , :]
1099- scale = params [:, 1 , :]
1100-
1101- # Broadcast
1102- hidden_states = hidden_states * (1 + scale [:, None , None , None , :]) + shift [:, None , None , None , :]
1103-
11041051 hidden_states = self .conv_act (hidden_states )
11051052 hidden_states = self .conv_out (hidden_states , causal = causal )
11061053
1107- # LTX-2 specific output expansion
1108- last_channel = hidden_states [..., - 1 :]
1109- repeats = 127 # 256 - 129
1110- last_channel_repeated = jnp .repeat (last_channel , repeats , axis = - 1 )
1111- hidden_states = jnp .concatenate ([hidden_states , last_channel_repeated ], axis = - 1 )
11121054
11131055 # Unpatchify
11141056 B , T , H , W , C = hidden_states .shape
0 commit comments