Skip to content

Commit 6a1c6a6

Browse files
committed
pyconfig.py
1 parent e9339cb commit 6a1c6a6

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

src/maxdiffusion/pyconfig.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def user_init(raw_keys):
195195
max_utils.write_config_raw_keys_for_gcs(raw_keys)
196196

197197
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
198+
logical_axis_rules = list(raw_keys["logical_axis_rules"])
199+
logical_axis_rules.append(('bias', 'tensor'))
200+
raw_keys["logical_axis_rules"] = tuple(logical_axis_rules)
198201
# Verify qkv is sharded across sequence.
199202
if raw_keys["attention"] == "ring":
200203
logical_axis_rules = list(raw_keys["logical_axis_rules"])

0 commit comments

Comments
 (0)