Skip to content

Commit 62c208f

Browse files
committed
import added in wan_checkpointer_test.py
1 parent 354ed44 commit 62c208f

2 files changed

Lines changed: 1 addition & 5 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,11 +1132,6 @@ def __call__(
11321132
encoder_attention_mask_img = jnp.ones((encoder_hidden_states_img.shape[0], padded_img_len), dtype=jnp.int32)
11331133
if image_seq_len_actual < padded_img_len:
11341134
encoder_attention_mask_img = encoder_attention_mask_img.at[:, image_seq_len_actual:].set(0)
1135-
1136-
# Extract image portion of attention mask (includes padded tokens)
1137-
# encoder_attention_mask_img = None
1138-
# if encoder_attention_mask is not None:
1139-
# encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len]
11401135
else:
11411136
# If no image_seq_len is specified, treat all as text
11421137
encoder_hidden_states_img = None

src/maxdiffusion/tests/wan_checkpointer_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from maxdiffusion.checkpointing.wan_checkpointer_2_2 import WanCheckpointer2_2
1818
from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p1 import WanCheckpointerI2V_2_1
1919
from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2
20+
from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
2021

2122
class WanCheckpointer2_1Test(unittest.TestCase):
2223
"""Tests for WAN 2.1 checkpointer."""

0 commit comments

Comments
 (0)