Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
15d242e
wip - wan transformer
jfacevedo-google Mar 9, 2025
30b20bd
Merge branch 'main' into wan
jfacevedo-google Apr 15, 2025
9b63238
adding nnx - wip
jfacevedo-google Apr 16, 2025
ae7a538
Merge branch 'main' into wan
jfacevedo-google Apr 18, 2025
9276f26
wan pipeline wip
jfacevedo-google Apr 18, 2025
120ceb3
wip - vae
jfacevedo-google Apr 22, 2025
4d46776
Merge branch 'wan' of https://github.com/AI-Hypercomputer/maxdiffusio…
jfacevedo-google Apr 22, 2025
cc11bb1
added tests for a couple of wan vae layers.
jfacevedo-google Apr 23, 2025
4e443b8
add unit tests to wan vae padded conv.
jfacevedo-google Apr 23, 2025
aeabe27
wip - test for vae encoder.
jfacevedo-google Apr 24, 2025
0ec4b02
Residual block test
jfacevedo-google Apr 24, 2025
9b42117
add wan vae attention test
jfacevedo-google Apr 24, 2025
4325325
add wan mid block vae test
jfacevedo-google Apr 24, 2025
a5e1e95
finishes vae encoder with matching shapes
jfacevedo-google Apr 28, 2025
efe8528
add cache logic to modules.
jfacevedo-google Apr 28, 2025
cf68754
adds decoder and checks matching resolutions.
jfacevedo-google Apr 29, 2025
cd16f28
run linter
jfacevedo-google Apr 29, 2025
089f8ac
fix unit tests
jfacevedo-google Apr 29, 2025
40d423d
e2e wan vae with weights loading. Still not fully working.
jfacevedo-google May 2, 2025
34ebdbe
debug statements
jfacevedo-google May 7, 2025
04f4909
solves distored decoded video. Now video is jittery, but frames are ok.
jfacevedo-google May 7, 2025
66146b9
fixes jittery decoder frames in vae.
jfacevedo-google May 8, 2025
d9749e9
Merge branch 'main' into wan_vae_debugging
jfacevedo-google May 8, 2025
4245b24
cleanup unused code.
jfacevedo-google May 8, 2025
c0ba5c1
linting
jfacevedo-google May 8, 2025
bab4d17
remove wan from readme.
jfacevedo-google May 8, 2025
7a8daed
remove unused files.
jfacevedo-google May 8, 2025
b31b4ad
more linter fixes.
jfacevedo-google May 8, 2025
d449d1f
add WanRotaryPosEmbed
jfacevedo-google May 9, 2025
064fc5f
add nnx classes for timestep embeddings and timesteps.
jfacevedo-google May 9, 2025
08444fd
add wan time text embedding layer.
jfacevedo-google May 9, 2025
2499b2d
add fp32 layer norm
jfacevedo-google May 9, 2025
b9b2465
wip - attention for wan.
jfacevedo-google May 13, 2025
1abc00c
wrap up attention.
jfacevedo-google May 13, 2025
4c00085
add transformer block
jfacevedo-google May 13, 2025
440f39c
wan transformer with in/out shapes verified
jfacevedo-google May 13, 2025
0ef8c71
load wan 2.1 transformer weights.
jfacevedo-google May 14, 2025
82b719e
fix rope calculations.
jfacevedo-google May 20, 2025
3267ec9
Merge branch 'main' into wan_transformer
jfacevedo-google May 22, 2025
38bea20
fix gelu block.
jfacevedo-google May 22, 2025
716598b
wip - building pipeline and gen code.
jfacevedo-google May 23, 2025
6973222
initial wan pipeline for txt2vid. Not currently working.
jfacevedo-google May 27, 2025
0731a49
add sharding annotations for vae. Verified transformer correctness fo…
jfacevedo-google May 29, 2025
b7c8ba6
wan pipeline with generation. Correctness is still not verified.
jfacevedo-google May 29, 2025
2388908
use collapse instead of reshape for final activation.
jfacevedo-google May 30, 2025
5cc2e49
implements a working wan 2.1 pipeline.
jfacevedo-google May 30, 2025
5f2434d
fix attention bug for lower frames.
jfacevedo-google Jun 3, 2025
d64e521
reduces memory significantly when loading transformer. Needs clean up.
jfacevedo-google Jun 4, 2025
56f5225
support bs > 1. Issue where all gens except for 1st coming out bad.
jfacevedo-google Jun 4, 2025
87817d0
Merge branch 'main' into wan_transformer
jfacevedo-google Jun 4, 2025
9ee7fd3
improves performance by 14% on v5p.
jfacevedo-google Jun 6, 2025
b84fc34
implements skip layer guidance for better generations.
jfacevedo-google Jun 6, 2025
05f0554
initial commit for wan training
jfacevedo-google Jun 10, 2025
a60d235
working training pipeline on v5p at num_frames=1
jfacevedo-google Jun 11, 2025
b90584c
wan training for single frame + bug fixes.
jfacevedo-google Jun 12, 2025
cd031b8
Merge branch 'main' into wan_training
jfacevedo-google Jun 12, 2025
3bedc5d
lint.
jfacevedo-google Jun 12, 2025
5494644
halves inference time.
jfacevedo-google Jun 13, 2025
fc77dc0
fix some tests.
jfacevedo-google Jun 16, 2025
50a029d
lint
jfacevedo-google Jun 16, 2025
cc2c288
update tests.
jfacevedo-google Jun 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/maxdiffusion/checkpointing/wan_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from abc import ABC
from flax import nnx
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
from ..pipelines.wan.wan_pipeline import WanPipeline
from .. import max_logging, max_utils

