Skip to content

Commit b87443f

Browse files
clean up code and lint
1 parent 8905362 commit b87443f

15 files changed

Lines changed: 352 additions & 648 deletions

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ optax>=0.2.3
99
torch>=2.3.1
1010
torchvision>=0.18.1
1111
ftfy
12-
tensorboard==2.17.0
12+
tensorboard>=2.17.0
1313
tensorboardx==2.6.2.2
1414
tensorboard-plugin-profile==2.15.2
1515
Jinja2
@@ -25,7 +25,7 @@ ruff>=0.1.5,<=0.2
2525
git+https://github.com/mlperf/logging.git
2626
opencv-python-headless==4.10.0.84
2727
orbax-checkpoint==0.10.2
28-
tokenizers==0.20.0
28+
tokenizers==0.21.0
2929
huggingface_hub==0.24.7
3030
transformers==4.48.1
3131
einops==0.8.0

src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,7 @@ def load_checkpoint(self, step=None, scheduler_class=None):
330330
if self.checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
331331
te_pretrained_2_config = CLIPTextConfig(**model_configs[0]["text_encoder_2_config"])
332332
text_encoder_2 = FlaxCLIPTextModelWithProjection(
333-
te_pretrained_2_config,
334-
seed=self.config.seed,
335-
dtype=self.config.activations_dtype,
336-
_do_init=False
333+
te_pretrained_2_config, seed=self.config.seed, dtype=self.config.activations_dtype, _do_init=False
337334
)
338335
pipeline_kwargs["text_encoder_2"] = text_encoder_2
339336
# both tokenizers in sdxl are the same.

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,27 @@ precision: "DEFAULT"
5454
from_pt: True
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash
57-
flash_block_sizes: {}
58-
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
57+
flash_block_sizes: {
58+
"block_q" : 256,
59+
"block_kv_compute" : 256,
60+
"block_kv" : 256,
61+
"block_q_dkv" : 256,
62+
"block_kv_dkv" : 256,
63+
"block_kv_dkv_compute" : 256,
64+
"block_q_dq" : 256,
65+
"block_kv_dq" : 256
66+
}
67+
68+
# Use the following flash_block_sizes on v6e (Trillium).
5969
# flash_block_sizes: {
60-
# "block_q" : 1536,
61-
# "block_kv_compute" : 1536,
62-
# "block_kv" : 1536,
63-
# "block_q_dkv" : 1536,
64-
# "block_kv_dkv" : 1536,
65-
# "block_kv_dkv_compute" : 1536,
66-
# "block_q_dq" : 1536,
67-
# "block_kv_dq" : 1536
70+
# "block_q" : 2176,
71+
# "block_kv_compute" : 2176,
72+
# "block_kv" : 2176,
73+
# "block_q_dkv" : 2176,
74+
# "block_kv_dkv" : 2176,
75+
# "block_kv_dkv_compute" : 2176,
76+
# "block_q_dq" : 2176,
77+
# "block_kv_dq" : 2176
6878
# }
6979
# GroupNorm groups
7080
norm_num_groups: 32

0 commit comments

Comments
 (0)