Skip to content

Commit bb71982

Browse files
finish transformer
1 parent cb91d5e commit bb71982

4 files changed

Lines changed: 252 additions & 93 deletions

File tree

src/maxdiffusion/configs/base_flux.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ save_config_to_gcs: False
2424
log_period: 100
2525

2626
pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
27+
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
28+
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
29+
2730
unet_checkpoint: ''
2831
revision: 'refs/pr/95'
2932
# This will convert the weights to this dtype.
@@ -41,7 +44,7 @@ precision: "DEFAULT"
4144
# Set true to load weights from pytorch
4245
from_pt: False
4346
split_head_dim: True
44-
attention: 'dot_product' # Supported attention: dot_product, flash
47+
attention: 'flash' # Supported attention: dot_product, flash
4548
flash_block_sizes: {}
4649
# GroupNorm groups
4750
norm_num_groups: 32
@@ -171,7 +174,7 @@ max_train_steps: 200
171174
num_train_epochs: 1
172175
seed: 0
173176
output_dir: 'sdxl-model-finetuned'
174-
per_device_batch_size: 2
177+
per_device_batch_size: 1
175178

176179
warmup_steps_fraction: 0.0
177180
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

src/maxdiffusion/generate_flux.py

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,30 @@
1616

1717
from typing import Any, Callable, Dict, List, Optional, Union, Sequence
1818
from absl import app
19-
19+
import functools
2020
import numpy as np
2121
import jax
22+
from jax.sharding import Mesh, PositionalSharding
2223
import jax.numpy as jnp
2324
from chex import Array
2425
from transformers import (
2526
CLIPTokenizer,
2627
FlaxCLIPTextModel,
2728
T5TokenizerFast,
28-
T5EncoderModel
29+
T5EncoderModel,
30+
FlaxT5EncoderModel
2931
)
3032

3133
from maxdiffusion import FlaxAutoencoderKL
3234
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
33-
3435
from maxdiffusion import pyconfig
36+
from max_utils import (
37+
device_put_replicated,
38+
get_memory_allocations,
39+
create_device_mesh,
40+
get_flash_block_sizes,
41+
get_precision
42+
)
3543

3644
def prepare_latent_image_ids(height, width):
3745
latent_image_ids = jnp.zeros((height, width, 3))
@@ -133,19 +141,17 @@ def get_t5_prompt_embeds(
133141
truncation=True,
134142
return_length=False,
135143
return_overflowing_tokens=False,
136-
return_tensors="pt"
144+
return_tensors="np"
137145
)
138146
text_input_ids = text_inputs.input_ids
139-
140147
prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False)[0]
141148
dtype = text_encoder.dtype
142-
prompt_embeds = prompt_embeds.to(dtype=dtype)
149+
prompt_embeds = prompt_embeds.astype(dtype)
143150

144151
_, seq_len, _ = prompt_embeds.shape
145-
146152
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
147-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
148-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
153+
prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1))
154+
prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1))
149155

150156
return prompt_embeds
151157

@@ -178,15 +184,16 @@ def encode_prompt(
178184
tokenizer=t5_tokenizer,
179185
text_encoder=t5_text_encoder
180186
)
181-
prompt_embeds = jnp.asarray(prompt_embeds.detach().numpy())
182187

183188
text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16)
184189
return prompt_embeds, pooled_prompt_embeds, text_ids
185190

186191
def run(config):
187192
from maxdiffusion.models.flux.util import load_flow_model
188193

189-
rng = jax.random.PRNGKey(config.seed)
194+
rng = jax.random.key(config.seed)
195+
devices_array = create_device_mesh(config)
196+
mesh = Mesh(devices_array, config.mesh_axes)
190197

191198
per_host_number_of_images = 1#config.per_device_batch_size * jax.local_device_count()
192199

@@ -201,11 +208,18 @@ def run(config):
201208
)
202209
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
203210