WAN_CHECKPOINT = "WAN_CHECKPOINT"


class WanCheckpointer(ABC):

def __init__(self, config, checkpoint_type):
self.config = config
self.checkpoint_type = checkpoint_type

self.checkpoint_manager = create_orbax_checkpoint_manager(
self.config.checkpoint_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=checkpoint_type,
dataset_type=config.dataset_type,
)

def _create_optimizer(self, model, config, learning_rate):
learning_rate_scheduler = max_utils.create_learning_rate_schedule(
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
)
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
return nnx.Optimizer(model, tx), learning_rate_scheduler

def load_wan_configs_from_orbax(self, step):
max_logging.log("Restoring stable diffusion configs")
if step is None:
step = self.checkpoint_manager.latest_step()
if step is None:
return None

def load_diffusers_checkpoint(self):
pipeline = WanPipeline.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None):
model_configs = self.load_wan_configs_from_orbax(step)

if model_configs:
raise NotImplementedError("model configs should not exist in orbax")
else:
pipeline = self.load_diffusers_checkpoint()

return pipeline
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,19 @@ run_name: ''
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
write_metrics: True

timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
write_timing_metrics: True

gcs_metrics: False
# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
log_period: 100

pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'

# Flux params
flux_name: "flux-dev"
max_sequence_length: 512
time_shift: True
base_shift: 0.5
max_shift: 1.15
# offloads t5 encoder after text encoding to save memory.
offload_encoders: True


unet_checkpoint: ''
revision: 'refs/pr/95'
revision: ''
# This will convert the weights to this dtype.
# When running inference on TPUv5e, use weights_dtype: 'bfloat16'
weights_dtype: 'bfloat16'
Expand All @@ -59,24 +53,9 @@ split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te

flash_block_sizes: {}
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
# flash_block_sizes: {
# "block_q" : 1536,
# "block_kv_compute" : 1536,
# "block_kv" : 1536,
# "block_q_dkv" : 1536,
# "block_kv_dkv" : 1536,
# "block_kv_dkv_compute" : 1536,
# "block_q_dq" : 1536,
# "block_kv_dq" : 1536
# }
# GroupNorm groups
norm_num_groups: 32

# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch
# else they will be loaded from pretrained_model_name_or_path
train_new_unet: False

