Skip to content

Commit 1ea6590

Browse files
committed
ruff check
2 parents d06dee3 + 6c52603 commit 1ea6590

36 files changed

Lines changed: 4186 additions & 624 deletions

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
ruff check .
5151
- name: PyTest
5252
run: |
53-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest
53+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x
5454
# add_pull_ready:
5555
# if: github.ref != 'refs/heads/main'
5656
# permissions:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ MaxDiffusion supports
3535
* Stable Diffusion 2 base (training and inference)
3636
* Stable Diffusion 2.1 (training and inference)
3737
* Stable Diffusion XL (training and inference).
38+
* Flux Dev and Schnell (Training and inference).
3839
* Stable Diffusion Lightning (inference).
3940
* Hyper-SD XL LoRA loading (inference).
4041
* Load Multiple LoRA (SDXL inference).
4142
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
4243
* Dreambooth training support for Stable Diffusion 1.x,2.x.
4344

44-
**WARNING: The training code is purely experimental and is under development.**
4545

4646
# Table of Contents
4747

requirements.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
jax>=0.4.30
1+
--extra-index-url https://download.pytorch.org/whl/cpu
2+
jax==0.5.3
23
jaxlib>=0.4.30
34
grain-nightly==0.0.10
45
google-cloud-storage==2.17.0
56
absl-py
67
datasets
78
flax>=0.10.2
89
optax>=0.2.3
9-
torch==2.5.1
10-
torchvision==0.20.1
10+
torch==2.6.0
11+
torchvision>=0.20.1
1112
ftfy
1213
tensorboard>=2.17.0
1314
tensorboardx==2.6.2.2
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from abc import ABC
18+
from flax import nnx
19+
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
20+
from ..pipelines.wan.wan_pipeline import WanPipeline
21+
from .. import max_logging, max_utils
22+
23+
WAN_CHECKPOINT = "WAN_CHECKPOINT"
24+
25+
26+
class WanCheckpointer(ABC):
27+
28+
def __init__(self, config, checkpoint_type):
29+
self.config = config
30+
self.checkpoint_type = checkpoint_type
31+
32+
self.checkpoint_manager = create_orbax_checkpoint_manager(
33+
self.config.checkpoint_dir,
34+
enable_checkpointing=True,
35+
save_interval_steps=1,
36+
checkpoint_type=checkpoint_type,
37+
dataset_type=config.dataset_type,
38+
)
39+
40+
def _create_optimizer(self, model, config, learning_rate):
41+
learning_rate_scheduler = max_utils.create_learning_rate_schedule(
42+
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
43+
)
44+
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
45+
return nnx.Optimizer(model, tx), learning_rate_scheduler
46+
47+
def load_wan_configs_from_orbax(self, step):
48+
max_logging.log("Restoring stable diffusion configs")
49+
if step is None:
50+
step = self.checkpoint_manager.latest_step()
51+
if step is None:
52+
return None
53+
54+
def load_diffusers_checkpoint(self):
55+
pipeline = WanPipeline.from_pretrained(self.config)
56+
return pipeline
57+
58+
def load_checkpoint(self, step=None):
59+
model_configs = self.load_wan_configs_from_orbax(step)
60+
61+
if model_configs:
62+
raise NotImplementedError("model configs should not exist in orbax")
63+
else:
64+
pipeline = self.load_diffusers_checkpoint()
65+
66+
return pipeline
Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,22 @@ run_name: ''
1818
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
21+
22+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
23+
write_timing_metrics: True
24+
2125
gcs_metrics: False
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False
2428
log_period: 100
2529

2630
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
2731

28-
# Flux params
29-
flux_name: "flux-dev"
30-
max_sequence_length: 512
31-
time_shift: True
32-
base_shift: 0.5
33-
max_shift: 1.15
34-
# offloads t5 encoder after text encoding to save memory.
35-
offload_encoders: True
36-
32+
# Overrides the transformer from pretrained_model_name_or_path
33+
wan_transformer_pretrained_model_name_or_path: ''
3734

3835
unet_checkpoint: ''
39-
revision: 'refs/pr/95'
36+
revision: ''
4037
# This will convert the weights to this dtype.
4138
# When running inference on TPUv5e, use weights_dtype: 'bfloat16'
4239
weights_dtype: 'bfloat16'
@@ -59,24 +56,9 @@ split_head_dim: True
5956
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
6057

6158
flash_block_sizes: {}
62-
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
63-
# flash_block_sizes: {
64-
# "block_q" : 1536,
65-
# "block_kv_compute" : 1536,
66-
# "block_kv" : 1536,
67-
# "block_q_dkv" : 1536,
68-
# "block_kv_dkv" : 1536,
69-
# "block_kv_dkv_compute" : 1536,
70-
# "block_q_dq" : 1536,
71-
# "block_kv_dq" : 1536
72-
# }
7359
# GroupNorm groups
7460
norm_num_groups: 32
7561

