-
Notifications
You must be signed in to change notification settings - Fork 69
Multi res support #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Multi res support #159
Changes from 7 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
e177935
fixes accidental wrong merge.
jfacevedo-google 7d3dbd4
precompile generate functions with different dimensions.
jfacevedo-google 3b40223
iterate over different resolutions and store precompiled functions in…
jfacevedo-google e2e2c50
adds a new inference file to show how to precompile for different res…
jfacevedo-google c8b71f1
Merge branch 'main' into multi_res_support
jfacevedo-google e0f8163
formatting
jfacevedo-google c39a741
remove unused dependencies.
jfacevedo-google 91826cf
update config with flux names
jfacevedo-google a53b9bc
Merge branch 'main' into multi_res_support
jfacevedo-google 4f230da
revert to torch 2.5.1 due to compatibility with torchvision.
jfacevedo-google File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,271 @@ | ||
| # Copyright 2023 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # This sentinel is a reminder to choose a real run name. | ||
| run_name: '' | ||
|
|
||
| metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. | ||
| # If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ | ||
| write_metrics: True | ||
| gcs_metrics: False | ||
| # If true save config to GCS in {base_output_directory}/{run_name}/ | ||
| save_config_to_gcs: False | ||
| log_period: 100 | ||
|
|
||
| pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' | ||
| clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' | ||
| t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' | ||
|
|
||
| # Flux params | ||
| flux_name: "flux-dev" | ||
| max_sequence_length: 512 | ||
| time_shift: True | ||
| base_shift: 0.5 | ||
| max_shift: 1.15 | ||
| # offloads t5 encoder after text encoding to save memory. | ||
| offload_encoders: True | ||
|
|
||
|
|
||
| unet_checkpoint: '' | ||
| revision: 'refs/pr/95' | ||
| # This will convert the weights to this dtype. | ||
| # When running inference on TPUv5e, use weights_dtype: 'bfloat16' | ||
| weights_dtype: 'bfloat16' | ||
| # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) | ||
| activations_dtype: 'bfloat16' | ||
|
|
||
| # matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision | ||
| # Options are "DEFAULT", "HIGH", "HIGHEST" | ||
| # fp32 activations and fp32 weights with HIGHEST will provide the best precision | ||
| # at the cost of time. | ||
| precision: "DEFAULT" | ||
|
|
||
| # if False state is not jitted and instead replicate is called. This is good for debugging on single host | ||
| # It must be True for multi-host. | ||
| jit_initializers: True | ||
|
|
||
| # Set true to load weights from pytorch | ||
| from_pt: True | ||
| split_head_dim: True | ||
| attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te | ||
|
|
||
| #flash_block_sizes: {} | ||
| # Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. | ||
| flash_block_sizes: { | ||
| "block_q" : 1536, | ||
| "block_kv_compute" : 1536, | ||
| "block_kv" : 1536, | ||
| "block_q_dkv" : 1536, | ||
| "block_kv_dkv" : 1536, | ||
| "block_kv_dkv_compute" : 1536, | ||
| "block_q_dq" : 1536, | ||
| "block_kv_dq" : 1536 | ||
| } | ||
| # GroupNorm groups | ||
| norm_num_groups: 32 | ||
|
|
||
| # If train_new_unet, unet weights will be randomly initialized to train the unet from scratch | ||
| # else they will be loaded from pretrained_model_name_or_path | ||
| train_new_unet: False | ||
|
|
||
| # train text_encoder - Currently not supported for SDXL | ||
| train_text_encoder: False | ||
| text_encoder_learning_rate: 4.25e-6 | ||
|
|
||
| # https://arxiv.org/pdf/2305.08891.pdf | ||
| snr_gamma: -1.0 | ||
|
|
||
| timestep_bias: { | ||
| # a value of later will increase the frequence of the model's final training steps. | ||
| # none, earlier, later, range | ||
| strategy: "none", | ||
| # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. | ||
| multiplier: 1.0, | ||
| # when using strategy=range, the beginning (inclusive) timestep to bias. | ||
| begin: 0, | ||
| # when using strategy=range, the final step (inclusive) to bias. | ||
| end: 1000, | ||
| # portion of timesteps to bias. | ||
| # 0.5 will bias one half of the timesteps. Value of strategy determines | ||
| # whether the biased portions are in the earlier or later timesteps. | ||
| portion: 0.25 | ||
| } | ||
|
|
||
| # Override parameters from checkpoints's scheduler. | ||
| diffusion_scheduler_config: { | ||
| _class_name: 'FlaxEulerDiscreteScheduler', | ||
| prediction_type: 'epsilon', | ||
| rescale_zero_terminal_snr: False, | ||
| timestep_spacing: 'trailing' | ||
| } | ||
|
|
||
| # Output directory | ||
| # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" | ||
| base_output_directory: "" | ||
|
|
||
| # Hardware | ||
| hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' | ||
|
|
||
| # Parallelism | ||
| mesh_axes: ['data', 'fsdp', 'tensor'] | ||
|
|
||
| # batch : batch dimension of data and activations | ||
| # hidden : | ||
| # embed : attention qkv dense layer hidden dim named as embed | ||
| # heads : attention head dim = num_heads * head_dim | ||
| # length : attention sequence length | ||
| # temb_in : dense.shape[0] of resnet dense before conv | ||
| # out_c : dense.shape[1] of resnet dense before conv | ||
| # out_channels : conv.shape[-1] activation | ||
| # keep_1 : conv.shape[0] weight | ||
| # keep_2 : conv.shape[1] weight | ||
| # conv_in : conv.shape[2] weight | ||
| # conv_out : conv.shape[-1] weight | ||
| logical_axis_rules: [ | ||
| ['batch', 'data'], | ||
| ['activation_batch', ['data','fsdp']], | ||
| ['activation_heads', 'tensor'], | ||
| ['activation_kv', 'tensor'], | ||
| # ['embed','fsdp'], | ||
| ['mlp',['fsdp','tensor']], | ||
| ['heads', 'tensor'], | ||
| ['conv_batch', ['data','fsdp']], | ||
| ['out_channels', 'tensor'], | ||
| ['conv_out', 'fsdp'], | ||
| ] | ||
| data_sharding: [['data', 'fsdp', 'tensor']] | ||
|
|
||
| # One axis for each parallelism type may hold a placeholder (-1) | ||
| # value to auto-shard based on available slices and devices. | ||
| # By default, product of the DCN axes should equal number of slices | ||
| # and product of the ICI axes should equal number of devices per slice. | ||
| dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded | ||
| dcn_fsdp_parallelism: -1 | ||
| dcn_tensor_parallelism: 1 | ||
| ici_data_parallelism: -1 | ||
| ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded | ||
| ici_tensor_parallelism: 1 | ||
|
|
||
| # Dataset | ||
| # Replace with dataset path or train_data_dir. One has to be set. | ||
| dataset_name: 'diffusers/pokemon-gpt4-captions' | ||
| train_split: 'train' | ||
| dataset_type: 'tf' | ||
| cache_latents_text_encoder_outputs: True | ||
| # cache_latents_text_encoder_outputs only apply to dataset_type="tf", | ||
| # only apply to small dataset that fits in memory | ||
| # prepare image latents and text encoder outputs | ||
| # Reduce memory consumption and reduce step time during training | ||
| # transformed dataset is saved at dataset_save_location | ||
| dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' | ||
| train_data_dir: '' | ||
| dataset_config_name: '' | ||
| jax_cache_dir: '' | ||
| hf_data_dir: '' | ||
| hf_train_files: '' | ||
| hf_access_token: '' | ||
| image_column: 'image' | ||
| caption_column: 'text' | ||
| resolution: 1024 | ||
| center_crop: False | ||
| random_flip: False | ||
| # If cache_latents_text_encoder_outputs is True | ||
| # the num_proc is set to 1 | ||
| tokenize_captions_num_proc: 4 | ||
| transform_images_num_proc: 4 | ||
| reuse_example_batch: False | ||
| enable_data_shuffling: True | ||
|
|
||
| # checkpoint every number of samples, -1 means don't checkpoint. | ||
| checkpoint_every: -1 | ||
| # enables one replica to read the ckpt then broadcast to the rest | ||
| enable_single_replica_ckpt_restoring: False | ||
|
|
||
| # Training loop | ||
| learning_rate: 4.e-7 | ||
| scale_lr: False | ||
| max_train_samples: -1 | ||
| # max_train_steps takes priority over num_train_epochs. | ||
| max_train_steps: 200 | ||
| num_train_epochs: 1 | ||
| seed: 0 | ||
| output_dir: 'sdxl-model-finetuned' | ||
| per_device_batch_size: 1 | ||
|
|
||
| warmup_steps_fraction: 0.0 | ||
| learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. | ||
|
|
||
| # However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before | ||
| # dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. | ||
|
|
||
| # AdamW optimizer parameters | ||
| adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. | ||
| adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. | ||
| adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. | ||
| adam_weight_decay: 1.e-2 # AdamW Weight decay | ||
| max_grad_norm: 1.0 | ||
|
|
||
| enable_profiler: False | ||
| # Skip first n steps for profiling, to omit things like compilation and to give | ||
| # the iteration time a chance to stabilize. | ||
| skip_first_n_steps_for_profiler: 5 | ||
| profiler_steps: 10 | ||
|
|
||
| # Generation parameters | ||
| prompt: "A magical castle in the middle of a forest, artistic drawing" | ||
| prompt_2: "A magical castle in the middle of a forest, artistic drawing" | ||
| negative_prompt: "purple, red" | ||
| do_classifier_free_guidance: True | ||
| guidance_scale: 3.5 | ||
| # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | ||
| guidance_rescale: 0.0 | ||
| num_inference_steps: 50 | ||
|
|
||
| # SDXL Lightning parameters | ||
| lightning_from_pt: True | ||
| # Empty or "ByteDance/SDXL-Lightning" to enable lightning. | ||
| lightning_repo: "" | ||
| # Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. | ||
| lightning_ckpt: "" | ||
|
|
||
| # LoRA parameters | ||
| # Values are lists to support multiple LoRA loading during inference in the future. | ||
| lora_config: { | ||
| lora_model_name_or_path: [], | ||
| weight_name: [], | ||
| adapter_name: [], | ||
| scale: [], | ||
| from_pt: [] | ||
| } | ||
| # Ex with values: | ||
| # lora_config : { | ||
| # lora_model_name_or_path: ["ByteDance/Hyper-SD"], | ||
| # weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], | ||
| # adapter_name: ["hyper-sdxl"], | ||
| # scale: [0.7], | ||
| # from_pt: [True] | ||
| # } | ||
|
|
||
| enable_mllog: False | ||
|
|
||
| #controlnet | ||
| controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' | ||
| controlnet_from_pt: True | ||
| controlnet_conditioning_scale: 0.5 | ||
| controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' | ||
| quantization: '' | ||
| # Shard the range finding operation for quantization. By default this is set to number of slices. | ||
| quantization_local_shard_count: -1 | ||
| compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be train_new_flux?