Skip to content

Commit cbc7723

Browse files
jfacevedo-googleksikiric
authored andcommitted
working loop, bad generation
1 parent dfe1089 commit cbc7723

4 files changed

Lines changed: 71 additions & 64 deletions

File tree

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,10 @@ clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
2828
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
2929

3030
# Flux params
31-
flux_name: "flux-dev"
3231
max_sequence_length: 512
33-
time_shift: True
32+
time_shift: False
3433
base_shift: 0.5
3534
max_shift: 1.15
36-
# offloads t5 encoder after text encoding to save memory.
37-
offload_encoders: True
3835

3936

4037
unet_checkpoint: ''
@@ -52,22 +49,10 @@ activations_dtype: 'bfloat16'
5249
precision: "DEFAULT"
5350

5451
# Set true to load weights from pytorch
55-
from_pt: True
52+
from_pt: False
5653
split_head_dim: True
5754
attention: 'flash' # Supported attention: dot_product, flash
58-
5955
flash_block_sizes: {}
60-
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
61-
# flash_block_sizes: {
62-
# "block_q" : 1536,
63-
# "block_kv_compute" : 1536,
64-
# "block_kv" : 1536,
65-
# "block_q_dkv" : 1536,
66-
# "block_kv_dkv" : 1536,
67-
# "block_kv_dkv_compute" : 1536,
68-
# "block_q_dq" : 1536,
69-
# "block_kv_dq" : 1536
70-
# }
7156
# GroupNorm groups
7257
norm_num_groups: 32
7358

@@ -133,7 +118,6 @@ logical_axis_rules: [
133118
['activation_batch', ['data','fsdp']],
134119
['activation_heads', 'tensor'],
135120
['activation_kv', 'tensor'],
136-
['mlp','tensor'],
137121
['embed','fsdp'],
138122
['heads', 'tensor'],
139123
['conv_batch', ['data','fsdp']],
@@ -149,8 +133,8 @@ data_sharding: [['data', 'fsdp', 'tensor']]
149133
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
150134
dcn_fsdp_parallelism: -1
151135
dcn_tensor_parallelism: 1
152-
ici_data_parallelism: -1
153-
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
136+
ici_data_parallelism: 1
137+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
154138
ici_tensor_parallelism: 1
155139

156140
# Dataset
@@ -226,7 +210,7 @@ do_classifier_free_guidance: True
226210
guidance_scale: 3.5
227211
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
228212
guidance_rescale: 0.0
229-
num_inference_steps: 50
213+
num_inference_steps: 20
230214

231215
# SDXL Lightning parameters
232216
lightning_from_pt: True

src/maxdiffusion/configs/base_flux.yml renamed to src/maxdiffusion/configs/base_fux_schnell.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
2727
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
2828
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
2929

30+
# Flux params
31+
max_sequence_length: 256
32+
time_shift: False
33+
base_shift: 0.5
34+
max_shift: 1.15
35+
3036
unet_checkpoint: ''
3137
revision: 'refs/pr/95'
3238
# This will convert the weights to this dtype.

src/maxdiffusion/generate_flux.py

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from absl import app
1919
import functools
2020
import math
21+
import time
2122
import numpy as np
23+
from PIL import Image
2224
import jax
2325
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
2426
import jax.numpy as jnp
@@ -33,9 +35,8 @@
3335
FlaxT5EncoderModel
3436
)
3537

