@@ -76,8 +76,8 @@ def unpack(x: Array, height: int, width: int) -> Array:
7676 )
7777
7878
79- def vae_decode (latents , vae , state , config ):
80- img = unpack (x = latents .astype (jnp .float32 ), height = config . resolution , width = config . resolution )
79+ def vae_decode (latents , vae , state , config , resolution ):
80+ img = unpack (x = latents .astype (jnp .float32 ), height = resolution , width = resolution )
8181 img = img / vae .config .scaling_factor + vae .config .shift_factor
8282 img = vae .apply ({"params" : state .params }, img , deterministic = True , method = vae .decode ).sample
8383 return img
@@ -135,7 +135,7 @@ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: flo
135135
136136
137137def run_inference (
138- states , transformer , vae , config , mesh , latents , latent_image_ids , prompt_embeds , txt_ids , vec , guidance_vec , c_ts , p_ts
138+ states , transformer , vae , config , resolution , mesh , latents , latent_image_ids , prompt_embeds , txt_ids , vec , guidance_vec , c_ts , p_ts
139139):
140140
141141 transformer_state = states ["transformer" ]
@@ -150,7 +150,7 @@ def run_inference(
150150 vec = vec ,
151151 guidance_vec = guidance_vec ,
152152 )
153- vae_decode_p = functools .partial (vae_decode , vae = vae , state = vae_state , config = config )
153+ vae_decode_p = functools .partial (vae_decode , vae = vae , state = vae_state , config = config , resolution = resolution )
154154
155155 with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
156156 latents , _ , _ , _ = jax .lax .fori_loop (0 , len (c_ts ), loop_body_p , (latents , transformer_state , c_ts , p_ts ))
@@ -376,8 +376,6 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
376376
377377 # move inputs to device and shard
378378 data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
379- latents = jax .device_put (latents , data_sharding )
380- latent_image_ids = jax .device_put (latent_image_ids )
381379 prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
382380 text_ids = jax .device_put (text_ids )
383381 guidance = jax .device_put (guidance , data_sharding )
@@ -429,45 +427,66 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
429427 states ["transformer" ] = transformer_state
430428 states ["vae" ] = vae_state
431429
432- # Setup timesteps
433- timesteps = jnp .linspace (1 , 0 , config .num_inference_steps + 1 )
434- # shifting the schedule to favor high timesteps for higher signal images
435- if config .time_shift :
436- # estimate mu based on linear estimation between two points
437- lin_function = get_lin_function (x1 = config .max_sequence_length , y1 = config .base_shift , y2 = config .max_shift )
438- mu = lin_function (latents .shape [1 ])
439- timesteps = time_shift (mu , 1.0 , timesteps )
440- c_ts = timesteps [:- 1 ]
441- p_ts = timesteps [1 :]
442-
443- validate_inputs (latents , latent_image_ids , prompt_embeds , text_ids , timesteps , guidance , pooled_prompt_embeds )
444-
445- p_run_inference = jax .jit (
446- functools .partial (
447- run_inference ,
448- transformer = transformer ,
449- vae = vae ,
450- config = config ,
451- mesh = mesh ,
452- latents = latents ,
453- latent_image_ids = latent_image_ids ,
454- prompt_embeds = prompt_embeds ,
455- txt_ids = text_ids ,
456- vec = pooled_prompt_embeds ,
457- guidance_vec = guidance ,
458- c_ts = c_ts ,
459- p_ts = p_ts ,
460- ),
461- in_shardings = (state_shardings ,),
462- out_shardings = None ,
463- )
430+ #validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
431+
432+ resolutions = [1024 , 768 , 512 ]
433+ p_jitted = {}
434+ for resolution in resolutions :
435+ latents , latent_image_ids = prepare_latents (
436+ batch_size = global_batch_size ,
437+ num_channels_latents = num_channels_latents ,
438+ height = resolution ,
439+ width = resolution ,
440+ dtype = jnp .bfloat16 ,
441+ vae_scale_factor = vae_scale_factor ,
442+ rng = rng ,
443+ )
444+ latents = jax .device_put (latents , data_sharding )
445+ latent_image_ids = jax .device_put (latent_image_ids )
446+
447+ # Setup timesteps
448+ timesteps = jnp .linspace (1 , 0 , config .num_inference_steps + 1 )
449+ # shifting the schedule to favor high timesteps for higher signal images
450+ if config .time_shift :
451+ # estimate mu based on linear estimation between two points
452+ lin_function = get_lin_function (x1 = config .max_sequence_length , y1 = config .base_shift , y2 = config .max_shift )
453+ mu = lin_function (latents .shape [1 ])
454+ timesteps = time_shift (mu , 1.0 , timesteps )
455+ c_ts = timesteps [:- 1 ]
456+ p_ts = timesteps [1 :]
457+
458+ p_run_inference = jax .jit (
459+ functools .partial (
460+ run_inference ,
461+ transformer = transformer ,
462+ vae = vae ,
463+ config = config ,
464+ resolution = resolution ,
465+ mesh = mesh ,
466+ latents = latents ,
467+ latent_image_ids = latent_image_ids ,
468+ prompt_embeds = prompt_embeds ,
469+ txt_ids = text_ids ,
470+ vec = pooled_prompt_embeds ,
471+ guidance_vec = guidance ,
472+ c_ts = c_ts ,
473+ p_ts = p_ts ,
474+ ),
475+ in_shardings = (state_shardings ,),
476+ out_shardings = None ,
477+ )
478+ with ExitStack () as stack :
479+ _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
480+ p_run_inference (states ).block_until_ready ()
481+ p_jitted [resolution ] = p_run_inference
482+ breakpoint ()
464483 t0 = time .perf_counter ()
465484 with ExitStack () as stack :
466485 _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
467486 p_run_inference (states ).block_until_ready ()
468487 t1 = time .perf_counter ()
469488 max_logging .log (f"Compile time: { t1 - t0 :.1f} s." )
470-
489+ breakpoint ()
471490 t0 = time .perf_counter ()
472491 with ExitStack () as stack , jax .profiler .trace ("/tmp/trace/" ):
473492 _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
0 commit comments