Skip to content

Commit f0067e1

Browse files
committed
Changes made for moving ltx2_pipeline_utils to export_utils.py
1 parent e75c858 commit f0067e1

2 files changed

Lines changed: 132 additions & 3 deletions

File tree

src/maxdiffusion/generate_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from absl import app
2424
from google.cloud import storage
2525
import flax
26-
from maxdiffusion.pipelines.ltx2.ltx2_pipeline_utils import encode_video
26+
from maxdiffusion.utils.export_utils import export_to_video_with_audio
2727

2828

2929
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -163,7 +163,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
163163
video_path = f"{filename_prefix}ltx2_output_{getattr(config, 'seed', 0)}_{i}.mp4"
164164
audio_i = audios[i] if audios is not None else None
165165

166-
encode_video(video=videos[i], fps=fps, audio=audio_i, audio_sample_rate=audio_sample_rate, output_path=video_path)
166+
export_to_video_with_audio(video=videos[i], fps=fps, audio=audio_i, audio_sample_rate=audio_sample_rate, output_path=video_path)
167167

168168
saved_video_path.append(video_path)
169169
if config.output_dir.startswith("gs://"):

src/maxdiffusion/utils/export_utils.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@
2626
import PIL.Image
2727
import PIL.ImageOps
2828

29-
from .import_utils import BACKENDS_MAPPING, is_imageio_available, is_opencv_available
29+
from .import_utils import AV_IMPORT_ERROR, BACKENDS_MAPPING, is_av_available, is_imageio_available, is_opencv_available
3030
from .logging import get_logger
3131

32+
if is_av_available():
33+
import av
34+
3235

3336
global_rng = random.Random()
3437

@@ -222,3 +225,129 @@ def export_to_video(
222225
writer.append_data(frame)
223226

224227
return output_video_path
228+
229+
230+
def _prepare_audio_stream(container, audio_sample_rate: int):
231+
"""
232+
Prepare the audio stream for writing.
233+
"""
234+
from fractions import Fraction
235+
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
236+
audio_stream.codec_context.sample_rate = audio_sample_rate
237+
audio_stream.codec_context.layout = "stereo"
238+
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
239+
return audio_stream
240+
241+
242+
def _resample_audio(container, audio_stream, frame_in) -> None:
243+
cc = audio_stream.codec_context
244+
245+
target_format = cc.format or "fltp"
246+
target_layout = cc.layout or "stereo"
247+
target_rate = cc.sample_rate or frame_in.sample_rate
248+
249+
audio_resampler = av.audio.resampler.AudioResampler(
250+
format=target_format,
251+
layout=target_layout,
252+
rate=target_rate,
253+
)
254+
255+
audio_next_pts = 0
256+
for rframe in audio_resampler.resample(frame_in):
257+
if rframe.pts is None:
258+
rframe.pts = audio_next_pts
259+
audio_next_pts += rframe.samples
260+
rframe.sample_rate = frame_in.sample_rate
261+
container.mux(audio_stream.encode(rframe))
262+
263+
# flush audio encoder
264+
for packet in audio_stream.encode():
265+
container.mux(packet)
266+
267+
268+
def _write_audio(
269+
container,
270+
audio_stream,
271+
samples: Any,
272+
audio_sample_rate: int,
273+
) -> None:
274+
import numpy as np
275+
276+
samples = np.asarray(samples)
277+
278+
if samples.ndim == 1:
279+
samples = samples[:, None]
280+
281+
# The Vocoder naturally outputs (Channels=2, Time)
282+
if samples.shape[0] == 2 and samples.shape[1] != 2:
283+
samples = samples.T # Now (Time, 2)
284+
285+
if samples.shape[1] != 2:
286+
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
287+
288+
if samples.dtype != np.int16:
289+
samples = np.clip(samples, -1.0, 1.0)
290+
samples = (samples * 32767.0).astype(np.int16)
291+
samples_np = np.ascontiguousarray(samples).reshape(1, -1)
292+
293+
frame_in = av.AudioFrame.from_ndarray(
294+
samples_np,
295+
format="s16",
296+
layout="stereo",
297+
)
298+
frame_in.sample_rate = audio_sample_rate
299+
300+
_resample_audio(container, audio_stream, frame_in)
301+
302+
303+
def export_to_video_with_audio(video: Any, fps: int, audio: Optional[Any], audio_sample_rate: Optional[int], output_path: str) -> None:
304+
"""
305+
Encodes video (and optionally audio) to a file using PyAV.
306+
Args:
307+
video: Video array-like [F, H, W, C] (frames, height, width, channels)
308+
fps: Frames per second
309+
audio: Audio array-like [C, L] or [L, C]
310+
audio_sample_rate: Audio sample rate
311+
output_path: Output file path
312+
"""
313+
if not is_av_available():
314+
raise ImportError(AV_IMPORT_ERROR.format("export_to_video_with_audio"))
315+
316+
video_np = np.asarray(video)
317+
318+
if video_np.ndim == 4:
319+
# [F, H, W, C]
320+
_, height, width, _ = video_np.shape
321+
elif video_np.ndim == 5:
322+
# [B, F, H, W, C] -> take the first video in the batch
323+
video_np = video_np[0]
324+
_, height, width, _ = video_np.shape
325+
else:
326+
raise ValueError(f"export_to_video_with_audio expects a 4D or 5D video tensor, got {video_np.ndim}D")
327+
328+
container = av.open(output_path, mode="w")
329+
stream = container.add_stream("libx264", rate=int(fps))
330+
stream.width = width
331+
stream.height = height
332+
stream.pix_fmt = "yuv420p"
333+
334+
if audio is not None:
335+
if audio_sample_rate is None:
336+
raise ValueError("audio_sample_rate is required when audio is provided")
337+
338+
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
339+
340+
for frame_array in video_np:
341+
# frame_array is [H, W, C]
342+
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
343+
for packet in stream.encode(frame):
344+
container.mux(packet)
345+
346+
# Flush encoder
347+
for packet in stream.encode():
348+
container.mux(packet)
349+
350+
if audio is not None:
351+
_write_audio(container, audio_stream, audio, audio_sample_rate)
352+
353+
container.close()

0 commit comments

Comments
 (0)