Skip to content

Commit 460f7db

Browse files
committed
changed ckpt name
1 parent 740d403 commit 460f7db

13 files changed

Lines changed: 1624 additions & 1874 deletions

File tree

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ pytest==8.2.2
2323
tensorflow>=2.17.0
2424
tensorflow-datasets>=4.9.6
2525
ruff>=0.1.5,<=0.2
26-
git+https://github.com/mlperf/logging.git
2726
git+https://github.com/Lightricks/LTX-Video
2827
git+https://github.com/zmelumian972/xla@torchax/jittable_module_callable#subdirectory=torchax
2928
opencv-python-headless==4.10.0.84

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def load_state_if_possible(
213213
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
214214
try:
215215
if not enable_single_replica_ckpt_restoring:
216-
if checkpoint_item == " ":
216+
if checkpoint_item == "ltxvid_transformer":
217217
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
218218
else:
219219
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}

src/maxdiffusion/generate_ltx_video.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Sequence
2020
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
2121
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline
22-
from maxdiffusion import pyconfig
22+
from maxdiffusion import pyconfig, max_logging
2323
import imageio
2424
from datetime import datetime
2525
import os
@@ -108,7 +108,7 @@ def run(config):
108108
enhance_prompt=enhance_prompt,
109109
seed=config.seed,
110110
)
111-
print("generation time: ", (time.perf_counter() - s0))
111+
max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.")
112112

113113
(pad_left, pad_right, pad_top, pad_bottom) = padding
114114
pad_bottom = -pad_bottom
@@ -146,7 +146,6 @@ def run(config):
146146
resolution=(height, width, config.num_frames),
147147
dir=output_dir,
148148
)
149-
print(output_filename)
150149
# Write video
151150
with imageio.get_writer(output_filename, fps=fps) as video:
152151
for frame in video_np:

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def setup_initial_state(
405405
config.enable_single_replica_ckpt_restoring,
406406
)
407407
if state:
408-
if checkpoint_item == " ":
408+
if checkpoint_item == "ltxvid_transformer":
409409
state = state
410410
else:
411411
state = state[checkpoint_item]

src/maxdiffusion/models/ltx_video/linear.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ def compute_dot_general(inputs, kernel, axis, contract_ind):
9595
axis = _normalize_axes(axis, inputs.ndim)
9696

9797
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
98-
# kernel_in_axis = np.arange(len(axis))
99-
# kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
10098
kernel = self.param(
10199
"kernel",
102100
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),

src/maxdiffusion/models/ltx_video/transformers/adaln.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray:
126126

127127

128128
class AlphaCombinedTimestepSizeEmbeddings(nn.Module):
129-
""" """
130129

131130
embedding_dim: int
132131
size_emb_dim: int

0 commit comments

Comments
 (0)