Skip to content

Commit d16c020

Browse files
support both dev and schnell loading. Images still incorrect.
1 parent 601f40c commit d16c020

6 files changed

Lines changed: 99 additions & 27 deletions

File tree

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
2828
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
2929

3030
# Flux params
31+
flux_name: "flux-dev"
3132
max_sequence_length: 512
3233
time_shift: False
3334
base_shift: 0.5

src/maxdiffusion/configs/base_fux_schnell.yml renamed to src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ gcs_metrics: False
2323
save_config_to_gcs: False
2424
log_period: 100
2525

26-
pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
26+
pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-schnell'
2727
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
2828
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
2929

3030
# Flux params
31+
flux_name: "flux-schnell"
3132
max_sequence_length: 256
3233
time_shift: False
3334
base_shift: 0.5
@@ -208,10 +209,10 @@ prompt: "A magical castle in the middle of a forest, artistic drawing"
208209
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
209210
negative_prompt: "purple, red"
210211
do_classifier_free_guidance: True
211-
guidance_scale: 3.5
212+
guidance_scale: 0.0
212213
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
213214
guidance_rescale: 0.0
214-
num_inference_steps: 20
215+
num_inference_steps: 4
215216

216217
# SDXL Lightning parameters
217218
lightning_from_pt: True

src/maxdiffusion/generate_flux.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def unpack(x: Array, height: int, width: int) -> Array:
5555
ph=2,
5656
pw=2,
5757
)
58-
58+
from einops import rearrange
5959
def vae_decode(latents, vae, state, config):
6060
img = unpack(x=latents, height=config.resolution, width=config.resolution)
6161
img = img / vae.config.scaling_factor + vae.config.shift_factor
@@ -87,6 +87,8 @@ def loop_body(
8787
guidance=guidance_vec,
8888
y=vec
8989
)
90+
jax.debug.print("*****pred max: {x}", x=np.max(pred))
91+
jax.debug.print("*****pred min: {x}", x=np.min(pred))
9092
latents = latents + (t_prev - t_curr) * pred
9193
latents = jnp.array(latents, dtype=latents_dtype)
9294
return latents, state, c_ts, p_ts
@@ -144,6 +146,8 @@ def run_inference(
144146
timesteps = time_shift(mu, 1.0, timesteps).tolist()
145147
c_ts = timesteps[:-1]
146148
p_ts = timesteps[1:]
149+
# jax.debug.print("c_ts: {x}", x=c_ts)
150+
# jax.debug.print("p_ts: {x}", x=p_ts)
147151

148152

149153
transformer_state = states["transformer"]
@@ -162,7 +166,7 @@ def run_inference(
162166
vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config)
163167

164168
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
165-
latents, _, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, loop_body_p, (latents, transformer_state, c_ts, p_ts))
169+
latents, _, _, _ = jax.lax.fori_loop(0, len(timesteps) - 1, loop_body_p, (latents, transformer_state, c_ts, p_ts))
166170
image = vae_decode_p(latents)
167171
return image
168172

@@ -293,7 +297,8 @@ def encode_prompt(
293297
prompt=prompt_2,
294298
num_images_per_prompt=num_images_per_prompt,
295299
tokenizer=t5_tokenizer,
296-
text_encoder=t5_text_encoder
300+
text_encoder=t5_text_encoder,
301+
max_sequence_length=max_sequence_length
297302
)
298303

299304
text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16)
@@ -356,7 +361,7 @@ def run(config):
356361
rng=rng
357362
)
358363

359-
# LOAD TEXT ENCODERS - t5 on cpu
364+
# LOAD TEXT ENCODERS
360365
clip_text_encoder = FlaxCLIPTextModel.from_pretrained(
361366
config.pretrained_model_name_or_path,
362367
subfolder="text_encoder",
@@ -389,7 +394,8 @@ def run(config):
389394
clip_text_encoder=clip_text_encoder,
390395
t5_tokenizer=t5_tokenizer,
391396
t5_text_encoder=t5_encoder,
392-
num_images_per_prompt=global_batch_size
397+
num_images_per_prompt=global_batch_size,
398+
max_sequence_length=config.max_sequence_length
393399
)
394400

395401
def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds):
@@ -430,12 +436,12 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
430436

431437
get_memory_allocations()
432438
# evaluate shapes
433-
transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=512, eval_only=True)
439+
transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True)
434440

