diff --git a/end_to_end/tpu/test_sdxl_training_loss.sh b/end_to_end/tpu/test_sdxl_training_loss.sh index 44fa2e2cf..78913f6c0 100755 --- a/end_to_end/tpu/test_sdxl_training_loss.sh +++ b/end_to_end/tpu/test_sdxl_training_loss.sh @@ -12,7 +12,7 @@ done TRAIN_CMD="python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \ pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \ revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 metrics_file=metrics.txt write_metrics=True \ - dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_xl resolution=1024 per_device_batch_size=1 \ + dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl resolution=1024 per_device_batch_size=1 \ jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ max_train_steps=$STEPS attention=flash run_name=sdxl-fsdp-v5p-64-ddp enable_profiler=True \ run_name=$RUN_NAME \ output_dir=$OUTPUT_DIR " diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index 78d0e8b2e..a9f17adc3 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -107,7 +107,7 @@ def _time_and_log_call( def calculate_tflops(self, pipeline, params): per_device_tflops = maxdiffusion_utils.calculate_unet_tflops( - self.config, pipeline, (self.config.per_device_batch_size * jax.local_device_count()), self.rng, train=True + self.config, pipeline, (self.config.per_device_batch_size * jax.device_count()), self.rng, train=True ) max_logging.log(f"UNET per device TFLOPS: {per_device_tflops}") return per_device_tflops