Skip to content

Commit 4140047

Browse files
committed
pyink checks
1 parent 14bee9e commit 4140047

7 files changed

Lines changed: 21 additions & 13 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,9 @@ def ring_scan_body(carry, _):
384384
return (m, l, o, k_next, v_next), None
385385

386386
initial_carry = (m, l, o, k1, v1)
387-
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_context_shards - 1)
387+
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(
388+
ring_scan_body, initial_carry, None, length=num_context_shards - 1
389+
)
388390

389391
attention_output = o_final / l_final[..., None]
390392
else:
@@ -749,6 +751,7 @@ def __init__(
749751
self.dpa_layer = None
750752
if attention_kernel == "cudnn_flash_te":
751753
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
754+
752755
jax.config.update("jax_use_shardy_partitioner", False)
753756

754757
dpa_layer = DotProductAttention(
@@ -829,6 +832,7 @@ def setup(self):
829832
self.dpa_layer = None
830833
if self.attention_kernel == "cudnn_flash_te":
831834
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
835+
832836
jax.config.update("jax_use_shardy_partitioner", False)
833837

834838
dpa_layer = DotProductAttention(
@@ -848,7 +852,6 @@ def setup(self):
848852
variables = {}
849853
self.dpa_layer = functools.partial(dpa_layer.apply, variables)
850854

851-
852855
def apply_attention(self, query: Array, key: Array, value: Array, attention_mask: Array = None):
853856
return _apply_attention(
854857
query=query,

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@
2929

3030
CACHE_T = 2
3131
try:
32-
flax.config.update('flax_always_shard_variable', False)
32+
flax.config.update("flax_always_shard_variable", False)
3333
except LookupError:
3434
pass
3535

36+
3637
# Helper to ensure kernel_size, stride, padding are tuples of 3 integers
3738
def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]:
3839
"""Canonicalizes a value to a tuple of integers."""

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def __call__(
118118
)
119119
# Set the TE shard_guard context_manager if using TE cudnn_flash attention
120120
if self.config.attention == "cudnn_flash_te":
121-
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
121+
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
122+
122123
shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
123124
else:
124125
shard_guard = nullcontext()

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def __call__(
140140
)
141141
# Set the TE shard_guard context_manager if using TE cudnn_flash attention
142142
if self.config.attention == "cudnn_flash_te":
143-
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
143+
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
144+
144145
shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
145146
else:
146147
shard_guard = nullcontext()

src/maxdiffusion/pyconfig.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def user_init(raw_keys):
201201
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
202202
# Verify qkv is sharded across sequence.
203203
if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]:
204-
max_logging.log(f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set.")
204+
max_logging.log(
205+
f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set."
206+
)
205207
logical_axis_rules = list(raw_keys["logical_axis_rules"])
206208
max_logging.log(f"Initial logical axis rules: {logical_axis_rules}")
207209
new_rules = []

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,7 @@ def test_wan_time_text_embedding(self):
125125

126126
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
127127
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
128-
temb, timestep_proj, encoder_hidden_states, _, _ = layer(
129-
dummy_timestep, dummy_encoder_hidden_states
130-
)
128+
temb, timestep_proj, encoder_hidden_states, _, _ = layer(dummy_timestep, dummy_encoder_hidden_states)
131129
assert temb.shape == (batch_size, dim)
132130
assert timestep_proj.shape == (batch_size, time_proj_dim)
133131
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
309309
pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60)
310310
max_logging.log(pretty_string)
311311
max_logging.log("------------------------------------------------")
312-
if self.config.hardware != 'gpu':
312+
if self.config.hardware != "gpu":
313313
max_utils.delete_pytree(params)
314314
data_shardings = self.get_data_shardings(mesh)
315315
eval_data_shardings = self.get_eval_data_shardings(mesh)
@@ -368,14 +368,16 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
368368

369369
# Designate the context parallel axis for sharding
370370
if self.config.attention == "cudnn_flash_te":
371-
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
371+
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
372+
372373
shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
373374
else:
374375
shard_guard = nullcontext()
375376

376377
next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config)
377-
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, \
378-
shard_guard, nn_partitioning.axis_rules(self.config.logical_axis_rules):
378+
with jax.profiler.StepTraceAnnotation(
379+
"train", step_num=step
380+
), pipeline.mesh, shard_guard, nn_partitioning.axis_rules(self.config.logical_axis_rules):
379381
state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state)
380382
train_metric["scalar"]["learning/loss"].block_until_ready()
381383
last_step_completion = datetime.datetime.now()

0 commit comments

Comments
 (0)