Skip to content

Commit b0bc3a3

Browse files
committed
Flag for using same sequence sharding for self and cross
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 0abc904 commit b0bc3a3

2 files changed

Lines changed: 15 additions & 8 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ from_pt: True
5858
split_head_dim: True
5959
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6060
flash_min_seq_length: 0
61-
mask_padding_tokens: True
61+
mask_padding_tokens: True # Whether to mask padding tokens in attention computation.
62+
attention_sharding_uniform: True # same sequence sharding rules applied for q in both (self and cross attention)
6263
dropout: 0.1
6364

6465
flash_block_sizes: {

src/maxdiffusion/pyconfig.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from . import max_logging
2828
from . import max_utils
2929
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
30-
from maxdiffusion.common_types import LENGTH, KV_LENGTH, RING_ATTENTION_AXIS_RULES
30+
from maxdiffusion.common_types import LENGTH, KV_LENGTH, RING_ATTENTION_AXIS_RULES, SEQUENCE_PARALLEL_AXIS_RULES
3131

3232

3333
def string_to_bool(s: str) -> bool:
@@ -179,8 +179,8 @@ def user_init(raw_keys):
179179

180180
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
181181
# Verify qkv is sharded across sequence.
182-
if raw_keys["attention"] == "ring":
183-
max_logging.log("Using ring attention, adding sequence sharding to q and kv if not already present.")
182+
if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]:
183+
max_logging.log(f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set.")
184184
logical_axis_rules = list(raw_keys["logical_axis_rules"])
185185
max_logging.log(f"Initial logical axis rules: {logical_axis_rules}")
186186
new_rules = []
@@ -190,10 +190,16 @@ def user_init(raw_keys):
190190
logical_axis_rules.append(q_seq_sharding)
191191
if kv_seq_sharding not in logical_axis_rules:
192192
logical_axis_rules.append(kv_seq_sharding)
193-
for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES:
194-
if ring_attention_axis_rule not in logical_axis_rules:
195-
max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}")
196-
new_rules.append(ring_attention_axis_rule)
193+
if raw_keys["attention"] == "ring":
194+
for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES:
195+
if ring_attention_axis_rule not in logical_axis_rules:
196+
max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}")
197+
new_rules.append(ring_attention_axis_rule)
198+
else: # attention =flash but sequence parallel sharding requested for both self and cross attention
199+
for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES:
200+
if seq_parallel_axis_rule not in logical_axis_rules:
201+
max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}")
202+
new_rules.append(seq_parallel_axis_rule)
197203
raw_keys["logical_axis_rules"] = tuple(new_rules) + tuple(logical_axis_rules)
198204
max_logging.log(f"Final logical axis rules: {raw_keys['logical_axis_rules']}")
199205

0 commit comments

Comments
 (0)