Skip to content

Commit 69a93b9

Browse files
wip - context parallelism
1 parent e2cb67f commit 69a93b9

4 files changed

Lines changed: 63 additions & 23 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,11 @@ mesh_axes: ['data', 'fsdp', 'tensor']
112112
# conv_out : conv.shape[-1] weight
113113
logical_axis_rules: [
114114
['batch', 'data'],
115-
['activation_heads', 'fsdp'],
116-
['activation_batch', ['data','fsdp']],
115+
#['activation_heads', 'fsdp'],
116+
['activation_length', 'fsdp'],
117+
#['activation_heads', 'fsdp'],
118+
#['activation_heads', 'fsdp'],
119+
#['activation_batch', ['data','fsdp']],
117120
['activation_kv', 'tensor'],
118121
['mlp','tensor'],
119122
['embed','fsdp'],
@@ -141,14 +144,15 @@ ici_tensor_parallelism: 1
141144
# Replace with dataset path or train_data_dir. One has to be set.
142145
dataset_name: 'diffusers/pokemon-gpt4-captions'
143146
train_split: 'train'
144-
dataset_type: 'tf'
147+
dataset_type: 'tfrecord'
145148
cache_latents_text_encoder_outputs: True
146149
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
147150
# only apply to small dataset that fits in memory
148151
# prepare image latents and text encoder outputs
149152
# Reduce memory consumption and reduce step time during training
150153
# transformed dataset is saved at dataset_save_location
151-
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
154+
dataset_save_location: ''
155+
load_tfrecord_cached: True
152156
train_data_dir: ''
153157
dataset_config_name: ''
154158
jax_cache_dir: ''
@@ -185,6 +189,10 @@ per_device_batch_size: 1
185189
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
186190
global_batch_size: 0
187191

192+
# For creating tfrecords from dataset
193+
tfrecords_dir: ''
194+
no_records_per_shard: 0
195+
188196
warmup_steps_fraction: 0.1
189197
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
190198

src/maxdiffusion/generate_wan.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
from absl import app
2121
from maxdiffusion.utils import export_to_video
2222

23+
jax.config.update('jax_use_shardy_partitioner', True)
2324

24-
def run(config):
25+
26+
def run(config, pipeline=None, filename_prefix=""):
2527
print("seed: ", config.seed)
26-
pipeline = WanPipeline.from_pretrained(config)
28+
if pipeline is None:
29+
pipeline = WanPipeline.from_pretrained(config)
2730
s0 = time.perf_counter()
2831

2932
# Skip layer guidance
@@ -59,7 +62,7 @@ def run(config):
5962

6063
print("compile time: ", (time.perf_counter() - s0))
6164
for i in range(len(videos)):
62-
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
65+
export_to_video(videos[i], f"{filename_prefix}wan_output_{config.seed}_{i}.mp4", fps=config.fps)
6366
s0 = time.perf_counter()
6467
videos = pipeline(
6568
prompt=prompt,

src/maxdiffusion/models/attention_flax.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -173,25 +173,54 @@ def _tpu_flash_attention(
173173
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute)
174174

175175
axis_names = nn.logical_to_mesh_axes(flash_axis_names)
176+
kv_axis_names = nn.logical_to_mesh_axes((BATCH, HEAD, None, D_KV))
177+
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
178+
axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
179+
named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
180+
181+
cp_size=8
176182

177183
@functools.partial(
178-
shard_map.shard_map,
179-
mesh=mesh,
180-
in_specs=(
181-
axis_names,
182-
axis_names,
183-
axis_names,
184-
),
185-
out_specs=axis_names,
186-
check_rep=False,
184+
jax.jit,
185+
static_argnames=[
186+
"multi_head_mask",
187+
"shard_head_size"
188+
],
187189
)
188-
def wrap_flash_attention(query, key, value):
189-
masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for _ in range(query.shape[1])]
190-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks)
190+
def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
191191
splash_kernel = splash_attention_kernel.make_splash_mha(
192-
mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes
192+
mask=multi_head_mask,
193+
head_shards=shard_head_size, # the sizes of the axis is sharding over heads
194+
q_seq_shards=cp_size,
195+
block_sizes=block_sizes,
193196
)
194-
return jax.vmap(splash_kernel)(query, key, value)
197+
return splash_kernel
198+
199+
shard_head_size = 1
200+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2]))
201+
mask &= splash_attention_mask.LocalMask(
202+
shape=(query.shape[2], key.shape[2]),
203+
window_size=(query.shape[2], query.shape[2]),
204+
offset=0
205+
)
206+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
207+
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
208+
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
209+
@functools.partial(
210+
shard_map.shard_map,
211+
mesh=mesh,
212+
in_specs=(
213+
axis_names,
214+
kv_axis_names,
215+
kv_axis_names,
216+
segment_axis_names_splash_kernel,
217+
),
218+
out_specs=axis_names,
219+
check_rep=False
220+
)
221+
def wrap_flash_attention(query, key, value, splash_kernel):
222+
attention_output = jax.vmap(splash_kernel)(query, key, value)
223+
return attention_output
195224

196225
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
197226
# This warning might show up when doing model eval for example, when calculating model flops
@@ -201,7 +230,7 @@ def wrap_flash_attention(query, key, value):
201230
"Warning, batch dimension should be shardable among the devices in data and fsdp"
202231
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
203232
)
204-
x = wrap_flash_attention(query, key, value)
233+
x = wrap_flash_attention(query, key, value, splash_kernel)
205234
x = x[:, :, :query_seq_len, :kv_size]
206235
x = _reshape_heads_to_head_dim(x)
207236

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def __call__(
397397
num_channels_latents=num_channel_latents,
398398
)
399399

400-
data_sharding = NamedSharding(self.devices_array, P())
400+
data_sharding = NamedSharding(self.mesh, P())
401401
if len(prompt) % jax.device_count() == 0:
402402
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))
403403

0 commit comments

Comments
 (0)