Skip to content

Commit 9ee7fd3

Browse files
improves performance by 14% on v5p.
1 parent 87817d0 commit 9ee7fd3

4 files changed

Lines changed: 34 additions & 32 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,7 @@ from_pt: True
5252
split_head_dim: True
5353
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
5454

55-
flash_block_sizes: {
56-
"block_q" : 1024,
57-
"block_kv_compute" : 1024,
58-
"block_kv" : 1024,
59-
"block_q_dkv" : 1024,
60-
"block_kv_dkv" : 1024,
61-
"block_kv_dkv_compute" : 1024,
62-
"block_q_dq" : 1024,
63-
"block_kv_dq" : 1024
64-
}
55+
flash_block_sizes: {}
6556
# GroupNorm groups
6657
norm_num_groups: 32
6758

@@ -127,6 +118,7 @@ logical_axis_rules: [
127118
['mlp','tensor'],
128119
['embed','fsdp'],
129120
['heads', 'tensor'],
121+
['norm', 'fsdp'],
130122
['conv_batch', ['data','fsdp']],
131123
['out_channels', 'tensor'],
132124
['conv_out', 'fsdp'],

src/maxdiffusion/generate_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def run(config):
3535

3636
print("compile time: ", (time.perf_counter() - s0))
3737
for i in range(len(videos)):
38-
export_to_video(videos[i], f"wan_output_{i}.mp4", fps=16)
38+
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=16)
3939
s0 = time.perf_counter()
4040
with jax.profiler.trace("/tmp/trace/"):
4141
videos = pipeline(
@@ -49,7 +49,7 @@ def run(config):
4949
)
5050
print("generation time: ", (time.perf_counter() - s0))
5151
for i in range(len(videos)):
52-
export_to_video(videos[i], f"wan_output_{i}.mp4", fps=16)
52+
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=16)
5353

5454

5555
def main(argv: Sequence[str]) -> None:

src/maxdiffusion/models/attention_flax.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import flax.linen as nn
2020
from flax import nnx
2121
import jax
22+
from jax.sharding import PartitionSpec
2223
import jax.numpy as jnp
2324
from jax.experimental import shard_map
2425
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
@@ -139,21 +140,23 @@ def _tpu_flash_attention(
139140
heads: int,
140141
mesh: Mesh,
141142
flash_axis_names: AxisNames,
142-
flash_block_sizes: BlockSizes) -> jax.Array:
143+
flash_block_sizes: BlockSizes,
144+
dtype: jnp.dtype = jnp.float32) -> jax.Array:
143145
"""TPU Flash Attention"""
144146

147+
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
145148
if flash_block_sizes:
146149
block_sizes = flash_block_sizes
147150
else:
148151
block_sizes = splash_attention_kernel.BlockSizes(
149-
block_q=min(512, query.shape[2]),
150-
block_kv_compute=min(512, key.shape[2]),
151-
block_kv=min(512, key.shape[2]),
152-
block_q_dkv=min(512, query.shape[2]),
153-
block_kv_dkv=min(512, key.shape[2]),
154-
block_kv_dkv_compute=min(512, query.shape[2]),
155-
block_q_dq=min(512, query.shape[2]),
156-
block_kv_dq=min(512, query.shape[2]),
152+
block_q=min(max_block_size, query.shape[2]),
153+
block_kv_compute=min(max_block_size, key.shape[2]),
154+
block_kv=min(max_block_size, key.shape[2]),
155+
block_q_dkv=min(max_block_size, query.shape[2]),
156+
block_kv_dkv=min(max_block_size, key.shape[2]),
157+
block_kv_dkv_compute=min(max_block_size, query.shape[2]),
158+
block_q_dq=min(max_block_size, query.shape[2]),
159+
block_kv_dq=min(max_block_size, query.shape[2]),
157160
)
158161

159162
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q)
@@ -340,7 +343,7 @@ def _apply_attention(
340343
if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention:
341344
return _apply_attention_dot(query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention)
342345
elif attention_kernel == "flash":
343-
return _tpu_flash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes)
346+
return _tpu_flash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes, dtype)
344347
elif attention_kernel == "cudnn_flash_te":
345348
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
346349
else:
@@ -668,15 +671,15 @@ def __init__(
668671
rngs=rngs,
669672
epsilon=eps,
670673
dtype=dtype,
671-
scale_init=nnx.with_partitioning(nnx.initializers.ones, ("heads", )),
674+
scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm", )),
672675
param_dtype=weights_dtype
673676
)
674677

675678
self.norm_k = nnx.RMSNorm(
676679
num_features=self.inner_dim,
677680
rngs=rngs,
678681
dtype=dtype,
679-
scale_init=nnx.with_partitioning(nnx.initializers.ones, ("heads", )),
682+
scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm", )),
680683
param_dtype=weights_dtype
681684
)
682685

@@ -702,9 +705,12 @@ def __call__(
702705
encoder_hidden_states: jax.Array = None,
703706
rotary_emb: Optional[jax.Array] = None
704707
) -> jax.Array:
708+
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec('data', 'fsdp','tensor'))
709+
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec('data', 'fsdp','tensor'))
705710
dtype = hidden_states.dtype
706711
if encoder_hidden_states is None:
707712
encoder_hidden_states = hidden_states
713+
708714
query_proj = self.query(hidden_states)
709715
key_proj = self.key(encoder_hidden_states)
710716
value_proj = self.value(encoder_hidden_states)
@@ -717,8 +723,13 @@ def __call__(
717723
key_proj = _unflatten_heads(key_proj, self.heads)
718724
value_proj = _unflatten_heads(value_proj, self.heads)
719725
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
720-
726+
query_proj = jax.lax.with_sharding_constraint(query_proj, PartitionSpec('data', 'tensor', None, None))
727+
key_proj = jax.lax.with_sharding_constraint(key_proj, PartitionSpec('data', 'tensor', None, None))
728+
value_proj = jax.lax.with_sharding_constraint(value_proj, PartitionSpec('data', 'tensor', None, None))
729+
721730
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
731+
attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec('data', None, None))
732+
722733
attn_output = attn_output.astype(dtype=dtype)
723734

724735
hidden_states = self.proj_attn(attn_output)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,6 @@ def __init__(
179179
dtype=dtype,
180180
param_dtype=weights_dtype,
181181
precision=precision,
182-
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)),
183-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
184182
)
185183

186184
def __call__(self, x: jax.Array) -> jax.Array:
@@ -231,7 +229,6 @@ def __init__(
231229
param_dtype=weights_dtype,
232230
precision=precision,
233231
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("mlp", "embed",)),
234-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
235232
)
236233

237234
def __call__(self, hidden_states: jax.Array) -> jax.Array:
@@ -338,7 +335,7 @@ def __call__(
338335

339336
# 1. Self-attention
340337
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype)
341-
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
338+
attn_output = self.attn1(hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
342339
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
343340

344341
# 2. Cross-attention
@@ -443,11 +440,13 @@ def __init__(
443440
dtype=dtype,
444441
param_dtype=weights_dtype,
445442
precision=precision,
446-
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "mlp",)),
447-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
443+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)),
448444
)
449445
key = rngs.params()
450-
self.scale_shift_table = nnx.Param(jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5)
446+
self.scale_shift_table = nnx.Param(
447+
jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5,
448+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed"))
449+
)
451450

452451
def __call__(
453452
self,

0 commit comments

Comments
 (0)