76-
# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch
77-
# else they will be loaded from pretrained_model_name_or_path
78-
train_new_unet: False
79-
8062
# train text_encoder - Currently not supported for SDXL
8163
train_text_encoder: False
8264
text_encoder_learning_rate: 4.25e-6
@@ -133,15 +115,17 @@ mesh_axes: ['data', 'fsdp', 'tensor']
133115
# conv_out : conv.shape[-1] weight
134116
logical_axis_rules: [
135117
['batch', 'data'],
118+
['activation_heads', 'fsdp'],
136119
['activation_batch', ['data','fsdp']],
137-
['activation_heads', 'tensor'],
138120
['activation_kv', 'tensor'],
139121
['mlp','tensor'],
140122
['embed','fsdp'],
141123
['heads', 'tensor'],
124+
['norm', 'fsdp'],
142125
['conv_batch', ['data','fsdp']],
143126
['out_channels', 'tensor'],
144127
['conv_out', 'fsdp'],
128+
['conv_in', 'fsdp']
145129
]
146130
data_sharding: [['data', 'fsdp', 'tensor']]
147131

@@ -152,22 +136,23 @@ data_sharding: [['data', 'fsdp', 'tensor']]
152136
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
153137
dcn_fsdp_parallelism: -1
154138
dcn_tensor_parallelism: 1
155-
ici_data_parallelism: -1
156-
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
139+
ici_data_parallelism: 1
140+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
157141
ici_tensor_parallelism: 1
158142

159143
# Dataset
160144
# Replace with dataset path or train_data_dir. One has to be set.
161145
dataset_name: 'diffusers/pokemon-gpt4-captions'
162146
train_split: 'train'
163-
dataset_type: 'tf'
147+
dataset_type: 'tfrecord'
164148
cache_latents_text_encoder_outputs: True
165149
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
166150
# only apply to small dataset that fits in memory
167151
# prepare image latents and text encoder outputs
168152
# Reduce memory consumption and reduce step time during training
169153
# transformed dataset is saved at dataset_save_location
170-
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
154+
dataset_save_location: ''
155+
load_tfrecord_cached: True
171156
train_data_dir: ''
172157
dataset_config_name: ''
173158
jax_cache_dir: ''
@@ -192,17 +177,23 @@ checkpoint_every: -1
192177
enable_single_replica_ckpt_restoring: False
193178

194179
# Training loop
195-
learning_rate: 4.e-7
180+
learning_rate: 1.e-5
196181
scale_lr: False
197182
max_train_samples: -1
198183
# max_train_steps takes priority over num_train_epochs.
199-
max_train_steps: 200
184+
max_train_steps: 1500
200185
num_train_epochs: 1
201186
seed: 0
202187
output_dir: 'sdxl-model-finetuned'
203188
per_device_batch_size: 1
189+
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
190+
global_batch_size: 0
191+
192+
# For creating tfrecords from dataset
193+
tfrecords_dir: ''
194+
no_records_per_shard: 0
204195

205-
warmup_steps_fraction: 0.0
196+
warmup_steps_fraction: 0.1
206197
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
207198

208199
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
@@ -212,7 +203,7 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set
212203
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
213204
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
214205
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
215-
adam_weight_decay: 1.e-2 # AdamW Weight decay
206+
adam_weight_decay: 0 # AdamW Weight decay
216207
max_grad_norm: 1.0
217208

218209
enable_profiler: False
@@ -222,14 +213,25 @@ skip_first_n_steps_for_profiler: 5
222213
profiler_steps: 10
223214

224215
# Generation parameters
225-
prompt: "A magical castle in the middle of a forest, artistic drawing"
226-
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
227-
negative_prompt: "purple, red"
216+
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
217+
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
218+
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
228219
do_classifier_free_guidance: True
229-
guidance_scale: 3.5
220+
height: 480
221+
width: 832
222+
num_frames: 81
223+
guidance_scale: 5.0
224+
flow_shift: 3.0
225+
226+
# skip layer guidance
227+
slg_layers: [9]
228+
slg_start: 0.2
229+
slg_end: 1.0
230230
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
231231
guidance_rescale: 0.0
232-
num_inference_steps: 50
232+
num_inference_steps: 30
233+
fps: 24
234+
save_final_checkpoint: False
233235

234236
# SDXL Lightning parameters
235237
lightning_from_pt: True
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""

0 commit comments

Comments
 (0)