Skip to content

Commit f6ae9ae

Browse files
committed
Adding I2V lora configs
1 parent c199584 commit f6ae9ae

2 files changed

Lines changed: 703 additions & 0 deletions

File tree

Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This sentinel is a reminder to choose a real run name.
16+
run_name: ''
17+
18+
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
19+
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
20+
write_metrics: True
21+
22+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
23+
write_timing_metrics: True
24+
25+
gcs_metrics: False
26+
# If true save config to GCS in {base_output_directory}/{run_name}/
27+
save_config_to_gcs: False
28+
log_period: 100
29+
30+
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-I2V-14B-720P-Diffusers'
31+
model_name: wan2.1
32+
model_type: 'I2V'
33+
34+
# Overrides the transformer from pretrained_model_name_or_path
35+
wan_transformer_pretrained_model_name_or_path: ''
36+
37+
unet_checkpoint: ''
38+
revision: ''
39+
# This will convert the weights to this dtype.
40+
# When running inference on TPUv5e, use weights_dtype: 'bfloat16'
41+
weights_dtype: 'bfloat16'
42+
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
43+
activations_dtype: 'bfloat16'
44+
45+
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
46+
replicate_vae: False
47+
48+
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
49+
# Options are "DEFAULT", "HIGH", "HIGHEST"
50+
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
51+
# at the cost of time.
52+
precision: "DEFAULT"
53+
# Use jax.lax.scan for transformer layers
54+
scan_layers: True
55+
56+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
57+
# It must be True for multi-host.
58+
jit_initializers: True
59+
60+
# Set true to load weights from pytorch
61+
from_pt: True
62+
split_head_dim: True
63+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
64+
flash_min_seq_length: 4096
65+
dropout: 0.1
66+
67+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
68+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
69+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
70+
mask_padding_tokens: True
71+
attention_sharding_uniform: True
72+
73+
flash_block_sizes: {
74+
"block_q" : 2048,
75+
"block_kv_compute" : 512,
76+
"block_kv" : 2048,
77+
"block_q_dkv" : 2048,
78+
"block_kv_dkv" : 2048,
79+
"block_kv_dkv_compute" : 512,
80+
"use_fused_bwd_kernel" : True
81+
}
82+
# Use on v6e
83+
# flash_block_sizes: {
84+
# "block_q" : 3024,
85+
# "block_kv_compute" : 1024,
86+
# "block_kv" : 2048,
87+
# "block_q_dkv" : 3024,
88+
# "block_kv_dkv" : 2048,
89+
# "block_kv_dkv_compute" : 2048,
90+
# "block_q_dq" : 3024,
91+
# "block_kv_dq" : 2048,
92+
# "use_fused_bwd_kernel": False,
93+
# }
94+
# GroupNorm groups
95+
norm_num_groups: 32
96+
97+
# train text_encoder - Currently not supported for SDXL
98+
train_text_encoder: False
99+
text_encoder_learning_rate: 4.25e-6
100+
101+
# https://arxiv.org/pdf/2305.08891.pdf
102+
snr_gamma: -1.0
103+
104+
timestep_bias: {
105+
# a value of later will increase the frequence of the model's final training steps.
106+
# none, earlier, later, range
107+
strategy: "none",
108+
# multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it.
109+
multiplier: 1.0,
110+
# when using strategy=range, the beginning (inclusive) timestep to bias.
111+
begin: 0,
112+
# when using strategy=range, the final step (inclusive) to bias.
113+
end: 1000,
114+
# portion of timesteps to bias.
115+
# 0.5 will bias one half of the timesteps. Value of strategy determines
116+
# whether the biased portions are in the earlier or later timesteps.
117+
portion: 0.25
118+
}
119+
120+
# Override parameters from checkpoints's scheduler.
121+
# Don't override _class_name - use the pretrained UniPCMultistepScheduler
122+
diffusion_scheduler_config: {
123+
prediction_type: 'flow_prediction',
124+
rescale_zero_terminal_snr: False,
125+
timestep_spacing: 'linspace'
126+
}
127+
128+
# Output directory
129+
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
130+
base_output_directory: ""
131+
132+
# Hardware
133+
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
134+
skip_jax_distributed_system: False
135+
136+
# Parallelism
137+
mesh_axes: ['data', 'fsdp', 'tensor']
138+
139+
# batch : batch dimension of data and activations
140+
# hidden :
141+
# embed : attention qkv dense layer hidden dim named as embed
142+
# heads : attention head dim = num_heads * head_dim
143+
# length : attention sequence length
144+
# temb_in : dense.shape[0] of resnet dense before conv
145+
# out_c : dense.shape[1] of resnet dense before conv
146+
# out_channels : conv.shape[-1] activation
147+
# keep_1 : conv.shape[0] weight
148+
# keep_2 : conv.shape[1] weight
149+
# conv_in : conv.shape[2] weight
150+
# conv_out : conv.shape[-1] weight
151+
logical_axis_rules: [
152+
['batch', 'data'],
153+
['activation_batch', 'data'],
154+
['activation_self_attn_heads', ['fsdp', 'tensor']],
155+
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
156+
['activation_length', 'fsdp'],
157+
['activation_heads', 'tensor'],
158+
['mlp','tensor'],
159+
['embed','fsdp'],
160+
['heads', 'tensor'],
161+
['norm', 'tensor'],
162+
['conv_batch', ['data','fsdp']],
163+
['out_channels', 'tensor'],
164+
['conv_out', 'fsdp'],
165+
]
166+
data_sharding: [['data', 'fsdp', 'tensor']]
167+
168+
# One axis for each parallelism type may hold a placeholder (-1)
169+
# value to auto-shard based on available slices and devices.
170+
# By default, product of the DCN axes should equal number of slices
171+
# and product of the ICI axes should equal number of devices per slice.
172+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
173+
dcn_fsdp_parallelism: -1
174+
dcn_tensor_parallelism: 1
175+
ici_data_parallelism: 1
176+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
177+
ici_tensor_parallelism: 1
178+
179+
allow_split_physical_axes: False
180+
181+
# Dataset
182+
# Replace with dataset path or train_data_dir. One has to be set.
183+
dataset_name: 'diffusers/pokemon-gpt4-captions'
184+
train_split: 'train'
185+
dataset_type: 'tfrecord'
186+
cache_latents_text_encoder_outputs: True
187+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
188+
# only apply to small dataset that fits in memory
189+
# prepare image latents and text encoder outputs
190+
# Reduce memory consumption and reduce step time during training
191+
# transformed dataset is saved at dataset_save_location
192+
dataset_save_location: ''
193+
load_tfrecord_cached: True
194+
train_data_dir: ''
195+
dataset_config_name: ''
196+
jax_cache_dir: ''
197+
hf_data_dir: ''
198+
hf_train_files: ''
199+
hf_access_token: ''
200+
image_column: 'image'
201+
caption_column: 'text'
202+
resolution: 1024
203+
center_crop: False
204+
random_flip: False
205+
# If cache_latents_text_encoder_outputs is True
206+
# the num_proc is set to 1
207+
tokenize_captions_num_proc: 4
208+
transform_images_num_proc: 4
209+
reuse_example_batch: False
210+
enable_data_shuffling: True
211+
212+
# Defines the type of gradient checkpoint to enable.
213+
# NONE - means no gradient checkpoint
214+
# FULL - means full gradient checkpoint, whenever possible (minimum memory usage)
215+
# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
216+
# except for ones that involve batch dimension - that means that all attention and projection
217+
# layers will have gradient checkpoint, but not the backward with respect to the parameters.
218+
# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing.
219+
# CUSTOM - set names to offload and save.
220+
remat_policy: "NONE"
221+
# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj
222+
# xq_out, xk_out, ffn_activation
223+
names_which_can_be_saved: []
224+
names_which_can_be_offloaded: []
225+
226+
# checkpoint every number of samples, -1 means don't checkpoint.
227+
checkpoint_every: -1
228+
checkpoint_dir: ""
229+
# enables one replica to read the ckpt then broadcast to the rest
230+
enable_single_replica_ckpt_restoring: False
231+
232+
# Training loop
233+
learning_rate: 1.e-5
234+
scale_lr: False
235+
max_train_samples: -1
236+
# max_train_steps takes priority over num_train_epochs.
237+
max_train_steps: 1500
238+
num_train_epochs: 1
239+
seed: 0
240+
output_dir: 'sdxl-model-finetuned'
241+
per_device_batch_size: 1.0
242+
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
243+
global_batch_size: 0
244+
245+
# For creating tfrecords from dataset
246+
tfrecords_dir: ''
247+
no_records_per_shard: 0
248+
enable_eval_timesteps: False
249+
timesteps_list: [125, 250, 375, 500, 625, 750, 875]
250+
num_eval_samples: 420
251+
252+
warmup_steps_fraction: 0.1
253+
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
254+
save_optimizer: False
255+
256+
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
257+
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
258+
259+
# AdamW optimizer parameters
260+
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
261+
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
262+
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
263+
adam_weight_decay: 0 # AdamW Weight decay
264+
max_grad_norm: 1.0
265+
266+
enable_profiler: False
267+
# Skip first n steps for profiling, to omit things like compilation and to give
268+
# the iteration time a chance to stabilize.
269+
skip_first_n_steps_for_profiler: 5
270+
profiler_steps: 10
271+
272+
# Enable JAX named scopes for detailed profiling and debugging
273+
# When enabled, adds named scopes around key operations in transformer and attention layers
274+
enable_jax_named_scopes: False
275+
276+
# Generation parameters
277+
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
278+
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
279+
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
280+
do_classifier_free_guidance: True
281+
height: 480
282+
width: 832
283+
num_frames: 81
284+
guidance_scale: 5.0
285+
flow_shift: 3.0
286+
287+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
288+
guidance_rescale: 0.0
289+
num_inference_steps: 30
290+
fps: 24
291+
save_final_checkpoint: False
292+
293+
# SDXL Lightning parameters
294+
lightning_from_pt: True
295+
# Empty or "ByteDance/SDXL-Lightning" to enable lightning.
296+
lightning_repo: ""
297+
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
298+
lightning_ckpt: ""
299+
300+
# LoRA parameters
301+
# Values are lists to support multiple LoRA loading during inference in the future.
302+
lora_config: {
303+
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"],
304+
weight_name: ["wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors"],
305+
adapter_name: ["wan21-distill-lora-i2v"],
306+
scale: [],
307+
from_pt: []
308+
}
309+
# Ex with values:
310+
# lora_config : {
311+
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
312+
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
313+
# adapter_name: ["hyper-sdxl"],
314+
# scale: [0.7],
315+
# from_pt: [True]
316+
# }
317+
318+
enable_mllog: False
319+
320+
#controlnet
321+
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
322+
controlnet_from_pt: True
323+
controlnet_conditioning_scale: 0.5
324+
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
325+
quantization: ''
326+
# Shard the range finding operation for quantization. By default this is set to number of slices.
327+
quantization_local_shard_count: -1
328+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
329+
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
330+
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
331+
quantization_calibration_method: "absmax"
332+
qwix_module_path: ".*"
333+
334+
# Eval model on per eval_every steps. -1 means don't eval.
335+
eval_every: -1
336+
eval_data_dir: ""
337+
enable_generate_video_for_eval: False # This will increase the used TPU memory.
338+
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).
339+
340+
enable_ssim: False
341+
342+
# i2v specific parameters
343+
# I2V Input Image
344+
# URL or local path to the conditioning image
345+
image_url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"

0 commit comments

Comments
 (0)