Added support for WAN 2.2 model#281
Conversation
0837512 to
56b1761
Compare
9c812c4 to
11d30fc
Compare
entrpn
left a comment
There was a problem hiding this comment.
Can you also make sure Wan2.1 training and inference still runs without issues before this PR goes into main?
sample logs inference: sample logs training: preview-xpk.sh logs: https://console.cloud.google.com/kubernetes/service/us-central1/bodaborg-tpu7x-128/default/prishajain-wan-v7x-20k-30315/details?project=cloud-tpu-multipod-dev |
This reverts commit d6cdb1e.
|
Is there a reason why training logs don't show training steps? |
By default the log_period = 100, we were running this for a shorter number of training steps. |
coolkp
left a comment
There was a problem hiding this comment.
What does validate training equal to true do? Should it always default to true?
using this for validating model_name = wan2.1 only for training, since we needed to move validation for model name in pyconfig.py |
Added support for WAN 2.2 model
This PR introduces support for the WAN 2.2 model and refactors the existing pipeline to dynamically handle multiple model versions.
New Model Support (WAN 2.2)
The configuration for the WAN 2.2 model are added:
base_wan_27b.yml: New configuration file for the WAN 2.2 (27B) model.Refactoring for Multi-Model Support
To allow the system to select between WAN 2.1 and 2.2, several core files were modified:
generate_wan.py: calls the correct pipeline (WAN 2.1 / WAN 2.2) depending on the model_name parameter in config files.wan_utils.py: Updated all transformer loading functions to include a subfolder parameter.base_wan_14b.yml: Added the model_name parameter to conform to the new configuration standard.wan_pipeline.py: The processing pipeline which supports both WAN 2.1 and WAN 2.2 by using separate base classeswan_checkpointer.py: Handles checkpoint loading/saving for both WAN 2.1 and WAN 2.2.wan_checkpointer_test.py: Tests for WAN 2.1 and WAN 2.2 checkpointerWe tested the above pipeline for both WAN 2.1 and WAN 2.2