@@ -77,7 +77,7 @@ def unpack(x: Array, height: int, width: int) -> Array:
7777
7878
7979def vae_decode (latents , vae , state , config ):
80- img = unpack (x = latents , height = config .resolution , width = config .resolution )
80+ img = unpack (x = latents . astype ( jnp . float32 ) , height = config .resolution , width = config .resolution )
8181 img = img / vae .config .scaling_factor + vae .config .shift_factor
8282 img = vae .apply ({"params" : state .params }, img , deterministic = True , method = vae .decode ).sample
8383 return img
@@ -115,13 +115,12 @@ def loop_body(
115115
116116def prepare_latent_image_ids (height , width ):
117117 latent_image_ids = jnp .zeros ((height , width , 3 ))
118- latent_image_ids = latent_image_ids .at [..., 1 ].set (latent_image_ids [..., 1 ] + jnp .arange (height )[:, None ])
119- latent_image_ids = latent_image_ids .at [..., 2 ].set (latent_image_ids [..., 2 ] + jnp .arange (width )[None , :])
118+ latent_image_ids = latent_image_ids .at [..., 1 ].set (jnp .arange (height )[:, None ])
119+ latent_image_ids = latent_image_ids .at [..., 2 ].set (jnp .arange (width )[None , :])
120120
121121 latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
122122
123123 latent_image_ids = latent_image_ids .reshape (latent_image_id_height * latent_image_id_width , latent_image_id_channels )
124-
125124 return latent_image_ids .astype (jnp .bfloat16 )
126125
127126
@@ -147,20 +146,10 @@ def run_inference(
147146 txt_ids ,
148147 vec ,
149148 guidance_vec ,
149+ c_ts ,
150+ p_ts
150151):
151152
152- timesteps = jnp .linspace (1 , 0 , config .num_inference_steps + 1 )
153- # shifting the schedule to favor high timesteps for higher signal images
154- if config .time_shift :
155- # estimate mu based on linear estimation between two points
156- lin_function = get_lin_function (y1 = config .base_shift , y2 = config .max_shift )
157- mu = lin_function (latents .shape [1 ])
158- timesteps = time_shift (mu , 1.0 , timesteps ).tolist ()
159- c_ts = timesteps [:- 1 ]
160- p_ts = timesteps [1 :]
161- # jax.debug.print("c_ts: {x}", x=c_ts)
162- # jax.debug.print("p_ts: {x}", x=p_ts)
163-
164153 transformer_state = states ["transformer" ]
165154 vae_state = states ["vae" ]
166155
@@ -173,11 +162,10 @@ def run_inference(
173162 vec = vec ,
174163 guidance_vec = guidance_vec ,
175164 )
176-
177165 vae_decode_p = functools .partial (vae_decode , vae = vae , state = vae_state , config = config )
178166
179167 with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
180- latents , _ , _ , _ = jax .lax .fori_loop (0 , len (timesteps ) - 1 , loop_body_p , (latents , transformer_state , c_ts , p_ts ))
168+ latents , _ , _ , _ = jax .lax .fori_loop (0 , len (c_ts ) , loop_body_p , (latents , transformer_state , c_ts , p_ts ))
181169 image = vae_decode_p (latents )
182170 return image
183171
@@ -236,8 +224,7 @@ def get_clip_prompt_embeds(
236224
237225 prompt_embeds = text_encoder (text_input_ids , params = text_encoder .params , train = False )
238226 prompt_embeds = prompt_embeds .pooler_output
239- prompt_embeds = np .repeat (prompt_embeds , num_images_per_prompt , axis = - 1 )
240- prompt_embeds = np .reshape (prompt_embeds , (batch_size * num_images_per_prompt , - 1 ))
227+ prompt_embeds = jnp .tile (prompt_embeds , (batch_size * num_images_per_prompt , 1 ))
241228 return prompt_embeds
242229
243230
@@ -300,7 +287,7 @@ def encode_prompt(
300287 max_sequence_length = max_sequence_length ,
301288 )
302289
303- text_ids = jnp .zeros ((prompt_embeds .shape [0 ], prompt_embeds . shape [ 1 ], 3 )).astype (jnp .bfloat16 )
290+ text_ids = jnp .zeros ((prompt_embeds .shape [1 ], 3 )).astype (jnp .bfloat16 )
304291 return prompt_embeds , pooled_prompt_embeds , text_ids
305292
306293
@@ -397,18 +384,14 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
397384 print ("guidance.shape: " , guidance .shape , guidance .dtype )
398385 print ("pooled_prompt_embeds.shape: " , pooled_prompt_embeds .shape , pooled_prompt_embeds .dtype )
399386
400- timesteps = jnp .asarray ([1.0 ] * global_batch_size , dtype = jnp .bfloat16 )
401387 guidance = jnp .asarray ([config .guidance_scale ] * global_batch_size , dtype = jnp .bfloat16 )
402388
403- validate_inputs (latents , latent_image_ids , prompt_embeds , text_ids , timesteps , guidance , pooled_prompt_embeds )
404-
405389 # move inputs to device and shard
406390 data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
407391 latents = jax .device_put (latents , data_sharding )
408- latent_image_ids = jax .device_put (latent_image_ids , data_sharding )
392+ latent_image_ids = jax .device_put (latent_image_ids )
409393 prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
410- text_ids = jax .device_put (text_ids , data_sharding )
411- timesteps = jax .device_put (timesteps , data_sharding )
394+ text_ids = jax .device_put (text_ids )
412395 guidance = jax .device_put (guidance , data_sharding )
413396 pooled_prompt_embeds = jax .device_put (pooled_prompt_embeds , data_sharding )
414397
@@ -458,6 +441,19 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
458441 states ["transformer" ] = transformer_state
459442 states ["vae" ] = vae_state
460443
444+ # Setup timesteps
445+ timesteps = jnp .linspace (1 , 0 , config .num_inference_steps + 1 )
446+ # shifting the schedule to favor high timesteps for higher signal images
447+ if config .time_shift :
448+ # estimate mu based on linear estimation between two points
449+ lin_function = get_lin_function (x1 = config .max_sequence_length , y1 = config .base_shift , y2 = config .max_shift )
450+ mu = lin_function (latents .shape [1 ])
451+ timesteps = time_shift (mu , 1.0 , timesteps )
452+ c_ts = timesteps [:- 1 ]
453+ p_ts = timesteps [1 :]
454+
455+ validate_inputs (latents , latent_image_ids , prompt_embeds , text_ids , timesteps , guidance , pooled_prompt_embeds )
456+
461457 p_run_inference = jax .jit (
462458 functools .partial (
463459 run_inference ,
@@ -471,6 +467,8 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
471467 txt_ids = text_ids ,
472468 vec = pooled_prompt_embeds ,
473469 guidance_vec = guidance ,
470+ c_ts = c_ts ,
471+ p_ts = p_ts
474472 ),
475473 in_shardings = (state_shardings ,),
476474 out_shardings = None ,
0 commit comments