@@ -69,6 +69,12 @@ def _write_audio(
6969 samples : torch .Tensor ,
7070 audio_sample_rate : int ,
7171) -> None :
72+ import numpy as np
73+
74+ # If it is a torch tensor, we convert to numpy first
75+ if hasattr (samples , "cpu" ):
76+ samples = samples .contiguous ().cpu ().numpy ()
77+
7278 if samples .ndim == 1 :
7379 samples = samples [:, None ]
7480
@@ -79,15 +85,11 @@ def _write_audio(
7985 raise ValueError (f"Expected samples with 2 channels; got shape { samples .shape } ." )
8086
8187 # Convert to int16 packed for ingestion; resampler converts to encoder fmt.
82- if samples .dtype != torch .int16 :
83- samples = torch .clip (samples , - 1.0 , 1.0 )
84- samples = (samples * 32767.0 ).to ( torch .int16 )
88+ if samples .dtype != np .int16 :
89+ samples = np .clip (samples , - 1.0 , 1.0 )
90+ samples = (samples * 32767.0 ).astype ( np .int16 )
8591
86- if hasattr (samples , "cpu" ):
87- samples_np = samples .contiguous ().reshape (1 , - 1 ).cpu ().numpy ()
88- else :
89- import numpy as np
90- samples_np = np .reshape (np .ascontiguousarray (samples ), (1 , - 1 ))
92+ samples_np = np .reshape (np .ascontiguousarray (samples ), (1 , - 1 ))
9193
9294 frame_in = av .AudioFrame .from_ndarray (
9395 samples_np ,
0 commit comments