@@ -86,7 +86,7 @@ def setup(self):
8686 self .linear1 = nn .Dense (
8787 self .dim * 3 + self .mlp_hidden_dim ,
8888 kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), ("embed" , "mlp" )),
89- bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("mlp" ,)),
89+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , (None ,)),
9090 dtype = self .dtype ,
9191 param_dtype = self .weights_dtype ,
9292 precision = self .precision ,
@@ -96,7 +96,7 @@ def setup(self):
9696 self .linear2 = nn .Dense (
9797 self .dim ,
9898 kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), ("mlp" , "embed" )),
99- bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("embed" ,)),
99+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , (None ,)),
100100 dtype = self .dtype ,
101101 param_dtype = self .weights_dtype ,
102102 precision = self .precision ,
@@ -209,7 +209,7 @@ def setup(self):
209209 int (self .dim * self .mlp_ratio ),
210210 use_bias = True ,
211211 kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), ("embed" , "mlp" )),
212- bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("mlp" ,)),
212+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , (None ,)),
213213 dtype = self .dtype ,
214214 param_dtype = self .weights_dtype ,
215215 precision = self .precision ,
@@ -218,8 +218,8 @@ def setup(self):
218218 nn .Dense (
219219 self .dim ,
220220 use_bias = True ,
221- kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), ("embed " , "mlp " )),
222- bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("mlp" ,)),
221+ kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), ("mlp " , "embed " )),
222+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , (None ,)),
223223 dtype = self .dtype ,
224224 param_dtype = self .weights_dtype ,
225225 precision = self .precision ,
@@ -240,7 +240,7 @@ def setup(self):
240240 int (self .dim * self .mlp_ratio ),
241241 use_bias = True ,
242242 kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), ("embed" , "mlp" )),
243- bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("mlp" ,)),
243+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , (None ,)),
244244 dtype = self .dtype ,
245245 param_dtype = self .weights_dtype ,
246246 precision = self .precision ,
@@ -249,8 +249,8 @@ def setup(self):
249249 nn .Dense (
250250 self .dim ,
251251 use_bias = True ,
252- kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), ("embed " , "mlp " )),
253- bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("mlp" ,)),
252+ kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), ("mlp " , "embed " )),
253+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , (None ,)),
254254 dtype = self .dtype ,
255255 param_dtype = self .weights_dtype ,
256256 precision = self .precision ,
@@ -483,6 +483,9 @@ def __call__(
483483 ):
484484 hidden_states = self .img_in (hidden_states )
485485 timestep = self .timestep_embedding (timestep , 256 )
486+
487+ timestep = nn .with_logical_constraint (timestep , ("activation_batch" , None ))
488+
486489 if self .guidance_embeds :
487490 guidance = self .timestep_embedding (guidance , 256 )
488491 else :
@@ -492,6 +495,9 @@ def __call__(
492495 if guidance is None
493496 else self .time_text_embed (timestep , guidance , pooled_projections )
494497 )
498+
499+ temb = nn .with_logical_constraint (temb , ("activation_batch" , None ))
500+
495501 encoder_hidden_states = self .txt_in (encoder_hidden_states )
496502 if txt_ids .ndim == 3 :
497503 txt_ids = txt_ids [0 ]
@@ -501,7 +507,7 @@ def __call__(
501507 ids = jnp .concatenate ((txt_ids , img_ids ), axis = 0 )
502508 ids = nn .with_logical_constraint (ids , ("activation_batch" , None ))
503509 image_rotary_emb = self .pe_embedder (ids )
504- image_rotary_emb = nn .with_logical_constraint (image_rotary_emb , ("activation_batch" , "activation_embed" ))
510+ image_rotary_emb = nn .with_logical_constraint (image_rotary_emb , (None , None ))
505511
506512 for double_block in self .double_blocks :
507513 hidden_states , encoder_hidden_states = double_block (
0 commit comments