|
18 | 18 | import math |
19 | 19 | from typing import Any, Dict, Optional, Tuple |
20 | 20 | from enum import Enum, auto |
| 21 | + |
21 | 22 | import jax |
22 | 23 | import jax.nn as jnn |
23 | 24 | import jax.numpy as jnp |
@@ -213,7 +214,8 @@ def __call__( |
213 | 214 |
|
214 | 215 | # Adaptive Norm |
215 | 216 | if self.adaptive_norm in ["single_scale_shift", "single_scale"]: |
216 | | - assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] |
| 217 | + # [batch, 1 or num_tokens, embedding_dim] |
| 218 | + assert timestep.ndim == 3 |
217 | 219 | num_ada_params = self.scale_shift_table.shape[0] |
218 | 220 | ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( |
219 | 221 | batch_size, timestep.shape[1], num_ada_params, -1 |
@@ -452,7 +454,7 @@ def __call__( |
452 | 454 | deterministic: bool = True, |
453 | 455 | **cross_attention_kwargs, |
454 | 456 | ) -> jnp.ndarray: |
455 | | - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa: F821 |
| 457 | + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa F821 |
456 | 458 | assert cross_attention_kwargs.get("scale", None) is None, "Not supported" |
457 | 459 |
|
458 | 460 | input_axis_names = ("activation_batch", "activation_length", "activation_embed") |
@@ -636,27 +638,20 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): |
636 | 638 | raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") |
637 | 639 | # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") |
638 | 640 | # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. |
639 | | - qkvo_sharding_spec = jax.sharding.PartitionSpec( |
640 | | - ("data", "fsdp", "fsdp_transpose", "expert"), |
641 | | - ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), |
642 | | - None, |
643 | | - None, |
644 | | - ) |
645 | 641 | # qkvo_sharding_spec = jax.sharding.PartitionSpec( |
646 | 642 | # ("data", "fsdp", "fsdp_transpose", "expert"), |
647 | 643 | # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), |
648 | 644 | # None, |
649 | 645 | # None, |
650 | 646 | # ) |
651 | | - # qkvo_sharding_spec = jax.sharding.PartitionSpec( |
652 | | - # None, |
653 | | - # None, |
654 | | - # None, |
655 | | - # None, |
656 | | - # ) |
| 647 | + qkvo_sharding_spec = jax.sharding.PartitionSpec( |
| 648 | + "data", |
| 649 | + "fsdp", |
| 650 | + None, |
| 651 | + "tensor", |
| 652 | + ) |
657 | 653 | # Based on: ("activation_kv_batch", "activation_length") |
658 | | - qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") |
659 | | - # qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None) |
| 654 | + qkv_segment_ids_spec = jax.sharding.PartitionSpec("data", None) |
660 | 655 | wrapped_flash_attention = shard_map( |
661 | 656 | partial_flash_attention, |
662 | 657 | mesh=sharding_mesh, |
@@ -841,7 +836,8 @@ def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: |
841 | 836 | inner_dim = dim * self.mult |
842 | 837 | if inner_dim < 256: |
843 | 838 | raise ValueError("inner_dim must be at least 256") |
844 | | - inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 |
| 839 | + # round to nearest multiple of 256 |
| 840 | + inner_dim = round(inner_dim / 256) * 256 |
845 | 841 | else: |
846 | 842 | inner_dim = self.inner_dim |
847 | 843 |
|
|
0 commit comments