204-
# LOAD UNET
205-
211+
# LOAD TRANSFORMER
212+
flash_block_sizes = get_flash_block_sizes(config)
206213
transformer = FluxTransformer2DModel.from_config(
207214
config.pretrained_model_name_or_path,
208-
subfolder="transformer"
215+
subfolder="transformer",
216+
mesh=mesh,
217+
split_head_dim=config.split_head_dim,
218+
attention_kernel=config.attention,
219+
flash_block_sizes=flash_block_sizes,
220+
dtype=config.activations_dtype,
221+
weights_dtype=config.weights_dtype,
222+
precision=get_precision(config)
209223
)
210224

211225
num_channels_latents = transformer.in_channels // 4
@@ -242,34 +256,40 @@ def run(config):
242256
dtype=config.weights_dtype
243257
)
244258

245-
t5_encoder_pt = T5EncoderModel.from_pretrained(
246-
config.pretrained_model_name_or_path,
247-
subfolder="text_encoder_2",
259+
t5_encoder = FlaxT5EncoderModel.from_pretrained(
260+
config.clip_model_name_or_path,
261+
dtype=config.weights_dtype
248262
)
249-
250263
t5_tokenizer = T5TokenizerFast.from_pretrained(
251264
config.pretrained_model_name_or_path,
252265
subfolder="tokenizer_2",
253266
)
254267

268+
encoders_sharding = PositionalSharding(devices_array).replicate()
269+
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
270+
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
271+
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)
272+
t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params)
273+
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params)
274+
255275
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
256276
prompt=config.prompt,
257277
prompt_2=config.prompt_2,
258278
clip_tokenizer=clip_tokenizer,
259279
clip_text_encoder=clip_text_encoder,
260280
t5_tokenizer=t5_tokenizer,
261-
t5_text_encoder=t5_encoder_pt,
281+
t5_text_encoder=t5_encoder,
262282
num_images_per_prompt=per_host_number_of_images
263283
)
264284

265285
def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds):
266-
print("latents.shape: ", latents.shape)
267-
print("latent_image_ids.shape: ", latent_image_ids.shape)
268-
print("text_ids.shape: ", text_ids.shape)
269-
print("prompt_embeds: ", prompt_embeds.shape)
270-
print("timesteps.shape: ", timesteps.shape)
271-
print("guidance.shape: ", guidance.shape)
272-
print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape)
286+
print("latents.shape: ", latents.shape, latents.dtype)
287+
print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype)
288+
print("text_ids.shape: ", text_ids.shape, text_ids.dtype)
289+
print("prompt_embeds: ", prompt_embeds.shape, prompt_embeds.dtype)
290+
print("timesteps.shape: ", timesteps.shape, timesteps.dtype)
291+
print("guidance.shape: ", guidance.shape, guidance.dtype)
292+
print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype)
273293

274294
timesteps = jnp.asarray([1.0], dtype=jnp.bfloat16)
275295
guidance = jnp.asarray([3.5], dtype=jnp.bfloat16)
@@ -282,17 +302,19 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
282302
guidance,
283303
pooled_prompt_embeds
284304
)
285-
286-
transformer_params = transformer.init(
287-
{"params" : rng},
288-
img=latents,
289-
img_ids=latent_image_ids,
290-
txt=prompt_embeds,
291-
txt_ids=text_ids,
292-
timesteps=timesteps,
293-
guidance=guidance,
294-
y=pooled_prompt_embeds
295-
)["params"]
305+
get_memory_allocations()
306+
transformer_params = transformer.init_weights(rng, True)
307+
# transformer_params = transformer.init(
308+
# {"params" : rng},
309+
# img=latents,
310+
# img_ids=latent_image_ids,
311+
# txt=prompt_embeds,
312+
# txt_ids=text_ids,
313+
# timesteps=timesteps,
314+
# guidance=guidance,
315+
# y=pooled_prompt_embeds
316+
# )["params"]
317+
get_memory_allocations()
296318
breakpoint()
297319

298320

src/maxdiffusion/models/flux/modules/layers.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,9 @@ def __call__(self, vec: Array) -> tuple[ModulationOut, ModulationOut | None]:
177177
)(nn.silu(vec))
178178

