Skip to content

Commit c8196ed

Browse files
jfacevedo-googleksikiric
authored andcommitted
flux schnell working
1 parent 1c8ed7b commit c8196ed

8 files changed

Lines changed: 777 additions & 302 deletions

File tree

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,19 @@ activations_dtype: 'bfloat16'
5252
precision: "DEFAULT"
5353

5454
# Set true to load weights from pytorch
55-
from_pt: False
55+
from_pt: True
5656
split_head_dim: True
5757
attention: 'flash' # Supported attention: dot_product, flash
58-
flash_block_sizes: {}
58+
flash_block_sizes: {
59+
"block_q" : 128,
60+
"block_kv" : 128,
61+
"block_kv_compute" : 128,
62+
"block_q_dkv" : 128,
63+
"block_kv_dkv" : 128,
64+
"block_kv_dkv_compute" : 128,
65+
"block_q_dq" : 128,
66+
"block_kv_dq" : 128
67+
}
5968
# GroupNorm groups
6069
norm_num_groups: 32
6170

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,19 @@ activations_dtype: 'bfloat16'
5151
precision: "DEFAULT"
5252

5353
# Set true to load weights from pytorch
54-
from_pt: False
54+
from_pt: True
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash
57-
flash_block_sizes: {}
57+
flash_block_sizes: {
58+
"block_q" : 128,
59+
"block_kv" : 128,
60+
"block_kv_compute" : 128,
61+
"block_q_dkv" : 128,
62+
"block_kv_dkv" : 128,
63+
"block_kv_dkv_compute" : 128,
64+
"block_q_dq" : 128,
65+
"block_kv_dq" : 128
66+
}
5867
# GroupNorm groups
5968
norm_num_groups: 32
6069

src/maxdiffusion/generate_flux.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,14 @@ def loop_body(
7979
t_vec = jnp.full((latents.shape[0], ), t_curr, dtype=latents.dtype)
8080
pred = transformer.apply(
8181
{"params" : state.params},
82-
img=latents,
82+
hidden_states=latents,
8383
img_ids=latent_image_ids,
84-
txt=prompt_embeds,
84+
encoder_hidden_states=prompt_embeds,
8585
txt_ids=txt_ids,
86-
timesteps=t_vec,
86+
timestep=t_vec,
8787
guidance=guidance_vec,
88-
y=vec
89-
)
90-
jax.debug.print("*****pred max: {x}", x=np.max(pred))
91-
jax.debug.print("*****pred min: {x}", x=np.min(pred))
88+
pooled_projections=vec
89+
).sample
9290
latents = latents + (t_prev - t_curr) * pred
9391
latents = jnp.array(latents, dtype=latents_dtype)
9492
return latents, state, c_ts, p_ts

src/maxdiffusion/models/attention_flax.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,177 @@ def chunk_scanner(chunk_idx, _):
322322

323323
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
324324

325+
class FlaxFluxAttention(nn.Module):
326+
query_dim: int
327+
heads: int = 8
328+
dim_head: int = 64
329+
dropout: float = 0.0
330+
use_memory_efficient_attention: bool = False
331+
split_head_dim: bool = False
332+
attention_kernel: str = "dot_product"
333+
flash_min_seq_length: int = 4096
334+
flash_block_sizes: BlockSizes = None
335+
mesh: jax.sharding.Mesh = None
336+
dtype: jnp.dtype = jnp.float32
337+
weights_dtype: jnp.dtype = jnp.float32
338+
query_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
339+
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
340+
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
341+
out_axis_names: AxisNames = (BATCH, LENGTH, EMBED)
342+
precision: jax.lax.Precision = None
343+
qkv_bias: bool = False
344+
345+
def setup(self):
346+
if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None:
347+
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
348+
inner_dim = self.dim_head * self.heads
349+
scale = self.dim_head**-0.5
350+
351+
self.attention_op = AttentionOp(
352+
mesh=self.mesh,
353+
attention_kernel=self.attention_kernel,
354+
scale=scale,
355+
heads=self.heads,
356+
dim_head=self.dim_head,
357+
flash_min_seq_length=self.flash_min_seq_length,
358+
use_memory_efficient_attention=self.use_memory_efficient_attention,
359+
split_head_dim=self.split_head_dim,
360+
flash_block_sizes=self.flash_block_sizes,
361+
dtype=self.dtype,
362+
float32_qk_product=False,
363+
)
364+
365+
kernel_axes = ("embed", "heads")
366+
qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes)
367+
368+
self.qkv = nn.Dense(
369+
inner_dim * 3,
370+
kernel_init=qkv_init_kernel,
371+
use_bias=self.qkv_bias,
372+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
373+
dtype=self.dtype,
374+
param_dtype=self.weights_dtype,
375+
name="i_qkv",
376+
precision=self.precision,
377+
)
378+
379+
self.encoder_qkv = nn.Dense(
380+
inner_dim * 3,
381+
kernel_init=qkv_init_kernel,
382+
use_bias=self.qkv_bias,
383+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
384+
dtype=self.dtype,
385+
param_dtype=self.weights_dtype,
386+
name="e_qkv",
387+
precision=self.precision,
388+
)
389+
390+
self.proj_attn = nn.Dense(
391+
self.query_dim,
392+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes),
393+
use_bias=True,
394+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
395+
dtype=self.dtype,
396+
param_dtype=self.weights_dtype,
397+
name="i_proj",
398+
precision=self.precision,
399+
)
400+
401+
self.encoder_proj_attn = nn.Dense(
402+
self.query_dim,
403+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes),
404+
use_bias=True,
405+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
406+
dtype=self.dtype,
407+
param_dtype=self.weights_dtype,
408+
name="e_proj",
409+
precision=self.precision,
410+
)
411+
412+
self.query_norm = nn.RMSNorm(
413+
dtype=self.dtype,
414+
scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)),
415+
param_dtype=self.weights_dtype,
416+
)
417+
self.key_norm = nn.RMSNorm(
418+
dtype=self.dtype,
419+
scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)),
420+
param_dtype=self.weights_dtype,
421+
)
422+
423+
self.encoder_query_norm = nn.RMSNorm(
424+
dtype=self.dtype,
425+
scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)),
426+
param_dtype=self.weights_dtype,
427+
)
428+
self.encoder_key_norm = nn.RMSNorm(
429+
dtype=self.dtype,
430+
scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)),
431+
param_dtype=self.weights_dtype,
432+
)
433+
434+
def apply_rope(self, xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
435+
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
436+
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
437+
438+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
439+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
440+
441+
return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype)
442+
443+
def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None):
444+
445+
qkv_proj = self.qkv(hidden_states)
446+
B, L = hidden_states.shape[:2]
447+
H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3
448+
qkv_proj = qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4)
449+
query_proj, key_proj, value_proj = qkv_proj
450+
451+
query_proj = self.query_norm(query_proj)
452+
453+
key_proj = self.key_norm(key_proj)
454+
455+
if encoder_hidden_states is not None:
456+
457+
encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states)
458+
B, L = encoder_hidden_states.shape[:2]
459+
H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3
460+
encoder_qkv_proj = encoder_qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4)
461+
encoder_query_proj, encoder_key_proj, encoder_value_proj = encoder_qkv_proj
462+
463+
encoder_query_proj = self.encoder_query_norm(encoder_query_proj)
464+
465+
encoder_key_proj = self.encoder_key_norm(encoder_key_proj)
466+
467+
query_proj = jnp.concatenate((encoder_query_proj, query_proj), axis=2)
468+
key_proj = jnp.concatenate((encoder_key_proj, key_proj), axis=2)
469+
value_proj = jnp.concatenate((encoder_value_proj, value_proj), axis=2)
470+
471+
query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names)
472+
key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names)
473+
value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)
474+
475+
image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2)
476+
query_proj, key_proj = self.apply_rope(query_proj, key_proj, image_rotary_emb)
477+
478+
query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1)
479+
key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1)
480+
value_proj = value_proj.transpose(0, 2, 1, 3).reshape(value_proj.shape[0], value_proj.shape[2], -1)
481+
482+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
483+
context_attn_output = None
484+
485+
if encoder_hidden_states is not None:
486+
context_attn_output, attn_output = (
487+
attn_output[:, : encoder_hidden_states.shape[1]],
488+
attn_output[:, encoder_hidden_states.shape[1] :],
489+
)
490+
491+
attn_output = self.proj_attn(attn_output)
492+
493+
context_attn_output = self.encoder_proj_attn(context_attn_output)
494+
495+
return attn_output, context_attn_output
325496

