Skip to content

Commit d6542fb

Browse files
committed
Add Wan ulysses_fsdp attention path
1 parent 097f4c3 commit d6542fb

12 files changed

Lines changed: 1057 additions & 143 deletions

docs/wan_animate_ulysses_fsdp_walkthrough.md

Lines changed: 515 additions & 0 deletions
Large diffs are not rendered by default.

src/maxdiffusion/common_types.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,39 @@
6666

6767
WAN_MODEL = "Wan2.1"
6868

69-
### Common axis rules for ring attention ###
69+
### Common axis rules for attention sharding ###
7070
RING_ATTENTION_AXIS_RULES = [
71-
[SELF_ATTN_HEAD, None],
72-
[SELF_ATTN_Q_LENGTH, CONTEXT],
73-
[SELF_ATTN_KV_LENGTH, CONTEXT],
74-
[CROSS_ATTN_HEAD, None],
75-
[CROSS_ATTN_Q_LENGTH, CONTEXT],
76-
[CROSS_ATTN_KV_LENGTH, CONTEXT],
71+
(SELF_ATTN_HEAD, None),
72+
(SELF_ATTN_Q_LENGTH, CONTEXT),
73+
(SELF_ATTN_KV_LENGTH, CONTEXT),
74+
(CROSS_ATTN_HEAD, None),
75+
(CROSS_ATTN_Q_LENGTH, CONTEXT),
76+
(CROSS_ATTN_KV_LENGTH, CONTEXT),
7777
]
7878

7979
SEQUENCE_PARALLEL_AXIS_RULES = [
80-
[SELF_ATTN_HEAD, None],
81-
[SELF_ATTN_Q_LENGTH, CONTEXT],
82-
[SELF_ATTN_KV_LENGTH, None],
83-
[CROSS_ATTN_HEAD, None],
84-
[CROSS_ATTN_Q_LENGTH, CONTEXT],
85-
[CROSS_ATTN_KV_LENGTH, None],
80+
(SELF_ATTN_HEAD, None),
81+
(SELF_ATTN_Q_LENGTH, CONTEXT),
82+
(SELF_ATTN_KV_LENGTH, None),
83+
(CROSS_ATTN_HEAD, None),
84+
(CROSS_ATTN_Q_LENGTH, CONTEXT),
85+
(CROSS_ATTN_KV_LENGTH, None),
86+
]
87+
88+
ULYSSES_ATTENTION_AXIS_RULES = [
89+
(SELF_ATTN_HEAD, None),
90+
(SELF_ATTN_Q_LENGTH, CONTEXT),
91+
(SELF_ATTN_KV_LENGTH, CONTEXT),
92+
(CROSS_ATTN_HEAD, None),
93+
(CROSS_ATTN_Q_LENGTH, CONTEXT),
94+
(CROSS_ATTN_KV_LENGTH, CONTEXT),
95+
]
96+
97+
ULYSSES_FSDP_ATTENTION_AXIS_RULES = [
98+
(SELF_ATTN_HEAD, None),
99+
(SELF_ATTN_Q_LENGTH, FSDP),
100+
(SELF_ATTN_KV_LENGTH, FSDP),
101+
(CROSS_ATTN_HEAD, None),
102+
(CROSS_ATTN_Q_LENGTH, FSDP),
103+
(CROSS_ATTN_KV_LENGTH, FSDP),
86104
]

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, ulysses, ulysses_fsdp, cudnn_flash_te, ring
6464
flash_min_seq_length: 0
6565

6666
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
@@ -166,19 +166,19 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
166166
# conv_in : conv.shape[2] weight
167167
# conv_out : conv.shape[-1] weight
168168
logical_axis_rules: [
169-
['batch', ['data', 'fsdp']],
170-
['activation_batch', ['data', 'fsdp']],
169+
['batch', 'data'],
170+
['activation_batch', 'data'],
171171
['activation_self_attn_heads', ['context', 'tensor']],
172172
['activation_cross_attn_q_length', ['context', 'tensor']],
173173
['activation_length', 'context'],
174174
['activation_heads', 'tensor'],
175175
['mlp','tensor'],
176-
['embed', ['context', 'fsdp']],
176+
['embed', 'fsdp'],
177177
['heads', 'tensor'],
178178
['norm', 'tensor'],
179-
['conv_batch', ['data', 'context', 'fsdp']],
179+
['conv_batch', 'data'],
180180
['out_channels', 'tensor'],
181-
['conv_out', 'context'],
181+
['conv_out', 'fsdp'],
182182
]
183183
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
184184

