Skip to content

Commit 0a36e73

Browse files
committed
embed -> none
1 parent 8425c37 commit 0a36e73

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ def __init__(
717717
dtype=dtype,
718718
param_dtype=weights_dtype,
719719
precision=precision,
720-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
720+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "mlp")),
721721
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
722722
)
723723
self.act = get_activation(activation_fn)
@@ -729,8 +729,8 @@ def __init__(
729729
dtype=dtype,
730730
param_dtype=weights_dtype,
731731
precision=precision,
732-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp","embed")),
733-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
732+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", None)),
733+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
734734
)
735735

736736
def __call__(self, hidden_states: Array) -> Array:

0 commit comments

Comments
 (0)