Skip to content

Commit 4a12b39

Browse files
flux schnell working
1 parent d16c020 commit 4a12b39

9 files changed

Lines changed: 1068 additions & 303 deletions

File tree

src/maxdiffusion/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
BATCH = "activation_batch"
3838
LENGTH = "activation_length"
39+
EMBED = "activation_embed"
3940
HEAD = "activation_heads"
4041
D_KV = "activation_kv"
4142
KEEP_1 = "activation_keep_1"

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,19 @@ activations_dtype: 'bfloat16'
5252
precision: "DEFAULT"
5353

5454
# Set true to load weights from pytorch
55-
from_pt: False
55+
from_pt: True
5656
split_head_dim: True
5757
attention: 'flash' # Supported attention: dot_product, flash
58-
flash_block_sizes: {}
58+
flash_block_sizes: {
59+
"block_q" : 128,
60+
"block_kv" : 128,
61+
"block_kv_compute" : 128,
62+
"block_q_dkv" : 128,
63+
"block_kv_dkv" : 128,
64+
"block_kv_dkv_compute" : 128,
65+
"block_q_dq" : 128,
66+
"block_kv_dq" : 128
67+
}
5968
# GroupNorm groups
6069
norm_num_groups: 32
6170

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,19 @@ activations_dtype: 'bfloat16'
5151
precision: "DEFAULT"
5252

5353
# Set true to load weights from pytorch
54-
from_pt: False
54+
from_pt: True
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash
57-
flash_block_sizes: {}
57+
flash_block_sizes: {
58+
"block_q" : 128,
59+
"block_kv" : 128,
60+
"block_kv_compute" : 128,
61+
"block_q_dkv" : 128,
62+
"block_kv_dkv" : 128,
63+
"block_kv_dkv_compute" : 128,
64+
"block_q_dq" : 128,
65+
"block_kv_dq" : 128
66+
}
5867
# GroupNorm groups
5968
norm_num_groups: 32
6069

src/maxdiffusion/generate_flux.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,14 @@ def loop_body(
7979
t_vec = jnp.full((latents.shape[0], ), t_curr, dtype=latents.dtype)
8080
pred = transformer.apply(
8181
{"params" : state.params},
82-
img=latents,
82+
hidden_states=latents,
8383
img_ids=latent_image_ids,
84-
txt=prompt_embeds,
84+
encoder_hidden_states=prompt_embeds,
8585
txt_ids=txt_ids,
86-
timesteps=t_vec,
86+
timestep=t_vec,
8787
guidance=guidance_vec,
88-
y=vec
89-
)
90-
jax.debug.print("*****pred max: {x}", x=np.max(pred))
91-
jax.debug.print("*****pred min: {x}", x=np.min(pred))
88+
pooled_projections=vec
89+
).sample
9290
latents = latents + (t_prev - t_curr) * pred
9391
latents = jnp.array(latents, dtype=latents_dtype)
9492
return latents, state, c_ts, p_ts

src/maxdiffusion/models/attention_flax.py

Lines changed: 174 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from jax.experimental import shard_map
2222
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
2323
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
24-
24+
from einops import rearrange
2525
from .. import common_types, max_logging
2626

2727
Array = common_types.Array
@@ -35,6 +35,7 @@
3535
LENGTH = common_types.LENGTH
3636
HEAD = common_types.HEAD
3737
D_KV = common_types.D_KV
38+
EMBED = common_types.EMBED
3839

3940

