Skip to content

Commit b2f41ba

Browse files
change axis dims
1 parent 2bbf289 commit b2f41ba

2 files changed

Lines changed: 11 additions & 13 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
110110
skip_jax_distributed_system: False
111111

112112
# Parallelism
113-
mesh_axes: ['data', 'fsdp', 'tensor']
113+
mesh_axes: ['tensor', 'data', 'fsdp']
114114

115115
# batch : batch dimension of data and activations
116116
# hidden :
@@ -126,11 +126,9 @@ mesh_axes: ['data', 'fsdp', 'tensor']
126126
# conv_out : conv.shape[-1] weight
127127
logical_axis_rules: [
128128
['batch', 'data'],
129-
['activation_length', ['fsdp','tensor']],
130-
['activation_kv_length', 'fsdp'],
129+
['activation_length', ['fsdp']],
131130
['activation_heads', 'tensor'],
132131
['activation_batch', 'data'],
133-
#['mlp','tensor'],
134132
['embed',['fsdp', 'tensor']],
135133
['heads', ['tensor', 'fsdp']],
136134
['norm', 'tensor'],

src/maxdiffusion/models/attention_flax.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ def _tpu_flash_attention(
187187
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
188188
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
189189
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
190-
#flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
191-
#axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
192-
#named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
190+
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
191+
axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
192+
named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
193193

194194
shard_head_size = mesh.shape["tensor"]
195195

@@ -210,7 +210,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
210210

211211
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
212212
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
213-
#segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
213+
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
214214

215215
@functools.partial(
216216
shard_map.shard_map,
@@ -219,7 +219,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
219219
q_axis_names,
220220
kv_axis_names,
221221
kv_axis_names,
222-
None,
222+
segment_axis_names_splash_kernel,
223223
),
224224
out_specs=q_axis_names,
225225
check_rep=False,
@@ -511,8 +511,8 @@ def __init__(
511511
use_memory_efficient_attention: bool = False,
512512
split_head_dim: bool = False,
513513
float32_qk_product: bool = True,
514-
axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
515-
axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
514+
axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, None),
515+
axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, None),
516516
flash_min_seq_length: int = 4096,
517517
flash_block_sizes: BlockSizes = None,
518518
dtype: DType = jnp.float32,
@@ -675,7 +675,7 @@ def __init__(
675675
quant=quant,
676676
)
677677

678-
kernel_axes = ("embed", "heads")
678+
kernel_axes = ("embed", None)
679679
qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes)
680680

681681
self.query = nnx.Linear(
@@ -715,7 +715,7 @@ def __init__(
715715
rngs=rngs,
716716
in_features=self.inner_dim,
717717
out_features=self.inner_dim,
718-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")),
718+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "embed")),
719719
dtype=dtype,
720720
param_dtype=weights_dtype,
721721
precision=precision,

0 commit comments

Comments
 (0)