Skip to content

Commit a8f80b7

Browse files
add config option to allow split physical mesh axis.
1 parent 793574a commit a8f80b7

10 files changed

Lines changed: 20 additions & 2 deletions

src/maxdiffusion/configs/base14.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
135135
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
136136
ici_tensor_parallelism: 1
137137

138+
allow_split_physical_axes: False
139+
138140
# Dataset
139141
# Replace with dataset path or train_data_dir. One has to be set.
140142
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base21.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
136136
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
137137
ici_tensor_parallelism: 1
138138

139+
allow_split_physical_axes: False
140+
139141
# Dataset
140142
# Replace with dataset path or train_data_dir. One has to be set.
141143
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
149149
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
150150
ici_tensor_parallelism: 1
151151

152+
allow_split_physical_axes: False
153+
152154
# Dataset
153155
# Replace with dataset path or train_data_dir. One has to be set.
154156
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ ici_data_parallelism: -1
162162
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
163163
ici_tensor_parallelism: 1
164164

165+
allow_split_physical_axes: False
166+
165167
# Dataset
166168
# Replace with dataset path or train_data_dir. One has to be set.
167169
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ ici_data_parallelism: -1
162162
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
163163
ici_tensor_parallelism: 1
164164

165+
allow_split_physical_axes: False
166+
165167
# Dataset
166168
# Replace with dataset path or train_data_dir. One has to be set.
167169
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ ici_data_parallelism: -1
170170
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
171171
ici_tensor_parallelism: 1
172172

173+
allow_split_physical_axes: False
174+
173175
# Dataset
174176
# Replace with dataset path or train_data_dir. One has to be set.
175177
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ ici_data_parallelism: 1
151151
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
152152
ici_tensor_parallelism: 1
153153

154+
allow_split_physical_axes: False
155+
154156
# Dataset
155157
# Replace with dataset path or train_data_dir. One has to be set.
156158
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_xl.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ ici_data_parallelism: -1
135135
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
136136
ici_tensor_parallelism: 1
137137

138+
allow_split_physical_axes: False
139+
138140
# Dataset
139141
# Replace with dataset path or train_data_dir. One has to be set.
140142
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_xl_lightning.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ ici_data_parallelism: -1
115115
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
116116
ici_tensor_parallelism: 1
117117

118+
allow_split_physical_axes: False
119+
118120
# Dataset
119121
# Replace with dataset path or train_data_dir. One has to be set.
120122
dataset_name: ''

src/maxdiffusion/max_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,9 @@ def create_device_mesh(config, devices=None, logging=True):
281281
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
282282
if multi_slice_env:
283283
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
284-
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
284+
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes)
285285
else:
286-
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
286+
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes)
287287

288288
if logging:
289289
max_logging.log(f"Decided on mesh: {mesh}")

0 commit comments

Comments
 (0)