Skip to content

Commit 0a6fd5c

Browse files
Enhance Inference: Wan Refactor and Kernel Registry
- Refactored generate_wan.py to use InferenceLoader and DiffusionRunner. - Added ATTENTION_KERNEL_REGISTRY to attention_flax.py to support pluggable custom kernels (e.g. VSA).
1 parent 9aae51d commit 0a6fd5c

2 files changed

Lines changed: 57 additions & 116 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 44 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
1-
# Copyright 2025 Google LLC
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
1416

1517
from typing import Sequence
1618
import jax
1719
import time
1820
import os
1921
import subprocess
20-
from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1
21-
from maxdiffusion.checkpointing.wan_checkpointer_2_2 import WanCheckpointer2_2
22-
from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p1 import WanCheckpointerI2V_2_1
23-
from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2
2422
from maxdiffusion import pyconfig, max_logging, max_utils
2523
from absl import app
2624
from maxdiffusion.utils import export_to_video
@@ -29,6 +27,8 @@
2927
import flax
3028
from maxdiffusion.common_types import WAN2_1, WAN2_2
3129
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader
30+
from maxdiffusion.inference.loader import InferenceLoader
31+
from maxdiffusion.inference.runner import DiffusionRunner
3232

3333

3434
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -84,84 +84,6 @@ def get_git_commit_hash():
8484
jax.config.update("jax_use_shardy_partitioner", True)
8585

8686

87-
def call_pipeline(config, pipeline, prompt, negative_prompt):
88-
model_key = config.model_name
89-
model_type = config.model_type
90-
if model_type == "I2V":
91-
image = load_image(config.image_url)
92-
if model_key == WAN2_1:
93-
return pipeline(
94-
prompt=prompt,
95-
image=image,
96-
negative_prompt=negative_prompt,
97-
height=config.height,
98-
width=config.width,
99-
num_frames=config.num_frames,
100-
num_inference_steps=config.num_inference_steps,
101-
guidance_scale=config.guidance_scale,
102-
)
103-
elif model_key == WAN2_2:
104-
return pipeline(
105-
prompt=prompt,
106-
image=image,
107-
negative_prompt=negative_prompt,
108-
height=config.height,
109-
width=config.width,
110-
num_frames=config.num_frames,
111-
num_inference_steps=config.num_inference_steps,
112-
guidance_scale_low=config.guidance_scale_low,
113-
guidance_scale_high=config.guidance_scale_high,
114-
)
115-
else:
116-
raise ValueError(f"Unsupported model_name for I2V in config: {model_key}")
117-
elif model_type == "T2V":
118-
if model_key == WAN2_1:
119-
return pipeline(
120-
prompt=prompt,
121-
negative_prompt=negative_prompt,
122-
height=config.height,
123-
width=config.width,
124-
num_frames=config.num_frames,
125-
num_inference_steps=config.num_inference_steps,
126-
guidance_scale=config.guidance_scale,
127-
)
128-
elif model_key == WAN2_2:
129-
return pipeline(
130-
prompt=prompt,
131-
negative_prompt=negative_prompt,
132-
height=config.height,
133-
width=config.width,
134-
num_frames=config.num_frames,
135-
num_inference_steps=config.num_inference_steps,
136-
guidance_scale_low=config.guidance_scale_low,
137-
guidance_scale_high=config.guidance_scale_high,
138-
)
139-
else:
140-
raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}")
141-
142-
143-
def inference_generate_video(config, pipeline, filename_prefix=""):
144-
s0 = time.perf_counter()
145-
prompt = [config.prompt] * config.global_batch_size_to_train_on
146-
negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on
147-
148-
max_logging.log(
149-
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}, video: {filename_prefix}"
150-
)
151-
152-
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
153-
154-
max_logging.log(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}")
155-
for i in range(len(videos)):
156-
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
157-
export_to_video(videos[i], video_path, fps=config.fps)
158-
if config.output_dir.startswith("gs://"):
159-
upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path)
160-
# Delete local files to avoid storing too manys videos
161-
delete_file(f"./{video_path}")
162-
return
163-
164-
16587
def run(config, pipeline=None, filename_prefix="", commit_hash=None):
16688
model_key = config.model_name
16789
writer = max_utils.initialize_summary_writer(config)
@@ -174,23 +96,22 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
17496
else:
17597
max_logging.log("Could not retrieve Git commit hash.")
17698

