@@ -55,7 +55,7 @@ def unpack(x: Array, height: int, width: int) -> Array:
5555 ph = 2 ,
5656 pw = 2 ,
5757 )
58-
58+ from einops import rearrange
5959def vae_decode (latents , vae , state , config ):
6060 img = unpack (x = latents , height = config .resolution , width = config .resolution )
6161 img = img / vae .config .scaling_factor + vae .config .shift_factor
@@ -87,6 +87,8 @@ def loop_body(
8787 guidance = guidance_vec ,
8888 y = vec
8989 )
90+ jax .debug .print ("*****pred max: {x}" , x = np .max (pred ))
91+ jax .debug .print ("*****pred min: {x}" , x = np .min (pred ))
9092 latents = latents + (t_prev - t_curr ) * pred
9193 latents = jnp .array (latents , dtype = latents_dtype )
9294 return latents , state , c_ts , p_ts
@@ -144,6 +146,8 @@ def run_inference(
144146 timesteps = time_shift (mu , 1.0 , timesteps ).tolist ()
145147 c_ts = timesteps [:- 1 ]
146148 p_ts = timesteps [1 :]
149+ # jax.debug.print("c_ts: {x}", x=c_ts)
150+ # jax.debug.print("p_ts: {x}", x=p_ts)
147151
148152
149153 transformer_state = states ["transformer" ]
@@ -162,7 +166,7 @@ def run_inference(
162166 vae_decode_p = functools .partial (vae_decode , vae = vae , state = vae_state , config = config )
163167
164168 with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
165- latents , _ , _ , _ = jax .lax .fori_loop (0 , config . num_inference_steps , loop_body_p , (latents , transformer_state , c_ts , p_ts ))
169+ latents , _ , _ , _ = jax .lax .fori_loop (0 , len ( timesteps ) - 1 , loop_body_p , (latents , transformer_state , c_ts , p_ts ))
166170 image = vae_decode_p (latents )
167171 return image
168172
@@ -293,7 +297,8 @@ def encode_prompt(
293297 prompt = prompt_2 ,
294298 num_images_per_prompt = num_images_per_prompt ,
295299 tokenizer = t5_tokenizer ,
296- text_encoder = t5_text_encoder
300+ text_encoder = t5_text_encoder ,
301+ max_sequence_length = max_sequence_length
297302 )
298303
299304 text_ids = jnp .zeros ((prompt_embeds .shape [0 ], prompt_embeds .shape [1 ], 3 )).astype (jnp .bfloat16 )
@@ -356,7 +361,7 @@ def run(config):
356361 rng = rng
357362 )
358363
359- # LOAD TEXT ENCODERS - t5 on cpu
364+ # LOAD TEXT ENCODERS
360365 clip_text_encoder = FlaxCLIPTextModel .from_pretrained (
361366 config .pretrained_model_name_or_path ,
362367 subfolder = "text_encoder" ,
@@ -389,7 +394,8 @@ def run(config):
389394 clip_text_encoder = clip_text_encoder ,
390395 t5_tokenizer = t5_tokenizer ,
391396 t5_text_encoder = t5_encoder ,
392- num_images_per_prompt = global_batch_size
397+ num_images_per_prompt = global_batch_size ,
398+ max_sequence_length = config .max_sequence_length
393399 )
394400
395401 def validate_inputs (latents , latent_image_ids , prompt_embeds , text_ids , timesteps , guidance , pooled_prompt_embeds ):
@@ -430,12 +436,12 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
430436
431437 get_memory_allocations ()
432438 # evaluate shapes
433- transformer_eval_params = transformer .init_weights (rngs = rng , max_sequence_length = 512 , eval_only = True )
439+ transformer_eval_params = transformer .init_weights (rngs = rng , max_sequence_length = config . max_sequence_length , eval_only = True )
434440
435441 # loads pretrained weights
436- transformer_params = load_flow_model ("flux-dev" , transformer_eval_params , "cpu" )
442+ transformer_params = load_flow_model (config . flux_name , transformer_eval_params , "cpu" )
437443 # create transformer state
438- weights_init_fn = functools .partial (transformer .init_weights , rngs = rng , max_sequence_length = 512 , eval_only = False )
444+ weights_init_fn = functools .partial (transformer .init_weights , rngs = rng , max_sequence_length = config . max_sequence_length , eval_only = False )
439445 transformer_state , transformer_state_shardings = setup_initial_state (
440446 model = transformer ,
441447 tx = None ,
0 commit comments