Skip to content

Commit e18128c

Browse files
committed
transformer step and test
1 parent 7e098c5 commit e18128c

7 files changed

Lines changed: 402 additions & 30 deletions

File tree

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,25 @@ weights_dtype: 'bfloat16'
2222
activations_dtype: 'bfloat16'
2323

2424

25+
run_name: ''
26+
output_dir: 'ltx-video-output'
27+
save_config_to_gcs: False
28+
29+
#hardware
30+
hardware: 'tpu'
31+
skip_jax_distributed_system: False
32+
33+
jax_cache_dir: ''
34+
weights_dtype: 'bfloat16'
35+
activations_dtype: 'bfloat16'
36+
37+
2538
run_name: ''
2639
output_dir: 'ltx-video-output'
2740
save_config_to_gcs: False
2841

2942
#parallelism
30-
mesh_axes: ['data', 'fsdp', 'tensor']
43+
mesh_axes: ['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']
3144
logical_axis_rules: [
3245
['batch', 'data'],
3346
['activation_batch', ['data','fsdp']],
@@ -40,13 +53,19 @@ logical_axis_rules: [
4053
['out_channels', 'tensor'],
4154
['conv_out', 'fsdp'],
4255
]
43-
data_sharding: [['data', 'fsdp', 'tensor']]
56+
data_sharding: [['data', 'fsdp', 'tensor', 'fsdp_transpose', 'expert', 'tensor_transpose', 'tensor_sequence', 'sequence']]
4457
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
4558
dcn_fsdp_parallelism: -1
4659
dcn_tensor_parallelism: 1
60+
4761
ici_data_parallelism: -1
4862
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
4963
ici_tensor_parallelism: 1
64+
ici_fsdp_transpose_parallelism: 1
65+
ici_sequence_parallelism: 1
66+
ici_tensor_transpose_parallelism: 1
67+
ici_expert_parallelism: 1
68+
ici_sequence_parallelism: 1
5069

5170

5271

@@ -63,3 +82,4 @@ per_device_batch_size: 1
6382
compile_topology_num_slices: -1
6483
quantization_local_shard_count: -1
6584
jit_initializers: True
85+
enable_single_replica_ckpt_restoring: False

src/maxdiffusion/generate_ltx_video.py

Lines changed: 144 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,180 @@
1-
"""
2-
Copyright 2025 Google LLC
3-
4-
Licensed under the Apache License, Version 2.0 (the "License");
5-
you may not use this file except in compliance with the License.
6-
You may obtain a copy of the License at
7-
8-
https://www.apache.org/licenses/LICENSE-2.0
9-
10-
Unless required by applicable law or agreed to in writing, software
11-
distributed under the License is distributed on an "AS IS" BASIS,
12-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
See the License for the specific language governing permissions and
14-
limitations under the License.
15-
"""
16-
171
from absl import app
182
from typing import Sequence
193
import jax
204
import json
5+
from flax.linen import partitioning as nn_partitioning
216
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
227
import os
238
import functools
249
import jax.numpy as jnp
2510
from maxdiffusion import pyconfig
2611
from maxdiffusion.max_utils import (
2712
create_device_mesh,
13+
setup_initial_state,
14+
get_memory_allocations,
2815
)
29-
from jax.sharding import Mesh
16+
from jax.sharding import Mesh, PartitionSpec as P
17+
import orbax.checkpoint as ocp
3018

3119

32-
def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond):
20+
def validate_transformer_inputs(
21+
prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids
22+
):
3323
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
3424
print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype)
3525
print("latents.shape: ", latents.shape, latents.dtype)
3626
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
27+
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
28+
print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype)
29+
print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype)
30+
31+
32+
def loop_body(step, args, transformer, fractional_cords, prompt_embeds, segment_ids, encoder_attention_segment_ids):
33+
latents, state, noise_cond = args
34+
noise_pred = transformer.apply(
35+
{"params": state.params},
36+
hidden_states=latents,
37+
indices_grid=fractional_cords,
38+
encoder_hidden_states=prompt_embeds,
39+
timestep=noise_cond,
40+
segment_ids=segment_ids,
41+
encoder_attention_segment_ids=encoder_attention_segment_ids,
42+
)
43+
return noise_pred, state, noise_cond
44+
45+
46+
def run_inference(
47+
states,
48+
transformer,
49+
config,
50+
mesh,
51+
latents,
52+
fractional_cords,
53+
prompt_embeds,
54+
timestep,
55+
segment_ids,
56+
encoder_attention_segment_ids,
57+
):
58+
transformer_state = states["transformer"]
59+
loop_body_p = functools.partial(
60+
loop_body,
61+
transformer=transformer,
62+
fractional_cords=fractional_cords,
63+
prompt_embeds=prompt_embeds,
64+
segment_ids=segment_ids,
65+
encoder_attention_segment_ids=encoder_attention_segment_ids,
66+
)
67+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
68+
noise_pred, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep))
69+
return noise_pred
3770

