1+ """
2+ Copyright 2025 Google LLC
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ https://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ """
16+
17+ import os
18+ import torch
19+ import jax
20+ import numpy as np
21+ import jax .numpy as jnp
22+ import unittest
23+ from absl .testing import absltest
24+ from flax import nnx
25+ from jax .sharding import Mesh
26+
27+ from json import encoder
28+ from absl import app
29+ from typing import Sequence
30+ import jax
31+ from flax import linen as nn
32+ import json
33+ from flax .linen import partitioning as nn_partitioning
34+ from maxdiffusion .models .ltx_video .transformers .transformer3d import Transformer3DModel
35+ import os
36+ import functools
37+ import jax .numpy as jnp
38+ from maxdiffusion import pyconfig
39+ from maxdiffusion .max_utils import (
40+ create_device_mesh ,
41+ setup_initial_state ,
42+ get_memory_allocations ,
43+ )
44+ from jax .sharding import Mesh , PartitionSpec as P
45+ import orbax .checkpoint as ocp
46+
47+ THIS_DIR = os .path .dirname (os .path .abspath (__file__ ))
48+
49+ def load_ref_prediction (): ###TODO: change these paths!
50+ # saved_input_path = "/home/serenagu_google_com/LTX-Video/ltx_video/pipelines/transformer_test_input_data"
51+ # tensor_dict = torch.load(saved_input_path)
52+ # latent_model_input = jnp.asarray(tensor_dict["latent_model_input"].to(torch.float32)
53+ # )
54+ # prompt_embeds_batch = jnp.asarray(tensor_dict["encoder_hidden_states"].to(torch.float32)
55+ # )
56+ # fractional_coords = jnp.asarray(tensor_dict["indices_grid"].to(torch.float32)
57+ # )
58+ # prompt_attention_mask_batch = jnp.asarray(tensor_dict["encoder_attention_segment_ids"].to(torch.float32)
59+ # )
60+ # timestep = jnp.asarray(tensor_dict["timestep"].to(torch.float32))
61+ # segment_ids = None
62+ saved_prediction_path = "/home/serenagu_google_com/LTX-Video/ltx_video/pipelines/schedulerTest2.0"
63+ predict_dict = torch .load (saved_prediction_path )
64+ noise_pred_pt = predict_dict ["noise_pred" ].to (torch .float32 )
65+ return noise_pred_pt
66+
67+ def loop_body (
68+ step ,
69+ args ,
70+ transformer ,
71+ fractional_cords ,
72+ prompt_embeds ,
73+ segment_ids ,
74+ encoder_attention_segment_ids
75+ ):
76+ latents , state , noise_cond = args
77+ noise_pred = transformer .apply (
78+ {"params" : state .params },
79+ hidden_states = latents ,
80+ indices_grid = fractional_cords ,
81+ encoder_hidden_states = prompt_embeds ,
82+ timestep = noise_cond ,
83+ segment_ids = segment_ids ,
84+ encoder_attention_segment_ids = encoder_attention_segment_ids
85+ )
86+ return noise_pred , state , noise_cond
87+
88+
89+ def run_inference (
90+ states , transformer , config , mesh , latents , fractional_cords , prompt_embeds , timestep , segment_ids , encoder_attention_segment_ids
91+ ):
92+ transformer_state = states ["transformer" ]
93+ loop_body_p = functools .partial (
94+ loop_body ,
95+ transformer = transformer ,
96+ fractional_cords = fractional_cords ,
97+ prompt_embeds = prompt_embeds ,
98+ segment_ids = segment_ids ,
99+ encoder_attention_segment_ids = encoder_attention_segment_ids
100+ )
101+ ## TODO: add vae decode step
102+ ## TODO: add loop
103+ with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
104+ latents , transformer_state , _ = jax .lax .fori_loop (0 , 1 , loop_body_p , (latents , transformer_state , timestep ))
105+ return latents
106+
107+
108+ class LTXTransformerTest (unittest .TestCase ):
109+
110+ def test_one_step_transformer (self ):
111+ pyconfig .initialize (
112+ [
113+ None ,
114+ os .path .join (THIS_DIR , ".." , "configs" , "ltx_video.yml" ),
115+ ],
116+ unittest = True
117+ )
118+ config = pyconfig .config
119+ noise_pred_pt = load_ref_prediction ()
120+
121+
122+ #set up transformer
123+ key = jax .random .PRNGKey (0 )
124+ devices_array = create_device_mesh (config )
125+ mesh = Mesh (devices_array , config .mesh_axes )
126+ config_path = "/home/serenagu_google_com/maxdiffusion/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json"
127+ with open (config_path , "r" ) as f :
128+ model_config = json .load (f )
129+ relative_ckpt_path = model_config ["ckpt_path" ]
130+ ignored_keys = ["_class_name" , "_diffusers_version" , "_name_or_path" , "causal_temporal_positioning" , "in_channels" , "ckpt_path" ]
131+ in_channels = model_config ["in_channels" ]
132+ for name in ignored_keys :
133+ if name in model_config :
134+ del model_config [name ]
135+
136+ transformer = Transformer3DModel (** model_config , dtype = jnp .float32 , gradient_checkpointing = "matmul_without_batch" , sharding_mesh = mesh )
137+ weights_init_fn = functools .partial (
138+ transformer .init_weights ,
139+ in_channels ,
140+ model_config ['caption_channels' ],
141+ eval_only = True
142+ )
143+
144+ absolute_ckpt_path = os .path .abspath (relative_ckpt_path )
145+
146+ checkpoint_manager = ocp .CheckpointManager (absolute_ckpt_path )
147+ transformer_state , transformer_state_shardings = setup_initial_state (
148+ model = transformer ,
149+ tx = None ,
150+ config = config ,
151+ mesh = mesh ,
152+ weights_init_fn = weights_init_fn ,
153+ checkpoint_manager = checkpoint_manager ,
154+ checkpoint_item = " " ,
155+ model_params = None ,
156+ training = False ,
157+ )
158+
159+
160+
161+
162+ transformer_state = jax .device_put (transformer_state , transformer_state_shardings )
163+ get_memory_allocations ()
164+
165+ states = {}
166+ state_shardings = {}
167+
168+ state_shardings ["transformer" ] = transformer_state_shardings
169+ states ["transformer" ] = transformer_state
170+ example_inputs = {}
171+ batch_size , num_tokens = 4 , 256
172+ input_shapes = {
173+ "latents" : (batch_size , num_tokens , in_channels ),
174+ "fractional_coords" : (batch_size , 3 , num_tokens ),
175+ "prompt_embeds" : (batch_size , 128 , model_config ["caption_channels" ]),
176+ "timestep" : (batch_size , 256 ),
177+ "segment_ids" : (batch_size , 256 ),
178+ "encoder_attention_segment_ids" : (batch_size , 128 ),
179+ }
180+ for name , shape in input_shapes .items ():
181+ example_inputs [name ] = jnp .ones (
182+ shape , dtype = jnp .float32 if name not in ["attention_mask" , "encoder_attention_mask" ] else jnp .bool
183+ )
184+
185+ data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
186+ latents = jax .device_put (example_inputs ["latents" ], data_sharding )
187+ prompt_embeds = jax .device_put (example_inputs ["prompt_embeds" ], data_sharding )
188+ fractional_coords = jax .device_put (example_inputs ["fractional_coords" ], data_sharding )
189+ noise_cond = jax .device_put (example_inputs ["timestep" ], data_sharding )
190+ segment_ids = jax .device_put (example_inputs ["segment_ids" ], data_sharding )
191+ encoder_attention_segment_ids = jax .device_put (example_inputs ["encoder_attention_segment_ids" ], data_sharding )
192+
193+
194+ p_run_inference = jax .jit (
195+ functools .partial (
196+ run_inference ,
197+ transformer = transformer ,
198+ config = config ,
199+ mesh = mesh ,
200+ latents = latents ,
201+ fractional_cords = fractional_coords ,
202+ prompt_embeds = prompt_embeds ,
203+ timestep = noise_cond ,
204+ segment_ids = segment_ids ,
205+ encoder_attention_segment_ids = encoder_attention_segment_ids
206+ ),
207+ in_shardings = (state_shardings ,),
208+ out_shardings = None ,
209+ )
210+ noise_pred = p_run_inference (states ).block_until_ready ()
211+ noise_pred = torch .from_numpy (np .array (noise_pred ))
212+
213+
214+
215+ torch .testing .assert_close (noise_pred_pt , noise_pred , atol = 0.025 , rtol = 20 )
216+
217+ if __name__ == "__main__" :
218+ absltest .main ()
0 commit comments