Skip to content
271 changes: 271 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev_multi_res.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This sentinel is a reminder to choose a real run name.
run_name: ''

metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
write_metrics: True
gcs_metrics: False
# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
log_period: 100

pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'

# Flux params
flux_name: "flux-dev"
max_sequence_length: 512
time_shift: True
base_shift: 0.5
max_shift: 1.15
# offloads t5 encoder after text encoding to save memory.
offload_encoders: True


unet_checkpoint: ''
revision: 'refs/pr/95'
# This will convert the weights to this dtype.
# When running inference on TPUv5e, use weights_dtype: 'bfloat16'
weights_dtype: 'bfloat16'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
# at the cost of time.
precision: "DEFAULT"

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
jit_initializers: True

# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te

#flash_block_sizes: {}
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
flash_block_sizes: {
"block_q" : 1536,
"block_kv_compute" : 1536,
"block_kv" : 1536,
"block_q_dkv" : 1536,
"block_kv_dkv" : 1536,
"block_kv_dkv_compute" : 1536,
"block_q_dq" : 1536,
"block_kv_dq" : 1536
}
# GroupNorm groups
norm_num_groups: 32

# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch
# else they will be loaded from pretrained_model_name_or_path
train_new_unet: False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be train_new_flux?


# train text_encoder - Currently not supported for SDXL
train_text_encoder: False
text_encoder_learning_rate: 4.25e-6

# https://arxiv.org/pdf/2305.08891.pdf
snr_gamma: -1.0

timestep_bias: {
# a value of later will increase the frequence of the model's final training steps.
# none, earlier, later, range
strategy: "none",
# multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it.
multiplier: 1.0,
# when using strategy=range, the beginning (inclusive) timestep to bias.
begin: 0,
# when using strategy=range, the final step (inclusive) to bias.
end: 1000,
# portion of timesteps to bias.
# 0.5 will bias one half of the timesteps. Value of strategy determines
# whether the biased portions are in the earlier or later timesteps.
portion: 0.25
}

# Override parameters from checkpoints's scheduler.
diffusion_scheduler_config: {
_class_name: 'FlaxEulerDiscreteScheduler',
prediction_type: 'epsilon',
rescale_zero_terminal_snr: False,
timestep_spacing: 'trailing'
}

# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
base_output_directory: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']

# batch : batch dimension of data and activations
# hidden :
# embed : attention qkv dense layer hidden dim named as embed
# heads : attention head dim = num_heads * head_dim
# length : attention sequence length
# temb_in : dense.shape[0] of resnet dense before conv
# out_c : dense.shape[1] of resnet dense before conv
# out_channels : conv.shape[-1] activation
# keep_1 : conv.shape[0] weight
# keep_2 : conv.shape[1] weight
# conv_in : conv.shape[2] weight
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
# ['embed','fsdp'],
['mlp',['fsdp','tensor']],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
# only apply to small dataset that fits in memory
# prepare image latents and text encoder outputs
# Reduce memory consumption and reduce step time during training
# transformed dataset is saved at dataset_save_location
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
train_data_dir: ''
dataset_config_name: ''
jax_cache_dir: ''
hf_data_dir: ''
hf_train_files: ''
hf_access_token: ''
image_column: 'image'
caption_column: 'text'
resolution: 1024
center_crop: False
random_flip: False
# If cache_latents_text_encoder_outputs is True
# the num_proc is set to 1
tokenize_captions_num_proc: 4
transform_images_num_proc: 4
reuse_example_batch: False
enable_data_shuffling: True

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

# Training loop
learning_rate: 4.e-7
scale_lr: False
max_train_samples: -1
# max_train_steps takes priority over num_train_epochs.
max_train_steps: 200
num_train_epochs: 1
seed: 0
output_dir: 'sdxl-model-finetuned'
per_device_batch_size: 1

warmup_steps_fraction: 0.0
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.

# AdamW optimizer parameters
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_weight_decay: 1.e-2 # AdamW Weight decay
max_grad_norm: 1.0

enable_profiler: False
# Skip first n steps for profiling, to omit things like compilation and to give
# the iteration time a chance to stabilize.
skip_first_n_steps_for_profiler: 5
profiler_steps: 10

# Generation parameters
prompt: "A magical castle in the middle of a forest, artistic drawing"
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
negative_prompt: "purple, red"
do_classifier_free_guidance: True
guidance_scale: 3.5
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 50

# SDXL Lightning parameters
lightning_from_pt: True
# Empty or "ByteDance/SDXL-Lightning" to enable lightning.
lightning_repo: ""
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
lightning_ckpt: ""

# LoRA parameters
# Values are lists to support multiple LoRA loading during inference in the future.
lora_config: {
lora_model_name_or_path: [],
weight_name: [],
adapter_name: [],
scale: [],
from_pt: []
}
# Ex with values:
# lora_config : {
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
# adapter_name: ["hyper-sdxl"],
# scale: [0.7],
# from_pt: [True]
# }

enable_mllog: False

#controlnet
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
controlnet_from_pt: True
controlnet_conditioning_scale: 0.5
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
quantization: ''
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

2 changes: 0 additions & 2 deletions src/maxdiffusion/generate_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def get_t5_prompt_embeds(

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)

text_inputs = tokenizer(
prompt,
truncation=True,
Expand All @@ -244,7 +243,6 @@ def get_t5_prompt_embeds(
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1))
prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1))

return prompt_embeds


Expand Down
Loading
Loading