Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import jax
import time
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
from maxdiffusion import pyconfig, max_logging
from maxdiffusion import pyconfig, max_logging, max_utils
from absl import app
from maxdiffusion.utils import export_to_video

Expand Down Expand Up @@ -59,8 +59,12 @@ def run(config, pipeline=None, filename_prefix=""):
)

print("compile time: ", (time.perf_counter() - s0))
saved_video_path = []
for i in range(len(videos)):
export_to_video(videos[i], f"{filename_prefix}wan_output_{config.seed}_{i}.mp4", fps=config.fps)
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
export_to_video(videos[i], video_path, fps=config.fps)
saved_video_path.append(video_path)

s0 = time.perf_counter()
videos = pipeline(
prompt=prompt,
Expand All @@ -74,12 +78,11 @@ def run(config, pipeline=None, filename_prefix=""):
slg_start=slg_start,
slg_end=slg_end,
)
print("generation time: ", (time.perf_counter() - s0))
for i in range(len(videos)):
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
print("compile time: ", (time.perf_counter() - s0))

s0 = time.perf_counter()
with jax.profiler.trace("/tmp/trace/"):
if config.enable_profiler:
max_utils.activate_profiler(config)
videos = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
Expand All @@ -92,7 +95,9 @@ def run(config, pipeline=None, filename_prefix=""):
slg_start=slg_start,
slg_end=slg_end,
)
print("generation time: ", (time.perf_counter() - s0))
max_utils.deactivate_profiler(config)
print("generation time: ", (time.perf_counter() - s0))
return saved_video_path


def main(argv: Sequence[str]) -> None:
Expand Down
33 changes: 28 additions & 5 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,32 @@
from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator)
from maxdiffusion.generate_wan import run as generate_wan
from maxdiffusion.train_utils import (_tensorboard_writer_worker, load_next_batch, _metrics_queue)
from ..video_processor import VideoProcessor
Comment thread
coolkp marked this conversation as resolved.
Outdated
from ..utils import load_video
from skimage.metrics import structural_similarity as ssim
Comment thread
coolkp marked this conversation as resolved.


def generate_sample(config, pipeline, filename_prefix):
"""
Generates a video to validate training did not corrupt the model
"""
generate_wan(config, pipeline, filename_prefix)
return generate_wan(config, pipeline, filename_prefix)


def print_ssim(pretrained_video_path, posttrained_video_path):
video_processor = VideoProcessor()
pretrained_video = load_video(pretrained_video_path[0])
pretrained_video = video_processor.preprocess_video(pretrained_video)
pretrained_video = np.array(pretrained_video)
pretrained_video = np.transpose(pretrained_video, (0, 2, 3, 4, 1))

posttrained_video = load_video(posttrained_video_path[0])
posttrained_video = video_processor.preprocess_video(posttrained_video)
posttrained_video = np.array(posttrained_video)
posttrained_video = np.transpose(posttrained_video, (0, 2, 3, 4, 1))
ssim_compare = ssim(pretrained_video[0], posttrained_video[0], multichannel=True, channel_axis=-1, data_range=255)

max_logging.log(f"SSIM score after training is {ssim_compare}")


class WanTrainer(WanCheckpointer):
Expand Down Expand Up @@ -105,7 +124,7 @@ def start_training(self):
# del pipeline.vae

# Generate a sample before training to compare against generated sample after training.
generate_sample(self.config, pipeline, filename_prefix="pre-training-")
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
mesh = pipeline.mesh
data_iterator = self.load_dataset(mesh)

Expand All @@ -115,7 +134,12 @@ def start_training(self):
pipeline.scheduler_state = scheduler_state

optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5)
self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator)

# Returns pipeline with trained transformer state
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, data_iterator)

posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
print_ssim(pretrained_video_path, posttrained_video_path)

def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_iterator):

Expand Down Expand Up @@ -189,8 +213,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
# load new state for trained tranformer
graphdef, _, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)
pipeline.transformer = nnx.merge(graphdef, state[0], rest_of_state)

generate_sample(self.config, pipeline, filename_prefix="post-training-")
return pipeline


def train_step(state, graphdef, scheduler_state, data, rng, scheduler):
Expand Down
Loading