Skip to content

Commit 40d3956

Browse files
committed
bias_init added in attention_flax.py
1 parent 06514c3 commit 40d3956

2 files changed

Lines changed: 15 additions & 7 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -888,14 +888,26 @@ def __init__(
888888
if self.added_kv_proj_dim is not None:
889889
self.add_k_proj = nnx.Linear(
890890
self.added_kv_proj_dim, self.inner_dim, rngs=rngs,
891-
dtype=dtype, param_dtype=weights_dtype, precision=precision
891+
dtype=dtype, param_dtype=weights_dtype, precision=precision,
892+
bias_init=nnx.with_partitioning(
893+
nnx.initializers.zeros,
894+
("embed",),
895+
),
892896
)
893897
self.add_v_proj = nnx.Linear(
894898
self.added_kv_proj_dim, self.inner_dim, rngs=rngs,
895-
dtype=dtype, param_dtype=weights_dtype, precision=precision
899+
dtype=dtype, param_dtype=weights_dtype, precision=precision,
900+
bias_init=nnx.with_partitioning(
901+
nnx.initializers.zeros,
902+
("embed",),
903+
),
896904
)
897905
self.norm_added_k = nnx.RMSNorm(
898-
num_features=self.inner_dim, rngs=rngs, epsilon=eps, dtype=dtype, param_dtype=weights_dtype
906+
num_features=self.inner_dim, rngs=rngs, epsilon=eps, dtype=dtype, param_dtype=weights_dtype,
907+
scale_init=nnx.with_partitioning(
908+
nnx.initializers.ones,
909+
("norm",),
910+
),
899911
)
900912

901913
def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]:

src/maxdiffusion/pyconfig.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,6 @@ def user_init(raw_keys):
195195
max_utils.write_config_raw_keys_for_gcs(raw_keys)
196196

197197
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
198-
logical_axis_rules = list(raw_keys["logical_axis_rules"])
199-
logical_axis_rules.append(('bias', 'tensor'))
200-
logical_axis_rules.append(('attn2', 'add_k_proj', 'bias', 'tensor'))
201-
raw_keys["logical_axis_rules"] = tuple(logical_axis_rules)
202198
# Verify qkv is sharded across sequence.
203199
if raw_keys["attention"] == "ring":
204200
logical_axis_rules = list(raw_keys["logical_axis_rules"])

0 commit comments

Comments
 (0)