Skip to content

Commit 8274aca

Browse files
authored
Fix flops calculation for multi host and also gcs bucket in end2end test (#217)
1 parent deeb20b commit 8274aca

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

end_to_end/tpu/test_sdxl_training_loss.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ done
1212
TRAIN_CMD="python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \
1313
pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \
1414
revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 metrics_file=metrics.txt write_metrics=True \
15-
dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_xl resolution=1024 per_device_batch_size=1 \
15+
dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl resolution=1024 per_device_batch_size=1 \
1616
jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ max_train_steps=$STEPS attention=flash run_name=sdxl-fsdp-v5p-64-ddp enable_profiler=True \
1717
run_name=$RUN_NAME \
1818
output_dir=$OUTPUT_DIR "

src/maxdiffusion/trainers/base_stable_diffusion_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _time_and_log_call(
107107

108108
def calculate_tflops(self, pipeline, params):
109109
per_device_tflops = maxdiffusion_utils.calculate_unet_tflops(
110-
self.config, pipeline, (self.config.per_device_batch_size * jax.local_device_count()), self.rng, train=True
110+
self.config, pipeline, (self.config.per_device_batch_size * jax.device_count()), self.rng, train=True
111111
)
112112
max_logging.log(f"UNET per device TFLOPS: {per_device_tflops}")
113113
return per_device_tflops

0 commit comments

Comments
 (0)