@@ -83,8 +83,14 @@ def _write_audio(
8383 samples = torch .clip (samples , - 1.0 , 1.0 )
8484 samples = (samples * 32767.0 ).to (torch .int16 )
8585
86+ if hasattr (samples , "cpu" ):
87+ samples_np = samples .contiguous ().reshape (1 , - 1 ).cpu ().numpy ()
88+ else :
89+ import numpy as np
90+ samples_np = np .reshape (np .ascontiguousarray (samples ), (1 , - 1 ))
91+
8692 frame_in = av .AudioFrame .from_ndarray (
87- samples . contiguous (). reshape ( 1 , - 1 ). cpu (). numpy () ,
93+ samples_np ,
8894 format = "s16" ,
8995 layout = "stereo" ,
9096 )
@@ -108,14 +114,20 @@ def encode_video(
108114 if not import_utils .is_av_available ():
109115 raise ImportError (import_utils .AV_IMPORT_ERROR .format ("encode_video" ))
110116
111- video_np = video .cpu ().numpy ()
117+ if hasattr (video , "cpu" ):
118+ video_np = video .cpu ().numpy ()
119+ else :
120+ video_np = np .array (video )
121+
112122 if video_np .ndim == 4 :
113123 # [F, H, W, C]
114124 _ , height , width , _ = video_np .shape
115125 elif video_np .ndim == 5 :
116- raise ValueError ("encode_video expects a single video tensor of shape [F, H, W, C]" )
126+ # [B, F, H, W, C] -> take the first video in the batch
127+ video_np = video_np [0 ]
128+ _ , height , width , _ = video_np .shape
117129 else :
118- _ , height , width , _ = video_np .shape
130+ raise ValueError ( f"encode_video expects a 4D or 5D video tensor, got { video_np .ndim } D" )
119131
120132 container = av .open (output_path , mode = "w" )
121133 stream = container .add_stream ("libx264" , rate = int (fps ))
0 commit comments