Skip to content

Commit 6e17c3e

Browse files
authored
Adding Img2Vid support for WAN 2.1 and WAN 2.2 (#306)
* WAN Img2Vid Implementation * Removed randn_tensor function import * logical_axis rules and attention_sharding_uniform added in config files * removed attn_mask from FlaxWanAttn call * fix to prevent load_image_encoder from running for wan 2.2 iv * boundary_ratio removed from generate_wan.py * testing with 720p * model restored * attn_mask correction * transformer corrected in wan 2.2 t2v and config files updated * revert * corrected * import added in wan_checkpointer_test.py * wan_checkpointer_test.py corrected * wan_checkpointer_test.py corrected * wan_checkpointer_test.py corrected * removed redundance img attn mask * Fix for multiple videos * Fix for multiple videos * Fix for multiple videos * Fix for multiple videos * removed redundant args * removed redundant args * trying dot attn fix * reverting fix to see if that was the issue * fix verified * updated comments * Added sharding * sharding added * ruff checks * README updated * sharding * ruff check
1 parent 2879a65 commit 6e17c3e

18 files changed

Lines changed: 2303 additions & 74 deletions

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2026/1/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported
2021
- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported
2122
- **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported.
2223
- **`2025/10/14`**: NVIDIA DGX Spark Flux support.
@@ -482,19 +483,38 @@ To generate images, run the following command:
482483

483484
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
484485

486+
### Text2Vid
487+
485488
```bash
486489
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
487490
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
488491
```
492+
493+
### Img2Vid
494+
495+
```bash
496+
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
497+
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_i2v_14b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=3.0 enable_profiler=True run_name=wan-i2v-inference-testing-480p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
498+
```
499+
489500
## Wan2.2
490501

491502
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
492503

504+
### Text2Vid
505+
493506
```bash
494507
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
495508
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
496509
```
497510

511+
### Img2Vid
512+
513+
```bash
514+
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
515+
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_i2v_27b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=3.0 enable_profiler=True run_name=wan-i2v-inference-testing-480p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
516+
```
517+
498518
## Flux
499519

500520
First make sure you have permissions to access the Flux repos in Huggingface.

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
2020
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
2121
from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
22+
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
23+
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
2224
from .. import max_logging, max_utils
2325
import orbax.checkpoint as ocp
2426

@@ -59,7 +61,7 @@ def load_diffusers_checkpoint(self):
5961
raise NotImplementedError
6062

6163
@abstractmethod
62-
def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2], Optional[dict], Optional[int]]:
64+
def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2 | WanPipelineI2V_2_1 | WanPipelineI2V_2_2], Optional[dict], Optional[int]]:
6365
raise NotImplementedError
6466

