Skip to content

Commit bca0d0e

Browse files
committed
fix lint issues
1 parent 049940e commit bca0d0e

5 files changed

Lines changed: 69 additions & 69 deletions

File tree

src/maxdiffusion/models/wan/transformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
limitations under the License.
1515
"""
1616

17-
from .transformer_wan_animate import NNXWanAnimateTransformer3DModel
17+
from .transformer_wan_animate import WanAnimateTransformer3DModel

src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -819,8 +819,8 @@ def __call__(
819819
return hidden_states
820820

821821

822-
class NNXWanAnimateTransformer3DModel(nnx.Module, FlaxModelMixin, ConfigMixin):
823-
"""NNX Wan Animate transformer with pose and face conditioning."""
822+
class WanAnimateTransformer3DModel(nnx.Module, FlaxModelMixin, ConfigMixin):
823+
"""Wan Animate transformer with pose and face conditioning."""
824824

825825
@register_to_config
826826
def __init__(
@@ -1055,7 +1055,7 @@ def conditional_named_scope(self, name: str):
10551055
def init_weights(self, rng: jax.Array, eval_only: bool = False) -> Dict[str, Any]:
10561056
"""NNX modules initialize parameters eagerly during construction."""
10571057
del rng, eval_only
1058-
raise NotImplementedError("NNXWanAnimateTransformer3DModel initializes weights during construction.")
1058+
raise NotImplementedError("WanAnimateTransformer3DModel initializes weights during construction.")
10591059

10601060
def _apply_face_adapter(self, hidden_states: jax.Array, motion_vec: Optional[jax.Array], block_idx) -> jax.Array:
10611061
"""Inject face-conditioning latents at the configured adapter blocks."""

src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from maxdiffusion.max_utils import device_put_replicated, get_flash_block_sizes, get_precision
4747
from maxdiffusion.video_processor import VideoProcessor
4848

49-
from ...models.wan.transformers.transformer_wan_animate import NNXWanAnimateTransformer3DModel
49+
from ...models.wan.transformers.transformer_wan_animate import WanAnimateTransformer3DModel
5050
from ...models.wan.wan_utils import load_wan_animate_transformer
5151
from ...pyconfig import HyperParameters
5252
from .wan_pipeline import WanPipeline, cast_with_exclusion
@@ -59,22 +59,22 @@ def create_sharded_animate_transformer(
5959
config: HyperParameters,
6060
restored_checkpoint=None,
6161
subfolder: str = "transformer",
62-
) -> NNXWanAnimateTransformer3DModel:
63-
"""Creates a sharded NNXWanAnimateTransformer3DModel on device.
62+
) -> WanAnimateTransformer3DModel:
63+
"""Creates a sharded WanAnimateTransformer3DModel on device.
6464
6565
Follows the same pattern as create_sharded_logical_transformer in
66-
wan_pipeline.py but uses NNXWanAnimateTransformer3DModel and the
66+
wan_pipeline.py but uses WanAnimateTransformer3DModel and the
6767
animate-specific weight loader.
6868
"""
6969

7070
def _create_model(rngs: nnx.Rngs, wan_config: dict):
71-
return NNXWanAnimateTransformer3DModel(**wan_config, rngs=rngs)
71+
return WanAnimateTransformer3DModel(**wan_config, rngs=rngs)
7272

7373
# 1. Load config.
7474
if restored_checkpoint:
7575
wan_config = restored_checkpoint["wan_config"]
7676
else:
77-
wan_config = NNXWanAnimateTransformer3DModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder)
77+
wan_config = WanAnimateTransformer3DModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder)
7878

