Skip to content

Commit b9b2465

Browse files
wip - attention for wan.
1 parent 2499b2d commit b9b2465

3 files changed

Lines changed: 867 additions & 352 deletions

File tree

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This sentinel is a reminder to choose a real run name.
16+
run_name: ''
17+
18+
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
19+
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
20+
write_metrics: True
21+
22+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
23+
write_timing_metrics: True
24+
25+
gcs_metrics: False
26+
# If true save config to GCS in {base_output_directory}/{run_name}/
27+
save_config_to_gcs: False
28+
log_period: 100
29+
30+
pretrained_model_name_or_path: ''
31+
32+
unet_checkpoint: ''
33+
revision: ''
34+
# This will convert the weights to this dtype.
35+
# When running inference on TPUv5e, use weights_dtype: 'bfloat16'
36+
weights_dtype: 'bfloat16'
37+
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
38+
activations_dtype: 'bfloat16'
39+
40+
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
41+
# Options are "DEFAULT", "HIGH", "HIGHEST"
42+
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
43+
# at the cost of time.
44+
precision: "DEFAULT"
45+
46+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
47+
# It must be True for multi-host.
48+
jit_initializers: True
49+
50+
# Set true to load weights from pytorch
51+
from_pt: True
52+
split_head_dim: True
53+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
54+
55+
flash_block_sizes: {
56+
"block_q" : 512,
57+
"block_kv_compute" : 512,
58+
"block_kv" : 512,
59+
"block_q_dkv" : 512,
60+
"block_kv_dkv" : 512,
61+
"block_kv_dkv_compute" : 512,
62+
"block_q_dq" : 512,
63+
"block_kv_dq" : 512
64+
}
65+
# GroupNorm groups
66+
norm_num_groups: 32
67+
68+
# train text_encoder - Currently not supported for SDXL
69+
train_text_encoder: False
70+
text_encoder_learning_rate: 4.25e-6
71+
72+
# https://arxiv.org/pdf/2305.08891.pdf
73+
snr_gamma: -1.0
74+
75+
timestep_bias: {
76+
# a value of later will increase the frequence of the model's final training steps.
77+
# none, earlier, later, range
78+
strategy: "none",
79+
# multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it.
80+
multiplier: 1.0,
81+
# when using strategy=range, the beginning (inclusive) timestep to bias.
82+
begin: 0,
83+
# when using strategy=range, the final step (inclusive) to bias.
84+
end: 1000,
85+
# portion of timesteps to bias.
86+
# 0.5 will bias one half of the timesteps. Value of strategy determines
87+
# whether the biased portions are in the earlier or later timesteps.
88+
portion: 0.25
89+
}
90+
91+
# Override parameters from checkpoints's scheduler.
92+
diffusion_scheduler_config: {
93+
_class_name: 'FlaxEulerDiscreteScheduler',
94+
prediction_type: 'epsilon',
95+
rescale_zero_terminal_snr: False,
96+
timestep_spacing: 'trailing'
97+
}
98+
99+
# Output directory
100+
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
101+
base_output_directory: ""
102+
103+
# Hardware
104+
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
105+
106+
# Parallelism
107+
mesh_axes: ['data', 'fsdp', 'tensor']
108+
109+
# batch : batch dimension of data and activations
110+
# hidden :
111+
# embed : attention qkv dense layer hidden dim named as embed
112+
# heads : attention head dim = num_heads * head_dim
113+
# length : attention sequence length
114+
# temb_in : dense.shape[0] of resnet dense before conv
115+
# out_c : dense.shape[1] of resnet dense before conv
116+
# out_channels : conv.shape[-1] activation
117+
# keep_1 : conv.shape[0] weight
118+
# keep_2 : conv.shape[1] weight
119+
# conv_in : conv.shape[2] weight
120+
# conv_out : conv.shape[-1] weight
121+
logical_axis_rules: [
122+
['batch', 'data'],
123+
['activation_batch', ['data','fsdp']],
124+
['activation_heads', 'tensor'],
125+
['activation_kv', 'tensor'],
126+
['mlp','tensor'],
127+
['embed','fsdp'],
128+
['heads', 'tensor'],
129+
['conv_batch', ['data','fsdp']],
130+
['out_channels', 'tensor'],
131+
['conv_out', 'fsdp'],
132+
]
133+
data_sharding: [['data', 'fsdp', 'tensor']]
134+
135+
# One axis for each parallelism type may hold a placeholder (-1)
136+
# value to auto-shard based on available slices and devices.
137+
# By default, product of the DCN axes should equal number of slices
138+
# and product of the ICI axes should equal number of devices per slice.
139+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
140+
dcn_fsdp_parallelism: -1
141+
dcn_tensor_parallelism: 1
142+
ici_data_parallelism: -1
143+
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
144+
ici_tensor_parallelism: 1
145+
146+
# Dataset
147+
# Replace with dataset path or train_data_dir. One has to be set.
148+
dataset_name: 'diffusers/pokemon-gpt4-captions'
149+
train_split: 'train'
150+
dataset_type: 'tf'
151+
cache_latents_text_encoder_outputs: True
152+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
153+
# only apply to small dataset that fits in memory
154+
# prepare image latents and text encoder outputs
155+
# Reduce memory consumption and reduce step time during training
156+
# transformed dataset is saved at dataset_save_location
157+
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
158+
train_data_dir: ''
159+
dataset_config_name: ''
160+
jax_cache_dir: ''
161+
hf_data_dir: ''
162+
hf_train_files: ''
163+
hf_access_token: ''
164+
image_column: 'image'
165+
caption_column: 'text'
166+
resolution: 1024
167+
center_crop: False
168+
random_flip: False
169+
# If cache_latents_text_encoder_outputs is True
170+
# the num_proc is set to 1
171+
tokenize_captions_num_proc: 4
172+
transform_images_num_proc: 4
173+
reuse_example_batch: False
174+
enable_data_shuffling: True
175+
176+
# checkpoint every number of samples, -1 means don't checkpoint.
177+
checkpoint_every: -1
178+
# enables one replica to read the ckpt then broadcast to the rest
179+
enable_single_replica_ckpt_restoring: False
180+
181+
# Training loop
182+
learning_rate: 1.e-5
183+
scale_lr: False
184+
max_train_samples: -1
185+
# max_train_steps takes priority over num_train_epochs.
186+
max_train_steps: 1500
187+
num_train_epochs: 1
188+
seed: 0
189+
output_dir: 'sdxl-model-finetuned'
190+
per_device_batch_size: 1
191+
192+
warmup_steps_fraction: 0.1
193+
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
194+
195+
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
196+
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
197+
198+
# AdamW optimizer parameters
199+
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
200+
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
201+
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
202+
adam_weight_decay: 0 # AdamW Weight decay
203+
max_grad_norm: 1.0
204+
205+
enable_profiler: False
206+
# Skip first n steps for profiling, to omit things like compilation and to give
207+
# the iteration time a chance to stabilize.
208+
skip_first_n_steps_for_profiler: 5
209+
profiler_steps: 10
210+
211+
# Generation parameters
212+
prompt: "A magical castle in the middle of a forest, artistic drawing"
213+
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
214+
negative_prompt: "purple, red"
215+
do_classifier_free_guidance: True
216+
guidance_scale: 3.5
217+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
218+
guidance_rescale: 0.0
219+
num_inference_steps: 50
220+
save_final_checkpoint: False
221+
222+
# SDXL Lightning parameters
223+
lightning_from_pt: True
224+
# Empty or "ByteDance/SDXL-Lightning" to enable lightning.
225+
lightning_repo: ""
226+
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
227+
lightning_ckpt: ""
228+
229+
# LoRA parameters
230+
# Values are lists to support multiple LoRA loading during inference in the future.
231+
lora_config: {
232+
lora_model_name_or_path: [],
233+
weight_name: [],
234+
adapter_name: [],
235+
scale: [],
236+
from_pt: []
237+
}
238+
# Ex with values:
239+
# lora_config : {
240+
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
241+
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
242+
# adapter_name: ["hyper-sdxl"],
243+
# scale: [0.7],
244+
# from_pt: [True]
245+
# }
246+
247+
enable_mllog: False
248+
249+
#controlnet
250+
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
251+
controlnet_from_pt: True
252+
controlnet_conditioning_scale: 0.5
253+
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
254+
quantization: ''
255+
# Shard the range finding operation for quantization. By default this is set to number of slices.
256+
quantization_local_shard_count: -1
257+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
258+

0 commit comments

Comments
 (0)