3871

3972
def run(config):
40-
key = jax.random.PRNGKey(0)
73+
key = jax.random.PRNGKey(42)
4174

4275
devices_array = create_device_mesh(config)
43-
mesh = Mesh(devices_array, config.mesh_axes) # noqa F841
76+
mesh = Mesh(devices_array, config.mesh_axes)
4477

45-
batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128
4678
base_dir = os.path.dirname(__file__)
4779

48-
# load in model config
80+
##load in model config
4981
config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json")
5082
with open(config_path, "r") as f:
5183
model_config = json.load(f)
84+
relative_ckpt_path = model_config["ckpt_path"]
85+
86+
ignored_keys = [
87+
"_class_name",
88+
"_diffusers_version",
89+
"_name_or_path",
90+
"causal_temporal_positioning",
91+
"in_channels",
92+
"ckpt_path",
93+
]
94+
in_channels = model_config["in_channels"]
95+
for name in ignored_keys:
96+
if name in model_config:
97+
del model_config[name]
98+
99+
transformer = Transformer3DModel(
100+
**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh
101+
)
102+
transformer_param_shapes = transformer.init_weights(in_channels, key, model_config["caption_channels"], eval_only=True) # noqa F841
103+
weights_init_fn = functools.partial(
104+
transformer.init_weights, in_channels, key, model_config["caption_channels"], eval_only=True
105+
)
52106

53-
transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch")
54-
transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) # noqa F841
107+
absolute_ckpt_path = os.path.abspath(relative_ckpt_path)
108+
109+
checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path)
110+
transformer_state, transformer_state_shardings = setup_initial_state(
111+
model=transformer,
112+
tx=None,
113+
config=config,
114+
mesh=mesh,
115+
weights_init_fn=weights_init_fn,
116+
checkpoint_manager=checkpoint_manager,
117+
checkpoint_item=" ",
118+
model_params=None,
119+
training=False,
120+
)
55121

56-
key, split_key = jax.random.split(key)
57-
weights_init_fn = functools.partial( # noqa F841
58-
transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True
122+
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
123+
get_memory_allocations()
124+
125+
states = {}
126+
state_shardings = {}
127+
128+
state_shardings["transformer"] = transformer_state_shardings
129+
states["transformer"] = transformer_state
130+
131+
# create dummy inputs:
132+
example_inputs = {}
133+
batch_size, num_tokens = 4, 256
134+
input_shapes = {
135+
"latents": (batch_size, num_tokens, in_channels),
136+
"fractional_coords": (batch_size, 3, num_tokens),
137+
"prompt_embeds": (batch_size, 128, model_config["caption_channels"]),
138+
"timestep": (batch_size, 256),
139+
"segment_ids": (batch_size, 256),
140+
"encoder_attention_segment_ids": (batch_size, 128),
141+
}
142+
for name, shape in input_shapes.items():
143+
example_inputs[name] = jnp.ones(
144+
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
145+
)
146+
147+
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
148+
latents = jax.device_put(example_inputs["latents"], data_sharding)
149+
prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding)
150+
fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding)
151+
noise_cond = jax.device_put(example_inputs["timestep"], data_sharding)
152+
segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding)
153+
encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding)
154+
155+
validate_transformer_inputs(
156+
prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids
157+
)
158+
p_run_inference = jax.jit(
159+
functools.partial(
160+
run_inference,
161+
transformer=transformer,
162+
config=config,
163+
mesh=mesh,
164+
latents=latents,
165+
fractional_cords=fractional_coords,
166+
prompt_embeds=prompt_embeds,
167+
timestep=noise_cond,
168+
segment_ids=segment_ids,
169+
encoder_attention_segment_ids=encoder_attention_segment_ids,
170+
),
171+
in_shardings=(state_shardings,),
172+
out_shardings=None,
59173
)
60174

