@@ -33,85 +33,86 @@ def run(config):
3333 from maxdiffusion .checkpointing .flux_checkpointer import FluxCheckpointer
3434
3535 checkpoint_loader = FluxCheckpointer (config , "FLUX_CHECKPOINT" )
36- pipeline , params = checkpoint_loader .load_checkpoint ()
37-
38- if not params :
39- ## VAE
40- weights_init_fn = functools .partial (pipeline .vae .init_weights , rng = checkpoint_loader .rng )
41- unboxed_abstract_state , _ , _ = max_utils .get_abstract_state (
42- pipeline .vae , None , config , checkpoint_loader .mesh , weights_init_fn , False
43- )
44- # load unet params from orbax checkpoint
45- vae_params = load_params_from_path (
46- config , checkpoint_loader .checkpoint_manager , unboxed_abstract_state .params , "vae_state"
47- )
48-
49- vae_state = {"params" : vae_params }
50-
51- ## Flux
52- weights_init_fn = functools .partial (
53- pipeline .flux .init_weights , rngs = checkpoint_loader .rng , max_sequence_length = config .max_sequence_length
54- )
55-
56- unboxed_abstract_state , _ , _ = max_utils .get_abstract_state (
57- pipeline .flux , None , config , checkpoint_loader .mesh , weights_init_fn , False
58- )
59- # load unet params from orbax checkpoint
60- flux_params = load_params_from_path (
61- config , checkpoint_loader .checkpoint_manager , unboxed_abstract_state .params , "flux_state"
62- )
63- flux_state = {"params" : flux_params }
64- else :
65- weights_init_fn = functools .partial (
66- pipeline .flux .init_weights ,
67- rngs = checkpoint_loader .rng ,
68- max_sequence_length = config .max_sequence_length ,
69- eval_only = False ,
70- )
71- transformer_state , flux_state_shardings = setup_initial_state (
72- model = pipeline .flux ,
73- tx = None ,
74- config = config ,
75- mesh = checkpoint_loader .mesh ,
76- weights_init_fn = weights_init_fn ,
77- model_params = None ,
78- training = False ,
79- )
80- transformer_state = transformer_state .replace (params = params ["flux_transformer_params" ])
81- transformer_state = jax .device_put (transformer_state , flux_state_shardings )
82-
83- weights_init_fn = functools .partial (pipeline .vae .init_weights , rng = checkpoint_loader .rng )
84- vae_state , _ = setup_initial_state (
85- model = pipeline .vae ,
86- tx = None ,
87- config = config ,
88- mesh = checkpoint_loader .mesh ,
89- weights_init_fn = weights_init_fn ,
90- model_params = params ["flux_vae" ],
91- training = False ,
92- )
93-
94- vae_state = {"params" : vae_state .params }
95- flux_state = {"params" : transformer_state .params }
96-
97- t0 = time .perf_counter ()
98- with ExitStack ():
99- imgs = pipeline (flux_params = flux_state , timesteps = config .num_inference_steps , vae_params = vae_state ).block_until_ready ()
100- t1 = time .perf_counter ()
101- max_logging .log (f"Compile time: { t1 - t0 :.1f} s." )
102-
103- t0 = time .perf_counter ()
104- with ExitStack ():
105- imgs = pipeline (flux_params = flux_state , timesteps = config .num_inference_steps , vae_params = vae_state ).block_until_ready ()
106- imgs = jax .experimental .multihost_utils .process_allgather (imgs , tiled = True )
107- t1 = time .perf_counter ()
108- max_logging .log (f"Inference time: { t1 - t0 :.1f} s." )
109- imgs = np .array (imgs )
110- imgs = (imgs * 0.5 + 0.5 ).clip (0 , 1 )
111- imgs = np .transpose (imgs , (0 , 2 , 3 , 1 ))
112- imgs = np .uint8 (imgs * 255 )
113- for i , image in enumerate (imgs ):
114- Image .fromarray (image ).save (f"flux_{ i } .png" )
36+ mesh = checkpoint_loader .mesh
37+ with mesh :
38+ pipeline , params = checkpoint_loader .load_checkpoint ()
39+
40+ if not params :
41+ ## VAE
42+ weights_init_fn = functools .partial (pipeline .vae .init_weights , rng = checkpoint_loader .rng )
43+ unboxed_abstract_state , _ , _ = max_utils .get_abstract_state (
44+ pipeline .vae , None , config , checkpoint_loader .mesh , weights_init_fn , False
45+ )
46+ # load unet params from orbax checkpoint
47+ vae_params = load_params_from_path (
48+ config , checkpoint_loader .checkpoint_manager , unboxed_abstract_state .params , "vae_state"
49+ )
50+
51+ vae_state = {"params" : vae_params }
52+
53+ ## Flux
54+ weights_init_fn = functools .partial (
55+ pipeline .flux .init_weights , rngs = checkpoint_loader .rng , max_sequence_length = config .max_sequence_length
56+ )
57+ unboxed_abstract_state , _ , _ = max_utils .get_abstract_state (
58+ pipeline .flux , None , config , checkpoint_loader .mesh , weights_init_fn , False
59+ )
60+ # load unet params from orbax checkpoint
61+ flux_params = load_params_from_path (
62+ config , checkpoint_loader .checkpoint_manager , unboxed_abstract_state .params , "flux_state"
63+ )
64+ flux_state = {"params" : flux_params }
65+ else :
66+ weights_init_fn = functools .partial (
67+ pipeline .flux .init_weights ,
68+ rngs = checkpoint_loader .rng ,
69+ max_sequence_length = config .max_sequence_length ,
70+ eval_only = False ,
71+ )
72+ transformer_state , flux_state_shardings = setup_initial_state (
73+ model = pipeline .flux ,
74+ tx = None ,
75+ config = config ,
76+ mesh = checkpoint_loader .mesh ,
77+ weights_init_fn = weights_init_fn ,
78+ model_params = None ,
79+ training = False ,
80+ )
81+ transformer_state = transformer_state .replace (params = params ["flux_transformer_params" ])
82+ transformer_state = jax .device_put (transformer_state , flux_state_shardings )
83+
84+ weights_init_fn = functools .partial (pipeline .vae .init_weights , rng = checkpoint_loader .rng )
85+ vae_state , _ = setup_initial_state (
86+ model = pipeline .vae ,
87+ tx = None ,
88+ config = config ,
89+ mesh = checkpoint_loader .mesh ,
90+ weights_init_fn = weights_init_fn ,
91+ model_params = params ["flux_vae" ],
92+ training = False ,
93+ )
94+
95+ vae_state = {"params" : vae_state .params }
96+ flux_state = {"params" : transformer_state .params }
97+
98+ t0 = time .perf_counter ()
99+ with ExitStack ():
100+ imgs = pipeline (flux_params = flux_state , timesteps = config .num_inference_steps , vae_params = vae_state ).block_until_ready ()
101+ t1 = time .perf_counter ()
102+ max_logging .log (f"Compile time: { t1 - t0 :.1f} s." )
103+
104+ t0 = time .perf_counter ()
105+ with ExitStack ():
106+ imgs = pipeline (flux_params = flux_state , timesteps = config .num_inference_steps , vae_params = vae_state ).block_until_ready ()
107+ imgs = jax .experimental .multihost_utils .process_allgather (imgs , tiled = True )
108+ t1 = time .perf_counter ()
109+ max_logging .log (f"Inference time: { t1 - t0 :.1f} s." )
110+ imgs = np .array (imgs )
111+ imgs = (imgs * 0.5 + 0.5 ).clip (0 , 1 )
112+ imgs = np .transpose (imgs , (0 , 2 , 3 , 1 ))
113+ imgs = np .uint8 (imgs * 255 )
114+ for i , image in enumerate (imgs ):
115+ Image .fromarray (image ).save (f"flux_{ i } .png" )
115116
116117 return imgs
117118
0 commit comments