Skip to content

Commit 6f32622

Browse files
eltsaiPerseus14
authored andcommitted
Revert "Integrate tokamax ring attention as optional attention kernel for WAN 2.1" (#305)
This reverts commit f68c7b0. Co-authored-by: Elisa Tsai <elisatsai@google.com>
1 parent ad0fd7c commit 6f32622

7 files changed

Lines changed: 172 additions & 203 deletions

File tree

src/maxdiffusion/configs/base_wan_lora_14b.yml

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ from_pt: True
6161
split_head_dim: True
6262
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6363
flash_min_seq_length: 4096
64+
65+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
66+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
67+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
68+
mask_padding_tokens: True
69+
# Maxdiffusion has 2 types of attention sharding strategies:
70+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
71+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
72+
# in cross attention q.
73+
attention_sharding_uniform: True
6474
dropout: 0.1
6575

6676
#flash_block_sizes: {
@@ -145,8 +155,9 @@ mesh_axes: ['data', 'fsdp', 'tensor']
145155
logical_axis_rules: [
146156
['batch', 'data'],
147157
['activation_batch', 'data'],
158+
['activation_self_attn_heads', ['fsdp', 'tensor']],
159+
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
148160
['activation_length', 'fsdp'],
149-
150161
['activation_heads', 'tensor'],
151162
['mlp','tensor'],
152163
['embed','fsdp'],
@@ -321,8 +332,10 @@ quantization: ''
321332
quantization_local_shard_count: -1
322333
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
323334
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
324-
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
325-
quantization_calibration_method: "absmax"
335+
# Quantization calibration method used for weights, activations and bwd. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
336+
weight_quantization_calibration_method: "absmax"
337+
act_quantization_calibration_method: "absmax"
338+
bwd_quantization_calibration_method: "absmax"
326339
qwix_module_path: ".*"
327340

328341
# Eval model on per eval_every steps. -1 means don't eval.

src/maxdiffusion/configs/base_wan_lora_27b.yml

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ from_pt: True
6161
split_head_dim: True
6262
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6363
flash_min_seq_length: 4096
64+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
65+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
66+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
67+
mask_padding_tokens: True
68+
# Maxdiffusion has 2 types of attention sharding strategies:
69+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
70+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
71+
# in cross attention q.
72+
attention_sharding_uniform: True
6473
dropout: 0.1
6574

6675
#flash_block_sizes: {
@@ -145,8 +154,9 @@ mesh_axes: ['data', 'fsdp', 'tensor']
145154
logical_axis_rules: [
146155
['batch', 'data'],
147156
['activation_batch', 'data'],
157+
['activation_self_attn_heads', ['fsdp', 'tensor']],
158+
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
148159
['activation_length', 'fsdp'],
149-
150160
['activation_heads', 'tensor'],
151161
['mlp','tensor'],
152162
['embed','fsdp'],
@@ -333,8 +343,10 @@ quantization: ''
333343
quantization_local_shard_count: -1
334344
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
335345
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
336-
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
337-
quantization_calibration_method: "absmax"
346+
# Quantization calibration method used for weights, activations and bwd. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
347+
weight_quantization_calibration_method: "absmax"
348+
act_quantization_calibration_method: "absmax"
349+
bwd_quantization_calibration_method: "absmax"
338350
qwix_module_path: ".*"
339351

340352
# Eval model on per eval_every steps. -1 means don't eval.

src/maxdiffusion/generate_wan.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from google.cloud import storage
2626
import flax
2727
from maxdiffusion.common_types import WAN2_1, WAN2_2
28-
from flax import nnx
2928
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader
3029

3130

src/maxdiffusion/loaders/wan_lora_nnx_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525

2626
class Wan2_1NnxLoraLoader(LoRABaseMixin):
2727
"""
28-
Handles loading LoRA weights into NNX-based WAN models.
29-
Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer'
28+
Handles loading LoRA weights into NNX-based WAN 2.1 model.
29+
Assumes WAN pipeline contains 'transformer'
3030
attributes that are NNX Modules.
3131
"""
3232

@@ -62,7 +62,7 @@ def load_lora_weights(
6262

6363
class Wan2_2NnxLoraLoader(LoRABaseMixin):
6464
"""
65-
Handles loading LoRA weights into NNX-based WAN models.
65+
Handles loading LoRA weights into NNX-based WAN 2.2 model.
6666
Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer'
6767
attributes that are NNX Modules.
6868
"""

src/maxdiffusion/models/attention_flax.py

Lines changed: 36 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
2828
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2929
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
30-
from tokamax._src.ops.experimental.tpu.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
3130
from einops import rearrange
3231
from .. import common_types, max_logging
3332

@@ -305,92 +304,62 @@ def wrap_flash_attention(query, key, value):
305304
mask=mask,
306305
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
307306
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
308-
save_residuals=True if "ring" in attention_kernel else False,
309-
)
310-
elif attention_kernel == "tokamax_ring":
311-
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
312-
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
313-
mask=mask,
314-
is_mqa=False,
315-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
316-
save_residuals=True,
317-
ring_axis="fsdp",
307+
save_residuals=True if attention_kernel == "ring" else False,
318308
)
319309
else:
320310
splash_kernel = splash_attention_kernel.make_splash_mha(
321311
mask=multi_head_mask,
322312
head_shards=1, # the sizes of the axis is sharding over heads
323313
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
324314
block_sizes=block_sizes,
325-
save_residuals=True if "ring" in attention_kernel else False,
315+
save_residuals=True if attention_kernel == "ring" else False,
326316
residual_checkpoint_name=residual_checkpoint_name
327317
)
318+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
328319

329-
if attention_kernel == "tokamax_ring":
330-
# For tokamax_ring, use the kernel directly without vmap
331-
# The ring attention kernel handles the ring topology internally
332-
if not mask_padding_tokens:
333-
segment_ids = None
334-
attention_output = splash_kernel(
335-
fwd_mask_info=None,
336-
dkv_mask_info=None,
337-
q=query,
338-
k=key,
339-
v=value,
340-
segment_ids=segment_ids,
341-
is_mqa=False,
342-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
343-
mask_value=-jnp.inf,
344-
mask_function=None,
345-
fwd_mask_sparsity=1.0,
346-
save_residuals=True,
347-
)
320+
if not mask_padding_tokens:
321+
segment_ids = None
322+
if attention_kernel in ["flash", "tokamax_flash"]:
323+
attention_output = vmapped_splash(query, key, value, segment_ids)
348324
else:
349-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
350-
351-
if not mask_padding_tokens:
352-
segment_ids = None
353-
if attention_kernel in ["flash", "tokamax_flash"]:
354-
attention_output = vmapped_splash(query, key, value, segment_ids)
355-
else:
356-
if num_fsdp_shards > 1:
357-
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
358-
m = lse.astype(jnp.float32)
359-
l = jnp.exp(lse - m)
360-
o = out.astype(jnp.float32) * l[..., None]
325+
if num_fsdp_shards > 1:
326+
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
327+
m = lse.astype(jnp.float32)
328+
l = jnp.exp(lse - m)
329+
o = out.astype(jnp.float32) * l[..., None]
361330

362-
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
331+
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
363332

364-
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
365-
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
333+
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
334+
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
366335

367-
def ring_scan_body(carry, _):
368-
m, l, o, k_current, v_current = carry
369-
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
370-
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
336+
def ring_scan_body(carry, _):
337+
m, l, o, k_current, v_current = carry
338+
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
339+
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
371340

372-
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
341+
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
373342

374-
m_chunk = lse_chunk.astype(jnp.float32)
375-
m_old = m
376-
m = jnp.maximum(m_old, m_chunk)
343+
m_chunk = lse_chunk.astype(jnp.float32)
344+
m_old = m
345+
m = jnp.maximum(m_old, m_chunk)
377346

378-
exp_m_diff = jnp.exp(m_old - m)
379-
exp_m_chunk_diff = jnp.exp(m_chunk - m)
347+
exp_m_diff = jnp.exp(m_old - m)
348+
exp_m_chunk_diff = jnp.exp(m_chunk - m)
380349

381-
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
382-
o = o * exp_m_diff[..., None]
383-
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
350+
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
351+
o = o * exp_m_diff[..., None]
352+
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
384353

385-
# Return the updated state for the next iteration
386-
return (m, l, o, k_next, v_next), None
354+
# Return the updated state for the next iteration
355+
return (m, l, o, k_next, v_next), None
387356

388-
initial_carry = (m, l, o, k1, v1)
389-
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
357+
initial_carry = (m, l, o, k1, v1)
358+
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
390359

391-
attention_output = o_final / l_final[..., None]
392-
else:
393-
raise ValueError("ring attention requires fsdp > 1")
360+
attention_output = o_final / l_final[..., None]
361+
else:
362+
raise ValueError("ring attention requires fsdp > 1")
394363

395364
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
396365

@@ -566,7 +535,7 @@ def _apply_attention(
566535
mask_padding_tokens=mask_padding_tokens,
567536
residual_checkpoint_name=residual_checkpoint_name,
568537
)
569-
elif "ring" in attention_kernel:
538+
elif attention_kernel == "ring":
570539
return _tpu_flash_attention(
571540
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel,
572541
mask_padding_tokens=mask_padding_tokens,
@@ -577,7 +546,6 @@ def _apply_attention(
577546
raise ValueError(f"Unexpected attention kernel {attention_kernel=}.")
578547

579548

580-
581549
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
582550
"""Multi-head dot product attention with a limited number of queries."""
583551
num_kv, num_heads, k_features = key.shape[-3:]

0 commit comments

Comments
 (0)