Skip to content

Commit 18250c5

Browse files
jfacevedo-googleksikiric
authored andcommitted
clean up code and lint
1 parent ff24ee1 commit 18250c5

15 files changed

Lines changed: 349 additions & 818 deletions

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ ruff>=0.1.5,<=0.2
2525
git+https://github.com/mlperf/logging.git
2626
opencv-python-headless==4.10.0.84
2727
orbax-checkpoint==0.10.2
28-
tokenizers==0.20.0
28+
tokenizers==0.21.0
2929
huggingface_hub==0.24.7
3030
transformers==4.48.1
3131
einops==0.8.0

src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,10 +336,7 @@ def load_checkpoint(self, step=None, scheduler_class=None):
336336
if self.checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
337337
te_pretrained_2_config = CLIPTextConfig(**model_configs[0]["text_encoder_2_config"])
338338
text_encoder_2 = FlaxCLIPTextModelWithProjection(
339-
te_pretrained_2_config,
340-
seed=self.config.seed,
341-
dtype=self.config.activations_dtype,
342-
_do_init=False
339+
te_pretrained_2_config, seed=self.config.seed, dtype=self.config.activations_dtype, _do_init=False
343340
)
344341
pipeline_kwargs["text_encoder_2"] = text_encoder_2
345342
# both tokenizers in sdxl are the same.

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,27 @@ precision: "DEFAULT"
5454
from_pt: True
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash
57-
flash_block_sizes: {}
58-
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
57+
flash_block_sizes: {
58+
"block_q" : 256,
59+
"block_kv_compute" : 256,
60+
"block_kv" : 256,
61+
"block_q_dkv" : 256,
62+
"block_kv_dkv" : 256,
63+
"block_kv_dkv_compute" : 256,
64+
"block_q_dq" : 256,
65+
"block_kv_dq" : 256
66+
}
67+
68+
# Use the following flash_block_sizes on v6e (Trillium).
5969
# flash_block_sizes: {
60-
# "block_q" : 1536,
61-
# "block_kv_compute" : 1536,
62-
# "block_kv" : 1536,
63-
# "block_q_dkv" : 1536,
64-
# "block_kv_dkv" : 1536,
65-
# "block_kv_dkv_compute" : 1536,
66-
# "block_q_dq" : 1536,
67-
# "block_kv_dq" : 1536
70+
# "block_q" : 2176,
71+
# "block_kv_compute" : 2176,
72+
# "block_kv" : 2176,
73+
# "block_q_dkv" : 2176,
74+
# "block_kv_dkv" : 2176,
75+
# "block_kv_dkv_compute" : 2176,
76+
# "block_q_dq" : 2176,
77+
# "block_kv_dq" : 2176
6878
# }
6979
# GroupNorm groups
7080
norm_num_groups: 32

0 commit comments

Comments
 (0)