|
1 | | -""" |
2 | | - Copyright 2025 Google LLC |
3 | | -
|
4 | | - Licensed under the Apache License, Version 2.0 (the "License"); |
5 | | - you may not use this file except in compliance with the License. |
6 | | - You may obtain a copy of the License at |
7 | | -
|
8 | | - https://www.apache.org/licenses/LICENSE-2.0 |
9 | | -
|
10 | | - Unless required by applicable law or agreed to in writing, software |
11 | | - distributed under the License is distributed on an "AS IS" BASIS, |
12 | | - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | | - See the License for the specific language governing permissions and |
14 | | - limitations under the License. |
15 | | -""" |
16 | | - |
17 | 1 | from absl import app |
18 | 2 | from typing import Sequence |
19 | 3 | import jax |
20 | 4 | import json |
| 5 | +from flax.linen import partitioning as nn_partitioning |
21 | 6 | from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel |
22 | 7 | import os |
23 | 8 | import functools |
24 | 9 | import jax.numpy as jnp |
25 | 10 | from maxdiffusion import pyconfig |
26 | 11 | from maxdiffusion.max_utils import ( |
27 | 12 | create_device_mesh, |
| 13 | + setup_initial_state, |
| 14 | + get_memory_allocations, |
28 | 15 | ) |
29 | | -from jax.sharding import Mesh |
| 16 | +from jax.sharding import Mesh, PartitionSpec as P |
| 17 | +import orbax.checkpoint as ocp |
30 | 18 |
|
31 | 19 |
|
32 | | -def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond): |
| 20 | +def validate_transformer_inputs( |
| 21 | + prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids |
| 22 | +): |
33 | 23 | print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) |
34 | 24 | print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) |
35 | 25 | print("latents.shape: ", latents.shape, latents.dtype) |
36 | 26 | print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) |
| 27 | + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) |
| 28 | + print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) |
| 29 | + print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) |
| 30 | + |
| 31 | + |
| 32 | +def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids): |
| 33 | + latents, state, noise_cond = args |
| 34 | + noise_pred = transformer.apply( |
| 35 | + {"params": state.params}, |
| 36 | + hidden_states=latents, |
| 37 | + indices_grid=fractional_cords, |
| 38 | + encoder_hidden_states=prompt_embeds, |
| 39 | + timestep=noise_cond, |
| 40 | + segment_ids=segment_ids, |
| 41 | + encoder_attention_segment_ids=encoder_attention_segment_ids, |
| 42 | + ) |
| 43 | + return noise_pred, state, noise_cond |
| 44 | + |
| 45 | + |
| 46 | +def run_inference( |
| 47 | + states, |
| 48 | + transformer, |
| 49 | + config, |
| 50 | + mesh, |
| 51 | + latents, |
| 52 | + fractional_cords, |
| 53 | + prompt_embeds, |
| 54 | + timestep, |
| 55 | + segment_ids, |
| 56 | + encoder_attention_segment_ids, |
| 57 | +): |
| 58 | + transformer_state = states["transformer"] |
| 59 | + loop_body_p = functools.partial( |
| 60 | + loop_body, |
| 61 | + transformer=transformer, |
| 62 | + fractional_cords=fractional_cords, |
| 63 | + prompt_embeds=prompt_embeds, |
| 64 | + segment_ids=segment_ids, |
| 65 | + encoder_attention_segment_ids=encoder_attention_segment_ids, |
| 66 | + ) |
| 67 | + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): |
| 68 | + noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep)) |
| 69 | + return noise_pred |
37 | 70 |
|
38 | 71 |
|
39 | 72 | def run(config): |
40 | | - key = jax.random.PRNGKey(0) |
| 73 | + key = jax.random.PRNGKey(42) |
41 | 74 |
|
42 | 75 | devices_array = create_device_mesh(config) |
43 | | - mesh = Mesh(devices_array, config.mesh_axes) # noqa F841 |
| 76 | + mesh = Mesh(devices_array, config.mesh_axes) |
44 | 77 |
|
45 | | - batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128 |
46 | 78 | base_dir = os.path.dirname(__file__) |
47 | 79 |
|
48 | | - # load in model config |
| 80 | + ##load in model config |
49 | 81 | config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") |
50 | 82 | with open(config_path, "r") as f: |
51 | 83 | model_config = json.load(f) |
| 84 | + relative_ckpt_path = model_config["ckpt_path"] |
| 85 | + |
| 86 | + ignored_keys = [ |
| 87 | + "_class_name", |
| 88 | + "_diffusers_version", |
| 89 | + "_name_or_path", |
| 90 | + "causal_temporal_positioning", |
| 91 | + "in_channels", |
| 92 | + "ckpt_path", |
| 93 | + ] |
| 94 | + in_channels = model_config["in_channels"] |
| 95 | + for name in ignored_keys: |
| 96 | + if name in model_config: |
| 97 | + del model_config[name] |
| 98 | + |
| 99 | + transformer = Transformer3DModel( |
| 100 | + **model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh |
| 101 | + ) |
| 102 | + transformer_param_shapes = transformer.init_weights(in_channels, key, model_config["caption_channels"], eval_only=True) # noqa F841 |
| 103 | + weights_init_fn = functools.partial( |
| 104 | + transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True |
| 105 | + ) |
52 | 106 |
|
53 | | - transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") |
54 | | - transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) # noqa F841 |
| 107 | + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) |
| 108 | + |
| 109 | + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) |
| 110 | + transformer_state, transformer_state_shardings = setup_initial_state( |
| 111 | + model=transformer, |
| 112 | + tx=None, |
| 113 | + config=config, |
| 114 | + mesh=mesh, |
| 115 | + weights_init_fn=weights_init_fn, |
| 116 | + checkpoint_manager=checkpoint_manager, |
| 117 | + checkpoint_item=" ", |
| 118 | + model_params=None, |
| 119 | + training=False, |
| 120 | + ) |
55 | 121 |
|
56 | | - key, split_key = jax.random.split(key) |
57 | | - weights_init_fn = functools.partial( # noqa F841 |
58 | | - transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True |
| 122 | + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) |
| 123 | + get_memory_allocations() |
| 124 | + |
| 125 | + states = {} |
| 126 | + state_shardings = {} |
| 127 | + |
| 128 | + state_shardings["transformer"] = transformer_state_shardings |
| 129 | + states["transformer"] = transformer_state |
| 130 | + |
| 131 | + # create dummy inputs: |
| 132 | + example_inputs = {} |
| 133 | + batch_size, num_tokens = 4, 256 |
| 134 | + input_shapes = { |
| 135 | + "latents": (batch_size, num_tokens, in_channels), |
| 136 | + "fractional_coords": (batch_size, 3, num_tokens), |
| 137 | + "prompt_embeds": (batch_size, 128, model_config["caption_channels"]), |
| 138 | + "timestep": (batch_size, 256), |
| 139 | + "segment_ids": (batch_size, 256), |
| 140 | + "encoder_attention_segment_ids": (batch_size, 128), |
| 141 | + } |
| 142 | + for name, shape in input_shapes.items(): |
| 143 | + example_inputs[name] = jnp.ones( |
| 144 | + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool |
| 145 | + ) |
| 146 | + |
| 147 | + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) |
| 148 | + latents = jax.device_put(example_inputs["latents"], data_sharding) |
| 149 | + prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding) |
| 150 | + fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding) |
| 151 | + noise_cond = jax.device_put(example_inputs["timestep"], data_sharding) |
| 152 | + segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding) |
| 153 | + encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding) |
| 154 | + |
| 155 | + validate_transformer_inputs( |
| 156 | + prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids |
| 157 | + ) |
| 158 | + p_run_inference = jax.jit( |
| 159 | + functools.partial( |
| 160 | + run_inference, |
| 161 | + transformer=transformer, |
| 162 | + config=config, |
| 163 | + mesh=mesh, |
| 164 | + latents=latents, |
| 165 | + fractional_cords=fractional_coords, |
| 166 | + prompt_embeds=prompt_embeds, |
| 167 | + timestep=noise_cond, |
| 168 | + segment_ids=segment_ids, |
| 169 | + encoder_attention_segment_ids=encoder_attention_segment_ids, |
| 170 | + ), |
| 171 | + in_shardings=(state_shardings,), |
| 172 | + out_shardings=None, |
59 | 173 | ) |
60 | 174 |
|
| 175 | + noise_pred = p_run_inference(states).block_until_ready() |
| 176 | + print(noise_pred) # (4, 256, 128) |
| 177 | + |
61 | 178 |
|
62 | 179 | def main(argv: Sequence[str]) -> None: |
63 | 180 | pyconfig.initialize(argv) |
|
0 commit comments