|
21 | 21 | from jax.experimental import shard_map |
22 | 22 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask |
23 | 23 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel |
24 | | - |
| 24 | +from einops import rearrange |
25 | 25 | from .. import common_types, max_logging |
26 | 26 |
|
27 | 27 | Array = common_types.Array |
|
35 | 35 | LENGTH = common_types.LENGTH |
36 | 36 | HEAD = common_types.HEAD |
37 | 37 | D_KV = common_types.D_KV |
| 38 | +EMBED = common_types.EMBED |
38 | 39 |
|
39 | 40 |
|
40 | 41 | class AttentionOp(nn.Module): |
@@ -63,7 +64,6 @@ def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None |
63 | 64 | def apply_attention(self, query: Array, key: Array, value: Array): |
64 | 65 | """Routes to different attention kernels.""" |
65 | 66 | self.check_attention_inputs(query, key, value) |
66 | | - |
67 | 67 | can_use_flash_attention = ( |
68 | 68 | query.shape[1] >= self.flash_min_seq_length |
69 | 69 | and key.shape[1] >= self.flash_min_seq_length |
@@ -111,8 +111,7 @@ def wrap_flash_attention(query, key, value): |
111 | 111 | block_q_dq=min(512, query.shape[2]), |
112 | 112 | block_kv_dq=min(512, query.shape[2]), |
113 | 113 | ) |
114 | | - |
115 | | - masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for i in range(query.shape[1])] |
| 114 | + masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for _ in range(query.shape[1])] |
116 | 115 | multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks) |
117 | 116 | splash_kernel = splash_attention_kernel.make_splash_mha( |
118 | 117 | mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes |
@@ -323,6 +322,177 @@ def chunk_scanner(chunk_idx, _): |
323 | 322 |
|
324 | 323 | return jnp.concatenate(res, axis=-3) # fuse the chunked result back |
325 | 324 |
|
| 325 | +class FlaxFluxAttention(nn.Module): |
| 326 | + query_dim: int |
| 327 | + heads: int = 8 |
| 328 | + dim_head: int = 64 |
| 329 | + dropout: float = 0.0 |
| 330 | + use_memory_efficient_attention: bool = False |
| 331 | + split_head_dim: bool = False |
| 332 | + attention_kernel: str = "dot_product" |
| 333 | + flash_min_seq_length: int = 4096 |
| 334 | + flash_block_sizes: BlockSizes = None |
| 335 | + mesh: jax.sharding.Mesh = None |
| 336 | + dtype: jnp.dtype = jnp.float32 |
| 337 | + weights_dtype: jnp.dtype = jnp.float32 |
| 338 | + query_axis_names: AxisNames = (BATCH, LENGTH, HEAD) |
| 339 | + key_axis_names: AxisNames = (BATCH, LENGTH, HEAD) |
| 340 | + value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) |
| 341 | + out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) |
| 342 | + precision: jax.lax.Precision = None |
| 343 | + qkv_bias: bool = False |
| 344 | + |
| 345 | + def setup(self): |
| 346 | + if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: |
| 347 | + raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") |
| 348 | + inner_dim = self.dim_head * self.heads |
| 349 | + scale = self.dim_head**-0.5 |
| 350 | + |
| 351 | + self.attention_op = AttentionOp( |
| 352 | + mesh=self.mesh, |
| 353 | + attention_kernel=self.attention_kernel, |
| 354 | + scale=scale, |
| 355 | + heads=self.heads, |
| 356 | + dim_head=self.dim_head, |
| 357 | + flash_min_seq_length=self.flash_min_seq_length, |
| 358 | + use_memory_efficient_attention=self.use_memory_efficient_attention, |
| 359 | + split_head_dim=self.split_head_dim, |
| 360 | + flash_block_sizes=self.flash_block_sizes, |
| 361 | + dtype=self.dtype, |
| 362 | + float32_qk_product=False, |
| 363 | + ) |
| 364 | + |
| 365 | + kernel_axes = ("embed", "heads") |
| 366 | + qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes) |
| 367 | + |
| 368 | + self.qkv = nn.Dense( |
| 369 | + inner_dim * 3, |
| 370 | + kernel_init=qkv_init_kernel, |
| 371 | + use_bias=self.qkv_bias, |
| 372 | + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), |
| 373 | + dtype=self.dtype, |
| 374 | + param_dtype=self.weights_dtype, |
| 375 | + name="i_qkv", |
| 376 | + precision=self.precision, |
| 377 | + ) |
| 378 | + |
| 379 | + self.encoder_qkv = nn.Dense( |
| 380 | + inner_dim * 3, |
| 381 | + kernel_init=qkv_init_kernel, |
| 382 | + use_bias=self.qkv_bias, |
| 383 | + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), |
| 384 | + dtype=self.dtype, |
| 385 | + param_dtype=self.weights_dtype, |
| 386 | + name="e_qkv", |
| 387 | + precision=self.precision, |
| 388 | + ) |
| 389 | + |
| 390 | + self.proj_attn = nn.Dense( |
| 391 | + self.query_dim, |
| 392 | + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), |
| 393 | + use_bias=True, |
| 394 | + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), |
| 395 | + dtype=self.dtype, |
| 396 | + param_dtype=self.weights_dtype, |
| 397 | + name="i_proj", |
| 398 | + precision=self.precision, |
| 399 | + ) |
| 400 | + |
| 401 | + self.encoder_proj_attn = nn.Dense( |
| 402 | + self.query_dim, |
| 403 | + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), |
| 404 | + use_bias=True, |
| 405 | + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), |
| 406 | + dtype=self.dtype, |
| 407 | + param_dtype=self.weights_dtype, |
| 408 | + name="e_proj", |
| 409 | + precision=self.precision, |
| 410 | + ) |
| 411 | + |
| 412 | + self.query_norm = nn.RMSNorm( |
| 413 | + dtype=self.dtype, |
| 414 | + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), |
| 415 | + param_dtype=self.weights_dtype, |
| 416 | + ) |
| 417 | + self.key_norm = nn.RMSNorm( |
| 418 | + dtype=self.dtype, |
| 419 | + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), |
| 420 | + param_dtype=self.weights_dtype, |
| 421 | + ) |
| 422 | + |
| 423 | + self.encoder_query_norm = nn.RMSNorm( |
| 424 | + dtype=self.dtype, |
| 425 | + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), |
| 426 | + param_dtype=self.weights_dtype, |
| 427 | + ) |
| 428 | + self.encoder_key_norm = nn.RMSNorm( |
| 429 | + dtype=self.dtype, |
| 430 | + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), |
| 431 | + param_dtype=self.weights_dtype, |
| 432 | + ) |
| 433 | + |
| 434 | + def apply_rope(self, xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: |
| 435 | + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) |
| 436 | + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) |
| 437 | + |
| 438 | + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] |
| 439 | + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] |
| 440 | + |
| 441 | + return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) |
| 442 | + |
| 443 | + def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): |
| 444 | + |
| 445 | + qkv_proj = self.qkv(hidden_states) |
| 446 | + B, L = hidden_states.shape[:2] |
| 447 | + H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3 |
| 448 | + qkv_proj = qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) |
| 449 | + query_proj, key_proj, value_proj = qkv_proj |
| 450 | + |
| 451 | + query_proj = self.query_norm(query_proj) |
| 452 | + |
| 453 | + key_proj = self.key_norm(key_proj) |
| 454 | + |
| 455 | + if encoder_hidden_states is not None: |
| 456 | + |
| 457 | + encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states) |
| 458 | + B, L = encoder_hidden_states.shape[:2] |
| 459 | + H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3 |
| 460 | + encoder_qkv_proj = encoder_qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) |
| 461 | + encoder_query_proj, encoder_key_proj, encoder_value_proj = encoder_qkv_proj |
| 462 | + |
| 463 | + encoder_query_proj = self.encoder_query_norm(encoder_query_proj) |
| 464 | + |
| 465 | + encoder_key_proj = self.encoder_key_norm(encoder_key_proj) |
| 466 | + |
| 467 | + query_proj = jnp.concatenate((encoder_query_proj, query_proj), axis=2) |
| 468 | + key_proj = jnp.concatenate((encoder_key_proj, key_proj), axis=2) |
| 469 | + value_proj = jnp.concatenate((encoder_value_proj, value_proj), axis=2) |
| 470 | + |
| 471 | + query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) |
| 472 | + key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) |
| 473 | + value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) |
| 474 | + |
| 475 | + image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) |
| 476 | + query_proj, key_proj = self.apply_rope(query_proj, key_proj, image_rotary_emb) |
| 477 | + |
| 478 | + query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1) |
| 479 | + key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1) |
| 480 | + value_proj = value_proj.transpose(0, 2, 1, 3).reshape(value_proj.shape[0], value_proj.shape[2], -1) |
| 481 | + |
| 482 | + attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) |
| 483 | + context_attn_output = None |
| 484 | + |
| 485 | + if encoder_hidden_states is not None: |
| 486 | + context_attn_output, attn_output = ( |
| 487 | + attn_output[:, : encoder_hidden_states.shape[1]], |
| 488 | + attn_output[:, encoder_hidden_states.shape[1] :], |
| 489 | + ) |
| 490 | + |
| 491 | + attn_output = self.proj_attn(attn_output) |
| 492 | + |
| 493 | + context_attn_output = self.encoder_proj_attn(context_attn_output) |
| 494 | + |
| 495 | + return attn_output, context_attn_output |
326 | 496 |
|
327 | 497 | class FlaxAttention(nn.Module): |
328 | 498 | r""" |
|
0 commit comments