2323import jax
2424import jax .numpy as jnp
2525from jax .sharding import PartitionSpec as P
26+ import flax .linen as nn
2627from flax .linen import partitioning as nn_partitioning
2728
28- from maxdiffusion import (
29- FlaxEulerDiscreteScheduler ,
30- )
31-
32-
3329from maxdiffusion import pyconfig , max_utils
3430from maxdiffusion .image_processor import VaeImageProcessor
35- from maxdiffusion .maxdiffusion_utils import (get_add_time_ids , rescale_noise_cfg , load_sdxllightning_unet )
31+ from maxdiffusion .maxdiffusion_utils import (
32+ get_add_time_ids ,
33+ rescale_noise_cfg ,
34+ load_sdxllightning_unet ,
35+ maybe_load_lora ,
36+ create_scheduler ,
37+ )
3638
3739from maxdiffusion .trainers .sdxl_trainer import (StableDiffusionXLTrainer )
3840
@@ -82,7 +84,6 @@ def apply_classifier_free_guidance(noise_pred, guidance_scale):
8284 lambda _ : noise_pred ,
8385 operand = None ,
8486 )
85-
8687 latents , scheduler_state = pipeline .scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
8788
8889 return latents , scheduler_state , state
@@ -217,6 +218,8 @@ def run(config):
217218 checkpoint_loader = GenerateSDXL (config )
218219 pipeline , params = checkpoint_loader .load_checkpoint ()
219220
221+ noise_scheduler , noise_scheduler_state = create_scheduler (pipeline .scheduler .config , config )
222+
220223 weights_init_fn = functools .partial (pipeline .unet .init_weights , rng = checkpoint_loader .rng )
221224 unboxed_abstract_state , _ , _ = max_utils .get_abstract_state (
222225 pipeline .unet , None , config , checkpoint_loader .mesh , weights_init_fn , False
@@ -228,20 +231,24 @@ def run(config):
228231 if unet_params :
229232 params ["unet" ] = unet_params
230233
234+ # maybe load lora and create interceptor
235+ params , lora_interceptor = maybe_load_lora (config , pipeline , params )
236+
231237 if config .lightning_repo :
232238 pipeline , params = load_sdxllightning_unet (config , pipeline , params )
233239
234- # Don't restore the train state to save memory , just restore params
240+ # Don't restore the full train state, instead , just restore params
235241 # and create an inference state.
236- unet_state , unet_state_shardings = max_utils .setup_initial_state (
237- model = pipeline .unet ,
238- tx = None ,
239- config = config ,
240- mesh = checkpoint_loader .mesh ,
241- weights_init_fn = weights_init_fn ,
242- model_params = params .get ("unet" , None ),
243- training = False ,
244- )
242+ with nn .intercept_methods (lora_interceptor ):
243+ unet_state , unet_state_shardings = max_utils .setup_initial_state (
244+ model = pipeline .unet ,
245+ tx = None ,
246+ config = config ,
247+ mesh = checkpoint_loader .mesh ,
248+ weights_init_fn = weights_init_fn ,
249+ model_params = params .get ("unet" , None ),
250+ training = False ,
251+ )
245252
246253 vae_state , vae_state_shardings = checkpoint_loader .create_vae_state (
247254 pipeline , params , checkpoint_item_name = "vae_state" , is_training = False
@@ -267,14 +274,6 @@ def run(config):
267274 states ["text_encoder_state" ] = text_encoder_state
268275 states ["text_encoder_2_state" ] = text_encoder_2_state
269276
270- noise_scheduler , noise_scheduler_state = FlaxEulerDiscreteScheduler .from_pretrained (
271- config .pretrained_model_name_or_path ,
272- revision = config .revision ,
273- subfolder = "scheduler" ,
274- dtype = jnp .float32 ,
275- timestep_spacing = "trailing" ,
276- )
277-
278277 pipeline .scheduler = noise_scheduler
279278 params ["scheduler" ] = noise_scheduler_state
280279
@@ -293,10 +292,12 @@ def run(config):
293292 )
294293
295294 s = time .time ()
296- p_run_inference (states ).block_until_ready ()
295+ with nn .intercept_methods (lora_interceptor ):
296+ p_run_inference (states ).block_until_ready ()
297297 print ("compile time: " , (time .time () - s ))
298298 s = time .time ()
299- images = p_run_inference (states ).block_until_ready ()
299+ with nn .intercept_methods (lora_interceptor ):
300+ images = p_run_inference (states ).block_until_ready ()
300301 print ("inference time: " , (time .time () - s ))
301302 images = jax .experimental .multihost_utils .process_allgather (images )
302303 numpy_images = np .array (images )
0 commit comments