Skip to content

Commit d9a3502

Browse files
committed
removed testing for now
1 parent e873a17 commit d9a3502

3 files changed

Lines changed: 2 additions & 260 deletions

File tree

src/maxdiffusion/max_utils.py

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def create_device_mesh(config, devices=None, logging=True):
271271
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
272272

273273
return mesh
274-
274+
275275
try:
276276
num_slices = 1 + max([d.slice_index for d in devices])
277277
except:
@@ -303,66 +303,9 @@ def create_device_mesh(config, devices=None, logging=True):
303303
if logging:
304304
max_logging.log(f"Decided on mesh: {mesh}")
305305

306-
307-
308-
309-
310-
311-
312-
313-
314-
315-
316-
317-
318-
319-
320-
321-
322-
323-
324306
return mesh
325307

326308

327-
328-
329-
330-
331-
332-
333-
334-
335-
336-
337-
338-
339-
340-
341-
342-
343-
344-
345-
346-
347-
348-
349-
350-
351-
352-
353-
354-
355-
356-
357-
358-
359-
360-
361-
362-
363-
364-
365-
366309
def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
367310
"""Unboxes the flax.LogicallyPartitioned pieces in a train state.
368311
@@ -685,4 +628,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
685628
initialize_jax_for_gpu()
686629
max_logging.log("Jax distributed system initialized on GPU!")
687630
else:
688-
jax.distributed.initialize()
631+
jax.distributed.initialize()

src/maxdiffusion/tests/ltx_transformer_step_test.py

Lines changed: 0 additions & 201 deletions
This file was deleted.
-258 KB
Binary file not shown.

0 commit comments

Comments
 (0)