2020from absl import app
2121from maxdiffusion .utils import export_to_video
2222import os
23+ from google .cloud import storage
24+
25+ def upload_video_to_gcs (output_dir : str , video_path : str ):
26+ """
27+ Uploads a local video file to a specified Google Cloud Storage bucket.
28+ """
29+ try :
30+ path_without_scheme = output_dir .removeprefix ("gs://" )
31+ parts = path_without_scheme .split ('/' , 1 )
32+ bucket_name = parts [0 ]
33+ folder_name = parts [1 ] if len (parts ) > 1 else ''
34+
35+ # Initialize the GCS client
36+ storage_client = storage .Client ()
37+
38+ # Get the bucket object
39+ bucket = storage_client .bucket (bucket_name )
40+
41+ # Define the source and destination paths
42+ source_file_path = f"./{ video_path } "
43+ destination_blob_name = os .path .join (folder_name , "videos" , video_path )
44+
45+ # Create a blob object
46+ blob = bucket .blob (destination_blob_name )
47+
48+ # Upload the file
49+ print (f"Uploading { source_file_path } to { bucket_name } /{ destination_blob_name } ..." )
50+ blob .upload_from_filename (source_file_path )
51+ print (f"Upload complete { source_file_path } ." )
52+
53+ except Exception as e :
54+ print (f"An error occurred: { e } " )
55+
56+ def delete_file (file_path : str ):
57+ # Best practice: Check if the file exists before trying to delete it.
58+ if os .path .exists (file_path ):
59+ try :
60+ os .remove (file_path )
61+ print (f"Successfully deleted file: { file_path } " )
62+ except OSError as e :
63+ # This catches other issues like permission errors
64+ print (f"Error deleting file '{ file_path } ': { e } " )
65+ else :
66+ print (f"The file '{ file_path } ' does not exist." )
2367
2468jax .config .update ("jax_use_shardy_partitioner" , True )
2569
@@ -44,8 +88,10 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
4488
4589 print (f"video { filename_prefix } , compile time: { (time .perf_counter () - s0 )} " )
4690 for i in range (len (videos )):
47- video_path = os . path . join ( config . output_dir , "videos" , f"{ filename_prefix } wan_output_{ config .seed } _{ i } .mp4" )
91+ video_path = f"{ filename_prefix } wan_output_{ config .seed } _{ i } .mp4"
4892 export_to_video (videos [i ], video_path , fps = config .fps )
93+ upload_video_to_gcs (config .output_dir , video_path )
94+ delete_file (f"./{ video_path } " )
4995 return
5096
5197def run (config , pipeline = None , filename_prefix = "" ):
@@ -79,9 +125,10 @@ def run(config, pipeline=None, filename_prefix=""):
79125 print ("compile time: " , (time .perf_counter () - s0 ))
80126 saved_video_path = []
81127 for i in range (len (videos )):
82- video_path = os .path .join (config . output_dir , "videos" , f"{ filename_prefix } wan_output_{ config .seed } _{ i } .mp4" )
128+ video_path = os .path .join (f"{ filename_prefix } wan_output_{ config .seed } _{ i } .mp4" )
83129 export_to_video (videos [i ], video_path , fps = config .fps )
84130 saved_video_path .append (video_path )
131+ upload_video_to_gcs (config .output_dir , video_path )
85132
86133 s0 = time .perf_counter ()
87134 videos = pipeline (
0 commit comments