Skip to content

Commit b2ef6ba

Browse files
committed
added transformer_step_test
1 parent ed30ace commit b2ef6ba

1 file changed

Lines changed: 218 additions & 0 deletions

File tree

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
import torch
19+
import jax
20+
import numpy as np
21+
import jax.numpy as jnp
22+
import unittest
23+
from absl.testing import absltest
24+
from flax import nnx
25+
from jax.sharding import Mesh
26+
27+
from json import encoder
28+
from absl import app
29+
from typing import Sequence
30+
import jax
31+
from flax import linen as nn
32+
import json
33+
from flax.linen import partitioning as nn_partitioning
34+
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
35+
import os
36+
import functools
37+
import jax.numpy as jnp
38+
from maxdiffusion import pyconfig
39+
from maxdiffusion.max_utils import (
40+
create_device_mesh,
41+
setup_initial_state,
42+
get_memory_allocations,
43+
)
44+
from jax.sharding import Mesh, PartitionSpec as P
45+
import orbax.checkpoint as ocp
46+
47+
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
48+
49+
def load_ref_prediction(): ###TODO: change these paths!
50+
# saved_input_path = "/home/serenagu_google_com/LTX-Video/ltx_video/pipelines/transformer_test_input_data"
51+
# tensor_dict = torch.load(saved_input_path)
52+
# latent_model_input = jnp.asarray(tensor_dict["latent_model_input"].to(torch.float32)
53+
# )
54+
# prompt_embeds_batch = jnp.asarray(tensor_dict["encoder_hidden_states"].to(torch.float32)
55+
# )
56+
# fractional_coords = jnp.asarray(tensor_dict["indices_grid"].to(torch.float32)
57+
# )
58+
# prompt_attention_mask_batch = jnp.asarray(tensor_dict["encoder_attention_segment_ids"].to(torch.float32)
59+
# )
60+
# timestep = jnp.asarray(tensor_dict["timestep"].to(torch.float32))
61+
# segment_ids = None
62+
saved_prediction_path = "/home/serenagu_google_com/LTX-Video/ltx_video/pipelines/schedulerTest2.0"
63+
predict_dict = torch.load(saved_prediction_path)
64+
noise_pred_pt = predict_dict["noise_pred"].to(torch.float32)
65+
return noise_pred_pt
66+
67+
def loop_body(
68+
step,
69+
args,
70+
transformer,
71+
fractional_cords,
72+
prompt_embeds,
73+
segment_ids,
74+
encoder_attention_segment_ids
75+
):
76+
latents, state, noise_cond = args
77+
noise_pred = transformer.apply(
78+
{"params": state.params},
79+
hidden_states=latents,
80+
indices_grid=fractional_cords,
81+
encoder_hidden_states=prompt_embeds,
82+
timestep=noise_cond,
83+
segment_ids=segment_ids,
84+
encoder_attention_segment_ids=encoder_attention_segment_ids
85+
)
86+
return noise_pred, state, noise_cond
87+
88+
89+
def run_inference(
90+
states, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, segment_ids, encoder_attention_segment_ids
91+
):
92+
transformer_state = states["transformer"]
93+
loop_body_p = functools.partial(
94+
loop_body,
95+
transformer=transformer,
96+
fractional_cords=fractional_cords,
97+
prompt_embeds=prompt_embeds,
98+
segment_ids=segment_ids,
99+
encoder_attention_segment_ids=encoder_attention_segment_ids
100+
)
101+
## TODO: add vae decode step
102+
## TODO: add loop
103+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
104+
latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep))
105+
return latents
106+
107+
108+
class LTXTransformerTest(unittest.TestCase):
109+
110+
def test_one_step_transformer(self):
111+
pyconfig.initialize(
112+
[
113+
None,
114+
os.path.join(THIS_DIR, "..", "configs", "ltx_video.yml"),
115+
],
116+
unittest=True
117+
)
118+
config = pyconfig.config
119+
noise_pred_pt = load_ref_prediction()
120+
121+
122+
#set up transformer
123+
key = jax.random.PRNGKey(0)
124+
devices_array = create_device_mesh(config)
125+
mesh = Mesh(devices_array, config.mesh_axes)
126+
config_path = "/home/serenagu_google_com/maxdiffusion/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json"
127+
with open(config_path, "r") as f:
128+
model_config = json.load(f)
129+
relative_ckpt_path = model_config["ckpt_path"]
130+
ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "in_channels", "ckpt_path"]
131+
in_channels = model_config["in_channels"]
132+
for name in ignored_keys:
133+
if name in model_config:
134+
del model_config[name]
135+
136+
transformer = Transformer3DModel(**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh)
137+
weights_init_fn = functools.partial(
138+
transformer.init_weights,
139+
in_channels,
140+
model_config['caption_channels'],
141+
eval_only = True
142+
)
143+
144+
absolute_ckpt_path = os.path.abspath(relative_ckpt_path)
145+
146+
checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path)
147+
transformer_state, transformer_state_shardings = setup_initial_state(
148+
model=transformer,
149+
tx=None,
150+
config=config,
151+
mesh=mesh,
152+
weights_init_fn=weights_init_fn,
153+
checkpoint_manager=checkpoint_manager,
154+
checkpoint_item=" ",
155+
model_params=None,
156+
training=False,
157+
)
158+
159+
160+
161+
162+
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
163+
get_memory_allocations()
164+
165+
states = {}
166+
state_shardings = {}
167+
168+
state_shardings["transformer"] = transformer_state_shardings
169+
states["transformer"] = transformer_state
170+
example_inputs = {}
171+
batch_size, num_tokens = 4, 256
172+
input_shapes = {
173+
"latents": (batch_size, num_tokens, in_channels),
174+
"fractional_coords": (batch_size, 3, num_tokens),
175+
"prompt_embeds": (batch_size, 128, model_config["caption_channels"]),
176+
"timestep": (batch_size, 256),
177+
"segment_ids": (batch_size, 256),
178+
"encoder_attention_segment_ids": (batch_size, 128),
179+
}
180+
for name, shape in input_shapes.items():
181+
example_inputs[name] = jnp.ones(
182+
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
183+
)
184+
185+
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
186+
latents = jax.device_put(example_inputs["latents"], data_sharding)
187+
prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding)
188+
fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding)
189+
noise_cond = jax.device_put(example_inputs["timestep"], data_sharding)
190+
segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding)
191+
encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding)
192+
193+
194+
p_run_inference = jax.jit(
195+
functools.partial(
196+
run_inference,
197+
transformer=transformer,
198+
config=config,
199+
mesh=mesh,
200+
latents=latents,
201+
fractional_cords=fractional_coords,
202+
prompt_embeds=prompt_embeds,
203+
timestep = noise_cond,
204+
segment_ids=segment_ids,
205+
encoder_attention_segment_ids=encoder_attention_segment_ids
206+
),
207+
in_shardings=(state_shardings,),
208+
out_shardings=None,
209+
)
210+
noise_pred = p_run_inference(states).block_until_ready()
211+
noise_pred = torch.from_numpy(np.array(noise_pred))
212+
213+
214+
215+
torch.testing.assert_close(noise_pred_pt, noise_pred, atol = 0.025, rtol = 20)
216+
217+
if __name__ == "__main__":
218+
absltest.main()

0 commit comments

Comments
 (0)