Skip to content
Closed
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxdiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
[
Expand Down Expand Up @@ -453,6 +454,7 @@
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .models.ltx_video.transformers.transformer3d import Transformer3DModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
from .schedulers import (
Expand Down
68 changes: 68 additions & 0 deletions src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2025 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


#hardware
hardware: 'tpu'
skip_jax_distributed_system: False

jax_cache_dir: ''
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'


run_name: ''
output_dir: 'ltx-video-output'
save_config_to_gcs: False

#parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_heads', 'fsdp'],
['activation_batch', ['data','fsdp']],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['norm', 'fsdp'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_in', 'fsdp']
]
data_sharding: [['data', 'fsdp', 'tensor']]
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1




learning_rate_schedule_steps: -1
max_train_steps: 500 #TODO: change this
pretrained_model_name_or_path: ''
unet_checkpoint: ''
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
per_device_batch_size: 1
compile_topology_num_slices: -1
quantization_local_shard_count: -1
jit_initializers: True
enable_single_replica_ckpt_restoring: False
186 changes: 186 additions & 0 deletions src/maxdiffusion/generate_ltx_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from absl import app
Comment thread
Serenagu525 marked this conversation as resolved.
from typing import Sequence
import jax
import json
from flax.linen import partitioning as nn_partitioning
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
import os
import functools
import jax.numpy as jnp
from maxdiffusion import pyconfig
from maxdiffusion.max_utils import (
create_device_mesh,
setup_initial_state,
get_memory_allocations,
)
from jax.sharding import Mesh, PartitionSpec as P
import orbax.checkpoint as ocp


def validate_transformer_inputs(
Comment thread
Serenagu525 marked this conversation as resolved.
prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids
):
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype)
print("latents.shape: ", latents.shape, latents.dtype)
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype)
print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype)


def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids):
latents, state, noise_cond = args
noise_pred = transformer.apply(
{"params": state.params},
hidden_states=latents,
indices_grid=fractional_cords,
encoder_hidden_states=prompt_embeds,
timestep=noise_cond,
segment_ids=segment_ids,
encoder_attention_segment_ids=encoder_attention_segment_ids,
)
return noise_pred, state, noise_cond


def run_inference(
states,
transformer,
config,
mesh,
latents,
fractional_cords,
prompt_embeds,
timestep,
segment_ids,
encoder_attention_segment_ids,
):
transformer_state = states["transformer"]
loop_body_p = functools.partial(
loop_body,
transformer=transformer,
fractional_cords=fractional_cords,
prompt_embeds=prompt_embeds,
segment_ids=segment_ids,
encoder_attention_segment_ids=encoder_attention_segment_ids,
)

with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep))
return noise_pred


def run(config):
key = jax.random.PRNGKey(42)

devices_array = create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)

base_dir = os.path.dirname(__file__)

##load in model config
config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json")
with open(config_path, "r") as f:
model_config = json.load(f)
relative_ckpt_path = model_config["ckpt_path"]

ignored_keys = [
"_class_name",
"_diffusers_version",
"_name_or_path",
"causal_temporal_positioning",
"in_channels",
"ckpt_path",
]
in_channels = model_config["in_channels"]
for name in ignored_keys:
if name in model_config:
del model_config[name]

transformer = Transformer3DModel(
**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh
)
transformer_param_shapes = transformer.init_weights(in_channels, key, model_config["caption_channels"], eval_only=True) # noqa F841
weights_init_fn = functools.partial(
transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True
)

absolute_ckpt_path = os.path.abspath(relative_ckpt_path)

checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path)
transformer_state, transformer_state_shardings = setup_initial_state(
model=transformer,
tx=None,
config=config,
mesh=mesh,
weights_init_fn=weights_init_fn,
checkpoint_manager=checkpoint_manager,
checkpoint_item=" ",
model_params=None,
training=False,
)

transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
get_memory_allocations()

states = {}
state_shardings = {}

