Skip to content

Commit df25e47

Browse files
fix repeated double and single blocks.
1 parent 05b6fc8 commit df25e47

4 files changed

Lines changed: 50 additions & 75 deletions

File tree

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,7 @@ precision: "DEFAULT"
5555
from_pt: True
5656
split_head_dim: True
5757
attention: 'flash' # Supported attention: dot_product, flash
58-
flash_block_sizes: {
59-
"block_q" : 128,
60-
"block_kv" : 128,
61-
"block_kv_compute" : 128,
62-
"block_q_dkv" : 128,
63-
"block_kv_dkv" : 128,
64-
"block_kv_dkv_compute" : 128,
65-
"block_q_dq" : 128,
66-
"block_kv_dq" : 128
67-
}
58+
flash_block_sizes: {}
6859
# GroupNorm groups
6960
norm_num_groups: 32
7061

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,7 @@ precision: "DEFAULT"
5454
from_pt: True
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash
57-
flash_block_sizes: {
58-
"block_q" : 128,
59-
"block_kv" : 128,
60-
"block_kv_compute" : 128,
61-
"block_q_dkv" : 128,
62-
"block_kv_dkv" : 128,
63-
"block_kv_dkv_compute" : 128,
64-
"block_q_dq" : 128,
65-
"block_kv_dq" : 128
66-
}
57+
flash_block_sizes: {}
6758
# GroupNorm groups
6859
norm_num_groups: 32
6960

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -383,48 +383,43 @@ def setup(self):
383383
precision=self.precision,
384384
)
385385

386-
self.double_blocks = nn.Sequential(
387-
[
388-
*[
389-
FluxTransformerBlock(
390-
dim=self.inner_dim,
391-
num_attention_heads=self.num_attention_heads,
392-
attention_head_dim=self.attention_head_dim,
393-
attention_kernel=self.attention_kernel,
394-
flash_min_seq_length=self.flash_min_seq_length,
395-
flash_block_sizes=self.flash_block_sizes,
396-
mesh=self.mesh,
397-
dtype=self.dtype,
398-
weights_dtype=self.weights_dtype,
399-
precision=self.precision,
400-
mlp_ratio=self.mlp_ratio,
401-
qkv_bias=self.qkv_bias,
402-
)
403-
for _ in range(self.num_layers)
404-
]
405-
]
406-
)
407-
408-
self.single_blocks = nn.Sequential(
409-
[
410-
*[
411-
FluxSingleTransformerBlock(
412-
dim=self.inner_dim,
413-
num_attention_heads=self.num_attention_heads,
414-
attention_head_dim=self.attention_head_dim,
415-
attention_kernel=self.attention_kernel,
416-
flash_min_seq_length=self.flash_min_seq_length,
417-
flash_block_sizes=self.flash_block_sizes,
418-
mesh=self.mesh,
419-
dtype=self.dtype,
420-
weights_dtype=self.weights_dtype,
421-
precision=self.precision,
422-
mlp_ratio=self.mlp_ratio,
423-
)
424-
for _ in range(self.num_single_layers)
425-
]
426-
]
427-
)
386+
double_blocks = []
387+
for _ in range(self.num_layers):
388+
double_block = FluxTransformerBlock(
389+
dim=self.inner_dim,
390+
num_attention_heads=self.num_attention_heads,
391+
attention_head_dim=self.attention_head_dim,
392+
attention_kernel=self.attention_kernel,
393+
flash_min_seq_length=self.flash_min_seq_length,
394+
flash_block_sizes=self.flash_block_sizes,
395+
mesh=self.mesh,
396+
dtype=self.dtype,
397+
weights_dtype=self.weights_dtype,
398+
precision=self.precision,
399+
mlp_ratio=self.mlp_ratio,
400+
qkv_bias=self.qkv_bias,
401+
)
402+
double_blocks.append(double_block)
403+
self.double_blocks = double_blocks
404+
405+
single_blocks = []
406+
for _ in range(self.num_single_layers):
407+
single_block = FluxSingleTransformerBlock(
408+
dim=self.inner_dim,
409+
num_attention_heads=self.num_attention_heads,
410+
attention_head_dim=self.attention_head_dim,
411+
attention_kernel=self.attention_kernel,
412+
flash_min_seq_length=self.flash_min_seq_length,
413+
flash_block_sizes=self.flash_block_sizes,
414+
mesh=self.mesh,
415+
dtype=self.dtype,
416+
weights_dtype=self.weights_dtype,
417+
precision=self.precision,
418+
mlp_ratio=self.mlp_ratio,
419+
)
420+
single_blocks.append(single_block)
421+
422+
self.single_blocks = single_blocks
428423

429424
self.norm_out = AdaLayerNormContinuous(
430425
self.inner_dim,
@@ -509,18 +504,19 @@ def __call__(
509504
image_rotary_emb = self.pe_embedder(ids)
510505
image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed"))
511506

512-
hidden_states, encoder_hidden_states, temb, image_rotary_emb = self.double_blocks(
513-
hidden_states=hidden_states,
514-
encoder_hidden_states=encoder_hidden_states,
515-
temb=temb,
516-
image_rotary_emb=image_rotary_emb,
517-
)
507+
for double_block in self.double_blocks:
508+
hidden_states, encoder_hidden_states, temb, image_rotary_emb = double_block(
509+
hidden_states=hidden_states,
510+
encoder_hidden_states=encoder_hidden_states,
511+
temb=temb,
512+
image_rotary_emb=image_rotary_emb,
513+
)
518514
hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1)
519515
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))
520-
521-
hidden_states, temb, image_rotary_emb = self.single_blocks(
522-
hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb
523-
)
516+
for single_block in self.single_blocks:
517+
hidden_states, temb, image_rotary_emb = single_block(
518+
hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb
519+
)
524520
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
525521

526522
hidden_states = self.norm_out(hidden_states, temb)

src/maxdiffusion/models/flux/util.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool
159159
for pt_key, tensor in tensors.items():
160160
renamed_pt_key = rename_key(pt_key)
161161
if "double_blocks" in renamed_pt_key:
162-
renamed_pt_key = renamed_pt_key.replace("double_blocks_", "double_blocks.layers_")
163162
renamed_pt_key = renamed_pt_key.replace("img_mlp_", "img_mlp.layers_")
164163
renamed_pt_key = renamed_pt_key.replace("txt_mlp_", "txt_mlp.layers_")
165164
renamed_pt_key = renamed_pt_key.replace("img_mod", "img_norm1")
@@ -176,7 +175,6 @@ def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool
176175
renamed_pt_key = renamed_pt_key.replace("in_layer", "linear_1")
177176
renamed_pt_key = renamed_pt_key.replace("out_layer", "linear_2")
178177
elif "single_blocks" in renamed_pt_key:
179-
renamed_pt_key = renamed_pt_key.replace("single_blocks_", "single_blocks.layers_")
180178
renamed_pt_key = renamed_pt_key.replace("modulation", "norm")
181179
renamed_pt_key = renamed_pt_key.replace("norm.key_norm", "attn.key_norm")
182180
renamed_pt_key = renamed_pt_key.replace("norm.query_norm", "attn.query_norm")
@@ -188,7 +186,6 @@ def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool
188186
elif "final_layer" in renamed_pt_key:
189187
renamed_pt_key = renamed_pt_key.replace("final_layer.linear", "proj_out")
190188
renamed_pt_key = renamed_pt_key.replace("final_layer.adaLN_modulation_1", "norm_out.Dense_0")
191-
192189
pt_tuple_key = tuple(renamed_pt_key.split("."))
193190
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes)
194191
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)

0 commit comments

Comments
 (0)