Skip to content

Commit 06514c3

Browse files
committed
pyconfig.py
1 parent 6a1c6a6 commit 06514c3

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

src/maxdiffusion/pyconfig.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def user_init(raw_keys):
197197
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
198198
logical_axis_rules = list(raw_keys["logical_axis_rules"])
199199
logical_axis_rules.append(('bias', 'tensor'))
200+
logical_axis_rules.append(('attn2', 'add_k_proj', 'bias', 'tensor'))
200201
raw_keys["logical_axis_rules"] = tuple(logical_axis_rules)
201202
# Verify qkv is sharded across sequence.
202203
if raw_keys["attention"] == "ring":

0 commit comments

Comments
 (0)