Skip to content

Commit 54ae6e7

Browse files
committed
Flag for using segment ids and masking padding tokens in attention
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 94ecdca commit 54ae6e7

5 files changed

Lines changed: 22 additions & 1 deletion

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ from_pt: True
5858
split_head_dim: True
5959
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6060
flash_min_seq_length: 0
61+
mask_padding_tokens: True
6162
dropout: 0.1
6263

6364
flash_block_sizes: {

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def delete_file(file_path: str):
6767
# this leads to CUDA OOMs. WAR for now is to hide GPUs from TF
6868
# tf.config.set_visible_devices([], "GPU")
6969
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
70+
max_logging.log("Enabling unsafe RNG bit generator for TPU SPMD.")
7071
os.environ["LIBTPU_INIT_ARGS"] = (
7172
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
7273
)

src/maxdiffusion/models/attention_flax.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def _tpu_flash_attention(
181181
flash_block_sizes: BlockSizes,
182182
dtype: jnp.dtype = jnp.float32,
183183
attention_kernel: str = "flash",
184+
mask_padding_tokens: bool = True,
184185
) -> jax.Array:
185186
"""TPU Flash Attention"""
186187

@@ -248,6 +249,8 @@ def wrap_flash_attention(query, key, value):
248249
)
249250
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
250251

252+
if not mask_padding_tokens:
253+
segment_ids = None
251254
if attention_kernel == "flash":
252255
attention_output = vmapped_splash(query, key, value, segment_ids)
253256
else:
@@ -287,6 +290,8 @@ def ring_scan_body(carry, _):
287290
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
288291

289292
attention_output = o_final / l_final[..., None]
293+
else:
294+
raise ValueError("ring attention requires fsdp > 1")
290295

291296
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
292297

@@ -427,6 +432,7 @@ def _apply_attention(
427432
axis_names_kv: AxisNames,
428433
flash_block_sizes: BlockSizes,
429434
dpa_layer: Callable,
435+
mask_padding_tokens: bool = True,
430436
):
431437
"""Routes to different attention kernels."""
432438
_check_attention_inputs(query, key, value)
@@ -457,10 +463,12 @@ def _apply_attention(
457463
flash_block_sizes,
458464
dtype,
459465
attention_kernel,
466+
mask_padding_tokens=mask_padding_tokens,
460467
)
461468
elif attention_kernel == "ring":
462469
return _tpu_flash_attention(
463-
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel
470+
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel,
471+
mask_padding_tokens=mask_padding_tokens,
464472
)
465473
elif attention_kernel == "cudnn_flash_te":
466474
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
@@ -591,6 +599,7 @@ def __init__(
591599
flash_block_sizes: BlockSizes = None,
592600
dtype: DType = jnp.float32,
593601
quant: Quant = None,
602+
mask_padding_tokens: bool = True,
594603
):
595604
self.dpa_layer = None
596605
if attention_kernel == "cudnn_flash_te":
@@ -610,6 +619,7 @@ def __init__(
610619
self.flash_block_sizes = flash_block_sizes
611620
self.dtype = dtype
612621
self.quant = quant
622+
self.mask_padding_tokens = mask_padding_tokens
613623

614624
def apply_attention(self, query: Array, key: Array, value: Array):
615625
return _apply_attention(
@@ -630,6 +640,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
630640
axis_names_kv=self.axis_names_kv,
631641
flash_block_sizes=self.flash_block_sizes,
632642
dpa_layer=self.dpa_layer,
643+
mask_padding_tokens=self.mask_padding_tokens,
633644
)
634645

635646

@@ -719,6 +730,7 @@ def __init__(
719730
qkv_bias: bool = False,
720731
quant: Quant = None,
721732
is_self_attention: bool = True,
733+
mask_padding_tokens: bool = True,
722734
):
723735
if attention_kernel == "cudnn_flash_te":
724736
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
@@ -757,6 +769,7 @@ def __init__(
757769
flash_block_sizes=flash_block_sizes,
758770
dtype=dtype,
759771
quant=quant,
772+
mask_padding_tokens=mask_padding_tokens,
760773
)
761774
# None axes corresponds to the stacked weights across all blocks
762775
# because of the use of nnx.vmap and nnx.scan.

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def __init__(
263263
precision: jax.lax.Precision = None,
264264
attention: str = "dot_product",
265265
dropout: float = 0.0,
266+
mask_padding_tokens: bool = True,
266267
):
267268

268269
# 1. Self-attention
@@ -283,6 +284,7 @@ def __init__(
283284
attention_kernel=attention,
284285
dropout=dropout,
285286
is_self_attention=True,
287+
mask_padding_tokens=mask_padding_tokens
286288
)
287289

288290
# 1. Cross-attention
@@ -302,6 +304,7 @@ def __init__(
302304
attention_kernel=attention,
303305
dropout=dropout,
304306
is_self_attention=False,
307+
mask_padding_tokens=mask_padding_tokens
305308
)
306309
assert cross_attn_norm is True
307310
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -404,6 +407,7 @@ def __init__(
404407
remat_policy: str = "None",
405408
names_which_can_be_saved: list = [],
406409
names_which_can_be_offloaded: list = [],
410+
mask_padding_tokens: bool = True,
407411
):
408412
inner_dim = num_attention_heads * attention_head_dim
409413
out_channels = out_channels or in_channels
@@ -458,6 +462,7 @@ def init_block(rngs):
458462
precision=precision,
459463
attention=attention,
460464
dropout=dropout,
465+
mask_padding_tokens=mask_padding_tokens,
461466
)
462467

463468
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
9090
wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded
9191
wan_config["flash_min_seq_length"] = config.flash_min_seq_length
9292
wan_config["dropout"] = config.dropout
93+
wan_config["mask_padding_tokens"] = config.mask_padding_tokens
9394

9495
# 2. eval_shape - will not use flops or create weights on device
9596
# thus not using HBM memory.

0 commit comments

Comments
 (0)