diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 8486c79d5..fc089337c 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -16,7 +16,7 @@ import jax import time from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline -from maxdiffusion import pyconfig, max_logging +from maxdiffusion import pyconfig, max_logging, max_utils from absl import app from maxdiffusion.utils import export_to_video @@ -59,8 +59,12 @@ def run(config, pipeline=None, filename_prefix=""): ) print("compile time: ", (time.perf_counter() - s0)) + saved_video_path = [] for i in range(len(videos)): - export_to_video(videos[i], f"{filename_prefix}wan_output_{config.seed}_{i}.mp4", fps=config.fps) + video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4" + export_to_video(videos[i], video_path, fps=config.fps) + saved_video_path.append(video_path) + s0 = time.perf_counter() videos = pipeline( prompt=prompt, @@ -74,12 +78,11 @@ def run(config, pipeline=None, filename_prefix=""): slg_start=slg_start, slg_end=slg_end, ) - print("generation time: ", (time.perf_counter() - s0)) - for i in range(len(videos)): - export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps) + print("compile time: ", (time.perf_counter() - s0)) s0 = time.perf_counter() - with jax.profiler.trace("/tmp/trace/"): + if config.enable_profiler: + max_utils.activate_profiler(config) videos = pipeline( prompt=prompt, negative_prompt=negative_prompt, @@ -92,7 +95,9 @@ def run(config, pipeline=None, filename_prefix=""): slg_start=slg_start, slg_end=slg_end, ) - print("generation time: ", (time.perf_counter() - s0)) + max_utils.deactivate_profiler(config) + print("generation time: ", (time.perf_counter() - s0)) + return saved_video_path def main(argv: Sequence[str]) -> None: diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 6af138bc7..31ce039ad 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -24,20 +24,39 @@ import jax.numpy as jnp import jax from flax import nnx -from ..schedulers import FlaxFlowMatchScheduler +from maxdiffusion.schedulers import FlaxFlowMatchScheduler from flax.linen import partitioning as nn_partitioning -from .. import max_utils, max_logging, train_utils -from ..checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT) +from maxdiffusion import max_utils, max_logging, train_utils +from maxdiffusion.checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT) from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) from maxdiffusion.generate_wan import run as generate_wan from maxdiffusion.train_utils import (_tensorboard_writer_worker, load_next_batch, _metrics_queue) +from maxdiffusion.video_processor import VideoProcessor +from maxdiffusion.utils import load_video +from skimage.metrics import structural_similarity as ssim def generate_sample(config, pipeline, filename_prefix): """ Generates a video to validate training did not corrupt the model """ - generate_wan(config, pipeline, filename_prefix) + return generate_wan(config, pipeline, filename_prefix) + + +def print_ssim(pretrained_video_path, posttrained_video_path): + video_processor = VideoProcessor() + pretrained_video = load_video(pretrained_video_path[0]) + pretrained_video = video_processor.preprocess_video(pretrained_video) + pretrained_video = np.array(pretrained_video) + pretrained_video = np.transpose(pretrained_video, (0, 2, 3, 4, 1)) + + posttrained_video = load_video(posttrained_video_path[0]) + posttrained_video = video_processor.preprocess_video(posttrained_video) + posttrained_video = np.array(posttrained_video) + posttrained_video = np.transpose(posttrained_video, (0, 2, 3, 4, 1)) + ssim_compare = ssim(pretrained_video[0], posttrained_video[0], multichannel=True, channel_axis=-1, data_range=255) + + max_logging.log(f"SSIM score after training is {ssim_compare}") class WanTrainer(WanCheckpointer): @@ -105,7 +124,7 @@ def start_training(self): # del pipeline.vae # Generate a sample before training to compare against generated sample after training. - generate_sample(self.config, pipeline, filename_prefix="pre-training-") + pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") mesh = pipeline.mesh data_iterator = self.load_dataset(mesh) @@ -115,7 +134,12 @@ def start_training(self): pipeline.scheduler_state = scheduler_state optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) - self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator) + + # Returns pipeline with trained transformer state + pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator) + + posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") + print_ssim(pretrained_video_path, posttrained_video_path) def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator): @@ -189,8 +213,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera # load new state for trained tranformer graphdef, _, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) pipeline.transformer = nnx.merge(graphdef, state[0], rest_of_state) - - generate_sample(self.config, pipeline, filename_prefix="post-training-") + return pipeline def train_step(state, graphdef, scheduler_state, data, rng, scheduler):