Skip to content

Commit 4b64f5d

Browse files
wip - generate fn
1 parent 956341e commit 4b64f5d

2 files changed

Lines changed: 153 additions & 19 deletions

File tree

src/maxdiffusion/configs/base_flux.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ prompt: "A magical castle in the middle of a forest, artistic drawing"
200200
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
201201
negative_prompt: "purple, red"
202202
do_classifier_free_guidance: True
203-
guidance_scale: 9.0
203+
guidance_scale: 3.5
204204
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
205205
guidance_rescale: 0.0
206206
num_inference_steps: 20

src/maxdiffusion/generate_flux.py

Lines changed: 152 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717
from typing import Any, Callable, Dict, List, Optional, Union, Sequence
1818
from absl import app
1919
import functools
20+
import math
2021
import numpy as np
2122
import jax
2223
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
2324
import jax.numpy as jnp
2425
from chex import Array
26+
from einops import rearrange
27+
from flax.linen import partitioning as nn_partitioning
2528
from transformers import (
2629
CLIPTokenizer,
2730
FlaxCLIPTextModel,
@@ -42,6 +45,51 @@
4245
setup_initial_state
4346
)
4447

48+
def unpack(x: Array, height: int, width: int) -> Array:
49+
return rearrange(
50+
x,
51+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
52+
h=math.ceil(height / 16),
53+
w=math.ceil(width / 16),
54+
ph=2,
55+
pw=2,
56+
)
57+
58+
def vae_decode(latents, vae, state, config):
59+
img = unpack(x=latents, height=config.resolution, width=config.resolution)
60+
img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample[0]
61+
breakpoint()
62+
return img
63+
64+
def loop_body(
65+
step,
66+
args,
67+
transformer,
68+
latent_image_ids,
69+
prompt_embeds,
70+
txt_ids,
71+
vec,
72+
guidance_vec,
73+
):
74+
latents, state, c_ts, p_ts = args
75+
latents_dtype = latents.dtype
76+
t_curr = c_ts[step]
77+
t_prev = p_ts[step]
78+
t_vec = jnp.full((latents.shape[0], ), t_curr, dtype=latents.dtype)
79+
pred = transformer.apply(
80+
{"params" : state.params},
81+
img=latents,
82+
img_ids=latent_image_ids,
83+
txt=prompt_embeds,
84+
txt_ids=txt_ids,
85+
timesteps=t_vec,
86+
guidance=guidance_vec,
87+
y=vec
88+
)
89+
latents = latents + (t_prev - t_curr) * pred
90+
latents = jnp.array(latents, dtype=latents_dtype)
91+
return latents, state, c_ts, p_ts
92+
4593
def prepare_latent_image_ids(height, width):
4694
latent_image_ids = jnp.zeros((height, width, 3))
4795
latent_image_ids = latent_image_ids.at[..., 1].set(
@@ -59,6 +107,45 @@ def prepare_latent_image_ids(height, width):
59107

60108
return latent_image_ids.astype(jnp.bfloat16)
61109

110+
def run_inference(
111+
states,
112+
transformer,
113+
vae,
114+
config,
115+
mesh,
116+
latents,
117+
latent_image_ids,
118+
prompt_embeds,
119+
txt_ids,
120+
vec,
121+
guidance_vec,
122+
):
123+
timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1)
124+
c_ts = timesteps[:-1]
125+
p_ts = timesteps[1:]
126+
127+
transformer_state = states["transformer"]
128+
vae_state = states["vae"]
129+
130+
loop_body_p = functools.partial(
131+
loop_body,
132+
transformer=transformer,
133+
latent_image_ids=latent_image_ids,
134+
prompt_embeds=prompt_embeds,
135+
txt_ids=txt_ids,
136+
vec=vec,
137+
guidance_vec=guidance_vec,
138+
)
139+
140+
vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config)
141+
142+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
143+
latents, _, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, loop_body_p, (latents, transformer_state, c_ts, p_ts))
144+
image = vae_decode_p(latents)
145+
breakpoint()
146+
return image
147+
148+
62149
def pack_latents(
63150
latents: Array,
64151
batch_size: int,
@@ -207,6 +294,18 @@ def run(config):
207294
use_safetensors=True,
208295
dtype="bfloat16"
209296
)
297+
298+
weights_init_fn = functools.partial(vae.init_weights, rng=rng)
299+
vae_state, vae_state_shardings = setup_initial_state(
300+
model=vae,
301+
tx=None,
302+
config=config,
303+
mesh=mesh,
304+
weights_init_fn=weights_init_fn,
305+
model_params=vae_params,
306+
training=False,
307+
)
308+
210309
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
211310

212311
# LOAD TRANSFORMER
@@ -283,7 +382,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
283382
print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype)
284383

285384
timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16)
286-
guidance = jnp.asarray([3.5] * global_batch_size, dtype=jnp.bfloat16)
385+
guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16)
287386
validate_inputs(
288387
latents,
289388
latent_image_ids,
@@ -321,34 +420,69 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
321420
config=config,
322421
mesh=mesh,
323422
weights_init_fn=weights_init_fn,
324-
model_params=transformer_params,
423+
#model_params=transformer_params,
424+
model_params=None,
325425
training=False
326426
)
327-
#transformer_state = transformer_state.replace(params=transformer_params)
427+
transformer_state = transformer_state.replace(params=transformer_params)
328428
get_memory_allocations()
329-
def run_inference(state, transformer):
330-
img = transformer.apply(
331-
{"params" : state.params},
332-
img=latents,
333-
img_ids=latent_image_ids,
334-
txt=prompt_embeds,
335-
txt_ids=text_ids,
336-
timesteps=timesteps,
337-
guidance=guidance,
338-
y=pooled_prompt_embeds
339-
)
340-
return img
429+
430+
states = {}
431+
state_shardings = {}
432+
433+
state_shardings["transformer"] = transformer_state_shardings
434+
state_shardings["vae"] = vae_state_shardings
435+
436+
states["transformer"] = transformer_state
437+
states["vae"] = vae_state
341438

342439
p_run_inference = jax.jit(
343440
functools.partial(
344441
run_inference,
345-
transformer=transformer
442+
transformer=transformer,
443+
vae=vae,
444+
config=config,
445+
mesh=mesh,
446+
latents=latents,
447+
latent_image_ids=latent_image_ids,
448+
prompt_embeds=prompt_embeds,
449+
txt_ids=text_ids,
450+
vec=pooled_prompt_embeds,
451+
guidance_vec=guidance,
346452
),
347-
in_shardings=(transformer_state_shardings,),
348-
out_shardings=None
453+
in_shardings=(state_shardings,),
454+
out_shardings=None,
349455
)
350456

457+
img = p_run_inference(states)
458+
459+
460+
461+
462+
# def run_inference(state, transformer):
463+
# img = transformer.apply(
464+
# {"params" : state.params},
465+
# img=latents,
466+
# img_ids=latent_image_ids,
467+
# txt=prompt_embeds,
468+
# txt_ids=text_ids,
469+
# timesteps=timesteps,
470+
# guidance=guidance,
471+
# y=pooled_prompt_embeds
472+
# )
473+
# return img
474+
475+
# p_run_inference = jax.jit(
476+
# functools.partial(
477+
# run_inference,
478+
# transformer=transformer,
479+
# ),
480+
# in_shardings=(transformer_state_shardings,),
481+
# out_shardings=None
482+
# )
483+
351484
img = p_run_inference(transformer_state)
485+
breakpoint()
352486
print("img.shape: ", img.shape)
353487

354488

0 commit comments

Comments
 (0)