Skip to content

Commit 91b6af7

Browse files
committed
sync pyink version to 23.10.0
1 parent 0299786 commit 91b6af7

19 files changed

Lines changed: 386 additions & 454 deletions

src/maxdiffusion/__init__.py

Lines changed: 182 additions & 196 deletions
Large diffs are not rendered by default.

src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def _ring_attention_forward(
7979
ring_axis: str,
8080
rotate_segment_ids: bool = True,
8181
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
82-
8382
if q.shape[-1] != k.shape[-1]:
8483
raise NotImplementedError("Queries and keys must have the same head dimension.")
8584

src/maxdiffusion/kernels/splash_attention/splash_attention_kernel.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,6 @@ def _splash_attention_fwd(
941941
dkv_mask_sparsity: float,
942942
max_logit_value: jax.Array | None = None,
943943
) -> tuple[tuple[jax.Array], base.SplashResidualsType]:
944-
945944
# TODO: add some higher order AD check that isn't save_residuals based.
946945
# if save_residuals:
947946
# raise NotImplementedError("Higher-order AD not supported.")
@@ -1180,7 +1179,6 @@ def init():
11801179
dv_scratch_ref[...] = jnp.zeros_like(dv_scratch_ref)
11811180

11821181
def body(i, _, has_partial_mask=False):
1183-
11841182
slice_k = pl.ds(i * bkv_compute, bkv_compute)
11851183
q = q_ref[...] # We keep q potentially transposed, since it's always RHS
11861184
if config.use_base2_exp:

src/maxdiffusion/kernels/splash_attention/splash_attention_kernel_test.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,14 +290,7 @@ def _generate_inputs(
290290
is_mqa: bool,
291291
is_segmented: bool,
292292
use_sinks: bool = False,
293-
) -> tuple[
294-
jax.Array,
295-
jax.Array,
296-
jax.Array,
297-
jax.Array | None,
298-
splash.SegmentIds | None,
299-
jax.Array,
300-
]:
293+
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array | None, splash.SegmentIds | None, jax.Array,]:
301294
seed = data.draw(seed_strategy())
302295
key = random.key(seed)
303296
k1, k2, k3, k_sinks, k_do = random.split(key, 5)

src/maxdiffusion/kernels/splash_attention/splash_attention_mask.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,12 @@ def __eq__(self, other: object):
278278
return self.shape == other.shape and self.offset == other.offset and np.array_equal(self.q_sequence, other.q_sequence)
279279

280280
def __hash__(self):
281-
return hash(
282-
(
283-
type(self),
284-
self.shape,
285-
self.offset,
286-
self.q_sequence.tobytes() if self.q_sequence is not None else None,
287-
)
288-
)
281+
return hash((
282+
type(self),
283+
self.shape,
284+
self.offset,
285+
self.q_sequence.tobytes() if self.q_sequence is not None else None,
286+
))
289287

290288

291289
class ChunkedCausalMask(_ComputableMask):
@@ -340,14 +338,12 @@ def __eq__(self, other: object):
340338
)
341339

342340
def __hash__(self):
343-
return hash(
344-
(
345-
type(self),
346-
self.shape,
347-
self.chunk_size,
348-
self.q_sequence.tobytes() if self.q_sequence is not None else None,
349-
)
350-
)
341+
return hash((
342+
type(self),
343+
self.shape,
344+
self.chunk_size,
345+
self.q_sequence.tobytes() if self.q_sequence is not None else None,
346+
))
351347

352348

353349
class LocalMask(_ComputableMask):
@@ -419,15 +415,13 @@ def __eq__(self, other: object):
419415
)
420416

421417
def __hash__(self):
422-
return hash(
423-
(
424-
type(self),
425-
self.shape,
426-
self.window_size,
427-
self.offset,
428-
self.q_sequence.tobytes() if self.q_sequence is not None else None,
429-
)
430-
)
418+
return hash((
419+
type(self),
420+
self.shape,
421+
self.window_size,
422+
self.offset,
423+
self.q_sequence.tobytes() if self.q_sequence is not None else None,
424+
))
431425

