Skip to content

Commit 9716fc5

Browse files
authored
host offloading (#273)
1 parent f5f212f commit 9716fc5

5 files changed

Lines changed: 48 additions & 7 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ flash_block_sizes: {
8181
# "block_kv_dkv" : 2048,
8282
# "block_kv_dkv_compute" : 2048,
8383
# "block_q_dq" : 3024,
84-
# "block_kv_dq" : 2048
84+
# "block_kv_dq" : 2048,
85+
# "use_fused_bwd_kernel": False,
8586
# }
8687
# GroupNorm groups
8788
norm_num_groups: 32

src/maxdiffusion/max_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,11 @@ def get_precision(config):
489489
retval = jax.lax.Precision.HIGHEST
490490
return retval
491491

492+
def value_or_none(flash_block_sizes, key):
493+
if key in flash_block_sizes:
494+
return flash_block_sizes[key]
495+
else:
496+
return None
492497

493498
def get_flash_block_sizes(config):
494499
"""Create custom flash attention BlockSizes."""
@@ -501,8 +506,9 @@ def get_flash_block_sizes(config):
501506
block_q_dkv=config.flash_block_sizes["block_q_dkv"],
502507
block_kv_dkv=config.flash_block_sizes["block_kv_dkv"],
503508
block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"],
504-
block_q_dq=config.flash_block_sizes["block_q_dq"],
505-
block_kv_dq=config.flash_block_sizes["block_kv_dq"],
509+
block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"),
510+
block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"),
511+
use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel")
506512
)
507513
return flash_block_sizes
508514

src/maxdiffusion/models/attention_flax.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def _tpu_flash_attention(
174174
flash_block_sizes: BlockSizes,
175175
dtype: jnp.dtype = jnp.float32,
176176
attention_kernel: str = "flash",
177+
residual_checkpoint_name: str | None = None,
177178
) -> jax.Array:
178179
"""TPU Flash Attention"""
179180

@@ -213,9 +214,22 @@ def _tpu_flash_attention(
213214
)
214215
def wrap_flash_attention(query, key, value):
215216

216-
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_sizes.block_q)
217-
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv)
218-
value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv)
217+
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
218+
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv,)
219+
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv,)
220+
if uses_fused_kernel:
221+
block_q_sizes += (block_sizes.block_q_dkv,)
222+
block_kv_sizes += (block_sizes.block_kv_dkv,)
223+
else:
224+
block_q_sizes += (block_sizes.block_q_dq,)
225+
block_kv_sizes += (block_sizes.block_kv_dq,)
226+
227+
block_q = max(*block_q_sizes)
228+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
229+
230+
block_kv = max(*block_kv_sizes)
231+
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
232+
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
219233

220234
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
221235
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
@@ -237,6 +251,7 @@ def wrap_flash_attention(query, key, value):
237251
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
238252
block_sizes=block_sizes,
239253
save_residuals=True if attention_kernel == "ring" else False,
254+
residual_checkpoint_name=residual_checkpoint_name,
240255
)
241256
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
242257

