1818from absl import app
1919import functools
2020import math
21+ import time
2122import numpy as np
23+ from PIL import Image
2224import jax
2325from jax .sharding import Mesh , PositionalSharding , PartitionSpec as P
2426import jax .numpy as jnp
3335 FlaxT5EncoderModel
3436)
3537
36- from maxdiffusion import FlaxAutoencoderKL
38+ from maxdiffusion import FlaxAutoencoderKL , pyconfig , max_logging
3739from maxdiffusion .models .flux .transformers .transformer_flux_flax import FluxTransformer2DModel
38- from maxdiffusion import pyconfig
3940from max_utils import (
4041 device_put_replicated ,
4142 get_memory_allocations ,
@@ -57,8 +58,8 @@ def unpack(x: Array, height: int, width: int) -> Array:
5758
5859def vae_decode (latents , vae , state , config ):
5960 img = unpack (x = latents , height = config .resolution , width = config .resolution )
60- img = vae .apply ({ "params" : state . params }, img , deterministic = True , method = vae .decode ). sample [ 0 ]
61- breakpoint ()
61+ img = img / vae .config . scaling_factor + vae .config . shift_factor
62+ img = vae . apply ({ "params" : state . params }, img , deterministic = True , method = vae . decode ). sample
6263 return img
6364
6465def loop_body (
@@ -107,6 +108,19 @@ def prepare_latent_image_ids(height, width):
107108
108109 return latent_image_ids .astype (jnp .bfloat16 )
109110
111+ def time_shift (mu : float , sigma : float , t : Array ):
112+ return math .exp (mu ) / (math .exp (mu ) + (1 / t - 1 ) ** sigma )
113+
114+ def get_lin_function (
115+ x1 : float = 256 ,
116+ y1 : float = 0.5 ,
117+ x2 : float = 4096 ,
118+ y2 : float = 1.15
119+ ) -> Callable [[float ], float ]:
120+ m = (y2 - y1 ) / (x2 - x1 )
121+ b = y1 - m * x1
122+ return lambda x : m * x + b
123+
110124def run_inference (
111125 states ,
112126 transformer ,
@@ -120,10 +134,18 @@ def run_inference(
120134 vec ,
121135 guidance_vec ,
122136):
137+
123138 timesteps = jnp .linspace (1 , 0 , config .num_inference_steps + 1 )
139+ # shifting the schedule to favor high timesteps for higher signal images
140+ if config .time_shift :
141+ # estimate mu based on linear estimation between two points
142+ lin_function = get_lin_function (y1 = config .base_shift , y2 = config .max_shift )
143+ mu = lin_function (latents .shape [1 ])
144+ timesteps = time_shift (mu , 1.0 , timesteps ).tolist ()
124145 c_ts = timesteps [:- 1 ]
125146 p_ts = timesteps [1 :]
126147
148+
127149 transformer_state = states ["transformer" ]
128150 vae_state = states ["vae" ]
129151
@@ -142,7 +164,6 @@ def run_inference(
142164 with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
143165 latents , _ , _ , _ = jax .lax .fori_loop (0 , config .num_inference_steps , loop_body_p , (latents , transformer_state , c_ts , p_ts ))
144166 image = vae_decode_p (latents )
145- breakpoint ()
146167 return image
147168
148169
@@ -383,6 +404,10 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
383404
384405 timesteps = jnp .asarray ([1.0 ] * global_batch_size , dtype = jnp .bfloat16 )
385406 guidance = jnp .asarray ([config .guidance_scale ] * global_batch_size , dtype = jnp .bfloat16 )
407+
408+ # TODO - remove this later and figure out why t5x is returning wrong shape
409+ prompt_embeds = jnp .ones ((global_batch_size , 512 , 4096 ))
410+
386411 validate_inputs (
387412 latents ,
388413 latent_image_ids ,
@@ -393,8 +418,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
393418 pooled_prompt_embeds
394419 )
395420
396- # TODO - remove this later and figure out why t5x is returning wrong shape
397- prompt_embeds = jnp .ones ((global_batch_size , 512 , 4096 ))
421+
398422
399423 # move inputs to device and shard
400424 data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
@@ -420,11 +444,11 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
420444 config = config ,
421445 mesh = mesh ,
422446 weights_init_fn = weights_init_fn ,
423- # model_params=transformer_params,
424- model_params = None ,
447+ model_params = transformer_params ,
448+ # model_params=None,
425449 training = False
426450 )
427- transformer_state = transformer_state .replace (params = transformer_params )
451+ # transformer_state = transformer_state.replace(params=transformer_params)
428452 get_memory_allocations ()
429453
430454 states = {}
@@ -453,37 +477,27 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
453477 in_shardings = (state_shardings ,),
454478 out_shardings = None ,
455479 )
456-
457- img = p_run_inference (states )
458-
459-
460-
461-
462- # def run_inference(state, transformer):
463- # img = transformer.apply(
464- # {"params" : state.params},
465- # img=latents,
466- # img_ids=latent_image_ids,
467- # txt=prompt_embeds,
468- # txt_ids=text_ids,
469- # timesteps=timesteps,
470- # guidance=guidance,
471- # y=pooled_prompt_embeds
472- # )
473- # return img
474-
475- # p_run_inference = jax.jit(
476- # functools.partial(
477- # run_inference,
478- # transformer=transformer,
479- # ),
480- # in_shardings=(transformer_state_shardings,),
481- # out_shardings=None
482- # )
483-
484- img = p_run_inference (transformer_state )
485- breakpoint ()
486- print ("img.shape: " , img .shape )
480+ t0 = time .perf_counter ()
481+ p_run_inference (states ).block_until_ready ()
482+ t1 = time .perf_counter ()
483+ max_logging .log (f"Compile time: { t1 - t0 :.1f} s." )
484+
485+ t0 = time .perf_counter ()
486+ imgs = p_run_inference (states ).block_until_ready ()
487+ t1 = time .perf_counter ()
488+ max_logging .log (f"Inference time: { t1 - t0 :.1f} s." )
489+
490+ t0 = time .perf_counter ()
491+ imgs = p_run_inference (states ).block_until_ready ()
492+ imgs = jax .experimental .multihost_utils .process_allgather (imgs , tiled = True )
493+ t1 = time .perf_counter ()
494+ max_logging .log (f"Inference time: { t1 - t0 :.1f} s." )
495+ imgs = np .array (imgs )
496+ imgs = (imgs * 0.5 + 0.5 ).clip (0 , 1 )
497+ imgs = np .transpose (imgs , (0 , 2 , 3 , 1 ))
498+ imgs = np .uint8 (imgs * 255 )
499+ for i , image in enumerate (imgs ):
500+ Image .fromarray (image ).save (f"flux_{ i } .png" )
487501
488502
489503def main (argv : Sequence [str ]) -> None :
0 commit comments