|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -from abc import ABC, abstractmethod |
| 15 | +from abc import abstractmethod |
16 | 16 | from typing import List, Union, Optional, Type |
17 | 17 | from functools import partial |
18 | 18 | import numpy as np |
@@ -466,7 +466,7 @@ def _denormalize_latents(self, latents: jax.Array) -> jax.Array: |
466 | 466 | latents = latents / latents_std + latents_mean |
467 | 467 | latents = latents.astype(jnp.float32) |
468 | 468 | return latents |
469 | | - |
| 469 | + |
470 | 470 | def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray: |
471 | 471 | """Decodes latents to video frames and postprocesses.""" |
472 | 472 | with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): |
@@ -508,7 +508,7 @@ def _get_subclass(cls, model_key: str) -> Type['WanPipeline']: |
508 | 508 | f"Supported keys are: {list(cls._SUBCLASS_MAP.keys())}" |
509 | 509 | ) |
510 | 510 | return subclass |
511 | | - |
| 511 | + |
512 | 512 | @classmethod |
513 | 513 | def from_checkpoint(cls, model_key: str, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): |
514 | 514 | subclass = cls._get_subclass(model_key) |
@@ -708,7 +708,6 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t |
708 | 708 | common_components = cls._create_common_components(config, vae_only) |
709 | 709 | low_noise_transformer, high_noise_transformer = None, None |
710 | 710 | if not vae_only and load_transformer: |
711 | | - rngs = nnx.Rngs(jax.random.key(config.seed)) |
712 | 711 | low_noise_transformer = super().load_transformer( |
713 | 712 | devices_array=common_components["devices_array"], |
714 | 713 | mesh=common_components["mesh"], |
|
0 commit comments