Skip to content

Commit 3a4463b

Browse files
committed
sharding
1 parent b0dab1a commit 3a4463b

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def __init__(self, rngs: nnx.Rngs, dim: int, dim_out: Optional[int] = None, mult
685685
param_dtype=weights_dtype,
686686
precision=precision,
687687
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", None)),
688-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
688+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
689689
)
690690
self.act = get_activation(activation_fn)
691691
self.net_2 = nnx.Linear(

0 commit comments

Comments
 (0)