Skip to content

Commit 00a19e2

Browse files
committed
reformatted
1 parent 7e92510 commit 00a19e2

5 files changed

Lines changed: 131 additions & 127 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ def scan_fn(carry, block):
650650

651651
rematted_block_forward = self.gradient_checkpoint.apply(
652652
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
653-
)
653+
)
654654
initial_carry = (h, rngs)
655655
final_carry, _ = nnx.scan(
656656
rematted_block_forward,
@@ -676,7 +676,10 @@ def layer_forward(hidden_states):
676676
)
677677

678678
rematted_layer_forward = self.gradient_checkpoint.apply(
679-
layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
679+
layer_forward,
680+
self.names_which_can_be_saved,
681+
self.names_which_can_be_offloaded,
682+
prevent_cse=not self.scan_layers,
680683
)
681684
h_out = rematted_layer_forward(h_out)
682685
return h_out
@@ -702,7 +705,7 @@ def layer_forward(hidden_states):
702705
)
703706
hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6))
704707
hidden_states = hidden_states.reshape(batch_size, -1, num_frames, height, width)
705-
708+
706709
if return_residual:
707710
return hidden_states, residual_x
708711
return hidden_states

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 120 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def transformer_forward_pass(
775775
cached_residual=cached_residual,
776776
return_residual=return_residual,
777777
)
778-
778+
779779
if return_residual:
780780
noise_pred, residual_x = outputs
781781
else:
@@ -899,56 +899,61 @@ def transformer_forward_pass_cfg_cache(
899899
noise_pred_merged = noise_uncond_approx + guidance_scale * (noise_cond - noise_uncond_approx)
900900
return noise_pred_merged, noise_cond
901901

902+
902903
def nearest_interp(src, target_len):
903-
"""Nearest neighbor interpolation for ratio scaling layout."""
904-
src_len = len(src)
905-
if target_len == 1:
906-
import numpy as np
907-
return np.array([src[-1]])
904+
"""Nearest neighbor interpolation for ratio scaling layout."""
905+
src_len = len(src)
906+
if target_len == 1:
908907
import numpy as np
909-
indices = np.round(np.linspace(0, src_len - 1, target_len)).astype(np.int32)
910-
return src[indices]
908+
909+
return np.array([src[-1]])
910+
import numpy as np
911+
912+
indices = np.round(np.linspace(0, src_len - 1, target_len)).astype(np.int32)
913+
return src[indices]
914+
911915

912916
def init_magcache(num_inference_steps, retention_ratio, mag_ratios_base):
913-
"""Initialize MagCache variables and interpolate ratios.
914-
915-
Args:
916-
num_inference_steps: Number of inference steps.
917-
retention_ratio: Retention ratio of unchanged steps.
918-
mag_ratios_base: Base magnitude ratios array or list.
919-
"""
920-
import numpy as np
921-
922-
accumulated_ratio_cond = 1.0
923-
accumulated_ratio_uncond = 1.0
924-
accumulated_err_cond = 0.0
925-
accumulated_err_uncond = 0.0
926-
accumulated_steps_cond = 0
927-
accumulated_steps_uncond = 0
928-
cached_residual = None
929-
930-
skip_warmup = int(num_inference_steps * retention_ratio)
931-
932-
mag_ratios_base = np.array(mag_ratios_base)
933-
934-
if len(mag_ratios_base) != num_inference_steps * 2:
935-
mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps)
936-
mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps)
937-
mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1)
938-
else:
939-
mag_ratios = mag_ratios_base
940-
941-
return (
942-
accumulated_ratio_cond,
943-
accumulated_ratio_uncond,
944-
accumulated_err_cond,
945-
accumulated_err_uncond,
946-
accumulated_steps_cond,
947-
accumulated_steps_uncond,
948-
cached_residual,
949-
skip_warmup,
950-
mag_ratios,
951-
)
917+
"""Initialize MagCache variables and interpolate ratios.
918+
919+
Args:
920+
num_inference_steps: Number of inference steps.
921+
retention_ratio: Retention ratio of unchanged steps.
922+
mag_ratios_base: Base magnitude ratios array or list.
923+
"""
924+
import numpy as np
925+
926+
accumulated_ratio_cond = 1.0
927+
accumulated_ratio_uncond = 1.0
928+
accumulated_err_cond = 0.0
929+
accumulated_err_uncond = 0.0
930+
accumulated_steps_cond = 0
931+
accumulated_steps_uncond = 0
932+
cached_residual = None
933+
934+
skip_warmup = int(num_inference_steps * retention_ratio)
935+
936+
mag_ratios_base = np.array(mag_ratios_base)
937+
938+
if len(mag_ratios_base) != num_inference_steps * 2:
939+
mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps)
940+
mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps)
941+
mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1)
942+
else:
943+
mag_ratios = mag_ratios_base
944+
945+
return (
946+
accumulated_ratio_cond,
947+
accumulated_ratio_uncond,
948+
accumulated_err_cond,
949+
accumulated_err_uncond,
950+
accumulated_steps_cond,
951+
accumulated_steps_uncond,
952+
cached_residual,
953+
skip_warmup,
954+
mag_ratios,
955+
)
956+
952957

