Skip to content

Commit 4be2899

Browse files
committed
Fix embed axis and add I2V configs to TP strategy
1 parent 08c0cab commit 4be2899

6 files changed

Lines changed: 20 additions & 20 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ logical_axis_rules: [
177177
['activation_length', None],
178178
['activation_heads', 'tensor'],
179179
['mlp','tensor'],
180-
['embed', ['fsdp', 'tensor']],
180+
['embed', 'fsdp'],
181181
['heads', 'tensor'],
182182
['norm', 'tensor'],
183183
['conv_batch', ['data', 'fsdp']],

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ logical_axis_rules: [
150150
['activation_length', None],
151151
['activation_heads', 'tensor'],
152152
['mlp','tensor'],
153-
['embed', ['fsdp', 'tensor']],
153+
['embed', 'fsdp'],
154154
['heads', 'tensor'],
155155
['norm', 'tensor'],
156156
['conv_batch', ['data', 'fsdp']],

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ logical_axis_rules: [
161161
['activation_length', None],
162162
['activation_heads', 'tensor'],
163163
['mlp','tensor'],
164-
['embed', ['fsdp', 'tensor']],
164+
['embed', 'fsdp'],
165165
['heads', 'tensor'],
166166
['norm', 'tensor'],
167167
['conv_batch', ['data', 'fsdp']],

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,17 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
151151
logical_axis_rules: [
152152
['batch', ['data', 'fsdp']],
153153
['activation_batch', ['data', 'fsdp']],
154-
['activation_self_attn_heads', ['context', 'tensor']],
155-
['activation_cross_attn_q_length', ['context', 'tensor']],
156-
['activation_length', 'context'],
154+
['activation_self_attn_heads', 'tensor'],
155+
['activation_cross_attn_q_length', 'tensor'],
156+
['activation_length', None],
157157
['activation_heads', 'tensor'],
158158
['mlp','tensor'],
159-
['embed', ['context', 'fsdp']],
159+
['embed', 'fsdp'],
160160
['heads', 'tensor'],
161161
['norm', 'tensor'],
162-
['conv_batch', ['data', 'context', 'fsdp']],
162+
['conv_batch', ['data', 'fsdp']],
163163
['out_channels', 'tensor'],
164-
['conv_out', 'context'],
164+
['conv_out', 'tensor'],
165165
]
166166
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
167167

@@ -174,9 +174,9 @@ dcn_fsdp_parallelism: -1
174174
dcn_context_parallelism: 1
175175
dcn_tensor_parallelism: 1
176176
ici_data_parallelism: 1
177-
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
177+
ici_fsdp_parallelism: 1
178178
ici_context_parallelism: 1
179-
ici_tensor_parallelism: 1
179+
ici_tensor_parallelism: -1 # recommended ICI axis to be auto-sharded
180180

181181
allow_split_physical_axes: False
182182

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,17 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
152152
logical_axis_rules: [
153153
['batch', ['data', 'fsdp']],
154154
['activation_batch', ['data', 'fsdp']],
155-
['activation_self_attn_heads', ['context', 'tensor']],
156-
['activation_cross_attn_q_length', ['context', 'tensor']],
157-
['activation_length', 'context'],
155+
['activation_self_attn_heads', 'tensor'],
156+
['activation_cross_attn_q_length', 'tensor'],
157+
['activation_length', None],
158158
['activation_heads', 'tensor'],
159159
['mlp','tensor'],
160-
['embed', ['context', 'fsdp']],
160+
['embed', 'fsdp'],
161161
['heads', 'tensor'],
162162
['norm', 'tensor'],
163-
['conv_batch', ['data', 'context', 'fsdp']],
163+
['conv_batch', ['data', 'fsdp']],
164164
['out_channels', 'tensor'],
165-
['conv_out', 'context'],
165+
['conv_out', 'tensor'],
166166
]
167167
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
168168

@@ -175,9 +175,9 @@ dcn_fsdp_parallelism: -1
175175
dcn_context_parallelism: 1
176176
dcn_tensor_parallelism: 1
177177
ici_data_parallelism: 1
178-
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
178+
ici_fsdp_parallelism: 1
179179
ici_context_parallelism: 1
180-
ici_tensor_parallelism: 1
180+
ici_tensor_parallelism: -1 # recommended ICI axis to be auto-sharded
181181

182182
allow_split_physical_axes: False
183183

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ logical_axis_rules: [
4949
['activation_length', None],
5050
['activation_heads', 'tensor'],
5151
['mlp','tensor'],
52-
['embed', ['fsdp', 'tensor']],
52+
['embed', 'fsdp'],
5353
['heads', 'tensor'],
5454
['norm', 'tensor'],
5555
['conv_batch', ['data', 'fsdp']],

0 commit comments

Comments
 (0)