Skip to content

Commit 719e6db

Browse files
jfacevedo-googleksikiric
authored andcommitted
adding another format lora support.
1 parent ff16ba6 commit 719e6db

5 files changed

Lines changed: 43 additions & 36 deletions

File tree

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
3030
# Flux params
3131
flux_name: "flux-dev"
3232
max_sequence_length: 512
33-
time_shift: False
33+
time_shift: True
3434
base_shift: 0.5
3535
max_shift: 1.15
3636
# offloads t5 encoder after text encoding to save memory.

src/maxdiffusion/generate_flux.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def unpack(x: Array, height: int, width: int) -> Array:
7777

7878

7979
def vae_decode(latents, vae, state, config):
80-
img = unpack(x=latents, height=config.resolution, width=config.resolution)
80+
img = unpack(x=latents.astype(jnp.float32), height=config.resolution, width=config.resolution)
8181
img = img / vae.config.scaling_factor + vae.config.shift_factor
8282
img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample
8383
return img
@@ -115,13 +115,12 @@ def loop_body(
115115

116116
def prepare_latent_image_ids(height, width):
117117
latent_image_ids = jnp.zeros((height, width, 3))
118-
latent_image_ids = latent_image_ids.at[..., 1].set(latent_image_ids[..., 1] + jnp.arange(height)[:, None])
119-
latent_image_ids = latent_image_ids.at[..., 2].set(latent_image_ids[..., 2] + jnp.arange(width)[None, :])
118+
latent_image_ids = latent_image_ids.at[..., 1].set(jnp.arange(height)[:, None])
119+
latent_image_ids = latent_image_ids.at[..., 2].set(jnp.arange(width)[None, :])
120120

121121
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
122122

123123
latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels)
124-
125124
return latent_image_ids.astype(jnp.bfloat16)
126125

127126

@@ -147,20 +146,10 @@ def run_inference(
147146
txt_ids,
148147
vec,
149148
guidance_vec,
149+
c_ts,
150+
p_ts
150151
):
151152

152-
timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1)
153-
# shifting the schedule to favor high timesteps for higher signal images
154-
if config.time_shift:
155-
# estimate mu based on linear estimation between two points
156-
lin_function = get_lin_function(y1=config.base_shift, y2=config.max_shift)
157-
mu = lin_function(latents.shape[1])
158-
timesteps = time_shift(mu, 1.0, timesteps).tolist()
159-
c_ts = timesteps[:-1]
160-
p_ts = timesteps[1:]
161-
# jax.debug.print("c_ts: {x}", x=c_ts)
162-
# jax.debug.print("p_ts: {x}", x=p_ts)
163-
164153
transformer_state = states["transformer"]
165154
vae_state = states["vae"]
166155

@@ -173,11 +162,10 @@ def run_inference(
173162
vec=vec,
174163
guidance_vec=guidance_vec,
175164
)
176-
177165
vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config)
178166

179167
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
180-
latents, _, _, _ = jax.lax.fori_loop(0, len(timesteps) - 1, loop_body_p, (latents, transformer_state, c_ts, p_ts))
168+
latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, transformer_state, c_ts, p_ts))
181169
image = vae_decode_p(latents)
182170
return image
183171

@@ -236,8 +224,7 @@ def get_clip_prompt_embeds(
236224

237225
prompt_embeds = text_encoder(text_input_ids, params=text_encoder.params, train=False)
238226
prompt_embeds = prompt_embeds.pooler_output
239-
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=-1)
240-
prompt_embeds = np.reshape(prompt_embeds, (batch_size * num_images_per_prompt, -1))
227+
prompt_embeds = jnp.tile(prompt_embeds, (batch_size * num_images_per_prompt, 1))
241228
return prompt_embeds
242229

243230

@@ -300,7 +287,7 @@ def encode_prompt(
300287
max_sequence_length=max_sequence_length,
301288
)
302289

303-
text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16)
290+
text_ids = jnp.zeros((prompt_embeds.shape[1], 3)).astype(jnp.bfloat16)
304291
return prompt_embeds, pooled_prompt_embeds, text_ids
305292

306293

@@ -397,18 +384,14 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
397384
print("guidance.shape: ", guidance.shape, guidance.dtype)
398385
print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype)
399386

400-
timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16)
401387
guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16)
402388

403-
validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
404-
405389
# move inputs to device and shard
406390
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
407391
latents = jax.device_put(latents, data_sharding)
408-
latent_image_ids = jax.device_put(latent_image_ids, data_sharding)
392+
latent_image_ids = jax.device_put(latent_image_ids)
409393
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
410-
text_ids = jax.device_put(text_ids, data_sharding)
411-
timesteps = jax.device_put(timesteps, data_sharding)
394+
text_ids = jax.device_put(text_ids)
412395
guidance = jax.device_put(guidance, data_sharding)
413396
pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding)
414397

@@ -458,6 +441,19 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
458441
states["transformer"] = transformer_state
459442
states["vae"] = vae_state
460443

