Skip to content

Commit 34a9134

Browse files
authored
support generating video for eval (#239)
* support generating video for eval * remove redundant note
1 parent 39b227b commit 34a9134

3 files changed

Lines changed: 76 additions & 3 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,4 @@ quantization_calibration_method: "absmax"
298298
# Eval model on per eval_every steps. -1 means don't eval.
299299
eval_every: -1
300300
eval_data_dir: ""
301+
enable_generate_video_for_eval: False # This will increase the used TPU memory.

src/maxdiffusion/generate_wan.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,78 @@
1515
from typing import Sequence
1616
import jax
1717
import time
18+
import os
1819
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
1920
from maxdiffusion import pyconfig, max_logging, max_utils
2021
from absl import app
2122
from maxdiffusion.utils import export_to_video
23+
from google.cloud import storage
24+
25+
def upload_video_to_gcs(output_dir: str, video_path: str):
26+
"""
27+
Uploads a local video file to a specified Google Cloud Storage bucket.
28+
"""
29+
try:
30+
path_without_scheme = output_dir.removeprefix("gs://")
31+
parts = path_without_scheme.split('/', 1)
32+
bucket_name = parts[0]
33+
folder_name = parts[1] if len(parts) > 1 else ''
34+
35+
storage_client = storage.Client()
36+
bucket = storage_client.bucket(bucket_name)
37+
38+
source_file_path = f"./{video_path}"
39+
destination_blob_name = os.path.join(folder_name, "videos", video_path)
40+
41+
blob = bucket.blob(destination_blob_name)
42+
43+
max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...")
44+
blob.upload_from_filename(source_file_path)
45+
max_logging.log(f"Upload complete {source_file_path}.")
46+
47+
except Exception as e:
48+
max_logging.log(f"An error occurred: {e}")
49+
50+
def delete_file(file_path: str):
51+
if os.path.exists(file_path):
52+
try:
53+
os.remove(file_path)
54+
max_logging.log(f"Successfully deleted file: {file_path}")
55+
except OSError as e:
56+
max_logging.log(f"Error deleting file '{file_path}': {e}")
57+
else:
58+
max_logging.log(f"The file '{file_path}' does not exist.")
2259

2360
jax.config.update("jax_use_shardy_partitioner", True)
2461

62+
def inference_generate_video(config, pipeline, filename_prefix=""):
63+
s0 = time.perf_counter()
64+
prompt = [config.prompt] * config.global_batch_size_to_train_on
65+
negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on
66+
67+
max_logging.log(
68+
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}, video: {filename_prefix}"
69+
)
70+
71+
videos = pipeline(
72+
prompt=prompt,
73+
negative_prompt=negative_prompt,
74+
height=config.height,
75+
width=config.width,
76+
num_frames=config.num_frames,
77+
num_inference_steps=config.num_inference_steps,
78+
guidance_scale=config.guidance_scale,
79+
)
80+
81+
max_logging.log(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}")
82+
for i in range(len(videos)):
83+
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
84+
export_to_video(videos[i], video_path, fps=config.fps)
85+
if config.output_dir.startswith("gs://"):
86+
upload_video_to_gcs(config.output_dir, video_path)
87+
# Delete local files to avoid storing too manys videos
88+
delete_file(f"./{video_path}")
89+
return
2590

2691
def run(config, pipeline=None, filename_prefix=""):
2792
print("seed: ", config.seed)
@@ -57,6 +122,8 @@ def run(config, pipeline=None, filename_prefix=""):
57122
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
58123
export_to_video(videos[i], video_path, fps=config.fps)
59124
saved_video_path.append(video_path)
125+
if config.output_dir.startswith("gs://"):
126+
upload_video_to_gcs(config.output_dir, video_path)
60127

61128
s0 = time.perf_counter()
62129
videos = pipeline(

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from maxdiffusion.checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT)
3232
from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator)
3333
from maxdiffusion.generate_wan import run as generate_wan
34+
from maxdiffusion.generate_wan import inference_generate_video
3435
from maxdiffusion.train_utils import (_tensorboard_writer_worker, load_next_batch, _metrics_queue)
3536
from maxdiffusion.video_processor import VideoProcessor
3637
from maxdiffusion.utils import load_video
@@ -151,9 +152,10 @@ def start_training(self):
151152
# Generate a sample before training to compare against generated sample after training.
152153
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
153154

154-
# save some memory.
155-
del pipeline.vae
156-
del pipeline.vae_cache
155+
if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval):
156+
# save some memory.
157+
del pipeline.vae
158+
del pipeline.vae_cache
157159

158160
mesh = pipeline.mesh
159161
train_data_iterator = self.load_dataset(mesh, is_training=True)
@@ -249,6 +251,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
249251
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)
250252

251253
if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0:
254+
if self.config.enable_generate_video_for_eval:
255+
pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state)
256+
inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-")
252257
# Re-create the iterator each time you start evaluation to reset it
253258
# This assumes your data loading logic can be called to get a fresh iterator.
254259
eval_data_iterator = self.load_dataset(mesh, is_training=False)

0 commit comments

Comments
 (0)