Skip to content

Commit 4c68d53

Browse files
Merge branch 'main' into flux_lora
2 parents 9e07358 + 7f0f5bc commit 4c68d53

4 files changed

Lines changed: 30 additions & 32 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20-
- **`2025/02/08**: Flux schnell & dev inference.
20+
- **`2025/02/08`**: Flux schnell & dev inference.
2121
- **`2024/12/12`**: Load multiple LoRAs for inference.
2222
- **`2024/10/22`**: LoRA support for Hyper SDXL.
2323
- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,27 @@ precision: "DEFAULT"
5454
from_pt: True
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash
57-
flash_block_sizes: {}
58-
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
57+
flash_block_sizes: {
58+
"block_q" : 256,
59+
"block_kv_compute" : 256,
60+
"block_kv" : 256,
61+
"block_q_dkv" : 256,
62+
"block_kv_dkv" : 256,
63+
"block_kv_dkv_compute" : 256,
64+
"block_q_dq" : 256,
65+
"block_kv_dq" : 256
66+
}
67+
68+
# Use the following flash_block_sizes on v6e (Trillium).
5969
# flash_block_sizes: {
60-
# "block_q" : 1536,
61-
# "block_kv_compute" : 1536,
62-
# "block_kv" : 1536,
63-
# "block_q_dkv" : 1536,
64-
# "block_kv_dkv" : 1536,
65-
# "block_kv_dkv_compute" : 1536,
66-
# "block_q_dq" : 1536,
67-
# "block_kv_dq" : 1536
70+
# "block_q" : 2176,
71+
# "block_kv_compute" : 2176,
72+
# "block_kv" : 2176,
73+
# "block_q_dkv" : 2176,
74+
# "block_kv_dkv" : 2176,
75+
# "block_kv_dkv_compute" : 2176,
76+
# "block_q_dq" : 2176,
77+
# "block_kv_dq" : 2176
6878
# }
6979
# GroupNorm groups
7080
norm_num_groups: 32

src/maxdiffusion/generate_flux.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def unpack(x: Array, height: int, width: int) -> Array:
7777

7878

7979
def vae_decode(latents, vae, state, config):
80-
img = unpack(x=latents.astype(jnp.float32), height=config.resolution, width=config.resolution)
80+
img = unpack(x=latents, height=config.resolution, width=config.resolution)
8181
img = img / vae.config.scaling_factor + vae.config.shift_factor
8282
img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample
8383
return img
@@ -135,19 +135,7 @@ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: flo
135135

136136

137137
def run_inference(
138-
states,
139-
transformer,
140-
vae,
141-
config,
142-
mesh,
143-
latents,
144-
latent_image_ids,
145-
prompt_embeds,
146-
txt_ids,
147-
vec,
148-
guidance_vec,
149-
c_ts,
150-
p_ts
138+
states, transformer, vae, config, mesh, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts
151139
):
152140

153141
transformer_state = states["transformer"]
@@ -468,7 +456,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
468456
vec=pooled_prompt_embeds,
469457
guidance_vec=guidance,
470458
c_ts=c_ts,
471-
p_ts=p_ts
459+
p_ts=p_ts,
472460
),
473461
in_shardings=(state_shardings,),
474462
out_shardings=None,

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,10 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None):
144144
hidden_states = self.linear2(attn_mlp)
145145
hidden_states = gate * hidden_states
146146
hidden_states = residual + hidden_states
147-
if hidden_states.dtype == jnp.float16:
147+
if hidden_states.dtype == jnp.float16 or hidden_states.dtype == jnp.bfloat16:
148148
hidden_states = jnp.clip(hidden_states, -65504, 65504)
149149

150-
return hidden_states
150+
return hidden_states, temb, image_rotary_emb
151151

152152

153153
class FluxTransformerBlock(nn.Module):
@@ -294,9 +294,9 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=
294294

295295
context_ff_output = self.txt_mlp(norm_encoder_hidden_states)
296296
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
297-
if encoder_hidden_states.dtype == jnp.float16:
297+
if encoder_hidden_states.dtype == jnp.float16 or encoder_hidden_states.dtype == jnp.bfloat16:
298298
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
299-
return hidden_states, encoder_hidden_states
299+
return hidden_states, encoder_hidden_states, temb, image_rotary_emb
300300

301301

302302
@flax_register_to_config
@@ -504,7 +504,7 @@ def __call__(
504504
image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed"))
505505

506506
for double_block in self.double_blocks:
507-
hidden_states, encoder_hidden_states = double_block(
507+
hidden_states, encoder_hidden_states, temb, image_rotary_emb = double_block(
508508
hidden_states=hidden_states,
509509
encoder_hidden_states=encoder_hidden_states,
510510
temb=temb,
@@ -513,7 +513,7 @@ def __call__(
513513
hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1)
514514
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))
515515
for single_block in self.single_blocks:
516-
hidden_states = single_block(
516+
hidden_states, temb, image_rotary_emb = single_block(
517517
hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb
518518
)
519519
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

0 commit comments

Comments
 (0)