Skip to content

Commit 07bb615

Browse files
committed
add upload video
1 parent 64a404d commit 07bb615

2 files changed

Lines changed: 52 additions & 4 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,50 @@
2020
from absl import app
2121
from maxdiffusion.utils import export_to_video
2222
import os
23+
from google.cloud import storage
24+
25+
def upload_video_to_gcs(output_dir: str, video_path: str):
26+
"""
27+
Uploads a local video file to a specified Google Cloud Storage bucket.
28+
"""
29+
try:
30+
path_without_scheme = output_dir.removeprefix("gs://")
31+
parts = path_without_scheme.split('/', 1)
32+
bucket_name = parts[0]
33+
folder_name = parts[1] if len(parts) > 1 else ''
34+
35+
# Initialize the GCS client
36+
storage_client = storage.Client()
37+
38+
# Get the bucket object
39+
bucket = storage_client.bucket(bucket_name)
40+
41+
# Define the source and destination paths
42+
source_file_path = f"./{video_path}"
43+
destination_blob_name = os.path.join(folder_name, "videos", video_path)
44+
45+
# Create a blob object
46+
blob = bucket.blob(destination_blob_name)
47+
48+
# Upload the file
49+
print(f"Uploading {source_file_path} to {bucket_name}/{destination_blob_name}...")
50+
blob.upload_from_filename(source_file_path)
51+
print(f"Upload complete {source_file_path}.")
52+
53+
except Exception as e:
54+
print(f"An error occurred: {e}")
55+
56+
def delete_file(file_path: str):
57+
# Best practice: Check if the file exists before trying to delete it.
58+
if os.path.exists(file_path):
59+
try:
60+
os.remove(file_path)
61+
print(f"Successfully deleted file: {file_path}")
62+
except OSError as e:
63+
# This catches other issues like permission errors
64+
print(f"Error deleting file '{file_path}': {e}")
65+
else:
66+
print(f"The file '{file_path}' does not exist.")
2367

2468
jax.config.update("jax_use_shardy_partitioner", True)
2569

@@ -44,8 +88,10 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
4488

4589
print(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}")
4690
for i in range(len(videos)):
47-
video_path = os.path.join(config.output_dir, "videos", f"{filename_prefix}wan_output_{config.seed}_{i}.mp4")
91+
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
4892
export_to_video(videos[i], video_path, fps=config.fps)
93+
upload_video_to_gcs(config.output_dir, video_path)
94+
delete_file(f"./{video_path}")
4995
return
5096

5197
def run(config, pipeline=None, filename_prefix=""):
@@ -79,9 +125,10 @@ def run(config, pipeline=None, filename_prefix=""):
79125
print("compile time: ", (time.perf_counter() - s0))
80126
saved_video_path = []
81127
for i in range(len(videos)):
82-
video_path = os.path.join(config.output_dir, "videos", f"{filename_prefix}wan_output_{config.seed}_{i}.mp4")
128+
video_path = os.path.join(f"{filename_prefix}wan_output_{config.seed}_{i}.mp4")
83129
export_to_video(videos[i], video_path, fps=config.fps)
84130
saved_video_path.append(video_path)
131+
upload_video_to_gcs(config.output_dir, video_path)
85132

86133
s0 = time.perf_counter()
87134
videos = pipeline(

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def start_training(self):
153153
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
154154

155155
# save some memory.
156-
del pipeline.vae
157-
del pipeline.vae_cache
156+
# del pipeline.vae
157+
# del pipeline.vae_cache
158158

159159
mesh = pipeline.mesh
160160
train_data_iterator = self.load_dataset(mesh, is_training=True)
@@ -252,6 +252,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
252252
if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0:
253253
# Re-create the iterator each time you start evaluation to reset it
254254
# This assumes your data loading logic can be called to get a fresh iterator.
255+
pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state)
255256
inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-")
256257
eval_data_iterator = self.load_dataset(mesh, is_training=False)
257258
eval_rng = jax.random.key(self.config.seed + step)

0 commit comments

Comments
 (0)