Skip to content

Commit fd4af91

Browse files
committed
fixed mesh
1 parent 1c55452 commit fd4af91

1 file changed

Lines changed: 60 additions & 3 deletions

File tree

src/maxdiffusion/max_utils.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def create_device_mesh(config, devices=None, logging=True):
258258
devices = jax.devices()
259259
num_devices = len(devices)
260260
##special case for ltx-video
261-
if config.ici_fsdp_transpose_parallelism:
261+
if "fsdp_transpose" in config.mesh_axes:
262262
num_slices = 1
263263
# if config.inference_benchmark_test else config.num_slices
264264
num_devices_per_slice = num_devices // num_slices
@@ -271,7 +271,7 @@ def create_device_mesh(config, devices=None, logging=True):
271271
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
272272

273273
return mesh
274-
274+
275275
try:
276276
num_slices = 1 + max([d.slice_index for d in devices])
277277
except:
@@ -303,9 +303,66 @@ def create_device_mesh(config, devices=None, logging=True):
303303
if logging:
304304
max_logging.log(f"Decided on mesh: {mesh}")
305305

306+
307+
308+
309+
310+
311+
312+
313+
314+
315+
316+
317+
318+
319+
320+
321+
322+
323+
306324
return mesh
307325

308326

327+
328+
329+
330+
331+
332+
333+
334+
335+
336+
337+
338+
339+
340+
341+
342+
343+
344+
345+
346+
347+
348+
349+
350+
351+
352+
353+
354+
355+
356+
357+
358+
359+
360+
361+
362+
363+
364+
365+
309366
def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
310367
"""Unboxes the flax.LogicallyPartitioned pieces in a train state.
311368
@@ -628,4 +685,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
628685
initialize_jax_for_gpu()
629686
max_logging.log("Jax distributed system initialized on GPU!")
630687
else:
631-
jax.distributed.initialize()
688+
jax.distributed.initialize()

0 commit comments

Comments
 (0)