Skip to content

Commit 860e76e

Browse files
working loop, bad generation
1 parent 4b64f5d commit 860e76e

4 files changed

Lines changed: 312 additions & 43 deletions

File tree

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This sentinel is a reminder to choose a real run name.
16+
run_name: ''
17+
18+
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
19+
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
20+
write_metrics: True
21+
gcs_metrics: False
22+
# If true save config to GCS in {base_output_directory}/{run_name}/
23+
save_config_to_gcs: False
24+
log_period: 100
25+
26+
pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
27+
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
28+
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
29+
30+
# Flux params
31+
max_sequence_length: 512
32+
time_shift: False
33+
base_shift: 0.5
34+
max_shift: 1.15
35+
36+
37+
unet_checkpoint: ''
38+
revision: 'refs/pr/95'
39+
# This will convert the weights to this dtype.
40+
# When running inference on TPUv5e, use weights_dtype: 'bfloat16'
41+
weights_dtype: 'bfloat16'
42+
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
43+
activations_dtype: 'bfloat16'
44+
45+
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
46+
# Options are "DEFAULT", "HIGH", "HIGHEST"
47+
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
48+
# at the cost of time.
49+
precision: "DEFAULT"
50+
51+
# Set true to load weights from pytorch
52+
from_pt: False
53+
split_head_dim: True
54+
attention: 'flash' # Supported attention: dot_product, flash
55+
flash_block_sizes: {}
56+
# GroupNorm groups
57+
norm_num_groups: 32
58+
59+
# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch
60+
# else they will be loaded from pretrained_model_name_or_path
61+
train_new_unet: False
62+
63+
# train text_encoder - Currently not supported for SDXL
64+
train_text_encoder: False
65+
text_encoder_learning_rate: 4.25e-6
66+
67+
# https://arxiv.org/pdf/2305.08891.pdf
68+
snr_gamma: -1.0
69+
70+
timestep_bias: {
71+
# a value of later will increase the frequence of the model's final training steps.
72+
# none, earlier, later, range
73+
strategy: "none",
74+
# multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it.
75+
multiplier: 1.0,
76+
# when using strategy=range, the beginning (inclusive) timestep to bias.
77+
begin: 0,
78+
# when using strategy=range, the final step (inclusive) to bias.
79+
end: 1000,
80+
# portion of timesteps to bias.
81+
# 0.5 will bias one half of the timesteps. Value of strategy determines
82+
# whether the biased portions are in the earlier or later timesteps.
83+
portion: 0.25
84+
}
85+
86+
# Override parameters from checkpoints's scheduler.
87+
diffusion_scheduler_config: {
88+
_class_name: 'FlaxEulerDiscreteScheduler',
89+
prediction_type: 'epsilon',
90+
rescale_zero_terminal_snr: False,
91+
timestep_spacing: 'trailing'
92+
}
93+
94+
# Output directory
95+
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
96+
base_output_directory: ""
97+
98+
# Hardware
99+
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
100+
101+
# Parallelism
102+
mesh_axes: ['data', 'fsdp', 'tensor']
103+
104+
# batch : batch dimension of data and activations
105+
# hidden :
106+
# embed : attention qkv dense layer hidden dim named as embed
107+
# heads : attention head dim = num_heads * head_dim
108+
# length : attention sequence length
109+
# temb_in : dense.shape[0] of resnet dense before conv
110+
# out_c : dense.shape[1] of resnet dense before conv
111+
# out_channels : conv.shape[-1] activation
112+
# keep_1 : conv.shape[0] weight
113+
# keep_2 : conv.shape[1] weight
114+
# conv_in : conv.shape[2] weight
115+
# conv_out : conv.shape[-1] weight
116+
logical_axis_rules: [
117+
['batch', 'data'],
118+
['activation_batch', ['data','fsdp']],
119+
['activation_heads', 'tensor'],
120+
['activation_kv', 'tensor'],
121+
['embed','fsdp'],
122+
['heads', 'tensor'],
123+
['conv_batch', ['data','fsdp']],
124+
['out_channels', 'tensor'],
125+
['conv_out', 'fsdp'],
126+
]
127+
data_sharding: [['data', 'fsdp', 'tensor']]
128+
129+
# One axis for each parallelism type may hold a placeholder (-1)
130+
# value to auto-shard based on available slices and devices.
131+
# By default, product of the DCN axes should equal number of slices
132+
# and product of the ICI axes should equal number of devices per slice.
133+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
134+
dcn_fsdp_parallelism: -1
135+
dcn_tensor_parallelism: 1
136+
ici_data_parallelism: 1
137+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
138+
ici_tensor_parallelism: 1
139+
140+
# Dataset
141+
# Replace with dataset path or train_data_dir. One has to be set.
142+
dataset_name: 'diffusers/pokemon-gpt4-captions'
143+
train_split: 'train'
144+
dataset_type: 'tf'
145+
cache_latents_text_encoder_outputs: True
146+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
147+
# only apply to small dataset that fits in memory
148+
# prepare image latents and text encoder outputs
149+
# Reduce memory consumption and reduce step time during training
150+
# transformed dataset is saved at dataset_save_location
151+
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
152+
train_data_dir: ''
153+
dataset_config_name: ''
154+
jax_cache_dir: ''
155+
hf_data_dir: ''
156+
hf_train_files: ''
157+
hf_access_token: ''
158+
image_column: 'image'
159+
caption_column: 'text'
160+
resolution: 1024
161+
center_crop: False
162+
random_flip: False
163+
# If cache_latents_text_encoder_outputs is True
164+
# the num_proc is set to 1
165+
tokenize_captions_num_proc: 4
166+
transform_images_num_proc: 4
167+
reuse_example_batch: False
168+
enable_data_shuffling: True
169+
170+
# checkpoint every number of samples, -1 means don't checkpoint.
171+
checkpoint_every: -1
172+
# enables one replica to read the ckpt then broadcast to the rest
173+
enable_single_replica_ckpt_restoring: False
174+
175+
# Training loop
176+
learning_rate: 4.e-7
177+
scale_lr: False
178+
max_train_samples: -1
179+
# max_train_steps takes priority over num_train_epochs.
180+
max_train_steps: 200
181+
num_train_epochs: 1
182+
seed: 0
183+
output_dir: 'sdxl-model-finetuned'
184+
per_device_batch_size: 1
185+
186+
warmup_steps_fraction: 0.0
187+
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
188+
189+
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
190+
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
191+
192+
# AdamW optimizer parameters
193+
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
194+
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
195+
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
196+
adam_weight_decay: 1.e-2 # AdamW Weight decay
197+
max_grad_norm: 1.0
198+
199+
enable_profiler: False
200+
# Skip first n steps for profiling, to omit things like compilation and to give
201+
# the iteration time a chance to stabilize.
202+
skip_first_n_steps_for_profiler: 5
203+
profiler_steps: 10
204+
205+
# Generation parameters
206+
prompt: "A magical castle in the middle of a forest, artistic drawing"
207+
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
208+
negative_prompt: "purple, red"
209+
do_classifier_free_guidance: True
210+
guidance_scale: 3.5
211+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
212+
guidance_rescale: 0.0
213+
num_inference_steps: 20
214+
215+
# SDXL Lightning parameters
216+
lightning_from_pt: True
217+
# Empty or "ByteDance/SDXL-Lightning" to enable lightning.
218+
lightning_repo: ""
219+
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
220+
lightning_ckpt: ""
221+
222+
# LoRA parameters
223+
# Values are lists to support multiple LoRA loading during inference in the future.
224+
lora_config: {
225+
lora_model_name_or_path: [],
226+
weight_name: [],
227+
adapter_name: [],
228+
scale: [],
229+
from_pt: []
230+
}
231+
# Ex with values:
232+
# lora_config : {
233+
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
234+
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
235+
# adapter_name: ["hyper-sdxl"],
236+
# scale: [0.7],
237+
# from_pt: [True]
238+
# }
239+
240+
enable_mllog: False
241+
242+
#controlnet
243+
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
244+
controlnet_from_pt: True
245+
controlnet_conditioning_scale: 0.5
246+
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'

src/maxdiffusion/configs/base_flux.yml renamed to src/maxdiffusion/configs/base_fux_schnell.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
2727
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
2828
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
2929

30+
# Flux params
31+
max_sequence_length: 256
32+
time_shift: False
33+
base_shift: 0.5
34+
max_shift: 1.15
35+
3036
unet_checkpoint: ''
3137
revision: 'refs/pr/95'
3238
# This will convert the weights to this dtype.

0 commit comments

Comments
 (0)