444+
# Setup timesteps
445+
timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1)
446+
# shifting the schedule to favor high timesteps for higher signal images
447+
if config.time_shift:
448+
# estimate mu based on linear estimation between two points
449+
lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift)
450+
mu = lin_function(latents.shape[1])
451+
timesteps = time_shift(mu, 1.0, timesteps)
452+
c_ts = timesteps[:-1]
453+
p_ts = timesteps[1:]
454+
455+
validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds)
456+
461457
p_run_inference = jax.jit(
462458
functools.partial(
463459
run_inference,
@@ -471,6 +467,8 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
471467
txt_ids=text_ids,
472468
vec=pooled_prompt_embeds,
473469
guidance_vec=guidance,
470+
c_ts=c_ts,
471+
p_ts=p_ts
474472
),
475473
in_shardings=(state_shardings,),
476474
out_shardings=None,

src/maxdiffusion/loaders/flux_lora_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def rename_for_interceptor(params_keys, network_alphas, adapter_name):
5353
new_layer_lora = layer_lora[: layer_lora.index(lora_name)]
5454
if new_layer_lora not in new_params_keys:
5555
new_params_keys.append(new_layer_lora)
56-
network_alpha = network_alphas[layer_lora]
56+
network_alpha = network_alphas.get(layer_lora, None)
5757
new_network_alphas[new_layer_lora] = network_alpha
5858
return new_params_keys, new_network_alphas
5959

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,10 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None):
144144
hidden_states = self.linear2(attn_mlp)
145145
hidden_states = gate * hidden_states
146146
hidden_states = residual + hidden_states
147-
if hidden_states.dtype == jnp.float16 or hidden_states.dtype == jnp.bfloat16:
147+
if hidden_states.dtype == jnp.float16:
148148
hidden_states = jnp.clip(hidden_states, -65504, 65504)
149149

150-
return hidden_states, temb, image_rotary_emb
150+
return hidden_states
151151

152152

153153
class FluxTransformerBlock(nn.Module):
@@ -294,9 +294,9 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=
294294

295295
context_ff_output = self.txt_mlp(norm_encoder_hidden_states)
296296
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
297-
if encoder_hidden_states.dtype == jnp.float16 or encoder_hidden_states.dtype == jnp.bfloat16:
297+
if encoder_hidden_states.dtype == jnp.float16:
298298
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
299-
return hidden_states, encoder_hidden_states, temb, image_rotary_emb
299+
return hidden_states, encoder_hidden_states
300300

301301

302302
@flax_register_to_config
@@ -504,7 +504,7 @@ def __call__(
504504
image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed"))
505505

506506
for double_block in self.double_blocks:
507-
hidden_states, encoder_hidden_states, temb, image_rotary_emb = double_block(
507+
hidden_states, encoder_hidden_states = double_block(
508508
hidden_states=hidden_states,
509509
encoder_hidden_states=encoder_hidden_states,
510510
temb=temb,
@@ -513,7 +513,7 @@ def __call__(
513513
hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1)
514514
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))
515515
for single_block in self.single_blocks:
516-
hidden_states, temb, image_rotary_emb = single_block(
516+
hidden_states = single_block(
517517
hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb
518518
)
519519
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,21 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params,
229229
rank = None
230230
for pt_key, tensor in pt_state_dict.items():
231231
renamed_pt_key = rename_key(pt_key)
232-
print("renamed_pt_key:", renamed_pt_key)
233232
renamed_pt_key = renamed_pt_key.replace("lora_unet_", "")
234233
renamed_pt_key = renamed_pt_key.replace("lora_down", f"lora-{adapter_name}.down")
235234
renamed_pt_key = renamed_pt_key.replace("lora_up", f"lora-{adapter_name}.up")
236235

237236
if "double_blocks" in renamed_pt_key:
237+
renamed_pt_key = renamed_pt_key.replace("double_blocks.", "double_blocks_")
238+
renamed_pt_key = renamed_pt_key.replace("processor.proj_lora1.down", f"attn.i_proj.lora-{adapter_name}.down")
239+
renamed_pt_key = renamed_pt_key.replace("processor.proj_lora1.up", f"attn.i_proj.lora-{adapter_name}.up")
240+
renamed_pt_key = renamed_pt_key.replace("processor.proj_lora2.down", f"attn.e_proj.lora-{adapter_name}.down")
241+
renamed_pt_key = renamed_pt_key.replace("processor.proj_lora2.up", f"attn.e_proj.lora-{adapter_name}.up")
242+
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.down", f"attn.i_qkv.lora-{adapter_name}.down")
243+
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.up", f"attn.i_qkv.lora-{adapter_name}.up")
244+
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.down", f"attn.e_qkv.lora-{adapter_name}.down")
245+
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.up", f"attn.e_qkv.lora-{adapter_name}.up")
246+
238247
renamed_pt_key = renamed_pt_key.replace("_img_attn_proj", ".attn.i_proj")
239248
renamed_pt_key = renamed_pt_key.replace("_img_attn_qkv", ".attn.i_qkv")
240249
renamed_pt_key = renamed_pt_key.replace("_img_mlp_0", ".img_mlp.layers_0")

0 commit comments

Comments
 (0)