Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion end_to_end/tpu/test_sdxl_training_loss.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ done
TRAIN_CMD="python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \
pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \
revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 metrics_file=metrics.txt write_metrics=True \
dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_xl resolution=1024 per_device_batch_size=1 \
dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl resolution=1024 per_device_batch_size=1 \
jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ max_train_steps=$STEPS attention=flash run_name=sdxl-fsdp-v5p-64-ddp enable_profiler=True \
run_name=$RUN_NAME \
output_dir=$OUTPUT_DIR "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _time_and_log_call(

def calculate_tflops(self, pipeline, params):
per_device_tflops = maxdiffusion_utils.calculate_unet_tflops(
self.config, pipeline, (self.config.per_device_batch_size * jax.local_device_count()), self.rng, train=True
self.config, pipeline, (self.config.per_device_batch_size * jax.device_count()), self.rng, train=True
)
max_logging.log(f"UNET per device TFLOPS: {per_device_tflops}")
return per_device_tflops
Expand Down
Loading