Skip to content

Commit 615174f

Browse files
authored
Update max_utils.py
1 parent a1ad421 commit 615174f

1 file changed

Lines changed: 77 additions & 1 deletion

File tree

src/maxdiffusion/max_utils.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,21 @@ def create_device_mesh(config, devices=None, logging=True):
257257
if devices is None:
258258
devices = jax.devices()
259259
num_devices = len(devices)
260+
##special case for ltx-video
261+
if "fsdp_transpose" in config.mesh_axes:
262+
num_slices = 1
263+
# if config.inference_benchmark_test else config.num_slices
264+
num_devices_per_slice = num_devices // num_slices
265+
# Find possible unspecified parallelisms
266+
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
267+
mesh = mesh_utils.create_device_mesh(
268+
ici_parallelism,
269+
devices,
270+
)
271+
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
272+
273+
return mesh
274+
260275
try:
261276
num_slices = 1 + max([d.slice_index for d in devices])
262277
except:
@@ -288,9 +303,66 @@ def create_device_mesh(config, devices=None, logging=True):
288303
if logging:
289304
max_logging.log(f"Decided on mesh: {mesh}")
290305

306+
307+
308+
309+
310+
311+
312+
313+
314+
315+
316+
317+
318+
319+
320+
321+
322+
323+
291324
return mesh
292325

293326

327+
328+
329+
330+
331+
332+
333+
334+
335+
336+
337+
338+
339+
340+
341+
342+
343+
344+
345+
346+
347+
348+
349+
350+
351+
352+
353+
354+
355+
356+
357+
358+
359+
360+
361+
362+
363+
364+
365+
294366
def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
295367
"""Unboxes the flax.LogicallyPartitioned pieces in a train state.
296368
@@ -402,7 +474,11 @@ def setup_initial_state(
402474
config.enable_single_replica_ckpt_restoring,
403475
)
404476
if state:
405-
state = state[checkpoint_item]
477+
###!Edited
478+
if checkpoint_item == " ":
479+
state = state
480+
else:
481+
state = state[checkpoint_item]
406482
if not state:
407483
max_logging.log(f"Could not find the item in orbax, creating state...")
408484
init_train_state_partial = functools.partial(

0 commit comments

Comments
 (0)