Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ else
fi

# Install maxdiffusion
pip3 install -U . || echo "Failed to install maxdiffusion" >&2
pip3 install -e . || echo "Failed to install maxdiffusion" >&2
2 changes: 2 additions & 0 deletions src/maxdiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
_import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
Expand Down Expand Up @@ -453,6 +454,7 @@
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .models.ltx_video.transformers.transformer3d import Transformer3DModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
from .schedulers import (
Expand Down
7 changes: 5 additions & 2 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,11 @@ def load_state_if_possible(
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
try:
if not enable_single_replica_ckpt_restoring:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
if checkpoint_item == " ":
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
else:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))

def map_to_pspec(data):
pspec = data.sharding.spec
Expand Down
57 changes: 57 additions & 0 deletions src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#hardware
hardware: 'tpu'
skip_jax_distributed_system: False

jax_cache_dir: ''
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'


run_name: ''
output_dir: 'ltx-video-output'
save_config_to_gcs: False

#parallelism
mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']]
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1

ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1
ici_fsdp_transpose_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_transpose_parallelism: 1
ici_expert_parallelism: 1
ici_sequence_parallelism: 1




learning_rate_schedule_steps: -1
max_train_steps: 500 #TODO: change this
pretrained_model_name_or_path: ''
unet_checkpoint: ''
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
per_device_batch_size: 1
compile_topology_num_slices: -1
quantization_local_shard_count: -1
jit_initializers: True
enable_single_replica_ckpt_restoring: False
198 changes: 198 additions & 0 deletions src/maxdiffusion/generate_ltx_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from json import encoder
from absl import app
from typing import Sequence
import jax
from flax import linen as nn
import json
from flax.linen import partitioning as nn_partitioning
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
import os
import functools
import jax.numpy as jnp
from maxdiffusion import pyconfig
from maxdiffusion.max_utils import (
create_device_mesh,
setup_initial_state,
get_memory_allocations,
)
from jax.sharding import Mesh, PartitionSpec as P
import orbax.checkpoint as ocp


def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids):
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype)
print("latents.shape: ", latents.shape, latents.dtype)
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype)
print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype)


def loop_body(
step,
args,
transformer,
fractional_cords,
prompt_embeds,
segment_ids,
encoder_attention_segment_ids
):
latents, state, noise_cond = args
noise_pred = transformer.apply(
{"params": state.params},
hidden_states=latents,
indices_grid=fractional_cords,
encoder_hidden_states=prompt_embeds,
timestep=noise_cond,
segment_ids=segment_ids,
encoder_attention_segment_ids=encoder_attention_segment_ids
)
return noise_pred, state, noise_cond



def run_inference(
states, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, segment_ids, encoder_attention_segment_ids
):
transformer_state = states["transformer"]
loop_body_p = functools.partial(
loop_body,
transformer=transformer,
fractional_cords=fractional_cords,
prompt_embeds=prompt_embeds,
segment_ids=segment_ids,
encoder_attention_segment_ids=encoder_attention_segment_ids
)
## TODO: add vae decode step
## TODO: add loop
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep))
return latents

def run(config):
key = jax.random.PRNGKey(0)

devices_array = create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)

base_dir = os.path.dirname(__file__)

##load in model config
config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json")
with open(config_path, "r") as f:
model_config = json.load(f)
relative_ckpt_path = model_config["ckpt_path"]

ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "in_channels", "ckpt_path"]
in_channels = model_config["in_channels"]
for name in ignored_keys:
if name in model_config:
del model_config[name]


transformer = Transformer3DModel(**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh)
transformer_param_shapes = transformer.init_weights(in_channels, model_config['caption_channels'], eval_only = True)

weights_init_fn = functools.partial(
transformer.init_weights,
in_channels,
model_config['caption_channels'],
eval_only = True
)

absolute_ckpt_path = os.path.abspath(relative_ckpt_path)

checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path)
transformer_state, transformer_state_shardings = setup_initial_state(
model=transformer,
tx=None,
config=config,
mesh=mesh,
weights_init_fn=weights_init_fn,
checkpoint_manager=checkpoint_manager,
checkpoint_item=" ",
model_params=None,
training=False,
)




transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
get_memory_allocations()

states = {}
state_shardings = {}

state_shardings["transformer"] = transformer_state_shardings
states["transformer"] = transformer_state

#create dummy inputs:
example_inputs = {}
batch_size, num_tokens = 4, 256
input_shapes = {
"latents": (batch_size, num_tokens, in_channels),
"fractional_coords": (batch_size, 3, num_tokens),
"prompt_embeds": (batch_size, 128, model_config["caption_channels"]),
"timestep": (batch_size, 256),
"segment_ids": (batch_size, 256),
"encoder_attention_segment_ids": (batch_size, 128),
}
for name, shape in input_shapes.items():
example_inputs[name] = jnp.ones(
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
)

data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
latents = jax.device_put(example_inputs["latents"], data_sharding)
prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding)
fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding)
noise_cond = jax.device_put(example_inputs["timestep"], data_sharding)
segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding)
encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding)

validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids)
p_run_inference = jax.jit(
functools.partial(
run_inference,
transformer=transformer,
config=config,
mesh=mesh,
latents=latents,
fractional_cords=fractional_coords,
prompt_embeds=prompt_embeds,
timestep = noise_cond,
segment_ids=segment_ids,
encoder_attention_segment_ids=encoder_attention_segment_ids
),
in_shardings=(state_shardings,),
out_shardings=None,
)

noise_pred = p_run_inference(states).block_until_ready()
print(noise_pred) #(4, 256, 128)


def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
run(pyconfig.config)


if __name__ == "__main__":
app.run(main)















45 changes: 13 additions & 32 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,46 +251,24 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ

return parallelism_vals


def create_device_mesh(config, devices=None, logging=True):
def create_device_mesh(config, devices=None):
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
if devices is None:
devices = jax.devices()
num_devices = len(devices)
try:
num_slices = 1 + max([d.slice_index for d in devices])
except:
num_slices = 1
num_slices = 1
num_devices_per_slice = num_devices // num_slices
max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")

multi_slice_env = num_slices > 1

dcn_parallelism = [
config.dcn_data_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_tensor_parallelism,
]
ici_parallelism = [
config.ici_data_parallelism,
config.ici_fsdp_parallelism,
config.ici_tensor_parallelism,
]

# Find possible unspecified parallelisms
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
if multi_slice_env:
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
else:
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)

if logging:
max_logging.log(f"Decided on mesh: {mesh}")
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
mesh = mesh_utils.create_device_mesh(
ici_parallelism,
devices,
)
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")

return mesh


def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
"""Unboxes the flax.LogicallyPartitioned pieces in a train state.

Expand Down Expand Up @@ -402,7 +380,10 @@ def setup_initial_state(
config.enable_single_replica_ckpt_restoring,
)
if state:
state = state[checkpoint_item]
if checkpoint_item == " ":
state = state
else:
state = state[checkpoint_item]
if not state:
max_logging.log(f"Could not find the item in orbax, creating state...")
init_train_state_partial = functools.partial(
Expand Down Expand Up @@ -609,4 +590,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
initialize_jax_for_gpu()
max_logging.log("Jax distributed system initialized on GPU!")
else:
jax.distributed.initialize()
jax.distributed.initialize()
4 changes: 2 additions & 2 deletions src/maxdiffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import TYPE_CHECKING

from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available


_import_structure = {}
Expand All @@ -32,7 +32,7 @@
from .vae_flax import FlaxAutoencoderKL
from .lora import *
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel

from .ltx_video.transformers.transformer3d import Transformer3DModel
else:
import sys

Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,4 +1188,4 @@ def setup(self):
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
Empty file.
Loading
Loading