6567
@abstractmethod
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import json
18+
import jax
19+
import numpy as np
20+
from typing import Optional, Tuple
21+
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
22+
from .. import max_logging
23+
import orbax.checkpoint as ocp
24+
from etils import epath
25+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
26+
27+
class WanCheckpointerI2V_2_1(WanCheckpointer):
28+
29+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
30+
if step is None:
31+
step = self.checkpoint_manager.latest_step()
32+
max_logging.log(f"Latest WAN checkpoint step: {step}")
33+
if step is None:
34+
max_logging.log("No WAN checkpoint found.")
35+
return None, None
36+
max_logging.log(f"Loading WAN checkpoint from step {step}")
37+
metadatas = self.checkpoint_manager.item_metadata(step)
38+
transformer_metadata = metadatas.wan_state
39+
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
40+
params_restore = ocp.args.PyTreeRestore(
41+
restore_args=jax.tree.map(
42+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
43+
abstract_tree_structure_params,
44+
)
45+
)
46+
47+
max_logging.log("Restoring WAN checkpoint")
48+
restored_checkpoint = self.checkpoint_manager.restore(
49+
directory=epath.Path(self.config.checkpoint_dir),
50+
step=step,
51+
args=ocp.args.Composite(
52+
wan_state=params_restore,
53+
wan_config=ocp.args.JsonRestore(),
54+
),
55+
)
56+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
57+
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
58+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
59+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
60+
return restored_checkpoint, step
61+
62+
def load_diffusers_checkpoint(self):
63+
pipeline = WanPipelineI2V_2_1.from_pretrained(self.config)
64+
return pipeline
65+
66+
def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_1, Optional[dict], Optional[int]]:
67+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
68+
opt_state = None
69+
if restored_checkpoint:
70+
max_logging.log("Loading WAN pipeline from checkpoint")
71+
pipeline = WanPipelineI2V_2_1.from_checkpoint(self.config, restored_checkpoint)
72+
if "opt_state" in restored_checkpoint.wan_state.keys():
73+
opt_state = restored_checkpoint.wan_state["opt_state"]
74+
else:
75+
max_logging.log("No checkpoint found, loading default pipeline.")
76+
pipeline = self.load_diffusers_checkpoint()
77+
78+
return pipeline, opt_state, step
79+
80+
def save_checkpoint(self, train_step, pipeline: WanPipelineI2V_2_1, train_states: dict):
81+
"""Saves the training state and model configurations."""
82+
83+
def config_to_json(model_or_config):
84+
return json.loads(model_or_config.to_json_string())
85+
86+
max_logging.log(f"Saving checkpoint for step {train_step}")
87+
items = {
88+
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
89+
}
90+
91+
items["wan_state"] = ocp.args.PyTreeSave(train_states)
92+
93+
# Save the checkpoint
94+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
95+
max_logging.log(f"Checkpoint for step {train_step} saved.")
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import json
18+
import jax
19+
import numpy as np
20+
from typing import Optional, Tuple
21+
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
22+
from .. import max_logging
23+
import orbax.checkpoint as ocp
24+
from etils import epath
25+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
26+
27+
class WanCheckpointerI2V_2_2(WanCheckpointer):
28+
29+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
30+
if step is None:
31+
step = self.checkpoint_manager.latest_step()
32+
max_logging.log(f"Latest WAN checkpoint step: {step}")
33+
if step is None:
34+
max_logging.log("No WAN checkpoint found.")
35+
return None, None
36+
max_logging.log(f"Loading WAN checkpoint from step {step}")
37+
metadatas = self.checkpoint_manager.item_metadata(step)
38+
39+
# Handle low_noise_transformer
40+
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
41+
abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata)
42+
low_params_restore = ocp.args.PyTreeRestore(
43+
restore_args=jax.tree.map(
44+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
45+
abstract_tree_structure_low_params,
46+
)
47+
)
48+
49+
# Handle high_noise_transformer
50+
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
51+
abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata)
52+
high_params_restore = ocp.args.PyTreeRestore(
53+
restore_args=jax.tree.map(
54+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
55+
abstract_tree_structure_high_params,
56+
)
57+
)
58+
59+
max_logging.log("Restoring WAN 2.2 checkpoint")
60+
restored_checkpoint = self.checkpoint_manager.restore(
61+
directory=epath.Path(self.config.checkpoint_dir),
62+
step=step,
63+
args=ocp.args.Composite(
64+
low_noise_transformer_state=low_params_restore,
65+
high_noise_transformer_state=high_params_restore,
66+
wan_config=ocp.args.JsonRestore(),
67+
),
68+
)
69+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
70+
max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}")
71+
max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}")
72+
max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}")
73+
max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}")
74+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
75+
return restored_checkpoint, step
76+
77+
def load_diffusers_checkpoint(self):
78+
pipeline = WanPipelineI2V_2_2.from_pretrained(self.config)
79+
return pipeline
80+
81+
def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_2, Optional[dict], Optional[int]]:
82+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
83+
opt_state = None
84+
if restored_checkpoint:
85+
max_logging.log("Loading WAN pipeline from checkpoint")
86+
pipeline = WanPipelineI2V_2_2.from_checkpoint(self.config, restored_checkpoint)
87+
# Check for optimizer state in either transformer
88+
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
89+
opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"]
90+
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
91+
opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"]
92+
else:
93+
max_logging.log("No checkpoint found, loading default pipeline.")
94+
pipeline = self.load_diffusers_checkpoint()
95+
96+
return pipeline, opt_state, step
97+
98+
def save_checkpoint(self, train_step, pipeline: WanPipelineI2V_2_2, train_states: dict):
99+
"""Saves the training state and model configurations."""
100+
101+
def config_to_json(model_or_config):
102+
return json.loads(model_or_config.to_json_string())
103+
104+
max_logging.log(f"Saving checkpoint for step {train_step}")
105+
items = {
106+
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
107+
}
108+
109+
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
110+
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
111+
112+
# Save the checkpoint
113+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
114+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
3131
model_name: wan2.1
32+
model_type: 'T2V'
3233

3334
# Overrides the transformer from pretrained_model_name_or_path
3435
wan_transformer_pretrained_model_name_or_path: ''

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.2-T2V-A14B-Diffusers'
3131
model_name: wan2.2
32+
model_type: 'T2V'
3233

3334
# Overrides the transformer from pretrained_model_name_or_path
3435
wan_transformer_pretrained_model_name_or_path: ''

0 commit comments

Comments
 (0)