@@ -28,50 +28,6 @@ def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise
2828 print ("segment_ids.shape: " , segment_ids .shape , segment_ids .dtype )
2929 print ("encoder_attention_segment_ids.shape: " , encoder_attention_segment_ids .shape , encoder_attention_segment_ids .dtype )
3030
31-
32- def loop_body (
33- step ,
34- args ,
35- transformer ,
36- fractional_cords ,
37- prompt_embeds ,
38- segment_ids ,
39- encoder_attention_segment_ids
40- ):
41- latents , state , noise_cond = args
42- noise_pred = transformer .apply (
43- {"params" : state .params },
44- hidden_states = latents ,
45- indices_grid = fractional_cords ,
46- encoder_hidden_states = prompt_embeds ,
47- timestep = noise_cond ,
48- segment_ids = segment_ids ,
49- encoder_attention_segment_ids = encoder_attention_segment_ids
50- )
51- import pdb ; pdb .set_trace ()
52- return noise_pred , state , noise_cond #need to make changes here? latents need to be changed based on noise_pred, but needs scheduler, return noise_pred for now
53-
54-
55-
56- def run_inference (
57- states , transformer , config , mesh , latents , fractional_cords , prompt_embeds , timestep , segment_ids , encoder_attention_segment_ids
58- ):
59- transformer_state = states ["transformer" ]
60- loop_body_p = functools .partial (
61- loop_body ,
62- transformer = transformer ,
63- fractional_cords = fractional_cords ,
64- prompt_embeds = prompt_embeds ,
65- segment_ids = segment_ids ,
66- encoder_attention_segment_ids = encoder_attention_segment_ids
67- )
68- ## TODO: add vae decode step
69- ## TODO: add loop
70- with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
71- latents , transformer_state , _ = jax .lax .fori_loop (0 , 1 , loop_body_p , (latents , transformer_state , timestep ))
72- return latents
73-
74-
7531def run (config ):
7632 key = jax .random .PRNGKey (0 )
7733
@@ -119,92 +75,6 @@ def run(config):
11975 )
12076
12177
122-
123-
124- transformer_state = jax .device_put (transformer_state , transformer_state_shardings )
125- get_memory_allocations ()
126-
127- states = {}
128- state_shardings = {}
129-
130- state_shardings ["transformer" ] = transformer_state_shardings
131- states ["transformer" ] = transformer_state
132-
133- #create dummy inputs:
134- example_inputs = {}
135- batch_size , num_tokens = 4 , 256
136- input_shapes = {
137- "latents" : (batch_size , num_tokens , in_channels ),
138- "fractional_coords" : (batch_size , 3 , num_tokens ),
139- "prompt_embeds" : (batch_size , 128 , model_config ["caption_channels" ]),
140- "timestep" : (batch_size , 256 ), #TODO: add in the segment id stuff
141- "segment_ids" : (batch_size , 256 ),
142- "encoder_attention_segment_ids" : (batch_size , 128 ),
143- }
144- for name , shape in input_shapes .items ():
145- example_inputs [name ] = jnp .ones (
146- shape , dtype = jnp .float32 if name not in ["attention_mask" , "encoder_attention_mask" ] else jnp .bool
147- )
148-
149- data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
150- latents = jax .device_put (example_inputs ["latents" ], data_sharding )
151- prompt_embeds = jax .device_put (example_inputs ["prompt_embeds" ], data_sharding )
152- fractional_coords = jax .device_put (example_inputs ["fractional_coords" ], data_sharding )
153- noise_cond = jax .device_put (example_inputs ["timestep" ], data_sharding )
154- segment_ids = jax .device_put (example_inputs ["segment_ids" ], data_sharding )
155- encoder_attention_segment_ids = jax .device_put (example_inputs ["encoder_attention_segment_ids" ], data_sharding )
156-
157- validate_transformer_inputs (prompt_embeds , fractional_coords , latents , noise_cond , segment_ids , encoder_attention_segment_ids )
158- p_run_inference = jax .jit (
159- functools .partial (
160- run_inference ,
161- transformer = transformer ,
162- config = config ,
163- mesh = mesh ,
164- latents = latents ,
165- fractional_cords = fractional_coords ,
166- prompt_embeds = prompt_embeds ,
167- timestep = noise_cond ,
168- segment_ids = segment_ids ,
169- encoder_attention_segment_ids = encoder_attention_segment_ids
170- ),
171- in_shardings = (state_shardings ,),
172- out_shardings = None ,
173- )
174- noise_pred = p_run_inference (states ).block_until_ready ()
175- print (noise_pred ) #(4, 256, 128)
176-
177-
178-
179-
180-
181-
182-
183-
184-
185-
186-
187-
188-
189-
190-
191-
192-
193-
194-
195-
196-
197-
198-
199-
200-
201-
202-
203-
204-
205-
206-
207-
20878
20979
21080
@@ -219,12 +89,4 @@ def main(argv: Sequence[str]) -> None:
21989
22090
22191
222- ###setup_initial_state, can optionally load from checkpoint
223-
224-
225-
226-
227-
228-
22992
230- #end to end steps from ltx repo: pipeline_ltx_video.py
0 commit comments