@@ -287,207 +287,207 @@ def run(config):
287287 global_batch_size = config .per_device_batch_size * jax .local_device_count ()
288288
289289 # LOAD VAE
290+ with mesh :
291+ vae , vae_params = FlaxAutoencoderKL .from_pretrained (
292+ config .pretrained_model_name_or_path , subfolder = "vae" , from_pt = True , use_safetensors = True , dtype = "bfloat16"
293+ )
290294
291- vae , vae_params = FlaxAutoencoderKL .from_pretrained (
292- config .pretrained_model_name_or_path , subfolder = "vae" , from_pt = True , use_safetensors = True , dtype = "bfloat16"
293- )
295+ weights_init_fn = functools .partial (vae .init_weights , rng = rng )
296+ vae_state , vae_state_shardings = setup_initial_state (
297+ model = vae ,
298+ tx = None ,
299+ config = config ,
300+ mesh = mesh ,
301+ weights_init_fn = weights_init_fn ,
302+ model_params = vae_params ,
303+ training = False ,
304+ )
294305
295- weights_init_fn = functools .partial (vae .init_weights , rng = rng )
296- vae_state , vae_state_shardings = setup_initial_state (
297- model = vae ,
298- tx = None ,
299- config = config ,
300- mesh = mesh ,
301- weights_init_fn = weights_init_fn ,
302- model_params = vae_params ,
303- training = False ,
304- )
306+ vae_scale_factor = 2 ** (len (vae .config .block_out_channels ) - 1 )
305307
306- vae_scale_factor = 2 ** (len (vae .config .block_out_channels ) - 1 )
307-
308- # LOAD TRANSFORMER
309- flash_block_sizes = get_flash_block_sizes (config )
310- transformer = FluxTransformer2DModel .from_config (
311- config .pretrained_model_name_or_path ,
312- subfolder = "transformer" ,
313- mesh = mesh ,
314- split_head_dim = config .split_head_dim ,
315- attention_kernel = config .attention ,
316- flash_block_sizes = flash_block_sizes ,
317- dtype = config .activations_dtype ,
318- weights_dtype = config .weights_dtype ,
319- precision = get_precision (config ),
320- )
308+ # LOAD TRANSFORMER
309+ flash_block_sizes = get_flash_block_sizes (config )
310+ transformer = FluxTransformer2DModel .from_config (
311+ config .pretrained_model_name_or_path ,
312+ subfolder = "transformer" ,
313+ mesh = mesh ,
314+ split_head_dim = config .split_head_dim ,
315+ attention_kernel = config .attention ,
316+ flash_block_sizes = flash_block_sizes ,
317+ dtype = config .activations_dtype ,
318+ weights_dtype = config .weights_dtype ,
319+ precision = get_precision (config ),
320+ )
321321
322- num_channels_latents = transformer .in_channels // 4
323- latents , latent_image_ids = prepare_latents (
324- batch_size = global_batch_size ,
325- num_channels_latents = num_channels_latents ,
326- height = config .resolution ,
327- width = config .resolution ,
328- dtype = jnp .bfloat16 ,
329- vae_scale_factor = vae_scale_factor ,
330- rng = rng ,
331- )
322+ num_channels_latents = transformer .in_channels // 4
323+ latents , latent_image_ids = prepare_latents (
324+ batch_size = global_batch_size ,
325+ num_channels_latents = num_channels_latents ,
326+ height = config .resolution ,
327+ width = config .resolution ,
328+ dtype = jnp .bfloat16 ,
329+ vae_scale_factor = vae_scale_factor ,
330+ rng = rng ,
331+ )
332332
333- # LOAD TEXT ENCODERS
334- clip_text_encoder = FlaxCLIPTextModel .from_pretrained (
335- config .pretrained_model_name_or_path , subfolder = "text_encoder" , from_pt = True , dtype = config .weights_dtype
336- )
337- clip_tokenizer = CLIPTokenizer .from_pretrained (
338- config .pretrained_model_name_or_path , subfolder = "tokenizer" , dtype = config .weights_dtype
339- )
333+ # LOAD TEXT ENCODERS
334+ clip_text_encoder = FlaxCLIPTextModel .from_pretrained (
335+ config .pretrained_model_name_or_path , subfolder = "text_encoder" , from_pt = True , dtype = config .weights_dtype
336+ )
337+ clip_tokenizer = CLIPTokenizer .from_pretrained (
338+ config .pretrained_model_name_or_path , subfolder = "tokenizer" , dtype = config .weights_dtype
339+ )
340340
341- t5_encoder = FlaxT5EncoderModel .from_pretrained (config .t5xxl_model_name_or_path , dtype = config .weights_dtype )
342- t5_tokenizer = AutoTokenizer .from_pretrained (
343- config .t5xxl_model_name_or_path , max_length = config .max_sequence_length , use_fast = True
344- )
341+ t5_encoder = FlaxT5EncoderModel .from_pretrained (config .t5xxl_model_name_or_path , dtype = config .weights_dtype )
342+ t5_tokenizer = AutoTokenizer .from_pretrained (
343+ config .t5xxl_model_name_or_path , max_length = config .max_sequence_length , use_fast = True
344+ )
345345
346- encoders_sharding = NamedSharding (mesh , P ())
347- partial_device_put_replicated = functools .partial (device_put_replicated , sharding = encoders_sharding )
348- clip_text_encoder .params = jax .tree_util .tree_map (lambda x : x .astype (jnp .bfloat16 ), clip_text_encoder .params )
349- clip_text_encoder .params = jax .tree_util .tree_map (partial_device_put_replicated , clip_text_encoder .params )
350- t5_encoder .params = jax .tree_util .tree_map (lambda x : x .astype (jnp .bfloat16 ), t5_encoder .params )
351- t5_encoder .params = jax .tree_util .tree_map (partial_device_put_replicated , t5_encoder .params )
352-
353- prompt_embeds , pooled_prompt_embeds , text_ids = encode_prompt (
354- prompt = config .prompt ,
355- prompt_2 = config .prompt_2 ,
356- clip_tokenizer = clip_tokenizer ,
357- clip_text_encoder = clip_text_encoder ,
358- t5_tokenizer = t5_tokenizer ,
359- t5_text_encoder = t5_encoder ,
360- num_images_per_prompt = global_batch_size ,
361- max_sequence_length = config .max_sequence_length ,
362- )
346+ encoders_sharding = NamedSharding (mesh , P ())
347+ partial_device_put_replicated = functools .partial (device_put_replicated , sharding = encoders_sharding )
348+ clip_text_encoder .params = jax .tree_util .tree_map (lambda x : x .astype (jnp .bfloat16 ), clip_text_encoder .params )
349+ clip_text_encoder .params = jax .tree_util .tree_map (partial_device_put_replicated , clip_text_encoder .params )
350+ t5_encoder .params = jax .tree_util .tree_map (lambda x : x .astype (jnp .bfloat16 ), t5_encoder .params )
351+ t5_encoder .params = jax .tree_util .tree_map (partial_device_put_replicated , t5_encoder .params )
352+
353+ prompt_embeds , pooled_prompt_embeds , text_ids = encode_prompt (
354+ prompt = config .prompt ,
355+ prompt_2 = config .prompt_2 ,
356+ clip_tokenizer = clip_tokenizer ,
357+ clip_text_encoder = clip_text_encoder ,
358+ t5_tokenizer = t5_tokenizer ,
359+ t5_text_encoder = t5_encoder ,
360+ num_images_per_prompt = global_batch_size ,
361+ max_sequence_length = config .max_sequence_length ,
362+ )
363363
364- def validate_inputs (latents , latent_image_ids , prompt_embeds , text_ids , timesteps , guidance , pooled_prompt_embeds ):
365- print ("latents.shape: " , latents .shape , latents .dtype )
366- print ("latent_image_ids.shape: " , latent_image_ids .shape , latent_image_ids .dtype )
367- print ("text_ids.shape: " , text_ids .shape , text_ids .dtype )
368- print ("prompt_embeds: " , prompt_embeds .shape , prompt_embeds .dtype )
369- print ("timesteps.shape: " , timesteps .shape , timesteps .dtype )
370- print ("guidance.shape: " , guidance .shape , guidance .dtype )
371- print ("pooled_prompt_embeds.shape: " , pooled_prompt_embeds .shape , pooled_prompt_embeds .dtype )
372-
373- guidance = jnp .asarray ([config .guidance_scale ] * global_batch_size , dtype = jnp .bfloat16 )
374-
375- # move inputs to device and shard
376- data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
377- latents = jax .device_put (latents , data_sharding )
378- latent_image_ids = jax .device_put (latent_image_ids )
379- prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
380- text_ids = jax .device_put (text_ids )
381- guidance = jax .device_put (guidance , data_sharding )
382- pooled_prompt_embeds = jax .device_put (pooled_prompt_embeds , data_sharding )
383-
384- if config .offload_encoders :
385- cpus = jax .devices ("cpu" )
386- t5_encoder .params = jax .device_put (t5_encoder .params , device = cpus [0 ])
387-
388- get_memory_allocations ()
389- # evaluate shapes
390- transformer_eval_params = transformer .init_weights (
391- rngs = rng , max_sequence_length = config .max_sequence_length , eval_only = True
392- )
364+ def validate_inputs (latents , latent_image_ids , prompt_embeds , text_ids , timesteps , guidance , pooled_prompt_embeds ):
365+ print ("latents.shape: " , latents .shape , latents .dtype )
366+ print ("latent_image_ids.shape: " , latent_image_ids .shape , latent_image_ids .dtype )
367+ print ("text_ids.shape: " , text_ids .shape , text_ids .dtype )
368+ print ("prompt_embeds: " , prompt_embeds .shape , prompt_embeds .dtype )
369+ print ("timesteps.shape: " , timesteps .shape , timesteps .dtype )
370+ print ("guidance.shape: " , guidance .shape , guidance .dtype )
371+ print ("pooled_prompt_embeds.shape: " , pooled_prompt_embeds .shape , pooled_prompt_embeds .dtype )
372+
373+ guidance = jnp .asarray ([config .guidance_scale ] * global_batch_size , dtype = jnp .bfloat16 )
374+
375+ # move inputs to device and shard
376+ data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
377+ latents = jax .device_put (latents , data_sharding )
378+ latent_image_ids = jax .device_put (latent_image_ids )
379+ prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
380+ text_ids = jax .device_put (text_ids )
381+ guidance = jax .device_put (guidance , data_sharding )
382+ pooled_prompt_embeds = jax .device_put (pooled_prompt_embeds , data_sharding )
383+
384+ if config .offload_encoders :
385+ cpus = jax .devices ("cpu" )
386+ t5_encoder .params = jax .device_put (t5_encoder .params , device = cpus [0 ])
387+
388+ get_memory_allocations ()
389+ # evaluate shapes
390+ transformer_eval_params = transformer .init_weights (
391+ rngs = rng , max_sequence_length = config .max_sequence_length , eval_only = True
392+ )
393393
394- # loads pretrained weights
395- transformer_params = load_flow_model (config .flux_name , transformer_eval_params , "cpu" )
396- params = {}
397- params ["transformer" ] = transformer_params
398- # maybe load lora and create interceptor
399- lora_loader = FluxLoraLoaderMixin ()
400- params , lora_interceptors = maybe_load_flux_lora (config , lora_loader , params )
401- transformer_params = params ["transformer" ]
402- # create transformer state
403- weights_init_fn = functools .partial (
404- transformer .init_weights , rngs = rng , max_sequence_length = config .max_sequence_length , eval_only = False
405- )
406- with ExitStack () as stack :
407- _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
408- transformer_state , transformer_state_shardings = setup_initial_state (
409- model = transformer ,
410- tx = None ,
411- config = config ,
412- mesh = mesh ,
413- weights_init_fn = weights_init_fn ,
414- model_params = None ,
415- training = False ,
394+ # loads pretrained weights
395+ transformer_params = load_flow_model (config .flux_name , transformer_eval_params , "cpu" )
396+ params = {}
397+ params ["transformer" ] = transformer_params
398+ # maybe load lora and create interceptor
399+ lora_loader = FluxLoraLoaderMixin ()
400+ params , lora_interceptors = maybe_load_flux_lora (config , lora_loader , params )
401+ transformer_params = params ["transformer" ]
402+ # create transformer state
403+ weights_init_fn = functools .partial (
404+ transformer .init_weights , rngs = rng , max_sequence_length = config .max_sequence_length , eval_only = False
416405 )
417- transformer_state = transformer_state .replace (params = transformer_params )
418- transformer_state = jax .device_put (transformer_state , transformer_state_shardings )
419- get_memory_allocations ()
420-
421- states = {}
422- state_shardings = {}
423-
424- state_shardings ["transformer" ] = transformer_state_shardings
425- state_shardings ["vae" ] = vae_state_shardings
426-
427- states ["transformer" ] = transformer_state
428- states ["vae" ] = vae_state
429-
430- # Setup timesteps
431- timesteps = jnp .linspace (1 , 0 , config .num_inference_steps + 1 )
432- # shifting the schedule to favor high timesteps for higher signal images
433- if config .time_shift :
434- # estimate mu based on linear estimation between two points
435- lin_function = get_lin_function (x1 = config .max_sequence_length , y1 = config .base_shift , y2 = config .max_shift )
436- mu = lin_function (latents .shape [1 ])
437- timesteps = time_shift (mu , 1.0 , timesteps )
438- c_ts = timesteps [:- 1 ]
439- p_ts = timesteps [1 :]
440-
441- validate_inputs (latents , latent_image_ids , prompt_embeds , text_ids , timesteps , guidance , pooled_prompt_embeds )
442-
443- p_run_inference = jax .jit (
444- functools .partial (
445- run_inference ,
446- transformer = transformer ,
447- vae = vae ,
406+ with ExitStack () as stack :
407+ _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
408+ transformer_state , transformer_state_shardings = setup_initial_state (
409+ model = transformer ,
410+ tx = None ,
448411 config = config ,
449412 mesh = mesh ,
450- latents = latents ,
451- latent_image_ids = latent_image_ids ,
452- prompt_embeds = prompt_embeds ,
453- txt_ids = text_ids ,
454- vec = pooled_prompt_embeds ,
455- guidance_vec = guidance ,
456- c_ts = c_ts ,
457- p_ts = p_ts ,
458- ),
459- in_shardings = (state_shardings ,),
460- out_shardings = None ,
461- )
462- t0 = time .perf_counter ()
463- with ExitStack () as stack :
464- _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
465- p_run_inference (states ).block_until_ready ()
466- t1 = time .perf_counter ()
467- max_logging .log (f"Compile time: { t1 - t0 :.1f} s." )
468-
469- t0 = time .perf_counter ()
470- with ExitStack () as stack , jax .profiler .trace ("/tmp/trace/" ):
471- _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
472- imgs = p_run_inference (states ).block_until_ready ()
473- t1 = time .perf_counter ()
474- max_logging .log (f"Inference time: { t1 - t0 :.1f} s." )
475-
476- t0 = time .perf_counter ()
477- with ExitStack () as stack :
478- _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
479- imgs = p_run_inference (states ).block_until_ready ()
480- imgs = jax .experimental .multihost_utils .process_allgather (imgs , tiled = True )
481- t1 = time .perf_counter ()
482- max_logging .log (f"Inference time: { t1 - t0 :.1f} s." )
483- imgs = np .array (imgs )
484- imgs = (imgs * 0.5 + 0.5 ).clip (0 , 1 )
485- imgs = np .transpose (imgs , (0 , 2 , 3 , 1 ))
486- imgs = np .uint8 (imgs * 255 )
487- for i , image in enumerate (imgs ):
488- Image .fromarray (image ).save (f"flux_{ i } .png" )
489-
490- return imgs
413+ weights_init_fn = weights_init_fn ,
414+ model_params = None ,
415+ training = False ,
416+ )
417+ transformer_state = transformer_state .replace (params = transformer_params )
418+ transformer_state = jax .device_put (transformer_state , transformer_state_shardings )
419+ get_memory_allocations ()
420+
421+ states = {}
422+ state_shardings = {}
423+
424+ state_shardings ["transformer" ] = transformer_state_shardings
425+ state_shardings ["vae" ] = vae_state_shardings
426+
427+ states ["transformer" ] = transformer_state
428+ states ["vae" ] = vae_state
429+
430+ # Setup timesteps
431+ timesteps = jnp .linspace (1 , 0 , config .num_inference_steps + 1 )
432+ # shifting the schedule to favor high timesteps for higher signal images
433+ if config .time_shift :
434+ # estimate mu based on linear estimation between two points
435+ lin_function = get_lin_function (x1 = config .max_sequence_length , y1 = config .base_shift , y2 = config .max_shift )
436+ mu = lin_function (latents .shape [1 ])
437+ timesteps = time_shift (mu , 1.0 , timesteps )
438+ c_ts = timesteps [:- 1 ]
439+ p_ts = timesteps [1 :]
440+
441+ validate_inputs (latents , latent_image_ids , prompt_embeds , text_ids , timesteps , guidance , pooled_prompt_embeds )
442+
443+ p_run_inference = jax .jit (
444+ functools .partial (
445+ run_inference ,
446+ transformer = transformer ,
447+ vae = vae ,
448+ config = config ,
449+ mesh = mesh ,
450+ latents = latents ,
451+ latent_image_ids = latent_image_ids ,
452+ prompt_embeds = prompt_embeds ,
453+ txt_ids = text_ids ,
454+ vec = pooled_prompt_embeds ,
455+ guidance_vec = guidance ,
456+ c_ts = c_ts ,
457+ p_ts = p_ts ,
458+ ),
459+ in_shardings = (state_shardings ,),
460+ out_shardings = None ,
461+ )
462+ t0 = time .perf_counter ()
463+ with ExitStack () as stack :
464+ _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
465+ p_run_inference (states ).block_until_ready ()
466+ t1 = time .perf_counter ()
467+ max_logging .log (f"Compile time: { t1 - t0 :.1f} s." )
468+
469+ t0 = time .perf_counter ()
470+ with ExitStack () as stack , jax .profiler .trace ("/tmp/trace/" ):
471+ _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
472+ imgs = p_run_inference (states ).block_until_ready ()
473+ t1 = time .perf_counter ()
474+ max_logging .log (f"Inference time: { t1 - t0 :.1f} s." )
475+
476+ t0 = time .perf_counter ()
477+ with ExitStack () as stack :
478+ _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
479+ imgs = p_run_inference (states ).block_until_ready ()
480+ imgs = jax .experimental .multihost_utils .process_allgather (imgs , tiled = True )
481+ t1 = time .perf_counter ()
482+ max_logging .log (f"Inference time: { t1 - t0 :.1f} s." )
483+ imgs = np .array (imgs )
484+ imgs = (imgs * 0.5 + 0.5 ).clip (0 , 1 )
485+ imgs = np .transpose (imgs , (0 , 2 , 3 , 1 ))
486+ imgs = np .uint8 (imgs * 255 )
487+ for i , image in enumerate (imgs ):
488+ Image .fromarray (image ).save (f"flux_{ i } .png" )
489+
490+ return imgs
491491
492492
493493def main (argv : Sequence [str ]) -> None :
0 commit comments