1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import math
15- from typing import Optional
15+ from typing import Optional , Any
1616import flax .linen as nn
1717from flax import nnx
1818import jax .numpy as jnp
@@ -60,6 +60,9 @@ def get_sinusoidal_embeddings(
6060 return signal
6161
6262
63+
64+
65+
6366class NNXTimestepEmbedding (nnx .Module ):
6467 r"""
6568 Time step Embedding Module. Learns embeddings for input time steps.
@@ -84,7 +87,12 @@ def __init__(
8487 dtype : jnp .dtype = jnp .float32 ,
8588 weights_dtype : jnp .dtype = jnp .float32 ,
8689 precision : jax .lax .Precision = None ,
90+ sharding_specs : Optional [Any ] = None ,
8791 ):
92+ linear_1_kernel = getattr (sharding_specs , "emb_linear_1_kernel" , ("embed" , "mlp" ))
93+ linear_1_bias = getattr (sharding_specs , "emb_linear_1_bias" , ("mlp" ,))
94+ linear_2_kernel = getattr (sharding_specs , "emb_linear_2_kernel" , ("mlp" , "embed" ))
95+ linear_2_bias = getattr (sharding_specs , "emb_linear_2_bias" , ("embed" ,))
8896 self .linear_1 = nnx .Linear (
8997 rngs = rngs ,
9098 in_features = in_channels ,
@@ -95,12 +103,9 @@ def __init__(
95103 precision = precision ,
96104 kernel_init = nnx .with_partitioning (
97105 nnx .initializers .xavier_uniform (),
98- (
99- "embed" ,
100- "mlp" ,
101- ),
106+ linear_1_kernel ,
102107 ),
103- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ( "mlp" ,) ),
108+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , linear_1_bias ),
104109 )
105110
106111 if cond_proj_dim is not None :
@@ -127,12 +132,9 @@ def __init__(
127132 precision = precision ,
128133 kernel_init = nnx .with_partitioning (
129134 nnx .initializers .xavier_uniform (),
130- (
131- "mlp" ,
132- "embed" ,
133- ),
135+ linear_2_kernel ,
134136 ),
135- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ( "embed" ,) ),
137+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , linear_2_bias ),
136138 )
137139
138140 if post_act_fn is None :
@@ -336,7 +338,12 @@ def __init__(
336338 dtype : jnp .dtype = jnp .float32 ,
337339 weights_dtype : jnp .dtype = jnp .float32 ,
338340 precision : jax .lax .Precision = None ,
341+ sharding_specs : Optional [Any ] = None ,
339342 ):
343+ linear_1_kernel = getattr (sharding_specs , "emb_linear_1_kernel" , ("embed" , "mlp" ))
344+ linear_1_bias = getattr (sharding_specs , "emb_linear_1_bias" , ("mlp" ,))
345+ linear_2_kernel = getattr (sharding_specs , "emb_linear_2_kernel" , ("mlp" , "embed" ))
346+ linear_2_bias = getattr (sharding_specs , "emb_linear_2_bias" , ("embed" ,))
340347 if out_features is None :
341348 out_features = hidden_size
342349
@@ -350,12 +357,9 @@ def __init__(
350357 precision = precision ,
351358 kernel_init = nnx .with_partitioning (
352359 nnx .initializers .xavier_uniform (),
353- (
354- "embed" ,
355- "mlp" ,
356- ),
360+ linear_1_kernel ,
357361 ),
358- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ( "mlp" ,) ),
362+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , linear_1_bias ),
359363 )
360364 self .act_1 = get_activation (act_fn )
361365
@@ -369,12 +373,9 @@ def __init__(
369373 precision = precision ,
370374 kernel_init = nnx .with_partitioning (
371375 nnx .initializers .xavier_uniform (),
372- (
373- "mlp" ,
374- "embed" ,
375- ),
376+ linear_2_kernel ,
376377 ),
377- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ( "embed" ,) ),
378+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , linear_2_bias ),
378379 )
379380
380381 def __call__ (self , caption ):
@@ -530,22 +531,38 @@ def __init__(
530531 use_additional_conditions : bool = False ,
531532 dtype : jnp .dtype = jnp .float32 ,
532533 weights_dtype : jnp .dtype = jnp .float32 ,
534+ sharding_specs : Optional [Any ] = None ,
533535 ):
534536 self .outdim = size_emb_dim
535537 self .use_additional_conditions = use_additional_conditions
536538
537539 self .time_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
538540 self .timestep_embedder = NNXTimestepEmbedding (
539- rngs = rngs , in_channels = 256 , time_embed_dim = embedding_dim , dtype = dtype , weights_dtype = weights_dtype
541+ rngs = rngs ,
542+ in_channels = 256 ,
543+ time_embed_dim = embedding_dim ,
544+ dtype = dtype ,
545+ weights_dtype = weights_dtype ,
546+ sharding_specs = sharding_specs ,
540547 )
541548
542549 if use_additional_conditions :
543550 self .additional_condition_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
544551 self .resolution_embedder = NNXTimestepEmbedding (
545- rngs = rngs , in_channels = 256 , time_embed_dim = size_emb_dim , dtype = dtype , weights_dtype = weights_dtype
552+ rngs = rngs ,
553+ in_channels = 256 ,
554+ time_embed_dim = size_emb_dim ,
555+ dtype = dtype ,
556+ weights_dtype = weights_dtype ,
557+ sharding_specs = sharding_specs ,
546558 )
547559 self .aspect_ratio_embedder = NNXTimestepEmbedding (
548- rngs = rngs , in_channels = 256 , time_embed_dim = size_emb_dim , dtype = dtype , weights_dtype = weights_dtype
560+ rngs = rngs ,
561+ in_channels = 256 ,
562+ time_embed_dim = size_emb_dim ,
563+ dtype = dtype ,
564+ weights_dtype = weights_dtype ,
565+ sharding_specs = sharding_specs ,
549566 )
550567
551568 def __call__ (
0 commit comments