Skip to content

Commit 705434c

Browse files
committed
Added sharding
1 parent e5c6324 commit 705434c

1 file changed

Lines changed: 22 additions & 2 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,9 +676,29 @@ class NNXSimpleFeedForward(nnx.Module):
676676
def __init__(self, rngs: nnx.Rngs, dim: int, dim_out: Optional[int] = None, mult: int = 4, activation_fn: str = "gelu", dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: Optional[jax.lax.Precision] = None):
677677
inner_dim = int(dim * mult)
678678
dim_out = dim_out if dim_out is not None else dim
679-
self.net_0 = nnx.Linear(dim, inner_dim, rngs=rngs, use_bias=True, dtype=dtype, param_dtype=weights_dtype, precision=precision)
679+
self.net_0 = nnx.Linear(
680+
dim,
681+
inner_dim,
682+
rngs=rngs,
683+
use_bias=True,
684+
dtype=dtype,
685+
param_dtype=weights_dtype,
686+
precision=precision,
687+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "embed")),
688+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
689+
)
680690
self.act = get_activation(activation_fn)
681-
self.net_2 = nnx.Linear(inner_dim, dim_out, rngs=rngs, use_bias=True, dtype=dtype, param_dtype=weights_dtype, precision=precision)
691+
self.net_2 = nnx.Linear(
692+
inner_dim,
693+
dim_out,
694+
rngs=rngs,
695+
use_bias=True,
696+
dtype=dtype,
697+
param_dtype=weights_dtype,
698+
precision=precision,
699+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
700+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
701+
)
682702

683703
def __call__(self, hidden_states: Array) -> Array:
684704
hidden_states = self.net_0(hidden_states)

0 commit comments

Comments
 (0)