We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e9339cb commit 6a1c6a6Copy full SHA for 6a1c6a6
1 file changed
src/maxdiffusion/pyconfig.py
@@ -195,6 +195,9 @@ def user_init(raw_keys):
195
max_utils.write_config_raw_keys_for_gcs(raw_keys)
196
197
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
198
+ logical_axis_rules = list(raw_keys["logical_axis_rules"])
199
+ logical_axis_rules.append(('bias', 'tensor'))
200
+ raw_keys["logical_axis_rules"] = tuple(logical_axis_rules)
201
# Verify qkv is sharded across sequence.
202
if raw_keys["attention"] == "ring":
203
logical_axis_rules = list(raw_keys["logical_axis_rules"])
0 commit comments