# train text_encoder - Currently not supported for SDXL
train_text_encoder: False
text_encoder_learning_rate: 4.25e-6
Expand Down Expand Up @@ -133,15 +112,17 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_heads', 'fsdp'],
['activation_batch', ['data','fsdp']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['norm', 'fsdp'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_in', 'fsdp']
]
data_sharding: [['data', 'fsdp', 'tensor']]

Expand All @@ -152,8 +133,8 @@ data_sharding: [['data', 'fsdp', 'tensor']]
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

# Dataset
Expand Down Expand Up @@ -192,17 +173,19 @@ checkpoint_every: -1
enable_single_replica_ckpt_restoring: False

# Training loop
learning_rate: 4.e-7
learning_rate: 1.e-5
scale_lr: False
max_train_samples: -1
# max_train_steps takes priority over num_train_epochs.
max_train_steps: 200
max_train_steps: 1500
num_train_epochs: 1
seed: 0
output_dir: 'sdxl-model-finetuned'
per_device_batch_size: 1
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
global_batch_size: 0

warmup_steps_fraction: 0.0
warmup_steps_fraction: 0.1
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
Expand All @@ -212,7 +195,7 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_weight_decay: 1.e-2 # AdamW Weight decay
adam_weight_decay: 0 # AdamW Weight decay
max_grad_norm: 1.0

enable_profiler: False
Expand All @@ -222,14 +205,25 @@ skip_first_n_steps_for_profiler: 5
profiler_steps: 10

# Generation parameters
prompt: "A magical castle in the middle of a forest, artistic drawing"
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
negative_prompt: "purple, red"
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
do_classifier_free_guidance: True
guidance_scale: 3.5
height: 480
width: 832
num_frames: 81
guidance_scale: 5.0
flow_shift: 3.0

# skip layer guidance
slg_layers: [9]
slg_start: 0.2
slg_end: 1.0
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 50
num_inference_steps: 30
fps: 24
save_final_checkpoint: False

# SDXL Lightning parameters
lightning_from_pt: True
Expand Down
103 changes: 103 additions & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence
import jax
import time
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
from maxdiffusion import pyconfig, max_logging
from absl import app
from maxdiffusion.utils import export_to_video


def run(config):
print("seed: ", config.seed)
pipeline = WanPipeline.from_pretrained(config)
s0 = time.perf_counter()

# Skip layer guidance
slg_layers = config.slg_layers
slg_start = config.slg_start
slg_end = config.slg_end
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
global_batch_size = config.global_batch_size
if global_batch_size != 0:
batch_multiplier = global_batch_size
else:
batch_multiplier = jax.device_count() * config.per_device_batch_size

prompt = [config.prompt] * batch_multiplier
negative_prompt = [config.negative_prompt] * batch_multiplier

max_logging.log(
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
)

videos = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
slg_layers=slg_layers,
slg_start=slg_start,
slg_end=slg_end,
)

print("compile time: ", (time.perf_counter() - s0))
for i in range(len(videos)):
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
s0 = time.perf_counter()
videos = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
slg_layers=slg_layers,
slg_start=slg_start,
slg_end=slg_end,
)
print("generation time: ", (time.perf_counter() - s0))
for i in range(len(videos)):
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)

s0 = time.perf_counter()
with jax.profiler.trace("/tmp/trace/"):
videos = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
slg_layers=slg_layers,
slg_start=slg_start,
slg_end=slg_end,
)
print("generation time: ", (time.perf_counter() - s0))


def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
run(pyconfig.config)


if __name__ == "__main__":
app.run(main)
28 changes: 28 additions & 0 deletions src/maxdiffusion/maxdiffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,34 @@ def get_dummy_flux_inputs(config, pipeline, batch_size):
return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states)


def get_dummy_wan_inputs(config, pipeline, batch_size):
latents = pipeline.prepare_latents(
batch_size,
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_channels_latents=pipeline.transformer.config.in_channels,
)
bsz = latents.shape[0]
prompt_embeds = jax.random.normal(jax.random.key(config.seed), (batch_size, 512, 4096))
timesteps = jnp.array([0] * bsz, dtype=jnp.int32)
return (latents, prompt_embeds, timesteps)


def calculate_wan_tflops(config, pipeline, batch_size, rngs, train):
"""
Calculates jflux tflops.
batch_size should be per_device_batch_size * jax.local_device_count() or attention's shard_map won't
cache the compilation when flash is enabled.
"""
(latents, prompt_embeds, timesteps) = get_dummy_wan_inputs(config, pipeline, batch_size)
return max_utils.calculate_model_tflops(
pipeline.transformer,
)


def calculate_flux_tflops(config, pipeline, batch_size, rngs, train):
"""
Calculates jflux tflops.
Expand Down
Loading
Loading