Skip to content

Commit 0d96afd

Browse files
wip - context paralleism - not working yet.
1 parent 05f0554 commit 0d96afd

7 files changed

Lines changed: 111 additions & 51 deletions

File tree

src/maxdiffusion/common_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636

3737
BATCH = "activation_batch"
3838
LENGTH = "activation_length"
39+
Q_LENGTH = "activation_q_length"
40+
KV_LENGTH = "activation_kv_length"
3941
EMBED = "activation_embed"
4042
HEAD = "activation_heads"
4143
D_KV = "activation_kv"

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
9696
skip_jax_distributed_system: False
9797

9898
# Parallelism
99-
mesh_axes: ['data', 'fsdp', 'tensor']
99+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
100100

101101
# batch : batch dimension of data and activations
102102
# hidden :
@@ -112,29 +112,33 @@ mesh_axes: ['data', 'fsdp', 'tensor']
112112
# conv_out : conv.shape[-1] weight
113113
logical_axis_rules: [
114114
['batch', 'data'],
115-
['activation_batch', ['data','fsdp']],
115+
#['activation_batch', 'fsdp'],
116116
['activation_heads', 'tensor'],
117+
['activation_q_length', ['context']],
118+
['activatation_kv_length', []],
117119
['activation_kv', 'tensor'],
118120
['mlp','tensor'],
119-
['embed','fsdp'],
121+
['embed',['fsdp','context']],
120122
['heads', 'tensor'],
121123
['norm', 'fsdp'],
122124
['conv_batch', ['data','fsdp']],
123125
['out_channels', 'tensor'],
124126
['conv_out', 'fsdp'],
125127
['conv_in', 'fsdp']
126128
]
127-
data_sharding: [['data', 'fsdp', 'tensor']]
129+
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
128130

129131
# One axis for each parallelism type may hold a placeholder (-1)
130132
# value to auto-shard based on available slices and devices.
131133
# By default, product of the DCN axes should equal number of slices
132134
# and product of the ICI axes should equal number of devices per slice.
133135
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
134-
dcn_fsdp_parallelism: -1
136+
dcn_fsdp_parallelism: 1
137+
dcn_context_parallelism: -1
135138
dcn_tensor_parallelism: 1
136139
ici_data_parallelism: 1
137-
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
140+
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
141+
ici_context_parallelism: -1
138142
ici_tensor_parallelism: 1
139143

140144
# Dataset

src/maxdiffusion/max_utils.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -266,21 +266,13 @@ def create_device_mesh(config, devices=None, logging=True):
266266

267267
multi_slice_env = num_slices > 1
268268

269-
dcn_parallelism = [
270-
config.dcn_data_parallelism,
271-
config.dcn_fsdp_parallelism,
272-
config.dcn_tensor_parallelism,
273-
]
274-
ici_parallelism = [
275-
config.ici_data_parallelism,
276-
config.ici_fsdp_parallelism,
277-
config.ici_tensor_parallelism,
278-
]
269+
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
279270

280271
# Find possible unspecified parallelisms
281272
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
273+
282274
if multi_slice_env:
283-
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
275+
dcn_parallelism = fill_unspecified_mesh_axes(config.dcp_parallelism.copy(), num_slices, "DCN")
284276
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
285277
else:
286278
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)

src/maxdiffusion/models/attention_flax.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
AxisNames = common_types.AxisNames
4040
BATCH = common_types.BATCH
4141
LENGTH = common_types.LENGTH
42+
Q_LENGTH = common_types.Q_LENGTH
43+
KV_LENGTH = common_types.KV_LENGTH
4244
HEAD = common_types.HEAD
4345
D_KV = common_types.D_KV
4446
EMBED = common_types.EMBED
@@ -139,50 +141,87 @@ def _tpu_flash_attention(
139141
value: jax.Array,
140142
heads: int,
141143
mesh: Mesh,
142-
flash_axis_names: AxisNames,
143-
flash_block_sizes: BlockSizes,
144+
flash_block_sizes: BlockSizes = None,
145+
flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
146+
flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
147+
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH),
144148
dtype: jnp.dtype = jnp.float32) -> jax.Array:
145149
"""TPU Flash Attention"""
146150

