Skip to content

Commit be689df

Browse files
committed
add flag and complete the video generation during eval
1 parent 07bb615 commit be689df

3 files changed

Lines changed: 21 additions & 15 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,4 @@ quantization_calibration_method: "absmax"
298298
# Eval model on per eval_every steps. -1 means don't eval.
299299
eval_every: -1
300300
eval_data_dir: ""
301+
enable_generate_video_for_eval: False # This will increase the used TPU memory.

src/maxdiffusion/generate_wan.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,24 @@ def upload_video_to_gcs(output_dir: str, video_path: str):
4646
blob = bucket.blob(destination_blob_name)
4747

4848
# Upload the file
49-
print(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...")
49+
max_logging.log(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...")
5050
blob.upload_from_filename(source_file_path)
51-
print(f"Upload complete {source_file_path}.")
51+
max_logging.log(f"Upload complete {source_file_path}.")
5252

5353
except Exception as e:
54-
print(f"An error occurred: {e}")
54+
max_logging.log(f"An error occurred: {e}")
5555

5656
def delete_file(file_path: str):
5757
# Best practice: Check if the file exists before trying to delete it.
5858
if os.path.exists(file_path):
5959
try:
6060
os.remove(file_path)
61-
print(f"Successfully deleted file: {file_path}")
61+
max_logging.log(f"Successfully deleted file: {file_path}")
6262
except OSError as e:
6363
# This catches other issues like permission errors
64-
print(f"Error deleting file '{file_path}': {e}")
64+
max_logging.log(f"Error deleting file '{file_path}': {e}")
6565
else:
66-
print(f"The file '{file_path}' does not exist.")
66+
max_logging.log(f"The file '{file_path}' does not exist.")
6767

6868
jax.config.update("jax_use_shardy_partitioner", True)
6969

@@ -86,12 +86,14 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
8686
guidance_scale=config.guidance_scale,
8787
)
8888

89-
print(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}")
89+
max_logging.log(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}")
9090
for i in range(len(videos)):
9191
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
9292
export_to_video(videos[i], video_path, fps=config.fps)
93-
upload_video_to_gcs(config.output_dir, video_path)
94-
delete_file(f"./{video_path}")
93+
if config.output_dir.startswith("gs://"):
94+
upload_video_to_gcs(config.output_dir, video_path)
95+
# Delete local files to avoid storing too manys videoss
96+
delete_file(f"./{video_path}")
9597
return
9698

9799
def run(config, pipeline=None, filename_prefix=""):
@@ -128,7 +130,8 @@ def run(config, pipeline=None, filename_prefix=""):
128130
video_path = os.path.join(f"{filename_prefix}wan_output_{config.seed}_{i}.mp4")
129131
export_to_video(videos[i], video_path, fps=config.fps)
130132
saved_video_path.append(video_path)
131-
upload_video_to_gcs(config.output_dir, video_path)
133+
if config.output_dir.startswith("gs://"):
134+
upload_video_to_gcs(config.output_dir, video_path)
132135

133136
s0 = time.perf_counter()
134137
videos = pipeline(

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,10 @@ def start_training(self):
152152
# Generate a sample before training to compare against generated sample after training.
153153
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
154154

155-
# save some memory.
156-
# del pipeline.vae
157-
# del pipeline.vae_cache
155+
if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval):
156+
# save some memory.
157+
del pipeline.vae
158+
del pipeline.vae_cache
158159

159160
mesh = pipeline.mesh
160161
train_data_iterator = self.load_dataset(mesh, is_training=True)
@@ -250,10 +251,11 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
250251
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)
251252

252253
if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0:
254+
if self.config.enable_generate_video_for_eval:
255+
pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state)
256+
inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-")
253257
# Re-create the iterator each time you start evaluation to reset it
254258
# This assumes your data loading logic can be called to get a fresh iterator.
255-
pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state)
256-
inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-")
257259
eval_data_iterator = self.load_dataset(mesh, is_training=False)
258260
eval_rng = jax.random.key(self.config.seed + step)
259261
eval_metrics = []

0 commit comments

Comments
 (0)