Skip to content

Commit 37524bb

Browse files
Test with dot_product attention
1 parent 496f67c commit 37524bb

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
train_text_encoder=false \
5252
cache_latents_text_encoder_outputs=true \
5353
per_device_batch_size=1 \
54-
attention=cudnn_flash_te \
54+
attention=dot_product \
5555
activations_dtype=bfloat16 \
5656
weights_dtype=bfloat16 \
5757
max_train_steps=200 \

src/maxdiffusion/trainers/base_stable_diffusion_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def start_training(self):
161161
params["scheduler"] = noise_scheduler_state
162162

163163
# Calculate tflops
164-
# per_device_tflops = self.calculate_tflops(pipeline, params)
165-
# self.per_device_tflops = per_device_tflops
164+
per_device_tflops = self.calculate_tflops(pipeline, params)
165+
self.per_device_tflops = per_device_tflops
166166

167167
# Load dataset
168168
data_iterator = self._time_and_log_call(self.load_dataset, pipeline, params, train_states)

0 commit comments

Comments
 (0)