326497
class FlaxFluxAttention(nn.Module):
327498
query_dim: int

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,14 @@ class FlaxTimestepEmbedding(nn.Module):
7373

7474
@nn.compact
7575
def __call__(self, temb):
76-
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, param_dtype=self.weights_dtype, name="linear_1")(temb)
76+
temb = nn.Dense(self.time_embed_dim,
77+
dtype=self.dtype,
78+
param_dtype=self.weights_dtype,
79+
name="linear_1")(temb)
7780
temb = nn.silu(temb)
78-
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, param_dtype=self.weights_dtype, name="linear_2")(temb)
81+
temb = nn.Dense(self.time_embed_dim,
82+
dtype=self.dtype,
83+
param_dtype=self.weights_dtype, name="linear_2")(temb)
7984
return temb
8085

8186

@@ -98,7 +103,6 @@ def __call__(self, timesteps):
98103
timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
99104
)
100105

101-
102106
def get_1d_rotary_pos_embed(
103107
dim: int, pos: Union[jnp.array, int], theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0, freqs_dtype=jnp.float32
104108
):
@@ -119,7 +123,6 @@ def get_1d_rotary_pos_embed(
119123

120124
return out
121125

122-
123126
class PixArtAlphaTextProjection(nn.Module):
124127
"""
125128
Projects caption embeddings. Also handles dropout for classifier-free guidance.
@@ -236,3 +239,36 @@ def __call__(self, timestep, guidance, pooled_projection):
236239
conditioning = time_guidance_emb + pooled_projections
237240

238241
return conditioning
242+
243+
244+
# class HFEmbedder(nnx.Module):
245+
246+
# def __init__(self, version: str, max_length: int, **hf_kwargs):
247+
# super().__init__()
248+
# self.is_clip = version.split("/")[1].startswith("clip")
249+
# self.max_length = max_length
250+
# self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
251+
252+
# if self.is_clip:
253+
# self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(version, max_length=max_length, use_fast=True)
254+
# self.hf_module: FlaxCLIPTextModel = FlaxCLIPTextModel.from_pretrained(version, **hf_kwargs)
255+
# else:
256+
# self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(version, max_length=max_length, use_fast=True)
257+
# self.hf_module: FlaxT5EncoderModel = FlaxT5EncoderModel.from_pretrained(version, **hf_kwargs)
258+
259+
# def __call__(self, text: list[str]):
260+
# batch_encoding = self.tokenizer(
261+
# text,
262+
# truncation=True,
263+
# max_length=self.max_length,
264+
# return_length=False,
265+
# return_overflowing_tokens=False,
266+
# padding="max_length",
267+
# return_tensors="np",
268+
# )
269+
# outputs = self.hf_module(
270+
# input_ids=batch_encoding["input_ids"],
271+
# attention_mask=None,
272+
# output_hidden_states=False,
273+
# )
274+
# return outputs[self.output_key]

0 commit comments

Comments
 (0)