435441
# loads pretrained weights
436-
transformer_params = load_flow_model("flux-dev", transformer_eval_params, "cpu")
442+
transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu")
437443
# create transformer state
438-
weights_init_fn = functools.partial(transformer.init_weights, rngs=rng, max_sequence_length=512, eval_only=False)
444+
weights_init_fn = functools.partial(transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False)
439445
transformer_state, transformer_state_shardings = setup_initial_state(
440446
model=transformer,
441447
tx=None,

src/maxdiffusion/models/flux/modules/layers.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,49 @@ def timestep_embedding(
111111
embedding = embedding.astype(t.dtype)
112112

113113
return embedding
114+
import numpy as np
115+
class PixArtAlphaTextProjection(nn.Module):
116+
hidden_dim: int
117+
dtype: jnp.dtype = jnp.float32
118+
weights_dtype: jnp.dtype = jnp.float32
119+
precision: jax.lax.Precision = None
120+
121+
@nn.compact
122+
def __call__(self, x: Array) -> Array:
123+
124+
hidden_states = nn.Dense(
125+
self.hidden_dim,
126+
use_bias=True,
127+
dtype=self.dtype,
128+
param_dtype=self.weights_dtype,
129+
precision=self.precision,
130+
kernel_init=nn.with_logical_partitioning(
131+
nn.initializers.lecun_normal(),
132+
("embed", "heads")
133+
),
134+
name="in_layer"
135+
)(x)
136+
jax.debug.print("PixArtAlphaTextProjection, in_layer min: {x}", x=np.min(hidden_states))
137+
jax.debug.print("PixArtAlphaTextProjection, in_layer max: {x}", x=np.max(hidden_states))
138+
hidden_states = nn.swish(hidden_states)
139+
jax.debug.print("PixArtAlphaTextProjection, act min: {x}", x=np.min(hidden_states))
140+
jax.debug.print("PixArtAlphaTextProjection, act max: {x}", x=np.max(hidden_states))
141+
hidden_states = nn.Dense(
142+
self.hidden_dim,
143+
use_bias=True,
144+
dtype=self.dtype,
145+
param_dtype=self.weights_dtype,
146+
precision=self.precision,
147+
kernel_init=nn.with_logical_partitioning(
148+
nn.initializers.lecun_normal(),
149+
("heads", "embed")
150+
),
151+
name="out_layer"
152+
)(hidden_states)
153+
jax.debug.print("PixArtAlphaTextProjection, out min: {x}", x=np.min(hidden_states))
154+
jax.debug.print("PixArtAlphaTextProjection, out max: {x}", x=np.max(hidden_states))
114155

156+
return hidden_states
115157

116158
class MLPEmbedder(nn.Module):
117159
hidden_dim: int

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Dict, Optional, Tuple, Union
1818

1919
from einops import repeat, rearrange
20+
import numpy as np
2021
import jax
2122
import jax.numpy as jnp
2223
import flax.linen as nn
@@ -28,7 +29,8 @@
2829
EmbedND,
2930
DoubleStreamBlock,
3031
SingleStreamBlock,
31-
LastLayer
32+
LastLayer,
33+
PixArtAlphaTextProjection
3234
)
3335
from ...modeling_flax_utils import FlaxModelMixin
3436
from ....configuration_utils import ConfigMixin, flax_register_to_config
@@ -129,6 +131,9 @@ def __call__(
129131
inner_dim = self.num_attention_heads * self.attention_head_dim
130132
pe_dim = inner_dim // self.num_attention_heads
131133

134+
jax.debug.print("pooled_projections value min: {x}", x=np.min(y))
135+
jax.debug.print("pooled_projections value max: {x}", x=np.max(y))
136+
132137
img = nn.Dense(
133138
inner_dim,
134139
dtype=self.dtype,
@@ -140,39 +145,57 @@ def __call__(
140145
),
141146
name="img_in"
142147
)(img)
143-
148+
jax.debug.print("img.min: {x}", x=np.min(img))
149+
jax.debug.print("img.max: {x}", x=np.max(img))
150+
timestep = timestep_embedding(timesteps, 256)
151+
jax.debug.print("timestep.min: {x}", x=np.min(timestep))
152+
jax.debug.print("timestep.max: {x}", x=np.max(timestep))
144153
vec = MLPEmbedder(
145154
hidden_dim=inner_dim,
146155
dtype=self.dtype,
147156
weights_dtype=self.weights_dtype,
148157
precision=self.precision,
149158
name="time_in"
150-
)(timestep_embedding(timesteps, 256))
151-
159+
)(timestep)
160+
jax.debug.print("timestep.vec min: {x}", x=np.min(vec))
161+
jax.debug.print("timestep.vec max: {x}", x=np.max(vec))
162+
print(f"guidance_embeds? {self.guidance_embeds}")
152163
if self.guidance_embeds:
153164
if guidance is None:
154165
raise ValueError(
155166
"Didn't get guidance strength for guidance distrilled model."
156167
)
168+
guidance_in = timestep_embedding(guidance, 256)
169+
170+
jax.debug.print("guidance_in.min: {x}", x=np.min(guidance_in))
171+
jax.debug.print("guidance_in.max: {x}", x=np.max(guidance_in))
157172
guidance_in = MLPEmbedder(
158173
hidden_dim=inner_dim,
159174
dtype=self.dtype,
160175
weights_dtype=self.weights_dtype,
161176
precision=self.precision,
162177
name="guidance_in"
163-
)(timestep_embedding(guidance, 256))
164-
else:
165-
guidance_in = Identity(timestep_embedding(guidance, 256))
166-
178+
)(guidance_in)
179+
jax.debug.print("guidance.vec min: {x}", x=np.min(guidance_in))
180+
jax.debug.print("guidance.vec max: {x}", x=np.max(guidance_in))
167181
vec = vec + guidance_in
168-
169-
vec = vec + MLPEmbedder(
182+
jax.debug.print("timestep_guidance.vec min: {x}", x=np.min(vec))
183+
jax.debug.print("timestep_guidance.vec max: {x}", x=np.max(vec))
184+
# else:
185+
# guidance_in = Identity()(timestep_embedding(guidance, 256))
186+
187+
pooled_projections = PixArtAlphaTextProjection(
170188
hidden_dim=inner_dim,
171189
dtype=self.dtype,
172190
weights_dtype=self.weights_dtype,
173191
precision=self.precision,
174192
name="vector_in"
175193
)(y)
194+
jax.debug.print("pooled_projections.min: {x}", x=np.min(pooled_projections))
195+
jax.debug.print("pooled_projections.max: {x}", x=np.max(pooled_projections))
196+
vec = vec + pooled_projections
197+
jax.debug.print("temb.min: {x}", x=np.min(vec))
198+
jax.debug.print("temb.max: {x}", x=np.max(vec))
176199

177200
txt = nn.Dense(
178201
inner_dim,
@@ -185,7 +208,8 @@ def __call__(
185208
),
186209
name="txt_in"
187210
)(txt)
188-
211+
jax.debug.print("txt.min: {x}", x=np.min(txt))
212+
jax.debug.print("txt.max: {x}", x=np.max(txt))
189213
ids = jnp.concatenate((txt_ids, img_ids), axis=1)
190214

191215
#pe_embedder
@@ -194,7 +218,8 @@ def __call__(
194218
theta=10000,
195219
axes_dim=self.axes_dims_rope
196220
)(ids)
197-
# breakpoint()
221+
jax.debug.print("pe.min: {x}", x=np.min(pe))
222+
jax.debug.print("pe.max: {x}", x=np.max(pe))
198223
# img, txt = DoubleStreamBlock(
199224
# hidden_size=inner_dim,
200225
# num_heads=self.num_attention_heads,

src/maxdiffusion/models/flux/util.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
)
1919
from maxdiffusion import max_logging
2020

21-
# from jflux.model import Flux, FluxParams
22-
from .port import port_flux
23-
2421
@dataclass
2522
class FluxParams:
2623
in_channels: int
@@ -42,7 +39,7 @@ def torch2jax(torch_tensor: torch.Tensor) -> Array:
4239
is_bfloat16 = torch_tensor.dtype == torch.bfloat16
4340
if is_bfloat16:
4441
# upcast the tensor to fp32
45-
torch_tensor = torch_tensor.to(dtype=torch.float32)
42+
torch_tensor = torch_tensor.float()
4643

4744
if torch.device.type != "cpu":
4845
torch_tensor = torch_tensor.to("cpu")

0 commit comments

Comments
 (0)