@@ -386,4 +386,4 @@ eval_data_dir: ""
386386
enable_generate_video_for_eval: False # This will increase the used TPU memory.
387387
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).
388388

389-
enable_ssim: False
389+
enable_ssim: False

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, ulysses, ulysses_fsdp, cudnn_flash_te, ring
6464
flash_min_seq_length: 0
6565

6666
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
@@ -143,19 +143,19 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
143143
# conv_in : conv.shape[2] weight
144144
# conv_out : conv.shape[-1] weight
145145
logical_axis_rules: [
146-
['batch', ['data', 'fsdp']],
147-
['activation_batch', ['data', 'fsdp']],
146+
['batch', 'data'],
147+
['activation_batch', 'data'],
148148
['activation_self_attn_heads', ['context', 'tensor']],
149149
['activation_cross_attn_q_length', ['context', 'tensor']],
150150
['activation_length', 'context'],
151151
['activation_heads', 'tensor'],
152152
['mlp','tensor'],
153-
['embed', ['context', 'fsdp']],
153+
['embed', 'fsdp'],
154154
['heads', 'tensor'],
155155
['norm', 'tensor'],
156-
['conv_batch', ['data', 'context', 'fsdp']],
156+
['conv_batch', 'data'],
157157
['out_channels', 'tensor'],
158-
['conv_out', 'context'],
158+
['conv_out', 'fsdp'],
159159
]
160160
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
161161

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jit_initializers: True
6060
# Set true to load weights from pytorch
6161
from_pt: True
6262
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63+
attention: 'flash' # Supported attention: dot_product, flash, ulysses, ulysses_fsdp, cudnn_flash_te, ring
6464
flash_min_seq_length: 4096
6565
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
6666
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
@@ -154,19 +154,19 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
154154
# conv_in : conv.shape[2] weight
155155
# conv_out : conv.shape[-1] weight
156156
logical_axis_rules: [
157-
['batch', ['data', 'fsdp']],
158-
['activation_batch', ['data', 'fsdp']],
157+
['batch', 'data'],
158+
['activation_batch', 'data'],
159159
['activation_self_attn_heads', ['context', 'tensor']],
160160
['activation_cross_attn_q_length', ['context', 'tensor']],
161161
['activation_length', 'context'],
162162
['activation_heads', 'tensor'],
163163
['mlp','tensor'],
164-
['embed', ['context', 'fsdp']],
164+
['embed', 'fsdp'],
165165
['heads', 'tensor'],
166166
['norm', 'tensor'],
167-
['conv_batch', ['data', 'context', 'fsdp']],
167+
['conv_batch', 'data'],
168168
['out_channels', 'tensor'],
169-
['conv_out', 'context'],
169+
['conv_out', 'fsdp'],
170170
]
171171
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
172172

@@ -364,4 +364,4 @@ eval_data_dir: ""
364364
enable_generate_video_for_eval: False # This will increase the used TPU memory.
365365
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).
366366

367-
enable_ssim: False
367+
enable_ssim: False

src/maxdiffusion/configs/base_wan_animate_27b.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jit_initializers: True
6262
# Set true to load weights from pytorch
6363
from_pt: True
6464
split_head_dim: True
65-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
65+
attention: 'flash' # Supported attention: dot_product, flash, ulysses, ulysses_fsdp, cudnn_flash_te, ring
6666
flash_min_seq_length: 4096
6767
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
6868
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
@@ -156,19 +156,19 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
156156
# conv_in : conv.shape[2] weight
157157
# conv_out : conv.shape[-1] weight
158158
logical_axis_rules: [
159-
['batch', ['data', 'fsdp']],
160-
['activation_batch', ['data', 'fsdp']],
159+
['batch', 'data'],
160+
['activation_batch', 'data'],
161161
['activation_self_attn_heads', ['context', 'tensor']],
162162
['activation_cross_attn_q_length', ['context', 'tensor']],
163163
['activation_length', 'context'],
164164
['activation_heads', 'tensor'],
165165
['mlp','tensor'],
166-
['embed', ['context', 'fsdp']],
166+
['embed', 'fsdp'],
167167
['heads', 'tensor'],
168168
['norm', 'tensor'],
169-
['conv_batch', ['data', 'context', 'fsdp']],
169+
['conv_batch', 'data'],
170170
['out_channels', 'tensor'],
171-
['conv_out', 'context'],
171+
['conv_out', 'fsdp'],
172172
]
173173
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
174174

0 commit comments

Comments
 (0)