Skip to content

Commit 3eb5729

Browse files
convert pt weights to flax and load transformer state.
1 parent bb71982 commit 3eb5729

4 files changed

Lines changed: 154 additions & 80 deletions

File tree

src/maxdiffusion/generate_flux.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
get_memory_allocations,
3939
create_device_mesh,
4040
get_flash_block_sizes,
41-
get_precision
41+
get_precision,
42+
setup_initial_state
4243
)
4344

4445
def prepare_latent_image_ids(height, width):
@@ -195,7 +196,7 @@ def run(config):
195196
devices_array = create_device_mesh(config)
196197
mesh = Mesh(devices_array, config.mesh_axes)
197198

198-
per_host_number_of_images = 1#config.per_device_batch_size * jax.local_device_count()
199+
per_host_number_of_images = config.per_device_batch_size * jax.local_device_count()
199200

200201
# LOAD VAE
201202

@@ -233,16 +234,6 @@ def run(config):
233234
rng=rng
234235
)
235236

236-
#load_flow_model("flux-dev", "cpu")
237-
238-
# transformer, params = FluxTransformer2DModel.from_pretrained(
239-
# config.pretrained_model_name_or_path,
240-
# subfolder="text_encoder_2",
241-
# from_pt=True,
242-
# dtype=config.weights_dtype
243-
# )
244-
245-
246237
# LOAD TEXT ENCODERS - t5 on cpu
247238
clip_text_encoder = FlaxCLIPTextModel.from_pretrained(
248239
config.pretrained_model_name_or_path,
@@ -303,17 +294,35 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
303294
pooled_prompt_embeds
304295
)
305296
get_memory_allocations()
306-
transformer_params = transformer.init_weights(rng, True)
307-
# transformer_params = transformer.init(
308-
# {"params" : rng},
309-
# img=latents,
310-
# img_ids=latent_image_ids,
311-
# txt=prompt_embeds,
312-
# txt_ids=text_ids,
313-
# timesteps=timesteps,
314-
# guidance=guidance,
315-
# y=pooled_prompt_embeds
316-
# )["params"]
297+
# evaluate shapes
298+
transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=512, eval_only=True)
299+
300+
# loads pretrained weights
301+
transformer_params = load_flow_model("flux-dev", transformer_eval_params, "cpu")
302+
get_memory_allocations()
303+
# create transformer state
304+
weights_init_fn = functools.partial(transformer.init_weights, rngs=rng, max_sequence_length=512, eval_only=False)
305+
transformer_state, transformer_state_shardings = setup_initial_state(
306+
model=transformer,
307+
tx=None,
308+
config=config,
309+
mesh=mesh,
310+
weights_init_fn=weights_init_fn,
311+
model_params=None,
312+
training=False
313+
)
314+
breakpoint()
315+
transformer_state = transformer_state.replace(params=transformer_params)
316+
img = transformer.apply(
317+
{"params" : transformer_state.params},
318+
img=latents,
319+
img_ids=latent_image_ids,
320+
txt=prompt_embeds,
321+
txt_ids=text_ids,
322+
timesteps=timesteps,
323+
guidance=guidance,
324+
y=pooled_prompt_embeds
325+
)
317326
get_memory_allocations()
318327
breakpoint()
319328

src/maxdiffusion/models/flux/modules/layers.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,13 @@ class QKNorm(nn.Module):
5353
def __call__(self, q: Array, k: Array, v: Array) -> tuple[Array, Array]:
5454
q = nn.RMSNorm(
5555
dtype=self.dtype,
56-
param_dtype=self.weights_dtype
56+
param_dtype=self.weights_dtype,
57+
name="query_norm"
5758
)(q)
5859
k = nn.RMSNorm(
5960
dtype=self.dtype,
60-
param_dtype=self.weights_dtype
61+
param_dtype=self.weights_dtype,
62+
name="key_norm"
6163
)(k)
6264
return q, k
6365

@@ -173,7 +175,8 @@ def __call__(self, vec: Array) -> tuple[ModulationOut, ModulationOut | None]:
173175
kernel_init=nn.with_logical_partitioning(
174176
nn.initializers.lecun_normal(),
175177
("embed", "heads")
176-
)
178+
),
179+
name="lin"
177180
)(nn.silu(vec))
178181

179182
out = jnp.split(lin[:, None, :], multiplier, axis=-1)
@@ -205,14 +208,16 @@ def __call__(self, x: Array, vec: Array, pe: Array) -> Array:
205208
double=False,
206209
dtype=self.dtype,
207210
weights_dtype=self.weights_dtype,
208-
precision=self.precision
211+
precision=self.precision,
212+
name="modulation"
209213
)(vec)
210214
x_mod = (1 + mod.scale) * nn.LayerNorm(
211215
use_scale=False,
212216
use_bias=False,
213217
epsilon=1e-6,
214218
dtype=self.dtype,
215-
param_dtype=self.weights_dtype
219+
param_dtype=self.weights_dtype,
220+
name="pre_norm"
216221
)(x) + mod.shift
217222

218223
x_mod = nn.Dense(
@@ -231,7 +236,8 @@ def __call__(self, x: Array, vec: Array, pe: Array) -> Array:
231236
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
232237
q, k = QKNorm(
233238
dtype=self.dtype,
234-
weights_dtype=self.weights_dtype
239+
weights_dtype=self.weights_dtype,
240+
name="norm"
235241
)(q, k, v)
236242

237243
q, k = apply_rope(q, k, pe)
@@ -286,15 +292,17 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
286292
double=True,
287293
dtype=self.dtype,
288294
weights_dtype=self.weights_dtype,
289-
precision=self.precision
295+
precision=self.precision,
296+
name="img_mod"
290297
)(vec)
291298

