Skip to content

Commit f71031a

Browse files
authored
fixes failing smoke tests (#134)
* fixes failing smoke tests
1 parent f4b9042 commit f71031a

5 files changed

Lines changed: 7 additions & 5 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ MaxDiffusion started as a fork of [Diffusers](https://github.com/huggingface/dif
188188

189189
Whether you are forking MaxDiffusion for your own needs or intending to contribute back to the community, a full suite of tests can be found in `tests` and `src/maxdiffusion/tests`.
190190

191-
To run unit tests, simply run:
191+
To run unit tests simply run:
192192
```
193193
python -m pytest
194194
```

docs/getting_started/first_run.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ multiple hosts.
1212
1. Clone MaxDiffusion in your TPU VM.
1313
1. Within the root directory of the MaxDiffusion `git` repo, install dependencies by running:
1414
```bash
15-
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
16-
pip3 install -r requirements.txt
17-
pip3 install .
15+
bash setup.sh MODE=stable
1816
```
1917

2018
## Getting Starting: Multihost development

setup.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,6 @@ fi
108108

109109
# Install dependencies from requirements.txt
110110
pip3 install -U -r requirements.txt
111+
112+
# Install maxdiffusion
113+
pip3 install -U .

src/maxdiffusion/configs/base_xl_lightning.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ ici_tensor_parallelism: 1
110110
# Dataset
111111
# Replace with dataset path or train_data_dir. One has to be set.
112112
dataset_name: ''
113+
dataset_type: 'tf'
113114
train_data_dir: ''
114115
dataset_config_name: ''
115116
jax_cache_dir: ''

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def write_metrics_for_gcs(metrics, step, config, running_metrics):
121121
"""Writes metrics to gcs"""
122122
metrics_dict_step = _prepare_metrics_for_json(metrics, step, config.run_name)
123123
running_metrics.append(metrics_dict_step)
124-
if (step + 1) % config.log_period == 0 or step == config.steps - 1:
124+
if (step + 1) % config.log_period == 0 or step == config.max_train_steps - 1:
125125
start_step = (step // config.log_period) * config.log_period
126126
metrics_filename = f"metrics_step_{start_step:06}_to_step_{step:06}.txt"
127127
with open(metrics_filename, "w", encoding="utf8") as metrics_for_gcs:

0 commit comments

Comments
 (0)