175+
noise_pred = p_run_inference(states).block_until_ready()
176+
print(noise_pred) # (4, 256, 128)
177+
61178

62179
def main(argv: Sequence[str]) -> None:
63180
pyconfig.initialize(argv)

src/maxdiffusion/max_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,21 @@ def create_device_mesh(config, devices=None, logging=True):
257257
if devices is None:
258258
devices = jax.devices()
259259
num_devices = len(devices)
260+
##special case for ltx-video
261+
if config.ici_fsdp_transpose_parallelism:
262+
num_slices = 1
263+
# if config.inference_benchmark_test else config.num_slices
264+
num_devices_per_slice = num_devices // num_slices
265+
# Find possible unspecified parallelisms
266+
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
267+
mesh = mesh_utils.create_device_mesh(
268+
ici_parallelism,
269+
devices,
270+
)
271+
max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
272+
273+
return mesh
274+
260275
try:
261276
num_slices = 1 + max([d.slice_index for d in devices])
262277
except:
@@ -402,7 +417,11 @@ def setup_initial_state(
402417
config.enable_single_replica_ckpt_restoring,
403418
)
404419
if state:
405-
state = state[checkpoint_item]
420+
###!Edited
421+
if checkpoint_item == " ":
422+
state = state
423+
else:
424+
state = state[checkpoint_item]
406425
if not state:
407426
max_logging.log(f"Could not find the item in orbax, creating state...")
408427
init_train_state_partial = functools.partial(

src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
2+
"ckpt_path": "",
23
"activation_fn": "gelu-approximate",
34
"attention_bias": true,
45
"attention_head_dim": 128,

src/maxdiffusion/pyconfig.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ def string_to_bool(s: str) -> bool:
4141
config = None
4242

4343

44+
def create_parallelisms_list(raw_keys):
45+
ici_parallelism = [
46+
raw_keys["ici_data_parallelism"],
47+
raw_keys["ici_fsdp_parallelism"],
48+
raw_keys["ici_fsdp_transpose_parallelism"],
49+
raw_keys["ici_sequence_parallelism"],
50+
raw_keys["ici_tensor_parallelism"],
51+
raw_keys["ici_tensor_transpose_parallelism"],
52+
raw_keys["ici_expert_parallelism"],
53+
raw_keys["ici_sequence_parallelism"],
54+
]
55+
raw_keys["ici_parallelism"] = ici_parallelism
56+
return raw_keys
57+
58+
4459
def print_system_information():
4560
max_logging.log(f"System Information: Jax Version: {jax.__version__}")
4661
max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}")
@@ -154,6 +169,8 @@ def user_init(raw_keys):
154169
raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"])
155170
raw_keys["num_slices"] = get_num_slices(raw_keys)
156171
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
172+
if "ici_fsdp_transpose_parallelism" in raw_keys:
173+
raw_keys = create_parallelisms_list(raw_keys)
157174

158175

159176
def get_num_slices(raw_keys):

0 commit comments

Comments
 (0)