292299
txt_mod1, txt_mod2 = Modulation(
293300
self.hidden_size,
294301
double=True,
295302
dtype=self.dtype,
296303
weights_dtype=self.weights_dtype,
297-
precision=self.precision
304+
precision=self.precision,
305+
name="txt_mod"
298306
)(vec)
299307

300308
# prepare image for attention
@@ -303,7 +311,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
303311
use_bias=False,
304312
epsilon=1e-6,
305313
dtype=self.dtype,
306-
param_dtype=self.weights_dtype
314+
param_dtype=self.weights_dtype,
315+
name="img_norm1"
307316
)(img)
308317
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
309318
img_qkv = nn.Dense(
@@ -315,14 +324,16 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
315324
kernel_init=nn.with_logical_partitioning(
316325
nn.initializers.lecun_normal(),
317326
("embed", "heads")
318-
)
327+
),
328+
name="img_attn_qkv"
319329
)(img_modulated)
320330
img_q, img_k, img_v = rearrange(
321331
img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
322332
)
323333
img_q, img_k = QKNorm(
324334
dtype=self.dtype,
325-
weights_dtype=self.weights_dtype
335+
weights_dtype=self.weights_dtype,
336+
name="img_attn_norm"
326337
)(img_q, img_k, img_v)
327338

328339
# prepare text for attention
@@ -331,7 +342,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
331342
use_bias=False,
332343
epsilon=1e-6,
333344
dtype=self.dtype,
334-
param_dtype=self.weights_dtype
345+
param_dtype=self.weights_dtype,
346+
name="txt_norm1"
335347
)(txt)
336348
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
337349
txt_qkv = nn.Dense(
@@ -343,14 +355,16 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
343355
kernel_init=nn.with_logical_partitioning(
344356
nn.initializers.lecun_normal(),
345357
("embed", "heads")
346-
)
358+
),
359+
name="txt_attn_qkv"
347360
)(txt_modulated)
348361
txt_q, txt_k, txt_v = rearrange(
349362
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
350363
)
351364
txt_q, txt_k = QKNorm(
352365
dtype=self.dtype,
353-
weights_dtype=self.weights_dtype
366+
weights_dtype=self.weights_dtype,
367+
name="txt_attn_norm"
354368
)(txt_q, txt_k, txt_v)
355369

356370
# run actual attention
@@ -385,6 +399,7 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
385399
nn.initializers.lecun_normal(),
386400
("heads", "embed")
387401
),
402+
name="img_attn_proj"
388403
)(img_attn)
389404
img = img + img_mod2.gate * nn.Sequential(
390405
[
@@ -397,7 +412,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
397412
kernel_init=nn.with_logical_partitioning(
398413
nn.initializers.lecun_normal(),
399414
("embed", "heads")
400-
)
415+
),
416+
name="img_mlp_0"
401417
),
402418
nn.gelu,
403419
nn.Dense(
@@ -408,14 +424,16 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
408424
kernel_init=nn.with_logical_partitioning(
409425
nn.initializers.lecun_normal(),
410426
("heads", "embed")
411-
)
412-
)
413-
]
427+
),
428+
name="img_mlp_2"
429+
),
430+
],
414431
)(
415432
(1 + img_mod2.scale) * nn.LayerNorm(
416433
use_scale=False,
417434
use_bias=False,
418-
param_dtype=self.weights_dtype
435+
param_dtype=self.weights_dtype,
436+
name="img_norm2"
419437
)(img) + img_mod2.shift
420438
)
421439

@@ -430,6 +448,7 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
430448
nn.initializers.lecun_normal(),
431449
("heads", "embed")
432450
),
451+
name="txt_attn_proj"
433452
)(txt_attn)
434453
txt = txt + txt_mod1.gate * txt_proj
435454

@@ -444,7 +463,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
444463
kernel_init=nn.with_logical_partitioning(
445464
nn.initializers.lecun_normal(),
446465
("embed", "heads")
447-
)
466+
),
467+
name="txt_mlp_0"
448468
),
449469
nn.gelu,
450470
nn.Dense(
@@ -455,14 +475,16 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
455475
kernel_init=nn.with_logical_partitioning(
456476
nn.initializers.lecun_normal(),
457477
("heads", "embed")
458-
)
459-
)
460-
]
478+
),
479+
name="txt_mlp_2"
480+
),
481+
],
461482
)(
462483
(1 + txt_mod2.scale) * nn.LayerNorm(
463484
use_scale=False,
464485
use_bias=False,
465-
param_dtype=self.weights_dtype
486+
param_dtype=self.weights_dtype,
487+
name="txt_norm2"
466488
)(txt) + txt_mod2.shift
467489
)
468490

@@ -491,8 +513,9 @@ def __call__(self, x: Array, vec: Array) -> Array:
491513
kernel_init=nn.with_logical_partitioning(
492514
nn.initializers.lecun_normal(),
493515
("embed", "heads")
494-
)
495-
)
516+
),
517+
name="adaLN_modulation_1"
518+
),
496519
]
497520
)(vec), 2, axis=1
498521
)
@@ -515,5 +538,5 @@ def __call__(self, x: Array, vec: Array) -> Array:
515538
("heads", "embed")
516539
),
517540
name="linear"
518-
)
541+
)(x)
519542
return x

0 commit comments

Comments
 (0)