Skip to content

Commit f9cc8f8

Browse files
committed
sharding added
1 parent 705434c commit f9cc8f8

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def __init__(self, rngs: nnx.Rngs, dim: int, dim_out: Optional[int] = None, mult
684684
dtype=dtype,
685685
param_dtype=weights_dtype,
686686
precision=precision,
687-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "embed")),
687+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", None)),
688688
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
689689
)
690690
self.act = get_activation(activation_fn)

0 commit comments

Comments
 (0)