Skip to content

Commit 0db81e3

Browse files
committed
lint
1 parent 618550c commit 0db81e3

2 files changed

Lines changed: 6 additions & 9 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,11 @@ def cudnn_flash_attention(
193193
value = nn.with_logical_constraint(value, axis_names)
194194

195195
@functools.partial(
196-
shard_map.shard_map,
197-
mesh=self.mesh,
198-
in_specs=(
199-
axis_names,
200-
axis_names,
201-
axis_names
202-
),
203-
out_specs=axis_names,
204-
check_rep=False
196+
shard_map.shard_map,
197+
mesh=self.mesh,
198+
in_specs=(axis_names, axis_names, axis_names),
199+
out_specs=axis_names,
200+
check_rep=False,
205201
)
206202
def wrap_flash_attention(query, key, value):
207203
return jax.vmap(self.dpa_layer)(query, key, value, mask=None)

src/maxdiffusion/train_sdxl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def main(argv: Sequence[str]) -> None:
4949
import os
5050
import tensorflow as tf
5151
import torch
52+
5253
tf.config.set_visible_devices([], "GPU")
5354
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
5455
torch.set_default_device("cpu")

0 commit comments

Comments
 (0)