7979
wan_config["mesh"] = mesh
8080
wan_config["dtype"] = config.activations_dtype
@@ -215,15 +215,15 @@ class WanAnimatePipeline(WanPipeline):
215215
216216
Args:
217217
config: HyperParameters configuration.
218-
transformer: NNXWanAnimateTransformer3DModel instance (may be None for
218+
transformer: WanAnimateTransformer3DModel instance (may be None for
219219
VAE-only mode).
220220
**kwargs: Passed to WanPipeline.__init__ (tokenizer, text_encoder, vae, etc.)
221221
"""
222222

223223
def __init__(
224224
self,
225225
config: HyperParameters,
226-
transformer: Optional[NNXWanAnimateTransformer3DModel],
226+
transformer: Optional[WanAnimateTransformer3DModel],
227227
**kwargs,
228228
):
229229
super().__init__(config=config, **kwargs)
@@ -255,7 +255,7 @@ def load_animate_transformer(
255255
config: HyperParameters,
256256
restored_checkpoint=None,
257257
subfolder: str = "transformer",
258-
) -> NNXWanAnimateTransformer3DModel:
258+
) -> WanAnimateTransformer3DModel:
259259
with mesh:
260260
return create_sharded_animate_transformer(
261261
devices_array=devices_array,
@@ -273,7 +273,7 @@ def _load_and_init(
273273
restored_checkpoint=None,
274274
vae_only: bool = False,
275275
load_transformer: bool = True,
276-
) -> Tuple["WanAnimatePipeline", Optional[NNXWanAnimateTransformer3DModel]]:
276+
) -> Tuple["WanAnimatePipeline", Optional[WanAnimateTransformer3DModel]]:
277277
common_components = cls._create_common_components(config, vae_only)
278278
transformer = None
279279
if not vae_only and load_transformer:

src/maxdiffusion/tests/wan_animate_diffusers_parity_test.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -608,54 +608,54 @@ def test_mask_video_preprocessing_matches_diffusers(self):
608608

609609
def test_check_inputs_matches_diffusers_validation(self):
610610
invalid_calls = [
611-
dict(
612-
prompt="prompt",
613-
negative_prompt=None,
614-
image=PIL.Image.new("RGB", (16, 16)),
615-
pose_video=[PIL.Image.new("RGB", (16, 16))],
616-
face_video=[PIL.Image.new("RGB", (16, 16))],
617-
background_video=None,
618-
mask_video=None,
619-
height=16,
620-
width=16,
621-
prompt_embeds=jnp.zeros((1, 1, 1)),
622-
negative_prompt_embeds=None,
623-
image_embeds=None,
624-
mode="animate",
625-
prev_segment_conditioning_frames=1,
626-
),
627-
dict(
628-
prompt="prompt",
629-
negative_prompt=None,
630-
image=PIL.Image.new("RGB", (16, 16)),
631-
pose_video=[PIL.Image.new("RGB", (16, 16))],
632-
face_video=[PIL.Image.new("RGB", (16, 16))],
633-
background_video=None,
634-
mask_video=None,
635-
height=18,
636-
width=16,
637-
prompt_embeds=None,
638-
negative_prompt_embeds=None,
639-
image_embeds=None,
640-
mode="animate",
641-
prev_segment_conditioning_frames=1,
642-
),
643-
dict(
644-
prompt="prompt",
645-
negative_prompt=None,
646-
image=PIL.Image.new("RGB", (16, 16)),
647-
pose_video=[PIL.Image.new("RGB", (16, 16))],
648-
face_video=[PIL.Image.new("RGB", (16, 16))],
649-
background_video=None,
650-
mask_video=None,
651-
height=16,
652-
width=16,
653-
prompt_embeds=None,
654-
negative_prompt_embeds=None,
655-
image_embeds=None,
656-
mode="replace",
657-
prev_segment_conditioning_frames=3,
658-
),
611+
{
612+
"prompt": "prompt",
613+
"negative_prompt": None,
614+
"image": PIL.Image.new("RGB", (16, 16)),
615+
"pose_video": [PIL.Image.new("RGB", (16, 16))],
616+
"face_video": [PIL.Image.new("RGB", (16, 16))],
617+
"background_video": None,
618+
"mask_video": None,
619+
"height": 16,
620+
"width": 16,
621+
"prompt_embeds": jnp.zeros((1, 1, 1)),
622+
"negative_prompt_embeds": None,
623+
"image_embeds": None,
624+
"mode": "animate",
625+
"prev_segment_conditioning_frames": 1,
626+
},
627+
{
628+
"prompt": "prompt",
629+
"negative_prompt": None,
630+
"image": PIL.Image.new("RGB", (16, 16)),
631+
"pose_video": [PIL.Image.new("RGB", (16, 16))],
632+
"face_video": [PIL.Image.new("RGB", (16, 16))],
633+
"background_video": None,
634+
"mask_video": None,
635+
"height": 18,
636+
"width": 16,
637+
"prompt_embeds": None,
638+
"negative_prompt_embeds": None,
639+
"image_embeds": None,
640+
"mode": "animate",
641+
"prev_segment_conditioning_frames": 1,
642+
},
643+
{
644+
"prompt": "prompt",
645+
"negative_prompt": None,
646+
"image": PIL.Image.new("RGB", (16, 16)),
647+
"pose_video": [PIL.Image.new("RGB", (16, 16))],
648+
"face_video": [PIL.Image.new("RGB", (16, 16))],
649+
"background_video": None,
650+
"mask_video": None,
651+
"height": 16,
652+
"width": 16,
653+
"prompt_embeds": None,
654+
"negative_prompt_embeds": None,
655+
"image_embeds": None,
656+
"mode": "replace",
657+
"prev_segment_conditioning_frames": 3,
658+
},
659659
]
660660

661661
for kwargs in invalid_calls:
@@ -780,7 +780,7 @@ def _scalar(x):
780780
hf_negative = torch.tensor(to_numpy(max_negative))
781781
hf_image = torch.tensor(to_numpy(max_image))
782782

783-
scheduler_config = dict(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=5.0)
783+
scheduler_config = {"prediction_type": "flow_prediction", "use_flow_sigmas": True, "flow_shift": 5.0}
784784
max_scheduler = FlaxUniPCMultistepScheduler(**scheduler_config)
785785
max_state = max_scheduler.create_state()
786786
max_state = max_scheduler.set_timesteps(max_state, num_inference_steps=timestep_count, shape=max_latents.shape)
@@ -852,7 +852,7 @@ def _scalar(x):
852852
np.testing.assert_allclose(to_numpy(max_next), hf_channel_first_to_last(hf_next), atol=1e-5, rtol=1e-5)
853853

854854
def test_flax_unipc_flow_sigmas_match_diffusers(self):
855-
scheduler_config = dict(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=5.0)
855+
scheduler_config = {"prediction_type": "flow_prediction", "use_flow_sigmas": True, "flow_shift": 5.0}
856856

857857
max_scheduler = FlaxUniPCMultistepScheduler(**scheduler_config)
858858
max_state = max_scheduler.create_state()

src/maxdiffusion/tests/wan_animate_module_parity_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
MotionConv2d,
5252
MotionEncoderResBlock,
5353
MotionLinear,
54-
NNXWanAnimateTransformer3DModel,
54+
WanAnimateTransformer3DModel,
5555
WanAnimateFaceBlockCrossAttention,
5656
WanAnimateFaceEncoder,
5757
WanAnimateMotionEncoder,
@@ -320,7 +320,7 @@ def test_wan_animate_transformer_weight_mapping_covers_all_local_params(self):
320320
hf_model = HFWanAnimateTransformer3DModel(**cfg).eval()
321321

322322
with self.mesh, nn_partitioning.axis_rules(self.logical_axis_rules):
323-
max_model = NNXWanAnimateTransformer3DModel(rngs=self.rngs, scan_layers=False, mesh=self.mesh, **cfg)
323+
max_model = WanAnimateTransformer3DModel(rngs=self.rngs, scan_layers=False, mesh=self.mesh, **cfg)
324324
missing_keys, flax_state_dict = map_hf_wan_animate_state_to_local(
325325
max_model, hf_model, num_layers=cfg["num_layers"], scan_layers=False
326326
)
@@ -359,7 +359,7 @@ def test_wan_animate_transformer_weight_mapping_covers_all_local_params_scanned(
359359
hf_model = HFWanAnimateTransformer3DModel(**cfg).eval()
360360

361361
with self.mesh, nn_partitioning.axis_rules(self.logical_axis_rules):
362-
max_model = NNXWanAnimateTransformer3DModel(rngs=self.rngs, scan_layers=True, mesh=self.mesh, **cfg)
362+
max_model = WanAnimateTransformer3DModel(rngs=self.rngs, scan_layers=True, mesh=self.mesh, **cfg)
363363
missing_keys, flax_state_dict = map_hf_wan_animate_state_to_local(
364364
max_model, hf_model, num_layers=cfg["num_layers"], scan_layers=True
365365
)
@@ -432,7 +432,7 @@ def test_wan_animate_transformer_forward_parity(self):
432432
hf_model = HFWanAnimateTransformer3DModel(**cfg).eval()
433433

434434
with self.mesh, nn_partitioning.axis_rules(self.logical_axis_rules):
435-
max_model = NNXWanAnimateTransformer3DModel(rngs=self.rngs, scan_layers=False, mesh=self.mesh, **cfg)
435+
max_model = WanAnimateTransformer3DModel(rngs=self.rngs, scan_layers=False, mesh=self.mesh, **cfg)
436436
missing_keys, _ = map_hf_wan_animate_state_to_local(
437437
max_model, hf_model, num_layers=cfg["num_layers"], scan_layers=False
438438
)
@@ -497,7 +497,7 @@ def test_wan_animate_transformer_forward_parity_scanned(self):
497497
hf_model = HFWanAnimateTransformer3DModel(**cfg).eval()
498498

499499
with self.mesh, nn_partitioning.axis_rules(self.logical_axis_rules):
500-
max_model = NNXWanAnimateTransformer3DModel(rngs=self.rngs, scan_layers=True, mesh=self.mesh, **cfg)
500+
max_model = WanAnimateTransformer3DModel(rngs=self.rngs, scan_layers=True, mesh=self.mesh, **cfg)
501501
missing_keys, _ = map_hf_wan_animate_state_to_local(
502502
max_model, hf_model, num_layers=cfg["num_layers"], scan_layers=True
503503
)

0 commit comments

Comments
 (0)