@@ -76,6 +76,7 @@ def unpack(x: Array, height: int, width: int, vae_scale_factor: int) -> Array:
7676
7777 return x
7878
79+
7980def vae_decode (latents , vae , state , vae_scale_factor , resolution ):
8081 img = unpack (x = latents .astype (jnp .float32 ), height = resolution [0 ], width = resolution [1 ], vae_scale_factor = vae_scale_factor )
8182 img = img / vae .config .scaling_factor + vae .config .shift_factor
@@ -127,18 +128,16 @@ def prepare_latent_image_ids(height, width):
127128def time_shift (mu : float , sigma : float , t : Array ):
128129 return math .exp (mu ) / (math .exp (mu ) + (1 / t - 1 ) ** sigma )
129130
131+
130132def calculate_shift (
131- image_seq_len ,
132- base_seq_len : int = 256 ,
133- max_seq_len : int = 4096 ,
134- base_shift : float = 0.5 ,
135- max_shift : float = 1.16
133+ image_seq_len , base_seq_len : int = 256 , max_seq_len : int = 4096 , base_shift : float = 0.5 , max_shift : float = 1.16
136134):
137135 m = (max_shift - base_shift ) / (max_seq_len - base_seq_len )
138136 b = base_shift - m * base_seq_len
139137 mu = image_seq_len * m + b
140138 return mu
141139
140+
142141def run_inference (
143142 states ,
144143 transformer ,
@@ -154,7 +153,7 @@ def run_inference(
154153 guidance_vec ,
155154 c_ts ,
156155 p_ts ,
157- vae_scale_factor
156+ vae_scale_factor ,
158157):
159158
160159 transformer_state = states ["transformer" ]
@@ -169,7 +168,9 @@ def run_inference(
169168 vec = vec ,
170169 guidance_vec = guidance_vec ,
171170 )
172- vae_decode_p = functools .partial (vae_decode , vae = vae , state = vae_state , vae_scale_factor = vae_scale_factor , resolution = resolution )
171+ vae_decode_p = functools .partial (
172+ vae_decode , vae = vae , state = vae_state , vae_scale_factor = vae_scale_factor , resolution = resolution
173+ )
173174
174175 with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
175176 latents , _ , _ , _ = jax .lax .fori_loop (0 , len (c_ts ), loop_body_p , (latents , transformer_state , c_ts , p_ts ))
@@ -211,6 +212,7 @@ def prepare_latents(
211212
212213 return latents , latent_image_ids
213214
215+
214216def tokenize_clip (prompt : Union [str , List [str ]], tokenizer : CLIPTokenizer ):
215217 prompt = [prompt ] if isinstance (prompt , str ) else prompt
216218 text_inputs = tokenizer (
@@ -224,6 +226,7 @@ def tokenize_clip(prompt: Union[str, List[str]], tokenizer: CLIPTokenizer):
224226 )
225227 return text_inputs .input_ids
226228
229+
227230def get_clip_prompt_embeds (
228231 prompt : Union [str , List [str ]], num_images_per_prompt : int , tokenizer : CLIPTokenizer , text_encoder : FlaxCLIPTextModel
229232):
@@ -246,6 +249,7 @@ def get_clip_prompt_embeds(
246249 prompt_embeds = jnp .tile (prompt_embeds , (batch_size * num_images_per_prompt , 1 ))
247250 return prompt_embeds
248251
252+
249253def tokenize_t5 (prompt : Union [str , List [str ]], tokenizer : AutoTokenizer , max_sequence_length : int = 512 ):
250254 prompt = [prompt ] if isinstance (prompt , str ) else prompt
251255 text_inputs = tokenizer (
@@ -259,6 +263,7 @@ def tokenize_t5(prompt: Union[str, List[str]], tokenizer: AutoTokenizer, max_seq
259263 )
260264 return text_inputs .input_ids
261265
266+
262267def get_t5_prompt_embeds (
263268 prompt : Union [str , List [str ]],
264269 num_images_per_prompt : int ,
@@ -288,6 +293,7 @@ def get_t5_prompt_embeds(
288293 prompt_embeds = jnp .reshape (prompt_embeds , (batch_size * num_images_per_prompt , seq_len , - 1 ))
289294 return prompt_embeds
290295
296+
291297def encode_prompt (
292298 prompt : Union [str , List [str ]],
293299 prompt_2 : Union [str , List [str ]],
@@ -318,6 +324,7 @@ def encode_prompt(
318324 text_ids = jnp .zeros ((prompt_embeds .shape [1 ], 3 )).astype (jnp .bfloat16 )
319325 return prompt_embeds , pooled_prompt_embeds , text_ids
320326
327+
321328def run (config ):
322329 from maxdiffusion .models .flux .util import load_flow_model
323330
@@ -436,19 +443,19 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
436443 states ["vae" ] = vae_state
437444 # some resolutions from https://www.reddit.com/r/StableDiffusion/comments/1enxdga/flux_recommended_resolutions_from_01_to_20/
438445 resolutions = [
439- (768 , 768 ),
440- (768 , 1024 ),
441- (1024 , 768 ),
442- (1024 , 1024 ),
443- (1408 , 1408 ),
444- (1728 , 1152 ),
445- (1152 , 1728 ),
446- (1664 , 1216 ),
447- (1216 , 1664 ),
448- (1920 , 1088 ),
449- (1088 , 1920 ),
450- (2176 , 960 ),
451- (960 , 2176 )
446+ (768 , 768 ),
447+ (768 , 1024 ),
448+ (1024 , 768 ),
449+ (1024 , 1024 ),
450+ (1408 , 1408 ),
451+ (1728 , 1152 ),
452+ (1152 , 1728 ),
453+ (1664 , 1216 ),
454+ (1216 , 1664 ),
455+ (1920 , 1088 ),
456+ (1088 , 1920 ),
457+ (2176 , 960 ),
458+ (960 , 2176 ),
452459 ]
453460 p_jitted = {}
454461 recorded_times = {}
@@ -461,14 +468,14 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
461468 t5_encoder .params = jax .tree_util .tree_map (partial_device_put_replicated , t5_encoder .params )
462469 max_logging .log (f"Moving encoder to TPU time: { (time .perf_counter () - s0 )} " )
463470 prompt_embeds , pooled_prompt_embeds , text_ids = encode_prompt (
464- prompt = config .prompt ,
465- prompt_2 = config .prompt_2 ,
466- clip_tokenizer = clip_tokenizer ,
467- clip_text_encoder = clip_text_encoder ,
468- t5_tokenizer = t5_tokenizer ,
469- t5_text_encoder = t5_encoder ,
470- num_images_per_prompt = global_batch_size ,
471- max_sequence_length = config .max_sequence_length ,
471+ prompt = config .prompt ,
472+ prompt_2 = config .prompt_2 ,
473+ clip_tokenizer = clip_tokenizer ,
474+ clip_text_encoder = clip_text_encoder ,
475+ t5_tokenizer = t5_tokenizer ,
476+ t5_text_encoder = t5_encoder ,
477+ num_images_per_prompt = global_batch_size ,
478+ max_sequence_length = config .max_sequence_length ,
472479 )
473480 if config .offload_encoders :
474481 s1 = time .perf_counter ()
@@ -478,15 +485,15 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
478485 text_encoding_time_final = time .perf_counter () - s0
479486 max_logging .log (f"text encoding time: { text_encoding_time_final } " )
480487 latents , latent_image_ids = prepare_latents (
481- batch_size = global_batch_size ,
482- num_channels_latents = num_channels_latents ,
483- height = resolution [0 ],
484- width = resolution [1 ],
485- dtype = jnp .bfloat16 ,
486- vae_scale_factor = vae_scale_factor ,
487- rng = rng ,
488+ batch_size = global_batch_size ,
489+ num_channels_latents = num_channels_latents ,
490+ height = resolution [0 ],
491+ width = resolution [1 ],
492+ dtype = jnp .bfloat16 ,
493+ vae_scale_factor = vae_scale_factor ,
494+ rng = rng ,
488495 )
489-
496+
490497 # move inputs to device and shard
491498 s0 = time .perf_counter ()
492499 data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
@@ -509,7 +516,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
509516 timesteps = time_shift (mu , 1.0 , timesteps )
510517 c_ts = timesteps [:- 1 ]
511518 p_ts = timesteps [1 :]
512- #validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
519+ # validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
513520 p_run_inference = p_jitted .get (resolution , None )
514521 if p_run_inference is None :
515522 print ("FN not found, compiling..." )
@@ -537,14 +544,14 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
537544 _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
538545 s0 = time .perf_counter ()
539546 imgs = p_run_inference (
540- states ,
541- latents = latents ,
542- latent_image_ids = latent_image_ids ,
543- prompt_embeds = prompt_embeds ,
544- txt_ids = text_ids ,
545- vec = pooled_prompt_embeds ,
547+ states ,
548+ latents = latents ,
549+ latent_image_ids = latent_image_ids ,
550+ prompt_embeds = prompt_embeds ,
551+ txt_ids = text_ids ,
552+ vec = pooled_prompt_embeds ,
546553 ).block_until_ready ()
547- recorded_times [resolution ] = ( time .perf_counter () - s0 )
554+ recorded_times [resolution ] = time .perf_counter () - s0
548555 max_logging .log (f"inference time: { recorded_times [resolution ]} " )
549556 s0 = time .perf_counter ()
550557 imgs = jax .experimental .multihost_utils .process_allgather (imgs , tiled = True )
@@ -566,6 +573,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
566573
567574 return imgs
568575
576+
569577def main (argv : Sequence [str ]) -> None :
570578 pyconfig .initialize (argv )
571579 run (pyconfig .config )
0 commit comments