99+
loaded_model = None
177100
if pipeline is None:
178-
model_type = config.model_type
179-
if model_key == WAN2_1:
180-
if model_type == "I2V":
181-
checkpoint_loader = WanCheckpointerI2V_2_1(config=config)
182-
else:
183-
checkpoint_loader = WanCheckpointer2_1(config=config)
184-
elif model_key == WAN2_2:
185-
if model_type == "I2V":
186-
checkpoint_loader = WanCheckpointerI2V_2_2(config=config)
187-
else:
188-
checkpoint_loader = WanCheckpointer2_2(config=config)
189-
else:
190-
raise ValueError(f"Unsupported model_name for checkpointer: {model_key}")
191-
pipeline, _, _ = checkpoint_loader.load_checkpoint()
101+
max_logging.log("Initializing InferenceLoader...")
102+
loaded_model = InferenceLoader.load(config)
103+
pipeline = loaded_model["pipeline"]
104+
else:
105+
# If pipeline passed explicitly (e.g. from test), wrap it
106+
# But InferenceLoader logic assumes it creates it.
107+
# We construct a dummy loaded_model dict
108+
loaded_model = {
109+
"pipeline": pipeline,
110+
"mesh": getattr(config, "mesh", None) # Fallback
111+
}
192112

193113
# If LoRA is specified, inject layers and load weights.
114+
# TODO: Move this into InferenceLoader._load_wan eventually
194115
if (
195116
config.enable_lora
196117
and hasattr(config, "lora_config")
@@ -225,17 +146,22 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
225146
scan_layers=config.scan_layers,
226147
dtype=config.weights_dtype,
227148
)
149+
# Update loaded model with modified pipeline
150+
loaded_model["pipeline"] = pipeline
228151

229152
s0 = time.perf_counter()
230153

231-
# Using global_batch_size_to_train_on so not to create more config variables
232-
prompt = [config.prompt] * config.global_batch_size_to_train_on
233-
negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on
154+
max_logging.log("Initializing DiffusionRunner...")
155+
runner = DiffusionRunner(loaded_model, config)
234156

235157
max_logging.log(
236158
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
237159
)
238-
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
160+
161+
# Using global_batch_size_to_train_on logic is handled by Runner/Pipeline mostly now
162+
# But we can override args
163+
164+
videos = runner.run()
239165

240166
max_logging.log("===================== Model details =======================")
241167
max_logging.log(f"model name: {config.model_name}")
@@ -257,9 +183,10 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
257183
saved_video_path.append(video_path)
258184
if config.output_dir.startswith("gs://"):
259185
upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path)
186+
delete_file(f"./{video_path}")
260187

261188
s0 = time.perf_counter()
262-
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
189+
videos = runner.run()
263190
generation_time = time.perf_counter() - s0
264191
max_logging.log(f"generation_time: {generation_time}")
265192
if writer and jax.process_index() == 0:
@@ -272,10 +199,11 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
272199
max_logging.log(f"generation time per video: {generation_time_per_video}")
273200
else:
274201
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
202+
275203
s0 = time.perf_counter()
276204
if config.enable_profiler:
277205
max_utils.activate_profiler(config)
278-
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
206+
videos = runner.run()
279207
max_utils.deactivate_profiler(config)
280208
generation_time_with_profiler = time.perf_counter() - s0
281209
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
@@ -296,4 +224,4 @@ def main(argv: Sequence[str]) -> None:
296224

297225

298226
if __name__ == "__main__":
299-
app.run(main)
227+
app.run(main)

src/maxdiffusion/models/attention_flax.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,12 @@ def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, m
500500
return _reshape_data_from_cudnn_flash(out)
501501

502502

503+
ATTENTION_KERNEL_REGISTRY = {}
504+
505+
def register_attention_kernel(name: str, func: Callable):
506+
"""Registers a custom attention kernel."""
507+
ATTENTION_KERNEL_REGISTRY[name] = func
508+
503509
def _apply_attention(
504510
query: Array,
505511
key: Array,
@@ -524,6 +530,13 @@ def _apply_attention(
524530
):
525531
"""Routes to different attention kernels."""
526532
_check_attention_inputs(query, key, value)
533+
534+
# Check Registry first
535+
if attention_kernel in ATTENTION_KERNEL_REGISTRY:
536+
return ATTENTION_KERNEL_REGISTRY[attention_kernel](
537+
query, key, value, heads, dim_head, scale, dtype, mesh
538+
)
539+
527540
seq_len_idx = 1
528541
if query.ndim == 4:
529542
seq_len_idx = 2

0 commit comments

Comments
 (0)