953958
def magcache_step(
954959
step,
@@ -959,71 +964,71 @@ def magcache_step(
959964
skip_warmup=0,
960965
use_magcache=None,
961966
):
962-
"""Update MagCache accumulated state and decide if to skip.
963-
964-
Args:
965-
step: Current inference step.
966-
mag_ratios: Interpolated magnitude ratios array.
967-
accumulated_state: Tuple containing accumulated variables.
968-
magcache_thresh: Error threshold.
969-
magcache_K: Max skip steps.
970-
skip_warmup: Warmup steps threshold.
971-
use_magcache: Optional manual override boolean to enable/disable cache for this step.
972-
"""
973-
import numpy as np
974-
975-
(
976-
accumulated_ratio_cond,
977-
accumulated_ratio_uncond,
978-
accumulated_err_cond,
979-
accumulated_err_uncond,
980-
accumulated_steps_cond,
981-
accumulated_steps_uncond,
982-
) = accumulated_state
983-
984-
cur_mag_ratio_cond = mag_ratios[step * 2]
985-
cur_mag_ratio_uncond = mag_ratios[step * 2 + 1]
986-
987-
if use_magcache is None:
988-
use_magcache = True
989-
if step < skip_warmup:
990-
use_magcache = False
991-
992-
skip_blocks = False
993-
if use_magcache:
994-
new_ratio_cond = accumulated_ratio_cond * cur_mag_ratio_cond
995-
new_ratio_uncond = accumulated_ratio_uncond * cur_mag_ratio_uncond
996-
997-
err_cond = np.abs(1.0 - new_ratio_cond)
998-
err_uncond = np.abs(1.0 - new_ratio_uncond)
999-
1000-
if (
1001-
accumulated_err_cond + err_cond < magcache_thresh
1002-
and accumulated_steps_cond < magcache_K
1003-
and accumulated_err_uncond + err_uncond < magcache_thresh
1004-
and accumulated_steps_uncond < magcache_K
1005-
):
1006-
skip_blocks = True
1007-
accumulated_ratio_cond = new_ratio_cond
1008-
accumulated_ratio_uncond = new_ratio_uncond
1009-
accumulated_err_cond += err_cond
1010-
accumulated_err_uncond += err_uncond
1011-
accumulated_steps_cond += 1
1012-
accumulated_steps_uncond += 1
1013-
else:
1014-
accumulated_ratio_cond = 1.0
1015-
accumulated_ratio_uncond = 1.0
1016-
accumulated_err_cond = 0.0
1017-
accumulated_err_uncond = 0.0
1018-
accumulated_steps_cond = 0
1019-
accumulated_steps_uncond = 0
1020-
1021-
new_state = (
1022-
accumulated_ratio_cond,
1023-
accumulated_ratio_uncond,
1024-
accumulated_err_cond,
1025-
accumulated_err_uncond,
1026-
accumulated_steps_cond,
1027-
accumulated_steps_uncond,
1028-
)
1029-
return skip_blocks, new_state
967+
"""Update MagCache accumulated state and decide if to skip.
968+
969+
Args:
970+
step: Current inference step.
971+
mag_ratios: Interpolated magnitude ratios array.
972+
accumulated_state: Tuple containing accumulated variables.
973+
magcache_thresh: Error threshold.
974+
magcache_K: Max skip steps.
975+
skip_warmup: Warmup steps threshold.
976+
use_magcache: Optional manual override boolean to enable/disable cache for this step.
977+
"""
978+
import numpy as np
979+
980+
(
981+
accumulated_ratio_cond,
982+
accumulated_ratio_uncond,
983+
accumulated_err_cond,
984+
accumulated_err_uncond,
985+
accumulated_steps_cond,
986+
accumulated_steps_uncond,
987+
) = accumulated_state
988+
989+
cur_mag_ratio_cond = mag_ratios[step * 2]
990+
cur_mag_ratio_uncond = mag_ratios[step * 2 + 1]
991+
992+
if use_magcache is None:
993+
use_magcache = True
994+
if step < skip_warmup:
995+
use_magcache = False
996+
997+
skip_blocks = False
998+
if use_magcache:
999+
new_ratio_cond = accumulated_ratio_cond * cur_mag_ratio_cond
1000+
new_ratio_uncond = accumulated_ratio_uncond * cur_mag_ratio_uncond
1001+
1002+
err_cond = np.abs(1.0 - new_ratio_cond)
1003+
err_uncond = np.abs(1.0 - new_ratio_uncond)
1004+
1005+
if (
1006+
accumulated_err_cond + err_cond < magcache_thresh
1007+
and accumulated_steps_cond < magcache_K
1008+
and accumulated_err_uncond + err_uncond < magcache_thresh
1009+
and accumulated_steps_uncond < magcache_K
1010+
):
1011+
skip_blocks = True
1012+
accumulated_ratio_cond = new_ratio_cond
1013+
accumulated_ratio_uncond = new_ratio_uncond
1014+
accumulated_err_cond += err_cond
1015+
accumulated_err_uncond += err_uncond
1016+
accumulated_steps_cond += 1
1017+
accumulated_steps_uncond += 1
1018+
else:
1019+
accumulated_ratio_cond = 1.0
1020+
accumulated_ratio_uncond = 1.0
1021+
accumulated_err_cond = 0.0
1022+
accumulated_err_uncond = 0.0
1023+
accumulated_steps_cond = 0
1024+
accumulated_steps_uncond = 0
1025+
1026+
new_state = (
1027+
accumulated_ratio_cond,
1028+
accumulated_ratio_uncond,
1029+
accumulated_err_cond,
1030+
accumulated_err_uncond,
1031+
accumulated_steps_cond,
1032+
accumulated_steps_uncond,
1033+
)
1034+
return skip_blocks, new_state

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, nearest_interp, init_magcache, magcache_step
15+
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, init_magcache, magcache_step
1616
from ...models.wan.transformers.transformer_wan import WanModel
17-
from typing import List, Union, Optional, Any
17+
from typing import List, Union, Optional
1818
from ...pyconfig import HyperParameters
1919
from functools import partial
2020
from flax import nnx
2121
from flax.linen import partitioning as nn_partitioning
2222
import jax
2323
import jax.numpy as jnp
24-
import numpy as np
2524
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
2625

2726

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414

1515
from maxdiffusion import max_logging
1616
from maxdiffusion.image_processor import PipelineImageInput
17-
from .wan_pipeline import WanPipeline, transformer_forward_pass, nearest_interp, init_magcache, magcache_step
17+
from .wan_pipeline import WanPipeline, transformer_forward_pass, init_magcache, magcache_step
1818
from ...models.wan.transformers.transformer_wan import WanModel
1919
from typing import List, Union, Optional, Tuple
2020
from ...pyconfig import HyperParameters
2121
from functools import partial
2222
from flax import nnx
23-
import numpy as np
2423
from flax.linen import partitioning as nn_partitioning
2524
import jax
2625
import jax.numpy as jnp
@@ -315,7 +314,7 @@ def run_inference_2_1_i2v(
315314

316315
for step in range(num_inference_steps):
317316
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
318-
317+
319318
skip_blocks = False
320319
if use_magcache and do_cfg:
321320
accumulated_state = (
@@ -345,7 +344,7 @@ def run_inference_2_1_i2v(
345344
latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=-1)
346345
timestep = jnp.broadcast_to(t, latents_input.shape[0])
347346
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
348-
347+
349348
outputs = transformer_forward_pass(
350349
graphdef,
351350
sharded_state,

src/maxdiffusion/tests/wan_magcache_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ class Wan21I2VMagCacheSmokeTest(unittest.TestCase):
152152

153153
@classmethod
154154
def setUpClass(cls):
155-
156155
pyconfig.initialize(
157156
[
158157
None,
@@ -224,4 +223,3 @@ def test_magcache_speedup_and_fidelity(self):
224223
self.assertGreaterEqual(ssim, 0.98)
225224
self.assertGreater(speedup, 1.0)
226225
self.assertGreaterEqual(psnr, 30.0)
227-

0 commit comments

Comments
 (0)