|
66 | 66 |
|
67 | 67 | WAN_MODEL = "Wan2.1" |
68 | 68 |
|
69 | | -### Common axis rules for ring attention ### |
| 69 | +### Common axis rules for attention sharding ### |
70 | 70 | RING_ATTENTION_AXIS_RULES = [ |
71 | | - [SELF_ATTN_HEAD, None], |
72 | | - [SELF_ATTN_Q_LENGTH, CONTEXT], |
73 | | - [SELF_ATTN_KV_LENGTH, CONTEXT], |
74 | | - [CROSS_ATTN_HEAD, None], |
75 | | - [CROSS_ATTN_Q_LENGTH, CONTEXT], |
76 | | - [CROSS_ATTN_KV_LENGTH, CONTEXT], |
| 71 | + (SELF_ATTN_HEAD, None), |
| 72 | + (SELF_ATTN_Q_LENGTH, CONTEXT), |
| 73 | + (SELF_ATTN_KV_LENGTH, CONTEXT), |
| 74 | + (CROSS_ATTN_HEAD, None), |
| 75 | + (CROSS_ATTN_Q_LENGTH, CONTEXT), |
| 76 | + (CROSS_ATTN_KV_LENGTH, CONTEXT), |
77 | 77 | ] |
78 | 78 |
|
79 | 79 | SEQUENCE_PARALLEL_AXIS_RULES = [ |
80 | | - [SELF_ATTN_HEAD, None], |
81 | | - [SELF_ATTN_Q_LENGTH, CONTEXT], |
82 | | - [SELF_ATTN_KV_LENGTH, None], |
83 | | - [CROSS_ATTN_HEAD, None], |
84 | | - [CROSS_ATTN_Q_LENGTH, CONTEXT], |
85 | | - [CROSS_ATTN_KV_LENGTH, None], |
| 80 | + (SELF_ATTN_HEAD, None), |
| 81 | + (SELF_ATTN_Q_LENGTH, CONTEXT), |
| 82 | + (SELF_ATTN_KV_LENGTH, None), |
| 83 | + (CROSS_ATTN_HEAD, None), |
| 84 | + (CROSS_ATTN_Q_LENGTH, CONTEXT), |
| 85 | + (CROSS_ATTN_KV_LENGTH, None), |
| 86 | +] |
| 87 | + |
| 88 | +ULYSSES_ATTENTION_AXIS_RULES = [ |
| 89 | + (SELF_ATTN_HEAD, None), |
| 90 | + (SELF_ATTN_Q_LENGTH, CONTEXT), |
| 91 | + (SELF_ATTN_KV_LENGTH, CONTEXT), |
| 92 | + (CROSS_ATTN_HEAD, None), |
| 93 | + (CROSS_ATTN_Q_LENGTH, CONTEXT), |
| 94 | + (CROSS_ATTN_KV_LENGTH, CONTEXT), |
| 95 | +] |
| 96 | + |
| 97 | +ULYSSES_FSDP_ATTENTION_AXIS_RULES = [ |
| 98 | + (SELF_ATTN_HEAD, None), |
| 99 | + (SELF_ATTN_Q_LENGTH, FSDP), |
| 100 | + (SELF_ATTN_KV_LENGTH, FSDP), |
| 101 | + (CROSS_ATTN_HEAD, None), |
| 102 | + (CROSS_ATTN_Q_LENGTH, FSDP), |
| 103 | + (CROSS_ATTN_KV_LENGTH, FSDP), |
86 | 104 | ] |
0 commit comments