Skip to content

Commit 0d97a52

Browse files
committed
Ruff format
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 69d2a30 commit 0d97a52

2 files changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
171171

172172
return tensor, kv_size, seq_len
173173

174-
def convert_to_tokamax_splash_config( block_sizes: BlockSizes,
174+
def convert_to_tokamax_splash_config( block_sizes: BlockSizes,
175175
q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR,
176176
k_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR,
177177
v_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR,
@@ -536,7 +536,7 @@ def _apply_attention(
536536
)
537537
elif attention_kernel == "ring":
538538
return _tpu_flash_attention(
539-
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel,
539+
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel,
540540
mask_padding_tokens=mask_padding_tokens,
541541
)
542542
elif attention_kernel == "cudnn_flash_te":

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def __call__(
567567
prompt = [prompt]
568568

569569
batch_size = len(prompt)
570-
570+
571571
with jax.named_scope("Encode-Prompt"):
572572
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
573573
prompt=prompt,
@@ -578,7 +578,7 @@ def __call__(
578578
)
579579

580580
num_channel_latents = self.transformer.config.in_channels
581-
if latents is None:
581+
if latents is None:
582582
latents = self.prepare_latents(
583583
batch_size=batch_size,
584584
vae_scale_factor_temporal=self.vae_scale_factor_temporal,

0 commit comments

Comments
 (0)