state_shardings["transformer"] = transformer_state_shardings
states["transformer"] = transformer_state

# create dummy inputs:
example_inputs = {}
batch_size, num_tokens = 4, 256
input_shapes = {
"latents": (batch_size, num_tokens, in_channels),
"fractional_coords": (batch_size, 3, num_tokens),
"prompt_embeds": (batch_size, 128, model_config["caption_channels"]),
"timestep": (batch_size, 256),
"segment_ids": (batch_size, 256),
"encoder_attention_segment_ids": (batch_size, 128),
}
for name, shape in input_shapes.items():
example_inputs[name] = jnp.ones(
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
)

data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
latents = jax.device_put(example_inputs["latents"], data_sharding)
prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding)
fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding)
noise_cond = jax.device_put(example_inputs["timestep"], data_sharding)
segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding)
encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding)

validate_transformer_inputs(
prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids
)
p_run_inference = jax.jit(
functools.partial(
run_inference,
transformer=transformer,
config=config,
mesh=mesh,
latents=latents,
fractional_cords=fractional_coords,
prompt_embeds=prompt_embeds,
timestep=noise_cond,
segment_ids=segment_ids,
encoder_attention_segment_ids=encoder_attention_segment_ids,
),
in_shardings=(state_shardings,),
out_shardings=None,
)

noise_pred = p_run_inference(states).block_until_ready()
print(noise_pred) # (4, 256, 128)


def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
run(pyconfig.config)


if __name__ == "__main__":
app.run(main)
2 changes: 1 addition & 1 deletion src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,4 +609,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
initialize_jax_for_gpu()
max_logging.log("Jax distributed system initialized on GPU!")
else:
jax.distributed.initialize()
jax.distributed.initialize()
5 changes: 2 additions & 3 deletions src/maxdiffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# limitations under the License.

from typing import TYPE_CHECKING

from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available

from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available

_import_structure = {}

Expand All @@ -32,6 +30,7 @@
from .vae_flax import FlaxAutoencoderKL
from .lora import *
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .ltx_video.transformers.transformer3d import Transformer3DModel

else:
import sys
Expand Down
Comment thread
Serenagu525 marked this conversation as resolved.
Empty file.
70 changes: 70 additions & 0 deletions src/maxdiffusion/models/ltx_video/gradient_checkpoint.py
Comment thread
Serenagu525 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from enum import Enum, auto
from typing import Optional

import jax
from flax import linen as nn

SKIP_GRADIENT_CHECKPOINT_KEY = "skip"


class GradientCheckpointType(Enum):
"""
Defines the type of the gradient checkpoint we will have

NONE - means no gradient checkpoint
FULL - means full gradient checkpoint, wherever possible (minimum memory usage)
MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
except for ones that involve batch dimension - that means that all attention and projection
layers will have gradient checkpoint, but not the backward with respect to the parameters
"""

NONE = auto()
FULL = auto()
MATMUL_WITHOUT_BATCH = auto()

@classmethod
def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType":
"""
Constructs the gradient checkpoint type from a string

Args:
s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None.

Returns:
GradientCheckpointType: The policy that corresponds to the string
"""
if s is None:
s = "none"
return GradientCheckpointType[s.upper()]

def to_jax_policy(self):
"""
Converts the gradient checkpoint type to a jax policy
"""
match self:
case GradientCheckpointType.NONE:
return SKIP_GRADIENT_CHECKPOINT_KEY
case GradientCheckpointType.FULL:
return None
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims

def apply(self, module: nn.Module) -> nn.Module:
"""
Applies a gradient checkpoint policy to a module
if no policy is needed, it will return the module as is

Args:
module (nn.Module): the module to apply the policy to

Returns:
nn.Module: the module with the policy applied
"""
policy = self.to_jax_policy()
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
return module
return nn.remat( # pylint: disable=invalid-name
module,
prevent_cse=False,
policy=policy,
)
Loading
Loading