Skip to content

Commit 5b0b8cf

Browse files
authored
Add batch divisibility check for VAE input sharding (#316)
* Adding check for batch size divisibility before sharding video condition tensor * pyink checks * Removed unused var * Moving commit retrieval to before JAX setup init * fix * replaced boundary_timestep with ratio in wan2.2 t2v * replaced boundary_timestep with ratio in wan2.2 t2v
1 parent a385c8e commit 5b0b8cf

10 files changed

Lines changed: 33 additions & 22 deletions

File tree

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ guidance_scale_high: 4.0
300300
# The timestep threshold. If `t` is at or above this value,
301301
# the `high_noise_model` is considered as the required model.
302302
# timestep to switch between low noise and high noise transformer
303-
boundary_timestep: 875
303+
boundary_ratio: 0.875
304304

305305
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
306306
guidance_rescale: 0.0

src/maxdiffusion/generate_wan.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
134134
num_inference_steps=config.num_inference_steps,
135135
guidance_scale_low=config.guidance_scale_low,
136136
guidance_scale_high=config.guidance_scale_high,
137-
boundary=config.boundary_timestep,
138137
)
139138
else:
140139
raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}")
@@ -162,13 +161,12 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
162161
return
163162

164163

165-
def run(config, pipeline=None, filename_prefix=""):
164+
def run(config, pipeline=None, filename_prefix="", commit_hash=None):
166165
model_key = config.model_name
167166
writer = max_utils.initialize_summary_writer(config)
168167
if jax.process_index() == 0 and writer:
169168
max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}")
170169

171-
commit_hash = get_git_commit_hash()
172170
if commit_hash:
173171
writer.add_text("inference/git_commit_hash", commit_hash, global_step=0)
174172
max_logging.log(f"Git Commit Hash: {commit_hash}")
@@ -250,12 +248,13 @@ def run(config, pipeline=None, filename_prefix=""):
250248

251249

252250
def main(argv: Sequence[str]) -> None:
251+
commit_hash = get_git_commit_hash()
253252
pyconfig.initialize(argv)
254253
try:
255254
flax.config.update("flax_always_shard_variable", False)
256255
except LookupError:
257256
pass
258-
run(pyconfig.config)
257+
run(pyconfig.config, commit_hash=commit_hash)
259258

260259

261260
if __name__ == "__main__":

src/maxdiffusion/models/attention_flax.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,9 @@ def ring_scan_body(carry, _):
384384
return (m, l, o, k_next, v_next), None
385385

386386
initial_carry = (m, l, o, k1, v1)
387-
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_context_shards - 1)
387+
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(
388+
ring_scan_body, initial_carry, None, length=num_context_shards - 1
389+
)
388390

