Skip to content

Commit 42cbb0e

Browse files
committed
Add dense padded attention kernel and use unsafe rng key for generation
1 parent 65da062 commit 42cbb0e

10 files changed

Lines changed: 1303 additions & 55 deletions

File tree

padded_flash_attn.py

Lines changed: 415 additions & 0 deletions
Large diffs are not rendered by default.

splash_attn_benchmark.py

Lines changed: 387 additions & 0 deletions
Large diffs are not rendered by default.

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,18 @@ jit_initializers: True
5757
from_pt: True
5858
split_head_dim: True
5959
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
60-
flash_min_seq_length: 4096
60+
flash_min_seq_length: 0
6161
dropout: 0.1
6262

6363
flash_block_sizes: {
64-
"block_q" : 1024,
65-
"block_kv_compute" : 256,
66-
"block_kv" : 1024,
67-
"block_q_dkv" : 1024,
68-
"block_kv_dkv" : 1024,
69-
"block_kv_dkv_compute" : 256,
70-
"block_q_dq" : 1024,
71-
"block_kv_dq" : 1024
64+
"block_q" : 3024,
65+
"block_kv_compute" : 1024,
66+
"block_kv" : 2048,
67+
"block_q_dkv" : 3024,
68+
"block_kv_dkv" : 2048,
69+
"block_kv_dkv_compute" : 2048,
70+
"block_q_dq" : 3024,
71+
"block_kv_dq" : 2048
7272
}
7373
# Use on v6e
7474
# flash_block_sizes: {
@@ -82,16 +82,16 @@ flash_block_sizes: {
8282
# "block_kv_dq" : 2048
8383
# }
8484
# Use on v5p
85-
flash_block_sizes: {
86-
"block_q" : 1024,
87-
"block_kv_compute" : 256,
88-
"block_kv" : 3072,
89-
"block_q_dkv" : 1024,
90-
"block_kv_dkv" : 3072,
91-
"block_kv_dkv_compute" : 256,
92-
"block_q_dq" : 1024,
93-
"block_kv_dq" : 3072
94-
}
85+
# flash_block_sizes: {
86+
# "block_q" : 3024,
87+
# "block_kv_compute" : 1024,
88+
# "block_kv" : 2048,
89+
# "block_q_dkv" : 1024,
90+
# "block_kv_dkv" : 3072,
91+
# "block_kv_dkv_compute" : 256,
92+
# "block_q_dq" : 1024,
93+
# "block_kv_dq" : 3072
94+
# }
9595
# GroupNorm groups
9696
norm_num_groups: 32
9797

@@ -152,7 +152,7 @@ mesh_axes: ['data', 'fsdp', 'tensor']
152152
logical_axis_rules: [
153153
['batch', 'data'],
154154
['activation_batch', 'data'],
155-
['activation_self_attn_heads', ['fsdp', 'tensor']],
155+
['activation_self_attn_heads', ['fsdp', 'tensor']],
156156
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
157157
['activation_length', 'fsdp'],
158158
['activation_heads', 'tensor'],
@@ -284,7 +284,7 @@ flow_shift: 3.0
284284
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
285285
guidance_rescale: 0.0
286286
num_inference_steps: 30
287-
fps: 24
287+
fps: 16
288288
save_final_checkpoint: False
289289

290290
# SDXL Lightning parameters

src/maxdiffusion/generate_wan.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ def delete_file(file_path: str):
6262

6363

6464
jax.config.update("jax_use_shardy_partitioner", True)
65+
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
66+
# TF allocates extraneous GPU memory when using TFDS data
67+
# this leads to CUDA OOMs. WAR for now is to hide GPUs from TF
68+
# tf.config.set_visible_devices([], "GPU")
69+
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
70+
os.environ["LIBTPU_INIT_ARGS"] = (
71+
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
72+
)
6573

6674

6775
def inference_generate_video(config, pipeline, filename_prefix=""):
@@ -97,7 +105,6 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
97105
def run(config, pipeline=None, filename_prefix=""):
98106
print("seed: ", config.seed)
99107
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
100-
101108
checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT")
102109
pipeline = checkpoint_loader.load_checkpoint()
103110
if pipeline is None:

src/maxdiffusion/max_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -498,11 +498,11 @@ def get_flash_block_sizes(config):
498498
block_q=int(config.flash_block_sizes["block_q"]),
499499
block_kv_compute=int(config.flash_block_sizes["block_kv_compute"]),
500500
block_kv=int(config.flash_block_sizes["block_kv"]),
501-
block_q_dkv=int(config.flash_block_sizes["block_q_dkv"]),
502-
block_kv_dkv=int(config.flash_block_sizes["block_kv_dkv"]),
503-
block_kv_dkv_compute=int(config.flash_block_sizes["block_kv_dkv_compute"]),
504-
block_q_dq=int(config.flash_block_sizes["block_q_dq"]),
505-
block_kv_dq=int(config.flash_block_sizes["block_kv_dq"]),
501+
block_q_dkv=config.flash_block_sizes.get("block_q_dkv"),
502+
block_kv_dkv=config.flash_block_sizes.get("block_kv_dkv"),
503+
block_kv_dkv_compute=config.flash_block_sizes.get("block_kv_dkv_compute"),
504+
block_q_dq=config.flash_block_sizes.get("block_q_dq"),
505+
block_kv_dq=config.flash_block_sizes.get("block_kv_dq"),
506506
)
507507
return flash_block_sizes
508508

src/maxdiffusion/models/attention_flax.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
2727
from einops import rearrange
2828
from .. import common_types, max_logging
29+
from .padded_flash_attn import make_dense_padded_attention
2930

3031
from . import quantizations
3132

@@ -236,20 +237,23 @@ def wrap_flash_attention(query, key, value):
236237
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
237238
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
238239
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
239-
240-
# make_splash_mha is wrapped around shardmap and seq and head is already
241-
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
242240
splash_kernel = splash_attention_kernel.make_splash_mha(
243241
mask=multi_head_mask,
244242
head_shards=1, # the sizes of the axis is sharding over heads
245243
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
246244
block_sizes=block_sizes,
247245
save_residuals=True if attention_kernel == "ring" else False,
248246
)
249-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
247+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None), out_axes=0)
250248