36-
from maxdiffusion import FlaxAutoencoderKL
38+
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
3739
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
38-
from maxdiffusion import pyconfig
3940
from max_utils import (
4041
device_put_replicated,
4142
get_memory_allocations,
@@ -57,8 +58,8 @@ def unpack(x: Array, height: int, width: int) -> Array:
5758

5859
def vae_decode(latents, vae, state, config):
5960
img = unpack(x=latents, height=config.resolution, width=config.resolution)
60-
img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample[0]
61-
breakpoint()
61+
img = img / vae.config.scaling_factor + vae.config.shift_factor
62+
img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample
6263
return img
6364

6465
def loop_body(
@@ -107,6 +108,19 @@ def prepare_latent_image_ids(height, width):
107108

108109
return latent_image_ids.astype(jnp.bfloat16)
109110

111+
def time_shift(mu: float, sigma: float, t: Array):
112+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
113+
114+
def get_lin_function(
115+
x1: float = 256,
116+
y1: float = 0.5,
117+
x2: float = 4096,
118+
y2: float = 1.15
119+
) -> Callable[[float], float]:
120+
m = (y2 - y1) / (x2 - x1)
121+
b = y1 - m * x1
122+
return lambda x: m * x + b
123+
110124
def run_inference(
111125
states,
112126
transformer,
@@ -120,10 +134,18 @@ def run_inference(
120134
vec,
121135
guidance_vec,
122136
):
137+
123138
timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1)
139+
# shifting the schedule to favor high timesteps for higher signal images
140+
if config.time_shift:
141+
# estimate mu based on linear estimation between two points
142+
lin_function = get_lin_function(y1=config.base_shift, y2=config.max_shift)
143+
mu = lin_function(latents.shape[1])
144+
timesteps = time_shift(mu, 1.0, timesteps).tolist()
124145
c_ts = timesteps[:-1]
125146
p_ts = timesteps[1:]
126147

148+
127149
transformer_state = states["transformer"]
128150
vae_state = states["vae"]
129151

@@ -142,7 +164,6 @@ def run_inference(
142164
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
143165
latents, _, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, loop_body_p, (latents, transformer_state, c_ts, p_ts))
144166
image = vae_decode_p(latents)
145-
breakpoint()
146167
return image
147168

148169

@@ -383,6 +404,10 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
383404

384405
timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16)
385406
guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16)
407+
408+
# TODO - remove this later and figure out why t5x is returning wrong shape
409+
prompt_embeds = jnp.ones((global_batch_size, 512, 4096))
410+
386411
validate_inputs(
387412
latents,
388413
latent_image_ids,
@@ -393,8 +418,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
393418
pooled_prompt_embeds
394419
)
395420

396-
# TODO - remove this later and figure out why t5x is returning wrong shape
397-
prompt_embeds = jnp.ones((global_batch_size, 512, 4096))
421+
398422

399423
# move inputs to device and shard
400424
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
@@ -420,11 +444,11 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
420444
config=config,
421445
mesh=mesh,
422446
weights_init_fn=weights_init_fn,
423-
#model_params=transformer_params,
424-
model_params=None,
447+
model_params=transformer_params,
448+
#model_params=None,
425449
training=False
426450
)
427-
transformer_state = transformer_state.replace(params=transformer_params)
451+
#transformer_state = transformer_state.replace(params=transformer_params)
428452
get_memory_allocations()
429453

430454
states = {}
@@ -453,37 +477,27 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
453477
in_shardings=(state_shardings,),
454478
out_shardings=None,
455479
)
456-
457-
img = p_run_inference(states)
458-
459-
460-
461-
462-
# def run_inference(state, transformer):
463-
# img = transformer.apply(
464-
# {"params" : state.params},
465-
# img=latents,
466-
# img_ids=latent_image_ids,
467-
# txt=prompt_embeds,
468-
# txt_ids=text_ids,
469-
# timesteps=timesteps,
470-
# guidance=guidance,
471-
# y=pooled_prompt_embeds
472-
# )
473-
# return img
474-
475-
# p_run_inference = jax.jit(
476-
# functools.partial(
477-
# run_inference,
478-
# transformer=transformer,
479-
# ),
480-
# in_shardings=(transformer_state_shardings,),
481-
# out_shardings=None
482-
# )
483-
484-
img = p_run_inference(transformer_state)
485-
breakpoint()
486-
print("img.shape: ", img.shape)
480+
t0 = time.perf_counter()
481+
p_run_inference(states).block_until_ready()
482+
t1 = time.perf_counter()
483+
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
484+
485+
t0 = time.perf_counter()
486+
imgs = p_run_inference(states).block_until_ready()
487+
t1 = time.perf_counter()
488+
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
489+
490+
t0 = time.perf_counter()
491+
imgs = p_run_inference(states).block_until_ready()
492+
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
493+
t1 = time.perf_counter()
494+
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
495+
imgs = np.array(imgs)
496+
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
497+
imgs = np.transpose(imgs, (0, 2, 3, 1))
498+
imgs = np.uint8(imgs * 255)
499+
for i, image in enumerate(imgs):
500+
Image.fromarray(image).save(f"flux_{i}.png")
487501

488502

489503
def main(argv: Sequence[str]) -> None:

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,17 @@ def __call__(
154154
raise ValueError(
155155
"Didn't get guidance strength for guidance distrilled model."
156156
)
157-
158-
vec = vec + MLPEmbedder(
157+
guidance_in = MLPEmbedder(
159158
hidden_dim=inner_dim,
160159
dtype=self.dtype,
161160
weights_dtype=self.weights_dtype,
162161
precision=self.precision,
163162
name="guidance_in"
164163
)(timestep_embedding(guidance, 256))
164+
else:
165+
guidance_in = Identity(timestep_embedding(guidance, 256))
166+
167+
vec = vec + guidance_in
165168

166169
vec = vec + MLPEmbedder(
167170
hidden_dim=inner_dim,

0 commit comments

Comments
 (0)