@@ -251,46 +251,88 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ
251251
252252 return parallelism_vals
253253
254-
255- def create_device_mesh (config , devices = None , logging = True ):
254+ def create_device_mesh (config , devices = None ):
256255 """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
257256 if devices is None :
258257 devices = jax .devices ()
259258 num_devices = len (devices )
260- try :
261- num_slices = 1 + max ([d .slice_index for d in devices ])
262- except :
263- num_slices = 1
259+ num_slices = 1
260+ # if config.inference_benchmark_test else config.num_slices
264261 num_devices_per_slice = num_devices // num_slices
265- max_logging .log (f"Devices: { devices } (num_devices: { num_devices } )" )
266262
267- multi_slice_env = num_slices > 1
268-
269- dcn_parallelism = [
270- config .dcn_data_parallelism ,
271- config .dcn_fsdp_parallelism ,
272- config .dcn_tensor_parallelism ,
273- ]
274- ici_parallelism = [
275- config .ici_data_parallelism ,
276- config .ici_fsdp_parallelism ,
277- config .ici_tensor_parallelism ,
278- ]
263+ # multi_slice_env = num_slices > 1
279264
280265 # Find possible unspecified parallelisms
281- ici_parallelism = fill_unspecified_mesh_axes (ici_parallelism , num_devices_per_slice , "ICI" )
282- if multi_slice_env :
283- dcn_parallelism = fill_unspecified_mesh_axes (dcn_parallelism , num_slices , "DCN" )
284- mesh = mesh_utils .create_hybrid_device_mesh (ici_parallelism , dcn_parallelism , devices )
285- else :
286- mesh = mesh_utils .create_device_mesh (ici_parallelism , devices )
287-
288- if logging :
289- max_logging .log (f"Decided on mesh: { mesh } " )
266+ ici_parallelism = fill_unspecified_mesh_axes (config .ici_parallelism .copy (), num_devices_per_slice , "ICI" )
267+
268+ # allow_split_physical_axes = config.allow_split_physical_axes if config.allow_split_physical_axes else False
269+
270+ # if allow_split_physical_axes:
271+ # if max_utils.is_valid_custom_mesh(ici_parallelism, config.custom_mesh):
272+ # mesh = mesh_utils.create_device_mesh(
273+ # [16, 16],
274+ # devices,
275+ # contiguous_submeshes=False,
276+ # allow_split_physical_axes=False,
277+ # )
278+ # mesh = max_utils.reshape_mesh_to_rings(mesh, config.custom_mesh)
279+ # mesh = np.reshape(mesh, ici_parallelism)
280+ # else:
281+ # mesh = mesh_utils.create_device_mesh(
282+ # ici_parallelism,
283+ # devices,
284+ # contiguous_submeshes=False,
285+ # allow_split_physical_axes=allow_split_physical_axes,
286+ # )
287+ # else:
288+ mesh = mesh_utils .create_device_mesh (
289+ ici_parallelism ,
290+ devices ,
291+ )
292+ max_logging .log (f"Num_devices: { num_devices } , shape { mesh .shape } " )
290293
291294 return mesh
292295
293296
297+ # def create_device_mesh(config, devices=None, logging=True):
298+ # """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
299+ # if devices is None:
300+ # devices = jax.devices()
301+ # num_devices = len(devices)
302+ # try:
303+ # num_slices = 1 + max([d.slice_index for d in devices])
304+ # except:
305+ # num_slices = 1
306+ # num_devices_per_slice = num_devices // num_slices
307+ # max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")
308+
309+ # multi_slice_env = num_slices > 1
310+
311+ # dcn_parallelism = [
312+ # config.dcn_data_parallelism,
313+ # config.dcn_fsdp_parallelism,
314+ # config.dcn_tensor_parallelism,
315+ # ]
316+ # ici_parallelism = [
317+ # config.ici_data_parallelism,
318+ # config.ici_fsdp_parallelism,
319+ # config.ici_tensor_parallelism,
320+ # ]
321+
322+ # # Find possible unspecified parallelisms
323+ # ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
324+ # if multi_slice_env:
325+ # dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
326+ # mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
327+ # else:
328+ # mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
329+
330+ # if logging:
331+ # max_logging.log(f"Decided on mesh: {mesh}")
332+
333+ # return mesh
334+
335+
294336def unbox_logicallypartioned_trainstate (boxed_train_state : train_state .TrainState ):
295337 """Unboxes the flax.LogicallyPartitioned pieces in a train state.
296338
0 commit comments