147-
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
151+
cp_size = mesh.shape["context"]
152+
#breakpoint()
153+
axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
154+
axis_names_q = nn.logical_to_mesh_axes(flash_axis_names_q)
155+
axis_names_kv = nn.logical_to_mesh_axes(flash_axis_names_kv)
156+
max_logging.log(f"axis_names_q: {axis_names_q}")
157+
max_logging.log(f"axis_names_kv: {axis_names_kv}")
158+
max_logging.log(f"axis_names_splash_kernel: {axis_names_splash_kernel}")
159+
160+
max_block_size = 256 if dtype == jnp.bfloat16 else 128
148161
if flash_block_sizes:
149162
block_sizes = flash_block_sizes
150163
else:
151164
block_sizes = splash_attention_kernel.BlockSizes(
152165
block_q=min(max_block_size, query.shape[2]),
153-
block_kv_compute=min(max_block_size, key.shape[2]),
154166
block_kv=min(max_block_size, key.shape[2]),
167+
block_kv_compute=min(max_block_size, key.shape[2]),
155168
block_q_dkv=min(max_block_size, query.shape[2]),
156169
block_kv_dkv=min(max_block_size, key.shape[2]),
157170
block_kv_dkv_compute=min(max_block_size, query.shape[2]),
158171
block_q_dq=min(max_block_size, query.shape[2]),
159172
block_kv_dq=min(max_block_size, query.shape[2]),
173+
q_layout=splash_attention_kernel.QKVLayout["HEAD_DIM_MINOR"],
174+
k_layout=splash_attention_kernel.QKVLayout["HEAD_DIM_MINOR"],
175+
v_layout=splash_attention_kernel.QKVLayout["HEAD_DIM_MINOR"],
160176
)
161177

162178
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q)
163179
key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute)
164180
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute)
165181

166-
axis_names = nn.logical_to_mesh_axes(flash_axis_names)
167-
168182
@functools.partial(
169-
shard_map.shard_map,
170-
mesh=mesh,
171-
in_specs=(
172-
axis_names,
173-
axis_names,
174-
axis_names,
175-
),
176-
out_specs=axis_names,
177-
check_rep=False,
183+
jax.jit,
184+
static_argnames=[
185+
"multi_head_mask",
186+
"shard_head_size"
187+
],
178188
)
179-
def wrap_flash_attention(query, key, value):
180-
masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for _ in range(query.shape[1])]
181-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks)
189+
def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
190+
# breakpoint()
182191
splash_kernel = splash_attention_kernel.make_splash_mha(
183-
mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes
192+
mask=multi_head_mask,
193+
head_shards=shard_head_size, # the sizes of the axis is sharding over heads
194+
q_seq_shards=cp_size,
195+
block_sizes=block_sizes,
184196
)
185-
return jax.vmap(splash_kernel)(query, key, value)
197+
return splash_kernel
198+
199+
# logical_axis_rules_head = np.array(
200+
# [mesh.shape[physical_axes] for physical_axes in dict(config.logical_axis_rules)[HEAD]]
201+
# )
202+
shard_head_size = 1
203+
204+
masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for _ in range(query.shape[1])]
205+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks)
206+
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
207+
named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
208+
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
209+
@functools.partial(
210+
shard_map.shard_map,
211+
mesh=mesh,
212+
in_specs=(
213+
axis_names_q,
214+
axis_names_kv,
215+
axis_names_kv,
216+
segment_axis_names_splash_kernel,
217+
None
218+
),
219+
out_specs=axis_names_q,
220+
check_rep=False
221+
)
222+
def wrap_flash_attention(query, key, value, splash_kernel, cp_size):
223+
attention_output = jax.vmap(splash_kernel)(query, key, value)
224+
return attention_output
186225

187226
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
188227
# This warning might show up when doing model eval for example, when calculating model flops
@@ -192,7 +231,7 @@ def wrap_flash_attention(query, key, value):
192231
"Warning, batch dimension should be shardable among the devices in data and fsdp"
193232
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
194233
)
195-
x = wrap_flash_attention(query, key, value)
234+
x = wrap_flash_attention(query, key, value, splash_kernel, cp_size)
196235
x = x[:, :, :query_seq_len, :kv_size]
197236
x = _reshape_heads_to_head_dim(x)
198237

