1- from json import encoder
21from absl import app
32from typing import Sequence
43import jax
5- from flax import linen as nn
64import json
75from flax .linen import partitioning as nn_partitioning
86from maxdiffusion .models .ltx_video .transformers .transformer3d import Transformer3DModel
1917import orbax .checkpoint as ocp
2018
2119
22- def validate_transformer_inputs (prompt_embeds , fractional_coords , latents , noise_cond , segment_ids , encoder_attention_segment_ids ):
20+ def validate_transformer_inputs (
21+ prompt_embeds , fractional_coords , latents , noise_cond , segment_ids , encoder_attention_segment_ids
22+ ):
2323 print ("prompts_embeds.shape: " , prompt_embeds .shape , prompt_embeds .dtype )
2424 print ("fractional_coords.shape: " , fractional_coords .shape , fractional_coords .dtype )
2525 print ("latents.shape: " , latents .shape , latents .dtype )
@@ -29,15 +29,7 @@ def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise
2929 print ("encoder_attention_segment_ids.shape: " , encoder_attention_segment_ids .shape , encoder_attention_segment_ids .dtype )
3030
3131
32- def loop_body (
33- step ,
34- args ,
35- transformer ,
36- fractional_cords ,
37- prompt_embeds ,
38- segment_ids ,
39- encoder_attention_segment_ids
40- ):
32+ def loop_body (step , args , transformer , fractional_cords , prompt_embeds , segment_ids , encoder_attention_segment_ids ):
4133 latents , state , noise_cond = args
4234 noise_pred = transformer .apply (
4335 {"params" : state .params },
@@ -46,14 +38,22 @@ def loop_body(
4638 encoder_hidden_states = prompt_embeds ,
4739 timestep = noise_cond ,
4840 segment_ids = segment_ids ,
49- encoder_attention_segment_ids = encoder_attention_segment_ids
41+ encoder_attention_segment_ids = encoder_attention_segment_ids ,
5042 )
51- return noise_pred , state , noise_cond
52-
43+ return noise_pred , state , noise_cond
5344
5445
5546def run_inference (
56- states , transformer , config , mesh , latents , fractional_cords , prompt_embeds , timestep , segment_ids , encoder_attention_segment_ids
47+ states ,
48+ transformer ,
49+ config ,
50+ mesh ,
51+ latents ,
52+ fractional_cords ,
53+ prompt_embeds ,
54+ timestep ,
55+ segment_ids ,
56+ encoder_attention_segment_ids ,
5757):
5858 transformer_state = states ["transformer" ]
5959 loop_body_p = functools .partial (
@@ -62,20 +62,19 @@ def run_inference(
6262 fractional_cords = fractional_cords ,
6363 prompt_embeds = prompt_embeds ,
6464 segment_ids = segment_ids ,
65- encoder_attention_segment_ids = encoder_attention_segment_ids
65+ encoder_attention_segment_ids = encoder_attention_segment_ids ,
6666 )
67- ## TODO: add vae decode step
68- ## TODO: add loop
6967 with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
70- latents , transformer_state , _ = jax .lax .fori_loop (0 , 1 , loop_body_p , (latents , transformer_state , timestep ))
71- return latents
72-
68+ noise_pred , transformer_state , _ = jax .lax .fori_loop (0 , 1 , loop_body_p , (latents , transformer_state , timestep ))
69+ return noise_pred
70+
71+
7372def run (config ):
74- key = jax .random .PRNGKey (0 )
73+ key = jax .random .PRNGKey (42 )
7574
76- devices_array = create_device_mesh (config )
75+ devices_array = create_device_mesh (config )
7776 mesh = Mesh (devices_array , config .mesh_axes )
78-
77+
7978 base_dir = os .path .dirname (__file__ )
8079
8180 ##load in model config
@@ -84,41 +83,42 @@ def run(config):
8483 model_config = json .load (f )
8584 relative_ckpt_path = model_config ["ckpt_path" ]
8685
87- ignored_keys = ["_class_name" , "_diffusers_version" , "_name_or_path" , "causal_temporal_positioning" , "in_channels" , "ckpt_path" ]
86+ ignored_keys = [
87+ "_class_name" ,
88+ "_diffusers_version" ,
89+ "_name_or_path" ,
90+ "causal_temporal_positioning" ,
91+ "in_channels" ,
92+ "ckpt_path" ,
93+ ]
8894 in_channels = model_config ["in_channels" ]
8995 for name in ignored_keys :
9096 if name in model_config :
9197 del model_config [name ]
92-
93-
94- transformer = Transformer3DModel ( ** model_config , dtype = jnp .float32 , gradient_checkpointing = "matmul_without_batch" , sharding_mesh = mesh )
95- transformer_param_shapes = transformer . init_weights ( in_channels , model_config [ 'caption_channels' ], eval_only = True )
96-
98+
99+ transformer = Transformer3DModel (
100+ ** model_config , dtype = jnp .float32 , gradient_checkpointing = "matmul_without_batch" , sharding_mesh = mesh
101+ )
102+ transformer_param_shapes = transformer . init_weights ( in_channels , key , model_config [ "caption_channels" ], eval_only = True ) # noqa F841
97103 weights_init_fn = functools .partial (
98- transformer .init_weights ,
99- in_channels ,
100- model_config ['caption_channels' ],
101- eval_only = True
104+ transformer .init_weights , in_channels , key , model_config ["caption_channels" ], eval_only = True
102105 )
103106
104107 absolute_ckpt_path = os .path .abspath (relative_ckpt_path )
105108
106109 checkpoint_manager = ocp .CheckpointManager (absolute_ckpt_path )
107110 transformer_state , transformer_state_shardings = setup_initial_state (
108- model = transformer ,
109- tx = None ,
110- config = config ,
111- mesh = mesh ,
112- weights_init_fn = weights_init_fn ,
113- checkpoint_manager = checkpoint_manager ,
114- checkpoint_item = " " ,
115- model_params = None ,
116- training = False ,
111+ model = transformer ,
112+ tx = None ,
113+ config = config ,
114+ mesh = mesh ,
115+ weights_init_fn = weights_init_fn ,
116+ checkpoint_manager = checkpoint_manager ,
117+ checkpoint_item = " " ,
118+ model_params = None ,
119+ training = False ,
117120 )
118121
119-
120-
121-
122122 transformer_state = jax .device_put (transformer_state , transformer_state_shardings )
123123 get_memory_allocations ()
124124
@@ -128,20 +128,20 @@ def run(config):
128128 state_shardings ["transformer" ] = transformer_state_shardings
129129 states ["transformer" ] = transformer_state
130130
131- #create dummy inputs:
131+ # create dummy inputs:
132132 example_inputs = {}
133133 batch_size , num_tokens = 4 , 256
134134 input_shapes = {
135- "latents" : (batch_size , num_tokens , in_channels ),
136- "fractional_coords" : (batch_size , 3 , num_tokens ),
137- "prompt_embeds" : (batch_size , 128 , model_config ["caption_channels" ]),
138- "timestep" : (batch_size , 256 ),
139- "segment_ids" : (batch_size , 256 ),
140- "encoder_attention_segment_ids" : (batch_size , 128 ),
135+ "latents" : (batch_size , num_tokens , in_channels ),
136+ "fractional_coords" : (batch_size , 3 , num_tokens ),
137+ "prompt_embeds" : (batch_size , 128 , model_config ["caption_channels" ]),
138+ "timestep" : (batch_size , 256 ),
139+ "segment_ids" : (batch_size , 256 ),
140+ "encoder_attention_segment_ids" : (batch_size , 128 ),
141141 }
142142 for name , shape in input_shapes .items ():
143143 example_inputs [name ] = jnp .ones (
144- shape , dtype = jnp .float32 if name not in ["attention_mask" , "encoder_attention_mask" ] else jnp .bool
144+ shape , dtype = jnp .float32 if name not in ["attention_mask" , "encoder_attention_mask" ] else jnp .bool
145145 )
146146
147147 data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
@@ -152,7 +152,9 @@ def run(config):
152152 segment_ids = jax .device_put (example_inputs ["segment_ids" ], data_sharding )
153153 encoder_attention_segment_ids = jax .device_put (example_inputs ["encoder_attention_segment_ids" ], data_sharding )
154154
155- validate_transformer_inputs (prompt_embeds , fractional_coords , latents , noise_cond , segment_ids , encoder_attention_segment_ids )
155+ validate_transformer_inputs (
156+ prompt_embeds , fractional_coords , latents , noise_cond , segment_ids , encoder_attention_segment_ids
157+ )
156158 p_run_inference = jax .jit (
157159 functools .partial (
158160 run_inference ,
@@ -162,16 +164,16 @@ def run(config):
162164 latents = latents ,
163165 fractional_cords = fractional_coords ,
164166 prompt_embeds = prompt_embeds ,
165- timestep = noise_cond ,
167+ timestep = noise_cond ,
166168 segment_ids = segment_ids ,
167- encoder_attention_segment_ids = encoder_attention_segment_ids
169+ encoder_attention_segment_ids = encoder_attention_segment_ids ,
168170 ),
169171 in_shardings = (state_shardings ,),
170172 out_shardings = None ,
171173 )
172174
173175 noise_pred = p_run_inference (states ).block_until_ready ()
174- print (noise_pred ) #(4, 256, 128)
176+ print (noise_pred ) # (4, 256, 128)
175177
176178
177179def main (argv : Sequence [str ]) -> None :
@@ -181,18 +183,3 @@ def main(argv: Sequence[str]) -> None:
181183
182184if __name__ == "__main__" :
183185 app .run (main )
184-
185-
186-
187-
188-
189-
190-
191-
192-
193-
194-
195-
196-
197-
198-
0 commit comments