432426

433427
@dataclasses.dataclass(slots=True)

src/maxdiffusion/kernels/splash_attention/splash_attention_mask_info.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -446,12 +446,10 @@ def _process_mask(
446446
# Partial blocks are deduplicated and stored in unique_chunks to save memory.
447447
for coords in np.ndindex((q_blocks_count, kv_blocks_count)):
448448
(q_idx, kv_idx) = coords
449-
chunk = mask[
450-
(
451-
slice(q_idx * q_block_size, (q_idx + 1) * q_block_size),
452-
slice(kv_idx * kv_block_size, (kv_idx + 1) * kv_block_size),
453-
)
454-
]
449+
chunk = mask[(
450+
slice(q_idx * q_block_size, (q_idx + 1) * q_block_size),
451+
slice(kv_idx * kv_block_size, (kv_idx + 1) * kv_block_size),
452+
)]
455453
if chunk.any():
456454
if chunk.all():
457455
state_grid[q_idx, kv_idx] = 2

src/maxdiffusion/kernels/splash_attention/splash_attention_mask_test.py

Lines changed: 52 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -374,39 +374,37 @@ def test_lazy_causal_mask_chunking(self, block_size: tuple[int, int], shape: tup
374374
block_size,
375375
)
376376

377-
@parameterized.parameters(
378-
[
379-
((256, 256), (1024, 1024), (128, None), 0),
380-
((256, 128), (1024, 1024), (128, None), 16),
381-
((128, 256), (1024, 1024), (128, None), 16),
382-
((256, 256), (1024, 1024), (128, 256), 0),
383-
((256, 128), (1024, 1024), (128, 256), 0),
384-
((128, 256), (1024, 1024), (128, 256), 16),
385-
((256, 256), (1024, 1024), (None, 256), 0),
386-
((256, 128), (1024, 1024), (None, 256), 32),
387-
((128, 256), (1024, 1024), (None, 256), 32),
388-
#
389-
((256, 256), (1024, 2048), (128, None), 0),
390-
((256, 128), (1024, 2048), (128, None), 16),
391-
((128, 256), (1024, 2048), (128, None), 16),
392-
((256, 256), (1024, 2048), (128, 256), 0),
393-
((256, 128), (1024, 2048), (128, 256), 0),
394-
((128, 256), (1024, 2048), (128, 256), 16),
395-
((256, 256), (1024, 2048), (None, 256), 0),
396-
((256, 128), (1024, 2048), (None, 256), 32),
397-
((128, 256), (1024, 2048), (None, 256), 32),
398-
#
399-
((256, 256), (2048, 1024), (128, None), 0),
400-
((256, 128), (2048, 1024), (128, None), 16),
401-
((128, 256), (2048, 1024), (128, None), 16),
402-
((256, 256), (2048, 1024), (128, 256), 0),
403-
((256, 128), (2048, 1024), (128, 256), 0),
404-
((128, 256), (2048, 1024), (128, 256), 16),
405-
((256, 256), (2048, 1024), (None, 256), 0),
406-
((256, 128), (2048, 1024), (None, 256), 32),
407-
((128, 256), (2048, 1024), (None, 256), 32),
408-
]
409-
)
377+
@parameterized.parameters([
378+
((256, 256), (1024, 1024), (128, None), 0),
379+
((256, 128), (1024, 1024), (128, None), 16),
380+
((128, 256), (1024, 1024), (128, None), 16),
381+
((256, 256), (1024, 1024), (128, 256), 0),
382+
((256, 128), (1024, 1024), (128, 256), 0),
383+
((128, 256), (1024, 1024), (128, 256), 16),
384+
((256, 256), (1024, 1024), (None, 256), 0),
385+
((256, 128), (1024, 1024), (None, 256), 32),
386+
((128, 256), (1024, 1024), (None, 256), 32),
387+
#
388+
((256, 256), (1024, 2048), (128, None), 0),
389+
((256, 128), (1024, 2048), (128, None), 16),
390+
((128, 256), (1024, 2048), (128, None), 16),
391+
((256, 256), (1024, 2048), (128, 256), 0),
392+
((256, 128), (1024, 2048), (128, 256), 0),
393+
((128, 256), (1024, 2048), (128, 256), 16),
394+
((256, 256), (1024, 2048), (None, 256), 0),
395+
((256, 128), (1024, 2048), (None, 256), 32),
396+
((128, 256), (1024, 2048), (None, 256), 32),
397+
#
398+
((256, 256), (2048, 1024), (128, None), 0),
399+
((256, 128), (2048, 1024), (128, None), 16),
400+
((128, 256), (2048, 1024), (128, None), 16),
401+
((256, 256), (2048, 1024), (128, 256), 0),
402+
((256, 128), (2048, 1024), (128, 256), 0),
403+
((128, 256), (2048, 1024), (128, 256), 16),
404+
((256, 256), (2048, 1024), (None, 256), 0),
405+
((256, 128), (2048, 1024), (None, 256), 32),
406+
((128, 256), (2048, 1024), (None, 256), 32),
407+
])
410408
def test_lazy_local_mask_chunking(
411409
self,
412410
block_size: tuple[int, int],
@@ -1164,17 +1162,15 @@ def test_two_qseq_shards_causal_local_stacked(self):
11641162

11651163
expected_num_active_blocks = np.array([10, 10], dtype=np.int32)
11661164

1167-
expected_partial_mask_blocks = np.stack(
1168-
[
1169-
np.tri(*block_shape, dtype=np.int8),
1170-
np.triu(
1171-
np.tri(*block_shape, window_size, dtype=np.int8),
1172-
-window_size,
1173-
),
1174-
np.tri(*block_shape, -window_size, dtype=np.int8),
1175-
np.triu(np.ones(block_shape, dtype=np.int8), window_size),
1176-
]
1177-
)
1165+
expected_partial_mask_blocks = np.stack([
1166+
np.tri(*block_shape, dtype=np.int8),
1167+
np.triu(
1168+
np.tri(*block_shape, window_size, dtype=np.int8),
1169+
-window_size,
1170+
),
1171+
np.tri(*block_shape, -window_size, dtype=np.int8),
1172+
np.triu(np.ones(block_shape, dtype=np.int8), window_size),
1173+
])
11781174

11791175
expected_mask_info = mask_info_lib.MaskInfo(
11801176
expected_mask_next,
@@ -1345,20 +1341,18 @@ def test_two_shards_local_wide_local_narrow_stacked(self, q_seq_shards, kv_seq_s
13451341

13461342
expected_active_rows_dkv = np.concatenate(
13471343
[
1348-
np.array(
1349-
[
1350-
0,
1351-
0,
1352-
1,
1353-
1,
1354-
1,
1355-
2,
1356-
2,
1357-
2,
1358-
3,
1359-
3,
1360-
]
1361-
),
1344+
np.array([
1345+
0,
1346+
0,
1347+
1,
1348+
1,
1349+
1,
1350+
2,
1351+
2,
1352+
2,
1353+
3,
1354+
3,
1355+
]),
13621356
np.array([0, 0, 1, 1, 2, 2, 3, -1, -1, -1]),
13631357
],
13641358
axis=0,
@@ -1453,7 +1447,6 @@ def test_causal_two_q_shards_two_kv_shards(self, return_dynamic_grid):
14531447
q_sequence=None,
14541448
)
14551449
else:
1456-
14571450
expected_mask_info_dkv = mask_info_lib.MaskInfo(
14581451
mask_next=np.array(
14591452
[0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0],

src/maxdiffusion/models/attention_flax.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ def _tpu_flash_attention(
295295
check_rep=False,
296296
)
297297
def wrap_flash_attention(query, key, value):
298-
299298
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
300299
block_q_sizes = (
301300
block_sizes.block_q,
@@ -1251,7 +1250,6 @@ def setup(self):
12511250
)
12521251

12531252
def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None):
1254-
12551253
qkv_proj = self.qkv(hidden_states)
12561254
B, L = hidden_states.shape[:2]
12571255
H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3
@@ -1263,7 +1261,6 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non
12631261
key_proj = self.key_norm(key_proj)
12641262

12651263
if encoder_hidden_states is not None:
1266-
12671264
encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states)
12681265
B, L = encoder_hidden_states.shape[:2]
12691266
H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3
@@ -1357,7 +1354,6 @@ class FlaxAttention(nn.Module):
13571354
quant: Quant = None
13581355

13591356
def setup(self):
1360-
13611357
if self.attention_kernel == "flash" and self.mesh is None:
13621358
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
13631359
inner_dim = self.dim_head * self.heads

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 42 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -202,29 +202,27 @@ def setup(self):
202202
dtype=self.dtype,
203203
param_dtype=self.weights_dtype,
204204
)
205-
self.img_mlp = nn.Sequential(
206-
[
207-
nn.Dense(
208-
int(self.dim * self.mlp_ratio),
209-
use_bias=True,
210-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
211-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
212-
dtype=self.dtype,
213-
param_dtype=self.weights_dtype,
214-
precision=self.precision,
215-
),
216-
nn.gelu,
217-
nn.Dense(
218-
self.dim,
219-
use_bias=True,
220-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
221-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
222-
dtype=self.dtype,
223-
param_dtype=self.weights_dtype,
224-
precision=self.precision,
225-
),
226-
]
227-
)
205+
self.img_mlp = nn.Sequential([
206+
nn.Dense(
207+
int(self.dim * self.mlp_ratio),
208+
use_bias=True,
209+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
210+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
211+
dtype=self.dtype,
212+
param_dtype=self.weights_dtype,
213+
precision=self.precision,
214+
),
215+
nn.gelu,
216+
nn.Dense(
217+
self.dim,
218+
use_bias=True,
219+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
220+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
221+
dtype=self.dtype,
222+
param_dtype=self.weights_dtype,
223+
precision=self.precision,
224+
),
225+
])
228226

229227
self.txt_norm2 = nn.LayerNorm(
230228
use_bias=False,
@@ -233,29 +231,27 @@ def setup(self):
233231
dtype=self.dtype,
234232
param_dtype=self.weights_dtype,
235233
)
236-
self.txt_mlp = nn.Sequential(
237-
[
238-
nn.Dense(
239-
int(self.dim * self.mlp_ratio),
240-
use_bias=True,
241-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
242-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
243-
dtype=self.dtype,
244-
param_dtype=self.weights_dtype,
245-
precision=self.precision,
246-
),
247-
nn.gelu,
248-
nn.Dense(
249-
self.dim,
250-
use_bias=True,
251-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
252-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
253-
dtype=self.dtype,
254-
param_dtype=self.weights_dtype,
255-
precision=self.precision,
256-
),
257-
]
258-
)
234+
self.txt_mlp = nn.Sequential([
235+
nn.Dense(
236+
int(self.dim * self.mlp_ratio),
237+
use_bias=True,
238+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
239+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
240+
dtype=self.dtype,
241+
param_dtype=self.weights_dtype,
242+
precision=self.precision,
243+
),
244+
nn.gelu,
245+
nn.Dense(
246+
self.dim,
247+
use_bias=True,
248+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
249+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
250+
dtype=self.dtype,
251+
param_dtype=self.weights_dtype,
252+
precision=self.precision,
253+
),
254+
])
259255

260256
# let chunk size default to None
261257
self._chunk_size = None

0 commit comments

Comments
 (0)