Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/maxdiffusion/checkpointing/wan_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""

from abc import ABC
from flax import nnx
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
from ..pipelines.wan.wan_pipeline import WanPipeline
from .. import max_logging, max_utils
Expand All @@ -42,7 +41,7 @@ def _create_optimizer(self, model, config, learning_rate):
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
)
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
return nnx.Optimizer(model, tx), learning_rate_scheduler
return tx, learning_rate_scheduler

def load_wan_configs_from_orbax(self, step):
max_logging.log("Restoring stable diffusion configs")
Expand Down
17 changes: 14 additions & 3 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
flash_min_seq_length: 4096

flash_block_sizes: {}
# Use on v6e
Expand Down Expand Up @@ -126,15 +127,17 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_length', 'fsdp'],

['activation_heads', 'tensor'],
['activation_batch', 'data'],
['mlp','tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_in', 'fsdp'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]

Expand Down Expand Up @@ -182,6 +185,14 @@ transform_images_num_proc: 4
reuse_example_batch: False
enable_data_shuffling: True

# Defines the type of gradient checkpoint to enable.
# NONE - means no gradient checkpoint
# FULL - means full gradient checkpoint, whenever possible (minimum memory usage)
# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
# except for ones that involve batch dimension - that means that all attention and projection
# layers will have gradient checkpoint, but not the backward with respect to the parameters
remat_policy: "NONE"

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
# enables one replica to read the ckpt then broadcast to the rest
Expand All @@ -196,7 +207,7 @@ max_train_steps: 1500
num_train_epochs: 1
seed: 0
output_dir: 'sdxl-model-finetuned'
per_device_batch_size: 1
per_device_batch_size: 1.0
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
global_batch_size: 0

Expand Down
12 changes: 3 additions & 9 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,9 @@ def run(config, pipeline=None, filename_prefix=""):
pipeline = WanPipeline.from_pretrained(config)
s0 = time.perf_counter()

# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
global_batch_size = config.global_batch_size
if global_batch_size != 0:
batch_multiplier = global_batch_size
else:
batch_multiplier = jax.device_count() * config.per_device_batch_size

prompt = [config.prompt] * batch_multiplier
negative_prompt = [config.negative_prompt] * batch_multiplier
# Using global_batch_size_to_train_on so not to create more config variables
prompt = [config.prompt] * config.global_batch_size_to_train_on
negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on

max_logging.log(
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
Expand Down
7 changes: 5 additions & 2 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,16 @@ def make_tfrecord_iterator(
check out preparation script
maxdiffusion/pedagogical_examples/to_tfrecords.py
"""

# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.

# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)

if (
config.cache_latents_text_encoder_outputs
and os.path.isdir(config.dataset_save_location)
and is_dataset_dir_valid
and "load_tfrecord_cached" in config.get_keys()
and config.load_tfrecord_cached
):
Expand Down
48 changes: 19 additions & 29 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import flax.linen as nn
from flax import nnx
import jax
from jax.ad_checkpoint import checkpoint_name
from jax.sharding import PartitionSpec
import jax.numpy as jnp
from jax.experimental import shard_map
Expand Down Expand Up @@ -187,30 +188,6 @@ def _tpu_flash_attention(
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)

shard_head_size = mesh.shape["tensor"]

@functools.partial(
jax.jit,
static_argnames=["multi_head_mask", "shard_head_size"],
)
def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask,
head_shards=shard_head_size, # the sizes of the axis is sharding over heads
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
block_sizes=block_sizes,
)
return splash_kernel

mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))

multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)

@functools.partial(
shard_map.shard_map,
Expand All @@ -219,12 +196,21 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
q_axis_names,
kv_axis_names,
kv_axis_names,
segment_axis_names_splash_kernel,
),
out_specs=q_axis_names,
check_rep=False,
)
def wrap_flash_attention(query, key, value, splash_kernel):
def wrap_flash_attention(query, key, value):
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
# make_splash_mha is wrapped around shardmap and seq and head is already
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask,
head_shards=1, # the sizes of the axis is sharding over heads
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
block_sizes=block_sizes,
)
attention_output = jax.vmap(splash_kernel)(query, key, value)
return attention_output

Expand All @@ -236,7 +222,7 @@ def wrap_flash_attention(query, key, value, splash_kernel):
"Warning, batch dimension should be shardable among the devices in data and fsdp"
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
)
x = wrap_flash_attention(query, key, value, splash_kernel)
x = wrap_flash_attention(query, key, value)
x = x[:, :, :query_seq_len, :kv_size]
x = _reshape_heads_to_head_dim(x)

Expand Down Expand Up @@ -632,7 +618,7 @@ def __init__(
use_memory_efficient_attention: bool = False,
split_head_dim: bool = False,
attention_kernel: str = "flash",
flash_min_seq_length: int = 4096,
flash_min_seq_length: int = 0,
flash_block_sizes: BlockSizes = None,
mesh: jax.sharding.Mesh = None,
dtype: jnp.dtype = jnp.float32,
Expand Down Expand Up @@ -809,12 +795,16 @@ def __call__(
query_proj = _unflatten_heads(query_proj, self.heads)
key_proj = _unflatten_heads(key_proj, self.heads)
value_proj = _unflatten_heads(value_proj, self.heads)
# output of _unflatten_heads Batch, heads, seq_len, head_dim
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)

query_proj = checkpoint_name(query_proj, "query_proj")
key_proj = checkpoint_name(key_proj, "key_proj")
value_proj = checkpoint_name(value_proj, "value_proj")
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)

attn_output = attn_output.astype(dtype=dtype)

attn_output = checkpoint_name(attn_output, "attn_output")
hidden_states = self.proj_attn(attn_output)
return hidden_states

Expand Down
93 changes: 93 additions & 0 deletions src/maxdiffusion/models/gradient_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from enum import Enum, auto
from typing import Optional

import jax
from jax import checkpoint_policies as cp
from flax import nnx

SKIP_GRADIENT_CHECKPOINT_KEY = "skip"


# This class only works with NNX modules.
class GradientCheckpointType(Enum):
"""
Defines the type of the gradient checkpoint we will have

NONE - means no gradient checkpoint
FULL - means full gradient checkpoint, wherever possible (minimum memory usage)
MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
except for ones that involve batch dimension - that means that all attention and projection
layers will have gradient checkpoint, but not the backward with respect to the parameters
"""

NONE = auto()
FULL = auto()
MATMUL_WITHOUT_BATCH = auto()
ATTN = auto()

@classmethod
def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType":
"""
Constructs the gradient checkpoint type from a string

Args:
s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None.

Returns:
GradientCheckpointType: The policy that corresponds to the string
"""
if s is None:
s = "none"
return GradientCheckpointType[s.upper()]

def to_jax_policy(self):
"""
Converts the gradient checkpoint type to a jax policy
"""
match self:
case GradientCheckpointType.NONE:
return SKIP_GRADIENT_CHECKPOINT_KEY
case GradientCheckpointType.FULL:
return None
case GradientCheckpointType.ATTN:
return cp.save_and_offload_only_these_names(
names_which_can_be_saved=[], names_which_can_be_offloaded=[], offload_src="device", offload_dst="pinned_host"
)
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims

def apply(self, module: nnx.Module) -> nnx.Module:
"""
Applies a gradient checkpoint policy to a module
if no policy is needed, it will return the module as is

Args:
module (nn.Module): the module to apply the policy to

Returns:
nn.Module: the module with the policy applied
"""
policy = self.to_jax_policy()
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
return module
return nnx.remat( # pylint: disable=invalid-name
module,
prevent_cse=False,
policy=policy,
)
Loading
Loading