1717from typing import Any , Callable , Dict , List , Optional , Union , Sequence
1818from absl import app
1919import functools
20+ import math
2021import numpy as np
2122import jax
2223from jax .sharding import Mesh , PositionalSharding , PartitionSpec as P
2324import jax .numpy as jnp
2425from chex import Array
26+ from einops import rearrange
27+ from flax .linen import partitioning as nn_partitioning
2528from transformers import (
2629 CLIPTokenizer ,
2730 FlaxCLIPTextModel ,
4245 setup_initial_state
4346)
4447
48+ def unpack (x : Array , height : int , width : int ) -> Array :
49+ return rearrange (
50+ x ,
51+ "b (h w) (c ph pw) -> b c (h ph) (w pw)" ,
52+ h = math .ceil (height / 16 ),
53+ w = math .ceil (width / 16 ),
54+ ph = 2 ,
55+ pw = 2 ,
56+ )
57+
58+ def vae_decode (latents , vae , state , config ):
59+ img = unpack (x = latents , height = config .resolution , width = config .resolution )
60+ img = vae .apply ({"params" : state .params }, img , deterministic = True , method = vae .decode ).sample [0 ]
61+ breakpoint ()
62+ return img
63+
64+ def loop_body (
65+ step ,
66+ args ,
67+ transformer ,
68+ latent_image_ids ,
69+ prompt_embeds ,
70+ txt_ids ,
71+ vec ,
72+ guidance_vec ,
73+ ):
74+ latents , state , c_ts , p_ts = args
75+ latents_dtype = latents .dtype
76+ t_curr = c_ts [step ]
77+ t_prev = p_ts [step ]
78+ t_vec = jnp .full ((latents .shape [0 ], ), t_curr , dtype = latents .dtype )
79+ pred = transformer .apply (
80+ {"params" : state .params },
81+ img = latents ,
82+ img_ids = latent_image_ids ,
83+ txt = prompt_embeds ,
84+ txt_ids = txt_ids ,
85+ timesteps = t_vec ,
86+ guidance = guidance_vec ,
87+ y = vec
88+ )
89+ latents = latents + (t_prev - t_curr ) * pred
90+ latents = jnp .array (latents , dtype = latents_dtype )
91+ return latents , state , c_ts , p_ts
92+
4593def prepare_latent_image_ids (height , width ):
4694 latent_image_ids = jnp .zeros ((height , width , 3 ))
4795 latent_image_ids = latent_image_ids .at [..., 1 ].set (
@@ -59,6 +107,45 @@ def prepare_latent_image_ids(height, width):
59107
60108 return latent_image_ids .astype (jnp .bfloat16 )
61109
110+ def run_inference (
111+ states ,
112+ transformer ,
113+ vae ,
114+ config ,
115+ mesh ,
116+ latents ,
117+ latent_image_ids ,
118+ prompt_embeds ,
119+ txt_ids ,
120+ vec ,
121+ guidance_vec ,
122+ ):
123+ timesteps = jnp .linspace (1 , 0 , config .num_inference_steps + 1 )
124+ c_ts = timesteps [:- 1 ]
125+ p_ts = timesteps [1 :]
126+
127+ transformer_state = states ["transformer" ]
128+ vae_state = states ["vae" ]
129+
130+ loop_body_p = functools .partial (
131+ loop_body ,
132+ transformer = transformer ,
133+ latent_image_ids = latent_image_ids ,
134+ prompt_embeds = prompt_embeds ,
135+ txt_ids = txt_ids ,
136+ vec = vec ,
137+ guidance_vec = guidance_vec ,
138+ )
139+
140+ vae_decode_p = functools .partial (vae_decode , vae = vae , state = vae_state , config = config )
141+
142+ with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
143+ latents , _ , _ , _ = jax .lax .fori_loop (0 , config .num_inference_steps , loop_body_p , (latents , transformer_state , c_ts , p_ts ))
144+ image = vae_decode_p (latents )
145+ breakpoint ()
146+ return image
147+
148+
62149def pack_latents (
63150 latents : Array ,
64151 batch_size : int ,
@@ -207,6 +294,18 @@ def run(config):
207294 use_safetensors = True ,
208295 dtype = "bfloat16"
209296 )
297+
298+ weights_init_fn = functools .partial (vae .init_weights , rng = rng )
299+ vae_state , vae_state_shardings = setup_initial_state (
300+ model = vae ,
301+ tx = None ,
302+ config = config ,
303+ mesh = mesh ,
304+ weights_init_fn = weights_init_fn ,
305+ model_params = vae_params ,
306+ training = False ,
307+ )
308+
210309 vae_scale_factor = 2 ** (len (vae .config .block_out_channels ) - 1 )
211310
212311 # LOAD TRANSFORMER
@@ -283,7 +382,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
283382 print ("pooled_prompt_embeds.shape: " , pooled_prompt_embeds .shape , pooled_prompt_embeds .dtype )
284383
285384 timesteps = jnp .asarray ([1.0 ] * global_batch_size , dtype = jnp .bfloat16 )
286- guidance = jnp .asarray ([3.5 ] * global_batch_size , dtype = jnp .bfloat16 )
385+ guidance = jnp .asarray ([config . guidance_scale ] * global_batch_size , dtype = jnp .bfloat16 )
287386 validate_inputs (
288387 latents ,
289388 latent_image_ids ,
@@ -321,34 +420,69 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
321420 config = config ,
322421 mesh = mesh ,
323422 weights_init_fn = weights_init_fn ,
324- model_params = transformer_params ,
423+ #model_params=transformer_params,
424+ model_params = None ,
325425 training = False
326426 )
327- # transformer_state = transformer_state.replace(params=transformer_params)
427+ transformer_state = transformer_state .replace (params = transformer_params )
328428 get_memory_allocations ()
329- def run_inference (state , transformer ):
330- img = transformer .apply (
331- {"params" : state .params },
332- img = latents ,
333- img_ids = latent_image_ids ,
334- txt = prompt_embeds ,
335- txt_ids = text_ids ,
336- timesteps = timesteps ,
337- guidance = guidance ,
338- y = pooled_prompt_embeds
339- )
340- return img
429+
430+ states = {}
431+ state_shardings = {}
432+
433+ state_shardings ["transformer" ] = transformer_state_shardings
434+ state_shardings ["vae" ] = vae_state_shardings
435+
436+ states ["transformer" ] = transformer_state
437+ states ["vae" ] = vae_state
341438
342439 p_run_inference = jax .jit (
343440 functools .partial (
344441 run_inference ,
345- transformer = transformer
442+ transformer = transformer ,
443+ vae = vae ,
444+ config = config ,
445+ mesh = mesh ,
446+ latents = latents ,
447+ latent_image_ids = latent_image_ids ,
448+ prompt_embeds = prompt_embeds ,
449+ txt_ids = text_ids ,
450+ vec = pooled_prompt_embeds ,
451+ guidance_vec = guidance ,
346452 ),
347- in_shardings = (transformer_state_shardings ,),
348- out_shardings = None
453+ in_shardings = (state_shardings ,),
454+ out_shardings = None ,
349455 )
350456
457+ img = p_run_inference (states )
458+
459+
460+
461+
462+ # def run_inference(state, transformer):
463+ # img = transformer.apply(
464+ # {"params" : state.params},
465+ # img=latents,
466+ # img_ids=latent_image_ids,
467+ # txt=prompt_embeds,
468+ # txt_ids=text_ids,
469+ # timesteps=timesteps,
470+ # guidance=guidance,
471+ # y=pooled_prompt_embeds
472+ # )
473+ # return img
474+
475+ # p_run_inference = jax.jit(
476+ # functools.partial(
477+ # run_inference,
478+ # transformer=transformer,
479+ # ),
480+ # in_shardings=(transformer_state_shardings,),
481+ # out_shardings=None
482+ # )
483+
351484 img = p_run_inference (transformer_state )
485+ breakpoint ()
352486 print ("img.shape: " , img .shape )
353487
354488
0 commit comments