diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index f25538631..486c8ba3c 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -298,3 +298,4 @@ quantization_calibration_method: "absmax" # Eval model on per eval_every steps. -1 means don't eval. eval_every: -1 eval_data_dir: "" +enable_generate_video_for_eval: False # This will increase the used TPU memory. diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 519bc8cb3..1dc1789a1 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -15,13 +15,78 @@ from typing import Sequence import jax import time +import os from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline from maxdiffusion import pyconfig, max_logging, max_utils from absl import app from maxdiffusion.utils import export_to_video +from google.cloud import storage + +def upload_video_to_gcs(output_dir: str, video_path: str): + """ + Uploads a local video file to a specified Google Cloud Storage bucket. + """ + try: + path_without_scheme = output_dir.removeprefix("gs://") + parts = path_without_scheme.split('/', 1) + bucket_name = parts[0] + folder_name = parts[1] if len(parts) > 1 else '' + + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + + source_file_path = f"./{video_path}" + destination_blob_name = os.path.join(folder_name, "videos", video_path) + + blob = bucket.blob(destination_blob_name) + + max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...") + blob.upload_from_filename(source_file_path) + max_logging.log(f"Upload complete {source_file_path}.") + + except Exception as e: + max_logging.log(f"An error occurred: {e}") + +def delete_file(file_path: str): + if os.path.exists(file_path): + try: + os.remove(file_path) + max_logging.log(f"Successfully deleted file: {file_path}") + except OSError as e: + max_logging.log(f"Error deleting file '{file_path}': {e}") + else: + max_logging.log(f"The file '{file_path}' does not exist.") jax.config.update("jax_use_shardy_partitioner", True) +def inference_generate_video(config, pipeline, filename_prefix=""): + s0 = time.perf_counter() + prompt = [config.prompt] * config.global_batch_size_to_train_on + negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on + + max_logging.log( + f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}, video: {filename_prefix}" + ) + + videos = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + ) + + max_logging.log(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}") + for i in range(len(videos)): + video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4" + export_to_video(videos[i], video_path, fps=config.fps) + if config.output_dir.startswith("gs://"): + upload_video_to_gcs(config.output_dir, video_path) + # Delete local files to avoid storing too manys videos + delete_file(f"./{video_path}") + return def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) @@ -57,6 +122,8 @@ def run(config, pipeline=None, filename_prefix=""): video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4" export_to_video(videos[i], video_path, fps=config.fps) saved_video_path.append(video_path) + if config.output_dir.startswith("gs://"): + upload_video_to_gcs(config.output_dir, video_path) s0 = time.perf_counter() videos = pipeline( diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 6b4f646c4..cc8142159 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -31,6 +31,7 @@ from maxdiffusion.checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT) from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) from maxdiffusion.generate_wan import run as generate_wan +from maxdiffusion.generate_wan import inference_generate_video from maxdiffusion.train_utils import (_tensorboard_writer_worker, load_next_batch, _metrics_queue) from maxdiffusion.video_processor import VideoProcessor from maxdiffusion.utils import load_video @@ -151,9 +152,10 @@ def start_training(self): # Generate a sample before training to compare against generated sample after training. pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") - # save some memory. - del pipeline.vae - del pipeline.vae_cache + if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval): + # save some memory. + del pipeline.vae + del pipeline.vae_cache mesh = pipeline.mesh train_data_iterator = self.load_dataset(mesh, is_training=True) @@ -249,6 +251,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0: + if self.config.enable_generate_video_for_eval: + pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) + inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-") # Re-create the iterator each time you start evaluation to reset it # This assumes your data loading logic can be called to get a fresh iterator. eval_data_iterator = self.load_dataset(mesh, is_training=False)