389391
attention_output = o_final / l_final[..., None]
390392
else:
@@ -749,6 +751,7 @@ def __init__(
749751
self.dpa_layer = None
750752
if attention_kernel == "cudnn_flash_te":
751753
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
754+
752755
jax.config.update("jax_use_shardy_partitioner", False)
753756

754757
dpa_layer = DotProductAttention(
@@ -829,6 +832,7 @@ def setup(self):
829832
self.dpa_layer = None
830833
if self.attention_kernel == "cudnn_flash_te":
831834
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
835+
832836
jax.config.update("jax_use_shardy_partitioner", False)
833837

834838
dpa_layer = DotProductAttention(
@@ -848,7 +852,6 @@ def setup(self):
848852
variables = {}
849853
self.dpa_layer = functools.partial(dpa_layer.apply, variables)
850854

851-
852855
def apply_attention(self, query: Array, key: Array, value: Array, attention_mask: Array = None):
853856
return _apply_attention(
854857
query=query,

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@
2929

3030
CACHE_T = 2
3131
try:
32-
flax.config.update('flax_always_shard_variable', False)
32+
flax.config.update("flax_always_shard_variable", False)
3333
except LookupError:
3434
pass
3535

36+
3637
# Helper to ensure kernel_size, stride, padding are tuples of 3 integers
3738
def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]:
3839
"""Canonicalizes a value to a tuple of integers."""

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,10 @@ def prepare_latents_i2v_base(
544544
vae_dtype = getattr(self.vae, "dtype", jnp.float32)
545545
video_condition = video_condition.astype(vae_dtype)
546546
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
547-
sharding_spec = P(self.config.mesh_axes[0], None, None, None, None)
548-
video_condition = jax.lax.with_sharding_constraint(video_condition, sharding_spec)
547+
data_mesh_size = self.mesh.shape[self.config.mesh_axes[0]]
548+
if video_condition.shape[0] % data_mesh_size == 0:
549+
sharding_spec = P(self.config.mesh_axes[0], None, None, None, None)
550+
video_condition = jax.lax.with_sharding_constraint(video_condition, sharding_spec)
549551
encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode()
550552

551553
# Normalize latents

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def __call__(
118118
)
119119
# Set the TE shard_guard context_manager if using TE cudnn_flash attention
120120
if self.config.attention == "cudnn_flash_te":
121-
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
121+
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
122+
122123
shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
123124
else:
124125
shard_guard = nullcontext()

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
super().__init__(config=config, **kwargs)
3939
self.low_noise_transformer = low_noise_transformer
4040
self.high_noise_transformer = high_noise_transformer
41+
self.boundary_ratio = config.boundary_ratio
4142

4243
@classmethod
4344
def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True):
@@ -103,7 +104,6 @@ def __call__(
103104
num_inference_steps: int = 50,
104105
guidance_scale_low: float = 3.0,
105106
guidance_scale_high: float = 4.0,
106-
boundary: int = 875,
107107
num_videos_per_prompt: Optional[int] = 1,
108108
max_sequence_length: int = 512,
109109
latents: jax.Array = None,
@@ -129,18 +129,21 @@ def __call__(
129129
low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...)
130130
high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...)
131131

132+
boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps
133+
132134
p_run_inference = partial(
133135
run_inference_2_2,
134136
guidance_scale_low=guidance_scale_low,
135137
guidance_scale_high=guidance_scale_high,
136-
boundary=boundary,
138+
boundary=boundary_timestep,
137139
num_inference_steps=num_inference_steps,
138140
scheduler=self.scheduler,
139141
scheduler_state=scheduler_state,
140142
)
141143
# Set the TE shard_guard context_manager if using TE cudnn_flash attention
142144
if self.config.attention == "cudnn_flash_te":
143-
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
145+
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
146+
144147
shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
145148
else:
146149
shard_guard = nullcontext()

src/maxdiffusion/pyconfig.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def user_init(raw_keys):
201201
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
202202
# Verify qkv is sharded across sequence.
203203
if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]:
204-
max_logging.log(f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set.")
204+
max_logging.log(
205+
f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set."
206+
)
205207
logical_axis_rules = list(raw_keys["logical_axis_rules"])
206208
max_logging.log(f"Initial logical axis rules: {logical_axis_rules}")
207209
new_rules = []

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,7 @@ def test_wan_time_text_embedding(self):
125125

126126
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
127127
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
128-
temb, timestep_proj, encoder_hidden_states, _, _ = layer(
129-
dummy_timestep, dummy_encoder_hidden_states
130-
)
128+
temb, timestep_proj, encoder_hidden_states, _, _ = layer(dummy_timestep, dummy_encoder_hidden_states)
131129
assert temb.shape == (batch_size, dim)
132130
assert timestep_proj.shape == (batch_size, time_proj_dim)
133131
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
309309
pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60)
310310
max_logging.log(pretty_string)
311311
max_logging.log("------------------------------------------------")
312-
if self.config.hardware != 'gpu':
312+
if self.config.hardware != "gpu":
313313
max_utils.delete_pytree(params)
314314
data_shardings = self.get_data_shardings(mesh)
315315
eval_data_shardings = self.get_eval_data_shardings(mesh)
@@ -368,14 +368,16 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
368368

369369
# Designate the context parallel axis for sharding
370370
if self.config.attention == "cudnn_flash_te":
371-
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
371+
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
372+
372373
shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
373374
else:
374375
shard_guard = nullcontext()
375376

376377
next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config)
377-
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, \
378-
shard_guard, nn_partitioning.axis_rules(self.config.logical_axis_rules):
378+
with jax.profiler.StepTraceAnnotation(
379+
"train", step_num=step
380+
), pipeline.mesh, shard_guard, nn_partitioning.axis_rules(self.config.logical_axis_rules):
379381
state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state)
380382
train_metric["scalar"]["learning/loss"].block_until_ready()
381383
last_step_completion = datetime.datetime.now()

0 commit comments

Comments
 (0)