Skip to content

Commit bffc7dc

Browse files
committed
Rebased on flux_lora and aligned flux_pipeline with changes in generate_flux.py
1 parent c11ad9d commit bffc7dc

1 file changed

Lines changed: 32 additions & 13 deletions

File tree

src/maxdiffusion/pipelines/flux/flux_pipeline.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from functools import partial
16-
from typing import Dict, List, Optional, Union
16+
from typing import Dict, List, Optional, Union, Callable
1717

1818
import jax
1919
import jax.numpy as jnp
@@ -152,8 +152,8 @@ def prepare_latents(
152152

153153
def prepare_latent_image_ids(self, height, width):
154154
latent_image_ids = jnp.zeros((height, width, 3))
155-
latent_image_ids = latent_image_ids.at[..., 1].set(latent_image_ids[..., 1] + jnp.arange(height)[:, None])
156-
latent_image_ids = latent_image_ids.at[..., 2].set(latent_image_ids[..., 2] + jnp.arange(width)[None, :])
155+
latent_image_ids = latent_image_ids.at[..., 1].set(jnp.arange(height)[:, None])
156+
latent_image_ids = latent_image_ids.at[..., 2].set(jnp.arange(width)[None, :])
157157

158158
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
159159

@@ -165,7 +165,6 @@ def get_clip_prompt_embeds(
165165
self, prompt: Union[str, List[str]], num_images_per_prompt: int, tokenizer: CLIPTokenizer, text_encoder: FlaxCLIPTextModel
166166
):
167167
prompt = [prompt] if isinstance(prompt, str) else prompt
168-
batch_size = len(prompt)
169168
text_inputs = tokenizer(
170169
prompt,
171170
padding="max_length",
@@ -180,8 +179,7 @@ def get_clip_prompt_embeds(
180179

181180
prompt_embeds = text_encoder(text_input_ids, params=text_encoder.params, train=False)
182181
prompt_embeds = prompt_embeds.pooler_output
183-
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=-1)
184-
prompt_embeds = np.reshape(prompt_embeds, (batch_size * num_images_per_prompt, -1))
182+
prompt_embeds = jnp.tile(prompt_embeds, (num_images_per_prompt, 1))
185183
return prompt_embeds
186184

187185

@@ -260,7 +258,8 @@ def _generate(
260258
txt_ids,
261259
vec,
262260
guidance_vec,
263-
timesteps,
261+
c_ts,
262+
p_ts
264263
):
265264

266265
def loop_body(
@@ -292,9 +291,6 @@ def loop_body(
292291
latents = jnp.array(latents, dtype=latents_dtype)
293292
return latents, state, c_ts, p_ts
294293

295-
c_ts = timesteps[:-1]
296-
p_ts = timesteps[1:]
297-
298294
loop_body_p = partial(
299295
loop_body,
300296
transformer=self.flux,
@@ -308,10 +304,28 @@ def loop_body(
308304
vae_decode_p = partial(self.vae_decode, vae=self.vae, state=vae_params, config=self._config)
309305

310306
with self.mesh, nn_partitioning.axis_rules(self._config.logical_axis_rules):
311-
latents, _, _, _ = jax.lax.fori_loop(0, len(timesteps) - 1, loop_body_p, (latents, flux_params, c_ts, p_ts))
307+
latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, flux_params, c_ts, p_ts))
312308
image = vae_decode_p(latents)
313309
return image
314310

311+
def do_time_shift(self, mu: float, sigma: float, t: Array):
312+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
313+
314+
315+
def get_lin_function(self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
316+
m = (y2 - y1) / (x2 - x1)
317+
b = y1 - m * x1
318+
return lambda x: m * x + b
319+
320+
def time_shift(self, latents, timesteps):
321+
# estimate mu based on linear estimation between two points
322+
lin_function = self.get_lin_function(x1=self._config.max_sequence_length,
323+
y1=self._config.base_shift,
324+
y2=self._config.max_shift)
325+
mu = lin_function(latents.shape[1])
326+
timesteps = self.do_time_shift(mu, 1.0, timesteps)
327+
return timesteps
328+
315329
def __call__(
316330
self,
317331
timesteps: int,
@@ -364,7 +378,11 @@ def __call__(
364378
rng=self.rng,
365379
)
366380

367-
#timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16)
381+
if self._config.time_shift:
382+
timesteps = self.time_shift(latents, timesteps)
383+
c_ts = timesteps[:-1]
384+
p_ts = timesteps[1:]
385+
368386
guidance = jnp.asarray([self._config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16)
369387

370388
images = self._generate(
@@ -376,7 +394,8 @@ def __call__(
376394
text_ids,
377395
pooled_prompt_embeds,
378396
guidance,
379-
timesteps,
397+
c_ts,
398+
p_ts
380399
)
381400

382401
images = images

0 commit comments

Comments
 (0)