@@ -419,6 +434,7 @@ def _apply_attention(
419434
axis_names_kv: AxisNames,
420435
flash_block_sizes: BlockSizes,
421436
dpa_layer: Callable,
437+
residual_checkpoint_name: str | None = None,
422438
):
423439
"""Routes to different attention kernels."""
424440
_check_attention_inputs(query, key, value)
@@ -439,7 +455,7 @@ def _apply_attention(
439455
)
440456
elif attention_kernel == "flash":
441457
return _tpu_flash_attention(
442-
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
458+
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, residual_checkpoint_name=residual_checkpoint_name
443459
)
444460
elif attention_kernel == "ring":
445461
return _tpu_flash_attention(
@@ -574,6 +590,7 @@ def __init__(
574590
flash_block_sizes: BlockSizes = None,
575591
dtype: DType = jnp.float32,
576592
quant: Quant = None,
593+
residual_checkpoint_name: str | None = None,
577594
):
578595
self.dpa_layer = None
579596
if attention_kernel == "cudnn_flash_te":
@@ -593,6 +610,7 @@ def __init__(
593610
self.flash_block_sizes = flash_block_sizes
594611
self.dtype = dtype
595612
self.quant = quant
613+
self.residual_checkpoint_name = residual_checkpoint_name
596614

597615
def apply_attention(self, query: Array, key: Array, value: Array):
598616
return _apply_attention(
@@ -613,6 +631,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
613631
axis_names_kv=self.axis_names_kv,
614632
flash_block_sizes=self.flash_block_sizes,
615633
dpa_layer=self.dpa_layer,
634+
residual_checkpoint_name=self.residual_checkpoint_name,
616635
)
617636

618637

@@ -701,6 +720,7 @@ def __init__(
701720
precision: jax.lax.Precision = None,
702721
qkv_bias: bool = False,
703722
quant: Quant = None,
723+
residual_checkpoint_name: str | None = None,
704724
):
705725
if attention_kernel == "cudnn_flash_te":
706726
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
@@ -730,6 +750,7 @@ def __init__(
730750
flash_block_sizes=flash_block_sizes,
731751
dtype=dtype,
732752
quant=quant,
753+
residual_checkpoint_name=residual_checkpoint_name,
733754
)
734755
# None axes corresponds to the stacked weights across all blocks
735756
# because of the use of nnx.vmap and nnx.scan.

src/maxdiffusion/models/gradient_checkpoint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class GradientCheckpointType(Enum):
4141
MATMUL_WITHOUT_BATCH = auto()
4242
OFFLOAD_MATMUL_WITHOUT_BATCH = auto()
4343
CUSTOM = auto()
44+
HIDDEN_STATE_WITH_OFFLOAD = auto()
4445

4546
@classmethod
4647
def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType":
@@ -76,6 +77,13 @@ def to_jax_policy(self, names_which_can_be_saved: list = [], names_which_can_be_
7677
offload_dst="pinned_host",
7778
)
7879
return policy
80+
case GradientCheckpointType.HIDDEN_STATE_WITH_OFFLOAD:
81+
return jax.checkpoint_policies.save_and_offload_only_these_names(
82+
names_which_can_be_saved=[],
83+
names_which_can_be_offloaded=["hidden_states","self_attn","cross_attn"],
84+
offload_src="device",
85+
offload_dst="pinned_host",
86+
)
7987
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
8088
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
8189

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from jax.sharding import PartitionSpec
2222
from jax.ad_checkpoint import checkpoint_name
2323
from flax import nnx
24+
import flax.linen as nn
2425
import numpy as np
2526
from .... import common_types
2627
from ...modeling_flax_utils import FlaxModelMixin, get_activation
@@ -282,6 +283,7 @@ def __init__(
282283
precision=precision,
283284
attention_kernel=attention,
284285
dropout=dropout,
286+
residual_checkpoint_name='self_attn',
285287
)
286288

287289
# 1. Cross-attention
@@ -300,6 +302,7 @@ def __init__(
300302
precision=precision,
301303
attention_kernel=attention,
302304
dropout=dropout,
305+
residual_checkpoint_name='cross_attn',
303306
)
304307
assert cross_attn_norm is True
305308
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -335,6 +338,7 @@ def __call__(
335338
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
336339
)
337340
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
341+
hidden_states = checkpoint_name(hidden_states, "hidden_states")
338342
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
339343

340344
# 1. Self-attention
@@ -514,6 +518,7 @@ def __call__(
514518
deterministic: bool = True,
515519
rngs: nnx.Rngs = None,
516520
) -> Union[jax.Array, Dict[str, jax.Array]]:
521+
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
517522
batch_size, _, num_frames, height, width = hidden_states.shape
518523
p_t, p_h, p_w = self.config.patch_size
519524
post_patch_num_frames = num_frames // p_t

0 commit comments

Comments
 (0)