4041
class AttentionOp(nn.Module):
@@ -63,7 +64,6 @@ def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None
6364
def apply_attention(self, query: Array, key: Array, value: Array):
6465
"""Routes to different attention kernels."""
6566
self.check_attention_inputs(query, key, value)
66-
6767
can_use_flash_attention = (
6868
query.shape[1] >= self.flash_min_seq_length
6969
and key.shape[1] >= self.flash_min_seq_length
@@ -111,8 +111,7 @@ def wrap_flash_attention(query, key, value):
111111
block_q_dq=min(512, query.shape[2]),
112112
block_kv_dq=min(512, query.shape[2]),
113113
)
114-
115-
masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for i in range(query.shape[1])]
114+
masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for _ in range(query.shape[1])]
116115
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks)
117116
splash_kernel = splash_attention_kernel.make_splash_mha(
118117
mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes
@@ -323,6 +322,177 @@ def chunk_scanner(chunk_idx, _):
323322

324323
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
325324

325+
class FlaxFluxAttention(nn.Module):
326+
query_dim: int
327+
heads: int = 8
328+
dim_head: int = 64
329+
dropout: float = 0.0
330+
use_memory_efficient_attention: bool = False
331+
split_head_dim: bool = False
332+
attention_kernel: str = "dot_product"
333+
flash_min_seq_length: int = 4096
334+
flash_block_sizes: BlockSizes = None
335+
mesh: jax.sharding.Mesh = None
336+
dtype: jnp.dtype = jnp.float32
337+
weights_dtype: jnp.dtype = jnp.float32
338+
query_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
339+
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
340+
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
341+
out_axis_names: AxisNames = (BATCH, LENGTH, EMBED)
342+
precision: jax.lax.Precision = None
343+
qkv_bias: bool = False
344+
345+
def setup(self):
346+
if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None:
347+
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
348+
inner_dim = self.dim_head * self.heads
349+
scale = self.dim_head**-0.5
350+
351+
self.attention_op = AttentionOp(
352+
mesh=self.mesh,
353+
attention_kernel=self.attention_kernel,
354+
scale=scale,
355+
heads=self.heads,
356+
dim_head=self.dim_head,
357+
flash_min_seq_length=self.flash_min_seq_length,
358+
use_memory_efficient_attention=self.use_memory_efficient_attention,
359+
split_head_dim=self.split_head_dim,
360+
flash_block_sizes=self.flash_block_sizes,
361+
dtype=self.dtype,
362+
float32_qk_product=False,
363+
)
364+
365+
kernel_axes = ("embed", "heads")
366+
qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes)
367+
368+
self.qkv = nn.Dense(
369+
inner_dim * 3,
370+
kernel_init=qkv_init_kernel,
371+
use_bias=self.qkv_bias,
372+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
373+
dtype=self.dtype,
374+
param_dtype=self.weights_dtype,
375+
name="i_qkv",
376+
precision=self.precision,
377+
)
378+
379+
self.encoder_qkv = nn.Dense(
380+
inner_dim * 3,
381+
kernel_init=qkv_init_kernel,
382+
use_bias=self.qkv_bias,
383+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
384+
dtype=self.dtype,
385+
param_dtype=self.weights_dtype,
386+
name="e_qkv",
387+
precision=self.precision,
388+
)
389+
390+
self.proj_attn = nn.Dense(
391+
self.query_dim,
392+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes),
393+
use_bias=True,
394+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
395+
dtype=self.dtype,
396+
param_dtype=self.weights_dtype,
397+
name="i_proj",
398+
precision=self.precision,
399+
)
400+
401+
self.encoder_proj_attn = nn.Dense(
402+
self.query_dim,
403+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes),
404+
use_bias=True,
405+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
406+
dtype=self.dtype,
407+
param_dtype=self.weights_dtype,
408+
name="e_proj",
409+
precision=self.precision,
410+
)
411+
412+
self.query_norm = nn.RMSNorm(
413+
dtype=self.dtype,
414+
scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)),
415+
param_dtype=self.weights_dtype,
416+
)
417+
self.key_norm = nn.RMSNorm(
418+
dtype=self.dtype,
419+
scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)),
420+
param_dtype=self.weights_dtype,
421+
)
422+
423+
self.encoder_query_norm = nn.RMSNorm(
424+
dtype=self.dtype,
425+
scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)),
426+
param_dtype=self.weights_dtype,
427+
)
428+
self.encoder_key_norm = nn.RMSNorm(
429+
dtype=self.dtype,
430+
scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)),
431+
param_dtype=self.weights_dtype,
432+
)
433+
434+
def apply_rope(self, xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
435+
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
436+
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
437+
438+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
439+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
440+
441+
return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype)
442+
443+
def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None):
444+
445+
qkv_proj = self.qkv(hidden_states)
446+
B, L = hidden_states.shape[:2]
447+
H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3
448+
qkv_proj = qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4)
449+
query_proj, key_proj, value_proj = qkv_proj
450+
451+
query_proj = self.query_norm(query_proj)
452+
453+
key_proj = self.key_norm(key_proj)
454+
455+
if encoder_hidden_states is not None:
456+
457+
encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states)
458+
B, L = encoder_hidden_states.shape[:2]
459+
H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3
460+
encoder_qkv_proj = encoder_qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4)
461+
encoder_query_proj, encoder_key_proj, encoder_value_proj = encoder_qkv_proj
462+
463+
encoder_query_proj = self.encoder_query_norm(encoder_query_proj)
464+
465+
encoder_key_proj = self.encoder_key_norm(encoder_key_proj)
466+
467+
query_proj = jnp.concatenate((encoder_query_proj, query_proj), axis=2)
468+
key_proj = jnp.concatenate((encoder_key_proj, key_proj), axis=2)
469+
value_proj = jnp.concatenate((encoder_value_proj, value_proj), axis=2)
470+
471+
query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names)
472+
key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names)
473+
value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)
474+
475+
image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2)
476+
query_proj, key_proj = self.apply_rope(query_proj, key_proj, image_rotary_emb)
477+
478+
query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1)
479+
key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1)
480+
value_proj = value_proj.transpose(0, 2, 1, 3).reshape(value_proj.shape[0], value_proj.shape[2], -1)
481+
482+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
483+
context_attn_output = None
484+
485+
if encoder_hidden_states is not None:
486+
context_attn_output, attn_output = (
487+
attn_output[:, : encoder_hidden_states.shape[1]],
488+
attn_output[:, encoder_hidden_states.shape[1] :],
489+
)
490+
491+
attn_output = self.proj_attn(attn_output)
492+
493+
context_attn_output = self.encoder_proj_attn(context_attn_output)
494+
495+
return attn_output, context_attn_output
326496

327497
class FlaxAttention(nn.Module):
328498
r"""

0 commit comments

Comments
 (0)