Skip to content

Commit abb97c3

Browse files
committed
fix: address various common pipeline bugs
- models/resnet_flax: add missing bias_init partition bounds to Conv layers - max_utils: use dict.get() for flash_block_sizes to prevent KeyErrors - maxdiffusion_utils: simplify VAE PyArrow latency encoding loop to prevent pad dropping and out-of-bounds loop references - models/attention_flax: allow custom flash_block_sizes to override cross-attention defaults
1 parent ceca471 commit abb97c3

4 files changed

Lines changed: 26 additions & 25 deletions

File tree

src/maxdiffusion/max_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -520,20 +520,20 @@ def get_flash_block_sizes(config):
520520
max_logging.log(
521521
"Tokamax kernel specified, Note: Tokamax only supports fused backward kernel."
522522
"Hence following flash block properties specified will be ignored:"
523-
f"block_q: {user_block_sizes['block_q']},"
523+
f"block_q: {user_block_sizes.get('block_q')},"
524524
f"block_q_dq: {user_block_sizes.get('block_q_dq')},"
525525
f"block_kv_dq: {user_block_sizes.get('block_kv_dq')},"
526526
f"use_fused_bwd_kernel: {user_block_sizes.get('use_fused_bwd_kernel')}"
527527
)
528528
flash_block_sizes = splash_attention_kernel.BlockSizes(
529-
block_q=user_block_sizes.get("block_q_dkv", user_block_sizes["block_kv"])
529+
block_q=user_block_sizes.get("block_q_dkv", user_block_sizes.get("block_kv"))
530530
if attention_is_tokamax
531-
else user_block_sizes["block_q"],
532-
block_kv_compute=user_block_sizes["block_kv_compute"],
533-
block_kv=user_block_sizes["block_kv"],
534-
block_q_dkv=user_block_sizes["block_q_dkv"],
535-
block_kv_dkv=user_block_sizes["block_kv_dkv"],
536-
block_kv_dkv_compute=user_block_sizes["block_kv_dkv_compute"],
531+
else user_block_sizes.get("block_q"),
532+
block_kv_compute=user_block_sizes.get("block_kv_compute", user_block_sizes.get("block_kv")),
533+
block_kv=user_block_sizes.get("block_kv"),
534+
block_q_dkv=user_block_sizes.get("block_q_dkv", user_block_sizes.get("block_q")),
535+
block_kv_dkv=user_block_sizes.get("block_kv_dkv", user_block_sizes.get("block_kv")),
536+
block_kv_dkv_compute=user_block_sizes.get("block_kv_dkv_compute", user_block_sizes.get("block_kv")),
537537
block_q_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_q_dq"),
538538
block_kv_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_kv_dq"),
539539
use_fused_bwd_kernel=True if attention_is_tokamax else value_or_none(user_block_sizes, "use_fused_bwd_kernel"),

src/maxdiffusion/maxdiffusion_utils.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,15 @@ def transform_images(
114114
if p_vae_apply:
115115
tensor_list = np.stack(tensor_list)
116116
ds_length = tensor_list.shape[0]
117-
iters = ds_length // global_batch_size
118-
latents_list = []
119117
local_batch_size = global_batch_size // jax.device_count()
120-
for i in range(0, iters * global_batch_size, local_batch_size):
118+
119+
pad_len = (local_batch_size - (ds_length % local_batch_size)) % local_batch_size
120+
if pad_len > 0:
121+
pad_tensor = np.zeros((pad_len,) + tensor_list.shape[1:], dtype=tensor_list.dtype)
122+
tensor_list = np.concatenate([tensor_list, pad_tensor], axis=0)
123+
124+
latents_list = []
125+
for i in range(0, tensor_list.shape[0], local_batch_size):
121126
sample_rng, rng = jax.random.split(rng)
122127
latents = p_vae_apply(tensor_list[i : i + local_batch_size], sample_rng)
123128
latents_list.append(latents)
@@ -126,14 +131,7 @@ def transform_images(
126131
b1, b2, c, l1, l2 = latents_list.shape
127132
latents_list = np.reshape(latents_list, (b1 * b2, c, l1, l2))
128133

129-
# TODO (Juan Acevedo): do last iteration, its required for the Pyarrow dataset
130-
# to not break due to items being fewer than expected. Is there a better way?
131-
if tensor_list[i + local_batch_size :].shape[0] != 0:
132-
sample_rng, rng = jax.random.split(rng)
133-
latents = p_vae_apply(tensor_list[i + local_batch_size :], sample_rng)
134-
examples[pixel_ids_key] = np.append(latents_list, latents, axis=0)
135-
else:
136-
examples[pixel_ids_key] = latents_list
134+
examples[pixel_ids_key] = latents_list[:ds_length]
137135
else:
138136
examples[pixel_ids_key] = tf.stack(tensor_list)
139137

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,19 +250,17 @@ def _tpu_flash_attention(
250250
kv_max_block_size = ((key.shape[1] + 127) // 128) * 128
251251
else:
252252
kv_max_block_size = q_max_block_size
253-
# ensure that for cross attention we override the block sizes.
254-
if flash_block_sizes and key.shape[1] == query.shape[1]:
253+
if flash_block_sizes:
255254
block_sizes = flash_block_sizes
256255
else:
257-
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
258256
block_sizes = splash_attention_kernel.BlockSizes(
259-
block_q=block_size_q,
257+
block_q=q_max_block_size,
260258
block_kv_compute=min(kv_max_block_size, key.shape[2]),
261259
block_kv=min(kv_max_block_size, key.shape[2]),
262-
block_q_dkv=block_size_q,
260+
block_q_dkv=q_max_block_size,
263261
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
264262
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
265-
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
263+
block_q_dq=None if attention_kernel == "tokamax_flash" else q_max_block_size,
266264
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
267265
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
268266
)

src/maxdiffusion/models/resnet_flax.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def setup(self):
5151
kernel_init=nn.with_logical_partitioning(
5252
nn.initializers.lecun_normal(), ("keep_1", "keep_2", "conv_in", "conv_out")
5353
),
54+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("conv_out",)),
5455
precision=self.precision,
5556
)
5657

@@ -85,6 +86,7 @@ def setup(self):
8586
kernel_init=nn.with_logical_partitioning(
8687
nn.initializers.lecun_normal(), ("keep_1", "keep_2", "conv_in", "conv_out")
8788
),
89+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("conv_out",)),
8890
precision=self.precision,
8991
)
9092

@@ -131,6 +133,7 @@ def setup(self):
131133
kernel_init=nn.with_logical_partitioning(
132134
nn.initializers.lecun_normal(), ("keep_1", "keep_2", "conv_in", "conv_out")
133135
),
136+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("conv_out",)),
134137
precision=self.precision,
135138
)
136139
out_channels = self.in_channels if self.out_channels is None else self.out_channels
@@ -144,6 +147,7 @@ def setup(self):
144147
kernel_init=nn.with_logical_partitioning(
145148
nn.initializers.lecun_normal(), ("keep_1", "keep_2", "conv_in", "conv_out")
146149
),
150+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("conv_out",)),
147151
precision=self.precision,
148152
)
149153

@@ -159,6 +163,7 @@ def setup(self):
159163
kernel_init=nn.with_logical_partitioning(
160164
nn.initializers.lecun_normal(), ("keep_1", "keep_2", "conv_in", "conv_out")
161165
),
166+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("conv_out",)),
162167
precision=self.precision,
163168
)
164169

0 commit comments

Comments
 (0)