Skip to content

Commit 978213d

Browse files
fix sharding
1 parent 0eb3303 commit 978213d

5 files changed

Lines changed: 105 additions & 41 deletions

File tree

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ output_dir: 'ltx-video-output'
1212
save_config_to_gcs: False
1313

1414
#parallelism
15-
mesh_axes: ['data', 'fsdp', 'tensor']
15+
mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']
1616
logical_axis_rules: [
1717
['batch', 'data'],
1818
['activation_batch', ['data','fsdp']],
@@ -25,13 +25,19 @@ logical_axis_rules: [
2525
['out_channels', 'tensor'],
2626
['conv_out', 'fsdp'],
2727
]
28-
data_sharding: [['data', 'fsdp', 'tensor']]
28+
data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']]
2929
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
3030
dcn_fsdp_parallelism: -1
3131
dcn_tensor_parallelism: 1
32+
3233
ici_data_parallelism: -1
3334
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
3435
ici_tensor_parallelism: 1
36+
ici_fsdp_transpose_parallelism: 1
37+
ici_sequence_parallelism: 1
38+
ici_tensor_transpose_parallelism: 1
39+
ici_expert_parallelism: 1
40+
ici_sequence_parallelism: 1
3541

3642

3743

src/maxdiffusion/max_utils.py

Lines changed: 70 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -251,46 +251,88 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ
251251

252252
return parallelism_vals
253253

254-
255-
def create_device_mesh(config, devices=None, logging=True):
254+
def create_device_mesh(config, devices=None):
256255
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
257256
if devices is None:
258257
devices = jax.devices()
259258
num_devices = len(devices)
260-
try:
261-
num_slices = 1 + max([d.slice_index for d in devices])
262-
except:
263-
num_slices = 1
259+
num_slices = 1
260+
# if config.inference_benchmark_test else config.num_slices
264261
num_devices_per_slice = num_devices // num_slices
265-
max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")
266262

267-
multi_slice_env = num_slices > 1
268-
269-
dcn_parallelism = [
270-
config.dcn_data_parallelism,
271-
config.dcn_fsdp_parallelism,
272-
config.dcn_tensor_parallelism,
273-
]
274-
ici_parallelism = [
275-
config.ici_data_parallelism,
276-
config.ici_fsdp_parallelism,
277-
config.ici_tensor_parallelism,
278-
]
263+
# multi_slice_env = num_slices > 1
279264

280265
# Find possible unspecified parallelisms
281-
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
282-
if multi_slice_env:
283-
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
284-
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
285-
else:
286-
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
287-
288-
if logging:
289-
max_logging.log(f"Decided on mesh: {mesh}")
266+
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
267+
268+
# allow_split_physical_axes = config.allow_split_physical_axes if config.allow_split_physical_axes else False
269+
270+
# if allow_split_physical_axes:
271+
# if max_utils.is_valid_custom_mesh(ici_parallelism, config.custom_mesh):
272+
# mesh = mesh_utils.create_device_mesh(
273+
# [16, 16],
274+
# devices,
275+
# contiguous_submeshes=False,
276+
# allow_split_physical_axes=False,
277+
# )
278+
# mesh = max_utils.reshape_mesh_to_rings(mesh, config.custom_mesh)
279+
# mesh = np.reshape(mesh, ici_parallelism)
280+
# else:
281+
# mesh = mesh_utils.create_device_mesh(
282+
# ici_parallelism,
283+
# devices,
284+
# contiguous_submeshes=False,
285+
# allow_split_physical_axes=allow_split_physical_axes,
286+
# )
287+
# else:
288+
mesh = mesh_utils.create_device_mesh(
289+
ici_parallelism,
290+
devices,
291+
)
292+
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
290293

291294
return mesh
292295

293296

297+
# def create_device_mesh(config, devices=None, logging=True):
298+
# """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
299+
# if devices is None:
300+
# devices = jax.devices()
301+
# num_devices = len(devices)
302+
# try:
303+
# num_slices = 1 + max([d.slice_index for d in devices])
304+
# except:
305+
# num_slices = 1
306+
# num_devices_per_slice = num_devices // num_slices
307+
# max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")
308+
309+
# multi_slice_env = num_slices > 1
310+
311+
# dcn_parallelism = [
312+
# config.dcn_data_parallelism,
313+
# config.dcn_fsdp_parallelism,
314+
# config.dcn_tensor_parallelism,
315+
# ]
316+
# ici_parallelism = [
317+
# config.ici_data_parallelism,
318+
# config.ici_fsdp_parallelism,
319+
# config.ici_tensor_parallelism,
320+
# ]
321+
322+
# # Find possible unspecified parallelisms
323+
# ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
324+
# if multi_slice_env:
325+
# dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
326+
# mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
327+
# else:
328+
# mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)
329+
330+
# if logging:
331+
# max_logging.log(f"Decided on mesh: {mesh}")
332+
333+
# return mesh
334+
335+
294336
def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
295337
"""Unboxes the flax.LogicallyPartitioned pieces in a train state.
296338

src/maxdiffusion/models/ltx_video/transformers/attention.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -631,27 +631,27 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids):
631631
raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.")
632632
# Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim")
633633
# Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py.
634+
qkvo_sharding_spec = jax.sharding.PartitionSpec(
635+
("data", "fsdp", "fsdp_transpose", "expert"),
636+
("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
637+
None,
638+
None,
639+
)
634640
# qkvo_sharding_spec = jax.sharding.PartitionSpec(
635641
# ("data", "fsdp", "fsdp_transpose", "expert"),
636642
# ("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
637643
# None,
638644
# None,
639645
# )
640646
# qkvo_sharding_spec = jax.sharding.PartitionSpec(
641-
# ("data", "fsdp", "fsdp_transpose", "expert"),
642-
# ("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
647+
# None,
648+
# None,
643649
# None,
644650
# None,
645651
# )
646-
qkvo_sharding_spec = jax.sharding.PartitionSpec(
647-
None,
648-
None,
649-
None,
650-
None,
651-
)
652652
#Based on: ("activation_kv_batch", "activation_length")
653-
# qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence")
654-
qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None)
653+
qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence")
654+
# qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None)
655655
wrapped_flash_attention = shard_map(
656656
partial_flash_attention,
657657
mesh=sharding_mesh,

src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"ckpt_path": "/mnt/disks/diffusionproj/jax_weights",
2+
"ckpt_path": "/dev/shm/ltx_converted",
33
"activation_fn": "gelu-approximate",
44
"attention_bias": true,
55
"attention_head_dim": 128,

src/maxdiffusion/pyconfig.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ def string_to_bool(s: str) -> bool:
4141
config = None
4242

4343

44+
def create_parallelisms_list(raw_keys):
45+
ici_parallelism = [
46+
raw_keys["ici_data_parallelism"],
47+
raw_keys["ici_fsdp_parallelism"],
48+
raw_keys["ici_fsdp_transpose_parallelism"],
49+
raw_keys["ici_sequence_parallelism"],
50+
raw_keys["ici_tensor_parallelism"],
51+
raw_keys["ici_tensor_transpose_parallelism"],
52+
raw_keys["ici_expert_parallelism"],
53+
raw_keys["ici_sequence_parallelism"],
54+
]
55+
raw_keys["ici_parallelism"] = ici_parallelism
56+
return raw_keys
57+
58+
4459
def print_system_information():
4560
max_logging.log(f"System Information: Jax Version: {jax.__version__}")
4661
max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}")
@@ -154,6 +169,7 @@ def user_init(raw_keys):
154169
raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"])
155170
raw_keys["num_slices"] = get_num_slices(raw_keys)
156171
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
172+
raw_keys = create_parallelisms_list(raw_keys)
157173

158174

159175
def get_num_slices(raw_keys):

0 commit comments

Comments
 (0)