Skip to content

Commit 0b513d9

Browse files
committed
fix
1 parent d8150bb commit 0b513d9

17 files changed

Lines changed: 189 additions & 247 deletions

src/maxdiffusion/generate_ltx2.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
from maxdiffusion import pyconfig, max_logging, max_utils
2323
from absl import app
2424
from google.cloud import storage
25+
from google.api_core.exceptions import GoogleAPIError
2526
import flax
26-
from maxdiffusion.pipelines.ltx2.ltx2_pipeline_utils import encode_video
27+
from maxdiffusion.utils.export_utils import export_to_video_with_audio
2728

2829

2930
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -48,8 +49,8 @@ def upload_video_to_gcs(output_dir: str, video_path: str):
4849
blob.upload_from_filename(source_file_path)
4950
max_logging.log(f"Upload complete {source_file_path}.")
5051

51-
except Exception as e:
52-
max_logging.log(f"An error occurred: {e}")
52+
except GoogleAPIError as e:
53+
max_logging.log(f"A storage error occurred during upload: {e}")
5354

5455

5556
def delete_file(file_path: str):
@@ -163,7 +164,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
163164
video_path = f"{filename_prefix}ltx2_output_{getattr(config, 'seed', 0)}_{i}.mp4"
164165
audio_i = audios[i] if audios is not None else None
165166

166-
encode_video(video=videos[i], fps=fps, audio=audio_i, audio_sample_rate=audio_sample_rate, output_path=video_path)
167+
export_to_video_with_audio(video=videos[i], fps=fps, audio=audio_i, audio_sample_rate=audio_sample_rate, output_path=video_path)
167168

168169
saved_video_path.append(video_path)
169170
if config.output_dir.startswith("gs://"):

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -355,13 +355,13 @@ def __init__(
355355
# 1. Define Partitioned Initializers (Logical Axes)
356356
# Q, K, V kernels: [in_features (embed), out_features (heads)]
357357
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads"))
358-
# Q, K, V biases: [out_features (heads)]
359-
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
358+
# Q, K, V biases: [out_features (embed)]
359+
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
360360

361361
# Out kernel: [in_features (heads), out_features (embed)]
362362
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed"))
363-
# Out bias: [out_features (embed)]
364-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
363+
# Out bias: [out_features (heads)]
364+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
365365

366366
# Norm scales
367367
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
117
import json
218
import torch
319
import jax
420
import jax.numpy as jnp
521
from maxdiffusion import max_logging
622
from huggingface_hub import hf_hub_download
23+
from huggingface_hub.utils import EntryNotFoundError
724
from safetensors import safe_open
825
from flax.traverse_util import unflatten_dict, flatten_dict
926
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
@@ -101,12 +118,12 @@ def load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device):
101118
with safe_open(shard_path, framework="pt") as f:
102119
for k in f.keys():
103120
tensors[k] = torch2jax(f.get_tensor(k))
104-
except Exception:
121+
except EntryNotFoundError:
105122
# Fallback to single file
106123
filename = "diffusion_pytorch_model.safetensors"
107124
try:
108125
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
109-
except Exception:
126+
except EntryNotFoundError:
110127
filename = "diffusion_pytorch_model.bin"
111128
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
112129

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright 2025 Google LLC
2+
Copyright 2026 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.

src/maxdiffusion/models/ltx2/text_encoders/feature_extractor_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright 2025 Google LLC
2+
Copyright 2026 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright 2025 Google LLC
2+
Copyright 2026 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright 2025 Google LLC
2+
Copyright 2026 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -837,8 +837,8 @@ def init_block(rngs):
837837
rngs=rngs,
838838
dtype=self.dtype,
839839
param_dtype=self.weights_dtype,
840-
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")),
841-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
840+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)),
841+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
842842
)
843843

844844
self.audio_norm_out = nnx.LayerNorm(
@@ -850,8 +850,8 @@ def init_block(rngs):
850850
rngs=rngs,
851851
dtype=self.dtype,
852852
param_dtype=self.weights_dtype,
853-
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")),
854-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
853+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)),
854+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
855855
)
856856

857857
def __call__(

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline_utils.py

Lines changed: 0 additions & 148 deletions
This file was deleted.

src/maxdiffusion/tests/ltx2/test_checkpointer_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright 2025 Google LLC
2+
Copyright 2026 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.

src/maxdiffusion/tests/ltx2/test_embeddings_connector_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright 2025 Google LLC
2+
Copyright 2026 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)