Skip to content

Commit 3e540c5

Browse files
committed
add option to replicate vae. Fix cross attn splash.
1 parent f279995 commit 3e540c5

5 files changed

Lines changed: 34 additions & 12 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ weights_dtype: 'bfloat16'
4040
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4141
activations_dtype: 'bfloat16'
4242

43+
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
44+
replicate_vae: False
45+
4346
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
4447
# Options are "DEFAULT", "HIGH", "HIGHEST"
4548
# fp32 activations and fp32 weights with HIGHEST will provide the best precision

src/maxdiffusion/max_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ def walk_and_upload_blobs(config, output_dir):
221221

222222

223223
def device_put_replicated(x, sharding):
224+
"""
225+
Although the name indiciates replication, this function can be used
226+
to also shard an array based on sharding.
227+
"""
224228
return jax.make_array_from_callback(x.shape, sharding, lambda index: x[index])
225229

226230

src/maxdiffusion/models/attention_flax.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,20 +166,25 @@ def _tpu_flash_attention(
166166
dtype: jnp.dtype = jnp.float32,
167167
) -> jax.Array:
168168
"""TPU Flash Attention"""
169-
170-
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
169+
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
170+
# Cross-attention where kv dims are much smaller due to encoder_hidden_states.
171+
# If kv seq_len is padded too much, it causes issues in attention calculations.
172+
if key.shape[1] != query.shape[1]:
173+
kv_max_block_size = key.shape[1]
174+
else:
175+
kv_max_block_size = q_max_block_size
171176
if flash_block_sizes:
172177
block_sizes = flash_block_sizes
173178
else:
174179
block_sizes = splash_attention_kernel.BlockSizes(
175-
block_q=min(max_block_size, query.shape[2]),
176-
block_kv_compute=min(max_block_size, key.shape[2]),
177-
block_kv=min(max_block_size, key.shape[2]),
178-
block_q_dkv=min(max_block_size, query.shape[2]),
179-
block_kv_dkv=min(max_block_size, key.shape[2]),
180-
block_kv_dkv_compute=min(max_block_size, query.shape[2]),
181-
block_q_dq=min(max_block_size, query.shape[2]),
182-
block_kv_dq=min(max_block_size, query.shape[2]),
180+
block_q=min(q_max_block_size, query.shape[2]),
181+
block_kv_compute=min(kv_max_block_size, key.shape[2]),
182+
block_kv=min(kv_max_block_size, key.shape[2]),
183+
block_q_dkv=min(q_max_block_size, query.shape[2]),
184+
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
185+
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
186+
block_q_dq=min(q_max_block_size, query.shape[2]),
187+
block_kv_dq=min(kv_max_block_size, query.shape[2]),
183188
)
184189

185190
num_fsdp_shards = mesh.shape["fsdp"]

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
220220
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
221221
for path, val in flax.traverse_util.flatten_dict(params).items():
222222
sharding = logical_state_sharding[path].value
223+
if config.replicate_vae:
224+
sharding = NamedSharding(mesh, P())
223225
state[path].value = device_put_replicated(val, sharding)
224226
state = nnx.from_flat_state(state)
225227

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from maxdiffusion.utils import load_video
3737
from skimage.metrics import structural_similarity as ssim
3838
from flax.training import train_state
39+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
3940

4041

4142
class TrainState(train_state.TrainState):
@@ -53,6 +54,10 @@ def generate_sample(config, pipeline, filename_prefix):
5354
"""
5455
Generates a video to validate training did not corrupt the model
5556
"""
57+
if not hasattr(pipeline, "vae"):
58+
wan_vae, vae_cache = WanPipeline.load_vae(pipeline.mesh.devices, pipeline.mesh, nnx.Rngs(jax.random.key(config.seed)), config)
59+
pipeline.vae = wan_vae
60+
pipeline.vae_cache = vae_cache
5661
return generate_wan(config, pipeline, filename_prefix)
5762

5863

@@ -140,10 +145,13 @@ def prepare_sample(features):
140145
def start_training(self):
141146

142147
pipeline = self.load_checkpoint()
143-
# del pipeline.vae
144-
145148
# Generate a sample before training to compare against generated sample after training.
146149
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
150+
151+
# save some memory.
152+
del pipeline.vae
153+
del pipeline.vae_cache
154+
147155
mesh = pipeline.mesh
148156
data_iterator = self.load_dataset(mesh)
149157

0 commit comments

Comments
 (0)