``` ├── .gitignore ├── INSTALL.md ├── LICENSE.txt ├── README.md ├── assets/ ├── comp_effic.png ├── data_for_diff_stage.jpg ├── i2v_res.png ├── logo.png ├── t2v_res.jpg ├── vben_vs_sota.png ├── video_dit_arch.jpg ├── video_vae_res.jpg ├── examples/ ├── flf2v_input_first_frame.png ├── flf2v_input_last_frame.png ├── i2v_input.JPG ├── generate.py ├── gradio/ ├── fl2v_14B_singleGPU.py ├── i2v_14B_singleGPU.py ├── t2i_14B_singleGPU.py ├── t2v_1.3B_singleGPU.py ├── t2v_14B_singleGPU.py ├── pyproject.toml ├── requirements.txt ├── tests/ ├── README.md ├── test.sh ├── wan/ ├── __init__.py ├── configs/ ├── __init__.py ├── shared_config.py ├── wan_i2v_14B.py ├── wan_t2v_14B.py ├── wan_t2v_1_3B.py ├── distributed/ ├── __init__.py ├── fsdp.py ├── xdit_context_parallel.py ├── first_last_frame2video.py ├── image2video.py ├── modules/ ├── __init__.py ├── attention.py ├── clip.py ├── model.py ├── t5.py ├── tokenizers.py ├── vae.py ├── xlm_roberta.py ├── text2video.py ``` ## /.gitignore ```gitignore path="/.gitignore" .* *.py[cod] # *.jpg *.jpeg # *.png *.gif *.bmp *.mp4 *.mov *.mkv *.log *.zip *.pt *.pth *.ckpt *.safetensors *.json # *.txt *.backup *.pkl *.html *.pdf *.whl cache __pycache__/ storage/ samples/ !.gitignore !requirements.txt .DS_Store *DS_Store google/ Wan2.1-T2V-14B/ Wan2.1-T2V-1.3B/ Wan2.1-I2V-14B-480P/ Wan2.1-I2V-14B-720P/ poetry.lock ``` ## /INSTALL.md # Installation Guide ## Install with pip ```bash pip install . pip install .[dev] # Installe aussi les outils de dev ``` ## Install with Poetry Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system. To install all dependencies: ```bash poetry install ``` ### Handling `flash-attn` Installation Issues If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes. #### No-Build-Isolation Installation (Recommended) ```bash poetry run pip install --upgrade pip setuptools wheel poetry run pip install flash-attn --no-build-isolation poetry install ``` #### Install from Git (Alternative) ```bash poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git ``` --- ### Running the Model Once the installation is complete, you can run **Wan2.1** using: ```bash poetry run python generate.py --task t2v-14B --size '1280x720' --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." ``` #### Test ```bash pytest tests/ ``` #### Format ```bash black . isort . ``` ## /LICENSE.txt Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ## /README.md # Wan2.1

