1313# limitations under the License.
1414
1515from functools import partial
16- from typing import Dict , List , Optional , Union
16+ from typing import Dict , List , Optional , Union , Callable
1717
1818import jax
1919import jax .numpy as jnp
@@ -152,8 +152,8 @@ def prepare_latents(
152152
153153 def prepare_latent_image_ids (self , height , width ):
154154 latent_image_ids = jnp .zeros ((height , width , 3 ))
155- latent_image_ids = latent_image_ids .at [..., 1 ].set (latent_image_ids [..., 1 ] + jnp .arange (height )[:, None ])
156- latent_image_ids = latent_image_ids .at [..., 2 ].set (latent_image_ids [..., 2 ] + jnp .arange (width )[None , :])
155+ latent_image_ids = latent_image_ids .at [..., 1 ].set (jnp .arange (height )[:, None ])
156+ latent_image_ids = latent_image_ids .at [..., 2 ].set (jnp .arange (width )[None , :])
157157
158158 latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
159159
@@ -165,7 +165,6 @@ def get_clip_prompt_embeds(
165165 self , prompt : Union [str , List [str ]], num_images_per_prompt : int , tokenizer : CLIPTokenizer , text_encoder : FlaxCLIPTextModel
166166 ):
167167 prompt = [prompt ] if isinstance (prompt , str ) else prompt
168- batch_size = len (prompt )
169168 text_inputs = tokenizer (
170169 prompt ,
171170 padding = "max_length" ,
@@ -180,8 +179,7 @@ def get_clip_prompt_embeds(
180179
181180 prompt_embeds = text_encoder (text_input_ids , params = text_encoder .params , train = False )
182181 prompt_embeds = prompt_embeds .pooler_output
183- prompt_embeds = np .repeat (prompt_embeds , num_images_per_prompt , axis = - 1 )
184- prompt_embeds = np .reshape (prompt_embeds , (batch_size * num_images_per_prompt , - 1 ))
182+ prompt_embeds = jnp .tile (prompt_embeds , (num_images_per_prompt , 1 ))
185183 return prompt_embeds
186184
187185
@@ -260,7 +258,8 @@ def _generate(
260258 txt_ids ,
261259 vec ,
262260 guidance_vec ,
263- timesteps ,
261+ c_ts ,
262+ p_ts
264263 ):
265264
266265 def loop_body (
@@ -292,9 +291,6 @@ def loop_body(
292291 latents = jnp .array (latents , dtype = latents_dtype )
293292 return latents , state , c_ts , p_ts
294293
295- c_ts = timesteps [:- 1 ]
296- p_ts = timesteps [1 :]
297-
298294 loop_body_p = partial (
299295 loop_body ,
300296 transformer = self .flux ,
@@ -308,10 +304,28 @@ def loop_body(
308304 vae_decode_p = partial (self .vae_decode , vae = self .vae , state = vae_params , config = self ._config )
309305
310306 with self .mesh , nn_partitioning .axis_rules (self ._config .logical_axis_rules ):
311- latents , _ , _ , _ = jax .lax .fori_loop (0 , len (timesteps ) - 1 , loop_body_p , (latents , flux_params , c_ts , p_ts ))
307+ latents , _ , _ , _ = jax .lax .fori_loop (0 , len (c_ts ) , loop_body_p , (latents , flux_params , c_ts , p_ts ))
312308 image = vae_decode_p (latents )
313309 return image
314310
311+ def do_time_shift (self , mu : float , sigma : float , t : Array ):
312+ return math .exp (mu ) / (math .exp (mu ) + (1 / t - 1 ) ** sigma )
313+
314+
315+ def get_lin_function (self , x1 : float = 256 , y1 : float = 0.5 , x2 : float = 4096 , y2 : float = 1.15 ) -> Callable [[float ], float ]:
316+ m = (y2 - y1 ) / (x2 - x1 )
317+ b = y1 - m * x1
318+ return lambda x : m * x + b
319+
320+ def time_shift (self , latents , timesteps ):
321+ # estimate mu based on linear estimation between two points
322+ lin_function = self .get_lin_function (x1 = self ._config .max_sequence_length ,
323+ y1 = self ._config .base_shift ,
324+ y2 = self ._config .max_shift )
325+ mu = lin_function (latents .shape [1 ])
326+ timesteps = self .do_time_shift (mu , 1.0 , timesteps )
327+ return timesteps
328+
315329 def __call__ (
316330 self ,
317331 timesteps : int ,
@@ -364,7 +378,11 @@ def __call__(
364378 rng = self .rng ,
365379 )
366380
367- #timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16)
381+ if self ._config .time_shift :
382+ timesteps = self .time_shift (latents , timesteps )
383+ c_ts = timesteps [:- 1 ]
384+ p_ts = timesteps [1 :]
385+
368386 guidance = jnp .asarray ([self ._config .guidance_scale ] * global_batch_size , dtype = jnp .bfloat16 )
369387
370388 images = self ._generate (
@@ -376,7 +394,8 @@ def __call__(
376394 text_ids ,
377395 pooled_prompt_embeds ,
378396 guidance ,
379- timesteps ,
397+ c_ts ,
398+ p_ts
380399 )
381400
382401 images = images
0 commit comments