@@ -504,6 +504,7 @@ def __call__(self, timestep, guidance, pooled_projection):
504504
505505
506506class NNXTimesteps (nnx .Module ):
507+
507508 def __init__ (self , num_channels : int , flip_sin_to_cos : bool , downscale_freq_shift : float , scale : int = 1 ):
508509 self .num_channels = num_channels
509510 self .flip_sin_to_cos = flip_sin_to_cos
@@ -516,75 +517,64 @@ def __call__(self, timesteps: jax.Array) -> jax.Array:
516517 embedding_dim = self .num_channels ,
517518 freq_shift = self .downscale_freq_shift ,
518519 flip_sin_to_cos = self .flip_sin_to_cos ,
519- scale = self .scale
520+ scale = self .scale ,
520521 )
521522
522523
523524class NNXPixArtAlphaCombinedTimestepSizeEmbeddings (nnx .Module ):
525+
524526 def __init__ (
525527 self ,
526528 rngs : nnx .Rngs ,
527529 embedding_dim : int ,
528530 size_emb_dim : int ,
529531 use_additional_conditions : bool = False ,
530532 dtype : jnp .dtype = jnp .float32 ,
531- weights_dtype : jnp .dtype = jnp .float32
533+ weights_dtype : jnp .dtype = jnp .float32 ,
532534 ):
533535 self .outdim = size_emb_dim
534536 self .use_additional_conditions = use_additional_conditions
535537
536538 self .time_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
537539 self .timestep_embedder = NNXTimestepEmbedding (
538- rngs = rngs ,
539- in_channels = 256 ,
540- time_embed_dim = embedding_dim ,
541- dtype = dtype ,
542- weights_dtype = weights_dtype
540+ rngs = rngs , in_channels = 256 , time_embed_dim = embedding_dim , dtype = dtype , weights_dtype = weights_dtype
543541 )
544542
545543 if use_additional_conditions :
546- self .additional_condition_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
547- self .resolution_embedder = NNXTimestepEmbedding (
548- rngs = rngs ,
549- in_channels = 256 ,
550- time_embed_dim = size_emb_dim ,
551- dtype = dtype ,
552- weights_dtype = weights_dtype
553- )
554- self .aspect_ratio_embedder = NNXTimestepEmbedding (
555- rngs = rngs ,
556- in_channels = 256 ,
557- time_embed_dim = size_emb_dim ,
558- dtype = dtype ,
559- weights_dtype = weights_dtype
560- )
544+ self .additional_condition_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
545+ self .resolution_embedder = NNXTimestepEmbedding (
546+ rngs = rngs , in_channels = 256 , time_embed_dim = size_emb_dim , dtype = dtype , weights_dtype = weights_dtype
547+ )
548+ self .aspect_ratio_embedder = NNXTimestepEmbedding (
549+ rngs = rngs , in_channels = 256 , time_embed_dim = size_emb_dim , dtype = dtype , weights_dtype = weights_dtype
550+ )
561551
562552 def __call__ (
563553 self ,
564554 timestep : jax .Array ,
565555 resolution : Optional [jax .Array ] = None ,
566556 aspect_ratio : Optional [jax .Array ] = None ,
567- hidden_dtype : jnp .dtype = jnp .float32
557+ hidden_dtype : jnp .dtype = jnp .float32 ,
568558 ) -> jax .Array :
569559 timesteps_proj = self .time_proj (timestep )
570560 timesteps_emb = self .timestep_embedder (timesteps_proj .astype (hidden_dtype ))
571561
572562 if self .use_additional_conditions :
573- if resolution is None or aspect_ratio is None :
574- raise ValueError ("resolution and aspect_ratio must be provided when use_additional_conditions is True" )
575-
576- resolution_emb = self .additional_condition_proj (resolution .flatten ()).astype (hidden_dtype )
577- resolution_emb = self .resolution_embedder (resolution_emb )
578- # Reshape to (batch_size, -1) matching PyTorch's reshape(batch_size, -1)
579- # assuming resolution input was (batch_size, ...) so flatten logic holds.
580- resolution_emb = resolution_emb .reshape (timestep .shape [0 ], - 1 )
581-
582- aspect_ratio_emb = self .additional_condition_proj (aspect_ratio .flatten ()).astype (hidden_dtype )
583- aspect_ratio_emb = self .aspect_ratio_embedder (aspect_ratio_emb )
584- aspect_ratio_emb = aspect_ratio_emb .reshape (timestep .shape [0 ], - 1 )
585-
586- conditioning = timesteps_emb + jnp .concatenate ([resolution_emb , aspect_ratio_emb ], axis = 1 )
563+ if resolution is None or aspect_ratio is None :
564+ raise ValueError ("resolution and aspect_ratio must be provided when use_additional_conditions is True" )
565+
566+ resolution_emb = self .additional_condition_proj (resolution .flatten ()).astype (hidden_dtype )
567+ resolution_emb = self .resolution_embedder (resolution_emb )
568+ # Reshape to (batch_size, -1) matching PyTorch's reshape(batch_size, -1)
569+ # assuming resolution input was (batch_size, ...) so flatten logic holds.
570+ resolution_emb = resolution_emb .reshape (timestep .shape [0 ], - 1 )
571+
572+ aspect_ratio_emb = self .additional_condition_proj (aspect_ratio .flatten ()).astype (hidden_dtype )
573+ aspect_ratio_emb = self .aspect_ratio_embedder (aspect_ratio_emb )
574+ aspect_ratio_emb = aspect_ratio_emb .reshape (timestep .shape [0 ], - 1 )
575+
576+ conditioning = timesteps_emb + jnp .concatenate ([resolution_emb , aspect_ratio_emb ], axis = 1 )
587577 else :
588- conditioning = timesteps_emb
578+ conditioning = timesteps_emb
589579
590580 return conditioning
0 commit comments