3131from maxdiffusion .input_pipeline .input_pipeline_interface import (make_data_iterator )
3232from maxdiffusion .generate_wan import run as generate_wan
3333from maxdiffusion .train_utils import (_tensorboard_writer_worker , load_next_batch , _metrics_queue )
34+ from ..video_processor import VideoProcessor
35+ from ..utils import load_video
36+ from skimage .metrics import structural_similarity as ssim
3437
3538
3639def generate_sample (config , pipeline , filename_prefix ):
3740 """
3841 Generates a video to validate training did not corrupt the model
3942 """
40- generate_wan (config , pipeline , filename_prefix )
43+ return generate_wan (config , pipeline , filename_prefix )
44+
45+
46+ def print_ssim (pretrained_video_path , posttrained_video_path ):
47+ video_processor = VideoProcessor ()
48+ pretrained_video = load_video (pretrained_video_path [0 ])
49+ pretrained_video = video_processor .preprocess_video (pretrained_video )
50+ pretrained_video = np .array (pretrained_video )
51+ pretrained_video = np .transpose (pretrained_video , (0 , 2 , 3 , 4 , 1 ))
52+
53+ posttrained_video = load_video (posttrained_video_path [0 ])
54+ posttrained_video = video_processor .preprocess_video (posttrained_video )
55+ posttrained_video = np .array (posttrained_video )
56+ posttrained_video = np .transpose (posttrained_video , (0 , 2 , 3 , 4 , 1 ))
57+ ssim_compare = ssim (pretrained_video [0 ], posttrained_video [0 ], multichannel = True , channel_axis = - 1 , data_range = 255 )
58+
59+ max_logging .log (f"SSIM score after training is { ssim_compare } " )
4160
4261
4362class WanTrainer (WanCheckpointer ):
@@ -105,7 +124,7 @@ def start_training(self):
105124 # del pipeline.vae
106125
107126 # Generate a sample before training to compare against generated sample after training.
108- generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
127+ pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
109128 mesh = pipeline .mesh
110129 data_iterator = self .load_dataset (mesh )
111130
@@ -115,7 +134,12 @@ def start_training(self):
115134 pipeline .scheduler_state = scheduler_state
116135
117136 optimizer , learning_rate_scheduler = self ._create_optimizer (pipeline .transformer , self .config , 1e-5 )
118- self .training_loop (pipeline , optimizer , learning_rate_scheduler , data_iterator )
137+
138+ # Returns pipeline with trained transformer state
139+ pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , data_iterator )
140+
141+ posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
142+ print_ssim (pretrained_video_path , posttrained_video_path )
119143
120144 def training_loop (self , pipeline , optimizer , learning_rate_scheduler , data_iterator ):
121145
@@ -189,8 +213,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
189213 # load new state for trained tranformer
190214 graphdef , _ , rest_of_state = nnx .split (pipeline .transformer , nnx .Param , ...)
191215 pipeline .transformer = nnx .merge (graphdef , state [0 ], rest_of_state )
192-
193- generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
216+ return pipeline
194217
195218
196219def train_step (state , graphdef , scheduler_state , data , rng , scheduler ):
0 commit comments