@@ -252,86 +252,115 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ
252252 return parallelism_vals
253253
254254
255- def create_device_mesh (config , devices = None ):
255+ def create_device_mesh (config , devices = None , logging = True ):
256256 """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
257257 if devices is None :
258258 devices = jax .devices ()
259259 num_devices = len (devices )
260- num_slices = 1
261- # if config.inference_benchmark_test else config.num_slices
260+ ##special case for ltx-video
261+ if config .ici_fsdp_transpose_parallelism :
262+ num_slices = 1
263+ # if config.inference_benchmark_test else config.num_slices
264+ num_devices_per_slice = num_devices // num_slices
265+ # Find possible unspecified parallelisms
266+ ici_parallelism = fill_unspecified_mesh_axes (config .ici_parallelism .copy (), num_devices_per_slice , "ICI" )
267+ mesh = mesh_utils .create_device_mesh (
268+ ici_parallelism ,
269+ devices ,
270+ )
271+ max_logging .log (f"Num_devices: { num_devices } , shape { mesh .shape } " )
272+
273+ return mesh
274+
275+ try :
276+ num_slices = 1 + max ([d .slice_index for d in devices ])
277+ except :
278+ num_slices = 1
262279 num_devices_per_slice = num_devices // num_slices
280+ max_logging .log (f"Devices: { devices } (num_devices: { num_devices } )" )
281+
282+ multi_slice_env = num_slices > 1
263283
264- # multi_slice_env = num_slices > 1
284+ dcn_parallelism = [
285+ config .dcn_data_parallelism ,
286+ config .dcn_fsdp_parallelism ,
287+ config .dcn_tensor_parallelism ,
288+ ]
289+ ici_parallelism = [
290+ config .ici_data_parallelism ,
291+ config .ici_fsdp_parallelism ,
292+ config .ici_tensor_parallelism ,
293+ ]
265294
266295 # Find possible unspecified parallelisms
267- ici_parallelism = fill_unspecified_mesh_axes (config . ici_parallelism . copy () , num_devices_per_slice , "ICI" )
268-
269- # allow_split_physical_axes = config.allow_split_physical_axes if config.allow_split_physical_axes else False
270-
271- # if allow_split_physical_axes :
272- # if max_utils.is_valid_custom_mesh (ici_parallelism, config.custom_mesh):
273- # mesh = mesh_utils.create_device_mesh(
274- # [16, 16],
275- # devices,
276- # contiguous_submeshes=False,
277- # allow_split_physical_axes=False,
278- # )
279- # mesh = max_utils.reshape_mesh_to_rings(mesh, config.custom_mesh)
280- # mesh = np.reshape(mesh, ici_parallelism)
281- # else:
282- # mesh = mesh_utils.create_device_mesh(
283- # ici_parallelism,
284- # devices,
285- # contiguous_submeshes=False,
286- # allow_split_physical_axes=allow_split_physical_axes,
287- # )
288- # else:
289- mesh = mesh_utils . create_device_mesh (
290- ici_parallelism ,
291- devices ,
292- )
293- max_logging . log ( f"Num_devices: { num_devices } , shape { mesh . shape } " )
296+ ici_parallelism = fill_unspecified_mesh_axes (ici_parallelism , num_devices_per_slice , "ICI" )
297+ if multi_slice_env :
298+ dcn_parallelism = fill_unspecified_mesh_axes ( dcn_parallelism , num_slices , "DCN" )
299+ mesh = mesh_utils . create_hybrid_device_mesh ( ici_parallelism , dcn_parallelism , devices )
300+ else :
301+ mesh = mesh_utils . create_device_mesh (ici_parallelism , devices )
302+
303+ if logging :
304+ max_logging . log ( f"Decided on mesh: { mesh } " )
305+
306+
307+
308+
309+
310+
311+
312+
313+
314+
315+
316+
317+
318+
319+
320+
321+
322+
294323
295324 return mesh
296325
297326
298- # def create_device_mesh(config, devices=None, logging=True):
299- # """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
300- # if devices is None:
301- # devices = jax.devices()
302- # num_devices = len(devices)
303- # try:
304- # num_slices = 1 + max([d.slice_index for d in devices])
305- # except:
306- # num_slices = 1
307- # num_devices_per_slice = num_devices // num_slices
308- # max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")
309-
310- # multi_slice_env = num_slices > 1
311-
312- # dcn_parallelism = [
313- # config.dcn_data_parallelism,
314- # config.dcn_fsdp_parallelism,
315- # config.dcn_tensor_parallelism,
316- # ]
317- # ici_parallelism = [
318- # config.ici_data_parallelism,
319- # config.ici_fsdp_parallelism,
320- # config.ici_tensor_parallelism,
321- # ]
322-
323- # # Find possible unspecified parallelisms
324- # ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
325- # if multi_slice_env:
326- # dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
327- # mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
328- # else:
329- # mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
330-
331- # if logging:
332- # max_logging.log(f"Decided on mesh: {mesh}")
333-
334- # return mesh
327+
328+
329+
330+
331+
332+
333+
334+
335+
336+
337+
338+
339+
340+
341+
342+
343+
344+
345+
346+
347+
348+
349+
350+
351+
352+
353+
354+
355+
356+
357+
358+
359+
360+
361+
362+
363+
335364
336365
337366def unbox_logicallypartioned_trainstate (boxed_train_state : train_state .TrainState ):
@@ -445,6 +474,7 @@ def setup_initial_state(
445474 config .enable_single_replica_ckpt_restoring ,
446475 )
447476 if state :
477+ ###!Edited
448478 if checkpoint_item == " " :
449479 state = state
450480 else :
@@ -655,4 +685,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
655685 initialize_jax_for_gpu ()
656686 max_logging .log ("Jax distributed system initialized on GPU!" )
657687 else :
658- jax .distributed .initialize ()
688+ jax .distributed .initialize ()
0 commit comments