Skip to content

Commit 83585ed

Browse files
committed
sharding added
1 parent 08b9574 commit 83585ed

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ def __init__(self, rngs: nnx.Rngs, dim: int, dim_out: Optional[int] = None, mult
716716
dtype=dtype,
717717
param_dtype=weights_dtype,
718718
precision=precision,
719-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "embed")),
719+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", None)),
720720
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
721721
)
722722
self.act = get_activation(activation_fn)

0 commit comments

Comments
 (0)