💜 Wan    |    🖥️ GitHub    |   🤗 Hugging Face   |   🤖 ModelScope   |    📑 Technical Report    |    📑 Blog    |   💬 WeChat Group   |    📖 Discord  
----- [**Wan: Open and Advanced Large-Scale Video Generative Models**](https://arxiv.org/abs/2503.20314) In this repository, we present **Wan2.1**, a comprehensive and open suite of video foundation models that pushes the boundaries of video generation. **Wan2.1** offers these key features: - 👍 **SOTA Performance**: **Wan2.1** consistently outperforms existing open-source models and state-of-the-art commercial solutions across multiple benchmarks. - 👍 **Supports Consumer-grade GPUs**: The T2V-1.3B model requires only 8.19 GB VRAM, making it compatible with almost all consumer-grade GPUs. It can generate a 5-second 480P video on an RTX 4090 in about 4 minutes (without optimization techniques like quantization). Its performance is even comparable to some closed-source models. - 👍 **Multiple Tasks**: **Wan2.1** excels in Text-to-Video, Image-to-Video, Video Editing, Text-to-Image, and Video-to-Audio, advancing the field of video generation. - 👍 **Visual Text Generation**: **Wan2.1** is the first video model capable of generating both Chinese and English text, featuring robust text generation that enhances its practical applications. - 👍 **Powerful Video VAE**: **Wan-VAE** delivers exceptional efficiency and performance, encoding and decoding 1080P videos of any length while preserving temporal information, making it an ideal foundation for video and image generation. ## Video Demos

## 🔥 Latest News!! * Apr 17, 2025: 👋 We introduce **Wan2.1** [FLF2V](#run-first-last-frame-to-video-generation) with its inference code and weights! * Mar 21, 2025: 👋 We are excited to announce the release of the **Wan2.1** [technical report](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf). We welcome discussions and feedback! * Mar 3, 2025: 👋 **Wan2.1**'s T2V and I2V have been integrated into Diffusers ([T2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanPipeline) | [I2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanImageToVideoPipeline)). Feel free to give it a try! * Feb 27, 2025: 👋 **Wan2.1** has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy! * Feb 25, 2025: 👋 We've released the inference code and weights of **Wan2.1**. ## Community Works If your work has improved **Wan2.1** and you would like more people to see it, please inform us. - [Phantom](https://github.com/Phantom-video/Phantom) has developed a unified video generation framework for single and multi-subject references based on **Wan2.1-T2V-1.3B**. Please refer to [their examples](https://github.com/Phantom-video/Phantom). - [UniAnimate-DiT](https://github.com/ali-vilab/UniAnimate-DiT), based on **Wan2.1-14B-I2V**, has trained a Human image animation model and has open-sourced the inference and training code. Feel free to enjoy it! - [CFG-Zero](https://github.com/WeichenFan/CFG-Zero-star) enhances **Wan2.1** (covering both T2V and I2V models) from the perspective of CFG. - [TeaCache](https://github.com/ali-vilab/TeaCache) now supports **Wan2.1** acceleration, capable of increasing speed by approximately 2x. Feel free to give it a try! - [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) provides more support for **Wan2.1**, including video-to-video, FP8 quantization, VRAM optimization, LoRA training, and more. Please refer to [their examples](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo). ## 📑 Todo List - Wan2.1 Text-to-Video - [x] Multi-GPU Inference code of the 14B and 1.3B models - [x] Checkpoints of the 14B and 1.3B models - [x] Gradio demo - [x] ComfyUI integration - [x] Diffusers integration - [ ] Diffusers + Multi-GPU Inference - Wan2.1 Image-to-Video - [x] Multi-GPU Inference code of the 14B model - [x] Checkpoints of the 14B model - [x] Gradio demo - [x] ComfyUI integration - [x] Diffusers integration - [ ] Diffusers + Multi-GPU Inference - Wan2.1 First-Last-Frame-to-Video - [x] Multi-GPU Inference code of the 14B model - [x] Checkpoints of the 14B model - [x] Gradio demo - [ ] ComfyUI integration - [ ] Diffusers integration - [ ] Diffusers + Multi-GPU Inference ## Quickstart #### Installation Clone the repo: ```sh git clone https://github.com/Wan-Video/Wan2.1.git cd Wan2.1 ``` Install dependencies: ```sh # Ensure torch >= 2.4.0 pip install -r requirements.txt ``` #### Model Download | Models | Download Link | Notes | |--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------| | T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P | I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P | I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P | T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P | FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P > 💡Note: > * The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution. > * For the first-last frame to video generation, we train our model primarily on Chinese text-video pairs. Therefore, we recommend using Chinese prompt to achieve better results. Download models using huggingface-cli: ``` sh pip install "huggingface_hub[cli]" huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./Wan2.1-T2V-14B ``` Download models using modelscope-cli: ``` sh pip install modelscope modelscope download Wan-AI/Wan2.1-T2V-14B --local_dir ./Wan2.1-T2V-14B ``` #### Run Text-to-Video Generation This repository supports two Text-to-Video models (1.3B and 14B) and two resolutions (480P and 720P). The parameters and configurations for these models are as follows:
Task Resolution Model
480P 720P
t2v-14B ✔️ ✔️ Wan2.1-T2V-14B
t2v-1.3B ✔️ Wan2.1-T2V-1.3B
##### (1) Without Prompt Extension To facilitate implementation, we will start with a basic version of the inference process that skips the [prompt extension](#2-using-prompt-extention) step. - Single-GPU inference ``` sh python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." ``` If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model True` and `--t5_cpu` options to reduce GPU memory usage. For example, on an RTX 4090 GPU: ``` sh python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." ``` > 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance. - Multi-GPU inference using FSDP + xDiT USP We use FSDP and [xDiT](https://github.com/xdit-project/xDiT) USP to accelerate inference. * Ulysess Strategy If you want to use [`Ulysses`](https://arxiv.org/abs/2309.14509) strategy, you should set `--ulysses_size $GPU_NUMS`. Note that the `num_heads` should be divisible by `ulysses_size` if you wish to use `Ulysess` strategy. For the 1.3B model, the `num_heads` is `12` which can't be divided by 8 (as most multi-GPU machines have 8 GPUs). Therefore, it is recommended to use `Ring Strategy` instead. * Ring Strategy If you want to use [`Ring`](https://arxiv.org/pdf/2310.01889) strategy, you should set `--ring_size $GPU_NUMS`. Note that the `sequence length` should be divisible by `ring_size` when using the `Ring` strategy. Of course, you can also combine the use of `Ulysses` and `Ring` strategies. ``` sh pip install "xfuser>=0.4.1" torchrun --nproc_per_node=8 generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." ``` ##### (2) Using Prompt Extension Extending the prompts can effectively enrich the details in the generated videos, further enhancing the video quality. Therefore, we recommend enabling prompt extension. We provide the following two methods for prompt extension: - Use the Dashscope API for extension. - Apply for a `dashscope.api_key` in advance ([EN](https://www.alibabacloud.com/help/en/model-studio/getting-started/first-api-call-to-qwen) | [CN](https://help.aliyun.com/zh/model-studio/getting-started/first-api-call-to-qwen)). - Configure the environment variable `DASH_API_KEY` to specify the Dashscope API key. For users of Alibaba Cloud's international site, you also need to set the environment variable `DASH_API_URL` to 'https://dashscope-intl.aliyuncs.com/api/v1'. For more detailed instructions, please refer to the [dashscope document](https://www.alibabacloud.com/help/en/model-studio/developer-reference/use-qwen-by-calling-api?spm=a2c63.p38356.0.i1). - Use the `qwen-plus` model for text-to-video tasks and `qwen-vl-max` for image-to-video tasks. - You can modify the model used for extension with the parameter `--prompt_extend_model`. For example: ```sh DASH_API_KEY=your_key python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'dashscope' --prompt_extend_target_lang 'zh' ``` - Using a local model for extension. - By default, the Qwen model on HuggingFace is used for this extension. Users can choose Qwen models or other models based on the available GPU memory size. - For text-to-video tasks, you can use models like `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-7B-Instruct` and `Qwen/Qwen2.5-3B-Instruct`. - For image-to-video or first-last-frame-to-video tasks, you can use models like `Qwen/Qwen2.5-VL-7B-Instruct` and `Qwen/Qwen2.5-VL-3B-Instruct`. - Larger models generally provide better extension results but require more GPU memory. - You can modify the model used for extension with the parameter `--prompt_extend_model` , allowing you to specify either a local model path or a Hugging Face model. For example: ``` sh python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'zh' ``` ##### (3) Running with Diffusers You can easily inference **Wan2.1**-T2V using Diffusers with the following command: ``` python import torch from diffusers.utils import export_to_video from diffusers import AutoencoderKLWan, WanPipeline from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift) pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) pipe.scheduler = scheduler pipe.to("cuda") prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" output = pipe( prompt=prompt, negative_prompt=negative_prompt, height=720, width=1280, num_frames=81, guidance_scale=5.0, ).frames[0] export_to_video(output, "output.mp4", fps=16) ``` > 💡Note: Please note that this example does not integrate Prompt Extension and distributed inference. We will soon update with the integrated prompt extension and multi-GPU version of Diffusers. ##### (4) Running local gradio ``` sh cd gradio # if one uses dashscope’s API for prompt extension DASH_API_KEY=your_key python t2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir ./Wan2.1-T2V-14B # if one uses a local model for prompt extension python t2v_14B_singleGPU.py --prompt_extend_method 'local_qwen' --ckpt_dir ./Wan2.1-T2V-14B ``` #### Run Image-to-Video Generation Similar to Text-to-Video, Image-to-Video is also divided into processes with and without the prompt extension step. The specific parameters and their corresponding settings are as follows:
Task Resolution Model
480P 720P
i2v-14B ✔️ Wan2.1-I2V-14B-720P
i2v-14B ✔️ Wan2.1-T2V-14B-480P
##### (1) Without Prompt Extension - Single-GPU inference ```sh python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." ``` > 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image. - Multi-GPU inference using FSDP + xDiT USP ```sh pip install "xfuser>=0.4.1" torchrun --nproc_per_node=8 generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." ``` ##### (2) Using Prompt Extension The process of prompt extension can be referenced [here](#2-using-prompt-extention). Run with local prompt extension using `Qwen/Qwen2.5-VL-7B-Instruct`: ``` python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --use_prompt_extend --prompt_extend_model Qwen/Qwen2.5-VL-7B-Instruct --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." ``` Run with remote prompt extension using `dashscope`: ``` DASH_API_KEY=your_key python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --use_prompt_extend --prompt_extend_method 'dashscope' --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." ``` ##### (3) Running with Diffusers You can easily inference **Wan2.1**-I2V using Diffusers with the following command: ``` python import torch import numpy as np from diffusers import AutoencoderKLWan, WanImageToVideoPipeline from diffusers.utils import export_to_video, load_image from transformers import CLIPVisionModel # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers" image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipe = WanImageToVideoPipeline.from_pretrained(model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16) pipe.to("cuda") image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" ) max_area = 720 * 1280 aspect_ratio = image.height / image.width mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value image = image.resize((width, height)) prompt = ( "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." ) negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" output = pipe( image=image, prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=81, guidance_scale=5.0 ).frames[0] export_to_video(output, "output.mp4", fps=16) ``` > 💡Note: Please note that this example does not integrate Prompt Extension and distributed inference. We will soon update with the integrated prompt extension and multi-GPU version of Diffusers. ##### (4) Running local gradio ```sh cd gradio # if one only uses 480P model in gradio DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_480p ./Wan2.1-I2V-14B-480P # if one only uses 720P model in gradio DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_720p ./Wan2.1-I2V-14B-720P # if one uses both 480P and 720P models in gradio DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_480p ./Wan2.1-I2V-14B-480P --ckpt_dir_720p ./Wan2.1-I2V-14B-720P ``` #### Run First-Last-Frame-to-Video Generation First-Last-Frame-to-Video is also divided into processes with and without the prompt extension step. Currently, only 720P is supported. The specific parameters and corresponding settings are as follows:
Task Resolution Model
480P 720P
flf2v-14B ✔️ Wan2.1-FLF2V-14B-720P
##### (1) Without Prompt Extension - Single-GPU inference ```sh python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." ``` > 💡Similar to Image-to-Video, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image. - Multi-GPU inference using FSDP + xDiT USP ```sh pip install "xfuser>=0.4.1" torchrun --nproc_per_node=8 generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." ``` ##### (2) Using Prompt Extension The process of prompt extension can be referenced [here](#2-using-prompt-extention). Run with local prompt extension using `Qwen/Qwen2.5-VL-7B-Instruct`: ``` python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --use_prompt_extend --prompt_extend_model Qwen/Qwen2.5-VL-7B-Instruct --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." ``` Run with remote prompt extension using `dashscope`: ``` DASH_API_KEY=your_key python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --use_prompt_extend --prompt_extend_method 'dashscope' --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." ``` ##### (3) Running local gradio ```sh cd gradio # use 720P model in gradio DASH_API_KEY=your_key python flf2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_720p ./Wan2.1-FLF2V-14B-720P ``` #### Run Text-to-Image Generation Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows: ##### (1) Without Prompt Extension - Single-GPU inference ```sh python generate.py --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人' ``` - Multi-GPU inference using FSDP + xDiT USP ```sh torchrun --nproc_per_node=8 generate.py --dit_fsdp --t5_fsdp --ulysses_size 8 --base_seed 0 --frame_num 1 --task t2i-14B --size 1024*1024 --prompt '一个朴素端庄的美人' --ckpt_dir ./Wan2.1-T2V-14B ``` ##### (2) With Prompt Extention - Single-GPU inference ```sh python generate.py --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人' --use_prompt_extend ``` - Multi-GPU inference using FSDP + xDiT USP ```sh torchrun --nproc_per_node=8 generate.py --dit_fsdp --t5_fsdp --ulysses_size 8 --base_seed 0 --frame_num 1 --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人' --use_prompt_extend ``` ## Manual Evaluation ##### (1) Text-to-Video Evaluation Through manual evaluation, the results generated after prompt extension are superior to those from both closed-source and open-source models.
##### (2) Image-to-Video Evaluation We also conducted extensive manual evaluations to evaluate the performance of the Image-to-Video model, and the results are presented in the table below. The results clearly indicate that **Wan2.1** outperforms both closed-source and open-source models.
## Computational Efficiency on Different GPUs We test the computational efficiency of different **Wan2.1** models on different GPUs in the following table. The results are presented in the format: **Total time (s) / peak GPU memory (GB)**.
> The parameter settings for the tests presented in this table are as follows: > (1) For the 1.3B model on 8 GPUs, set `--ring_size 8` and `--ulysses_size 1`; > (2) For the 14B model on 1 GPU, use `--offload_model True`; > (3) For the 1.3B model on a single 4090 GPU, set `--offload_model True --t5_cpu`; > (4) For all testings, no prompt extension was applied, meaning `--use_prompt_extend` was not enabled. > 💡Note: T2V-14B is slower than I2V-14B because the former samples 50 steps while the latter uses 40 steps. ------- ## Introduction of Wan2.1 **Wan2.1** is designed on the mainstream diffusion transformer paradigm, achieving significant advancements in generative capabilities through a series of innovations. These include our novel spatio-temporal variational autoencoder (VAE), scalable training strategies, large-scale data construction, and automated evaluation metrics. Collectively, these contributions enhance the model’s performance and versatility. ##### (1) 3D Variational Autoencoders We propose a novel 3D causal VAE architecture, termed **Wan-VAE** specifically designed for video generation. By combining multiple strategies, we improve spatio-temporal compression, reduce memory usage, and ensure temporal causality. **Wan-VAE** demonstrates significant advantages in performance efficiency compared to other open-source VAEs. Furthermore, our **Wan-VAE** can encode and decode unlimited-length 1080P videos without losing historical temporal information, making it particularly well-suited for video generation tasks.
##### (2) Video Diffusion DiT **Wan2.1** is designed using the Flow Matching framework within the paradigm of mainstream Diffusion Transformers. Our model's architecture uses the T5 Encoder to encode multilingual text input, with cross-attention in each transformer block embedding the text into the model structure. Additionally, we employ an MLP with a Linear layer and a SiLU layer to process the input time embeddings and predict six modulation parameters individually. This MLP is shared across all transformer blocks, with each block learning a distinct set of biases. Our experimental findings reveal a significant performance improvement with this approach at the same parameter scale.
| Model | Dimension | Input Dimension | Output Dimension | Feedforward Dimension | Frequency Dimension | Number of Heads | Number of Layers | |--------|-----------|-----------------|------------------|-----------------------|---------------------|-----------------|------------------| | 1.3B | 1536 | 16 | 16 | 8960 | 256 | 12 | 30 | | 14B | 5120 | 16 | 16 | 13824 | 256 | 40 | 40 | ##### Data We curated and deduplicated a candidate dataset comprising a vast amount of image and video data. During the data curation process, we designed a four-step data cleaning process, focusing on fundamental dimensions, visual quality and motion quality. Through the robust data processing pipeline, we can easily obtain high-quality, diverse, and large-scale training sets of images and videos. ![figure1](assets/data_for_diff_stage.jpg "figure1") ##### Comparisons to SOTA We compared **Wan2.1** with leading open-source and closed-source models to evaluate the performance. Using our carefully designed set of 1,035 internal prompts, we tested across 14 major dimensions and 26 sub-dimensions. We then compute the total score by performing a weighted calculation on the scores of each dimension, utilizing weights derived from human preferences in the matching process. The detailed results are shown in the table below. These results demonstrate our model's superior performance compared to both open-source and closed-source models. ![figure1](assets/vben_vs_sota.png "figure1") ## Citation If you find our work helpful, please cite us. ``` @article{wan2025, title={Wan: Open and Advanced Large-Scale Video Generative Models}, author={Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu}, journal = {arXiv preprint arXiv:2503.20314}, year={2025} } ``` ## License Agreement The models in this repository are licensed under the Apache 2.0 License. We claim no rights over the your generated contents, granting you the freedom to use them while ensuring that your usage complies with the provisions of this license. You are fully accountable for your use of the models, which must not involve sharing any content that violates applicable laws, causes harm to individuals or groups, disseminates personal information intended for harm, spreads misinformation, or targets vulnerable populations. For a complete list of restrictions and details regarding your rights, please refer to the full text of the [license](LICENSE.txt). ## Acknowledgements We would like to thank the contributors to the [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [Qwen](https://huggingface.co/Qwen), [umt5-xxl](https://huggingface.co/google/umt5-xxl), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research. ## Contact Us If you would like to leave a message to our research or product teams, feel free to join our [Discord](https://discord.gg/AKNgpMK4Yj) or [WeChat groups](https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg)! ## /assets/comp_effic.png Binary file available at https://raw.githubusercontent.com/Wan-Video/Wan2.1/refs/heads/main/assets/comp_effic.png ## /assets/data_for_diff_stage.jpg Binary file available at https://raw.githubusercontent.com/Wan-Video/Wan2.1/refs/heads/main/assets/data_for_diff_stage.jpg ## /assets/i2v_res.png Binary file available at https://raw.githubusercontent.com/Wan-Video/Wan2.1/refs/heads/main/assets/i2v_res.png ## /assets/logo.png Binary file available at https://raw.githubusercontent.com/Wan-Video/Wan2.1/refs/heads/main/assets/logo.png ## /assets/t2v_res.jpg Binary file available at https://raw.githubusercontent.com/Wan-Video/Wan2.1/refs/heads/main/assets/t2v_res.jpg ## /assets/vben_vs_sota.png Binary file available at https://raw.githubusercontent.com/Wan-Video/Wan2.1/refs/heads/main/assets/vben_vs_sota.png ## /assets/video_dit_arch.jpg Binary file available at https://raw.githubusercontent.com/Wan-Video/Wan2.1/refs/heads/main/assets/video_dit_arch.jpg ## /assets/video_vae_res.jpg Binary file available at https://raw.githubusercontent.com/Wan-Video/Wan2.1/refs/heads/main/assets/video_vae_res.jpg ## /examples/flf2v_input_first_frame.png Binary file available at https://raw.githubusercontent.com/Wan-Video/Wan2.1/refs/heads/main/examples/flf2v_input_first_frame.png ## /examples/flf2v_input_last_frame.png Binary file available at https://raw.githubusercontent.com/Wan-Video/Wan2.1/refs/heads/main/examples/flf2v_input_last_frame.png ## /examples/i2v_input.JPG Binary file available at https://raw.githubusercontent.com/Wan-Video/Wan2.1/refs/heads/main/examples/i2v_input.JPG ## /generate.py ```py path="/generate.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse from datetime import datetime import logging import os import sys import warnings warnings.filterwarnings('ignore') import torch, random import torch.distributed as dist from PIL import Image import wan from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_video, cache_image, str2bool EXAMPLE_PROMPT = { "t2v-1.3B": { "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, "t2v-14B": { "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, "t2i-14B": { "prompt": "一个朴素端庄的美人", }, "i2v-14B": { "prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", "image": "examples/i2v_input.JPG", }, "flf2v-14B": { "prompt": "CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。", "first_frame": "examples/flf2v_input_first_frame.png", "last_frame": "examples/flf2v_input_last_frame.png", }, } def _validate_args(args): # Basic check assert args.ckpt_dir is not None, "Please specify the checkpoint directory." assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. if args.sample_steps is None: args.sample_steps = 40 if "i2v" in args.task else 50 if args.sample_shift is None: args.sample_shift = 5.0 if "i2v" in args.task and args.size in ["832*480", "480*832"]: args.sample_shift = 3.0 if "flf2v" in args.task: args.sample_shift = 16 # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. if args.frame_num is None: args.frame_num = 1 if "t2i" in args.task else 81 # T2I frame_num check if "t2i" in args.task: assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}" args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( 0, sys.maxsize) # Size check assert args.size in SUPPORTED_SIZES[ args. task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" def _parse_args(): parser = argparse.ArgumentParser( description="Generate a image or video from a text prompt or image using Wan" ) parser.add_argument( "--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.") parser.add_argument( "--size", type=str, default="1280*720", choices=list(SIZE_CONFIGS.keys()), help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." ) parser.add_argument( "--frame_num", type=int, default=None, help="How many frames to sample from a image or video. The number should be 4n+1" ) parser.add_argument( "--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory.") parser.add_argument( "--offload_model", type=str2bool, default=None, help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." ) parser.add_argument( "--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.") parser.add_argument( "--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.") parser.add_argument( "--t5_fsdp", action="store_true", default=False, help="Whether to use FSDP for T5.") parser.add_argument( "--t5_cpu", action="store_true", default=False, help="Whether to place T5 model on CPU.") parser.add_argument( "--dit_fsdp", action="store_true", default=False, help="Whether to use FSDP for DiT.") parser.add_argument( "--save_file", type=str, default=None, help="The file to save the generated image or video to.") parser.add_argument( "--prompt", type=str, default=None, help="The prompt to generate the image or video from.") parser.add_argument( "--use_prompt_extend", action="store_true", default=False, help="Whether to use prompt extend.") parser.add_argument( "--prompt_extend_method", type=str, default="local_qwen", choices=["dashscope", "local_qwen"], help="The prompt extend method to use.") parser.add_argument( "--prompt_extend_model", type=str, default=None, help="The prompt extend model to use.") parser.add_argument( "--prompt_extend_target_lang", type=str, default="zh", choices=["zh", "en"], help="The target language of prompt extend.") parser.add_argument( "--base_seed", type=int, default=-1, help="The seed to use for generating the image or video.") parser.add_argument( "--image", type=str, default=None, help="[image to video] The image to generate the video from.") parser.add_argument( "--first_frame", type=str, default=None, help="[first-last frame to video] The image (first frame) to generate the video from.") parser.add_argument( "--last_frame", type=str, default=None, help="[first-last frame to video] The image (last frame) to generate the video from.") parser.add_argument( "--sample_solver", type=str, default='unipc', choices=['unipc', 'dpm++'], help="The solver used to sample.") parser.add_argument( "--sample_steps", type=int, default=None, help="The sampling steps.") parser.add_argument( "--sample_shift", type=float, default=None, help="Sampling shift factor for flow matching schedulers.") parser.add_argument( "--sample_guide_scale", type=float, default=5.0, help="Classifier free guidance scale.") args = parser.parse_args() _validate_args(args) return args def _init_logging(rank): # logging if rank == 0: # set format logging.basicConfig( level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s", handlers=[logging.StreamHandler(stream=sys.stdout)]) else: logging.basicConfig(level=logging.ERROR) def generate(args): rank = int(os.getenv("RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) local_rank = int(os.getenv("LOCAL_RANK", 0)) device = local_rank _init_logging(rank) if args.offload_model is None: args.offload_model = False if world_size > 1 else True logging.info( f"offload_model is not specified, set to {args.offload_model}.") if world_size > 1: torch.cuda.set_device(local_rank) dist.init_process_group( backend="nccl", init_method="env://", rank=rank, world_size=world_size) else: assert not ( args.t5_fsdp or args.dit_fsdp ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." assert not ( args.ulysses_size > 1 or args.ring_size > 1 ), f"context parallel are not supported in non-distributed environments." if args.ulysses_size > 1 or args.ring_size > 1: assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." from xfuser.core.distributed import (initialize_model_parallel, init_distributed_environment) init_distributed_environment( rank=dist.get_rank(), world_size=dist.get_world_size()) initialize_model_parallel( sequence_parallel_degree=dist.get_world_size(), ring_degree=args.ring_size, ulysses_degree=args.ulysses_size, ) if args.use_prompt_extend: if args.prompt_extend_method == "dashscope": prompt_expander = DashScopePromptExpander( model_name=args.prompt_extend_model, is_vl="i2v" in args.task or "flf2v" in args.task) elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, is_vl="i2v" in args.task, device=rank) else: raise NotImplementedError( f"Unsupport prompt_extend_method: {args.prompt_extend_method}") cfg = WAN_CONFIGS[args.task] if args.ulysses_size > 1: assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`." logging.info(f"Generation job args: {args}") logging.info(f"Generation model config: {cfg}") if dist.is_initialized(): base_seed = [args.base_seed] if rank == 0 else [None] dist.broadcast_object_list(base_seed, src=0) args.base_seed = base_seed[0] if "t2v" in args.task or "t2i" in args.task: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] logging.info(f"Input prompt: {args.prompt}") if args.use_prompt_extend: logging.info("Extending prompt ...") if rank == 0: prompt_output = prompt_expander( args.prompt, tar_lang=args.prompt_extend_target_lang, seed=args.base_seed) if prompt_output.status == False: logging.info( f"Extending prompt failed: {prompt_output.message}") logging.info("Falling back to original prompt.") input_prompt = args.prompt else: input_prompt = prompt_output.prompt input_prompt = [input_prompt] else: input_prompt = [None] if dist.is_initialized(): dist.broadcast_object_list(input_prompt, src=0) args.prompt = input_prompt[0] logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanT2V pipeline.") wan_t2v = wan.WanT2V( config=cfg, checkpoint_dir=args.ckpt_dir, device_id=device, rank=rank, t5_fsdp=args.t5_fsdp, dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) logging.info( f"Generating {'image' if 't2i' in args.task else 'video'} ...") video = wan_t2v.generate( args.prompt, size=SIZE_CONFIGS[args.size], frame_num=args.frame_num, shift=args.sample_shift, sample_solver=args.sample_solver, sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) elif "i2v" in args.task: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] if args.image is None: args.image = EXAMPLE_PROMPT[args.task]["image"] logging.info(f"Input prompt: {args.prompt}") logging.info(f"Input image: {args.image}") img = Image.open(args.image).convert("RGB") if args.use_prompt_extend: logging.info("Extending prompt ...") if rank == 0: prompt_output = prompt_expander( args.prompt, tar_lang=args.prompt_extend_target_lang, image=img, seed=args.base_seed) if prompt_output.status == False: logging.info( f"Extending prompt failed: {prompt_output.message}") logging.info("Falling back to original prompt.") input_prompt = args.prompt else: input_prompt = prompt_output.prompt input_prompt = [input_prompt] else: input_prompt = [None] if dist.is_initialized(): dist.broadcast_object_list(input_prompt, src=0) args.prompt = input_prompt[0] logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanI2V pipeline.") wan_i2v = wan.WanI2V( config=cfg, checkpoint_dir=args.ckpt_dir, device_id=device, rank=rank, t5_fsdp=args.t5_fsdp, dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) logging.info("Generating video ...") video = wan_i2v.generate( args.prompt, img, max_area=MAX_AREA_CONFIGS[args.size], frame_num=args.frame_num, shift=args.sample_shift, sample_solver=args.sample_solver, sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) else: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] if args.first_frame is None or args.last_frame is None: args.first_frame = EXAMPLE_PROMPT[args.task]["first_frame"] args.last_frame = EXAMPLE_PROMPT[args.task]["last_frame"] logging.info(f"Input prompt: {args.prompt}") logging.info(f"Input first frame: {args.first_frame}") logging.info(f"Input last frame: {args.last_frame}") first_frame = Image.open(args.first_frame).convert("RGB") last_frame = Image.open(args.last_frame).convert("RGB") if args.use_prompt_extend: logging.info("Extending prompt ...") if rank == 0: prompt_output = prompt_expander( args.prompt, tar_lang=args.prompt_extend_target_lang, image=[first_frame, last_frame], seed=args.base_seed) if prompt_output.status == False: logging.info( f"Extending prompt failed: {prompt_output.message}") logging.info("Falling back to original prompt.") input_prompt = args.prompt else: input_prompt = prompt_output.prompt input_prompt = [input_prompt] else: input_prompt = [None] if dist.is_initialized(): dist.broadcast_object_list(input_prompt, src=0) args.prompt = input_prompt[0] logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanFLF2V pipeline.") wan_flf2v = wan.WanFLF2V( config=cfg, checkpoint_dir=args.ckpt_dir, device_id=device, rank=rank, t5_fsdp=args.t5_fsdp, dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) logging.info("Generating video ...") video = wan_flf2v.generate( args.prompt, first_frame, last_frame, max_area=MAX_AREA_CONFIGS[args.size], frame_num=args.frame_num, shift=args.sample_shift, sample_solver=args.sample_solver, sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model ) if rank == 0: if args.save_file is None: formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50] suffix = '.png' if "t2i" in args.task else '.mp4' args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix if "t2i" in args.task: logging.info(f"Saving generated image to {args.save_file}") cache_image( tensor=video.squeeze(1)[None], save_file=args.save_file, nrow=1, normalize=True, value_range=(-1, 1)) else: logging.info(f"Saving generated video to {args.save_file}") cache_video( tensor=video[None], save_file=args.save_file, fps=cfg.sample_fps, nrow=1, normalize=True, value_range=(-1, 1)) logging.info("Finished.") if __name__ == "__main__": args = _parse_args() generate(args) ``` ## /gradio/fl2v_14B_singleGPU.py ```py path="/gradio/fl2v_14B_singleGPU.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import gc import os.path as osp import os import sys import warnings import gradio as gr warnings.filterwarnings('ignore') # Model sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_video # Global Var prompt_expander = None wan_flf2v_720P = None # Button Func def load_model(value): global wan_flf2v_720P if value == '------': print("No model loaded") return '------' if value == '720P': if args.ckpt_dir_720p is None: print("Please specify the checkpoint directory for 720P model") return '------' if wan_flf2v_720P is not None: pass else: gc.collect() print("load 14B-720P flf2v model...", end='', flush=True) cfg = WAN_CONFIGS['flf2v-14B'] wan_flf2v_720P = wan.WanFLF2V( config=cfg, checkpoint_dir=args.ckpt_dir_720p, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, ) print("done", flush=True) return '720P' return value def prompt_enc(prompt, img_first, img_last, tar_lang): print('prompt extend...') if img_first is None or img_last is None: print('Please upload the first and last frames') return prompt global prompt_expander prompt_output = prompt_expander( prompt, image=[img_first, img_last], tar_lang=tar_lang.lower()) if prompt_output.status == False: return prompt else: return prompt_output.prompt def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt): if resolution == '------': print( 'Please specify the resolution ckpt dir or specify the resolution' ) return None else: if resolution == '720P': global wan_flf2v_720P video = wan_flf2v_720P.generate( flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, max_area=MAX_AREA_CONFIGS['720*1280'], shift=shift_scale, sampling_steps=sd_steps, guide_scale=guide_scale, n_prompt=n_prompt, seed=seed, offload_model=True) pass else: print( 'Sorry, currently only 720P is supported.' ) return None cache_video( tensor=video[None], save_file="example.mp4", fps=16, nrow=1, normalize=True, value_range=(-1, 1)) return "example.mp4" # Interface def gradio_interface(): with gr.Blocks() as demo: gr.Markdown("""
Wan2.1 (FLF2V-14B)
Wan: Open and Advanced Large-Scale Video Generative Models.
""") with gr.Row(): with gr.Column(): resolution = gr.Dropdown( label='Resolution', choices=['------', '720P'], value='------') flf2vid_image_first = gr.Image( type="pil", label="Upload First Frame", elem_id="image_upload", ) flf2vid_image_last = gr.Image( type="pil", label="Upload Last Frame", elem_id="image_upload", ) flf2vid_prompt = gr.Textbox( label="Prompt", placeholder="Describe the video you want to generate", ) tar_lang = gr.Radio( choices=["ZH", "EN"], label="Target language of prompt enhance", value="ZH") run_p_button = gr.Button(value="Prompt Enhance") with gr.Accordion("Advanced Options", open=True): with gr.Row(): sd_steps = gr.Slider( label="Diffusion steps", minimum=1, maximum=1000, value=50, step=1) guide_scale = gr.Slider( label="Guide scale", minimum=0, maximum=20, value=5.0, step=1) with gr.Row(): shift_scale = gr.Slider( label="Shift scale", minimum=0, maximum=20, value=5.0, step=1) seed = gr.Slider( label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1) n_prompt = gr.Textbox( label="Negative Prompt", placeholder="Describe the negative prompt you want to add" ) run_flf2v_button = gr.Button("Generate Video") with gr.Column(): result_gallery = gr.Video( label='Generated Video', interactive=False, height=600) resolution.input( fn=load_model, inputs=[resolution], outputs=[resolution]) run_p_button.click( fn=prompt_enc, inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang], outputs=[flf2vid_prompt]) run_flf2v_button.click( fn=flf2v_generation, inputs=[ flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt ], outputs=[result_gallery], ) return demo # Main def _parse_args(): parser = argparse.ArgumentParser( description="Generate a video from a text prompt or image using Gradio") parser.add_argument( "--ckpt_dir_720p", type=str, default=None, help="The path to the checkpoint directory.") parser.add_argument( "--prompt_extend_method", type=str, default="local_qwen", choices=["dashscope", "local_qwen"], help="The prompt extend method to use.") parser.add_argument( "--prompt_extend_model", type=str, default=None, help="The prompt extend model to use.") args = parser.parse_args() assert args.ckpt_dir_720p is not None, "Please specify the checkpoint directory." return args if __name__ == '__main__': args = _parse_args() print("Step1: Init prompt_expander...", end='', flush=True) if args.prompt_extend_method == "dashscope": prompt_expander = DashScopePromptExpander( model_name=args.prompt_extend_model, is_vl=True) elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, is_vl=True, device=0) else: raise NotImplementedError( f"Unsupport prompt_extend_method: {args.prompt_extend_method}") print("done", flush=True) demo = gradio_interface() demo.launch(server_name="0.0.0.0", share=False, server_port=7860) ``` ## /gradio/i2v_14B_singleGPU.py ```py path="/gradio/i2v_14B_singleGPU.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import gc import os.path as osp import os import sys import warnings import gradio as gr warnings.filterwarnings('ignore') # Model sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_video # Global Var prompt_expander = None wan_i2v_480P = None wan_i2v_720P = None # Button Func def load_model(value): global wan_i2v_480P, wan_i2v_720P if value == '------': print("No model loaded") return '------' if value == '720P': if args.ckpt_dir_720p is None: print("Please specify the checkpoint directory for 720P model") return '------' if wan_i2v_720P is not None: pass else: del wan_i2v_480P gc.collect() wan_i2v_480P = None print("load 14B-720P i2v model...", end='', flush=True) cfg = WAN_CONFIGS['i2v-14B'] wan_i2v_720P = wan.WanI2V( config=cfg, checkpoint_dir=args.ckpt_dir_720p, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, ) print("done", flush=True) return '720P' if value == '480P': if args.ckpt_dir_480p is None: print("Please specify the checkpoint directory for 480P model") return '------' if wan_i2v_480P is not None: pass else: del wan_i2v_720P gc.collect() wan_i2v_720P = None print("load 14B-480P i2v model...", end='', flush=True) cfg = WAN_CONFIGS['i2v-14B'] wan_i2v_480P = wan.WanI2V( config=cfg, checkpoint_dir=args.ckpt_dir_480p, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, ) print("done", flush=True) return '480P' return value def prompt_enc(prompt, img, tar_lang): print('prompt extend...') if img is None: print('Please upload an image') return prompt global prompt_expander prompt_output = prompt_expander( prompt, image=img, tar_lang=tar_lang.lower()) if prompt_output.status == False: return prompt else: return prompt_output.prompt def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt): # print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}") if resolution == '------': print( 'Please specify at least one resolution ckpt dir or specify the resolution' ) return None else: if resolution == '720P': global wan_i2v_720P video = wan_i2v_720P.generate( img2vid_prompt, img2vid_image, max_area=MAX_AREA_CONFIGS['720*1280'], shift=shift_scale, sampling_steps=sd_steps, guide_scale=guide_scale, n_prompt=n_prompt, seed=seed, offload_model=True) else: global wan_i2v_480P video = wan_i2v_480P.generate( img2vid_prompt, img2vid_image, max_area=MAX_AREA_CONFIGS['480*832'], shift=shift_scale, sampling_steps=sd_steps, guide_scale=guide_scale, n_prompt=n_prompt, seed=seed, offload_model=True) cache_video( tensor=video[None], save_file="example.mp4", fps=16, nrow=1, normalize=True, value_range=(-1, 1)) return "example.mp4" # Interface def gradio_interface(): with gr.Blocks() as demo: gr.Markdown("""
Wan2.1 (I2V-14B)
Wan: Open and Advanced Large-Scale Video Generative Models.
""") with gr.Row(): with gr.Column(): resolution = gr.Dropdown( label='Resolution', choices=['------', '720P', '480P'], value='------') img2vid_image = gr.Image( type="pil", label="Upload Input Image", elem_id="image_upload", ) img2vid_prompt = gr.Textbox( label="Prompt", placeholder="Describe the video you want to generate", ) tar_lang = gr.Radio( choices=["ZH", "EN"], label="Target language of prompt enhance", value="ZH") run_p_button = gr.Button(value="Prompt Enhance") with gr.Accordion("Advanced Options", open=True): with gr.Row(): sd_steps = gr.Slider( label="Diffusion steps", minimum=1, maximum=1000, value=50, step=1) guide_scale = gr.Slider( label="Guide scale", minimum=0, maximum=20, value=5.0, step=1) with gr.Row(): shift_scale = gr.Slider( label="Shift scale", minimum=0, maximum=10, value=5.0, step=1) seed = gr.Slider( label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1) n_prompt = gr.Textbox( label="Negative Prompt", placeholder="Describe the negative prompt you want to add" ) run_i2v_button = gr.Button("Generate Video") with gr.Column(): result_gallery = gr.Video( label='Generated Video', interactive=False, height=600) resolution.input( fn=load_model, inputs=[resolution], outputs=[resolution]) run_p_button.click( fn=prompt_enc, inputs=[img2vid_prompt, img2vid_image, tar_lang], outputs=[img2vid_prompt]) run_i2v_button.click( fn=i2v_generation, inputs=[ img2vid_prompt, img2vid_image, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt ], outputs=[result_gallery], ) return demo # Main def _parse_args(): parser = argparse.ArgumentParser( description="Generate a video from a text prompt or image using Gradio") parser.add_argument( "--ckpt_dir_720p", type=str, default=None, help="The path to the checkpoint directory.") parser.add_argument( "--ckpt_dir_480p", type=str, default=None, help="The path to the checkpoint directory.") parser.add_argument( "--prompt_extend_method", type=str, default="local_qwen", choices=["dashscope", "local_qwen"], help="The prompt extend method to use.") parser.add_argument( "--prompt_extend_model", type=str, default=None, help="The prompt extend model to use.") args = parser.parse_args() assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory." return args if __name__ == '__main__': args = _parse_args() print("Step1: Init prompt_expander...", end='', flush=True) if args.prompt_extend_method == "dashscope": prompt_expander = DashScopePromptExpander( model_name=args.prompt_extend_model, is_vl=True) elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, is_vl=True, device=0) else: raise NotImplementedError( f"Unsupport prompt_extend_method: {args.prompt_extend_method}") print("done", flush=True) demo = gradio_interface() demo.launch(server_name="0.0.0.0", share=False, server_port=7860) ``` ## /gradio/t2i_14B_singleGPU.py ```py path="/gradio/t2i_14B_singleGPU.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import os.path as osp import os import sys import warnings import gradio as gr warnings.filterwarnings('ignore') # Model sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan.configs import WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_image # Global Var prompt_expander = None wan_t2i = None # Button Func def prompt_enc(prompt, tar_lang): global prompt_expander prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower()) if prompt_output.status == False: return prompt else: return prompt_output.prompt def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt): global wan_t2i # print(f"{txt2img_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}") W = int(resolution.split("*")[0]) H = int(resolution.split("*")[1]) video = wan_t2i.generate( txt2img_prompt, size=(W, H), frame_num=1, shift=shift_scale, sampling_steps=sd_steps, guide_scale=guide_scale, n_prompt=n_prompt, seed=seed, offload_model=True) cache_image( tensor=video.squeeze(1)[None], save_file="example.png", nrow=1, normalize=True, value_range=(-1, 1)) return "example.png" # Interface def gradio_interface(): with gr.Blocks() as demo: gr.Markdown("""
Wan2.1 (T2I-14B)
Wan: Open and Advanced Large-Scale Video Generative Models.
""") with gr.Row(): with gr.Column(): txt2img_prompt = gr.Textbox( label="Prompt", placeholder="Describe the image you want to generate", ) tar_lang = gr.Radio( choices=["ZH", "EN"], label="Target language of prompt enhance", value="ZH") run_p_button = gr.Button(value="Prompt Enhance") with gr.Accordion("Advanced Options", open=True): resolution = gr.Dropdown( label='Resolution(Width*Height)', choices=[ '720*1280', '1280*720', '960*960', '1088*832', '832*1088', '480*832', '832*480', '624*624', '704*544', '544*704' ], value='720*1280') with gr.Row(): sd_steps = gr.Slider( label="Diffusion steps", minimum=1, maximum=1000, value=50, step=1) guide_scale = gr.Slider( label="Guide scale", minimum=0, maximum=20, value=5.0, step=1) with gr.Row(): shift_scale = gr.Slider( label="Shift scale", minimum=0, maximum=10, value=5.0, step=1) seed = gr.Slider( label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1) n_prompt = gr.Textbox( label="Negative Prompt", placeholder="Describe the negative prompt you want to add" ) run_t2i_button = gr.Button("Generate Image") with gr.Column(): result_gallery = gr.Image( label='Generated Image', interactive=False, height=600) run_p_button.click( fn=prompt_enc, inputs=[txt2img_prompt, tar_lang], outputs=[txt2img_prompt]) run_t2i_button.click( fn=t2i_generation, inputs=[ txt2img_prompt, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt ], outputs=[result_gallery], ) return demo # Main def _parse_args(): parser = argparse.ArgumentParser( description="Generate a image from a text prompt or image using Gradio") parser.add_argument( "--ckpt_dir", type=str, default="cache", help="The path to the checkpoint directory.") parser.add_argument( "--prompt_extend_method", type=str, default="local_qwen", choices=["dashscope", "local_qwen"], help="The prompt extend method to use.") parser.add_argument( "--prompt_extend_model", type=str, default=None, help="The prompt extend model to use.") args = parser.parse_args() return args if __name__ == '__main__': args = _parse_args() print("Step1: Init prompt_expander...", end='', flush=True) if args.prompt_extend_method == "dashscope": prompt_expander = DashScopePromptExpander( model_name=args.prompt_extend_model, is_vl=False) elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, is_vl=False, device=0) else: raise NotImplementedError( f"Unsupport prompt_extend_method: {args.prompt_extend_method}") print("done", flush=True) print("Step2: Init 14B t2i model...", end='', flush=True) cfg = WAN_CONFIGS['t2i-14B'] wan_t2i = wan.WanT2V( config=cfg, checkpoint_dir=args.ckpt_dir, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, ) print("done", flush=True) demo = gradio_interface() demo.launch(server_name="0.0.0.0", share=False, server_port=7860) ``` ## /gradio/t2v_1.3B_singleGPU.py ```py path="/gradio/t2v_1.3B_singleGPU.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import os.path as osp import os import sys import warnings import gradio as gr warnings.filterwarnings('ignore') # Model sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan.configs import WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_video # Global Var prompt_expander = None wan_t2v = None # Button Func def prompt_enc(prompt, tar_lang): global prompt_expander prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower()) if prompt_output.status == False: return prompt else: return prompt_output.prompt def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt): global wan_t2v # print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}") W = int(resolution.split("*")[0]) H = int(resolution.split("*")[1]) video = wan_t2v.generate( txt2vid_prompt, size=(W, H), shift=shift_scale, sampling_steps=sd_steps, guide_scale=guide_scale, n_prompt=n_prompt, seed=seed, offload_model=True) cache_video( tensor=video[None], save_file="example.mp4", fps=16, nrow=1, normalize=True, value_range=(-1, 1)) return "example.mp4" # Interface def gradio_interface(): with gr.Blocks() as demo: gr.Markdown("""
Wan2.1 (T2V-1.3B)
Wan: Open and Advanced Large-Scale Video Generative Models.
""") with gr.Row(): with gr.Column(): txt2vid_prompt = gr.Textbox( label="Prompt", placeholder="Describe the video you want to generate", ) tar_lang = gr.Radio( choices=["ZH", "EN"], label="Target language of prompt enhance", value="ZH") run_p_button = gr.Button(value="Prompt Enhance") with gr.Accordion("Advanced Options", open=True): resolution = gr.Dropdown( label='Resolution(Width*Height)', choices=[ '480*832', '832*480', '624*624', '704*544', '544*704', ], value='480*832') with gr.Row(): sd_steps = gr.Slider( label="Diffusion steps", minimum=1, maximum=1000, value=50, step=1) guide_scale = gr.Slider( label="Guide scale", minimum=0, maximum=20, value=6.0, step=1) with gr.Row(): shift_scale = gr.Slider( label="Shift scale", minimum=0, maximum=20, value=8.0, step=1) seed = gr.Slider( label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1) n_prompt = gr.Textbox( label="Negative Prompt", placeholder="Describe the negative prompt you want to add" ) run_t2v_button = gr.Button("Generate Video") with gr.Column(): result_gallery = gr.Video( label='Generated Video', interactive=False, height=600) run_p_button.click( fn=prompt_enc, inputs=[txt2vid_prompt, tar_lang], outputs=[txt2vid_prompt]) run_t2v_button.click( fn=t2v_generation, inputs=[ txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt ], outputs=[result_gallery], ) return demo # Main def _parse_args(): parser = argparse.ArgumentParser( description="Generate a video from a text prompt or image using Gradio") parser.add_argument( "--ckpt_dir", type=str, default="cache", help="The path to the checkpoint directory.") parser.add_argument( "--prompt_extend_method", type=str, default="local_qwen", choices=["dashscope", "local_qwen"], help="The prompt extend method to use.") parser.add_argument( "--prompt_extend_model", type=str, default=None, help="The prompt extend model to use.") args = parser.parse_args() return args if __name__ == '__main__': args = _parse_args() print("Step1: Init prompt_expander...", end='', flush=True) if args.prompt_extend_method == "dashscope": prompt_expander = DashScopePromptExpander( model_name=args.prompt_extend_model, is_vl=False) elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, is_vl=False, device=0) else: raise NotImplementedError( f"Unsupport prompt_extend_method: {args.prompt_extend_method}") print("done", flush=True) print("Step2: Init 1.3B t2v model...", end='', flush=True) cfg = WAN_CONFIGS['t2v-1.3B'] wan_t2v = wan.WanT2V( config=cfg, checkpoint_dir=args.ckpt_dir, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, ) print("done", flush=True) demo = gradio_interface() demo.launch(server_name="0.0.0.0", share=False, server_port=7860) ``` ## /gradio/t2v_14B_singleGPU.py ```py path="/gradio/t2v_14B_singleGPU.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import os.path as osp import os import sys import warnings import gradio as gr warnings.filterwarnings('ignore') # Model sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan.configs import WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_video # Global Var prompt_expander = None wan_t2v = None # Button Func def prompt_enc(prompt, tar_lang): global prompt_expander prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower()) if prompt_output.status == False: return prompt else: return prompt_output.prompt def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt): global wan_t2v # print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}") W = int(resolution.split("*")[0]) H = int(resolution.split("*")[1]) video = wan_t2v.generate( txt2vid_prompt, size=(W, H), shift=shift_scale, sampling_steps=sd_steps, guide_scale=guide_scale, n_prompt=n_prompt, seed=seed, offload_model=True) cache_video( tensor=video[None], save_file="example.mp4", fps=16, nrow=1, normalize=True, value_range=(-1, 1)) return "example.mp4" # Interface def gradio_interface(): with gr.Blocks() as demo: gr.Markdown("""
Wan2.1 (T2V-14B)
Wan: Open and Advanced Large-Scale Video Generative Models.
""") with gr.Row(): with gr.Column(): txt2vid_prompt = gr.Textbox( label="Prompt", placeholder="Describe the video you want to generate", ) tar_lang = gr.Radio( choices=["ZH", "EN"], label="Target language of prompt enhance", value="ZH") run_p_button = gr.Button(value="Prompt Enhance") with gr.Accordion("Advanced Options", open=True): resolution = gr.Dropdown( label='Resolution(Width*Height)', choices=[ '720*1280', '1280*720', '960*960', '1088*832', '832*1088', '480*832', '832*480', '624*624', '704*544', '544*704' ], value='720*1280') with gr.Row(): sd_steps = gr.Slider( label="Diffusion steps", minimum=1, maximum=1000, value=50, step=1) guide_scale = gr.Slider( label="Guide scale", minimum=0, maximum=20, value=5.0, step=1) with gr.Row(): shift_scale = gr.Slider( label="Shift scale", minimum=0, maximum=10, value=5.0, step=1) seed = gr.Slider( label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1) n_prompt = gr.Textbox( label="Negative Prompt", placeholder="Describe the negative prompt you want to add" ) run_t2v_button = gr.Button("Generate Video") with gr.Column(): result_gallery = gr.Video( label='Generated Video', interactive=False, height=600) run_p_button.click( fn=prompt_enc, inputs=[txt2vid_prompt, tar_lang], outputs=[txt2vid_prompt]) run_t2v_button.click( fn=t2v_generation, inputs=[ txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt ], outputs=[result_gallery], ) return demo # Main def _parse_args(): parser = argparse.ArgumentParser( description="Generate a video from a text prompt or image using Gradio") parser.add_argument( "--ckpt_dir", type=str, default="cache", help="The path to the checkpoint directory.") parser.add_argument( "--prompt_extend_method", type=str, default="local_qwen", choices=["dashscope", "local_qwen"], help="The prompt extend method to use.") parser.add_argument( "--prompt_extend_model", type=str, default=None, help="The prompt extend model to use.") args = parser.parse_args() return args if __name__ == '__main__': args = _parse_args() print("Step1: Init prompt_expander...", end='', flush=True) if args.prompt_extend_method == "dashscope": prompt_expander = DashScopePromptExpander( model_name=args.prompt_extend_model, is_vl=False) elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, is_vl=False, device=0) else: raise NotImplementedError( f"Unsupport prompt_extend_method: {args.prompt_extend_method}") print("done", flush=True) print("Step2: Init 14B t2v model...", end='', flush=True) cfg = WAN_CONFIGS['t2v-14B'] wan_t2v = wan.WanT2V( config=cfg, checkpoint_dir=args.ckpt_dir, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, ) print("done", flush=True) demo = gradio_interface() demo.launch(server_name="0.0.0.0", share=False, server_port=7860) ``` ## /pyproject.toml ```toml path="/pyproject.toml" [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "wan" version = "2.1.0" description = "Wan: Open and Advanced Large-Scale Video Generative Models" authors = [ { name = "Wan Team", email = "wan.ai@alibabacloud.com" } ] license = { file = "LICENSE.txt" } readme = "README.md" requires-python = ">=3.10,<4.0" dependencies = [ "torch>=2.4.0", "torchvision>=0.19.0", "opencv-python>=4.9.0.80", "diffusers>=0.31.0", "transformers>=4.49.0", "tokenizers>=0.20.3", "accelerate>=1.1.1", "tqdm", "imageio", "easydict", "ftfy", "dashscope", "imageio-ffmpeg", "flash_attn", "gradio>=5.0.0", "numpy>=1.23.5,<2" ] [project.optional-dependencies] dev = [ "pytest", "black", "flake8", "isort", "mypy", "huggingface-hub[cli]" ] [project.urls] homepage = "https://wanxai.com" documentation = "https://github.com/Wan-Video/Wan2.1" repository = "https://github.com/Wan-Video/Wan2.1" huggingface = "https://huggingface.co/Wan-AI/" modelscope = "https://modelscope.cn/organization/Wan-AI" discord = "https://discord.gg/p5XbdQV7" [tool.setuptools] packages = ["wan"] [tool.setuptools.package-data] "wan" = ["**/*.py"] [tool.black] line-length = 88 [tool.isort] profile = "black" [tool.mypy] strict = true ``` ## /requirements.txt torch>=2.4.0 torchvision>=0.19.0 opencv-python>=4.9.0.80 diffusers>=0.31.0 transformers>=4.49.0 tokenizers>=0.20.3 accelerate>=1.1.1 tqdm imageio easydict ftfy dashscope imageio-ffmpeg flash_attn gradio>=5.0.0 numpy>=1.23.5,<2 ## /tests/README.md Put all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify the max GPU number you want to use. ```bash bash ./test.sh ``` ## /tests/test.sh ```sh path="/tests/test.sh" #!/bin/bash if [ "$#" -eq 2 ]; then MODEL_DIR=$(realpath "$1") GPUS=$2 else echo "Usage: $0 " exit 1 fi SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" REPO_ROOT="$(dirname "$SCRIPT_DIR")" cd "$REPO_ROOT" || exit 1 PY_FILE=./generate.py function t2v_1_3B() { T2V_1_3B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-1.3B" # 1-GPU Test echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B 1-GPU Test: " python $PY_FILE --task t2v-1.3B --size 480*832 --ckpt_dir $T2V_1_3B_CKPT_DIR # Multiple GPU Test echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU Test: " torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend local_qwen: " torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en" if [ -n "${DASH_API_KEY+x}" ]; then echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend dashscope: " torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope" else echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test." fi } function t2v_14B() { T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B" # 1-GPU Test echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B 1-GPU Test: " python $PY_FILE --task t2v-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR # Multiple GPU Test echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU Test: " torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU, prompt extend local_qwen: " torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en" } function t2i_14B() { T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B" # 1-GPU Test echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B 1-GPU Test: " python $PY_FILE --task t2i-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR # Multiple GPU Test echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU Test: " torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU, prompt extend local_qwen: " torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en" } function i2v_14B_480p() { I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-480P" echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: " python $PY_FILE --task i2v-14B --size 832*480 --ckpt_dir $I2V_14B_CKPT_DIR # Multiple GPU Test echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: " torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: " torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en" if [ -n "${DASH_API_KEY+x}" ]; then echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: " torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope" else echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test." fi } function i2v_14B_720p() { I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P" # 1-GPU Test echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: " python $PY_FILE --task i2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR # Multiple GPU Test echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: " torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS } t2i_14B t2v_1_3B t2v_14B i2v_14B_480p i2v_14B_720p ``` ## /wan/__init__.py ```py path="/wan/__init__.py" from . import configs, distributed, modules from .image2video import WanI2V from .text2video import WanT2V from .first_last_frame2video import WanFLF2V ``` ## /wan/configs/__init__.py ```py path="/wan/configs/__init__.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import copy import os os.environ['TOKENIZERS_PARALLELISM'] = 'false' from .wan_i2v_14B import i2v_14B from .wan_t2v_1_3B import t2v_1_3B from .wan_t2v_14B import t2v_14B # the config of t2i_14B is the same as t2v_14B t2i_14B = copy.deepcopy(t2v_14B) t2i_14B.__name__ = 'Config: Wan T2I 14B' # the config of flf2v_14B is the same as i2v_14B flf2v_14B = copy.deepcopy(i2v_14B) flf2v_14B.__name__ = 'Config: Wan FLF2V 14B' flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt WAN_CONFIGS = { 't2v-14B': t2v_14B, 't2v-1.3B': t2v_1_3B, 'i2v-14B': i2v_14B, 't2i-14B': t2i_14B, 'flf2v-14B': flf2v_14B } SIZE_CONFIGS = { '720*1280': (720, 1280), '1280*720': (1280, 720), '480*832': (480, 832), '832*480': (832, 480), '1024*1024': (1024, 1024), } MAX_AREA_CONFIGS = { '720*1280': 720 * 1280, '1280*720': 1280 * 720, '480*832': 480 * 832, '832*480': 832 * 480, } SUPPORTED_SIZES = { 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 't2v-1.3B': ('480*832', '832*480'), 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 't2i-14B': tuple(SIZE_CONFIGS.keys()), } ``` ## /wan/configs/shared_config.py ```py path="/wan/configs/shared_config.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch from easydict import EasyDict #------------------------ Wan shared config ------------------------# wan_shared_cfg = EasyDict() # t5 wan_shared_cfg.t5_model = 'umt5_xxl' wan_shared_cfg.t5_dtype = torch.bfloat16 wan_shared_cfg.text_len = 512 # transformer wan_shared_cfg.param_dtype = torch.bfloat16 # inference wan_shared_cfg.num_train_timesteps = 1000 wan_shared_cfg.sample_fps = 16 wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' ``` ## /wan/configs/wan_i2v_14B.py ```py path="/wan/configs/wan_i2v_14B.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch from easydict import EasyDict from .shared_config import wan_shared_cfg #------------------------ Wan I2V 14B ------------------------# i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') i2v_14B.update(wan_shared_cfg) i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' i2v_14B.t5_tokenizer = 'google/umt5-xxl' # clip i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' i2v_14B.clip_dtype = torch.float16 i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' i2v_14B.clip_tokenizer = 'xlm-roberta-large' # vae i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' i2v_14B.vae_stride = (4, 8, 8) # transformer i2v_14B.patch_size = (1, 2, 2) i2v_14B.dim = 5120 i2v_14B.ffn_dim = 13824 i2v_14B.freq_dim = 256 i2v_14B.num_heads = 40 i2v_14B.num_layers = 40 i2v_14B.window_size = (-1, -1) i2v_14B.qk_norm = True i2v_14B.cross_attn_norm = True i2v_14B.eps = 1e-6 ``` ## /wan/configs/wan_t2v_14B.py ```py path="/wan/configs/wan_t2v_14B.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. from easydict import EasyDict from .shared_config import wan_shared_cfg #------------------------ Wan T2V 14B ------------------------# t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') t2v_14B.update(wan_shared_cfg) # t5 t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' t2v_14B.t5_tokenizer = 'google/umt5-xxl' # vae t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' t2v_14B.vae_stride = (4, 8, 8) # transformer t2v_14B.patch_size = (1, 2, 2) t2v_14B.dim = 5120 t2v_14B.ffn_dim = 13824 t2v_14B.freq_dim = 256 t2v_14B.num_heads = 40 t2v_14B.num_layers = 40 t2v_14B.window_size = (-1, -1) t2v_14B.qk_norm = True t2v_14B.cross_attn_norm = True t2v_14B.eps = 1e-6 ``` ## /wan/configs/wan_t2v_1_3B.py ```py path="/wan/configs/wan_t2v_1_3B.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. from easydict import EasyDict from .shared_config import wan_shared_cfg #------------------------ Wan T2V 1.3B ------------------------# t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') t2v_1_3B.update(wan_shared_cfg) # t5 t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' # vae t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' t2v_1_3B.vae_stride = (4, 8, 8) # transformer t2v_1_3B.patch_size = (1, 2, 2) t2v_1_3B.dim = 1536 t2v_1_3B.ffn_dim = 8960 t2v_1_3B.freq_dim = 256 t2v_1_3B.num_heads = 12 t2v_1_3B.num_layers = 30 t2v_1_3B.window_size = (-1, -1) t2v_1_3B.qk_norm = True t2v_1_3B.cross_attn_norm = True t2v_1_3B.eps = 1e-6 ``` ## /wan/distributed/__init__.py ```py path="/wan/distributed/__init__.py" ``` ## /wan/distributed/fsdp.py ```py path="/wan/distributed/fsdp.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import gc from functools import partial import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy from torch.distributed.utils import _free_storage def shard_model( model, device_id, param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32, process_group=None, sharding_strategy=ShardingStrategy.FULL_SHARD, sync_module_states=True, ): model = FSDP( module=model, process_group=process_group, sharding_strategy=sharding_strategy, auto_wrap_policy=partial( lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), mixed_precision=MixedPrecision( param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype), device_id=device_id, sync_module_states=sync_module_states) return model def free_model(model): for m in model.modules(): if isinstance(m, FSDP): _free_storage(m._handle.flat_param.data) del model gc.collect() torch.cuda.empty_cache() ``` ## /wan/distributed/xdit_context_parallel.py ```py path="/wan/distributed/xdit_context_parallel.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch import torch.cuda.amp as amp from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention from ..modules.model import sinusoidal_embedding_1d def pad_freqs(original_tensor, target_len): seq_len, s1, s2 = original_tensor.shape pad_size = target_len - seq_len padding_tensor = torch.ones( pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device) padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) return padded_tensor @amp.autocast(enabled=False) def rope_apply(x, grid_sizes, freqs): """ x: [B, L, N, C]. grid_sizes: [B, 3]. freqs: [M, C // 2]. """ s, n, c = x.size(1), x.size(2), x.size(3) // 2 # split freqs freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) # loop over samples output = [] for i, (f, h, w) in enumerate(grid_sizes.tolist()): seq_len = f * h * w # precompute multipliers x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( s, n, -1, 2)) freqs_i = torch.cat([ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(seq_len, 1, -1) # apply rotary embedding sp_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() freqs_i = pad_freqs(freqs_i, s * sp_size) s_per_rank = s freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) x_i = torch.cat([x_i, x[i, s:]]) # append to collection output.append(x_i) return torch.stack(output).float() def usp_dit_forward( self, x, t, context, seq_len, clip_fea=None, y=None, ): """ x: A list of videos each with shape [C, T, H, W]. t: [B]. context: A list of text embeddings each with shape [L, C]. """ if self.model_type == 'i2v': assert clip_fea is not None and y is not None # params device = self.patch_embedding.weight.device if self.freqs.device != device: self.freqs = self.freqs.to(device) if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] # embeddings x = [self.patch_embedding(u.unsqueeze(0)) for u in x] grid_sizes = torch.stack( [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) x = [u.flatten(2).transpose(1, 2) for u in x] seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) assert seq_lens.max() <= seq_len x = torch.cat([ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x ]) # time embeddings with amp.autocast(dtype=torch.float32): e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t).float()) e0 = self.time_projection(e).unflatten(1, (6, self.dim)) assert e.dtype == torch.float32 and e0.dtype == torch.float32 # context context_lens = None context = self.text_embedding( torch.stack([ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context ])) if clip_fea is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) # arguments kwargs = dict( e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=self.freqs, context=context, context_lens=context_lens) # Context Parallel x = torch.chunk( x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] for block in self.blocks: x = block(x, **kwargs) # head x = self.head(x, e) # Context Parallel x = get_sp_group().all_gather(x, dim=1) # unpatchify x = self.unpatchify(x, grid_sizes) return [u.float() for u in x] def usp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16): b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim half_dtypes = (torch.float16, torch.bfloat16) def half(x): return x if x.dtype in half_dtypes else x.to(dtype) # query, key, value function def qkv_fn(x): q = self.norm_q(self.q(x)).view(b, s, n, d) k = self.norm_k(self.k(x)).view(b, s, n, d) v = self.v(x).view(b, s, n, d) return q, k, v q, k, v = qkv_fn(x) q = rope_apply(q, grid_sizes, freqs) k = rope_apply(k, grid_sizes, freqs) # TODO: We should use unpaded q,k,v for attention. # k_lens = seq_lens // get_sequence_parallel_world_size() # if k_lens is not None: # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) x = xFuserLongContextAttention()( None, query=half(q), key=half(k), value=half(v), window_size=self.window_size) # TODO: padding after attention. # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) # output x = x.flatten(2) x = self.o(x) return x ``` ## /wan/first_last_frame2video.py ```py path="/wan/first_last_frame2video.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import gc import logging import math import os import random import sys import types from contextlib import contextmanager from functools import partial import numpy as np import torch import torch.cuda.amp as amp import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm from .distributed.fsdp import shard_model from .modules.clip import CLIPModel from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler class WanFLF2V: def __init__( self, config, checkpoint_dir, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, t5_cpu=False, init_on_cpu=True, ): r""" Initializes the image-to-video generation model components. Args: config (EasyDict): Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints device_id (`int`, *optional*, defaults to 0): Id of target GPU device rank (`int`, *optional*, defaults to 0): Process rank for distributed training t5_fsdp (`bool`, *optional*, defaults to False): Enable FSDP sharding for T5 model dit_fsdp (`bool`, *optional*, defaults to False): Enable FSDP sharding for DiT model use_usp (`bool`, *optional*, defaults to False): Enable distribution strategy of USP. t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. """ self.device = torch.device(f"cuda:{device_id}") self.config = config self.rank = rank self.use_usp = use_usp self.t5_cpu = t5_cpu self.num_train_timesteps = config.num_train_timesteps self.param_dtype = config.param_dtype shard_fn = partial(shard_model, device_id=device_id) self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, device=torch.device('cpu'), checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), shard_fn=shard_fn if t5_fsdp else None, ) self.vae_stride = config.vae_stride self.patch_size = config.patch_size self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) self.clip = CLIPModel( dtype=config.clip_dtype, device=self.device, checkpoint_path=os.path.join(checkpoint_dir, config.clip_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) logging.info(f"Creating WanModel from {checkpoint_dir}") self.model = WanModel.from_pretrained(checkpoint_dir) self.model.eval().requires_grad_(False) if t5_fsdp or dit_fsdp or use_usp: init_on_cpu = False if use_usp: from xfuser.core.distributed import \ get_sequence_parallel_world_size from .distributed.xdit_context_parallel import (usp_attn_forward, usp_dit_forward) for block in self.model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) self.model.forward = types.MethodType(usp_dit_forward, self.model) self.sp_size = get_sequence_parallel_world_size() else: self.sp_size = 1 if dist.is_initialized(): dist.barrier() if dit_fsdp: self.model = shard_fn(self.model) else: if not init_on_cpu: self.model.to(self.device) self.sample_neg_prompt = config.sample_neg_prompt def generate(self, input_prompt, first_frame, last_frame, max_area=720 * 1280, frame_num=81, shift=16, sample_solver='unipc', sampling_steps=50, guide_scale=5.5, n_prompt="", seed=-1, offload_model=True): r""" Generates video frames from input first-last frame and text prompt using diffusion process. Args: input_prompt (`str`): Text prompt for content generation. first_frame (PIL.Image.Image): Input image tensor. Shape: [3, H, W] last_frame (PIL.Image.Image): Input image tensor. Shape: [3, H, W] [NOTE] If the sizes of first_frame and last_frame are mismatched, last_frame will be cropped & resized to match first_frame. max_area (`int`, *optional*, defaults to 720*1280): Maximum pixel area for latent space calculation. Controls video resolution scaling frame_num (`int`, *optional*, defaults to 81): How many frames to sample from a video. The number should be 4n+1 shift (`float`, *optional*, defaults to 5.0): Noise schedule shift parameter. Affects temporal dynamics [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. sample_solver (`str`, *optional*, defaults to 'unipc'): Solver used to sample the video. sampling_steps (`int`, *optional*, defaults to 40): Number of diffusion sampling steps. Higher values improve quality but slow generation guide_scale (`float`, *optional*, defaults 5.0): Classifier-free guidance scale. Controls prompt adherence vs. creativity n_prompt (`str`, *optional*, defaults to ""): Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` seed (`int`, *optional*, defaults to -1): Random seed for noise generation. If -1, use random seed offload_model (`bool`, *optional*, defaults to True): If True, offloads models to CPU during generation to save VRAM Returns: torch.Tensor: Generated video frames tensor. Dimensions: (C, N H, W) where: - C: Color channels (3 for RGB) - N: Number of frames (81) - H: Frame height (from max_area) - W: Frame width from max_area) """ first_frame_size = first_frame.size last_frame_size = last_frame.size first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(self.device) last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(self.device) F = frame_num first_frame_h, first_frame_w = first_frame.shape[1:] aspect_ratio = first_frame_h / first_frame_w lat_h = round( np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]) lat_w = round( np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]) first_frame_h = lat_h * self.vae_stride[1] first_frame_w = lat_w * self.vae_stride[2] if first_frame_size != last_frame_size: # 1. resize last_frame_resize_ratio = max( first_frame_size[0] / last_frame_size[0], first_frame_size[1] / last_frame_size[1] ) last_frame_size = [ round(last_frame_size[0] * last_frame_resize_ratio), round(last_frame_size[1] * last_frame_resize_ratio), ] # 2. center crop last_frame = TF.center_crop(last_frame, last_frame_size) max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2]) max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size seed = seed if seed >= 0 else random.randint(0, sys.maxsize) seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) noise = torch.randn( 16, (F - 1) // 4 + 1, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) msk[:, 1: -1] = 0 msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.transpose(1, 2)[0] if n_prompt == "": n_prompt = self.sample_neg_prompt # preprocess if not self.t5_cpu: self.text_encoder.model.to(self.device) context = self.text_encoder([input_prompt], self.device) context_null = self.text_encoder([n_prompt], self.device) if offload_model: self.text_encoder.model.cpu() else: context = self.text_encoder([input_prompt], torch.device('cpu')) context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null] self.clip.model.to(self.device) clip_context = self.clip.visual([first_frame[:, None, :, :], last_frame[:, None, :, :]]) if offload_model: self.clip.model.cpu() y = self.vae.encode([ torch.concat([ torch.nn.functional.interpolate( first_frame[None].cpu(), size=(first_frame_h, first_frame_w), mode='bicubic' ).transpose(0, 1), torch.zeros(3, F - 2, first_frame_h, first_frame_w), torch.nn.functional.interpolate( last_frame[None].cpu(), size=(first_frame_h, first_frame_w), mode='bicubic' ).transpose(0, 1), ], dim=1).to(self.device) ])[0] y = torch.concat([msk, y]) @contextmanager def noop_no_sync(): yield no_sync = getattr(self.model, 'no_sync', noop_no_sync) # evaluation mode with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): if sample_solver == 'unipc': sample_scheduler = FlowUniPCMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False) sample_scheduler.set_timesteps( sampling_steps, device=self.device, shift=shift) timesteps = sample_scheduler.timesteps elif sample_solver == 'dpm++': sample_scheduler = FlowDPMSolverMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False) sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) timesteps, _ = retrieve_timesteps( sample_scheduler, device=self.device, sigmas=sampling_sigmas) else: raise NotImplementedError("Unsupported solver.") # sample videos latent = noise arg_c = { 'context': [context[0]], 'clip_fea': clip_context, 'seq_len': max_seq_len, 'y': [y], } arg_null = { 'context': context_null, 'clip_fea': clip_context, 'seq_len': max_seq_len, 'y': [y], } if offload_model: torch.cuda.empty_cache() self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): latent_model_input = [latent.to(self.device)] timestep = [t] timestep = torch.stack(timestep).to(self.device) noise_pred_cond = self.model( latent_model_input, t=timestep, **arg_c)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: torch.cuda.empty_cache() noise_pred_uncond = self.model( latent_model_input, t=timestep, **arg_null)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: torch.cuda.empty_cache() noise_pred = noise_pred_uncond + guide_scale * ( noise_pred_cond - noise_pred_uncond) latent = latent.to( torch.device('cpu') if offload_model else self.device) temp_x0 = sample_scheduler.step( noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=seed_g)[0] latent = temp_x0.squeeze(0) x0 = [latent.to(self.device)] del latent_model_input, timestep if offload_model: self.model.cpu() torch.cuda.empty_cache() if self.rank == 0: videos = self.vae.decode(x0) del noise, latent del sample_scheduler if offload_model: gc.collect() torch.cuda.synchronize() if dist.is_initialized(): dist.barrier() return videos[0] if self.rank == 0 else None ``` ## /wan/image2video.py ```py path="/wan/image2video.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import gc import logging import math import os import random import sys import types from contextlib import contextmanager from functools import partial import numpy as np import torch import torch.cuda.amp as amp import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm from .distributed.fsdp import shard_model from .modules.clip import CLIPModel from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler class WanI2V: def __init__( self, config, checkpoint_dir, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, t5_cpu=False, init_on_cpu=True, ): r""" Initializes the image-to-video generation model components. Args: config (EasyDict): Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints device_id (`int`, *optional*, defaults to 0): Id of target GPU device rank (`int`, *optional*, defaults to 0): Process rank for distributed training t5_fsdp (`bool`, *optional*, defaults to False): Enable FSDP sharding for T5 model dit_fsdp (`bool`, *optional*, defaults to False): Enable FSDP sharding for DiT model use_usp (`bool`, *optional*, defaults to False): Enable distribution strategy of USP. t5_cpu (`bool`, *optional*, defaults to False): Whether to place T5 model on CPU. Only works without t5_fsdp. init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. """ self.device = torch.device(f"cuda:{device_id}") self.config = config self.rank = rank self.use_usp = use_usp self.t5_cpu = t5_cpu self.num_train_timesteps = config.num_train_timesteps self.param_dtype = config.param_dtype shard_fn = partial(shard_model, device_id=device_id) self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, device=torch.device('cpu'), checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), shard_fn=shard_fn if t5_fsdp else None, ) self.vae_stride = config.vae_stride self.patch_size = config.patch_size self.vae = WanVAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), device=self.device) self.clip = CLIPModel( dtype=config.clip_dtype, device=self.device, checkpoint_path=os.path.join(checkpoint_dir, config.clip_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) logging.info(f"Creating WanModel from {checkpoint_dir}") self.model = WanModel.from_pretrained(checkpoint_dir) self.model.eval().requires_grad_(False) if t5_fsdp or dit_fsdp or use_usp: init_on_cpu = False if use_usp: from xfuser.core.distributed import \ get_sequence_parallel_world_size from .distributed.xdit_context_parallel import (usp_attn_forward, usp_dit_forward) for block in self.model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) self.model.forward = types.MethodType(usp_dit_forward, self.model) self.sp_size = get_sequence_parallel_world_size() else: self.sp_size = 1 if dist.is_initialized(): dist.barrier() if dit_fsdp: self.model = shard_fn(self.model) else: if not init_on_cpu: self.model.to(self.device) self.sample_neg_prompt = config.sample_neg_prompt def generate(self, input_prompt, img, max_area=720 * 1280, frame_num=81, shift=5.0, sample_solver='unipc', sampling_steps=40, guide_scale=5.0, n_prompt="", seed=-1, offload_model=True): r""" Generates video frames from input image and text prompt using diffusion process. Args: input_prompt (`str`): Text prompt for content generation. img (PIL.Image.Image): Input image tensor. Shape: [3, H, W] max_area (`int`, *optional*, defaults to 720*1280): Maximum pixel area for latent space calculation. Controls video resolution scaling frame_num (`int`, *optional*, defaults to 81): How many frames to sample from a video. The number should be 4n+1 shift (`float`, *optional*, defaults to 5.0): Noise schedule shift parameter. Affects temporal dynamics [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. sample_solver (`str`, *optional*, defaults to 'unipc'): Solver used to sample the video. sampling_steps (`int`, *optional*, defaults to 40): Number of diffusion sampling steps. Higher values improve quality but slow generation guide_scale (`float`, *optional*, defaults 5.0): Classifier-free guidance scale. Controls prompt adherence vs. creativity n_prompt (`str`, *optional*, defaults to ""): Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` seed (`int`, *optional*, defaults to -1): Random seed for noise generation. If -1, use random seed offload_model (`bool`, *optional*, defaults to True): If True, offloads models to CPU during generation to save VRAM Returns: torch.Tensor: Generated video frames tensor. Dimensions: (C, N H, W) where: - C: Color channels (3 for RGB) - N: Number of frames (81) - H: Frame height (from max_area) - W: Frame width from max_area) """ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) F = frame_num h, w = img.shape[1:] aspect_ratio = h / w lat_h = round( np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]) lat_w = round( np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]) h = lat_h * self.vae_stride[1] w = lat_w * self.vae_stride[2] max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2]) max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size seed = seed if seed >= 0 else random.randint(0, sys.maxsize) seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) noise = torch.randn( 16, (F - 1) // 4 + 1, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) msk[:, 1:] = 0 msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.transpose(1, 2)[0] if n_prompt == "": n_prompt = self.sample_neg_prompt # preprocess if not self.t5_cpu: self.text_encoder.model.to(self.device) context = self.text_encoder([input_prompt], self.device) context_null = self.text_encoder([n_prompt], self.device) if offload_model: self.text_encoder.model.cpu() else: context = self.text_encoder([input_prompt], torch.device('cpu')) context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null] self.clip.model.to(self.device) clip_context = self.clip.visual([img[:, None, :, :]]) if offload_model: self.clip.model.cpu() y = self.vae.encode([ torch.concat([ torch.nn.functional.interpolate( img[None].cpu(), size=(h, w), mode='bicubic').transpose( 0, 1), torch.zeros(3, F - 1, h, w) ], dim=1).to(self.device) ])[0] y = torch.concat([msk, y]) @contextmanager def noop_no_sync(): yield no_sync = getattr(self.model, 'no_sync', noop_no_sync) # evaluation mode with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): if sample_solver == 'unipc': sample_scheduler = FlowUniPCMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False) sample_scheduler.set_timesteps( sampling_steps, device=self.device, shift=shift) timesteps = sample_scheduler.timesteps elif sample_solver == 'dpm++': sample_scheduler = FlowDPMSolverMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False) sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) timesteps, _ = retrieve_timesteps( sample_scheduler, device=self.device, sigmas=sampling_sigmas) else: raise NotImplementedError("Unsupported solver.") # sample videos latent = noise arg_c = { 'context': [context[0]], 'clip_fea': clip_context, 'seq_len': max_seq_len, 'y': [y], } arg_null = { 'context': context_null, 'clip_fea': clip_context, 'seq_len': max_seq_len, 'y': [y], } if offload_model: torch.cuda.empty_cache() self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): latent_model_input = [latent.to(self.device)] timestep = [t] timestep = torch.stack(timestep).to(self.device) noise_pred_cond = self.model( latent_model_input, t=timestep, **arg_c)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: torch.cuda.empty_cache() noise_pred_uncond = self.model( latent_model_input, t=timestep, **arg_null)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: torch.cuda.empty_cache() noise_pred = noise_pred_uncond + guide_scale * ( noise_pred_cond - noise_pred_uncond) latent = latent.to( torch.device('cpu') if offload_model else self.device) temp_x0 = sample_scheduler.step( noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=seed_g)[0] latent = temp_x0.squeeze(0) x0 = [latent.to(self.device)] del latent_model_input, timestep if offload_model: self.model.cpu() torch.cuda.empty_cache() if self.rank == 0: videos = self.vae.decode(x0) del noise, latent del sample_scheduler if offload_model: gc.collect() torch.cuda.synchronize() if dist.is_initialized(): dist.barrier() return videos[0] if self.rank == 0 else None ``` ## /wan/modules/__init__.py ```py path="/wan/modules/__init__.py" from .attention import flash_attention from .model import WanModel from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model from .tokenizers import HuggingfaceTokenizer from .vae import WanVAE __all__ = [ 'WanVAE', 'WanModel', 'T5Model', 'T5Encoder', 'T5Decoder', 'T5EncoderModel', 'HuggingfaceTokenizer', 'flash_attention', ] ``` ## /wan/modules/attention.py ```py path="/wan/modules/attention.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch try: import flash_attn_interface FLASH_ATTN_3_AVAILABLE = True except ModuleNotFoundError: FLASH_ATTN_3_AVAILABLE = False try: import flash_attn FLASH_ATTN_2_AVAILABLE = True except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False import warnings __all__ = [ 'flash_attention', 'attention', ] def flash_attention( q, k, v, q_lens=None, k_lens=None, dropout_p=0., softmax_scale=None, q_scale=None, causal=False, window_size=(-1, -1), deterministic=False, dtype=torch.bfloat16, version=None, ): """ q: [B, Lq, Nq, C1]. k: [B, Lk, Nk, C1]. v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. q_lens: [B]. k_lens: [B]. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. causal: bool. Whether to apply causal attention mask. window_size: (left right). If not (-1, -1), apply sliding window local attention. deterministic: bool. If True, slightly slower and uses more memory. dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. """ half_dtypes = (torch.float16, torch.bfloat16) assert dtype in half_dtypes assert q.device.type == 'cuda' and q.size(-1) <= 256 # params b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype def half(x): return x if x.dtype in half_dtypes else x.to(dtype) # preprocess query if q_lens is None: q = half(q.flatten(0, 1)) q_lens = torch.tensor( [lq] * b, dtype=torch.int32).to( device=q.device, non_blocking=True) else: q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) # preprocess key, value if k_lens is None: k = half(k.flatten(0, 1)) v = half(v.flatten(0, 1)) k_lens = torch.tensor( [lk] * b, dtype=torch.int32).to( device=k.device, non_blocking=True) else: k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) q = q.to(v.dtype) k = k.to(v.dtype) if q_scale is not None: q = q * q_scale if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: warnings.warn( 'Flash attention 3 is not available, use flash attention 2 instead.' ) # apply attention if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: # Note: dropout_p, window_size are not supported in FA3 now. x = flash_attn_interface.flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 0, dtype=torch.int32).to(q.device, non_blocking=True), seqused_q=None, seqused_k=None, max_seqlen_q=lq, max_seqlen_k=lk, softmax_scale=softmax_scale, causal=causal, deterministic=deterministic)[0].unflatten(0, (b, lq)) else: assert FLASH_ATTN_2_AVAILABLE x = flash_attn.flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 0, dtype=torch.int32).to(q.device, non_blocking=True), max_seqlen_q=lq, max_seqlen_k=lk, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal, window_size=window_size, deterministic=deterministic).unflatten(0, (b, lq)) # output return x.type(out_dtype) def attention( q, k, v, q_lens=None, k_lens=None, dropout_p=0., softmax_scale=None, q_scale=None, causal=False, window_size=(-1, -1), deterministic=False, dtype=torch.bfloat16, fa_version=None, ): if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: return flash_attention( q=q, k=k, v=v, q_lens=q_lens, k_lens=k_lens, dropout_p=dropout_p, softmax_scale=softmax_scale, q_scale=q_scale, causal=causal, window_size=window_size, deterministic=deterministic, dtype=dtype, version=fa_version, ) else: if q_lens is not None or k_lens is not None: warnings.warn( 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' ) attn_mask = None q = q.transpose(1, 2).to(dtype) k = k.transpose(1, 2).to(dtype) v = v.transpose(1, 2).to(dtype) out = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) out = out.transpose(1, 2).contiguous() return out ``` ## /wan/modules/clip.py ```py path="/wan/modules/clip.py" # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging import math import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T from .attention import flash_attention from .tokenizers import HuggingfaceTokenizer from .xlm_roberta import XLMRoberta __all__ = [ 'XLMRobertaCLIP', 'clip_xlm_roberta_vit_h_14', 'CLIPModel', ] def pos_interpolate(pos, seq_len): if pos.size(1) == seq_len: return pos else: src_grid = int(math.sqrt(pos.size(1))) tar_grid = int(math.sqrt(seq_len)) n = pos.size(1) - src_grid * src_grid return torch.cat([ pos[:, :n], F.interpolate( pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( 0, 3, 1, 2), size=(tar_grid, tar_grid), mode='bicubic', align_corners=False).flatten(2).transpose(1, 2) ], dim=1) class QuickGELU(nn.Module): def forward(self, x): return x * torch.sigmoid(1.702 * x) class LayerNorm(nn.LayerNorm): def forward(self, x): return super().forward(x.float()).type_as(x) class SelfAttention(nn.Module): def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0): assert dim % num_heads == 0 super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.causal = causal self.attn_dropout = attn_dropout self.proj_dropout = proj_dropout # layers self.to_qkv = nn.Linear(dim, dim * 3) self.proj = nn.Linear(dim, dim) def forward(self, x): """ x: [B, L, C]. """ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim # compute query, key, value q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) # compute attention p = self.attn_dropout if self.training else 0.0 x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) x = x.reshape(b, s, c) # output x = self.proj(x) x = F.dropout(x, self.proj_dropout, self.training) return x class SwiGLU(nn.Module): def __init__(self, dim, mid_dim): super().__init__() self.dim = dim self.mid_dim = mid_dim # layers self.fc1 = nn.Linear(dim, mid_dim) self.fc2 = nn.Linear(dim, mid_dim) self.fc3 = nn.Linear(mid_dim, dim) def forward(self, x): x = F.silu(self.fc1(x)) * self.fc2(x) x = self.fc3(x) return x class AttentionBlock(nn.Module): def __init__(self, dim, mlp_ratio, num_heads, post_norm=False, causal=False, activation='quick_gelu', attn_dropout=0.0, proj_dropout=0.0, norm_eps=1e-5): assert activation in ['quick_gelu', 'gelu', 'swi_glu'] super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio self.num_heads = num_heads self.post_norm = post_norm self.causal = causal self.norm_eps = norm_eps # layers self.norm1 = LayerNorm(dim, eps=norm_eps) self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout) self.norm2 = LayerNorm(dim, eps=norm_eps) if activation == 'swi_glu': self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) else: self.mlp = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == 'quick_gelu' else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) def forward(self, x): if self.post_norm: x = x + self.norm1(self.attn(x)) x = x + self.norm2(self.mlp(x)) else: x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x class AttentionPool(nn.Module): def __init__(self, dim, mlp_ratio, num_heads, activation='gelu', proj_dropout=0.0, norm_eps=1e-5): assert dim % num_heads == 0 super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio self.num_heads = num_heads self.head_dim = dim // num_heads self.proj_dropout = proj_dropout self.norm_eps = norm_eps # layers gain = 1.0 / math.sqrt(dim) self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) self.to_q = nn.Linear(dim, dim) self.to_kv = nn.Linear(dim, dim * 2) self.proj = nn.Linear(dim, dim) self.norm = LayerNorm(dim, eps=norm_eps) self.mlp = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == 'quick_gelu' else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) def forward(self, x): """ x: [B, L, C]. """ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim # compute query, key, value q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) # compute attention x = flash_attention(q, k, v, version=2) x = x.reshape(b, 1, c) # output x = self.proj(x) x = F.dropout(x, self.proj_dropout, self.training) # mlp x = x + self.mlp(self.norm(x)) return x[:, 0] class VisionTransformer(nn.Module): def __init__(self, image_size=224, patch_size=16, dim=768, mlp_ratio=4, out_dim=512, num_heads=12, num_layers=12, pool_type='token', pre_norm=True, post_norm=False, activation='quick_gelu', attn_dropout=0.0, proj_dropout=0.0, embedding_dropout=0.0, norm_eps=1e-5): if image_size % patch_size != 0: print( '[WARNING] image_size is not divisible by patch_size', flush=True) assert pool_type in ('token', 'token_fc', 'attn_pool') out_dim = out_dim or dim super().__init__() self.image_size = image_size self.patch_size = patch_size self.num_patches = (image_size // patch_size)**2 self.dim = dim self.mlp_ratio = mlp_ratio self.out_dim = out_dim self.num_heads = num_heads self.num_layers = num_layers self.pool_type = pool_type self.post_norm = post_norm self.norm_eps = norm_eps # embeddings gain = 1.0 / math.sqrt(dim) self.patch_embedding = nn.Conv2d( 3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm) if pool_type in ('token', 'token_fc'): self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) self.pos_embedding = nn.Parameter(gain * torch.randn( 1, self.num_patches + (1 if pool_type in ('token', 'token_fc') else 0), dim)) self.dropout = nn.Dropout(embedding_dropout) # transformer self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None self.transformer = nn.Sequential(*[ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps) for _ in range(num_layers) ]) self.post_norm = LayerNorm(dim, eps=norm_eps) # head if pool_type == 'token': self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) elif pool_type == 'token_fc': self.head = nn.Linear(dim, out_dim) elif pool_type == 'attn_pool': self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps) def forward(self, x, interpolation=False, use_31_block=False): b = x.size(0) # embeddings x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) if self.pool_type in ('token', 'token_fc'): x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) if interpolation: e = pos_interpolate(self.pos_embedding, x.size(1)) else: e = self.pos_embedding x = self.dropout(x + e) if self.pre_norm is not None: x = self.pre_norm(x) # transformer if use_31_block: x = self.transformer[:-1](x) return x else: x = self.transformer(x) return x class XLMRobertaWithHead(XLMRoberta): def __init__(self, **kwargs): self.out_dim = kwargs.pop('out_dim') super().__init__(**kwargs) # head mid_dim = (self.dim + self.out_dim) // 2 self.head = nn.Sequential( nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False)) def forward(self, ids): # xlm-roberta x = super().forward(ids) # average pooling mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) x = (x * mask).sum(dim=1) / mask.sum(dim=1) # head x = self.head(x) return x class XLMRobertaCLIP(nn.Module): def __init__(self, embed_dim=1024, image_size=224, patch_size=14, vision_dim=1280, vision_mlp_ratio=4, vision_heads=16, vision_layers=32, vision_pool='token', vision_pre_norm=True, vision_post_norm=False, activation='gelu', vocab_size=250002, max_text_len=514, type_size=1, pad_id=1, text_dim=1024, text_heads=16, text_layers=24, text_post_norm=True, text_dropout=0.1, attn_dropout=0.0, proj_dropout=0.0, embedding_dropout=0.0, norm_eps=1e-5): super().__init__() self.embed_dim = embed_dim self.image_size = image_size self.patch_size = patch_size self.vision_dim = vision_dim self.vision_mlp_ratio = vision_mlp_ratio self.vision_heads = vision_heads self.vision_layers = vision_layers self.vision_pre_norm = vision_pre_norm self.vision_post_norm = vision_post_norm self.activation = activation self.vocab_size = vocab_size self.max_text_len = max_text_len self.type_size = type_size self.pad_id = pad_id self.text_dim = text_dim self.text_heads = text_heads self.text_layers = text_layers self.text_post_norm = text_post_norm self.norm_eps = norm_eps # models self.visual = VisionTransformer( image_size=image_size, patch_size=patch_size, dim=vision_dim, mlp_ratio=vision_mlp_ratio, out_dim=embed_dim, num_heads=vision_heads, num_layers=vision_layers, pool_type=vision_pool, pre_norm=vision_pre_norm, post_norm=vision_post_norm, activation=activation, attn_dropout=attn_dropout, proj_dropout=proj_dropout, embedding_dropout=embedding_dropout, norm_eps=norm_eps) self.textual = XLMRobertaWithHead( vocab_size=vocab_size, max_seq_len=max_text_len, type_size=type_size, pad_id=pad_id, dim=text_dim, out_dim=embed_dim, num_heads=text_heads, num_layers=text_layers, post_norm=text_post_norm, dropout=text_dropout) self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) def forward(self, imgs, txt_ids): """ imgs: [B, 3, H, W] of torch.float32. - mean: [0.48145466, 0.4578275, 0.40821073] - std: [0.26862954, 0.26130258, 0.27577711] txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer. """ xi = self.visual(imgs) xt = self.textual(txt_ids) return xi, xt def param_groups(self): groups = [{ 'params': [ p for n, p in self.named_parameters() if 'norm' in n or n.endswith('bias') ], 'weight_decay': 0.0 }, { 'params': [ p for n, p in self.named_parameters() if not ('norm' in n or n.endswith('bias')) ] }] return groups def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding='eos', dtype=torch.float32, device='cpu', **kwargs): # init a model on device with torch.device(device): model = model_cls(**kwargs) # set device model = model.to(dtype=dtype, device=device) output = (model,) # init transforms if return_transforms: # mean and std if 'siglip' in pretrained_name.lower(): mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] else: mean = [0.48145466, 0.4578275, 0.40821073] std = [0.26862954, 0.26130258, 0.27577711] # transforms transforms = T.Compose([ T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=mean, std=std) ]) output += (transforms,) return output[0] if len(output) == 1 else output def clip_xlm_roberta_vit_h_14( pretrained=False, pretrained_name='open-clip-xlm-roberta-large-vit-huge-14', **kwargs): cfg = dict( embed_dim=1024, image_size=224, patch_size=14, vision_dim=1280, vision_mlp_ratio=4, vision_heads=16, vision_layers=32, vision_pool='token', activation='gelu', vocab_size=250002, max_text_len=514, type_size=1, pad_id=1, text_dim=1024, text_heads=16, text_layers=24, text_post_norm=True, text_dropout=0.1, attn_dropout=0.0, proj_dropout=0.0, embedding_dropout=0.0) cfg.update(**kwargs) return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) class CLIPModel: def __init__(self, dtype, device, checkpoint_path, tokenizer_path): self.dtype = dtype self.device = device self.checkpoint_path = checkpoint_path self.tokenizer_path = tokenizer_path # init model self.model, self.transforms = clip_xlm_roberta_vit_h_14( pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device) self.model = self.model.eval().requires_grad_(False) logging.info(f'loading {checkpoint_path}') self.model.load_state_dict( torch.load(checkpoint_path, map_location='cpu')) # init tokenizer self.tokenizer = HuggingfaceTokenizer( name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean='whitespace') def visual(self, videos): # preprocess size = (self.model.image_size,) * 2 videos = torch.cat([ F.interpolate( u.transpose(0, 1), size=size, mode='bicubic', align_corners=False) for u in videos ]) videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) # forward with torch.cuda.amp.autocast(dtype=self.dtype): out = self.model.visual(videos, use_31_block=True) return out ``` ## /wan/modules/model.py ```py path="/wan/modules/model.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import math import torch import torch.cuda.amp as amp import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from .attention import flash_attention __all__ = ['WanModel'] T5_CONTEXT_TOKEN_NUMBER = 512 FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2 def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 half = dim // 2 position = position.type(torch.float64) # calculation sinusoid = torch.outer( position, torch.pow(10000, -torch.arange(half).to(position).div(half))) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) return x @amp.autocast(enabled=False) def rope_params(max_seq_len, dim, theta=10000): assert dim % 2 == 0 freqs = torch.outer( torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs @amp.autocast(enabled=False) def rope_apply(x, grid_sizes, freqs): n, c = x.size(2), x.size(3) // 2 # split freqs freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) # loop over samples output = [] for i, (f, h, w) in enumerate(grid_sizes.tolist()): seq_len = f * h * w # precompute multipliers x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( seq_len, n, -1, 2)) freqs_i = torch.cat([ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(seq_len, 1, -1) # apply rotary embedding x_i = torch.view_as_real(x_i * freqs_i).flatten(2) x_i = torch.cat([x_i, x[i, seq_len:]]) # append to collection output.append(x_i) return torch.stack(output).float() class WanRMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): r""" Args: x(Tensor): Shape [B, L, C] """ return self._norm(x.float()).type_as(x) * self.weight def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) class WanLayerNorm(nn.LayerNorm): def __init__(self, dim, eps=1e-6, elementwise_affine=False): super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) def forward(self, x): r""" Args: x(Tensor): Shape [B, L, C] """ return super().forward(x.float()).type_as(x) class WanSelfAttention(nn.Module): def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): assert dim % num_heads == 0 super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.window_size = window_size self.qk_norm = qk_norm self.eps = eps # layers self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() def forward(self, x, seq_lens, grid_sizes, freqs): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] seq_lens(Tensor): Shape [B] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim # query, key, value function def qkv_fn(x): q = self.norm_q(self.q(x)).view(b, s, n, d) k = self.norm_k(self.k(x)).view(b, s, n, d) v = self.v(x).view(b, s, n, d) return q, k, v q, k, v = qkv_fn(x) x = flash_attention( q=rope_apply(q, grid_sizes, freqs), k=rope_apply(k, grid_sizes, freqs), v=v, k_lens=seq_lens, window_size=self.window_size) # output x = x.flatten(2) x = self.o(x) return x class WanT2VCrossAttention(WanSelfAttention): def forward(self, x, context, context_lens): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] context_lens(Tensor): Shape [B] """ b, n, d = x.size(0), self.num_heads, self.head_dim # compute query, key, value q = self.norm_q(self.q(x)).view(b, -1, n, d) k = self.norm_k(self.k(context)).view(b, -1, n, d) v = self.v(context).view(b, -1, n, d) # compute attention x = flash_attention(q, k, v, k_lens=context_lens) # output x = x.flatten(2) x = self.o(x) return x class WanI2VCrossAttention(WanSelfAttention): def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): super().__init__(dim, num_heads, window_size, qk_norm, eps) self.k_img = nn.Linear(dim, dim) self.v_img = nn.Linear(dim, dim) # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() def forward(self, x, context, context_lens): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] context_lens(Tensor): Shape [B] """ image_context_length = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER context_img = context[:, :image_context_length] context = context[:, image_context_length:] b, n, d = x.size(0), self.num_heads, self.head_dim # compute query, key, value q = self.norm_q(self.q(x)).view(b, -1, n, d) k = self.norm_k(self.k(context)).view(b, -1, n, d) v = self.v(context).view(b, -1, n, d) k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) v_img = self.v_img(context_img).view(b, -1, n, d) img_x = flash_attention(q, k_img, v_img, k_lens=None) # compute attention x = flash_attention(q, k, v, k_lens=context_lens) # output x = x.flatten(2) img_x = img_x.flatten(2) x = x + img_x x = self.o(x) return x WAN_CROSSATTENTION_CLASSES = { 't2v_cross_attn': WanT2VCrossAttention, 'i2v_cross_attn': WanI2VCrossAttention, } class WanAttentionBlock(nn.Module): def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, eps=1e-6): super().__init__() self.dim = dim self.ffn_dim = ffn_dim self.num_heads = num_heads self.window_size = window_size self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps # layers self.norm1 = WanLayerNorm(dim, eps) self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) self.norm3 = WanLayerNorm( dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps) self.norm2 = WanLayerNorm(dim, eps) self.ffn = nn.Sequential( nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), nn.Linear(ffn_dim, dim)) # modulation self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5) def forward( self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, ): r""" Args: x(Tensor): Shape [B, L, C] e(Tensor): Shape [B, 6, C] seq_lens(Tensor): Shape [B], length of each sequence in batch grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ assert e.dtype == torch.float32 with amp.autocast(dtype=torch.float32): e = (self.modulation + e).chunk(6, dim=1) assert e[0].dtype == torch.float32 # self-attention y = self.self_attn( self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs) with amp.autocast(dtype=torch.float32): x = x + y * e[2] # cross-attention & ffn function def cross_attn_ffn(x, context, context_lens, e): x = x + self.cross_attn(self.norm3(x), context, context_lens) y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) with amp.autocast(dtype=torch.float32): x = x + y * e[5] return x x = cross_attn_ffn(x, context, context_lens, e) return x class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6): super().__init__() self.dim = dim self.out_dim = out_dim self.patch_size = patch_size self.eps = eps # layers out_dim = math.prod(patch_size) * out_dim self.norm = WanLayerNorm(dim, eps) self.head = nn.Linear(dim, out_dim) # modulation self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5) def forward(self, x, e): r""" Args: x(Tensor): Shape [B, L1, C] e(Tensor): Shape [B, C] """ assert e.dtype == torch.float32 with amp.autocast(dtype=torch.float32): e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) return x class MLPProj(torch.nn.Module): def __init__(self, in_dim, out_dim, flf_pos_emb=False): super().__init__() self.proj = torch.nn.Sequential( torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), torch.nn.LayerNorm(out_dim)) if flf_pos_emb: # NOTE: we only use this for `flf2v` self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280)) def forward(self, image_embeds): if hasattr(self, 'emb_pos'): bs, n, d = image_embeds.shape image_embeds = image_embeds.view(-1, 2 * n, d) image_embeds = image_embeds + self.emb_pos clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens class WanModel(ModelMixin, ConfigMixin): r""" Wan diffusion backbone supporting both text-to-video and image-to-video. """ ignore_for_config = [ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' ] _no_split_modules = ['WanAttentionBlock'] @register_to_config def __init__(self, model_type='t2v', patch_size=(1, 2, 2), text_len=512, in_dim=16, dim=2048, ffn_dim=8192, freq_dim=256, text_dim=4096, out_dim=16, num_heads=16, num_layers=32, window_size=(-1, -1), qk_norm=True, cross_attn_norm=True, eps=1e-6): r""" Initialize the diffusion model backbone. Args: model_type (`str`, *optional*, defaults to 't2v'): Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) text_len (`int`, *optional*, defaults to 512): Fixed length for text embeddings in_dim (`int`, *optional*, defaults to 16): Input video channels (C_in) dim (`int`, *optional*, defaults to 2048): Hidden dimension of the transformer ffn_dim (`int`, *optional*, defaults to 8192): Intermediate dimension in feed-forward network freq_dim (`int`, *optional*, defaults to 256): Dimension for sinusoidal time embeddings text_dim (`int`, *optional*, defaults to 4096): Input dimension for text embeddings out_dim (`int`, *optional*, defaults to 16): Output video channels (C_out) num_heads (`int`, *optional*, defaults to 16): Number of attention heads num_layers (`int`, *optional*, defaults to 32): Number of transformer blocks window_size (`tuple`, *optional*, defaults to (-1, -1)): Window size for local attention (-1 indicates global attention) qk_norm (`bool`, *optional*, defaults to True): Enable query/key normalization cross_attn_norm (`bool`, *optional*, defaults to False): Enable cross-attention normalization eps (`float`, *optional*, defaults to 1e-6): Epsilon value for normalization layers """ super().__init__() assert model_type in ['t2v', 'i2v', 'flf2v'] self.model_type = model_type self.patch_size = patch_size self.text_len = text_len self.in_dim = in_dim self.dim = dim self.ffn_dim = ffn_dim self.freq_dim = freq_dim self.text_dim = text_dim self.out_dim = out_dim self.num_heads = num_heads self.num_layers = num_layers self.window_size = window_size self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps # embeddings self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) self.text_embedding = nn.Sequential( nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim)) self.time_embedding = nn.Sequential( nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) # blocks cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' self.blocks = nn.ModuleList([ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) for _ in range(num_layers) ]) # head self.head = Head(dim, out_dim, patch_size, eps) # buffers (don't use register_buffer otherwise dtype will be changed in to()) assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 d = dim // num_heads self.freqs = torch.cat([ rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)) ], dim=1) if model_type == 'i2v' or model_type == 'flf2v': self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v') # initialize weights self.init_weights() def forward( self, x, t, context, seq_len, clip_fea=None, y=None, ): r""" Forward pass through the diffusion model Args: x (List[Tensor]): List of input video tensors, each with shape [C_in, F, H, W] t (Tensor): Diffusion timesteps tensor of shape [B] context (List[Tensor]): List of text embeddings each with shape [L, C] seq_len (`int`): Maximum sequence length for positional encoding clip_fea (Tensor, *optional*): CLIP image features for image-to-video mode or first-last-frame-to-video mode y (List[Tensor], *optional*): Conditional video inputs for image-to-video mode, same shape as x Returns: List[Tensor]: List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] """ if self.model_type == 'i2v' or self.model_type == 'flf2v': assert clip_fea is not None and y is not None # params device = self.patch_embedding.weight.device if self.freqs.device != device: self.freqs = self.freqs.to(device) if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] # embeddings x = [self.patch_embedding(u.unsqueeze(0)) for u in x] grid_sizes = torch.stack( [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) x = [u.flatten(2).transpose(1, 2) for u in x] seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) assert seq_lens.max() <= seq_len x = torch.cat([ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x ]) # time embeddings with amp.autocast(dtype=torch.float32): e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t).float()) e0 = self.time_projection(e).unflatten(1, (6, self.dim)) assert e.dtype == torch.float32 and e0.dtype == torch.float32 # context context_lens = None context = self.text_embedding( torch.stack([ torch.cat( [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context ])) if clip_fea is not None: context_clip = self.img_emb(clip_fea) # bs x 257 (x2) x dim context = torch.concat([context_clip, context], dim=1) # arguments kwargs = dict( e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=self.freqs, context=context, context_lens=context_lens) for block in self.blocks: x = block(x, **kwargs) # head x = self.head(x, e) # unpatchify x = self.unpatchify(x, grid_sizes) return [u.float() for u in x] def unpatchify(self, x, grid_sizes): r""" Reconstruct video tensors from patch embeddings. Args: x (List[Tensor]): List of patchified features, each with shape [L, C_out * prod(patch_size)] grid_sizes (Tensor): Original spatial-temporal grid dimensions before patching, shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) Returns: List[Tensor]: Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] """ c = self.out_dim out = [] for u, v in zip(x, grid_sizes.tolist()): u = u[:math.prod(v)].view(*v, *self.patch_size, c) u = torch.einsum('fhwpqrc->cfphqwr', u) u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) out.append(u) return out def init_weights(self): r""" Initialize model parameters using Xavier initialization. """ # basic init for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) # init embeddings nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) for m in self.text_embedding.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=.02) for m in self.time_embedding.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=.02) # init output layer nn.init.zeros_(self.head.head.weight) ``` ## /wan/modules/t5.py ```py path="/wan/modules/t5.py" # Modified from transformers.models.t5.modeling_t5 # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging import math import torch import torch.nn as nn import torch.nn.functional as F from .tokenizers import HuggingfaceTokenizer __all__ = [ 'T5Model', 'T5Encoder', 'T5Decoder', 'T5EncoderModel', ] def fp16_clamp(x): if x.dtype == torch.float16 and torch.isinf(x).any(): clamp = torch.finfo(x.dtype).max - 1000 x = torch.clamp(x, min=-clamp, max=clamp) return x def init_weights(m): if isinstance(m, T5LayerNorm): nn.init.ones_(m.weight) elif isinstance(m, T5Model): nn.init.normal_(m.token_embedding.weight, std=1.0) elif isinstance(m, T5FeedForward): nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) elif isinstance(m, T5Attention): nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) nn.init.normal_(m.k.weight, std=m.dim**-0.5) nn.init.normal_(m.v.weight, std=m.dim**-0.5) nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) elif isinstance(m, T5RelativeEmbedding): nn.init.normal_( m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) class GELU(nn.Module): def forward(self, x): return 0.5 * x * (1.0 + torch.tanh( math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) class T5LayerNorm(nn.Module): def __init__(self, dim, eps=1e-6): super(T5LayerNorm, self).__init__() self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) if self.weight.dtype in [torch.float16, torch.bfloat16]: x = x.type_as(self.weight) return self.weight * x class T5Attention(nn.Module): def __init__(self, dim, dim_attn, num_heads, dropout=0.1): assert dim_attn % num_heads == 0 super(T5Attention, self).__init__() self.dim = dim self.dim_attn = dim_attn self.num_heads = num_heads self.head_dim = dim_attn // num_heads # layers self.q = nn.Linear(dim, dim_attn, bias=False) self.k = nn.Linear(dim, dim_attn, bias=False) self.v = nn.Linear(dim, dim_attn, bias=False) self.o = nn.Linear(dim_attn, dim, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x, context=None, mask=None, pos_bias=None): """ x: [B, L1, C]. context: [B, L2, C] or None. mask: [B, L2] or [B, L1, L2] or None. """ # check inputs context = x if context is None else context b, n, c = x.size(0), self.num_heads, self.head_dim # compute query, key, value q = self.q(x).view(b, -1, n, c) k = self.k(context).view(b, -1, n, c) v = self.v(context).view(b, -1, n, c) # attention bias attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) if pos_bias is not None: attn_bias += pos_bias if mask is not None: assert mask.ndim in [2, 3] mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1) attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) # compute attention (T5 does not use scaling) attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias attn = F.softmax(attn.float(), dim=-1).type_as(attn) x = torch.einsum('bnij,bjnc->binc', attn, v) # output x = x.reshape(b, -1, n * c) x = self.o(x) x = self.dropout(x) return x class T5FeedForward(nn.Module): def __init__(self, dim, dim_ffn, dropout=0.1): super(T5FeedForward, self).__init__() self.dim = dim self.dim_ffn = dim_ffn # layers self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) self.fc1 = nn.Linear(dim, dim_ffn, bias=False) self.fc2 = nn.Linear(dim_ffn, dim, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.fc1(x) * self.gate(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x class T5SelfAttention(nn.Module): def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): super(T5SelfAttention, self).__init__() self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn self.num_heads = num_heads self.num_buckets = num_buckets self.shared_pos = shared_pos # layers self.norm1 = T5LayerNorm(dim) self.attn = T5Attention(dim, dim_attn, num_heads, dropout) self.norm2 = T5LayerNorm(dim) self.ffn = T5FeedForward(dim, dim_ffn, dropout) self.pos_embedding = None if shared_pos else T5RelativeEmbedding( num_buckets, num_heads, bidirectional=True) def forward(self, x, mask=None, pos_bias=None): e = pos_bias if self.shared_pos else self.pos_embedding( x.size(1), x.size(1)) x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) x = fp16_clamp(x + self.ffn(self.norm2(x))) return x class T5CrossAttention(nn.Module): def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): super(T5CrossAttention, self).__init__() self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn self.num_heads = num_heads self.num_buckets = num_buckets self.shared_pos = shared_pos # layers self.norm1 = T5LayerNorm(dim) self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) self.norm2 = T5LayerNorm(dim) self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) self.norm3 = T5LayerNorm(dim) self.ffn = T5FeedForward(dim, dim_ffn, dropout) self.pos_embedding = None if shared_pos else T5RelativeEmbedding( num_buckets, num_heads, bidirectional=False) def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None): e = pos_bias if self.shared_pos else self.pos_embedding( x.size(1), x.size(1)) x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) x = fp16_clamp(x + self.cross_attn( self.norm2(x), context=encoder_states, mask=encoder_mask)) x = fp16_clamp(x + self.ffn(self.norm3(x))) return x class T5RelativeEmbedding(nn.Module): def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): super(T5RelativeEmbedding, self).__init__() self.num_buckets = num_buckets self.num_heads = num_heads self.bidirectional = bidirectional self.max_dist = max_dist # layers self.embedding = nn.Embedding(num_buckets, num_heads) def forward(self, lq, lk): device = self.embedding.weight.device # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ # torch.arange(lq).unsqueeze(1).to(device) rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ torch.arange(lq, device=device).unsqueeze(1) rel_pos = self._relative_position_bucket(rel_pos) rel_pos_embeds = self.embedding(rel_pos) rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( 0) # [1, N, Lq, Lk] return rel_pos_embeds.contiguous() def _relative_position_bucket(self, rel_pos): # preprocess if self.bidirectional: num_buckets = self.num_buckets // 2 rel_buckets = (rel_pos > 0).long() * num_buckets rel_pos = torch.abs(rel_pos) else: num_buckets = self.num_buckets rel_buckets = 0 rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) # embeddings for small and large positions max_exact = num_buckets // 2 rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long() rel_pos_large = torch.min( rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) return rel_buckets class T5Encoder(nn.Module): def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): super(T5Encoder, self).__init__() self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn self.num_heads = num_heads self.num_layers = num_layers self.num_buckets = num_buckets self.shared_pos = shared_pos # layers self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ else nn.Embedding(vocab, dim) self.pos_embedding = T5RelativeEmbedding( num_buckets, num_heads, bidirectional=True) if shared_pos else None self.dropout = nn.Dropout(dropout) self.blocks = nn.ModuleList([ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers) ]) self.norm = T5LayerNorm(dim) # initialize weights self.apply(init_weights) def forward(self, ids, mask=None): x = self.token_embedding(ids) x = self.dropout(x) e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None for block in self.blocks: x = block(x, mask, pos_bias=e) x = self.norm(x) x = self.dropout(x) return x class T5Decoder(nn.Module): def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): super(T5Decoder, self).__init__() self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn self.num_heads = num_heads self.num_layers = num_layers self.num_buckets = num_buckets self.shared_pos = shared_pos # layers self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ else nn.Embedding(vocab, dim) self.pos_embedding = T5RelativeEmbedding( num_buckets, num_heads, bidirectional=False) if shared_pos else None self.dropout = nn.Dropout(dropout) self.blocks = nn.ModuleList([ T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers) ]) self.norm = T5LayerNorm(dim) # initialize weights self.apply(init_weights) def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): b, s = ids.size() # causal mask if mask is None: mask = torch.tril(torch.ones(1, s, s).to(ids.device)) elif mask.ndim == 2: mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) # layers x = self.token_embedding(ids) x = self.dropout(x) e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None for block in self.blocks: x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) x = self.norm(x) x = self.dropout(x) return x class T5Model(nn.Module): def __init__(self, vocab_size, dim, dim_attn, dim_ffn, num_heads, encoder_layers, decoder_layers, num_buckets, shared_pos=True, dropout=0.1): super(T5Model, self).__init__() self.vocab_size = vocab_size self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn self.num_heads = num_heads self.encoder_layers = encoder_layers self.decoder_layers = decoder_layers self.num_buckets = num_buckets # layers self.token_embedding = nn.Embedding(vocab_size, dim) self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout) self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout) self.head = nn.Linear(dim, vocab_size, bias=False) # initialize weights self.apply(init_weights) def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): x = self.encoder(encoder_ids, encoder_mask) x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) x = self.head(x) return x def _t5(name, encoder_only=False, decoder_only=False, return_tokenizer=False, tokenizer_kwargs={}, dtype=torch.float32, device='cpu', **kwargs): # sanity check assert not (encoder_only and decoder_only) # params if encoder_only: model_cls = T5Encoder kwargs['vocab'] = kwargs.pop('vocab_size') kwargs['num_layers'] = kwargs.pop('encoder_layers') _ = kwargs.pop('decoder_layers') elif decoder_only: model_cls = T5Decoder kwargs['vocab'] = kwargs.pop('vocab_size') kwargs['num_layers'] = kwargs.pop('decoder_layers') _ = kwargs.pop('encoder_layers') else: model_cls = T5Model # init model with torch.device(device): model = model_cls(**kwargs) # set device model = model.to(dtype=dtype, device=device) # init tokenizer if return_tokenizer: from .tokenizers import HuggingfaceTokenizer tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs) return model, tokenizer else: return model def umt5_xxl(**kwargs): cfg = dict( vocab_size=256384, dim=4096, dim_attn=4096, dim_ffn=10240, num_heads=64, encoder_layers=24, decoder_layers=24, num_buckets=32, shared_pos=False, dropout=0.1) cfg.update(**kwargs) return _t5('umt5-xxl', **cfg) class T5EncoderModel: def __init__( self, text_len, dtype=torch.bfloat16, device=torch.cuda.current_device(), checkpoint_path=None, tokenizer_path=None, shard_fn=None, ): self.text_len = text_len self.dtype = dtype self.device = device self.checkpoint_path = checkpoint_path self.tokenizer_path = tokenizer_path # init model model = umt5_xxl( encoder_only=True, return_tokenizer=False, dtype=dtype, device=device).eval().requires_grad_(False) logging.info(f'loading {checkpoint_path}') model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) self.model = model if shard_fn is not None: self.model = shard_fn(self.model, sync_module_states=False) else: self.model.to(self.device) # init tokenizer self.tokenizer = HuggingfaceTokenizer( name=tokenizer_path, seq_len=text_len, clean='whitespace') def __call__(self, texts, device): ids, mask = self.tokenizer( texts, return_mask=True, add_special_tokens=True) ids = ids.to(device) mask = mask.to(device) seq_lens = mask.gt(0).sum(dim=1).long() context = self.model(ids, mask) return [u[:v] for u, v in zip(context, seq_lens)] ``` ## /wan/modules/tokenizers.py ```py path="/wan/modules/tokenizers.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import html import string import ftfy import regex as re from transformers import AutoTokenizer __all__ = ['HuggingfaceTokenizer'] def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r'\s+', ' ', text) text = text.strip() return text def canonicalize(text, keep_punctuation_exact_string=None): text = text.replace('_', ' ') if keep_punctuation_exact_string: text = keep_punctuation_exact_string.join( part.translate(str.maketrans('', '', string.punctuation)) for part in text.split(keep_punctuation_exact_string)) else: text = text.translate(str.maketrans('', '', string.punctuation)) text = text.lower() text = re.sub(r'\s+', ' ', text) return text.strip() class HuggingfaceTokenizer: def __init__(self, name, seq_len=None, clean=None, **kwargs): assert clean in (None, 'whitespace', 'lower', 'canonicalize') self.name = name self.seq_len = seq_len self.clean = clean # init tokenizer self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) self.vocab_size = self.tokenizer.vocab_size def __call__(self, sequence, **kwargs): return_mask = kwargs.pop('return_mask', False) # arguments _kwargs = {'return_tensors': 'pt'} if self.seq_len is not None: _kwargs.update({ 'padding': 'max_length', 'truncation': True, 'max_length': self.seq_len }) _kwargs.update(**kwargs) # tokenization if isinstance(sequence, str): sequence = [sequence] if self.clean: sequence = [self._clean(u) for u in sequence] ids = self.tokenizer(sequence, **_kwargs) # output if return_mask: return ids.input_ids, ids.attention_mask else: return ids.input_ids def _clean(self, text): if self.clean == 'whitespace': text = whitespace_clean(basic_clean(text)) elif self.clean == 'lower': text = whitespace_clean(basic_clean(text)).lower() elif self.clean == 'canonicalize': text = canonicalize(basic_clean(text)) return text ``` ## /wan/modules/vae.py ```py path="/wan/modules/vae.py" # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging import torch import torch.cuda.amp as amp import torch.nn as nn import torch.nn.functional as F from einops import rearrange __all__ = [ 'WanVAE', ] CACHE_T = 2 class CausalConv3d(nn.Conv3d): """ Causal 3d convolusion. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) self.padding = (0, 0, 0) def forward(self, x, cache_x=None): padding = list(self._padding) if cache_x is not None and self._padding[4] > 0: cache_x = cache_x.to(x.device) x = torch.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] x = F.pad(x, padding) return super().forward(x) class RMS_norm(nn.Module): def __init__(self, dim, channel_first=True, images=True, bias=False): super().__init__() broadcastable_dims = (1, 1, 1) if not images else (1, 1) shape = (dim, *broadcastable_dims) if channel_first else (dim,) self.channel_first = channel_first self.scale = dim**0.5 self.gamma = nn.Parameter(torch.ones(shape)) self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. def forward(self, x): return F.normalize( x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias class Upsample(nn.Upsample): def forward(self, x): """ Fix bfloat16 support for nearest neighbor interpolation. """ return super().forward(x.float()).type_as(x) class Resample(nn.Module): def __init__(self, dim, mode): assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', 'downsample3d') super().__init__() self.dim = dim self.mode = mode # layers if mode == 'upsample2d': self.resample = nn.Sequential( Upsample(scale_factor=(2., 2.), mode='nearest-exact'), nn.Conv2d(dim, dim // 2, 3, padding=1)) elif mode == 'upsample3d': self.resample = nn.Sequential( Upsample(scale_factor=(2., 2.), mode='nearest-exact'), nn.Conv2d(dim, dim // 2, 3, padding=1)) self.time_conv = CausalConv3d( dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) elif mode == 'downsample2d': self.resample = nn.Sequential( nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) elif mode == 'downsample3d': self.resample = nn.Sequential( nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) self.time_conv = CausalConv3d( dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) else: self.resample = nn.Identity() def forward(self, x, feat_cache=None, feat_idx=[0]): b, c, t, h, w = x.size() if self.mode == 'upsample3d': if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: feat_cache[idx] = 'Rep' feat_idx[0] += 1 else: cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[ idx] is not None and feat_cache[idx] != 'Rep': # cache last frame of last two chunk cache_x = torch.cat([ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( cache_x.device), cache_x ], dim=2) if cache_x.shape[2] < 2 and feat_cache[ idx] is not None and feat_cache[idx] == 'Rep': cache_x = torch.cat([ torch.zeros_like(cache_x).to(cache_x.device), cache_x ], dim=2) if feat_cache[idx] == 'Rep': x = self.time_conv(x) else: x = self.time_conv(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 x = x.reshape(b, 2, c, t, h, w) x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) x = x.reshape(b, c, t * 2, h, w) t = x.shape[2] x = rearrange(x, 'b c t h w -> (b t) c h w') x = self.resample(x) x = rearrange(x, '(b t) c h w -> b c t h w', t=t) if self.mode == 'downsample3d': if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: feat_cache[idx] = x.clone() feat_idx[0] += 1 else: cache_x = x[:, :, -1:, :, :].clone() # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': # # cache last frame of last two chunk # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = self.time_conv( torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x feat_idx[0] += 1 return x def init_weight(self, conv): conv_weight = conv.weight nn.init.zeros_(conv_weight) c1, c2, t, h, w = conv_weight.size() one_matrix = torch.eye(c1, c2) init_matrix = one_matrix nn.init.zeros_(conv_weight) #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 conv.weight.data.copy_(conv_weight) nn.init.zeros_(conv.bias.data) def init_weight2(self, conv): conv_weight = conv.weight.data nn.init.zeros_(conv_weight) c1, c2, t, h, w = conv_weight.size() init_matrix = torch.eye(c1 // 2, c2) #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix conv.weight.data.copy_(conv_weight) nn.init.zeros_(conv.bias.data) class ResidualBlock(nn.Module): def __init__(self, in_dim, out_dim, dropout=0.0): super().__init__() self.in_dim = in_dim self.out_dim = out_dim # layers self.residual = nn.Sequential( RMS_norm(in_dim, images=False), nn.SiLU(), CausalConv3d(in_dim, out_dim, 3, padding=1), RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), CausalConv3d(out_dim, out_dim, 3, padding=1)) self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ if in_dim != out_dim else nn.Identity() def forward(self, x, feat_cache=None, feat_idx=[0]): h = self.shortcut(x) for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk cache_x = torch.cat([ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( cache_x.device), cache_x ], dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = layer(x) return x + h class AttentionBlock(nn.Module): """ Causal self-attention with a single head. """ def __init__(self, dim): super().__init__() self.dim = dim # layers self.norm = RMS_norm(dim) self.to_qkv = nn.Conv2d(dim, dim * 3, 1) self.proj = nn.Conv2d(dim, dim, 1) # zero out the last layer params nn.init.zeros_(self.proj.weight) def forward(self, x): identity = x b, c, t, h, w = x.size() x = rearrange(x, 'b c t h w -> (b t) c h w') x = self.norm(x) # compute query, key, value q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk( 3, dim=-1) # apply attention x = F.scaled_dot_product_attention( q, k, v, ) x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) # output x = self.proj(x) x = rearrange(x, '(b t) c h w-> b c t h w', t=t) return x + identity class Encoder3d(nn.Module): def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0): super().__init__() self.dim = dim self.z_dim = z_dim self.dim_mult = dim_mult self.num_res_blocks = num_res_blocks self.attn_scales = attn_scales self.temperal_downsample = temperal_downsample # dimensions dims = [dim * u for u in [1] + dim_mult] scale = 1.0 # init block self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) # downsample blocks downsamples = [] for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): # residual (+attention) blocks for _ in range(num_res_blocks): downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) if scale in attn_scales: downsamples.append(AttentionBlock(out_dim)) in_dim = out_dim # downsample block if i != len(dim_mult) - 1: mode = 'downsample3d' if temperal_downsample[ i] else 'downsample2d' downsamples.append(Resample(out_dim, mode=mode)) scale /= 2.0 self.downsamples = nn.Sequential(*downsamples) # middle blocks self.middle = nn.Sequential( ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)) # output blocks self.head = nn.Sequential( RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)) def forward(self, x, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk cache_x = torch.cat([ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( cache_x.device), cache_x ], dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = self.conv1(x) ## downsamples for layer in self.downsamples: if feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) ## middle for layer in self.middle: if isinstance(layer, ResidualBlock) and feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) ## head for layer in self.head: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk cache_x = torch.cat([ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( cache_x.device), cache_x ], dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = layer(x) return x class Decoder3d(nn.Module): def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_upsample=[False, True, True], dropout=0.0): super().__init__() self.dim = dim self.z_dim = z_dim self.dim_mult = dim_mult self.num_res_blocks = num_res_blocks self.attn_scales = attn_scales self.temperal_upsample = temperal_upsample # dimensions dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] scale = 1.0 / 2**(len(dim_mult) - 2) # init block self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) # middle blocks self.middle = nn.Sequential( ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)) # upsample blocks upsamples = [] for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): # residual (+attention) blocks if i == 1 or i == 2 or i == 3: in_dim = in_dim // 2 for _ in range(num_res_blocks + 1): upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) if scale in attn_scales: upsamples.append(AttentionBlock(out_dim)) in_dim = out_dim # upsample block if i != len(dim_mult) - 1: mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' upsamples.append(Resample(out_dim, mode=mode)) scale *= 2.0 self.upsamples = nn.Sequential(*upsamples) # output blocks self.head = nn.Sequential( RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1)) def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk cache_x = torch.cat([ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( cache_x.device), cache_x ], dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = self.conv1(x) ## middle for layer in self.middle: if isinstance(layer, ResidualBlock) and feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) ## upsamples for layer in self.upsamples: if feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) ## head for layer in self.head: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk cache_x = torch.cat([ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( cache_x.device), cache_x ], dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = layer(x) return x def count_conv3d(model): count = 0 for m in model.modules(): if isinstance(m, CausalConv3d): count += 1 return count class WanVAE_(nn.Module): def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0): super().__init__() self.dim = dim self.z_dim = z_dim self.dim_mult = dim_mult self.num_res_blocks = num_res_blocks self.attn_scales = attn_scales self.temperal_downsample = temperal_downsample self.temperal_upsample = temperal_downsample[::-1] # modules self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) def forward(self, x): mu, log_var = self.encode(x) z = self.reparameterize(mu, log_var) x_recon = self.decode(z) return x_recon, mu, log_var def encode(self, x, scale): self.clear_cache() ## cache t = x.shape[2] iter_ = 1 + (t - 1) // 4 ## 对encode输入的x,按时间拆分为1、4、4、4.... for i in range(iter_): self._enc_conv_idx = [0] if i == 0: out = self.encoder( x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: out_ = self.encoder( x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) out = torch.cat([out, out_], 2) mu, log_var = self.conv1(out).chunk(2, dim=1) if isinstance(scale[0], torch.Tensor): mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( 1, self.z_dim, 1, 1, 1) else: mu = (mu - scale[0]) * scale[1] self.clear_cache() return mu def decode(self, z, scale): self.clear_cache() # z: [b,c,t,h,w] if isinstance(scale[0], torch.Tensor): z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( 1, self.z_dim, 1, 1, 1) else: z = z / scale[1] + scale[0] iter_ = z.shape[2] x = self.conv2(z) for i in range(iter_): self._conv_idx = [0] if i == 0: out = self.decoder( x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: out_ = self.decoder( x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) self.clear_cache() return out def reparameterize(self, mu, log_var): std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return eps * std + mu def sample(self, imgs, deterministic=False): mu, log_var = self.encode(imgs) if deterministic: return mu std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) return mu + std * torch.randn_like(std) def clear_cache(self): self._conv_num = count_conv3d(self.decoder) self._conv_idx = [0] self._feat_map = [None] * self._conv_num #cache encode self._enc_conv_num = count_conv3d(self.encoder) self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): """ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. """ # params cfg = dict( dim=96, z_dim=z_dim, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[False, True, True], dropout=0.0) cfg.update(**kwargs) # init model with torch.device('meta'): model = WanVAE_(**cfg) # load checkpoint logging.info(f'loading {pretrained_path}') model.load_state_dict( torch.load(pretrained_path, map_location=device), assign=True) return model class WanVAE: def __init__(self, z_dim=16, vae_pth='cache/vae_step_411000.pth', dtype=torch.float, device="cuda"): self.dtype = dtype self.device = device mean = [ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 ] std = [ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 ] self.mean = torch.tensor(mean, dtype=dtype, device=device) self.std = torch.tensor(std, dtype=dtype, device=device) self.scale = [self.mean, 1.0 / self.std] # init model self.model = _video_vae( pretrained_path=vae_pth, z_dim=z_dim, ).eval().requires_grad_(False).to(device) def encode(self, videos): """ videos: A list of videos each with shape [C, T, H, W]. """ with amp.autocast(dtype=self.dtype): return [ self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos ] def decode(self, zs): with amp.autocast(dtype=self.dtype): return [ self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs ] ``` ## /wan/modules/xlm_roberta.py ```py path="/wan/modules/xlm_roberta.py" # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch import torch.nn as nn import torch.nn.functional as F __all__ = ['XLMRoberta', 'xlm_roberta_large'] class SelfAttention(nn.Module): def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): assert dim % num_heads == 0 super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.eps = eps # layers self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) self.dropout = nn.Dropout(dropout) def forward(self, x, mask): """ x: [B, L, C]. """ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim # compute query, key, value q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) # compute attention p = self.dropout.p if self.training else 0.0 x = F.scaled_dot_product_attention(q, k, v, mask, p) x = x.permute(0, 2, 1, 3).reshape(b, s, c) # output x = self.o(x) x = self.dropout(x) return x class AttentionBlock(nn.Module): def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): super().__init__() self.dim = dim self.num_heads = num_heads self.post_norm = post_norm self.eps = eps # layers self.attn = SelfAttention(dim, num_heads, dropout, eps) self.norm1 = nn.LayerNorm(dim, eps=eps) self.ffn = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout)) self.norm2 = nn.LayerNorm(dim, eps=eps) def forward(self, x, mask): if self.post_norm: x = self.norm1(x + self.attn(x, mask)) x = self.norm2(x + self.ffn(x)) else: x = x + self.attn(self.norm1(x), mask) x = x + self.ffn(self.norm2(x)) return x class XLMRoberta(nn.Module): """ XLMRobertaModel with no pooler and no LM head. """ def __init__(self, vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5): super().__init__() self.vocab_size = vocab_size self.max_seq_len = max_seq_len self.type_size = type_size self.pad_id = pad_id self.dim = dim self.num_heads = num_heads self.num_layers = num_layers self.post_norm = post_norm self.eps = eps # embeddings self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) self.type_embedding = nn.Embedding(type_size, dim) self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) self.dropout = nn.Dropout(dropout) # blocks self.blocks = nn.ModuleList([ AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers) ]) # norm layer self.norm = nn.LayerNorm(dim, eps=eps) def forward(self, ids): """ ids: [B, L] of torch.LongTensor. """ b, s = ids.shape mask = ids.ne(self.pad_id).long() # embeddings x = self.token_embedding(ids) + \ self.type_embedding(torch.zeros_like(ids)) + \ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) if self.post_norm: x = self.norm(x) x = self.dropout(x) # blocks mask = torch.where( mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min) for block in self.blocks: x = block(x, mask) # output if not self.post_norm: x = self.norm(x) return x def xlm_roberta_large(pretrained=False, return_tokenizer=False, device='cpu', **kwargs): """ XLMRobertaLarge adapted from Huggingface. """ # params cfg = dict( vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5) cfg.update(**kwargs) # init a model on device with torch.device(device): model = XLMRoberta(**cfg) return model ``` The content has been capped at 50000 tokens, and files over NaN bytes have been omitted. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.