Skip to content

Commit 2879a65

Browse files
change configuration to v4-8 (#307)
* change configuration to v4-8 * deprecate tests for older stable diffusion models. * linting * deprecate tests that use sd2base checkpoint. * update wan configs to use default 512 flash block sizes. * fix typos and mispellings. * revert change to scheduler test --------- Co-authored-by: Juan Acevedo <jfacevedo@google.com>
1 parent d1a2b24 commit 2879a65

26 files changed

Lines changed: 56 additions & 46 deletions

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
strategy:
3232
fail-fast: false
3333
matrix:
34-
tpu-type: ["v5p-8"]
34+
tpu-type: ["v4-8"]
3535
name: "TPU test (${{ matrix.tpu-type }})"
3636
runs-on: ["self-hosted","${{ matrix.tpu-type }}"]
3737
steps:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ After installation completes, run the training script.
279279

280280
### Deploying with XPK
281281

282-
This assummes the user has already created an xpk cluster, installed all dependencies and the also created the dataset from the step above. For getting started with MaxDiffusion and xpk see [this guide](docs/getting_started/run_maxdiffusion_via_xpk.md).
282+
This assumes the user has already created an xpk cluster, installed all dependencies and the also created the dataset from the step above. For getting started with MaxDiffusion and xpk see [this guide](docs/getting_started/run_maxdiffusion_via_xpk.md).
283283

284284
Using v5p-256 Then the command to run on xpk is as follows:
285285

docs/data_README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Currently MaxDiffusion supports 3 data input pipelines, controlled by the flag `
55
| Pipeline | Dataset Location | Dataset formats | Features or limitations |
66
| -------- | ---------------- | --------------- | ----------------------- |
77
| HuggingFace (hf) | datasets in HuggingFace Hub or local/Cloud Storage | Formats supported in HF Hub: parquet, arrow, json, csv, txt | data are not loaded in memory but streamed from the saved location, good for big dataset |
8-
| tf | dataset will be downaloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset |
8+
| tf | dataset will be downloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset |
99
| tfrecord | local/Cloud Storage | TFRecord | data are not loaded in memory but streamed from the saved location, good for big dataset |
1010
| Grain | local/Cloud Storage | ArrayRecord (or any random access format) | data are not loaded in memory but streamed from the saved location, good for big dataset, supports global shuffle and data iterator checkpoint for determinism (see details in [doc](https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-pipeline---for-determinism)) |
1111

docs/train_README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ Now let's change the configuration as follows:
129129

130130
Then our mesh will look like `Mesh('data': 2, 'fsdp': 2, 'tensor': 1)`.
131131

132-
The `logical_axis_rules` specifies the sharding across the mesh. You are encouranged to add or remove rules and find what best works for you.
132+
The `logical_axis_rules` specifies the sharding across the mesh. You are encouraged to add or remove rules and find what best works for you.
133133

134134
### Checkpointing
135135

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,15 @@ attention_sharding_uniform: True
7474
dropout: 0.1
7575

7676
flash_block_sizes: {
77-
"block_q" : 2048,
77+
"block_q" : 512,
7878
"block_kv_compute" : 512,
79-
"block_kv" : 2048,
80-
"block_q_dkv" : 2048,
81-
"block_kv_dkv" : 2048,
79+
"block_kv" : 512,
80+
"block_q_dkv" : 512,
81+
"block_kv_dkv" : 512,
8282
"block_kv_dkv_compute" : 512,
83-
"use_fused_bwd_kernel": True
83+
"block_q_dq" : 512,
84+
"block_kv_dq" : 512,
85+
"use_fused_bwd_kernel": False,
8486
}
8587
# Use on v6e
8688
# flash_block_sizes: {

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,15 @@ attention_sharding_uniform: True
7373
dropout: 0.1
7474

7575
flash_block_sizes: {
76-
"block_q" : 1024,
77-
"block_kv_compute" : 256,
78-
"block_kv" : 1024,
79-
"block_q_dkv" : 1024,
80-
"block_kv_dkv" : 1024,
81-
"block_kv_dkv_compute" : 256,
82-
"block_q_dq" : 1024,
83-
"block_kv_dq" : 1024
76+
"block_q" : 512,
77+
"block_kv_compute" : 512,
78+
"block_kv" : 512,
79+
"block_q_dkv" : 512,
80+
"block_kv_dkv" : 512,
81+
"block_kv_dkv_compute" : 512,
82+
"block_q_dq" : 512,
83+
"block_kv_dq" : 512,
84+
"use_fused_bwd_kernel": False,
8485
}
8586
# Use on v6e
8687
# flash_block_sizes: {

src/maxdiffusion/configuration_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __getattr__(self, name: str) -> Any:
145145
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
146146
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
147147
148-
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
148+
Tihs function is mostly copied from PyTorch's __getattr__ overwrite:
149149
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
150150
"""
151151

@@ -540,7 +540,7 @@ def extract_init_dict(cls, config_dict, **kwargs):
540540
f"{cls.config_name} configuration file."
541541
)
542542

543-
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
543+
# 5. Give nice info if config attributes are initialized to default because they have not been passed
544544
passed_keys = set(init_dict.keys())
545545
if len(expected_keys - passed_keys) > 0:
546546
logger.info(f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values.")

src/maxdiffusion/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def loop_body(step, args, model, pipeline, prompt_embeds, guidance_scale, guidan
6666

6767
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
6868
# Helps solve overexposure problem when terminal SNR approaches zero.
69-
# Empirical values recomended from the paper are guidance_scale=7.5 and guidance_rescale=0.7
69+
# Empirical values recommended from the paper are guidance_scale=7.5 and guidance_rescale=0.7
7070
noise_pred = jax.lax.cond(
7171
guidance_rescale[0] > 0,
7272
lambda _: rescale_noise_cfg(noise_pred, noise_prediction_text, guidance_rescale),

src/maxdiffusion/generate_sdxl_replicated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
NUM_DEVICES = jax.device_count()
3333

3434
# 1. Let's start by downloading the model and loading it into our pipeline class
35-
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
35+
# Adhering to JAX's functional approach, the model's parameters are returned separately and
3636
# will have to be passed to the pipeline during inference
3737
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
3838
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
@@ -83,7 +83,7 @@ def replicate_all(prompt_ids, neg_prompt_ids, seed):
8383
# to the function and tell JAX which are static arguments, that is, arguments that
8484
# are known at compile time and won't change. In our case, it is num_inference_steps,
8585
# height, width and return_latents.
86-
# Once the function is compiled, these parameters are ommited from future calls and
86+
# Once the function is compiled, these parameters are omitted from future calls and
8787
# cannot be changed without modifying the code and recompiling.
8888
def aot_compile(
8989
prompt=default_prompt,

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _make_tfrecord_iterator(
9797
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
9898
# if is_training is True, loads the training dataset. If False, loads the evaluation dataset.
9999

100-
# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
100+
# checks that the dataset path is valid. In case of gcs, the existence of the dir is not checked.
101101
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)
102102

103103
# Determine whether to use the "cached" dataset, which requires externally

0 commit comments

Comments
 (0)