-
Notifications
You must be signed in to change notification settings - Fork 69
Expand file tree
/
Copy pathbase_wan_i2v_27b.yml
More file actions
388 lines (341 loc) · 15.6 KB
/
base_wan_i2v_27b.yml
File metadata and controls
388 lines (341 loc) · 15.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This sentinel is a reminder to choose a real run name.
run_name: ''
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
write_metrics: True
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
write_timing_metrics: True
gcs_metrics: False
# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
log_period: 100
pretrained_model_name_or_path: 'Wan-AI/Wan2.2-I2V-A14B-Diffusers'
model_name: wan2.2
model_type: 'I2V'
# Overrides the transformer from pretrained_model_name_or_path
wan_transformer_pretrained_model_name_or_path: ''
unet_checkpoint: ''
revision: ''
# This will convert the weights to this dtype.
# When running inference on TPUv5e, use weights_dtype: 'bfloat16'
weights_dtype: 'bfloat16'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
# at the cost of time.
precision: "DEFAULT"
# Use jax.lax.scan for transformer layers
scan_layers: True
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
flash_min_seq_length: 4096
dropout: 0.0
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
attention_sharding_uniform: True
flash_block_sizes: {
"block_q" : 1024,
"block_kv_compute" : 256,
"block_kv" : 1024,
"block_q_dkv" : 1024,
"block_kv_dkv" : 1024,
"block_kv_dkv_compute" : 256,
"block_q_dq" : 1024,
"block_kv_dq" : 1024
}
# Use on v6e
# flash_block_sizes: {
# "block_q" : 3024,
# "block_kv_compute" : 1024,
# "block_kv" : 2048,
# "block_q_dkv" : 3024,
# "block_kv_dkv" : 2048,
# "block_kv_dkv_compute" : 2048,
# "block_q_dq" : 3024,
# "block_kv_dq" : 2048
# "use_fused_bwd_kernel": False,
# }
# GroupNorm groups
norm_num_groups: 32
# train text_encoder - Currently not supported for SDXL
train_text_encoder: False
text_encoder_learning_rate: 4.25e-6
# https://arxiv.org/pdf/2305.08891.pdf
snr_gamma: -1.0
timestep_bias: {
# a value of later will increase the frequence of the model's final training steps.
# none, earlier, later, range
strategy: "none",
# multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it.
multiplier: 1.0,
# when using strategy=range, the beginning (inclusive) timestep to bias.
begin: 0,
# when using strategy=range, the final step (inclusive) to bias.
end: 1000,
# portion of timesteps to bias.
# 0.5 will bias one half of the timesteps. Value of strategy determines
# whether the biased portions are in the earlier or later timesteps.
portion: 0.25
}
# Override parameters from checkpoints's scheduler.
diffusion_scheduler_config: {
_class_name: 'FlaxEulerDiscreteScheduler',
prediction_type: 'epsilon',
rescale_zero_terminal_snr: False,
timestep_spacing: 'trailing'
}
# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
base_output_directory: ""
# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False
# Parallelism
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
# batch : batch dimension of data and activations
# hidden :
# embed : attention qkv dense layer hidden dim named as embed
# heads : attention head dim = num_heads * head_dim
# length : attention sequence length
# temb_in : dense.shape[0] of resnet dense before conv
# out_c : dense.shape[1] of resnet dense before conv
# out_channels : conv.shape[-1] activation
# keep_1 : conv.shape[0] weight
# keep_2 : conv.shape[1] weight
# conv_in : conv.shape[2] weight
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', ['data', 'fsdp']],
['activation_batch', ['data', 'fsdp']],
['activation_self_attn_heads', ['context', 'tensor']],
['activation_cross_attn_q_length', ['context', 'tensor']],
['activation_length', 'context'],
['activation_heads', 'tensor'],
['mlp','tensor'],
['embed', ['context', 'fsdp']],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data', 'context', 'fsdp']],
['out_channels', 'tensor'],
['conv_out', 'context'],
]
vae_logical_axis_rules: [
['activation_batch', 'redundant'],
['activation_length', 'vae_spatial'],
['activation_heads', null],
['activation_kv_length', null],
['embed', null],
['heads', null],
['norm', null],
['conv_batch', 'redundant'],
['out_channels', 'vae_spatial'],
['conv_out', 'vae_spatial'],
['conv_in', 'vae_spatial'],
]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_context_parallelism: 1
ici_tensor_parallelism: 1
allow_split_physical_axes: False
# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tfrecord'
cache_latents_text_encoder_outputs: True
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
# only apply to small dataset that fits in memory
# prepare image latents and text encoder outputs
# Reduce memory consumption and reduce step time during training
# transformed dataset is saved at dataset_save_location
dataset_save_location: ''
load_tfrecord_cached: True
train_data_dir: ''
dataset_config_name: ''
jax_cache_dir: ''
hf_data_dir: ''
hf_train_files: ''
hf_access_token: ''
image_column: 'image'
caption_column: 'text'
resolution: 1024
center_crop: False
random_flip: False
# If cache_latents_text_encoder_outputs is True
# the num_proc is set to 1
tokenize_captions_num_proc: 4
transform_images_num_proc: 4
reuse_example_batch: False
enable_data_shuffling: True
# Defines the type of gradient checkpoint to enable.
# NONE - means no gradient checkpoint
# FULL - means full gradient checkpoint, whenever possible (minimum memory usage)
# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
# except for ones that involve batch dimension - that means that all attention and projection
# layers will have gradient checkpoint, but not the backward with respect to the parameters.
# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing.
# CUSTOM - set names to offload and save.
remat_policy: "NONE"
# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj
# xq_out, xk_out, ffn_activation
names_which_can_be_saved: []
names_which_can_be_offloaded: []
# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
checkpoint_dir: ""
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False
# Training loop
learning_rate: 1.e-5
scale_lr: False
max_train_samples: -1
# max_train_steps takes priority over num_train_epochs.
max_train_steps: 1500
num_train_epochs: 1
seed: 0
output_dir: 'sdxl-model-finetuned'
per_device_batch_size: 1.0
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
global_batch_size: 0
# For creating tfrecords from dataset
tfrecords_dir: ''
no_records_per_shard: 0
enable_eval_timesteps: False
timesteps_list: [125, 250, 375, 500, 625, 750, 875]
num_eval_samples: 420
warmup_steps_fraction: 0.1
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
save_optimizer: False
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
# AdamW optimizer parameters
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_weight_decay: 0.0 # AdamW Weight decay
opt_enable_grad_clipping: False
max_grad_value: 1.0
opt_enable_grad_global_norm_clipping: False
max_grad_norm: 1.0
enable_profiler: False
# Skip first n steps for profiling, to omit things like compilation and to give
# the iteration time a chance to stabilize.
skip_first_n_steps_for_profiler: 5
profiler_steps: 10
# Enable JAX named scopes for detailed profiling and debugging
# When enabled, adds named scopes around key operations in transformer and attention layers
enable_jax_named_scopes: False
# Generation parameters
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
do_classifier_free_guidance: True
height: 720
width: 1280
num_frames: 81
flow_shift: 5.0
# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py
# guidance scale factor for low noise transformer
guidance_scale_low: 3.0
# guidance scale factor for high noise transformer
guidance_scale_high: 4.0
# The timestep threshold. If `t` is at or above this value,
# the `high_noise_model` is considered as the required model.
# timestep to switch between low noise and high noise transformer
boundary_ratio: 0.875
# Diffusion CFG cache (FasterCache-style)
use_cfg_cache: False
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
use_sen_cache: False
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 50
fps: 16
save_final_checkpoint: False
# SDXL Lightning parameters
lightning_from_pt: True
# Empty or "ByteDance/SDXL-Lightning" to enable lightning.
lightning_repo: ""
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
lightning_ckpt: ""
# LoRA parameters
enable_lora: False
# Values are lists to support multiple LoRA loading during inference in the future.
lora_config: {
rank: [64, 16],
lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras", "ostris/wan22_i2v_14b_orbit_shot_lora"],
high_noise_weight_name: ["wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors", "wan22_14b_i2v_orbit_high_noise.safetensors"],
low_noise_weight_name: ["wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors", "wan22_14b_i2v_orbit_low_noise.safetensors"], # Empty or "wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors"
adapter_name: ["wan22-distill-lora", "wan22-orbit-lora"],
scale: [1.0, 1.0],
from_pt: []
}
# Ex with values:
# lora_config : {
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
# adapter_name: ["hyper-sdxl"],
# scale: [0.7],
# from_pt: [True]
# }
enable_mllog: False
#controlnet
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
controlnet_from_pt: True
controlnet_conditioning_scale: 0.5
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
quantization: ''
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
quantization_calibration_method: "absmax"
qwix_module_path: ".*"
# Eval model on per eval_every steps. -1 means don't eval.
eval_every: -1
eval_data_dir: ""
enable_generate_video_for_eval: False # This will increase the used TPU memory.
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).
enable_ssim: False
# i2v specific parameters
# I2V Input Image
# URL or local path to the conditioning image
image_url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False