From e1779354e41fed3baab82c1658d3de5b1f12bc2a Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 18 Mar 2025 03:44:15 +0000 Subject: [PATCH 1/8] fixes accidental wrong merge. --- src/maxdiffusion/generate_flux.py | 2 +- src/maxdiffusion/max_utils.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 59564a271..e6e5bcb0e 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -469,7 +469,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep max_logging.log(f"Compile time: {t1 - t0:.1f}s.") t0 = time.perf_counter() - with ExitStack() as stack, jax.profiler.trace("/home/jfacevedo/trace/"): + with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"): _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] imgs = p_run_inference(states).block_until_ready() t1 = time.perf_counter() diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 2d37e9416..3dff39e3c 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -424,8 +424,6 @@ def setup_initial_state( if model_params: state = state.replace(params=model_params) state = jax.device_put(state, state_mesh_shardings) - if model_params: - state = state.replace(params=model_params) state = unbox_logicallypartioned_trainstate(state) From 7d3dbd49f4ac0eb32cdef14cac9b4faaa98396c7 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 18 Mar 2025 04:16:39 +0000 Subject: [PATCH 2/8] precompile generate functions with different dimensions. --- src/maxdiffusion/generate_flux.py | 97 ++++++++++++++++++------------- 1 file changed, 58 insertions(+), 39 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index e6e5bcb0e..867c5077f 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -76,8 +76,8 @@ def unpack(x: Array, height: int, width: int) -> Array: ) -def vae_decode(latents, vae, state, config): - img = unpack(x=latents.astype(jnp.float32), height=config.resolution, width=config.resolution) +def vae_decode(latents, vae, state, config, resolution): + img = unpack(x=latents.astype(jnp.float32), height=resolution, width=resolution) img = img / vae.config.scaling_factor + vae.config.shift_factor img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample return img @@ -135,7 +135,7 @@ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: flo def run_inference( - states, transformer, vae, config, mesh, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts + states, transformer, vae, config, resolution, mesh, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts ): transformer_state = states["transformer"] @@ -150,7 +150,7 @@ def run_inference( vec=vec, guidance_vec=guidance_vec, ) - vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config) + vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config, resolution=resolution) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, transformer_state, c_ts, p_ts)) @@ -376,8 +376,6 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep # move inputs to device and shard data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - latents = jax.device_put(latents, data_sharding) - latent_image_ids = jax.device_put(latent_image_ids) prompt_embeds = jax.device_put(prompt_embeds, data_sharding) text_ids = jax.device_put(text_ids) guidance = jax.device_put(guidance, data_sharding) @@ -429,45 +427,66 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep states["transformer"] = transformer_state states["vae"] = vae_state - # Setup timesteps - timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) - # shifting the schedule to favor high timesteps for higher signal images - if config.time_shift: - # estimate mu based on linear estimation between two points - lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift) - mu = lin_function(latents.shape[1]) - timesteps = time_shift(mu, 1.0, timesteps) - c_ts = timesteps[:-1] - p_ts = timesteps[1:] - - validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) - - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - vae=vae, - config=config, - mesh=mesh, - latents=latents, - latent_image_ids=latent_image_ids, - prompt_embeds=prompt_embeds, - txt_ids=text_ids, - vec=pooled_prompt_embeds, - guidance_vec=guidance, - c_ts=c_ts, - p_ts=p_ts, - ), - in_shardings=(state_shardings,), - out_shardings=None, - ) + #validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) + + resolutions = [1024, 768, 512] + p_jitted = {} + for resolution in resolutions: + latents, latent_image_ids = prepare_latents( + batch_size=global_batch_size, + num_channels_latents=num_channels_latents, + height=resolution, + width=resolution, + dtype=jnp.bfloat16, + vae_scale_factor=vae_scale_factor, + rng=rng, + ) + latents = jax.device_put(latents, data_sharding) + latent_image_ids = jax.device_put(latent_image_ids) + + # Setup timesteps + timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) + # shifting the schedule to favor high timesteps for higher signal images + if config.time_shift: + # estimate mu based on linear estimation between two points + lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift) + mu = lin_function(latents.shape[1]) + timesteps = time_shift(mu, 1.0, timesteps) + c_ts = timesteps[:-1] + p_ts = timesteps[1:] + + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + vae=vae, + config=config, + resolution=resolution, + mesh=mesh, + latents=latents, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=text_ids, + vec=pooled_prompt_embeds, + guidance_vec=guidance, + c_ts=c_ts, + p_ts=p_ts, + ), + in_shardings=(state_shardings,), + out_shardings=None, + ) + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + p_run_inference(states).block_until_ready() + p_jitted[resolution] = p_run_inference + breakpoint() t0 = time.perf_counter() with ExitStack() as stack: _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] p_run_inference(states).block_until_ready() t1 = time.perf_counter() max_logging.log(f"Compile time: {t1 - t0:.1f}s.") - + breakpoint() t0 = time.perf_counter() with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"): _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] From 3b40223dfc2a66b1a213221f713b96ef6e8eb7db Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 18 Mar 2025 05:16:35 +0000 Subject: [PATCH 3/8] iterate over different resolutions and store precompiled functions in dict. --- src/maxdiffusion/generate_flux.py | 181 +++++++++++++++++------------- 1 file changed, 103 insertions(+), 78 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 867c5077f..220f4b564 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -77,7 +77,7 @@ def unpack(x: Array, height: int, width: int) -> Array: def vae_decode(latents, vae, state, config, resolution): - img = unpack(x=latents.astype(jnp.float32), height=resolution, width=resolution) + img = unpack(x=latents.astype(jnp.float32), height=resolution[1], width=resolution[0]) img = img / vae.config.scaling_factor + vae.config.shift_factor img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample return img @@ -322,15 +322,6 @@ def run(config): ) num_channels_latents = transformer.in_channels // 4 - latents, latent_image_ids = prepare_latents( - batch_size=global_batch_size, - num_channels_latents=num_channels_latents, - height=config.resolution, - width=config.resolution, - dtype=jnp.bfloat16, - vae_scale_factor=vae_scale_factor, - rng=rng, - ) # LOAD TEXT ENCODERS clip_text_encoder = FlaxCLIPTextModel.from_pretrained( @@ -352,17 +343,6 @@ def run(config): t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params) t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - prompt=config.prompt, - prompt_2=config.prompt_2, - clip_tokenizer=clip_tokenizer, - clip_text_encoder=clip_text_encoder, - t5_tokenizer=t5_tokenizer, - t5_text_encoder=t5_encoder, - num_images_per_prompt=global_batch_size, - max_sequence_length=config.max_sequence_length, - ) - def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds): print("latents.shape: ", latents.shape, latents.dtype) print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype) @@ -374,13 +354,6 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) - # move inputs to device and shard - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - prompt_embeds = jax.device_put(prompt_embeds, data_sharding) - text_ids = jax.device_put(text_ids) - guidance = jax.device_put(guidance, data_sharding) - pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding) - if config.offload_encoders: cpus = jax.devices("cpu") t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0]) @@ -427,58 +400,110 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep states["transformer"] = transformer_state states["vae"] = vae_state - #validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) - - resolutions = [1024, 768, 512] + resolutions = [ + (768, 768), + (768, 1024), + (1024, 768), + (1024, 1024), + (896, 1152), + (1152, 896), + (1920, 1080), + (1080, 1920) + ] p_jitted = {} for resolution in resolutions: - latents, latent_image_ids = prepare_latents( - batch_size=global_batch_size, - num_channels_latents=num_channels_latents, - height=resolution, - width=resolution, - dtype=jnp.bfloat16, - vae_scale_factor=vae_scale_factor, - rng=rng, - ) - latents = jax.device_put(latents, data_sharding) - latent_image_ids = jax.device_put(latent_image_ids) - - # Setup timesteps - timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) - # shifting the schedule to favor high timesteps for higher signal images - if config.time_shift: - # estimate mu based on linear estimation between two points - lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift) - mu = lin_function(latents.shape[1]) - timesteps = time_shift(mu, 1.0, timesteps) - c_ts = timesteps[:-1] - p_ts = timesteps[1:] - - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - vae=vae, - config=config, - resolution=resolution, - mesh=mesh, - latents=latents, - latent_image_ids=latent_image_ids, - prompt_embeds=prompt_embeds, - txt_ids=text_ids, - vec=pooled_prompt_embeds, - guidance_vec=guidance, - c_ts=c_ts, - p_ts=p_ts, - ), - in_shardings=(state_shardings,), - out_shardings=None, - ) - with ExitStack() as stack: - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - p_run_inference(states).block_until_ready() - p_jitted[resolution] = p_run_inference + max_logging.log(f"Resolutions: {resolution}") + for _ in range(5): + s0 = time.perf_counter() + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + prompt=config.prompt, + prompt_2=config.prompt_2, + clip_tokenizer=clip_tokenizer, + clip_text_encoder=clip_text_encoder, + t5_tokenizer=t5_tokenizer, + t5_text_encoder=t5_encoder, + num_images_per_prompt=global_batch_size, + max_sequence_length=config.max_sequence_length, + ) + max_logging.log(f"text encoding time: {(time.perf_counter() - s0)}") + latents, latent_image_ids = prepare_latents( + batch_size=global_batch_size, + num_channels_latents=num_channels_latents, + height=resolution[1], + width=resolution[0], + dtype=jnp.bfloat16, + vae_scale_factor=vae_scale_factor, + rng=rng, + ) + + # move inputs to device and shard + s0 = time.perf_counter() + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + text_ids = jax.device_put(text_ids) + guidance = jax.device_put(guidance, data_sharding) + pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding) + latents = jax.device_put(latents, data_sharding) + latent_image_ids = jax.device_put(latent_image_ids) + max_logging.log(f"Moving to device time: {(time.perf_counter() - s0)}") + + # Setup timesteps + timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) + # shifting the schedule to favor high timesteps for higher signal images + if config.time_shift: + # estimate mu based on linear estimation between two points + lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift) + mu = lin_function(latents.shape[1]) + timesteps = time_shift(mu, 1.0, timesteps) + c_ts = timesteps[:-1] + p_ts = timesteps[1:] + validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) + p_run_inference = p_jitted.get(resolution, None) + if p_run_inference is None: + print("FN not found, compiling...") + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + vae=vae, + config=config, + resolution=resolution, + mesh=mesh, + latents=latents, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=text_ids, + vec=pooled_prompt_embeds, + guidance_vec=guidance, + c_ts=c_ts, + p_ts=p_ts, + ), + ) + p_jitted[resolution] = p_run_inference + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + s0 = time.perf_counter() + imgs = p_run_inference( + states, + latents = latents, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=text_ids, + vec=pooled_prompt_embeds, + ).block_until_ready() + max_logging.log(f"inference time: {(time.perf_counter() - s0)}") + s0 = time.perf_counter() + imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) + max_logging.log(f"Gathering all time: {(time.perf_counter() - s0)}") + s0 = time.perf_counter() + imgs = np.array(imgs) + imgs = (imgs * 0.5 + 0.5).clip(0, 1) + imgs = np.transpose(imgs, (0, 2, 3, 1)) + imgs = np.uint8(imgs * 255) + for i, image in enumerate(imgs): + Image.fromarray(image).save(f"flux_{resolution[0]}_{resolution[1]}_{i}.png") + max_logging.log(f"Saving images time: {(time.perf_counter() - s0)}") + get_memory_allocations() breakpoint() t0 = time.perf_counter() with ExitStack() as stack: From e2e2c505958588c853edbbf7493968896f5061a3 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 19 Mar 2025 00:07:00 +0000 Subject: [PATCH 4/8] adds a new inference file to show how to precompile for different resolutions. --- .../configs/base_flux_dev_multi_res.yml | 271 +++++++++ src/maxdiffusion/generate_flux.py | 172 ++---- src/maxdiffusion/generate_flux_multi_res.py | 575 ++++++++++++++++++ 3 files changed, 909 insertions(+), 109 deletions(-) create mode 100644 src/maxdiffusion/configs/base_flux_dev_multi_res.yml create mode 100644 src/maxdiffusion/generate_flux_multi_res.py diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml new file mode 100644 index 000000000..e0597360d --- /dev/null +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -0,0 +1,271 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' +clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' +t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' + +# Flux params +flux_name: "flux-dev" +max_sequence_length: 512 +time_shift: True +base_shift: 0.5 +max_shift: 1.15 +# offloads t5 encoder after text encoding to save memory. +offload_encoders: True + + +unet_checkpoint: '' +revision: 'refs/pr/95' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te + +#flash_block_sizes: {} +# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. +flash_block_sizes: { + "block_q" : 1536, + "block_kv_compute" : 1536, + "block_kv" : 1536, + "block_q_dkv" : 1536, + "block_kv_dkv" : 1536, + "block_kv_dkv_compute" : 1536, + "block_q_dq" : 1536, + "block_kv_dq" : 1536 +} +# GroupNorm groups +norm_num_groups: 32 + +# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch +# else they will be loaded from pretrained_model_name_or_path +train_new_unet: False + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], +# ['embed','fsdp'], + ['mlp',['fsdp','tensor']], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 4.e-7 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 200 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1 + +warmup_steps_fraction: 0.0 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 1.e-2 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Generation parameters +prompt: "A magical castle in the middle of a forest, artistic drawing" +prompt_2: "A magical castle in the middle of a forest, artistic drawing" +negative_prompt: "purple, red" +do_classifier_free_guidance: True +guidance_scale: 3.5 +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 50 + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. + diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 220f4b564..ac7f4cd90 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -76,8 +76,8 @@ def unpack(x: Array, height: int, width: int) -> Array: ) -def vae_decode(latents, vae, state, config, resolution): - img = unpack(x=latents.astype(jnp.float32), height=resolution[1], width=resolution[0]) +def vae_decode(latents, vae, state, config): + img = unpack(x=latents.astype(jnp.float32), height=config.resolution, width=config.resolution) img = img / vae.config.scaling_factor + vae.config.shift_factor img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample return img @@ -135,7 +135,7 @@ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: flo def run_inference( - states, transformer, vae, config, resolution, mesh, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts + states, transformer, vae, config, mesh, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts ): transformer_state = states["transformer"] @@ -150,7 +150,7 @@ def run_inference( vec=vec, guidance_vec=guidance_vec, ) - vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config, resolution=resolution) + vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, transformer_state, c_ts, p_ts)) @@ -226,7 +226,6 @@ def get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = tokenizer( prompt, truncation=True, @@ -244,7 +243,6 @@ def get_t5_prompt_embeds( # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) - return prompt_embeds @@ -322,6 +320,15 @@ def run(config): ) num_channels_latents = transformer.in_channels // 4 + latents, latent_image_ids = prepare_latents( + batch_size=global_batch_size, + num_channels_latents=num_channels_latents, + height=config.resolution, + width=config.resolution, + dtype=jnp.bfloat16, + vae_scale_factor=vae_scale_factor, + rng=rng, + ) # LOAD TEXT ENCODERS clip_text_encoder = FlaxCLIPTextModel.from_pretrained( @@ -343,6 +350,17 @@ def run(config): t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params) t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + prompt=config.prompt, + prompt_2=config.prompt_2, + clip_tokenizer=clip_tokenizer, + clip_text_encoder=clip_text_encoder, + t5_tokenizer=t5_tokenizer, + t5_text_encoder=t5_encoder, + num_images_per_prompt=global_batch_size, + max_sequence_length=config.max_sequence_length, + ) + def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds): print("latents.shape: ", latents.shape, latents.dtype) print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype) @@ -354,6 +372,15 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) + # move inputs to device and shard + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(latents, data_sharding) + latent_image_ids = jax.device_put(latent_image_ids) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + text_ids = jax.device_put(text_ids) + guidance = jax.device_put(guidance, data_sharding) + pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding) + if config.offload_encoders: cpus = jax.devices("cpu") t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0]) @@ -400,118 +427,45 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep states["transformer"] = transformer_state states["vae"] = vae_state - resolutions = [ - (768, 768), - (768, 1024), - (1024, 768), - (1024, 1024), - (896, 1152), - (1152, 896), - (1920, 1080), - (1080, 1920) - ] - p_jitted = {} - for resolution in resolutions: - max_logging.log(f"Resolutions: {resolution}") - for _ in range(5): - s0 = time.perf_counter() - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - prompt=config.prompt, - prompt_2=config.prompt_2, - clip_tokenizer=clip_tokenizer, - clip_text_encoder=clip_text_encoder, - t5_tokenizer=t5_tokenizer, - t5_text_encoder=t5_encoder, - num_images_per_prompt=global_batch_size, - max_sequence_length=config.max_sequence_length, - ) - max_logging.log(f"text encoding time: {(time.perf_counter() - s0)}") - latents, latent_image_ids = prepare_latents( - batch_size=global_batch_size, - num_channels_latents=num_channels_latents, - height=resolution[1], - width=resolution[0], - dtype=jnp.bfloat16, - vae_scale_factor=vae_scale_factor, - rng=rng, - ) - - # move inputs to device and shard - s0 = time.perf_counter() - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - prompt_embeds = jax.device_put(prompt_embeds, data_sharding) - text_ids = jax.device_put(text_ids) - guidance = jax.device_put(guidance, data_sharding) - pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding) - latents = jax.device_put(latents, data_sharding) - latent_image_ids = jax.device_put(latent_image_ids) - max_logging.log(f"Moving to device time: {(time.perf_counter() - s0)}") - - # Setup timesteps - timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) - # shifting the schedule to favor high timesteps for higher signal images - if config.time_shift: - # estimate mu based on linear estimation between two points - lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift) - mu = lin_function(latents.shape[1]) - timesteps = time_shift(mu, 1.0, timesteps) - c_ts = timesteps[:-1] - p_ts = timesteps[1:] - validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) - p_run_inference = p_jitted.get(resolution, None) - if p_run_inference is None: - print("FN not found, compiling...") - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - vae=vae, - config=config, - resolution=resolution, - mesh=mesh, - latents=latents, - latent_image_ids=latent_image_ids, - prompt_embeds=prompt_embeds, - txt_ids=text_ids, - vec=pooled_prompt_embeds, - guidance_vec=guidance, - c_ts=c_ts, - p_ts=p_ts, - ), - ) - p_jitted[resolution] = p_run_inference - with ExitStack() as stack: - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - s0 = time.perf_counter() - imgs = p_run_inference( - states, - latents = latents, + # Setup timesteps + timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) + # shifting the schedule to favor high timesteps for higher signal images + if config.time_shift: + # estimate mu based on linear estimation between two points + lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift) + mu = lin_function(latents.shape[1]) + timesteps = time_shift(mu, 1.0, timesteps) + c_ts = timesteps[:-1] + p_ts = timesteps[1:] + + validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) + + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + vae=vae, + config=config, + mesh=mesh, + latents=latents, latent_image_ids=latent_image_ids, prompt_embeds=prompt_embeds, txt_ids=text_ids, vec=pooled_prompt_embeds, - ).block_until_ready() - max_logging.log(f"inference time: {(time.perf_counter() - s0)}") - s0 = time.perf_counter() - imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) - max_logging.log(f"Gathering all time: {(time.perf_counter() - s0)}") - s0 = time.perf_counter() - imgs = np.array(imgs) - imgs = (imgs * 0.5 + 0.5).clip(0, 1) - imgs = np.transpose(imgs, (0, 2, 3, 1)) - imgs = np.uint8(imgs * 255) - for i, image in enumerate(imgs): - Image.fromarray(image).save(f"flux_{resolution[0]}_{resolution[1]}_{i}.png") - max_logging.log(f"Saving images time: {(time.perf_counter() - s0)}") - get_memory_allocations() - breakpoint() + guidance_vec=guidance, + c_ts=c_ts, + p_ts=p_ts, + ), + in_shardings=(state_shardings,), + out_shardings=None, + ) t0 = time.perf_counter() with ExitStack() as stack: _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] p_run_inference(states).block_until_ready() t1 = time.perf_counter() max_logging.log(f"Compile time: {t1 - t0:.1f}s.") - breakpoint() + t0 = time.perf_counter() with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"): _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] @@ -542,4 +496,4 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) + app.run(main) \ No newline at end of file diff --git a/src/maxdiffusion/generate_flux_multi_res.py b/src/maxdiffusion/generate_flux_multi_res.py new file mode 100644 index 000000000..d1b9b177e --- /dev/null +++ b/src/maxdiffusion/generate_flux_multi_res.py @@ -0,0 +1,575 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Callable, List, Union, Sequence +from absl import app +from contextlib import ExitStack +import functools +import math +import time +import numpy as np +from PIL import Image +import jax +from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P +import jax.numpy as jnp +import flax.linen as nn +from chex import Array +from einops import rearrange +from flax.linen import partitioning as nn_partitioning +from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) + +from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging +from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel +from maxdiffusion.max_utils import ( + device_put_replicated, + get_memory_allocations, + create_device_mesh, + get_flash_block_sizes, + get_precision, + setup_initial_state, +) +from maxdiffusion.loaders.flux_lora_pipeline import FluxLoraLoaderMixin + + +def maybe_load_flux_lora(config, lora_loader, params): + def _noop_interceptor(next_fn, args, kwargs, context): + return next_fn(*args, **kwargs) + + lora_config = config.lora_config + interceptors = [_noop_interceptor] + if len(lora_config["lora_model_name_or_path"]) > 0: + interceptors = [] + for i in range(len(lora_config["lora_model_name_or_path"])): + params, rank, network_alphas = lora_loader.load_lora_weights( + config, + lora_config["lora_model_name_or_path"][i], + weight_name=lora_config["weight_name"][i], + params=params, + adapter_name=lora_config["adapter_name"][i], + ) + interceptor = lora_loader.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i]) + interceptors.append(interceptor) + return params, interceptors + + +def unpack(x: Array, height: int, width: int, vae_scale_factor: int) -> Array: + batch_size, num_patches, channels = x.shape + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = int(2 * (width // (vae_scale_factor * 2))) + + x = jnp.reshape(x, (batch_size, height // 2, width // 2, channels // 4, 2, 2)) + x = jnp.transpose(x, (0, 3, 1, 4, 2, 5)) + x = jnp.reshape(x, (batch_size, channels // (2 * 2), height, width)) + + return x + +def vae_decode(latents, vae, state, vae_scale_factor, resolution): + img = unpack(x=latents.astype(jnp.float32), height=resolution[0], width=resolution[1], vae_scale_factor=vae_scale_factor) + img = img / vae.config.scaling_factor + vae.config.shift_factor + img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample + return img + + +def loop_body( + step, + args, + transformer, + latent_image_ids, + prompt_embeds, + txt_ids, + vec, + guidance_vec, +): + latents, state, c_ts, p_ts = args + latents_dtype = latents.dtype + t_curr = c_ts[step] + t_prev = p_ts[step] + t_vec = jnp.full((latents.shape[0],), t_curr, dtype=latents.dtype) + pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + img_ids=latent_image_ids, + encoder_hidden_states=prompt_embeds, + txt_ids=txt_ids, + timestep=t_vec, + guidance=guidance_vec, + pooled_projections=vec, + ).sample + latents = latents + (t_prev - t_curr) * pred + latents = jnp.array(latents, dtype=latents_dtype) + return latents, state, c_ts, p_ts + + +def prepare_latent_image_ids(height, width): + latent_image_ids = jnp.zeros((height, width, 3)) + latent_image_ids = latent_image_ids.at[..., 1].set(jnp.arange(height)[:, None]) + latent_image_ids = latent_image_ids.at[..., 2].set(jnp.arange(width)[None, :]) + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels) + return latent_image_ids.astype(jnp.bfloat16) + + +def time_shift(mu: float, sigma: float, t: Array): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16 +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + +def run_inference( + states, + transformer, + vae, + config, + resolution, + mesh, + latents, + latent_image_ids, + prompt_embeds, + txt_ids, + vec, + guidance_vec, + c_ts, + p_ts, + vae_scale_factor +): + + transformer_state = states["transformer"] + vae_state = states["vae"] + + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=txt_ids, + vec=vec, + guidance_vec=guidance_vec, + ) + vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, vae_scale_factor=vae_scale_factor, resolution=resolution) + + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, transformer_state, c_ts, p_ts)) + image = vae_decode_p(latents) + return image + + +def pack_latents( + latents: Array, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, +): + latents = jnp.reshape(latents, (batch_size, num_channels_latents, height // 2, 2, width // 2, 2)) + latents = jnp.permute_dims(latents, (0, 2, 4, 1, 3, 5)) + latents = jnp.reshape(latents, (batch_size, (height // 2) * (width // 2), num_channels_latents * 4)) + + return latents + + +def prepare_latents( + batch_size: int, num_channels_latents: int, height: int, width: int, vae_scale_factor: int, dtype: jnp.dtype, rng: Array +): + + # VAE applies 8x compression on images but we must also account for packing which + # requires latent height and width to be divisibly by 2. + height = 2 * (height // (vae_scale_factor * 2)) + width = 2 * (width // (vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + latents = jax.random.normal(rng, shape=shape, dtype=dtype) + # pack latents + latents = pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = prepare_latent_image_ids(height // 2, width // 2) + latent_image_ids = jnp.tile(latent_image_ids, (batch_size, 1, 1)) + + return latents, latent_image_ids + +def tokenize_clip(prompt: Union[str, List[str]], tokenizer: CLIPTokenizer): + prompt = [prompt] if isinstance(prompt, str) else prompt + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="np", + ) + return text_inputs.input_ids + +def get_clip_prompt_embeds( + prompt: Union[str, List[str]], num_images_per_prompt: int, tokenizer: CLIPTokenizer, text_encoder: FlaxCLIPTextModel +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="np", + ) + + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder(text_input_ids, params=text_encoder.params, train=False) + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = jnp.tile(prompt_embeds, (batch_size * num_images_per_prompt, 1)) + return prompt_embeds + +def tokenize_t5(prompt: Union[str, List[str]], tokenizer: AutoTokenizer, max_sequence_length: int = 512): + prompt = [prompt] if isinstance(prompt, str) else prompt + text_inputs = tokenizer( + prompt, + truncation=True, + max_length=max_sequence_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="np", + ) + return text_inputs.input_ids + +def get_t5_prompt_embeds( + prompt: Union[str, List[str]], + num_images_per_prompt: int, + tokenizer: AutoTokenizer, + text_encoder: T5EncoderModel, + max_sequence_length: int = 512, +): + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + text_inputs = tokenizer( + prompt, + truncation=True, + max_length=max_sequence_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids, attention_mask=None, output_hidden_states=False)["last_hidden_state"] + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.astype(dtype) + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) + return prompt_embeds + +def encode_prompt( + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + clip_tokenizer: CLIPTokenizer, + clip_text_encoder: FlaxCLIPTextModel, + t5_tokenizer: AutoTokenizer, + t5_text_encoder: T5EncoderModel, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, +): + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_2 = prompt or prompt_2 + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + pooled_prompt_embeds = get_clip_prompt_embeds( + prompt=prompt, num_images_per_prompt=num_images_per_prompt, tokenizer=clip_tokenizer, text_encoder=clip_text_encoder + ) + + prompt_embeds = get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + tokenizer=t5_tokenizer, + text_encoder=t5_text_encoder, + max_sequence_length=max_sequence_length, + ) + + text_ids = jnp.zeros((prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) + return prompt_embeds, pooled_prompt_embeds, text_ids + +def run(config): + from maxdiffusion.models.flux.util import load_flow_model + + rng = jax.random.key(config.seed) + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + global_batch_size = config.per_device_batch_size * jax.local_device_count() + + # LOAD VAE + + vae, vae_params = FlaxAutoencoderKL.from_pretrained( + config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16" + ) + + weights_init_fn = functools.partial(vae.init_weights, rng=rng) + vae_state, vae_state_shardings = setup_initial_state( + model=vae, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=vae_params, + training=False, + ) + + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + + # LOAD TRANSFORMER + flash_block_sizes = get_flash_block_sizes(config) + transformer = FluxTransformer2DModel.from_config( + config.pretrained_model_name_or_path, + subfolder="transformer", + mesh=mesh, + split_head_dim=config.split_head_dim, + attention_kernel=config.attention, + flash_block_sizes=flash_block_sizes, + dtype=config.activations_dtype, + weights_dtype=config.weights_dtype, + precision=get_precision(config), + ) + + num_channels_latents = transformer.in_channels // 4 + + # LOAD TEXT ENCODERS + clip_text_encoder = FlaxCLIPTextModel.from_pretrained( + config.pretrained_model_name_or_path, subfolder="text_encoder", from_pt=True, dtype=config.weights_dtype + ) + clip_tokenizer = CLIPTokenizer.from_pretrained( + config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype + ) + + t5_encoder = FlaxT5EncoderModel.from_pretrained(config.t5xxl_model_name_or_path, dtype=config.weights_dtype) + t5_tokenizer = AutoTokenizer.from_pretrained( + config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True + ) + + encoders_sharding = PositionalSharding(devices_array).replicate() + partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding) + clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params) + clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params) + t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params) + t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) + + def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds): + print("latents.shape: ", latents.shape, latents.dtype) + print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype) + print("text_ids.shape: ", text_ids.shape, text_ids.dtype) + print("prompt_embeds: ", prompt_embeds.shape, prompt_embeds.dtype) + print("timesteps.shape: ", timesteps.shape, timesteps.dtype) + print("guidance.shape: ", guidance.shape, guidance.dtype) + print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype) + + guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) + + get_memory_allocations() + # evaluate shapes + transformer_eval_params = transformer.init_weights( + rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True + ) + + # loads pretrained weights + transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu") + params = {} + params["transformer"] = transformer_params + # maybe load lora and create interceptor + lora_loader = FluxLoraLoaderMixin() + params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params) + transformer_params = params["transformer"] + # create transformer state + weights_init_fn = functools.partial( + transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False + ) + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=None, + training=False, + ) + transformer_state = transformer_state.replace(params=transformer_params) + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + get_memory_allocations() + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + state_shardings["vae"] = vae_state_shardings + + states["transformer"] = transformer_state + states["vae"] = vae_state + # some resolutions from https://www.reddit.com/r/StableDiffusion/comments/1enxdga/flux_recommended_resolutions_from_01_to_20/ + resolutions = [ + (768, 768), + (768, 1024), + (1024, 768), + (1024, 1024), + (1408, 1408), + (1728, 1152), + (1152, 1728), + (1664, 1216), + (1216, 1664), + (1920, 1088), + (1088, 1920), + (2176, 960), + (960, 2176) + ] + p_jitted = {} + recorded_times = {} + text_encoding_time_final = 0 + for resolution in resolutions: + max_logging.log(f"Resolutions: {resolution}") + for _ in range(2): + s0 = time.perf_counter() + if config.offload_encoders: + t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) + max_logging.log(f"Moving encoder to TPU time: {(time.perf_counter() - s0)}") + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + prompt=config.prompt, + prompt_2=config.prompt_2, + clip_tokenizer=clip_tokenizer, + clip_text_encoder=clip_text_encoder, + t5_tokenizer=t5_tokenizer, + t5_text_encoder=t5_encoder, + num_images_per_prompt=global_batch_size, + max_sequence_length=config.max_sequence_length, + ) + if config.offload_encoders: + s1 = time.perf_counter() + cpus = jax.devices("cpu") + t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0]) + max_logging.log(f"Text encoding offload time: {(time.perf_counter() - s1)}") + text_encoding_time_final = time.perf_counter() - s0 + max_logging.log(f"text encoding time: {text_encoding_time_final}") + latents, latent_image_ids = prepare_latents( + batch_size=global_batch_size, + num_channels_latents=num_channels_latents, + height=resolution[0], + width=resolution[1], + dtype=jnp.bfloat16, + vae_scale_factor=vae_scale_factor, + rng=rng, + ) + + # move inputs to device and shard + s0 = time.perf_counter() + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + text_ids = jax.device_put(text_ids) + guidance = jax.device_put(guidance, data_sharding) + pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding) + latents = jax.device_put(latents, data_sharding) + latent_image_ids = jax.device_put(latent_image_ids) + max_logging.log(f"Moving to device time: {(time.perf_counter() - s0)}") + + # Setup timesteps + timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) + # shifting the schedule to favor high timesteps for higher signal images + if config.time_shift: + # estimate mu based on linear estimation between two points + # lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift) + # mu = lin_function(latents.shape[1]) + mu = calculate_shift(latents.shape[1], base_shift=config.base_shift, max_shift=config.max_shift) + timesteps = time_shift(mu, 1.0, timesteps) + c_ts = timesteps[:-1] + p_ts = timesteps[1:] + #validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) + p_run_inference = p_jitted.get(resolution, None) + if p_run_inference is None: + print("FN not found, compiling...") + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer, + vae=vae, + config=config, + resolution=resolution, + mesh=mesh, + latents=latents, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=text_ids, + vec=pooled_prompt_embeds, + guidance_vec=guidance, + c_ts=c_ts, + p_ts=p_ts, + vae_scale_factor=vae_scale_factor, + ), + ) + p_jitted[resolution] = p_run_inference + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + s0 = time.perf_counter() + imgs = p_run_inference( + states, + latents = latents, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=text_ids, + vec=pooled_prompt_embeds, + ).block_until_ready() + recorded_times[resolution] = (time.perf_counter() - s0) + max_logging.log(f"inference time: {recorded_times[resolution]}") + s0 = time.perf_counter() + imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) + max_logging.log(f"Gathering all time: {(time.perf_counter() - s0)}") + s0 = time.perf_counter() + imgs = np.array(imgs) + imgs = (imgs * 0.5 + 0.5).clip(0, 1) + imgs = np.transpose(imgs, (0, 2, 3, 1)) + imgs = np.uint8(imgs * 255) + for i, image in enumerate(imgs): + Image.fromarray(image).save(f"flux_{resolution[0]}_{resolution[1]}_{i}.png") + max_logging.log(f"Saving images time: {(time.perf_counter() - s0)}") + get_memory_allocations() + + max_logging.log("***************RESULTS***************") + for key in recorded_times.keys(): + max_logging.log(f"Resolution: {key}, last inference time: {recorded_times[key]}") + max_logging.log(f"\nText encoding time: {text_encoding_time_final}") + + return imgs + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) From e0f8163b1886e752d2d8f8c0e9c5991cd2f61366 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 18 Apr 2025 20:50:39 +0000 Subject: [PATCH 5/8] formatting --- src/maxdiffusion/generate_flux.py | 2 +- src/maxdiffusion/generate_flux_multi_res.py | 96 +++++++++++---------- 2 files changed, 53 insertions(+), 45 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index ac7f4cd90..615f3c241 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -496,4 +496,4 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) \ No newline at end of file + app.run(main) diff --git a/src/maxdiffusion/generate_flux_multi_res.py b/src/maxdiffusion/generate_flux_multi_res.py index d1b9b177e..fa0f8659b 100644 --- a/src/maxdiffusion/generate_flux_multi_res.py +++ b/src/maxdiffusion/generate_flux_multi_res.py @@ -76,6 +76,7 @@ def unpack(x: Array, height: int, width: int, vae_scale_factor: int) -> Array: return x + def vae_decode(latents, vae, state, vae_scale_factor, resolution): img = unpack(x=latents.astype(jnp.float32), height=resolution[0], width=resolution[1], vae_scale_factor=vae_scale_factor) img = img / vae.config.scaling_factor + vae.config.shift_factor @@ -127,18 +128,16 @@ def prepare_latent_image_ids(height, width): def time_shift(mu: float, sigma: float, t: Array): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.16 + image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.16 ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu + def run_inference( states, transformer, @@ -154,7 +153,7 @@ def run_inference( guidance_vec, c_ts, p_ts, - vae_scale_factor + vae_scale_factor, ): transformer_state = states["transformer"] @@ -169,7 +168,9 @@ def run_inference( vec=vec, guidance_vec=guidance_vec, ) - vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, vae_scale_factor=vae_scale_factor, resolution=resolution) + vae_decode_p = functools.partial( + vae_decode, vae=vae, state=vae_state, vae_scale_factor=vae_scale_factor, resolution=resolution + ) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, transformer_state, c_ts, p_ts)) @@ -211,6 +212,7 @@ def prepare_latents( return latents, latent_image_ids + def tokenize_clip(prompt: Union[str, List[str]], tokenizer: CLIPTokenizer): prompt = [prompt] if isinstance(prompt, str) else prompt text_inputs = tokenizer( @@ -224,6 +226,7 @@ def tokenize_clip(prompt: Union[str, List[str]], tokenizer: CLIPTokenizer): ) return text_inputs.input_ids + def get_clip_prompt_embeds( prompt: Union[str, List[str]], num_images_per_prompt: int, tokenizer: CLIPTokenizer, text_encoder: FlaxCLIPTextModel ): @@ -246,6 +249,7 @@ def get_clip_prompt_embeds( prompt_embeds = jnp.tile(prompt_embeds, (batch_size * num_images_per_prompt, 1)) return prompt_embeds + def tokenize_t5(prompt: Union[str, List[str]], tokenizer: AutoTokenizer, max_sequence_length: int = 512): prompt = [prompt] if isinstance(prompt, str) else prompt text_inputs = tokenizer( @@ -259,6 +263,7 @@ def tokenize_t5(prompt: Union[str, List[str]], tokenizer: AutoTokenizer, max_seq ) return text_inputs.input_ids + def get_t5_prompt_embeds( prompt: Union[str, List[str]], num_images_per_prompt: int, @@ -288,6 +293,7 @@ def get_t5_prompt_embeds( prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) return prompt_embeds + def encode_prompt( prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], @@ -318,6 +324,7 @@ def encode_prompt( text_ids = jnp.zeros((prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) return prompt_embeds, pooled_prompt_embeds, text_ids + def run(config): from maxdiffusion.models.flux.util import load_flow_model @@ -436,19 +443,19 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep states["vae"] = vae_state # some resolutions from https://www.reddit.com/r/StableDiffusion/comments/1enxdga/flux_recommended_resolutions_from_01_to_20/ resolutions = [ - (768, 768), - (768, 1024), - (1024, 768), - (1024, 1024), - (1408, 1408), - (1728, 1152), - (1152, 1728), - (1664, 1216), - (1216, 1664), - (1920, 1088), - (1088, 1920), - (2176, 960), - (960, 2176) + (768, 768), + (768, 1024), + (1024, 768), + (1024, 1024), + (1408, 1408), + (1728, 1152), + (1152, 1728), + (1664, 1216), + (1216, 1664), + (1920, 1088), + (1088, 1920), + (2176, 960), + (960, 2176), ] p_jitted = {} recorded_times = {} @@ -461,14 +468,14 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) max_logging.log(f"Moving encoder to TPU time: {(time.perf_counter() - s0)}") prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - prompt=config.prompt, - prompt_2=config.prompt_2, - clip_tokenizer=clip_tokenizer, - clip_text_encoder=clip_text_encoder, - t5_tokenizer=t5_tokenizer, - t5_text_encoder=t5_encoder, - num_images_per_prompt=global_batch_size, - max_sequence_length=config.max_sequence_length, + prompt=config.prompt, + prompt_2=config.prompt_2, + clip_tokenizer=clip_tokenizer, + clip_text_encoder=clip_text_encoder, + t5_tokenizer=t5_tokenizer, + t5_text_encoder=t5_encoder, + num_images_per_prompt=global_batch_size, + max_sequence_length=config.max_sequence_length, ) if config.offload_encoders: s1 = time.perf_counter() @@ -478,15 +485,15 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep text_encoding_time_final = time.perf_counter() - s0 max_logging.log(f"text encoding time: {text_encoding_time_final}") latents, latent_image_ids = prepare_latents( - batch_size=global_batch_size, - num_channels_latents=num_channels_latents, - height=resolution[0], - width=resolution[1], - dtype=jnp.bfloat16, - vae_scale_factor=vae_scale_factor, - rng=rng, + batch_size=global_batch_size, + num_channels_latents=num_channels_latents, + height=resolution[0], + width=resolution[1], + dtype=jnp.bfloat16, + vae_scale_factor=vae_scale_factor, + rng=rng, ) - + # move inputs to device and shard s0 = time.perf_counter() data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) @@ -509,7 +516,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep timesteps = time_shift(mu, 1.0, timesteps) c_ts = timesteps[:-1] p_ts = timesteps[1:] - #validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) + # validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) p_run_inference = p_jitted.get(resolution, None) if p_run_inference is None: print("FN not found, compiling...") @@ -537,14 +544,14 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] s0 = time.perf_counter() imgs = p_run_inference( - states, - latents = latents, - latent_image_ids=latent_image_ids, - prompt_embeds=prompt_embeds, - txt_ids=text_ids, - vec=pooled_prompt_embeds, + states, + latents=latents, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=text_ids, + vec=pooled_prompt_embeds, ).block_until_ready() - recorded_times[resolution] = (time.perf_counter() - s0) + recorded_times[resolution] = time.perf_counter() - s0 max_logging.log(f"inference time: {recorded_times[resolution]}") s0 = time.perf_counter() imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) @@ -566,6 +573,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep return imgs + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) run(pyconfig.config) From c39a7419e4f1a4df79a95b4fd37bb2214d58a62b Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 18 Apr 2025 20:53:11 +0000 Subject: [PATCH 6/8] remove unused dependencies. --- src/maxdiffusion/generate_flux_multi_res.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/maxdiffusion/generate_flux_multi_res.py b/src/maxdiffusion/generate_flux_multi_res.py index fa0f8659b..ed1baa67a 100644 --- a/src/maxdiffusion/generate_flux_multi_res.py +++ b/src/maxdiffusion/generate_flux_multi_res.py @@ -14,7 +14,7 @@ limitations under the License. """ -from typing import Callable, List, Union, Sequence +from typing import List, Union, Sequence from absl import app from contextlib import ExitStack import functools @@ -27,7 +27,6 @@ import jax.numpy as jnp import flax.linen as nn from chex import Array -from einops import rearrange from flax.linen import partitioning as nn_partitioning from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) From 91826cf8a36b5535f14c87279ee7818917335d7d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 22 Apr 2025 18:31:33 +0000 Subject: [PATCH 7/8] update config with flux names --- src/maxdiffusion/configs/base_flux_dev_multi_res.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index e0597360d..7c11e698d 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -75,9 +75,9 @@ flash_block_sizes: { # GroupNorm groups norm_num_groups: 32 -# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch +# If train_new_flux, unet weights will be randomly initialized to train the unet from scratch # else they will be loaded from pretrained_model_name_or_path -train_new_unet: False +train_new_flux: False # train text_encoder - Currently not supported for SDXL train_text_encoder: False From 4f230da1c193b2bf97eb38c98ed1e41a794481fd Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 22 Apr 2025 18:34:47 +0000 Subject: [PATCH 8/8] revert to torch 2.5.1 due to compatibility with torchvision. --- requirements.txt | 2 +- requirements_with_jax_stable_stack.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4d257e21a..1ca1dc79b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ absl-py datasets flax>=0.10.2 optax>=0.2.3 -torch==2.6.0 +torch==2.5.1 torchvision==0.20.1 ftfy tensorboard>=2.17.0 diff --git a/requirements_with_jax_stable_stack.txt b/requirements_with_jax_stable_stack.txt index b674de387..ba21b4774 100644 --- a/requirements_with_jax_stable_stack.txt +++ b/requirements_with_jax_stable_stack.txt @@ -29,6 +29,6 @@ tensorboardx==2.6.2.2 tensorflow>=2.17.0 tensorflow-datasets>=4.9.6 tokenizers==0.21.0 -torch==2.6.0 +torch==2.5.1 torchvision==0.20.1 transformers==4.48.1 \ No newline at end of file