Skip to content

Commit b0e9bab

Browse files
committed
key error fix
1 parent 991a44e commit b0e9bab

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def create_device_mesh(config, devices=None, logging=True):
258258
devices = jax.devices()
259259
num_devices = len(devices)
260260
##special case for ltx-video
261-
if config.ici_fsdp_transpose_parallelism:
261+
if "fsdp_transpose" in config.mesh_axes:
262262
num_slices = 1
263263
# if config.inference_benchmark_test else config.num_slices
264264
num_devices_per_slice = num_devices // num_slices

0 commit comments

Comments
 (0)