@@ -77,7 +77,7 @@ def unpack(x: Array, height: int, width: int) -> Array:
7777
7878
7979def vae_decode (latents , vae , state , config , resolution ):
80- img = unpack (x = latents .astype (jnp .float32 ), height = resolution , width = resolution )
80+ img = unpack (x = latents .astype (jnp .float32 ), height = resolution [ 1 ] , width = resolution [ 0 ] )
8181 img = img / vae .config .scaling_factor + vae .config .shift_factor
8282 img = vae .apply ({"params" : state .params }, img , deterministic = True , method = vae .decode ).sample
8383 return img
@@ -322,15 +322,6 @@ def run(config):
322322 )
323323
324324 num_channels_latents = transformer .in_channels // 4
325- latents , latent_image_ids = prepare_latents (
326- batch_size = global_batch_size ,
327- num_channels_latents = num_channels_latents ,
328- height = config .resolution ,
329- width = config .resolution ,
330- dtype = jnp .bfloat16 ,
331- vae_scale_factor = vae_scale_factor ,
332- rng = rng ,
333- )
334325
335326 # LOAD TEXT ENCODERS
336327 clip_text_encoder = FlaxCLIPTextModel .from_pretrained (
@@ -352,17 +343,6 @@ def run(config):
352343 t5_encoder .params = jax .tree_util .tree_map (lambda x : x .astype (jnp .bfloat16 ), t5_encoder .params )
353344 t5_encoder .params = jax .tree_util .tree_map (partial_device_put_replicated , t5_encoder .params )
354345
355- prompt_embeds , pooled_prompt_embeds , text_ids = encode_prompt (
356- prompt = config .prompt ,
357- prompt_2 = config .prompt_2 ,
358- clip_tokenizer = clip_tokenizer ,
359- clip_text_encoder = clip_text_encoder ,
360- t5_tokenizer = t5_tokenizer ,
361- t5_text_encoder = t5_encoder ,
362- num_images_per_prompt = global_batch_size ,
363- max_sequence_length = config .max_sequence_length ,
364- )
365-
366346 def validate_inputs (latents , latent_image_ids , prompt_embeds , text_ids , timesteps , guidance , pooled_prompt_embeds ):
367347 print ("latents.shape: " , latents .shape , latents .dtype )
368348 print ("latent_image_ids.shape: " , latent_image_ids .shape , latent_image_ids .dtype )
@@ -374,13 +354,6 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
374354
375355 guidance = jnp .asarray ([config .guidance_scale ] * global_batch_size , dtype = jnp .bfloat16 )
376356
377- # move inputs to device and shard
378- data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
379- prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
380- text_ids = jax .device_put (text_ids )
381- guidance = jax .device_put (guidance , data_sharding )
382- pooled_prompt_embeds = jax .device_put (pooled_prompt_embeds , data_sharding )
383-
384357 if config .offload_encoders :
385358 cpus = jax .devices ("cpu" )
386359 t5_encoder .params = jax .device_put (t5_encoder .params , device = cpus [0 ])
@@ -427,58 +400,110 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
427400 states ["transformer" ] = transformer_state
428401 states ["vae" ] = vae_state
429402
430- #validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
431-
432- resolutions = [1024 , 768 , 512 ]
403+ resolutions = [
404+ (768 , 768 ),
405+ (768 , 1024 ),
406+ (1024 , 768 ),
407+ (1024 , 1024 ),
408+ (896 , 1152 ),
409+ (1152 , 896 ),
410+ (1920 , 1080 ),
411+ (1080 , 1920 )
412+ ]
433413 p_jitted = {}
434414 for resolution in resolutions :
435- latents , latent_image_ids = prepare_latents (
436- batch_size = global_batch_size ,
437- num_channels_latents = num_channels_latents ,
438- height = resolution ,
439- width = resolution ,
440- dtype = jnp .bfloat16 ,
441- vae_scale_factor = vae_scale_factor ,
442- rng = rng ,
443- )
444- latents = jax .device_put (latents , data_sharding )
445- latent_image_ids = jax .device_put (latent_image_ids )
446-
447- # Setup timesteps
448- timesteps = jnp .linspace (1 , 0 , config .num_inference_steps + 1 )
449- # shifting the schedule to favor high timesteps for higher signal images
450- if config .time_shift :
451- # estimate mu based on linear estimation between two points
452- lin_function = get_lin_function (x1 = config .max_sequence_length , y1 = config .base_shift , y2 = config .max_shift )
453- mu = lin_function (latents .shape [1 ])
454- timesteps = time_shift (mu , 1.0 , timesteps )
455- c_ts = timesteps [:- 1 ]
456- p_ts = timesteps [1 :]
457-
458- p_run_inference = jax .jit (
459- functools .partial (
460- run_inference ,
461- transformer = transformer ,
462- vae = vae ,
463- config = config ,
464- resolution = resolution ,
465- mesh = mesh ,
466- latents = latents ,
467- latent_image_ids = latent_image_ids ,
468- prompt_embeds = prompt_embeds ,
469- txt_ids = text_ids ,
470- vec = pooled_prompt_embeds ,
471- guidance_vec = guidance ,
472- c_ts = c_ts ,
473- p_ts = p_ts ,
474- ),
475- in_shardings = (state_shardings ,),
476- out_shardings = None ,
477- )
478- with ExitStack () as stack :
479- _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
480- p_run_inference (states ).block_until_ready ()
481- p_jitted [resolution ] = p_run_inference
415+ max_logging .log (f"Resolutions: { resolution } " )
416+ for _ in range (5 ):
417+ s0 = time .perf_counter ()
418+ prompt_embeds , pooled_prompt_embeds , text_ids = encode_prompt (
419+ prompt = config .prompt ,
420+ prompt_2 = config .prompt_2 ,
421+ clip_tokenizer = clip_tokenizer ,
422+ clip_text_encoder = clip_text_encoder ,
423+ t5_tokenizer = t5_tokenizer ,
424+ t5_text_encoder = t5_encoder ,
425+ num_images_per_prompt = global_batch_size ,
426+ max_sequence_length = config .max_sequence_length ,
427+ )
428+ max_logging .log (f"text encoding time: { (time .perf_counter () - s0 )} " )
429+ latents , latent_image_ids = prepare_latents (
430+ batch_size = global_batch_size ,
431+ num_channels_latents = num_channels_latents ,
432+ height = resolution [1 ],
433+ width = resolution [0 ],
434+ dtype = jnp .bfloat16 ,
435+ vae_scale_factor = vae_scale_factor ,
436+ rng = rng ,
437+ )
438+
439+ # move inputs to device and shard
440+ s0 = time .perf_counter ()
441+ data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
442+ prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
443+ text_ids = jax .device_put (text_ids )
444+ guidance = jax .device_put (guidance , data_sharding )
445+ pooled_prompt_embeds = jax .device_put (pooled_prompt_embeds , data_sharding )
446+ latents = jax .device_put (latents , data_sharding )
447+ latent_image_ids = jax .device_put (latent_image_ids )
448+ max_logging .log (f"Moving to device time: { (time .perf_counter () - s0 )} " )
449+
450+ # Setup timesteps
451+ timesteps = jnp .linspace (1 , 0 , config .num_inference_steps + 1 )
452+ # shifting the schedule to favor high timesteps for higher signal images
453+ if config .time_shift :
454+ # estimate mu based on linear estimation between two points
455+ lin_function = get_lin_function (x1 = config .max_sequence_length , y1 = config .base_shift , y2 = config .max_shift )
456+ mu = lin_function (latents .shape [1 ])
457+ timesteps = time_shift (mu , 1.0 , timesteps )
458+ c_ts = timesteps [:- 1 ]
459+ p_ts = timesteps [1 :]
460+ validate_inputs (latents , latent_image_ids , prompt_embeds , text_ids , timesteps , guidance , pooled_prompt_embeds )
461+ p_run_inference = p_jitted .get (resolution , None )
462+ if p_run_inference is None :
463+ print ("FN not found, compiling..." )
464+ p_run_inference = jax .jit (
465+ functools .partial (
466+ run_inference ,
467+ transformer = transformer ,
468+ vae = vae ,
469+ config = config ,
470+ resolution = resolution ,
471+ mesh = mesh ,
472+ latents = latents ,
473+ latent_image_ids = latent_image_ids ,
474+ prompt_embeds = prompt_embeds ,
475+ txt_ids = text_ids ,
476+ vec = pooled_prompt_embeds ,
477+ guidance_vec = guidance ,
478+ c_ts = c_ts ,
479+ p_ts = p_ts ,
480+ ),
481+ )
482+ p_jitted [resolution ] = p_run_inference
483+ with ExitStack () as stack :
484+ _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
485+ s0 = time .perf_counter ()
486+ imgs = p_run_inference (
487+ states ,
488+ latents = latents ,
489+ latent_image_ids = latent_image_ids ,
490+ prompt_embeds = prompt_embeds ,
491+ txt_ids = text_ids ,
492+ vec = pooled_prompt_embeds ,
493+ ).block_until_ready ()
494+ max_logging .log (f"inference time: { (time .perf_counter () - s0 )} " )
495+ s0 = time .perf_counter ()
496+ imgs = jax .experimental .multihost_utils .process_allgather (imgs , tiled = True )
497+ max_logging .log (f"Gathering all time: { (time .perf_counter () - s0 )} " )
498+ s0 = time .perf_counter ()
499+ imgs = np .array (imgs )
500+ imgs = (imgs * 0.5 + 0.5 ).clip (0 , 1 )
501+ imgs = np .transpose (imgs , (0 , 2 , 3 , 1 ))
502+ imgs = np .uint8 (imgs * 255 )
503+ for i , image in enumerate (imgs ):
504+ Image .fromarray (image ).save (f"flux_{ resolution [0 ]} _{ resolution [1 ]} _{ i } .png" )
505+ max_logging .log (f"Saving images time: { (time .perf_counter () - s0 )} " )
506+ get_memory_allocations ()
482507 breakpoint ()
483508 t0 = time .perf_counter ()
484509 with ExitStack () as stack :
0 commit comments