@@ -343,7 +382,15 @@ def _apply_attention(
343382
if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention:
344383
return _apply_attention_dot(query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention)
345384
elif attention_kernel == "flash":
346-
return _tpu_flash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes, dtype)
385+
return _tpu_flash_attention(
386+
query=query,
387+
key=key * scale,
388+
value=value,
389+
heads=heads,
390+
mesh=mesh,
391+
flash_block_sizes=flash_block_sizes,
392+
dtype=dtype
393+
)
347394
elif attention_kernel == "cudnn_flash_te":
348395
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
349396
else:

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import flax
2222
import flax.linen as nn
2323
from flax import nnx
24+
from flax.linen import partitioning as nn_partitioning
2425
from ...pyconfig import HyperParameters
2526
from ... import max_logging
2627
from ... import max_utils
@@ -434,7 +435,7 @@ def __call__(
434435
num_transformer_layers=self.transformer.config.num_layers
435436
)
436437

437-
with self.mesh:
438+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
438439
latents = p_run_inference(
439440
graphdef=graphdef,
440441
sharded_state=state,

src/maxdiffusion/pyconfig.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,22 @@ def user_init(raw_keys):
155155
raw_keys["num_slices"] = get_num_slices(raw_keys)
156156
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
157157

158+
ici_parallelism = [
159+
raw_keys["ici_data_parallelism"],
160+
raw_keys["ici_fsdp_parallelism"],
161+
raw_keys["ici_context_parallelism"],
162+
raw_keys["ici_tensor_parallelism"],
163+
164+
]
165+
dcn_parallelism = [
166+
raw_keys["dcn_data_parallelism"],
167+
raw_keys["dcn_fsdp_parallelism"],
168+
raw_keys["dcn_context_parallelism"],
169+
raw_keys["dcn_tensor_parallelism"],
170+
]
171+
raw_keys["ici_parallelism"] = ici_parallelism
172+
raw_keys['dcn_parallelism'] = dcn_parallelism
173+
158174

159175
def get_num_slices(raw_keys):
160176
if int(raw_keys["compile_topology_num_slices"]) > 0:

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,26 +77,24 @@ def start_training(self):
7777

7878
pipeline = self.load_checkpoint()
7979
mesh = pipeline.mesh
80+
#breakpoint()
81+
# logical_axis_rules_head = np.array([mesh.shape[physical_axes] for physical_axes in dict(self.config.logical_axis_rules)["activation_heads"]])
82+
# breakpoint()
8083

8184
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, self.config.learning_rate)
8285

83-
# @nnx.jit
84-
# def create_transformer_state(transformer):
85-
# optimizer = self._create_optimizer(transformer, self.config, self.config.learning_rate)
86-
# breakpoint()
87-
# _, state = nnx.split((transformer, optimizer))
88-
89-
# with mesh:
90-
# create_transformer_state(pipeline.transformer)
91-
92-
#graphdef, state = nnx.plit((pipeline.transformer, optimizer))
9386
dummy_inputs = self.load_dataset(pipeline)
9487
dummy_inputs = tuple([jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs])
9588

9689
self.training_loop(pipeline, optimizer, learning_rate_scheduler, dummy_inputs)
9790

9891
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data):
99-
92+
# From Wan 2.1 paper https://arxiv.org/pdf/2503.20314
93+
# Input shape of DiT block is (b, s, h)
94+
# b corresponds to data parallelism.
95+
# s represents the sequence length and sharding is achieved through context parallelism.
96+
# Sharding long the h dimension primarily involves Megatron's tensor parallelism TP combined
97+
# with sequence parallelism which shards the hidden dimension of the activations by splitting the weights.
10098
graphdef, state = nnx.split((pipeline.transformer, optimizer))
10199
state = state.to_pure_dict()
102100
p_train_step = jax.jit(

0 commit comments

Comments
 (0)