Skip to content

Commit 1c8ed7b

Browse files
jfacevedo-googleksikiric
authored andcommitted
support both dev and schnell loading. Images still incorrect.
1 parent ac14a4b commit 1c8ed7b

7 files changed

Lines changed: 101 additions & 299 deletions

File tree

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
2828
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
2929

3030
# Flux params
31+
flux_name: "flux-dev"
3132
max_sequence_length: 512
3233
time_shift: False
3334
base_shift: 0.5

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,10 @@ activations_dtype: 'bfloat16'
5151
precision: "DEFAULT"
5252

5353
# Set true to load weights from pytorch
54-
from_pt: True
54+
from_pt: False
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash
57-
flash_block_sizes: {
58-
"block_q" : 256,
59-
"block_kv_compute" : 256,
60-
"block_kv" : 256,
61-
"block_q_dkv" : 256,
62-
"block_kv_dkv" : 256,
63-
"block_kv_dkv_compute" : 256,
64-
"block_q_dq" : 256,
65-
"block_kv_dq" : 256
66-
}
67-
68-
# Use the following flash_block_sizes on v6e (Trillium).
69-
# flash_block_sizes: {
70-
# "block_q" : 2176,
71-
# "block_kv_compute" : 2176,
72-
# "block_kv" : 2176,
73-
# "block_q_dkv" : 2176,
74-
# "block_kv_dkv" : 2176,
75-
# "block_kv_dkv_compute" : 2176,
76-
# "block_q_dq" : 2176,
77-
# "block_kv_dq" : 2176
78-
# }
57+
flash_block_sizes: {}
7958
# GroupNorm groups
8059
norm_num_groups: 32
8160

@@ -141,7 +120,6 @@ logical_axis_rules: [
141120
['activation_batch', ['data','fsdp']],
142121
['activation_heads', 'tensor'],
143122
['activation_kv', 'tensor'],
144-
['mlp','tensor'],
145123
['embed','fsdp'],
146124
['heads', 'tensor'],
147125
['conv_batch', ['data','fsdp']],
@@ -154,11 +132,11 @@ data_sharding: [['data', 'fsdp', 'tensor']]
154132
# value to auto-shard based on available slices and devices.
155133
# By default, product of the DCN axes should equal number of slices
156134
# and product of the ICI axes should equal number of devices per slice.
157-
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
158-
dcn_fsdp_parallelism: 1
135+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
136+
dcn_fsdp_parallelism: -1
159137
dcn_tensor_parallelism: 1
160-
ici_data_parallelism: -1
161-
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
138+
ici_data_parallelism: 1
139+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
162140
ici_tensor_parallelism: 1
163141

164142
# Dataset

src/maxdiffusion/configs/base_fux_schnell.yml

Lines changed: 0 additions & 247 deletions
This file was deleted.

src/maxdiffusion/generate_flux.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def unpack(x: Array, height: int, width: int) -> Array:
5555
ph=2,
5656
pw=2,
5757
)
58-
58+
from einops import rearrange
5959
def vae_decode(latents, vae, state, config):
6060
img = unpack(x=latents, height=config.resolution, width=config.resolution)
6161
img = img / vae.config.scaling_factor + vae.config.shift_factor
@@ -87,6 +87,8 @@ def loop_body(
8787
guidance=guidance_vec,
8888
y=vec
8989
)
90+
jax.debug.print("*****pred max: {x}", x=np.max(pred))
91+
jax.debug.print("*****pred min: {x}", x=np.min(pred))
9092
latents = latents + (t_prev - t_curr) * pred
9193
latents = jnp.array(latents, dtype=latents_dtype)
9294
return latents, state, c_ts, p_ts
@@ -144,6 +146,8 @@ def run_inference(
144146
timesteps = time_shift(mu, 1.0, timesteps).tolist()
145147
c_ts = timesteps[:-1]
146148
p_ts = timesteps[1:]
149+
# jax.debug.print("c_ts: {x}", x=c_ts)
150+
# jax.debug.print("p_ts: {x}", x=p_ts)
147151

148152

149153
transformer_state = states["transformer"]
@@ -162,7 +166,7 @@ def run_inference(
162166
vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config)
163167

164168
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
165-
latents, _, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, loop_body_p, (latents, transformer_state, c_ts, p_ts))
169+
latents, _, _, _ = jax.lax.fori_loop(0, len(timesteps) - 1, loop_body_p, (latents, transformer_state, c_ts, p_ts))
166170
image = vae_decode_p(latents)
167171
return image
168172

@@ -293,7 +297,8 @@ def encode_prompt(
293297
prompt=prompt_2,
294298
num_images_per_prompt=num_images_per_prompt,
295299
tokenizer=t5_tokenizer,
296-
text_encoder=t5_text_encoder
300+
text_encoder=t5_text_encoder,
301+
max_sequence_length=max_sequence_length
297302
)
298303

299304
text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16)
@@ -356,7 +361,7 @@ def run(config):
356361
rng=rng
357362
)
358363

359-
# LOAD TEXT ENCODERS - t5 on cpu
364+
# LOAD TEXT ENCODERS
360365
clip_text_encoder = FlaxCLIPTextModel.from_pretrained(
361366
config.pretrained_model_name_or_path,
362367
subfolder="text_encoder",
@@ -389,7 +394,8 @@ def run(config):
389394
clip_text_encoder=clip_text_encoder,
390395
t5_tokenizer=t5_tokenizer,
391396
t5_text_encoder=t5_encoder,
392-
num_images_per_prompt=global_batch_size
397+
num_images_per_prompt=global_batch_size,
398+
max_sequence_length=config.max_sequence_length
393399
)
394400

395401
def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds):
@@ -430,12 +436,12 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
430436

431437
get_memory_allocations()
432438
# evaluate shapes
433-
transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=512, eval_only=True)
439+
transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True)
434440

435441
# loads pretrained weights
436-
transformer_params = load_flow_model("flux-dev", transformer_eval_params, "cpu")
442+
transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu")
437443
# create transformer state
438-
weights_init_fn = functools.partial(transformer.init_weights, rngs=rng, max_sequence_length=512, eval_only=False)
444+
weights_init_fn = functools.partial(transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False)
439445
transformer_state, transformer_state_shardings = setup_initial_state(
440446
model=transformer,
441447
tx=None,

0 commit comments

Comments
 (0)