Skip to content

Commit 991a44e

Browse files
committed
mesh edit
1 parent c369302 commit 991a44e

2 files changed

Lines changed: 101 additions & 70 deletions

File tree

src/maxdiffusion/max_utils.py

Lines changed: 99 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -252,86 +252,115 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ
252252
return parallelism_vals
253253

254254

255-
def create_device_mesh(config, devices=None):
255+
def create_device_mesh(config, devices=None, logging=True):
256256
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
257257
if devices is None:
258258
devices = jax.devices()
259259
num_devices = len(devices)
260-
num_slices = 1
261-
# if config.inference_benchmark_test else config.num_slices
260+
##special case for ltx-video
261+
if config.ici_fsdp_transpose_parallelism:
262+
num_slices = 1
263+
# if config.inference_benchmark_test else config.num_slices
264+
num_devices_per_slice = num_devices // num_slices
265+
# Find possible unspecified parallelisms
266+
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
267+
mesh = mesh_utils.create_device_mesh(
268+
ici_parallelism,
269+
devices,
270+
)
271+
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
272+
273+
return mesh
274+
275+
try:
276+
num_slices = 1 + max([d.slice_index for d in devices])
277+
except:
278+
num_slices = 1
262279
num_devices_per_slice = num_devices // num_slices
280+
max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")
281+
282+
multi_slice_env = num_slices > 1
263283

264-
# multi_slice_env = num_slices > 1
284+
dcn_parallelism = [
285+
config.dcn_data_parallelism,
286+
config.dcn_fsdp_parallelism,
287+
config.dcn_tensor_parallelism,
288+
]
289+
ici_parallelism = [
290+
config.ici_data_parallelism,
291+
config.ici_fsdp_parallelism,
292+
config.ici_tensor_parallelism,
293+
]
265294

266295
# Find possible unspecified parallelisms
267-
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
268-
269-
# allow_split_physical_axes = config.allow_split_physical_axes if config.allow_split_physical_axes else False
270-
271-
# if allow_split_physical_axes:
272-
# if max_utils.is_valid_custom_mesh(ici_parallelism, config.custom_mesh):
273-
# mesh = mesh_utils.create_device_mesh(
274-
# [16, 16],
275-
# devices,
276-
# contiguous_submeshes=False,
277-
# allow_split_physical_axes=False,
278-
# )
279-
# mesh = max_utils.reshape_mesh_to_rings(mesh, config.custom_mesh)
280-
# mesh = np.reshape(mesh, ici_parallelism)
281-
# else:
282-
# mesh = mesh_utils.create_device_mesh(
283-
# ici_parallelism,
284-
# devices,
285-
# contiguous_submeshes=False,
286-
# allow_split_physical_axes=allow_split_physical_axes,
287-
# )
288-
# else:
289-
mesh = mesh_utils.create_device_mesh(
290-
ici_parallelism,
291-
devices,
292-
)
293-
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
296+
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
297+
if multi_slice_env:
298+
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
299+
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
300+
else:
301+
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
302+
303+
if logging:
304+
max_logging.log(f"Decided on mesh: {mesh}")
305+
306+
307+
308+
309+
310+
311+
312+
313+
314+
315+
316+
317+
318+
319+
320+
321+
322+
294323

295324
return mesh
296325

297326

298-
# def create_device_mesh(config, devices=None, logging=True):
299-
# """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
300-
# if devices is None:
301-
# devices = jax.devices()
302-
# num_devices = len(devices)
303-
# try:
304-
# num_slices = 1 + max([d.slice_index for d in devices])
305-
# except:
306-
# num_slices = 1
307-
# num_devices_per_slice = num_devices // num_slices
308-
# max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")
309-
310-
# multi_slice_env = num_slices > 1
311-
312-
# dcn_parallelism = [
313-
# config.dcn_data_parallelism,
314-
# config.dcn_fsdp_parallelism,
315-
# config.dcn_tensor_parallelism,
316-
# ]
317-
# ici_parallelism = [
318-
# config.ici_data_parallelism,
319-
# config.ici_fsdp_parallelism,
320-
# config.ici_tensor_parallelism,
321-
# ]
322-
323-
# # Find possible unspecified parallelisms
324-
# ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
325-
# if multi_slice_env:
326-
# dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
327-
# mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
328-
# else:
329-
# mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
330-
331-
# if logging:
332-
# max_logging.log(f"Decided on mesh: {mesh}")
333-
334-
# return mesh
327+
328+
329+
330+
331+
332+
333+
334+
335+
336+
337+
338+
339+
340+
341+
342+
343+
344+
345+
346+
347+
348+
349+
350+
351+
352+
353+
354+
355+
356+
357+
358+
359+
360+
361+
362+
363+
335364

336365

337366
def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
@@ -445,6 +474,7 @@ def setup_initial_state(
445474
config.enable_single_replica_ckpt_restoring,
446475
)
447476
if state:
477+
###!Edited
448478
if checkpoint_item == " ":
449479
state = state
450480
else:
@@ -655,4 +685,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
655685
initialize_jax_for_gpu()
656686
max_logging.log("Jax distributed system initialized on GPU!")
657687
else:
658-
jax.distributed.initialize()
688+
jax.distributed.initialize()

src/maxdiffusion/pyconfig.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def user_init(raw_keys):
169169
raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"])
170170
raw_keys["num_slices"] = get_num_slices(raw_keys)
171171
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
172-
raw_keys = create_parallelisms_list(raw_keys)
172+
if "ici_fsdp_transpose_parallelism" in raw_keys:
173+
raw_keys = create_parallelisms_list(raw_keys)
173174

174175

175176
def get_num_slices(raw_keys):

0 commit comments

Comments
 (0)