Skip to content

Commit 08b9574

Browse files
committed
Added sharding
1 parent dcef418 commit 08b9574

1 file changed

Lines changed: 22 additions & 2 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -708,9 +708,29 @@ class NNXSimpleFeedForward(nnx.Module):
708708
def __init__(self, rngs: nnx.Rngs, dim: int, dim_out: Optional[int] = None, mult: int = 4, activation_fn: str = "gelu", dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: Optional[jax.lax.Precision] = None):
709709
inner_dim = int(dim * mult)
710710
dim_out = dim_out if dim_out is not None else dim
711-
self.net_0 = nnx.Linear(dim, inner_dim, rngs=rngs, use_bias=True, dtype=dtype, param_dtype=weights_dtype, precision=precision)
711+
self.net_0 = nnx.Linear(
712+
dim,
713+
inner_dim,
714+
rngs=rngs,
715+
use_bias=True,
716+
dtype=dtype,
717+
param_dtype=weights_dtype,
718+
precision=precision,
719+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "embed")),
720+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
721+
)
712722
self.act = get_activation(activation_fn)
713-
self.net_2 = nnx.Linear(inner_dim, dim_out, rngs=rngs, use_bias=True, dtype=dtype, param_dtype=weights_dtype, precision=precision)
723+
self.net_2 = nnx.Linear(
724+
inner_dim,
725+
dim_out,
726+
rngs=rngs,
727+
use_bias=True,
728+
dtype=dtype,
729+
param_dtype=weights_dtype,
730+
precision=precision,
731+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
732+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
733+
)
714734

715735
def __call__(self, hidden_states: Array) -> Array:
716736
hidden_states = self.net_0(hidden_states)

0 commit comments

Comments
 (0)