Skip to content

Commit aa7befd

Browse files
committed
changed sharding back
1 parent 1ea6590 commit aa7befd

5 files changed

Lines changed: 18 additions & 9 deletions

File tree

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,22 @@ output_dir: 'ltx-video-output'
4040
save_config_to_gcs: False
4141

4242
#parallelism
43-
mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']
43+
mesh_axes: ['data', 'fsdp', 'tensor']
4444
logical_axis_rules: [
4545
['batch', 'data'],
46+
['activation_heads', 'fsdp'],
4647
['activation_batch', ['data','fsdp']],
47-
['activation_heads', 'tensor'],
4848
['activation_kv', 'tensor'],
4949
['mlp','tensor'],
5050
['embed','fsdp'],
5151
['heads', 'tensor'],
52+
['norm', 'fsdp'],
5253
['conv_batch', ['data','fsdp']],
5354
['out_channels', 'tensor'],
5455
['conv_out', 'fsdp'],
56+
['conv_in', 'fsdp']
5557
]
56-
data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']]
58+
data_sharding: [['data', 'fsdp', 'tensor']]
5759
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
5860
dcn_fsdp_parallelism: -1
5961
dcn_tensor_parallelism: 1

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,4 +609,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
609609
initialize_jax_for_gpu()
610610
max_logging.log("Jax distributed system initialized on GPU!")
611611
else:
612-
jax.distributed.initialize()
612+
jax.distributed.initialize()

src/maxdiffusion/models/ltx_video/transformers/attention.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,14 +622,21 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids):
622622
raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.")
623623
# Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim")
624624
# Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py.
625+
# qkvo_sharding_spec = jax.sharding.PartitionSpec(
626+
# ("data", "fsdp", "fsdp_transpose", "expert"),
627+
# ("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
628+
# None,
629+
# None,
630+
# )
625631
qkvo_sharding_spec = jax.sharding.PartitionSpec(
626-
("data", "fsdp", "fsdp_transpose", "expert"),
627-
("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
632+
None,
633+
("data", "fsdp", "tensor"),
628634
None,
629635
None,
630636
)
631637
# Based on: ("activation_kv_batch", "activation_length")
632-
qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence")
638+
qkv_segment_ids_spec = jax.sharding.PartitionSpec("fsdp", None)
639+
# qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None)
633640
wrapped_flash_attention = shard_map(
634641
partial_flash_attention,
635642
mesh=sharding_mesh,

src/maxdiffusion/pyconfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,4 +226,4 @@ def initialize(argv, **kwargs):
226226
if __name__ == "__main__":
227227
initialize(sys.argv)
228228
print(config.steps)
229-
r = range(config.steps)
229+
r = range(config.steps)

src/maxdiffusion/tests/ltx_transformer_step_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_one_step_transformer(self):
104104
devices_array = create_device_mesh(config)
105105
mesh = Mesh(devices_array, config.mesh_axes)
106106
base_dir = os.path.dirname(__file__)
107-
config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json")
107+
config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json")
108108

109109
with open(config_path, "r") as f:
110110
model_config = json.load(f)

0 commit comments

Comments
 (0)