Skip to content

Commit 0225201

Browse files
committed
sharding
1 parent ed67622 commit 0225201

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ def __init__(self, rngs: nnx.Rngs, dim: int, dim_out: Optional[int] = None, mult
717717
param_dtype=weights_dtype,
718718
precision=precision,
719719
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", None)),
720-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
720+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
721721
)
722722
self.act = get_activation(activation_fn)
723723
self.net_2 = nnx.Linear(

0 commit comments

Comments
 (0)