2727from . import max_logging
2828from . import max_utils
2929from .models .wan .wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH , WAN_21_FUSION_X_MODEL_NAME_OR_PATH
30- from maxdiffusion .common_types import LENGTH , KV_LENGTH , RING_ATTENTION_AXIS_RULES
30+ from maxdiffusion .common_types import LENGTH , KV_LENGTH , RING_ATTENTION_AXIS_RULES , SEQUENCE_PARALLEL_AXIS_RULES
3131
3232
3333def string_to_bool (s : str ) -> bool :
@@ -179,8 +179,8 @@ def user_init(raw_keys):
179179
180180 raw_keys ["logical_axis_rules" ] = _lists_to_tuples (raw_keys ["logical_axis_rules" ])
181181 # Verify qkv is sharded across sequence.
182- if raw_keys ["attention" ] == "ring" :
183- max_logging .log ("Using ring attention, adding sequence sharding to q and kv if not already present." )
182+ if raw_keys ["attention" ] == "ring" or raw_keys [ "attention_sharding_uniform" ] :
183+ max_logging .log (f"Adding sequence sharding to q and kv if not already present because { raw_keys [ 'attention' ] } =='ring' or { raw_keys [ 'attention_sharding_uniform' ] } is set ." )
184184 logical_axis_rules = list (raw_keys ["logical_axis_rules" ])
185185 max_logging .log (f"Initial logical axis rules: { logical_axis_rules } " )
186186 new_rules = []
@@ -190,10 +190,16 @@ def user_init(raw_keys):
190190 logical_axis_rules .append (q_seq_sharding )
191191 if kv_seq_sharding not in logical_axis_rules :
192192 logical_axis_rules .append (kv_seq_sharding )
193- for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES :
194- if ring_attention_axis_rule not in logical_axis_rules :
195- max_logging .log (f"Adding ring attention axis rule { ring_attention_axis_rule } " )
196- new_rules .append (ring_attention_axis_rule )
193+ if raw_keys ["attention" ] == "ring" :
194+ for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES :
195+ if ring_attention_axis_rule not in logical_axis_rules :
196+ max_logging .log (f"Adding ring attention axis rule { ring_attention_axis_rule } " )
197+ new_rules .append (ring_attention_axis_rule )
198+ else : # attention =flash but sequence parallel sharding requested for both self and cross attention
199+ for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES :
200+ if seq_parallel_axis_rule not in logical_axis_rules :
201+ max_logging .log (f"Adding sequence parallel attention axis rule { seq_parallel_axis_rule } " )
202+ new_rules .append (seq_parallel_axis_rule )
197203 raw_keys ["logical_axis_rules" ] = tuple (new_rules ) + tuple (logical_axis_rules )
198204 max_logging .log (f"Final logical axis rules: { raw_keys ['logical_axis_rules' ]} " )
199205
0 commit comments