1414limitations under the License.
1515"""
1616
17- import functools
18- from absl import app
19- from contextlib import ExitStack
2017from typing import Sequence
18+ from absl import app
2119import time
22-
23- import numpy as np
24- import jax
25- import jax .numpy as jnp
26- from jax .sharding import PartitionSpec as P
27- import flax .linen as nn
28- from flax .linen import partitioning as nn_partitioning
29-
30- from maxdiffusion import pyconfig , max_utils
31- from maxdiffusion .image_processor import VaeImageProcessor
32- from maxdiffusion .maxdiffusion_utils import (
33- get_add_time_ids ,
34- rescale_noise_cfg ,
35- load_sdxllightning_unet ,
36- maybe_load_sdxl_lora ,
37- create_scheduler ,
38- )
39-
40- from maxdiffusion .trainers .sdxl_trainer import (StableDiffusionXLTrainer )
41-
42- from maxdiffusion .checkpointing .checkpointing_utils import load_params_from_path
43-
44-
45- class GenerateSDXL (StableDiffusionXLTrainer ):
46-
47- def __init__ (self , config ):
48- super ().__init__ (config )
49-
50-
51- def loop_body (step , args , model , pipeline , added_cond_kwargs , prompt_embeds , guidance_scale , guidance_rescale , config ):
52- latents , scheduler_state , state = args
53-
54- if config .do_classifier_free_guidance :
55- latents_input = jnp .concatenate ([latents ] * 2 )
56- else :
57- latents_input = latents
58-
59- t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
60- timestep = jnp .broadcast_to (t , latents_input .shape [0 ])
61-
62- latents_input = pipeline .scheduler .scale_model_input (scheduler_state , latents_input , t )
63- noise_pred = model .apply (
64- {"params" : state .params },
65- jnp .array (latents_input ),
66- jnp .array (timestep , dtype = jnp .int32 ),
67- encoder_hidden_states = prompt_embeds ,
68- added_cond_kwargs = added_cond_kwargs ,
69- ).sample
70-
71- def apply_classifier_free_guidance (noise_pred , guidance_scale ):
72- noise_pred_uncond , noise_prediction_text = jnp .split (noise_pred , 2 , axis = 0 )
73- noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond )
74- return noise_pred , noise_prediction_text
75-
76- if config .do_classifier_free_guidance :
77- noise_pred , noise_prediction_text = apply_classifier_free_guidance (noise_pred , guidance_scale )
78-
79- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
80- # Helps solve overexposure problem when terminal SNR approaches zero.
81- # Empirical values recomended from the paper are guidance_scale=7.5 and guidance_rescale=0.7
82- noise_pred = jax .lax .cond (
83- guidance_rescale [0 ] > 0 ,
84- lambda _ : rescale_noise_cfg (noise_pred , noise_prediction_text , guidance_rescale ),
85- lambda _ : noise_pred ,
86- operand = None ,
87- )
88- latents , scheduler_state = pipeline .scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
89-
90- return latents , scheduler_state , state
91-
92-
93- def get_embeddings (prompt_ids , pipeline , params ):
94- te_1_inputs = prompt_ids [:, 0 , :]
95- te_2_inputs = prompt_ids [:, 1 , :]
96-
97- prompt_embeds = pipeline .text_encoder (te_1_inputs , params = params ["text_encoder" ], output_hidden_states = True )
98- prompt_embeds = prompt_embeds ["hidden_states" ][- 2 ]
99- prompt_embeds_2_out = pipeline .text_encoder_2 (te_2_inputs , params = params ["text_encoder_2" ], output_hidden_states = True )
100- prompt_embeds_2 = prompt_embeds_2_out ["hidden_states" ][- 2 ]
101- text_embeds = prompt_embeds_2_out ["text_embeds" ]
102- prompt_embeds = jnp .concatenate ([prompt_embeds , prompt_embeds_2 ], axis = - 1 )
103- return prompt_embeds , text_embeds
104-
105-
106- def tokenize (prompt , pipeline ):
107- inputs = []
108- for _tokenizer in [pipeline .tokenizer , pipeline .tokenizer_2 ]:
109- text_inputs = _tokenizer (
110- prompt , padding = "max_length" , max_length = _tokenizer .model_max_length , truncation = True , return_tensors = "np"
111- )
112- inputs .append (text_inputs .input_ids )
113- inputs = jnp .stack (inputs , axis = 1 )
114- return inputs
115-
116-
117- def get_unet_inputs (pipeline , params , states , config , rng , mesh , batch_size ):
118- data_sharding = jax .sharding .NamedSharding (mesh , P (* config .data_sharding ))
119-
120- vae_scale_factor = 2 ** (len (pipeline .vae .config .block_out_channels ) - 1 )
121- prompt_ids = [config .prompt ] * batch_size
122- prompt_ids = tokenize (prompt_ids , pipeline )
123- negative_prompt_ids = [config .negative_prompt ] * batch_size
124- negative_prompt_ids = tokenize (negative_prompt_ids , pipeline )
125- guidance_scale = config .guidance_scale
126- guidance_rescale = config .guidance_rescale
127- num_inference_steps = config .num_inference_steps
128- height = config .resolution
129- width = config .resolution
130- text_encoder_params = {
131- "text_encoder" : states ["text_encoder_state" ].params ,
132- "text_encoder_2" : states ["text_encoder_2_state" ].params ,
133- }
134- prompt_embeds , pooled_embeds = get_embeddings (prompt_ids , pipeline , text_encoder_params )
135-
136- batch_size = prompt_embeds .shape [0 ]
137- add_time_ids = get_add_time_ids (
138- (height , width ), (0 , 0 ), (height , width ), prompt_embeds .shape [0 ], dtype = prompt_embeds .dtype
139- )
140-
141- if config .do_classifier_free_guidance :
142- if negative_prompt_ids is None :
143- negative_prompt_embeds = jnp .zeros_like (prompt_embeds )
144- negative_pooled_embeds = jnp .zeros_like (pooled_embeds )
145- else :
146- negative_prompt_embeds , negative_pooled_embeds = get_embeddings (negative_prompt_ids , pipeline , text_encoder_params )
147-
148- prompt_embeds = jnp .concatenate ([negative_prompt_embeds , prompt_embeds ], axis = 0 )
149- add_text_embeds = jnp .concatenate ([negative_pooled_embeds , pooled_embeds ], axis = 0 )
150- add_time_ids = jnp .concatenate ([add_time_ids , add_time_ids ], axis = 0 )
151-
152- else :
153- add_text_embeds = pooled_embeds
154-
155- # Ensure model output will be `float32` before going into the scheduler
156- guidance_scale = jnp .array ([guidance_scale ], dtype = jnp .float32 )
157- guidance_rescale = jnp .array ([guidance_rescale ], dtype = jnp .float32 )
158-
159- latents_shape = (
160- batch_size ,
161- pipeline .unet .config .in_channels ,
162- height // vae_scale_factor ,
163- width // vae_scale_factor ,
164- )
165-
166- latents = jax .random .normal (rng , shape = latents_shape , dtype = jnp .float32 )
167-
168- scheduler_state = pipeline .scheduler .set_timesteps (
169- params ["scheduler" ], num_inference_steps = num_inference_steps , shape = latents .shape
170- )
171-
172- latents = latents * scheduler_state .init_noise_sigma
173-
174- added_cond_kwargs = {"text_embeds" : add_text_embeds , "time_ids" : add_time_ids }
175- latents = jax .device_put (latents , data_sharding )
176- prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
177- added_cond_kwargs ["text_embeds" ] = jax .device_put (added_cond_kwargs ["text_embeds" ], data_sharding )
178- added_cond_kwargs ["time_ids" ] = jax .device_put (added_cond_kwargs ["time_ids" ], data_sharding )
179-
180- return latents , prompt_embeds , added_cond_kwargs , guidance_scale , guidance_rescale , scheduler_state
181-
182-
183- def vae_decode (latents , state , pipeline ):
184- latents = 1 / pipeline .vae .config .scaling_factor * latents
185- image = pipeline .vae .apply ({"params" : state .params }, latents , method = pipeline .vae .decode ).sample
186- image = (image / 2 + 0.5 ).clip (0 , 1 ).transpose (0 , 2 , 3 , 1 )
187- return image
188-
189-
190- def run_inference (states , pipeline , params , config , rng , mesh , batch_size ):
191- unet_state = states ["unet_state" ]
192- vae_state = states ["vae_state" ]
193-
194- (latents , prompt_embeds , added_cond_kwargs , guidance_scale , guidance_rescale , scheduler_state ) = get_unet_inputs (
195- pipeline , params , states , config , rng , mesh , batch_size
196- )
197-
198- loop_body_p = functools .partial (
199- loop_body ,
200- model = pipeline .unet ,
201- pipeline = pipeline ,
202- added_cond_kwargs = added_cond_kwargs ,
203- prompt_embeds = prompt_embeds ,
204- guidance_scale = guidance_scale ,
205- guidance_rescale = guidance_rescale ,
206- config = config ,
207- )
208- vae_decode_p = functools .partial (vae_decode , pipeline = pipeline )
209-
210- with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
211- latents , _ , _ = jax .lax .fori_loop (0 , config .num_inference_steps , loop_body_p , (latents , scheduler_state , unet_state ))
212- image = vae_decode_p (latents , vae_state )
213- return image
214-
20+ from maxdiffusion import pyconfig , max_logging
21+ from maxdiffusion .inference .loader import InferenceLoader
22+ from maxdiffusion .inference .runner import DiffusionRunner
21523
21624def run (config ):
217- checkpoint_loader = GenerateSDXL (config )
218- mesh = checkpoint_loader .mesh
219- with mesh :
220- pipeline , params = checkpoint_loader .load_checkpoint ()
221-
222- noise_scheduler , noise_scheduler_state = create_scheduler (pipeline .scheduler .config , config )
223-
224- weights_init_fn = functools .partial (pipeline .unet .init_weights , rng = checkpoint_loader .rng )
225- unboxed_abstract_state , _ , _ = max_utils .get_abstract_state (
226- pipeline .unet , None , config , checkpoint_loader .mesh , weights_init_fn , False
227- )
228-
229- # load unet params from orbax checkpoint
230- unet_params = load_params_from_path (
231- config , checkpoint_loader .checkpoint_manager , unboxed_abstract_state .params , "unet_state"
232- )
233- if unet_params :
234- params ["unet" ] = unet_params
235-
236- # maybe load lora and create interceptor
237- params , lora_interceptors = maybe_load_sdxl_lora (config , pipeline , params )
238-
239- if config .lightning_repo :
240- pipeline , params = load_sdxllightning_unet (config , pipeline , params )
241-
242- # Don't restore the full train state, instead, just restore params
243- # and create an inference state.
244- with ExitStack () as stack :
245- _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
246- unet_state , unet_state_shardings = max_utils .setup_initial_state (
247- model = pipeline .unet ,
248- tx = None ,
249- config = config ,
250- mesh = checkpoint_loader .mesh ,
251- weights_init_fn = weights_init_fn ,
252- model_params = None ,
253- training = False ,
254- )
255- unet_state = unet_state .replace (params = params .get ("unet" , None ))
256- unet_state = jax .device_put (unet_state , unet_state_shardings )
257-
258- vae_state , vae_state_shardings = checkpoint_loader .create_vae_state (
259- pipeline , params , checkpoint_item_name = "vae_state" , is_training = False
260- )
261- with ExitStack () as stack :
262- _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
263- text_encoder_state , text_encoder_state_shardings = checkpoint_loader .create_text_encoder_state (
264- pipeline , params , checkpoint_item_name = "text_encoder_state" , is_training = False
265- )
266-
267- text_encoder_2_state , text_encoder_2_state_shardings = checkpoint_loader .create_text_encoder_2_state (
268- pipeline , params , checkpoint_item_name = "text_encoder_2_state" , is_training = False
269- )
270- states = {}
271- state_shardings = {}
272-
273- state_shardings ["vae_state" ] = vae_state_shardings
274- state_shardings ["unet_state" ] = unet_state_shardings
275- state_shardings ["text_encoder_state" ] = text_encoder_state_shardings
276- state_shardings ["text_encoder_2_state" ] = text_encoder_2_state_shardings
277-
278- states ["unet_state" ] = unet_state
279- states ["vae_state" ] = vae_state
280- states ["text_encoder_state" ] = text_encoder_state
281- states ["text_encoder_2_state" ] = text_encoder_2_state
282-
283- pipeline .scheduler = noise_scheduler
284- params ["scheduler" ] = noise_scheduler_state
285-
286- p_run_inference = jax .jit (
287- functools .partial (
288- run_inference ,
289- pipeline = pipeline ,
290- params = params ,
291- config = config ,
292- rng = checkpoint_loader .rng ,
293- mesh = checkpoint_loader .mesh ,
294- batch_size = checkpoint_loader .total_train_batch_size ,
295- ),
296- in_shardings = (state_shardings ,),
297- out_shardings = None ,
298- )
299-
300- s = time .time ()
301- with ExitStack () as stack :
302- _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
303- p_run_inference (states ).block_until_ready ()
304- print ("compile time: " , (time .time () - s ))
305- s = time .time ()
306- with ExitStack () as stack :
307- _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
308- images = p_run_inference (states ).block_until_ready ()
309- print ("inference time: " , (time .time () - s ))
310- images = jax .experimental .multihost_utils .process_allgather (images , tiled = True )
311- numpy_images = np .array (images )
312- images = VaeImageProcessor .numpy_to_pil (numpy_images )
313- for i , image in enumerate (images ):
314- image .save (f"image_sdxl_{ i } .png" )
315-
316- return images
317-
25+ # 1. Load Model
26+ max_logging .log ("Initializing InferenceLoader..." )
27+ loaded_model = InferenceLoader .load (config )
28+
29+ # 2. Initialize Runner
30+ max_logging .log ("Initializing DiffusionRunner..." )
31+ runner = DiffusionRunner (loaded_model , config )
32+
33+ # 3. Run Inference
34+ max_logging .log ("Starting Inference..." )
35+ t0 = time .perf_counter ()
36+ pil_images = runner .run ()
37+ t1 = time .perf_counter ()
38+ max_logging .log (f"Inference time: { t1 - t0 :.2f} s" )
39+
40+ # 4. Save Images
41+ for i , image in enumerate (pil_images ):
42+ save_path = f"image_sdxl_{ i } .png"
43+ image .save (save_path )
44+ max_logging .log (f"Saved image to { save_path } " )
45+
46+ return pil_images
31847
31948def main (argv : Sequence [str ]) -> None :
32049 pyconfig .initialize (argv )
32150 run (pyconfig .config )
32251
323-
32452if __name__ == "__main__" :
325- app .run (main )
53+ app .run (main )
0 commit comments