@@ -257,6 +257,21 @@ def create_device_mesh(config, devices=None, logging=True):
257257 if devices is None :
258258 devices = jax .devices ()
259259 num_devices = len (devices )
260+ ##special case for ltx-video
261+ if "fsdp_transpose" in config .mesh_axes :
262+ num_slices = 1
263+ # if config.inference_benchmark_test else config.num_slices
264+ num_devices_per_slice = num_devices // num_slices
265+ # Find possible unspecified parallelisms
266+ ici_parallelism = fill_unspecified_mesh_axes (config .ici_parallelism .copy (), num_devices_per_slice , "ICI" )
267+ mesh = mesh_utils .create_device_mesh (
268+ ici_parallelism ,
269+ devices ,
270+ )
271+ max_logging .log (f"Num_devices: { num_devices } , shape { mesh .shape } " )
272+
273+ return mesh
274+
260275 try :
261276 num_slices = 1 + max ([d .slice_index for d in devices ])
262277 except :
@@ -288,9 +303,66 @@ def create_device_mesh(config, devices=None, logging=True):
288303 if logging :
289304 max_logging .log (f"Decided on mesh: { mesh } " )
290305
306+
307+
308+
309+
310+
311+
312+
313+
314+
315+
316+
317+
318+
319+
320+
321+
322+
323+
291324 return mesh
292325
293326
327+
328+
329+
330+
331+
332+
333+
334+
335+
336+
337+
338+
339+
340+
341+
342+
343+
344+
345+
346+
347+
348+
349+
350+
351+
352+
353+
354+
355+
356+
357+
358+
359+
360+
361+
362+
363+
364+
365+
294366def unbox_logicallypartioned_trainstate (boxed_train_state : train_state .TrainState ):
295367 """Unboxes the flax.LogicallyPartitioned pieces in a train state.
296368
@@ -402,7 +474,11 @@ def setup_initial_state(
402474 config .enable_single_replica_ckpt_restoring ,
403475 )
404476 if state :
405- state = state [checkpoint_item ]
477+ ###!Edited
478+ if checkpoint_item == " " :
479+ state = state
480+ else :
481+ state = state [checkpoint_item ]
406482 if not state :
407483 max_logging .log (f"Could not find the item in orbax, creating state..." )
408484 init_train_state_partial = functools .partial (
0 commit comments