Skip to content

Commit 8425c37

Browse files
committed
ff sharding fix
1 parent 6b72f95 commit 8425c37

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -717,8 +717,8 @@ def __init__(
717717
dtype=dtype,
718718
param_dtype=weights_dtype,
719719
precision=precision,
720-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", None)),
721-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
720+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
721+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
722722
)
723723
self.act = get_activation(activation_fn)
724724
self.net_2 = nnx.Linear(
@@ -729,8 +729,8 @@ def __init__(
729729
dtype=dtype,
730730
param_dtype=weights_dtype,
731731
precision=precision,
732-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
733-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
732+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp","embed")),
733+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
734734
)
735735

736736
def __call__(self, hidden_states: Array) -> Array:

0 commit comments

Comments
 (0)