Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxdiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
[
Expand Down Expand Up @@ -453,6 +454,7 @@
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .models.ltx_video.transformers.transformer3d import Transformer3DModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
from .schedulers import (
Expand Down
5 changes: 4 additions & 1 deletion src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,10 @@ def load_state_if_possible(
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
try:
if not enable_single_replica_ckpt_restoring:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
if checkpoint_item == " ":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comment as Juan from previous PR, why is checkpoint == " "

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if checkpoint set to None, cannot pass the check "if checkpoint_manager and checkpoint_item:" in max_utils.py. So I set it to empty string to get around this

return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
else:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))

def map_to_pspec(data):
Expand Down
72 changes: 72 additions & 0 deletions src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2025 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


#hardware
hardware: 'tpu'
skip_jax_distributed_system: False

jax_cache_dir: ''
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'


run_name: ''
output_dir: 'ltx-video-output'
save_config_to_gcs: False

#parallelism
mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']]
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

ici_fsdp_transpose_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_transpose_parallelism: 1
ici_expert_parallelism: 1
ici_sequence_parallelism: 1




learning_rate_schedule_steps: -1
max_train_steps: 500 #TODO: change this
pretrained_model_name_or_path: ''
unet_checkpoint: ''
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
per_device_batch_size: 1
compile_topology_num_slices: -1
quantization_local_shard_count: -1
jit_initializers: True
enable_single_replica_ckpt_restoring: False
117 changes: 117 additions & 0 deletions src/maxdiffusion/generate_ltx_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from absl import app
from typing import Sequence
import jax
import json
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
import os
import functools
import jax.numpy as jnp
from maxdiffusion import pyconfig
from maxdiffusion.max_utils import (
create_device_mesh,
setup_initial_state,
get_memory_allocations,
)
from jax.sharding import Mesh
import orbax.checkpoint as ocp


def validate_transformer_inputs(
prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids
):
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype)
print("latents.shape: ", latents.shape, latents.dtype)
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype)
print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype)


def run(config):

key = jax.random.PRNGKey(42)

devices_array = create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)

base_dir = os.path.dirname(__file__)

##load in model config
config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json")
with open(config_path, "r") as f:
model_config = json.load(f)
ckpt_path = model_config["ckpt_path"]

ignored_keys = [
"_class_name",
"_diffusers_version",
"_name_or_path",
"causal_temporal_positioning",
"in_channels",
"ckpt_path",
]
in_channels = model_config["in_channels"]
for name in ignored_keys:
if name in model_config:
del model_config[name]

transformer = Transformer3DModel(
**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh
)
transformer_param_shapes = transformer.init_weights( # noqa: F841
in_channels, key, model_config["caption_channels"], eval_only=True
)
weights_init_fn = functools.partial(
transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True
)

checkpoint_manager = ocp.CheckpointManager(ckpt_path)
transformer_state, transformer_state_shardings = setup_initial_state(
model=transformer,
tx=None,
config=config,
mesh=mesh,
weights_init_fn=weights_init_fn,
checkpoint_manager=checkpoint_manager,
checkpoint_item=" ",
model_params=None,
training=False,
)

transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
get_memory_allocations()

states = {}
state_shardings = {}

state_shardings["transformer"] = transformer_state_shardings
states["transformer"] = transformer_state


def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
run(pyconfig.config)


if __name__ == "__main__":
app.run(main)



80 changes: 78 additions & 2 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,21 @@ def create_device_mesh(config, devices=None, logging=True):
if devices is None:
devices = jax.devices()
num_devices = len(devices)
##special case for ltx-video
if "fsdp_transpose" in config.mesh_axes:
num_slices = 1
# if config.inference_benchmark_test else config.num_slices
num_devices_per_slice = num_devices // num_slices
# Find possible unspecified parallelisms
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
mesh = mesh_utils.create_device_mesh(
ici_parallelism,
devices,
)
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")

return mesh

try:
num_slices = 1 + max([d.slice_index for d in devices])
except:
Expand Down Expand Up @@ -288,9 +303,66 @@ def create_device_mesh(config, devices=None, logging=True):
if logging:
max_logging.log(f"Decided on mesh: {mesh}")



















return mesh









































def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
"""Unboxes the flax.LogicallyPartitioned pieces in a train state.

Expand Down Expand Up @@ -402,7 +474,11 @@ def setup_initial_state(
config.enable_single_replica_ckpt_restoring,
)
if state:
state = state[checkpoint_item]
###!Edited
if checkpoint_item == " ":
state = state
else:
state = state[checkpoint_item]
if not state:
max_logging.log(f"Could not find the item in orbax, creating state...")
init_train_state_partial = functools.partial(
Expand Down Expand Up @@ -609,4 +685,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
initialize_jax_for_gpu()
max_logging.log("Jax distributed system initialized on GPU!")
else:
jax.distributed.initialize()
jax.distributed.initialize()
5 changes: 2 additions & 3 deletions src/maxdiffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# limitations under the License.

from typing import TYPE_CHECKING

from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available

from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available

_import_structure = {}

Expand All @@ -32,6 +30,7 @@
from .vae_flax import FlaxAutoencoderKL
from .lora import *
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .ltx_video.transformers.transformer3d import Transformer3DModel

else:
import sys
Expand Down
Empty file.
Loading
Loading