diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 31ce039ad..2de705a88 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -49,11 +49,14 @@ def print_ssim(pretrained_video_path, posttrained_video_path): pretrained_video = video_processor.preprocess_video(pretrained_video) pretrained_video = np.array(pretrained_video) pretrained_video = np.transpose(pretrained_video, (0, 2, 3, 4, 1)) + pretrained_video = np.uint8(255 * pretrained_video) posttrained_video = load_video(posttrained_video_path[0]) posttrained_video = video_processor.preprocess_video(posttrained_video) posttrained_video = np.array(posttrained_video) posttrained_video = np.transpose(posttrained_video, (0, 2, 3, 4, 1)) + posttrained_video = np.uint8(255 * posttrained_video) + ssim_compare = ssim(pretrained_video[0], posttrained_video[0], multichannel=True, channel_axis=-1, data_range=255) max_logging.log(f"SSIM score after training is {ssim_compare}")