@@ -227,6 +227,16 @@ def __init__(
227227 self .per_channel_scale2 = None
228228
229229 if timestep_conditioning :
230+ self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
231+ rngs = rngs ,
232+ embedding_dim = in_channels * 4 ,
233+ size_emb_dim = 0 ,
234+ use_additional_conditions = False ,
235+ dtype = dtype ,
236+ weights_dtype = weights_dtype
237+ ))
238+ else :
239+ self .time_embedder = None
230240 self .scale_shift_table = nnx .Param (
231241 jax .random .normal (rngs .params (), (4 , in_channels )) / (in_channels ** 0.5 )
232242 )
@@ -573,14 +583,16 @@ def __init__(
573583 precision : jax .lax .Precision = None ,
574584 ):
575585 if timestep_conditioning :
576- self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
577- rngs = rngs ,
578- embedding_dim = in_channels * 4 ,
579- size_emb_dim = 0 ,
580- use_additional_conditions = False ,
581- dtype = dtype ,
582- weights_dtype = weights_dtype
583- ))
586+ self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
587+ rngs = rngs ,
588+ embedding_dim = in_channels * 4 ,
589+ size_emb_dim = 0 ,
590+ use_additional_conditions = False ,
591+ dtype = dtype ,
592+ weights_dtype = weights_dtype
593+ ))
594+ else :
595+ self .time_embedder = None
584596
585597 self .resnets = nnx .List ([
586598 LTX2VideoResnetBlock3d (
@@ -654,6 +666,16 @@ def __init__(
654666 out_channels = out_channels or in_channels
655667
656668 if timestep_conditioning :
669+ self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
670+ rngs = rngs ,
671+ embedding_dim = in_channels * 4 ,
672+ size_emb_dim = 0 ,
673+ use_additional_conditions = False ,
674+ dtype = dtype ,
675+ weights_dtype = weights_dtype
676+ ))
677+ else :
678+ self .time_embedder = None
657679 self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
658680 rngs = rngs ,
659681 embedding_dim = in_channels * 4 ,
@@ -1011,6 +1033,16 @@ def __init__(
10111033 self .scale_shift_table = None
10121034 self .timestep_scale_multiplier = None
10131035 if timestep_conditioning :
1036+ self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
1037+ rngs = rngs ,
1038+ embedding_dim = in_channels * 4 ,
1039+ size_emb_dim = 0 ,
1040+ use_additional_conditions = False ,
1041+ dtype = dtype ,
1042+ weights_dtype = weights_dtype
1043+ ))
1044+ else :
1045+ self .time_embedder = None
10141046 self .timestep_scale_multiplier = nnx .Param (jnp .array (1000.0 , dtype = jnp .float32 ))
10151047 self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
10161048 rngs = rngs ,
0 commit comments