1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import math
15- from typing import Optional
15+ from typing import Optional , Any
1616import flax .linen as nn
1717from flax import nnx
1818import jax .numpy as jnp
@@ -84,7 +84,12 @@ def __init__(
8484 dtype : jnp .dtype = jnp .float32 ,
8585 weights_dtype : jnp .dtype = jnp .float32 ,
8686 precision : jax .lax .Precision = None ,
87+ sharding_specs : Optional [Any ] = None ,
8788 ):
89+ linear_1_kernel = getattr (sharding_specs , "emb_linear_1_kernel" , ("embed" , "mlp" ))
90+ linear_1_bias = getattr (sharding_specs , "emb_linear_1_bias" , ("mlp" ,))
91+ linear_2_kernel = getattr (sharding_specs , "emb_linear_2_kernel" , ("mlp" , "embed" ))
92+ linear_2_bias = getattr (sharding_specs , "emb_linear_2_bias" , ("embed" ,))
8893 self .linear_1 = nnx .Linear (
8994 rngs = rngs ,
9095 in_features = in_channels ,
@@ -95,12 +100,9 @@ def __init__(
95100 precision = precision ,
96101 kernel_init = nnx .with_partitioning (
97102 nnx .initializers .xavier_uniform (),
98- (
99- "embed" ,
100- "mlp" ,
101- ),
103+ linear_1_kernel ,
102104 ),
103- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ( "mlp" ,) ),
105+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , linear_1_bias ),
104106 )
105107
106108 if cond_proj_dim is not None :
@@ -127,12 +129,9 @@ def __init__(
127129 precision = precision ,
128130 kernel_init = nnx .with_partitioning (
129131 nnx .initializers .xavier_uniform (),
130- (
131- "mlp" ,
132- "embed" ,
133- ),
132+ linear_2_kernel ,
134133 ),
135- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ( "embed" ,) ),
134+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , linear_2_bias ),
136135 )
137136
138137 if post_act_fn is None :
@@ -336,7 +335,12 @@ def __init__(
336335 dtype : jnp .dtype = jnp .float32 ,
337336 weights_dtype : jnp .dtype = jnp .float32 ,
338337 precision : jax .lax .Precision = None ,
338+ sharding_specs : Optional [Any ] = None ,
339339 ):
340+ linear_1_kernel = getattr (sharding_specs , "emb_linear_1_kernel" , ("embed" , "mlp" ))
341+ linear_1_bias = getattr (sharding_specs , "emb_linear_1_bias" , ("mlp" ,))
342+ linear_2_kernel = getattr (sharding_specs , "emb_linear_2_kernel" , ("mlp" , "embed" ))
343+ linear_2_bias = getattr (sharding_specs , "emb_linear_2_bias" , ("embed" ,))
340344 if out_features is None :
341345 out_features = hidden_size
342346
@@ -350,12 +354,9 @@ def __init__(
350354 precision = precision ,
351355 kernel_init = nnx .with_partitioning (
352356 nnx .initializers .xavier_uniform (),
353- (
354- "embed" ,
355- "mlp" ,
356- ),
357+ linear_1_kernel ,
357358 ),
358- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ( "mlp" ,) ),
359+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , linear_1_bias ),
359360 )
360361 self .act_1 = get_activation (act_fn )
361362
@@ -369,12 +370,9 @@ def __init__(
369370 precision = precision ,
370371 kernel_init = nnx .with_partitioning (
371372 nnx .initializers .xavier_uniform (),
372- (
373- "mlp" ,
374- "embed" ,
375- ),
373+ linear_2_kernel ,
376374 ),
377- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ( "embed" ,) ),
375+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , linear_2_bias ),
378376 )
379377
380378 def __call__ (self , caption ):
@@ -530,22 +528,38 @@ def __init__(
530528 use_additional_conditions : bool = False ,
531529 dtype : jnp .dtype = jnp .float32 ,
532530 weights_dtype : jnp .dtype = jnp .float32 ,
531+ sharding_specs : Optional [Any ] = None ,
533532 ):
534533 self .outdim = size_emb_dim
535534 self .use_additional_conditions = use_additional_conditions
536535
537536 self .time_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
538537 self .timestep_embedder = NNXTimestepEmbedding (
539- rngs = rngs , in_channels = 256 , time_embed_dim = embedding_dim , dtype = dtype , weights_dtype = weights_dtype
538+ rngs = rngs ,
539+ in_channels = 256 ,
540+ time_embed_dim = embedding_dim ,
541+ dtype = dtype ,
542+ weights_dtype = weights_dtype ,
543+ sharding_specs = sharding_specs ,
540544 )
541545
542546 if use_additional_conditions :
543547 self .additional_condition_proj = NNXTimesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 )
544548 self .resolution_embedder = NNXTimestepEmbedding (
545- rngs = rngs , in_channels = 256 , time_embed_dim = size_emb_dim , dtype = dtype , weights_dtype = weights_dtype
549+ rngs = rngs ,
550+ in_channels = 256 ,
551+ time_embed_dim = size_emb_dim ,
552+ dtype = dtype ,
553+ weights_dtype = weights_dtype ,
554+ sharding_specs = sharding_specs ,
546555 )
547556 self .aspect_ratio_embedder = NNXTimestepEmbedding (
548- rngs = rngs , in_channels = 256 , time_embed_dim = size_emb_dim , dtype = dtype , weights_dtype = weights_dtype
557+ rngs = rngs ,
558+ in_channels = 256 ,
559+ time_embed_dim = size_emb_dim ,
560+ dtype = dtype ,
561+ weights_dtype = weights_dtype ,
562+ sharding_specs = sharding_specs ,
549563 )
550564
551565 def __call__ (
0 commit comments