Skip to content

Commit 50a029d

Browse files
lint
1 parent fc77dc0 commit 50a029d

5 files changed

Lines changed: 23 additions & 18 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def run(config):
4040
prompt = [config.prompt] * batch_multiplier
4141
negative_prompt = [config.negative_prompt] * batch_multiplier
4242

43-
max_logging.log(f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}")
43+
max_logging.log(
44+
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
45+
)
4446

4547
videos = pipeline(
4648
prompt=prompt,
@@ -91,6 +93,7 @@ def run(config):
9193
)
9294
print("generation time: ", (time.perf_counter() - s0))
9395

96+
9497
def main(argv: Sequence[str]) -> None:
9598
pyconfig.initialize(argv)
9699
run(pyconfig.config)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,7 @@ def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, t
315315
)
316316

317317
# 1. Self-attention
318-
norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(
319-
hidden_states.dtype
320-
)
318+
norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype)
321319
attn_output = self.attn1(
322320
hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb
323321
)
@@ -329,13 +327,9 @@ def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, t
329327
hidden_states = hidden_states + attn_output
330328

331329
# 3. Feed-forward
332-
norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype(
333-
hidden_states.dtype
334-
)
330+
norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype)
335331
ff_output = self.ffn(norm_hidden_states)
336-
hidden_states = (hidden_states + ff_output * c_gate_msa).astype(
337-
hidden_states.dtype
338-
)
332+
hidden_states = (hidden_states + ff_output * c_gate_msa).astype(hidden_states.dtype)
339333
return hidden_states
340334

341335

src/maxdiffusion/tests/attention_test.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,16 @@ def setUp(self):
3737
def test_splash_attention(self):
3838
"""Test numerics of splash attention are equivalent to dot_product"""
3939

40-
pyconfig.initialize([None, os.path.join(THIS_DIR, "..", "configs", "base21.yml"),
41-
'flash_block_sizes={"block_q" : 512, "block_kv_compute": 512, "block_kv": 512,'
42-
'"block_q_dkv": 512, "block_kv_dkv": 512, "block_kv_dkv_compute": 512,'
43-
'"block_q_dq": 512, "block_kv_dq": 512}',], unittest=True)
40+
pyconfig.initialize(
41+
[
42+
None,
43+
os.path.join(THIS_DIR, "..", "configs", "base21.yml"),
44+
'flash_block_sizes={"block_q" : 512, "block_kv_compute": 512, "block_kv": 512,'
45+
'"block_q_dkv": 512, "block_kv_dkv": 512, "block_kv_dkv_compute": 512,'
46+
'"block_q_dq": 512, "block_kv_dq": 512}',
47+
],
48+
unittest=True,
49+
)
4450
config = pyconfig.config
4551

4652
batch = 8
@@ -57,7 +63,7 @@ def test_splash_attention(self):
5763
split_head_dim=True,
5864
attention_kernel="dot_product",
5965
mesh=None,
60-
dtype=jnp.bfloat16
66+
dtype=jnp.bfloat16,
6167
)
6268

6369
params = dot_product_attention.init(key2, x)["params"]
@@ -75,7 +81,7 @@ def test_splash_attention(self):
7581
attention_kernel="flash",
7682
mesh=mesh,
7783
dtype=jnp.bfloat16,
78-
flash_block_sizes=flash_block_sizes
84+
flash_block_sizes=flash_block_sizes,
7985
)
8086

8187
params = splash_attention.init(key2, x)["params"]

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def test_wan_model(self):
272272
timestep=dummy_timestep,
273273
encoder_hidden_states=dummy_encoder_hidden_states,
274274
is_uncond=jnp.array(True, dtype=jnp.bool_),
275-
slg_mask=jnp.zeros(40, dtype=jnp.bool_)
275+
slg_mask=jnp.zeros(40, dtype=jnp.bool_),
276276
)
277277
assert dummy_output.shape == hidden_states_shape
278278

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data):
116116
for step in np.arange(start_step, self.config.max_train_steps):
117117
if self.config.enable_profiler and step == first_profiling_step:
118118
max_utils.activate_profiler(self.config)
119-
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
119+
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules(
120+
self.config.logical_axis_rules
121+
):
120122
state, train_metric, rng = p_train_step(state, graphdef, data, rng)
121123

122124
new_time = datetime.datetime.now()

0 commit comments

Comments
 (0)