@@ -22,47 +22,32 @@ weights_dtype: 'bfloat16'
2222activations_dtype : ' bfloat16'
2323
2424
25- run_name : ' '
26- output_dir : ' ltx-video-output'
27- save_config_to_gcs : False
28-
29- # hardware
30- hardware : ' tpu'
31- skip_jax_distributed_system : False
32-
33- jax_cache_dir : ' '
34- weights_dtype : ' bfloat16'
35- activations_dtype : ' bfloat16'
36-
37-
3825run_name : ' '
3926output_dir : ' ltx-video-output'
4027save_config_to_gcs : False
4128
4229# parallelism
43- mesh_axes : ['data', 'fsdp', 'tensor']
30+ mesh_axes : ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence' ]
4431logical_axis_rules : [
4532 ['batch', 'data'],
46- ['activation_heads', 'fsdp'],
4733 ['activation_batch', ['data','fsdp']],
34+ ['activation_heads', 'tensor'],
4835 ['activation_kv', 'tensor'],
4936 ['mlp','tensor'],
5037 ['embed','fsdp'],
5138 ['heads', 'tensor'],
52- ['norm', 'fsdp'],
5339 ['conv_batch', ['data','fsdp']],
5440 ['out_channels', 'tensor'],
5541 ['conv_out', 'fsdp'],
56- ['conv_in', 'fsdp']
5742 ]
58- data_sharding : [['data', 'fsdp', 'tensor']]
43+ data_sharding : [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence' ]]
5944dcn_data_parallelism : 1 # recommended DCN axis to be auto-sharded
6045dcn_fsdp_parallelism : -1
6146dcn_tensor_parallelism : 1
62-
6347ici_data_parallelism : -1
6448ici_fsdp_parallelism : 1 # recommended ICI axis to be auto-sharded
6549ici_tensor_parallelism : 1
50+
6651ici_fsdp_transpose_parallelism : 1
6752ici_sequence_parallelism : 1
6853ici_tensor_transpose_parallelism : 1
@@ -84,4 +69,4 @@ per_device_batch_size: 1
8469compile_topology_num_slices : -1
8570quantization_local_shard_count : -1
8671jit_initializers : True
87- enable_single_replica_ckpt_restoring : False
72+ enable_single_replica_ckpt_restoring : False
0 commit comments