Skip to content

Commit 3141d69

Browse files
committed
ruff and code_style
1 parent f05a7be commit 3141d69

4 files changed

Lines changed: 16 additions & 84 deletions

File tree

src/maxdiffusion/generate_flux.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -135,19 +135,7 @@ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: flo
135135

136136

137137
def run_inference(
138-
states,
139-
transformer,
140-
vae,
141-
config,
142-
mesh,
143-
latents,
144-
latent_image_ids,
145-
prompt_embeds,
146-
txt_ids,
147-
vec,
148-
guidance_vec,
149-
c_ts,
150-
p_ts
138+
states, transformer, vae, config, mesh, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts
151139
):
152140

153141
transformer_state = states["transformer"]
@@ -468,7 +456,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
468456
vec=pooled_prompt_embeds,
469457
guidance_vec=guidance,
470458
c_ts=c_ts,
471-
p_ts=p_ts
459+
p_ts=p_ts,
472460
),
473461
in_shardings=(state_shardings,),
474462
out_shardings=None,

src/maxdiffusion/models/attention_flax.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,20 @@ class AttentionOp(nn.Module):
5555
def setup(self):
5656
if self.attention_kernel == "cudnn_flash_te":
5757
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
58+
5859
self.dpa_layer = DotProductAttention(
59-
head_dim=self.dim_head,
60-
num_attention_heads=self.heads,
61-
num_gqa_groups=self.heads,
62-
attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal'
63-
attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
64-
# attention_dropout=self.dropout_rate,
65-
dropout_rng_name="aqt",
66-
dtype=self.dtype,
67-
# float32_logits=self.float32_logits,
68-
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
69-
scale_factor=self.scale,
70-
transpose_batch_sequence=False,
60+
head_dim=self.dim_head,
61+
num_attention_heads=self.heads,
62+
num_gqa_groups=self.heads,
63+
attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal'
64+
attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
65+
# attention_dropout=self.dropout_rate,
66+
dropout_rng_name="aqt",
67+
dtype=self.dtype,
68+
# float32_logits=self.float32_logits,
69+
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
70+
scale_factor=self.scale,
71+
transpose_batch_sequence=False,
7172
)
7273

7374
def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None:

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -513,9 +513,7 @@ def __call__(
513513
hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1)
514514
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))
515515
for single_block in self.single_blocks:
516-
hidden_states = single_block(
517-
hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb
518-
)
516+
hidden_states = single_block(hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb)
519517
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
520518

521519
hidden_states = self.norm_out(hidden_states, temb)

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -278,61 +278,6 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params,
278278
return params, rank, network_alphas
279279

280280

281-
def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, adapter_name):
282-
pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()}
283-
transformer_params = flatten_dict(unfreeze(params["transformer"]))
284-
network_alphas = {}
285-
rank = None
286-
for pt_key, tensor in pt_state_dict.items():
287-
renamed_pt_key = rename_key(pt_key)
288-
renamed_pt_key = renamed_pt_key.replace("lora_unet_", "")
289-
renamed_pt_key = renamed_pt_key.replace("lora_down", f"lora-{adapter_name}.down")
290-
renamed_pt_key = renamed_pt_key.replace("lora_up", f"lora-{adapter_name}.up")
291-
292-
if "double_blocks" in renamed_pt_key:
293-
renamed_pt_key = renamed_pt_key.replace("double_blocks.", "double_blocks_")
294-
renamed_pt_key = renamed_pt_key.replace("processor.proj_lora1.down", f"attn.i_proj.lora-{adapter_name}.down")
295-
renamed_pt_key = renamed_pt_key.replace("processor.proj_lora1.up", f"attn.i_proj.lora-{adapter_name}.up")
296-
renamed_pt_key = renamed_pt_key.replace("processor.proj_lora2.down", f"attn.e_proj.lora-{adapter_name}.down")
297-
renamed_pt_key = renamed_pt_key.replace("processor.proj_lora2.up", f"attn.e_proj.lora-{adapter_name}.up")
298-
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.down", f"attn.i_qkv.lora-{adapter_name}.down")
299-
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.up", f"attn.i_qkv.lora-{adapter_name}.up")
300-
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.down", f"attn.e_qkv.lora-{adapter_name}.down")
301-
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.up", f"attn.e_qkv.lora-{adapter_name}.up")
302-
303-
renamed_pt_key = renamed_pt_key.replace("_img_attn_proj", ".attn.i_proj")
304-
renamed_pt_key = renamed_pt_key.replace("_img_attn_qkv", ".attn.i_qkv")
305-
renamed_pt_key = renamed_pt_key.replace("_img_mlp_0", ".img_mlp.layers_0")
306-
renamed_pt_key = renamed_pt_key.replace("_img_mlp_2", ".img_mlp.layers_2")
307-
renamed_pt_key = renamed_pt_key.replace("_img_mod_lin", ".img_norm1.lin")
308-
renamed_pt_key = renamed_pt_key.replace("_txt_attn_proj", ".attn.e_proj")
309-
renamed_pt_key = renamed_pt_key.replace("_txt_attn_qkv", ".attn.e_qkv")
310-
renamed_pt_key = renamed_pt_key.replace("_txt_mlp_0", ".txt_mlp.layers_0")
311-
renamed_pt_key = renamed_pt_key.replace("_txt_mlp_2", ".txt_mlp.layers_2")
312-
renamed_pt_key = renamed_pt_key.replace("_txt_mod_lin", ".txt_norm1.lin")
313-
elif "single_blocks" in renamed_pt_key:
314-
renamed_pt_key = renamed_pt_key.replace("_linear1", ".linear1")
315-
renamed_pt_key = renamed_pt_key.replace("_linear2", ".linear2")
316-
renamed_pt_key = renamed_pt_key.replace("_modulation_lin", ".norm.lin")
317-
318-
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
319-
320-
pt_tuple_key = tuple(renamed_pt_key.split("."))
321-
if "alpha" in pt_tuple_key:
322-
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "down", "kernel")
323-
network_alphas[tuple([*pt_tuple_key])] = tensor.item() # noqa: C409
324-
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "up", "kernel")
325-
network_alphas[tuple([*pt_tuple_key])] = tensor.item() # noqa: C409
326-
else:
327-
if pt_tuple_key[-2] == "up":
328-
rank = tensor.shape[1]
329-
transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype) # noqa: C409
330-
331-
params["transformer"] = unflatten_dict(transformer_params)
332-
333-
return params, rank, network_alphas
334-
335-
336281
def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alphas, adapter_name):
337282
# Step 1: Convert pytorch tensor to numpy
338283
# sometimes we load weights in bf16 and numpy doesn't support it

0 commit comments

Comments
 (0)