179179
out = jnp.split(lin[:, None, :], multiplier, axis=-1)
180-
181180
return (
182181
ModulationOut(*out[:3]),
183-
ModulationOut(*out[3:] if self.double else None)
182+
ModulationOut(*out[3:]) if self.double else None
184183
)
185184

186185
class SingleStreamBlock(nn.Module):
@@ -209,7 +208,6 @@ def __call__(self, x: Array, vec: Array, pe: Array) -> Array:
209208
precision=self.precision
210209
)(vec)
211210
x_mod = (1 + mod.scale) * nn.LayerNorm(
212-
self.hidden_size,
213211
use_scale=False,
214212
use_bias=False,
215213
epsilon=1e-6,
@@ -261,7 +259,7 @@ def __call__(self, x: Array, vec: Array, pe: Array) -> Array:
261259
("embed", "heads")
262260
),
263261
name="linear2"
264-
)(jnp.concatenate((attn, nn.genu(mlp)), 2))
262+
)(jnp.concatenate((attn, nn.gelu(mlp)), 2))
265263
return x + mod.gate * output
266264

267265
class DoubleStreamBlock(nn.Module):
@@ -279,7 +277,7 @@ class DoubleStreamBlock(nn.Module):
279277
attention_kernel: str = "dot_product"
280278

281279
@nn.compact
282-
def __call__(self, img: Array, txt: Array, vec: Array, pe: Array):
280+
def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array, Array]:
283281

284282
mlp_hidden_dim = int(self.hidden_size * self.mlp_ratio)
285283

@@ -422,7 +420,7 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array):
422420
)
423421

424422
# calculate the txt blocks
425-
txt = txt + txt_mod1.gate * nn.Dense(
423+
txt_proj = nn.Dense(
426424
self.hidden_size,
427425
use_bias=True,
428426
dtype=self.dtype,
@@ -433,6 +431,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array):
433431
("heads", "embed")
434432
),
435433
)(txt_attn)
434+
txt = txt + txt_mod1.gate * txt_proj
435+
436436
txt = txt + txt_mod2.gate * nn.Sequential(
437437
[
438438
nn.Dense(
@@ -466,4 +466,54 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array):
466466
)(txt) + txt_mod2.shift
467467
)
468468

469-
return img, txt
469+
return img, txt
470+
471+
class LastLayer(nn.Module):
472+
hidden_size: int
473+
patch_size: int
474+
out_channels: int
475+
dtype: jnp.dtype = jnp.float32
476+
weights_dtype: jnp.dtype = jnp.float32
477+
precision: jax.lax.Precision = None
478+
479+
@nn.compact
480+
def __call__(self, x: Array, vec: Array) -> Array:
481+
shift, scale = jnp.split(
482+
nn.Sequential(
483+
[
484+
nn.silu,
485+
nn.Dense(
486+
2 * self.hidden_size,
487+
use_bias=True,
488+
param_dtype=self.weights_dtype,
489+
dtype=self.dtype,
490+
precision=self.precision,
491+
kernel_init=nn.with_logical_partitioning(
492+
nn.initializers.lecun_normal(),
493+
("embed", "heads")
494+
)
495+
)
496+
]
497+
)(vec), 2, axis=1
498+
)
499+
norm_final = nn.LayerNorm(
500+
epsilon=1e-6,
501+
use_scale=False,
502+
use_bias=False,
503+
param_dtype=self.weights_dtype,
504+
name="norm_final"
505+
)(x)
506+
x = (1 + scale[:, None, :]) * norm_final + shift[:, None, :]
507+
x = nn.Dense(
508+
self.patch_size * self.patch_size * self.out_channels,
509+
use_bias=True,
510+
param_dtype=self.weights_dtype,
511+
dtype=self.dtype,
512+
precision=self.precision,
513+
kernel_init=nn.with_logical_partitioning(
514+
nn.initializers.lecun_normal(),
515+
("heads", "embed")
516+
),
517+
name="linear"
518+
)
519+
return x

0 commit comments

Comments
 (0)