251249
if attention_kernel == "flash":
250+
# attention_output = vmapped_splash(query, key, value, segment_ids)
252251
attention_output = vmapped_splash(query, key, value, segment_ids)
252+
elif attention_kernel == "dense_padded":
253+
padded_kv_len = key.shape[1] - key_seq_len
254+
dense_padded_attention_kernel = make_dense_padded_attention(block_sizes=block_sizes, kv_padding=padded_kv_len)
255+
vmapped_splash = jax.vmap(dense_padded_attention_kernel, in_axes=(0, 0, 0), out_axes=0)
256+
attention_output, _ = vmapped_splash(query, key, value)
253257
else:
254258
if num_fsdp_shards > 1:
255259
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
@@ -458,6 +462,19 @@ def _apply_attention(
458462
dtype,
459463
attention_kernel,
460464
)
465+
elif attention_kernel == "dense_padded":
466+
return _tpu_flash_attention(
467+
query,
468+
key * scale,
469+
value,
470+
heads,
471+
mesh,
472+
axis_names_q,
473+
axis_names_kv,
474+
flash_block_sizes,
475+
dtype,
476+
attention_kernel,
477+
)
461478
elif attention_kernel == "ring":
462479
return _tpu_flash_attention(
463480
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel
@@ -877,10 +894,10 @@ def __call__(
877894
dtype = hidden_states.dtype
878895
if encoder_hidden_states is None:
879896
encoder_hidden_states = hidden_states
880-
881-
query_proj = self.query(hidden_states)
882-
key_proj = self.key(encoder_hidden_states)
883-
value_proj = self.value(encoder_hidden_states)
897+
with jax.named_scope("attention-projection"):
898+
query_proj = self.query(hidden_states)
899+
key_proj = self.key(encoder_hidden_states)
900+
value_proj = self.value(encoder_hidden_states)
884901

885902
if self.qk_norm:
886903
query_proj = self.norm_q(query_proj)
@@ -895,7 +912,8 @@ def __call__(
895912
query_proj = checkpoint_name(query_proj, "query_proj")
896913
key_proj = checkpoint_name(key_proj, "key_proj")
897914
value_proj = checkpoint_name(value_proj, "value_proj")
898-
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
915+
with jax.named_scope("attention-compute"):
916+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
899917

900918
attn_output = attn_output.astype(dtype=dtype)
901919
attn_output = checkpoint_name(attn_output, "attn_output")

0 commit comments

Comments
 (0)