``` ├── .codestyle/ ├── copyright.hook ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── example/ ├── 24B/ ├── 24B_config.json ├── run.sh ├── 4.5B/ ├── 4.5B_config.json ├── run.sh ├── assets/ ├── image.jpeg ├── prefix_video.mp4 ├── special_tokens.npz ├── figures/ ├── algorithm.png ├── dit_architecture.png ├── inhouse_human_evaluation.png ├── logo_black.png ├── inference/ ├── common/ ├── __init__.py ├── common_utils.py ├── config.py ├── dataclass.py ├── logger.py ├── timer.py ├── infra/ ├── checkpoint/ ├── __init__.py ├── checkpointing.py ├── distributed/ ├── __init__.py ├── dist_utils.py ├── parallel_state.py ├── parallelism/ ├── __init__.py ├── context_parallel.py ├── pipeline_parallel.py ├── tile_parallel.py ├── model/ ├── dit/ ├── __init__.py ├── dit_model.py ├── dit_module.py ├── t5/ ├── __init__.py ``` ## /.codestyle/copyright.hook ```hook path="/.codestyle/copyright.hook" from __future__ import absolute_import from __future__ import print_function from __future__ import unicode_literals import argparse import io import re import sys import os import datetime COPYRIGHT = '''Copyright (c) 2025 SandAI. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.''' def _generate_copyright(comment_mark): copyright=COPYRIGHT.split(os.linesep) header = copyright[0].rstrip() p = re.search('(\d{4})', header).group(0) now = datetime.datetime.now() header = header.replace(p,str(now.year)) ans=[comment_mark + " " + header + os.linesep] for idx, line in enumerate(copyright[1:]): ans.append(comment_mark + " " + line.rstrip() + os.linesep) return ans def _get_comment_mark(path): lang_type=re.compile(r"\.(py|sh)$") if lang_type.search(path) is not None: return "#" lang_type=re.compile(r"\.(h|c|hpp|cc|cpp|cu|go|cuh|proto)$") if lang_type.search(path) is not None: return "//" return None RE_ENCODE = re.compile(r"^[ \t\v]*#.*?coding[:=]", re.IGNORECASE) RE_COPYRIGHT = re.compile(r".*Copyright \(c\) \d{4}", re.IGNORECASE) RE_SHEBANG = re.compile(r"^[ \t\v]*#[ \t]?\!") def _check_copyright(path): head=[] try: with open(path) as f: head = [next(f) for x in range(4)] except StopIteration: pass for idx, line in enumerate(head): if RE_COPYRIGHT.search(line) is not None: return True return False def generate_copyright(path, comment_mark): original_contents = io.open(path, encoding="utf-8").readlines() head = original_contents[0:4] insert_line_no=0 for i, line in enumerate(head): if RE_ENCODE.search(line) or RE_SHEBANG.search(line): insert_line_no=i+1 copyright = _generate_copyright(comment_mark) if insert_line_no == 0: new_contents = copyright if len(original_contents) > 0 and len(original_contents[0].strip()) != 0: new_contents.append(os.linesep) new_contents.extend(original_contents) else: new_contents=original_contents[0:insert_line_no] new_contents.append(os.linesep) new_contents.extend(copyright) if len(original_contents) > insert_line_no and len(original_contents[insert_line_no].strip()) != 0: new_contents.append(os.linesep) new_contents.extend(original_contents[insert_line_no:]) new_contents="".join(new_contents) with io.open(path, 'w') as output_file: output_file.write(new_contents) def main(argv=None): parser = argparse.ArgumentParser( description='Checker for copyright declaration.') parser.add_argument('filenames', nargs='*', help='Filenames to check') args = parser.parse_args(argv) retv = 0 for path in args.filenames: comment_mark = _get_comment_mark(path) if comment_mark is None: print("warning:Unsupported file", path, file=sys.stderr) continue if _check_copyright(path): continue generate_copyright(path, comment_mark) if __name__ == '__main__': exit(main()) ``` ## /.gitignore ```gitignore path="/.gitignore" __pycache__ *.pyc *.log *.pt *.mp4 ckpt downloads ``` ## /.pre-commit-config.yaml ```yaml path="/.pre-commit-config.yaml" exclude: \.patch$ repos: - repo: local hooks: - id: copyright_checker name: copyright_checker entry: python3 ./.codestyle/copyright.hook language: system files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$ - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - id: check-added-large-files args: - --maxkb=30720 - id: check-merge-conflict - id: check-symlinks - id: detect-private-key files: (?!.*third_party)^.*$ | (?!.*book)^.*$ - id: end-of-file-fixer - id: trailing-whitespace - id: requirements-txt-fixer - id: sort-simple-yaml - repo: https://github.com/Lucas-C/pre-commit-hooks.git rev: v1.5.1 hooks: - id: remove-crlf files: (?!.*third_party)^.*$ | (?!.*book)^.*$ - id: remove-tabs name: Tabs remover (C++) files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$ args: [--whitespaces-count, '2'] - id: remove-tabs name: Tabs remover (Python) files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ args: [--whitespaces-count, '4'] - repo: https://github.com/psf/black.git rev: 23.3.0 hooks: - id: black args: [--line-length=127, --skip-string-normalization, --skip-magic-trailing-comma] files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ - repo: https://github.com/pre-commit/mirrors-isort rev: v5.10.1 hooks: - id: isort args: [--profile=black, --line-length=127, --multi-line=3, --force-grid-wrap=0] files: \.py$ - repo: https://github.com/PyCQA/autoflake rev: v2.3.1 hooks: - id: autoflake args: [--remove-all-unused-imports, --remove-unused-variables, --in-place, --ignore-init-module-imports, --ignore-pass-after-docstring] files: \.py$ - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks.git rev: v2.9.0 hooks: - id: pretty-format-yaml args: [--autofix, --indent, '4'] ``` ## /LICENSE ``` path="/LICENSE" Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ``` ## /README.md ![magi-logo](figures/logo_black.png) -----

paper blog product Hugging Face Twitter Follow Discord license

# MAGI-1: Autoregressive Video Generation at Scale This repository contains the code for the MAGI-1 model, pre-trained weights and inference code. You can find more information on our [technical report](https://static.magi.world/static/files/MAGI_1.pdf) or directly create magic with MAGI-1 [here](http://sand.ai) . 🚀✨ ## 🔥🔥🔥 Latest News - Apr 22, 2025: We’re planning to release our 4.5B model by the end of April. Final touches are still underway — stay tuned for the official drop. - Apr 21, 2025: MAGI-1 is here 🎉. We've released the model weights and inference code — check it out! ## 1. About We present MAGI-1, a world model that generates videos by ***autoregressively*** predicting a sequence of video chunks, defined as fixed-length segments of consecutive frames. Trained to denoise per-chunk noise that increases monotonically over time, MAGI-1 enables causal temporal modeling and naturally supports streaming generation. It achieves strong performance on image-to-video (I2V) tasks conditioned on text instructions, providing high temporal consistency and scalability, which are made possible by several algorithmic innovations and a dedicated infrastructure stack. MAGI-1 further supports controllable generation via chunk-wise prompting, enabling smooth scene transitions, long-horizon synthesis, and fine-grained text-driven control. We believe MAGI-1 offers a promising direction for unifying high-fidelity video generation with flexible instruction control and real-time deployment.
## 2. Model Summary ### Transformer-based VAE - Variational autoencoder (VAE) with transformer-based architecture, 8x spatial and 4x temporal compression. - Fastest average decoding time and highly competitive reconstruction quality ### Auto-Regressive Denoising Algorithm MAGI-1 is an autoregressive denoising video generation model generating videos chunk-by-chunk instead of as a whole. Each chunk (24 frames) is denoised holistically, and the generation of the next chunk begins as soon as the current one reaches a certain level of denoising. This pipeline design enables concurrent processing of up to four chunks for efficient video generation. ![auto-regressive denosing algorithm](figures/algorithm.png) ### Diffusion Model Architecture MAGI-1 is built upon the Diffusion Transformer, incorporating several key innovations to enhance training efficiency and stability at scale. These advancements include Block-Causal Attention, Parallel Attention Block, QK-Norm and GQA, Sandwich Normalization in FFN, SwiGLU, and Softcap Modulation. For more details, please refer to the [technical report.](https://static.magi.world/static/files/MAGI_1.pdf)
diffusion model architecture
### Distillation Algorithm We adopt a shortcut distillation approach that trains a single velocity-based model to support variable inference budgets. By enforcing a self-consistency constraint—equating one large step with two smaller steps—the model learns to approximate flow-matching trajectories across multiple step sizes. During training, step sizes are cyclically sampled from {64, 32, 16, 8}, and classifier-free guidance distillation is incorporated to preserve conditional alignment. This enables efficient inference with minimal loss in fidelity. ## 3. Model Zoo We provide the pre-trained weights for MAGI-1, including the 24B and 4.5B models, as well as the corresponding distill and distill+quant models. The model weight links are shown in the table. | Model | Link | Recommend Machine | | ----------------------------- | ------------------------------------------------------------ | ------------------------------- | | T5 | [T5](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/t5) | - | | MAGI-1-VAE | [MAGI-1-VAE](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/vae) | - | | MAGI-1-24B | [MAGI-1-24B](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/24B_base) | H100/H800 \* 8 | | MAGI-1-24B-distill | [MAGI-1-24B-distill](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/24B_distill) | H100/H800 \* 8 | | MAGI-1-24B-distill+fp8_quant | [MAGI-1-24B-distill+quant](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/24B_distill_quant) | H100/H800 \* 4 or RTX 4090 \* 8 | | MAGI-1-4.5B | MAGI-1-4.5B | RTX 4090 \* 1 | ## 4. Evaluation ### In-house Human Evaluation MAGI-1 achieves state-of-the-art performance among open-source models like Wan-2.1 and HunyuanVideo and closed-source model like Hailuo (i2v-01), particularly excelling in instruction following and motion quality, positioning it as a strong potential competitor to closed-source commercial models such as Kling. ![inhouse human evaluation](figures/inhouse_human_evaluation.png) ### Physical Evaluation Thanks to the natural advantages of autoregressive architecture, Magi achieves far superior precision in predicting physical behavior on the [Physics-IQ benchmark](https://github.com/google-deepmind/physics-IQ-benchmark) through video continuation—significantly outperforming all existing models. | Model | Phys. IQ Score ↑ | Spatial IoU ↑ | Spatio Temporal ↑ | Weighted Spatial IoU ↑ | MSE ↓ | |----------------|------------------|---------------|-------------------|-------------------------|--------| | **V2V Models** | | | | | | | **Magi (V2V)** | **56.02** | **0.367** | **0.270** | **0.304** | **0.005** | | VideoPoet (V2V)| 29.50 | 0.204 | 0.164 | 0.137 | 0.010 | | **I2V Models** | | | | | | | **Magi (I2V)** | **30.23** | **0.203** | **0.151** | **0.154** | **0.012** | | Kling1.6 (I2V) | 23.64 | 0.197 | 0.086 | 0.144 | 0.025 | | VideoPoet (I2V)| 20.30 | 0.141 | 0.126 | 0.087 | 0.012 | | Gen 3 (I2V) | 22.80 | 0.201 | 0.115 | 0.116 | 0.015 | | Wan2.1 (I2V) | 20.89 | 0.153 | 0.100 | 0.112 | 0.023 | | Sora (I2V) | 10.00 | 0.138 | 0.047 | 0.063 | 0.030 | | **GroundTruth**| **100.0** | **0.678** | **0.535** | **0.577** | **0.002** | ## 5. How to run ### Environment Preparation We provide two ways to run MAGI-1, with the Docker environment being the recommended option. **Run with Docker Environment (Recommend)** ```bash docker pull sandai/magi:latest docker run -it --gpus all --privileged --shm-size=32g --name magi --net=host --ipc=host --ulimit memlock=-1 --ulimit stack=6710886 sandai/magi:latest /bin/bash ``` **Run with Source Code** ```bash # Create a new environment conda create -n magi python==3.10.12 # Install pytorch conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.4 -c pytorch -c nvidia # Install other dependencies pip install -r requirements.txt # Install ffmpeg conda install -c conda-forge ffmpeg=4.4 # Install MagiAttention, for more information, please refer to https://github.com/SandAI-org/MagiAttention# git clone git@github.com:SandAI-org/MagiAttention.git cd MagiAttention git submodule update --init --recursive pip install --no-build-isolation . ``` ### Inference Command To run the `MagiPipeline`, you can control the input and output by modifying the parameters in the `example/24B/run.sh` or `example/4.5B/run.sh` script. Below is an explanation of the key parameters: #### Parameter Descriptions - `--config_file`: Specifies the path to the configuration file, which contains model configuration parameters, e.g., `example/24B/24B_config.json`. - `--mode`: Specifies the mode of operation. Available options are: - `t2v`: Text to Video - `i2v`: Image to Video - `v2v`: Video to Video - `--prompt`: The text prompt used for video generation, e.g., `"Good Boy"`. - `--image_path`: Path to the image file, used only in `i2v` mode. - `--prefix_video_path`: Path to the prefix video file, used only in `v2v` mode. - `--output_path`: Path where the generated video file will be saved. #### Bash Script ```bash #!/bin/bash # Run 24B MAGI-1 model bash example/24B/run.sh # Run 4.5B MAGI-1 model bash example/4.5B/run.sh ``` #### Customizing Parameters You can modify the parameters in `run.sh` as needed. For example: - To use the Image to Video mode (`i2v`), set `--mode` to `i2v` and provide `--image_path`: ```bash --mode i2v \ --image_path example/assets/image.jpeg \ ``` - To use the Video to Video mode (`v2v`), set `--mode` to `v2v` and provide `--prefix_video_path`: ```bash --mode v2v \ --prefix_video_path example/assets/prefix_video.mp4 \ ``` By adjusting these parameters, you can flexibly control the input and output to meet different requirements. ### Some Useful Configs (for config.json) > NOTE: If you are running 24B model with RTX 4090 \* 8, please set `pp_size:2 cp_size: 4`. | Config | Help | | -------------- | ------------------------------------------------------------ | | seed | Random seed used for video generation | | video_size_h | Height of the video | | video_size_w | Width of the video | | num_frames | Controls the duration of generated video | | fps | Frames per second, 4 video frames correspond to 1 latent_frame | | cfg_number | Base model uses cfg_number==3, distill and quant model uses cfg_number=1 | | load | Directory containing a model checkpoint. | | t5_pretrained | Path to load pretrained T5 model | | vae_pretrained | Path to load pretrained VAE model | ## 6. License This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. ## 7. Citation If you find our code or model useful in your research, please cite: ```bibtex @misc{magi1, title={MAGI-1: Autoregressive Video Generation at Scale}, author={Sand-AI}, year={2025}, url={https://static.magi.world/static/files/MAGI_1.pdf}, } ``` ## 8. Contact If you have any questions, please feel free to raise an issue or contact us at [research@sand.ai](mailto:research@sand.ai) . ## /example/24B/24B_config.json ```json path="/example/24B/24B_config.json" { "model_config": { "model_name": "videodit_ardf", "num_layers": 48, "hidden_size": 6144, "ffn_hidden_size": 16384, "num_attention_heads": 48, "num_query_groups": 8, "kv_channels": 128, "layernorm_epsilon": 1e-06, "apply_layernorm_1p": true, "x_rescale_factor": 0.1, "half_channel_vae": true, "params_dtype": "torch.bfloat16", "patch_size": 2, "t_patch_size": 1, "in_channels": 32, "out_channels": 32, "cond_hidden_ratio": 0.25, "caption_channels": 4096, "caption_max_length": 800, "xattn_cond_hidden_ratio": 1.0, "cond_gating_ratio": 1.0, "gated_linear_unit": true }, "runtime_config": { "cfg_number": 1, "cfg_t_range": [ 0.0, 0.0217, 0.1, 0.3, 0.999 ], "prev_chunk_scales": [ 1.5, 1.5, 1.5, 1.0, 1.0 ], "text_scales": [ 7.5, 7.5, 7.5, 0.0, 0.0 ], "noise2clean_kvrange": [ 5, 4, 3, 2 ], "clean_chunk_kvrange": 1, "clean_t": 0.9999, "seed": 1234, "num_frames": 96, "video_size_h": 720, "video_size_w": 1280, "num_steps": 8, "window_size": 4, "fps": 24, "chunk_width": 6, "load": "./downloads/24B_base", "t5_pretrained": "./downloads/t5_pretrained", "t5_device": "cuda", "vae_pretrained": "./downloads/vae", "scale_factor": 0.18215, "temporal_downsample_factor": 4 }, "engine_config": { "distributed_backend": "nccl", "distributed_timeout_minutes": 15, "pp_size": 1, "cp_size": 8, "cp_strategy": "cp_ulysses", "ulysses_overlap_degree": 1, "fp8_quant": true, "distill_nearly_clean_chunk_threshold": 0.3, "shortcut_mode": "8,16,16", "distill": true, "kv_offload": true, "enable_cuda_graph": false } } ``` ## /example/24B/run.sh ```sh path="/example/24B/run.sh" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. export CUDA_DEVICE_MAX_CONNECTIONS=1 export NCCL_ALGO=^NVLS export PAD_HQ=1 export PAD_DURATION=1 export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True export OFFLOAD_T5_CACHE=true export OFFLOAD_VAE_CACHE=true export TORCH_CUDA_ARCH_LIST="8.9;9.0" GPUS_PER_NODE=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) DISTRIBUTED_ARGS=" --rdzv-backend=c10d \ --rdzv-endpoint=localhost:6009 \ --nnodes=1 \ --nproc_per_node=$GPUS_PER_NODE " MAGI_ROOT=$(git rev-parse --show-toplevel) LOG_DIR=log_$(date "+%Y-%m-%d_%H:%M:%S").log export PYTHONPATH="$MAGI_ROOT:$PYTHONPATH" torchrun $DISTRIBUTED_ARGS inference/pipeline/entry.py \ --config_file example/24B/24B_config.json \ --mode i2v \ --prompt "Good Boy" \ --image_path example/assets/image.jpeg \ --output_path example/assets/output_i2v.mp4 \ 2>&1 | tee $LOG_DIR ``` ## /example/4.5B/4.5B_config.json ```json path="/example/4.5B/4.5B_config.json" { "model_config": { "model_name": "videodit_ardf", "num_layers": 34, "hidden_size": 3072, "ffn_hidden_size": 12288, "num_attention_heads": 24, "num_query_groups": 8, "kv_channels": 128, "layernorm_epsilon": 1e-06, "apply_layernorm_1p": true, "x_rescale_factor": 1, "half_channel_vae": false, "params_dtype": "torch.bfloat16", "patch_size": 2, "t_patch_size": 1, "in_channels": 16, "out_channels": 16, "cond_hidden_ratio": 0.25, "caption_channels": 4096, "caption_max_length": 800, "xattn_cond_hidden_ratio": 1.0, "cond_gating_ratio": 1.0, "gated_linear_unit": false }, "runtime_config": { "cfg_number": 3, "cfg_t_range": [ 0.0, 0.0217, 0.1, 0.3, 0.999 ], "prev_chunk_scales": [ 1.5, 1.5, 1.5, 1.0, 1.0 ], "text_scales": [ 7.5, 7.5, 7.5, 0.0, 0.0 ], "noise2clean_kvrange": [ 5, 4, 3, 2 ], "clean_chunk_kvrange": 1, "clean_t": 0.9999, "seed": 1234, "num_frames": 192, "video_size_h": 720, "video_size_w": 1280, "num_steps": 64, "window_size": 4, "fps": 24, "chunk_width": 6, "load": "./downloads/4.5B_base", "t5_pretrained": "./downloads/t5_pretrained", "t5_device": "cpu", "vae_pretrained": "./downloads/vae", "scale_factor": 0.18215, "temporal_downsample_factor": 4 }, "engine_config": { "distributed_backend": "nccl", "distributed_timeout_minutes": 15, "pp_size": 1, "cp_size": 1, "cp_strategy": "cp_ulysses", "ulysses_overlap_degree": 1, "fp8_quant": false, "distill_nearly_clean_chunk_threshold": 0.3, "shortcut_mode": "8,16,16", "distill": false, "kv_offload": true, "enable_cuda_graph": false } } ``` ## /example/4.5B/run.sh ```sh path="/example/4.5B/run.sh" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. export MASTER_ADDR=localhost export MASTER_PORT=6009 export GPUS_PER_NODE=1 export NNODES=1 export WORLD_SIZE=1 export CUDA_VISIBLE_DEVICES=1 export PAD_HQ=1 export PAD_DURATION=1 export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True export OFFLOAD_T5_CACHE=true export OFFLOAD_VAE_CACHE=true export TORCH_CUDA_ARCH_LIST="8.9;9.0" MAGI_ROOT=$(git rev-parse --show-toplevel) LOG_DIR=log_$(date "+%Y-%m-%d_%H:%M:%S").log export PYTHONPATH="$MAGI_ROOT:$PYTHONPATH" python3 inference/pipeline/entry.py \ --config_file example/4.5B/4.5B_config.json \ --mode t2v \ --prompt "Good Boy" \ --output_path example/assets/output_t2v.mp4 \ 2>&1 | tee $LOG_DIR ``` ## /example/assets/image.jpeg Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/example/assets/image.jpeg ## /example/assets/prefix_video.mp4 Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/example/assets/prefix_video.mp4 ## /example/assets/special_tokens.npz Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/example/assets/special_tokens.npz ## /figures/algorithm.png Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/figures/algorithm.png ## /figures/dit_architecture.png Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/figures/dit_architecture.png ## /figures/inhouse_human_evaluation.png Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/figures/inhouse_human_evaluation.png ## /figures/logo_black.png Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/figures/logo_black.png ## /inference/common/__init__.py ```py path="/inference/common/__init__.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .common_utils import divide, env_is_true, set_random_seed from .config import EngineConfig, MagiConfig, ModelConfig, RuntimeConfig from .dataclass import InferenceParams, ModelMetaArgs, PackedCoreAttnParams, PackedCrossAttnParams from .logger import magi_logger, print_per_rank, print_rank_0 from .timer import event_path_timer __all__ = [ "MagiConfig", "ModelConfig", "EngineConfig", "RuntimeConfig", "magi_logger", "print_per_rank", "print_rank_0", "event_path_timer", "divide", "env_is_true", "set_random_seed", "PackedCoreAttnParams", "PackedCrossAttnParams", "ModelMetaArgs", "InferenceParams", ] ``` ## /inference/common/common_utils.py ```py path="/inference/common/common_utils.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import random import numpy as np import torch def env_is_true(env_name: str) -> bool: return str(os.environ.get(env_name, "0")).lower() in {"1", "true", "yes", "y", "on", "enabled"} def divide(numerator, denominator): assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) return numerator // denominator def set_random_seed(seed): """Set random seed. Args: seed (int): Seed to be used. If not provided or set to 0, a random seed will be generated. """ if not seed or seed == 0: seed = random.randint(0, 2**32 - 1) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) return seed ``` ## /inference/common/config.py ```py path="/inference/common/config.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import json import os import torch @dataclasses.dataclass class ModelConfig: model_name: str # Transformer num_layers: int = None # Number of transformer layers. hidden_size: int = None # Transformer hidden size. ffn_hidden_size: int = None # Transformer Feed-Forward Network hidden size num_attention_heads: int = None # Number of transformer attention heads. num_query_groups: int = 1 # Number of query groups, which used for GQA kv_channels: int = None # Projection weights dimension in multi-head attention layernorm_epsilon: float = 1e-6 # Epsilon for layer norm and RMS norm. apply_layernorm_1p: bool = False # Adjust LayerNorm weights which improves numerical stability. x_rescale_factor: float = 1.0 half_channel_vae: bool = False params_dtype: torch.dtype = None # Embedding patch_size: int = 2 # (latent) patch size for DiT patch embedding layer t_patch_size: int = 1 # (latent) patch size for t dim patch embedding layer in_channels: int = 4 # latent input channel for DiT out_channels: int = 4 # latent output channel for DiT cond_hidden_ratio: float = 0.25 caption_channels: int = 4096 caption_max_length: int = 800 xattn_cond_hidden_ratio: float = 1.0 cond_gating_ratio: float = 1.0 gated_linear_unit: bool = False @dataclasses.dataclass class RuntimeConfig: # Inference settings such as cfg, kv range, clean t, etc. cfg_number: int = None # Number of CFG cfg_t_range: list = dataclasses.field( default_factory=lambda: [0, 0.0217, 0.1000, 0.3, 0.999] ) # CFG t-range of each scales prev_chunk_scales: list = dataclasses.field( default_factory=lambda: [1.5, 1.5, 1.5, 1.5, 1.5] ) # CFG scales of previous chunks text_scales: list = dataclasses.field(default_factory=lambda: [7.5, 7.5, 7.5, 7.5, 7.5]) # CFG scales of text noise2clean_kvrange: list = dataclasses.field(default_factory=list) # Range of kv for noise2clean chunks clean_chunk_kvrange: int = -1 # Range of kv for clean chunks clean_t: float = 1.0 # timestep for clean chunks # Video settings seed: int = 1234 # Random seed used for python, numpy, pytorch, and cuda. num_frames: int = 128 video_size_h: int = None video_size_w: int = None num_steps: int = 64 # Number of steps for the diffusion model window_size: int = 4 # Window size for the diffusion model fps: int = 24 # Frames per second chunk_width: int = 6 # Clip width for the diffusion model # Checkpoint, includes t5, vae, dit, etc. t5_pretrained: str = None # Path to load pretrained T5 model. t5_device: str = "cuda" # Device for T5 model to run on. vae_pretrained: str = None # Path to load pretrained VAE model. scale_factor: float = 0.18215 # Scale factor for the vae temporal_downsample_factor: int = 4 # Temporal downsample factor for the vae load: str = None # Directory containing a model checkpoint. @dataclasses.dataclass class EngineConfig: # Parallism strategy distributed_backend: str = "nccl" # Choices: ["nccl", "gloo"] distributed_timeout_minutes: int = 10 # Timeout minutes for torch.distributed. pp_size: int = 1 # Degree of pipeline model parallelism. cp_size: int = 1 # Degree of context parallelism. cp_strategy: str = "none" # Choices: ["none", "cp_ulysses", "cp_shuffle_overlap"] ulysses_overlap_degree: int = 1 # Overlap degree for Ulysses # Quantization fp8_quant: bool = False # Enable 8-bit floating point quantization for model weights. # Distillation distill_nearly_clean_chunk_threshold: float = 0.3 # Threshold for distilling nearly clean chunks shortcut_mode: str = "8,16,16" # Parameters for shortcut mode distill: bool = False # Use distill mode # Optimization kv_offload: bool = False # Use kv-offload algorithm enable_cuda_graph: bool = False # Enable CUDA graph for video generation @dataclasses.dataclass class MagiConfig: model_config: ModelConfig runtime_config: RuntimeConfig engine_config: EngineConfig @classmethod def _check_missing_fields(cls, config_dict: dict, required_fields: list): actual_fields = set(config_dict.keys()) missing_fields = set(required_fields) - actual_fields if missing_fields: raise ValueError(f"Missing fields in the configuration file: {', '.join(missing_fields)}") @classmethod def _create_nested_config(cls, config_dict: dict, config_name: str, config_cls): nested_config_dict = config_dict.get(config_name, {}) cls._check_missing_fields(nested_config_dict, config_cls.__dataclass_fields__.keys()) return config_cls(**nested_config_dict) @classmethod def _create_config_from_dict(cls, config_dict: dict): cls._check_missing_fields(config_dict, cls.__dataclass_fields__.keys()) # Create nested configs model_config = cls._create_nested_config(config_dict, "model_config", ModelConfig) runtime_config = cls._create_nested_config(config_dict, "runtime_config", RuntimeConfig) engine_config = cls._create_nested_config(config_dict, "engine_config", EngineConfig) return cls(model_config=model_config, runtime_config=runtime_config, engine_config=engine_config) @classmethod def from_json(cls, json_path: str): def simple_json_decoder(dct): dtype_map = {"torch.bfloat16": torch.bfloat16, "torch.float16": torch.float16, "torch.float32": torch.float32} if 'params_dtype' in dct: dct['params_dtype'] = dtype_map[dct['params_dtype']] return dct with open(json_path, "r") as f: config_dict = json.load(f, object_hook=simple_json_decoder) magi_config = cls._create_config_from_dict(config_dict) def post_validation(magi_config): if magi_config.engine_config.fp8_quant or magi_config.engine_config.distill: assert ( magi_config.runtime_config.cfg_number == 1 ), "Please set `cfg_number: 1` in config.json for distill or quant model" else: assert magi_config.runtime_config.cfg_number == 3, "Please set `cfg_number: 3` in config.json for base model" post_validation(magi_config) return magi_config def to_json(self, json_path: str): class SimpleJSONEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, torch.dtype): return str(obj) return super().default(obj) # Ensure the directory exists os.makedirs(os.path.dirname(json_path), exist_ok=True) config_dict = { "model_config": dataclasses.asdict(self.model_config), "runtime_config": dataclasses.asdict(self.runtime_config), "engine_config": dataclasses.asdict(self.engine_config), } with open(json_path, "w") as f: json.dump(config_dict, f, indent=4, cls=SimpleJSONEncoder) ``` ## /inference/common/dataclass.py ```py path="/inference/common/dataclass.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List import numpy as np import torch @dataclass(frozen=True) class PackedCoreAttnParams: # Packed sequence parameters for core_attn q_range: torch.Tensor k_range: torch.Tensor np_q_range: np.ndarray np_k_range: np.ndarray max_seqlen_q: int max_seqlen_k: int @dataclass(frozen=True) class PackedCrossAttnParams: # Packed sequence parameters for cross_attn q_ranges: torch.Tensor = None kv_ranges: torch.Tensor = None cu_seqlens_q: torch.Tensor = None cu_seqlens_kv: torch.Tensor = None max_seqlen_q: int = None max_seqlen_kv: int = None @dataclass(frozen=True) class ModelMetaArgs: H: int W: int cp_pad_size: int cp_split_sizes: List[int] slice_point: int denoising_range_num: int range_num: int extract_prefix_video_feature: bool fwd_extra_1st_chunk: bool distill_nearly_clean_chunk: bool clip_token_nums: int enable_cuda_graph: bool core_attn_params: PackedCoreAttnParams cross_attn_params: PackedCrossAttnParams class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" def __init__(self, max_batch_size, max_sequence_length): self.max_sequence_length = max_sequence_length self.max_batch_size = max_batch_size self.sequence_len_offset = 0 self.key_value_memory_dict = {} self.update_kv_cache = False def swap_key_value_dict(self, batch_idx): "swap between batches" if len(self.key_value_memory_dict) == 0: raise ValueError("should not swap when dict in empty") for layer_number in self.key_value_memory_dict.keys(): inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] assert len(batch_idx) == inference_key_memory.shape[1] # make sure batch size is the same new_inference_key_memory = inference_key_memory[:, batch_idx] new_inference_value_memory = inference_value_memory[:, batch_idx] self.key_value_memory_dict[layer_number] = (new_inference_key_memory, new_inference_value_memory) ``` ## /inference/common/logger.py ```py path="/inference/common/logger.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import torch class GlobalLogger: _logger = None @classmethod def get_logger(cls, name=__name__, level=logging.INFO): if cls._logger is None: cls._logger = logging.getLogger("magi_logger") cls._logger.setLevel(logging.INFO) cls._logger.propagate = False cls._logger.handlers.clear() formatter = logging.Formatter("[%(asctime)s - %(levelname)s] %(message)s") handler = logging.StreamHandler() handler.setFormatter(formatter) cls._logger.addHandler(handler) return cls._logger magi_logger = GlobalLogger.get_logger() def print_per_rank(message): magi_logger.info(message) def print_rank_0(message): if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: magi_logger.info(message) else: magi_logger.info(message) ``` ## /inference/common/timer.py ```py path="/inference/common/timer.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from datetime import datetime import torch from .logger import print_rank_0 class EventPathTimer: """ A lightweight class for recording time without any distributed barrier. This class allows for recording elapsed time between events without requiring synchronization across distributed processes. It maintains the previous message and time to calculate the duration between consecutive records. """ def __init__(self): """ Initialize the EventPathTimer. This constructor sets the previous message and time to None, preparing the instance for recording events. """ self.prev_message: str = None self.prev_time: datetime = None def reset(self): """ Reset the recorded message and time. This method clears the previous message and time, allowing for a fresh start in recording new events. """ self.prev_message = None self.prev_time = None def synced_record(self, message): """ Record the current time with a message. Args: message (str): A message to log along with the current time. This method synchronizes the CUDA operations, records the current time, and calculates the elapsed time since the last recorded message, if any. It then logs the elapsed time along with the previous and current messages. """ torch.cuda.synchronize() current_time = datetime.now() if self.prev_message is not None: print_rank_0( f"\nTime Elapsed: [{current_time - self.prev_time}] From [{self.prev_message} ({self.prev_time})] To [{message} ({current_time})]" ) self.prev_message = message self.prev_time = current_time _GLOBAL_LIGHT_TIMER = EventPathTimer() def event_path_timer() -> EventPathTimer: """Get the current EventPathTimer instance. Returns: EventPathTimer: The current EventPathTimer instance. Raises: AssertionError: If the EventPathTimer has not been initialized. """ assert _GLOBAL_LIGHT_TIMER is not None, "light time recorder is not initialized" return _GLOBAL_LIGHT_TIMER ``` ## /inference/infra/checkpoint/__init__.py ```py path="/inference/infra/checkpoint/__init__.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .checkpointing import load_checkpoint __all__ = ["load_checkpoint"] ``` ## /inference/infra/checkpoint/checkpointing.py ```py path="/inference/infra/checkpoint/checkpointing.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import json import os import re import subprocess from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor from datetime import datetime import numpy as np import torch import torch.distributed from safetensors.torch import load as load_from_bytes from safetensors.torch import load_file from tqdm.auto import tqdm import inference.infra.distributed.parallel_state as mpu from inference.common import EngineConfig, ModelConfig, RuntimeConfig, print_per_rank, print_rank_0 def _load_shard(shard_path, param_names, num_threads=None): zstd_path = shard_path + ".zst" if os.path.exists(zstd_path): start_time = datetime.now() print_per_rank(f"Decompressing {zstd_path} with {num_threads} threads") cmd = ["zstd", "-d"] if num_threads: cmd.extend(["-T", str(num_threads)]) process = subprocess.Popen(cmd + ["-c", zstd_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=-1) decompressed_data = process.stdout.read() process.stdout.close() retcode = process.wait() if retcode != 0: raise RuntimeError(f"Decompression failed: {process.stderr.read().decode()}") print_per_rank( f"Decompressed {zstd_path} with {num_threads} threads, duration: {(datetime.now() - start_time).total_seconds()}s" ) buffer = io.BytesIO(decompressed_data) start_time = datetime.now() print_per_rank(f"Loading {shard_path} from zstd file, start time: {start_time}") weights = load_from_bytes(buffer.getvalue()) print_per_rank(f"Loaded {shard_path} from zstd file, duration: {(datetime.now() - start_time).total_seconds()}s") buffer.close() else: weights = load_file(shard_path) return {name: weights[name] for name in param_names} def load_sharded_safetensors_parallel_with_progress(checkpoint_dir): index_path = os.path.join(checkpoint_dir, "model.safetensors.index.json") with open(index_path, "r") as f: index = json.load(f) state_dict = {} shard_map = {} # Group parameters by shard file for param_name, shard_file in index["weight_map"].items(): shard_path = os.path.join(checkpoint_dir, shard_file) if shard_path not in shard_map: shard_map[shard_path] = [] shard_map[shard_path].append(param_name) # Load shards in parallel with a progress bar with ThreadPoolExecutor() as executor: futures = { executor.submit(_load_shard, shard_path, param_names): shard_path for shard_path, param_names in shard_map.items() } pbar = tqdm(futures, desc="Loading shards", total=len(futures)) for future in pbar: result = future.result() state_dict.update(result) return state_dict def unwrap_model(model): return_list = True if not isinstance(model, list): model = [model] return_list = False unwrapped_model = [] for model_module in model: while hasattr(model_module, "module"): model_module = model_module.module unwrapped_model.append(model_module) if not return_list: return unwrapped_model[0] return unwrapped_model def _split_state_dict_for_pp(weight_dict: OrderedDict, model_config: ModelConfig): num_layers = model_config.num_layers partition = mpu.get_pp_world_size() ## use partition and num_layers to get current rank layer order layers_for_each_stage = np.array_split(range(num_layers), partition) current_stage = mpu.get_pp_rank() allow_layer_num = layers_for_each_stage[current_stage] layer_offset = allow_layer_num[0] new_weight_dict = {} for k, v in weight_dict.items(): if "videodit_blocks.layers" in k: layer_num = int(re.search(r"videodit_blocks\.layers\.(\d+)", k).group(1)) if layer_num not in allow_layer_num: continue ## replace the old key name by new layer number new_layer_num = layer_num - layer_offset new_k = k.replace(f"videodit_blocks.layers.{layer_num}", f"videodit_blocks.layers.{new_layer_num}") new_weight_dict[new_k] = v else: new_weight_dict[k] = v return new_weight_dict def load_state_dict(runtime_config: RuntimeConfig, engine_config: EngineConfig): load_dir = runtime_config.load default_subdir = "inference_weight" if engine_config.fp8_quant: default_subdir = f"{default_subdir}.fp8" if engine_config.distill: default_subdir = f"{default_subdir}.distill" inference_weight_dir = os.path.join(load_dir, default_subdir) assert os.path.exists(inference_weight_dir) print_rank_0(f"load {default_subdir} weight from {inference_weight_dir}") assert ( os.path.exists(inference_weight_dir) and len(os.listdir(inference_weight_dir)) > 0 ), f"Ckpt directory {inference_weight_dir} does not exist or empty. If you are using fp8_quant, please run calibration first." state_dict = load_sharded_safetensors_parallel_with_progress(inference_weight_dir) return state_dict def load_checkpoint(model): state_dict = load_state_dict(model.runtime_config, model.engine_config) model = unwrap_model(model) # if we use pipeline parallelism, we need to load the state dict for each stage # as it always record layer from 0 -> num_layers//pipeline_parallel_size # so we need to choose correct layer weight when load_state_dict if mpu.get_pp_world_size() > 1: state_dict = _split_state_dict_for_pp(state_dict, model.model_config) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True) model.cuda(torch.cuda.current_device()) if mpu.get_pp_world_size() > 1: rank_msg = f"CP_rank={mpu.get_cp_rank()} PP_rank={mpu.get_pp_rank()}" print_per_rank( f"""[{rank_msg}] Load Weight Missing Keys: {missing_keys} Load Weight Unexpected Keys: {unexpected_keys} You should see message [missing fianl layer norm weight] except the final pipeline stage""" ) else: print_rank_0(f"Load Weight Missing Keys: {missing_keys}") print_rank_0(f"Load Weight Unexpected Keys: {unexpected_keys}") return model ``` ## /inference/infra/distributed/__init__.py ```py path="/inference/infra/distributed/__init__.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .dist_utils import dist_init, get_device, get_world_size, is_last_rank, is_last_tp_cp_rank from .parallel_state import ( destroy_model_parallel, get_cp_group, get_cp_rank, get_cp_world_size, get_dp_group, get_dp_group_gloo, get_dp_rank, get_dp_world_size, get_pipeline_model_parallel_first_rank, get_pipeline_model_parallel_last_rank, get_pipeline_model_parallel_next_rank, get_pipeline_model_parallel_prev_rank, get_pp_group, get_pp_rank, get_pp_world_size, get_tensor_model_parallel_last_rank, get_tensor_model_parallel_ranks, get_tensor_model_parallel_src_rank, get_tp_group, get_tp_rank, get_tp_world_size, is_initialized, is_pipeline_first_stage, is_pipeline_last_stage, ) __all__ = [ "dist_init", "is_initialized", "get_tp_group", "get_pp_group", "get_dp_group", "get_dp_group_gloo", "get_cp_group", "get_tp_world_size", "get_pp_world_size", "get_dp_world_size", "get_cp_world_size", "get_tp_rank", "get_pp_rank", "get_dp_rank", "get_cp_rank", "is_pipeline_first_stage", "is_pipeline_last_stage", "get_tensor_model_parallel_src_rank", "get_tensor_model_parallel_ranks", "get_tensor_model_parallel_last_rank", "get_pipeline_model_parallel_first_rank", "get_pipeline_model_parallel_last_rank", "get_pipeline_model_parallel_next_rank", "get_pipeline_model_parallel_prev_rank", "destroy_model_parallel", "is_last_rank", "is_last_tp_cp_rank", "get_world_size", "get_device", ] ``` ## /inference/infra/distributed/dist_utils.py ```py path="/inference/infra/distributed/dist_utils.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from datetime import timedelta import torch import inference.infra.distributed.parallel_state as mpu from inference.common import print_rank_0 from inference.infra.parallelism.pipeline_parallel import init_pp_scheduler from . import parallel_state as mpu def dist_init(config): """Initialize torch.distributed and core model parallel.""" assert torch.cuda.is_available() device_count = torch.cuda.device_count() if torch.distributed.is_initialized(): print_rank_0("Torch distribution already initialized, skipping initialization ...") else: rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) # Manually set the device ids. if device_count > 0: device = rank % device_count torch.cuda.set_device(device) # Call the init process torch.distributed.init_process_group( backend=config.engine_config.distributed_backend, world_size=world_size, rank=rank, timeout=timedelta(minutes=config.engine_config.distributed_timeout_minutes), ) assert config.engine_config.cp_size * config.engine_config.pp_size == torch.distributed.get_world_size() if device_count > 0: if mpu.model_parallel_is_initialized(): print_rank_0("Model parallel is already initialized") else: mpu.initialize_model_parallel( cp_size=config.engine_config.cp_size, pp_size=config.engine_config.pp_size, nccl_communicator_config_path=None, distributed_timeout_minutes=config.engine_config.distributed_timeout_minutes, order="tp-cp-pp-dp", ) if mpu.get_pp_world_size() > 1: init_pp_scheduler() print_rank_0("Initialize torch distribution and model parallel successfully") def is_last_rank(): return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1) def is_last_tp_cp_rank(): return mpu.get_tp_rank(with_context_parallel=True) == mpu.get_tp_world_size(with_context_parallel=True) - 1 def get_world_size(): if torch.distributed.is_available() and torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() else: world_size = 1 return world_size def get_device(local_rank=None): backend = torch.distributed.get_backend() if backend == "nccl": if local_rank is None: device = torch.device("cuda") else: device = torch.device(f"cuda:{local_rank}") elif backend == "gloo": device = torch.device("cpu") else: raise RuntimeError return device ``` ## /inference/infra/distributed/parallel_state.py ```py path="/inference/infra/distributed/parallel_state.py" # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Model and data parallel groups.""" import warnings from datetime import timedelta from typing import List, Optional import torch # Intra-layer model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Tensor parallel group information with context parallel combined. _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP = None _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP = None # Inter-layer model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None # Model parallel group (both intra- and pipeline) that the current rank belongs to. _MODEL_PARALLEL_GROUP = None # Data parallel group that the current rank belongs to. _DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP_GLOO = None # tensor model parallel group and data parallel group combined # used for fp8 and moe training _TENSOR_AND_DATA_PARALLEL_GROUP = None # A list of global ranks for each pipeline group to ease calculation of the source # rank when broadcasting from the first or last pipeline stage. _PIPELINE_GLOBAL_RANKS = None # A list of global ranks for each data parallel group to ease calculation of the source # rank when broadcasting weights from src to all other data parallel ranks _DATA_PARALLEL_GLOBAL_RANKS = None # A list of global ranks for each tensor model parallel group to ease calculation of # the first local rank in the tensor model parallel group _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None # Context parallel group that the current rank belongs to _CONTEXT_PARALLEL_GROUP = None # A list of global ranks for each context parallel group to ease calculation of the # destination rank when exchanging KV/dKV between context parallel_ranks _CONTEXT_PARALLEL_GLOBAL_RANKS = None # Data parallel group information with context parallel combined. _DATA_PARALLEL_GROUP_WITH_CP = None _DATA_PARALLEL_GROUP_WITH_CP_GLOO = None _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None # combined parallel group of TP, DP, and CP used for fp8 _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None def get_nccl_options(pg_name, nccl_comm_cfgs): """Set the NCCL process group options. Args: pg_name (str): process group name nccl_comm_cfgs (dict): nccl communicator configurations When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting. """ if pg_name in nccl_comm_cfgs: nccl_options = torch.distributed.ProcessGroupNCCL.Options() nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get("cga_cluster_size", 4) nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get("max_ctas", 32) nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get("min_ctas", 1) return nccl_options else: return None def generate_masked_orthogonal_rank_groups(world_size: int, parallel_size: List[int], mask: List[bool]) -> List[List[int]]: """Generate orthogonal parallel groups based on the parallel size and mask. Arguments: world_size (int): world size parallel_size (List[int]): The parallel size of each orthogonal parallel type. For example, if tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. mask (List[bool]): The mask controls which parallel methods the generated groups represent. If mask[i] is True, it means the generated group contains the i-th parallelism method. For example, if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then the generated group is the `tp-dp` group, if the mask = [False, True, False], then the generated group is the `pp` group. Algorithm: For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and local_rank satisfy the following equation: global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1) tp_rank \in [0, tp_size) dp_rank \in [0, dp_size) pp_rank \in [0, pp_size) If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each. For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].) The tp_rank and pp_rank will be combined to form the `dp_group_index`. dp_group_index = tp_rank + pp_rank * tp_size (2) So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the equation (1). This function solve this math problem. For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4], and the mask = [False, True, False]. Then, dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2 dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2 ... dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2 dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4] dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5] ... dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23] """ def prefix_product(a: List[int], init=1) -> List[int]: r = [init] for v in a: init = init * v r.append(init) return r def inner_product(a: List[int], b: List[int]) -> int: return sum([x * y for x, y in zip(a, b)]) def decompose(index, shape, stride=None): """ This function solve the math problem below: There is an equation: index = sum(idx[i] * stride[i]) And given the value of index, stride. Return the idx. This function will used to get the pp/dp/pp_rank from group_index and rank_in_group. """ if stride is None: stride = prefix_product(shape) idx = [(index // d) % s for s, d in zip(shape, stride)] # stride is a prefix_product result. And the value of stride[-1] # is not used. assert ( sum([x * y for x, y in zip(idx, stride[:-1])]) == index ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) return idx masked_shape = [s for s, m in zip(parallel_size, mask) if m] unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] global_stride = prefix_product(parallel_size) masked_stride = [d for d, m in zip(global_stride, mask) if m] unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] group_size = prefix_product(masked_shape)[-1] num_of_group = world_size // group_size ranks = [] for group_index in range(num_of_group): # get indices from unmaksed for group_index. decomposed_group_idx = decompose(group_index, unmasked_shape) rank = [] for rank_in_group in range(group_size): # get indices from masked for rank_in_group. decomposed_rank_idx = decompose(rank_in_group, masked_shape) rank.append( inner_product(decomposed_rank_idx, masked_stride) + inner_product(decomposed_group_idx, unmasked_stride) ) ranks.append(rank) return ranks class RankGenerator(object): def __init__(self, tp: int, dp: int, pp: int, cp: int, order: str) -> None: self.tp = tp self.dp = dp self.pp = pp self.cp = cp self.world_size = tp * dp * pp * cp self.name_to_size = {"tp": self.tp, "pp": self.pp, "dp": self.dp, "cp": self.cp} order = order.lower() for name in self.name_to_size.keys(): if name not in order and self.name_to_size[name] != 1: raise RuntimeError( f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." ) elif name not in order: order = order + "-" + name self.order = order self.ordered_size = [self.name_to_size[token] for token in order.split("-")] def get_mask(self, order: str, token: str): ordered_token = order.split("-") token = token.split("-") mask = [False] * len(ordered_token) for t in token: mask[ordered_token.index(t)] = True return mask def get_ranks(self, token): """Get rank group by input token. Arguments: token (str): Specify the ranks type that want to get. If we want to obtain multiple parallel types, we can use a hyphen '-' to separate them. For example, if we want to obtain the TP_DP group, the token should be 'tp-dp'. """ mask = self.get_mask(self.order, token) ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask) return ranks def initialize_model_parallel( tp_size: int = 1, pp_size: int = 1, cp_size: int = 1, nccl_communicator_config_path: Optional[str] = None, distributed_timeout_minutes: int = 30, order: str = "tp-cp-pp-dp", ) -> None: """Initialize model data parallel groups. Borrow from: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py Args: tp_size (int, default = 1): The number of GPUs to split individual tensors across. pp_size (int, default = 1): The number of tensor parallel GPU groups to split the Transformer layers across. For example, if tp_size is 4 and pp_size is 2, the model will be split into 2 groups of 4 GPUs. cp_size (int, default = 1): The number of tensor parallel GPU groups to split the network input sequence length across. Compute of attention module requires tokens of full sequence length, so GPUs in a context parallel group need to communicate with each other to exchange information of other sequence chunks. Each GPU and its counterparts in other tensor parallel groups compose a context parallel group. For example, assume we have 8 GPUs, if tensor model parallel size is 4 and context parallel size is 2, the network input will be split into two sequence chunks, which are processed by 2 different groups of 4 GPUs. One chunk is processed by GPU0-3, the other chunk is processed by GPU4-7. Four groups are build to do context parallel communications: [GPU0, GPU4], [GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7]. Context parallelism partitions sequence length, so it has no impact on weights, which means weights are duplicated among GPUs in a context parallel group. Hence, weight gradients all-reduce is required in backward. For simplicity, we piggyback GPUs of context parallelism on data parallel group for weight gradient all-reduce. nccl_communicator_config_path (str, default = None): Path to the yaml file of NCCL communicator configurations. `min_ctas`, `max_ctas`, and `cga_cluster_size` can be set for each communicator. distributed_timeout_minutes (int, default = 30): Timeout, in minutes,for operations executed against distributed process groups. See PyTorch documentation at https://pytorch.org/docs/stable/distributed.html for caveats. order (str, default=tp-dp-pp): The rank initialization order of parallelism. Now we support tp-dp-pp and tp-pp-dp orders. Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize the model pipeline. The present function will create 8 tensor model-parallel groups, 4 pipeline model-parallel groups and 8 data-parallel groups as: 8 data_parallel groups: [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] 8 tensor model-parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] 4 pipeline model-parallel groups: [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() if world_size % (tp_size * pp_size * cp_size) != 0: raise RuntimeError( f"world_size ({world_size}) is not divisible by tp_size " f"({tp_size}) x pp_size ({pp_size}) " f"x cp_size ({cp_size})" ) nccl_comm_cfgs = {} if nccl_communicator_config_path is not None: try: import yaml except ImportError: raise RuntimeError("Cannot import `yaml`. Setting custom nccl communicator configs " "requires the yaml package.") with open(nccl_communicator_config_path, "r") as stream: nccl_comm_cfgs = yaml.safe_load(stream) dp_size: int = world_size // (tp_size * pp_size * cp_size) rank = torch.distributed.get_rank() rank_generator = RankGenerator(tp=tp_size, dp=dp_size, pp=pp_size, cp=cp_size, order=order) timeout = timedelta(minutes=distributed_timeout_minutes) # Build the data-parallel groups. global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP_GLOO global _DATA_PARALLEL_GLOBAL_RANKS global _DATA_PARALLEL_GROUP_WITH_CP global _DATA_PARALLEL_GROUP_WITH_CP_GLOO global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" for ranks in rank_generator.get_ranks("dp"): group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("dp", nccl_comm_cfgs)) group_gloo = torch.distributed.new_group(ranks, timeout=timeout, backend="gloo") if rank in ranks: _DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GROUP_GLOO = group_gloo _DATA_PARALLEL_GLOBAL_RANKS = ranks for ranks_with_cp in rank_generator.get_ranks("dp-cp"): group_with_cp = torch.distributed.new_group( ranks_with_cp, timeout=timeout, pg_options=get_nccl_options("dp_cp", nccl_comm_cfgs) ) group_with_cp_gloo = torch.distributed.new_group(ranks_with_cp, timeout=timeout, backend="gloo") if rank in ranks_with_cp: _DATA_PARALLEL_GROUP_WITH_CP = group_with_cp _DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp # Build the context-parallel groups. global _CONTEXT_PARALLEL_GROUP global _CONTEXT_PARALLEL_GLOBAL_RANKS assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized" for ranks in rank_generator.get_ranks("cp"): group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("cp", nccl_comm_cfgs)) if rank in ranks: _CONTEXT_PARALLEL_GROUP = group _CONTEXT_PARALLEL_GLOBAL_RANKS = ranks # Build the model-parallel groups. global _MODEL_PARALLEL_GROUP assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" for ranks in rank_generator.get_ranks("tp-pp"): group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("mp", nccl_comm_cfgs)) if rank in ranks: _MODEL_PARALLEL_GROUP = group # Build the tensor model-parallel groups. global _TENSOR_MODEL_PARALLEL_GROUP global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized" for ranks in rank_generator.get_ranks("tp"): group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("tp", nccl_comm_cfgs)) if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP = group _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ranks # Build the tensor + context parallel groups. global _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP assert ( _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP is None ), "tensor model parallel group with context parallel is already initialized" for ranks in rank_generator.get_ranks("tp-cp"): group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("tp_cp", nccl_comm_cfgs)) if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP = group _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks # Build the pipeline model-parallel groups global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_GLOBAL_RANKS assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized" for ranks in rank_generator.get_ranks("pp"): group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("pp", nccl_comm_cfgs)) if rank in ranks: _PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_GLOBAL_RANKS = ranks # Build the tensor + data parallel groups. global _TENSOR_AND_DATA_PARALLEL_GROUP global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP assert _TENSOR_AND_DATA_PARALLEL_GROUP is None, "Tensor + data parallel group is already initialized" for ranks in rank_generator.get_ranks("tp-cp-dp"): group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("tp_cp_dp", nccl_comm_cfgs)) if rank in ranks: _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group for ranks in rank_generator.get_ranks("tp-dp"): group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("tp_dp", nccl_comm_cfgs)) if rank in ranks: _TENSOR_AND_DATA_PARALLEL_GROUP = group def is_initialized(): """Useful for code segments that may be accessed with or without mpu initialization""" return _DATA_PARALLEL_GROUP is not None def is_unitialized() -> bool: """Check if parallel state has been initialized Deprecated. Use is_initialized instead. """ warnings.warn("is_unitialized is deprecated, use is_initialized instead", DeprecationWarning) return not is_initialized() def model_parallel_is_initialized(): """Check if model and data parallel groups are initialized.""" if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: return False return True def get_model_parallel_group(): """Get the model parallel group the caller rank belongs to.""" assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" return _MODEL_PARALLEL_GROUP def get_tp_group(check_initialized=True, with_context_parallel=False): """Get the tensor model parallel group the caller rank belongs to.""" if check_initialized: assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "tensor model parallel group is not initialized" if with_context_parallel: assert ( _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP is not None ), "tensor model parallel group with context parallel combined is not initialized" return _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP else: assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "tensor model parallel group is not initialized" return _TENSOR_MODEL_PARALLEL_GROUP def get_pp_group(): """Get the pipeline model parallel group the caller rank belongs to.""" assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, "pipeline_model parallel group is not initialized" return _PIPELINE_MODEL_PARALLEL_GROUP def get_dp_group(with_context_parallel=False): """Get the data parallel group the caller rank belongs to.""" if with_context_parallel: assert ( _DATA_PARALLEL_GROUP_WITH_CP is not None ), "data parallel group with context parallel combined is not initialized" return _DATA_PARALLEL_GROUP_WITH_CP else: assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" return _DATA_PARALLEL_GROUP def get_dp_group_gloo(with_context_parallel=False): """Get the data parallel group-gloo the caller rank belongs to.""" if with_context_parallel: assert ( _DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None ), "data parallel group-gloo with context parallel combined is not initialized" return _DATA_PARALLEL_GROUP_WITH_CP_GLOO else: assert _DATA_PARALLEL_GROUP_GLOO is not None, "data parallel group-gloo is not initialized" return _DATA_PARALLEL_GROUP_GLOO def get_cp_group(check_initialized=True): """Get the context parallel group the caller rank belongs to.""" if check_initialized: assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized" return _CONTEXT_PARALLEL_GROUP def get_tp_world_size(with_context_parallel=False): """Return world size for the tensor model parallel group.""" return torch.distributed.get_world_size(group=get_tp_group(with_context_parallel=with_context_parallel)) def get_pp_world_size(): """Return world size for the pipeline model parallel group.""" return torch.distributed.get_world_size(group=get_pp_group()) def get_tp_rank(with_context_parallel=False): """Return my rank for the tensor model parallel group.""" return torch.distributed.get_rank(group=get_tp_group(with_context_parallel=with_context_parallel)) def get_pp_rank(): """Return my rank for the pipeline model parallel group.""" return torch.distributed.get_rank(group=get_pp_group()) def is_pipeline_first_stage(): """Return True if in the first pipeline model-parallel stage, False otherwise.""" return get_pp_rank() == 0 def is_pipeline_last_stage(): """Return True if in the last pipeline model-parallel stage, False otherwise.""" return get_pp_rank() == (get_pp_world_size() - 1) def get_tensor_model_parallel_src_rank(with_context_parallel=False): """Calculate the global rank corresponding to the first local rank in the tensor model parallel group.""" assert _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None, "Tensor model parallel group is not initialized" if with_context_parallel: assert ( _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP is not None ), "Tensor model parallel group with context parallel combined is not initialized" return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP[0] else: return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[0] def get_tensor_model_parallel_ranks(with_context_parallel=False): """Return all global ranks for the tensor model parallel group.""" if with_context_parallel: assert ( _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP is not None ), "Tensor model parallel group with context parallel combined is not initialized" return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP else: assert _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None, "Tensor model parallel group is not initialized" return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS def get_tensor_model_parallel_last_rank(with_context_parallel=False): """Calculate the global rank corresponding to the first local rank in the tensor model parallel group.""" assert _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None, "Tensor model parallel group is not initialized" if with_context_parallel: assert ( _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP is not None ), "Tensor model parallel group with context parallel combined is not initialized" return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP[-1] else: return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[-1] def get_pipeline_model_parallel_first_rank(): """Return the global rank of the first process in the pipeline for the current tensor parallel group""" assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" return _PIPELINE_GLOBAL_RANKS[0] def get_pipeline_model_parallel_last_rank(): """Return the global rank of the last process in the pipeline for the current tensor parallel group""" assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" last_rank_local = get_pp_world_size() - 1 return _PIPELINE_GLOBAL_RANKS[last_rank_local] def get_pipeline_model_parallel_next_rank(): """Return the global rank that follows the caller in the pipeline""" assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" rank_in_pipeline = get_pp_rank() world_size = get_pp_world_size() return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] def get_pipeline_model_parallel_prev_rank(): """Return the global rank that preceeds the caller in the pipeline""" assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" rank_in_pipeline = get_pp_rank() world_size = get_pp_world_size() return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] def get_dp_world_size(with_context_parallel=False): """Return world size for the data parallel group.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): return torch.distributed.get_world_size(group=get_dp_group(with_context_parallel=with_context_parallel)) else: return 0 def get_dp_rank(with_context_parallel=False): """Return my rank for the data parallel group.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): return torch.distributed.get_rank(group=get_dp_group(with_context_parallel=with_context_parallel)) else: return 0 def get_cp_world_size(): """Return world size for the context parallel group.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): return torch.distributed.get_world_size(group=get_cp_group()) else: return 0 def get_cp_rank(): """Return my rank for the context parallel group.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): return torch.distributed.get_rank(group=get_cp_group()) else: return 0 def destroy_model_parallel(): """Set the groups to none.""" global _MODEL_PARALLEL_GROUP _MODEL_PARALLEL_GROUP = None global _TENSOR_MODEL_PARALLEL_GROUP _TENSOR_MODEL_PARALLEL_GROUP = None global _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP = None global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP = None global _PIPELINE_MODEL_PARALLEL_GROUP _PIPELINE_MODEL_PARALLEL_GROUP = None global _DATA_PARALLEL_GROUP _DATA_PARALLEL_GROUP = None global _DATA_PARALLEL_GROUP_GLOO _DATA_PARALLEL_GROUP_GLOO = None global _TENSOR_AND_DATA_PARALLEL_GROUP _TENSOR_AND_DATA_PARALLEL_GROUP = None global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None global _DATA_PARALLEL_GLOBAL_RANKS _DATA_PARALLEL_GLOBAL_RANKS = None global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None global _CONTEXT_PARALLEL_GROUP _CONTEXT_PARALLEL_GROUP = None global _CONTEXT_PARALLEL_GLOBAL_RANKS _CONTEXT_PARALLEL_GLOBAL_RANKS = None global _DATA_PARALLEL_GROUP_WITH_CP _DATA_PARALLEL_GROUP_WITH_CP = None global _DATA_PARALLEL_GROUP_WITH_CP_GLOO _DATA_PARALLEL_GROUP_WITH_CP_GLOO = None global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None ``` ## /inference/infra/parallelism/__init__.py ```py path="/inference/infra/parallelism/__init__.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .context_parallel import CSOHelper, UlyssesScheduler, cp_post_process, cp_pre_process, cso_communication from .pipeline_parallel import pp_scheduler from .tile_parallel import TileProcessor __all__ = [ "CSOHelper", "cso_communication", "UlyssesScheduler", "pp_scheduler", "TileProcessor", "cp_pre_process", "cp_post_process", ] ``` ## /inference/infra/parallelism/context_parallel.py ```py path="/inference/infra/parallelism/context_parallel.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Callable, List, Tuple, Union import torch import torch.distributed from einops import rearrange from inference.common import ModelMetaArgs, PackedCoreAttnParams, PackedCrossAttnParams, divide from inference.infra.distributed import parallel_state as mpu ##################################################### # Common Primitives ##################################################### def scatter_to_context_parallel_region(input_, cp_split_sizes, cp_shuffle_num=1, cp_pad_size=0): """Split the tensor along its first dimension and keep the corresponding slice.""" world_size = mpu.get_cp_world_size() # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ # Split along first dimension with padding. rank = mpu.get_cp_rank() if cp_shuffle_num > 1: cp_pad_size = divide(cp_pad_size, cp_shuffle_num) cp_split_sizes = [divide(s, cp_shuffle_num) for s in cp_split_sizes] dim_offset = sum(cp_split_sizes[:rank]) xs = [] for x in torch.chunk(input_, cp_shuffle_num, dim=0): x = torch.nn.functional.pad(x, [0, 0] * (x.dim() - 1) + [0, cp_pad_size], mode="constant", value=0) xs.append(x[dim_offset : dim_offset + cp_split_sizes[rank]]) output = torch.concat(xs, dim=0) else: dim_offset = sum(cp_split_sizes[:rank]) x = torch.nn.functional.pad(input_, [0, 0] * (input_.dim() - 1) + [0, cp_pad_size], mode="constant", value=0) output = x[dim_offset : dim_offset + cp_split_sizes[rank]].contiguous() return output def gather_from_context_parallel_region(input_, cp_split_sizes, cp_shuffle_num=1, cp_pad_size=0): """Gather tensors and concatinate along the first dimension.""" world_size = mpu.get_cp_world_size() # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ input_ = input_.contiguous() total_seq_len = sum(cp_split_sizes) dim_size = list(input_.size()) dim_size[0] = total_seq_len output = torch.empty(dim_size, dtype=input_.dtype, device=input_.device) outputs = list(torch.split(output, cp_split_sizes, dim=0)) torch.distributed.all_gather(outputs, input_, group=mpu.get_cp_group()) if cp_shuffle_num > 1: total_seq_len = divide(total_seq_len, cp_shuffle_num) cp_pad_size = divide(cp_pad_size, cp_shuffle_num) chunks = [torch.chunk(o, cp_shuffle_num, dim=0) for o in outputs] output = torch.concat( [ torch.concat([chunk[i] for chunk in chunks], dim=0)[: total_seq_len - cp_pad_size] for i in range(cp_shuffle_num) ], dim=0, ) else: output = torch.concat(outputs, dim=0)[: total_seq_len - cp_pad_size] return output class FakeHandle: def __init__(self): pass def wait(self): pass ##################################################### # Context Parallel Process ##################################################### def update_packed_seq_params_for_cuda_graph(cross_attn_params: PackedCrossAttnParams, xattn_mask: torch.Tensor): assert xattn_mask is not None # xattn_mask: (N * denoising_range_num, L, 1, 1) xattn_mask = xattn_mask.reshape(xattn_mask.shape[0], -1) batch_size, static_caption_length = xattn_mask.shape # Get index_map for kv_range injection, map y_index to static_caption_length y_index = torch.sum(xattn_mask, dim=-1) cu_seqlens_k = torch.cat([y_index.new_tensor([0]), y_index]).to(torch.int32).to(xattn_mask.device) cu_seqlens_k = cu_seqlens_k.cumsum(-1).to(torch.int32) static_cu_seqlens_k = torch.arange(0, (batch_size + 1) * static_caption_length, static_caption_length) assert cu_seqlens_k.shape[0] == batch_size + 1 == static_cu_seqlens_k.shape[0] start_index_map = dict(zip(cu_seqlens_k.flatten().tolist(), static_cu_seqlens_k.flatten().tolist())) # Move kv_range to the right position kv_range_start_list = cross_attn_params.kv_ranges[:, 0].flatten().tolist() static_kv_range_start = [start_index_map[kv_range_start_list[i]] for i in range(len(kv_range_start_list))] static_kv_range_start = torch.tensor(static_kv_range_start, dtype=torch.int32, device=xattn_mask.device) assert static_kv_range_start.shape[0] == cross_attn_params.kv_ranges.shape[0] static_kv_range_diff = cross_attn_params.kv_ranges[:, 1] - cross_attn_params.kv_ranges[:, 0] static_kv_range_end = static_kv_range_start + static_kv_range_diff static_kv_range = torch.stack((static_kv_range_start, static_kv_range_end), dim=1) assert static_kv_range.shape == cross_attn_params.kv_ranges.shape return PackedCrossAttnParams( q_ranges=cross_attn_params.q_ranges, kv_ranges=static_kv_range, cu_seqlens_q=cross_attn_params.cu_seqlens_q, cu_seqlens_kv=cross_attn_params.cu_seqlens_kv, max_seqlen_q=cross_attn_params.max_seqlen_q, max_seqlen_kv=cross_attn_params.max_seqlen_kv, ) def cp_update_cross_attn_qkv_range( cross_attn_params: PackedCrossAttnParams, batch_size: int, cp_split_sizes: List[int], device: torch.device, cp_shuffle_num: int = 1, cp_pad_size: int = 0, ): """ Update cross_attn_params for cross_attn in context parallel. Input: cross_attn_params: PackedCrossAttnParams. Packed sequence parameters for cross_atten batch_size: int. Batch size cp_split_sizes: List[int]. Split sizes for each rank device: torch.device. Device Output: cross_attn_params: PackedCrossAttnParams. Updated packed parameters for cross_atten """ # Update cu_seqlens_q and max_seqlen_q because split x maybe unbalanced cp_rank = mpu.get_cp_rank() seq_len_cur_rank = cp_split_sizes[cp_rank] cp_split_sizes = [divide(x, cp_shuffle_num) for x in cp_split_sizes] cp_split_sizes = torch.tensor(cp_split_sizes, dtype=torch.int32, device=device) base_cp_boundaries = torch.cat((torch.zeros(1, dtype=torch.int32, device=device), cp_split_sizes.cumsum(0))) total_seq_len = base_cp_boundaries[-1] cu_seqlens_q = cross_attn_params.cu_seqlens_q cu_seqlens_k = cross_attn_params.cu_seqlens_kv cu_seqlens_pad = torch.arange(cu_seqlens_q.shape[0], dtype=torch.int32, device=device) * divide( cp_pad_size, cp_shuffle_num ) cu_seqlens_q = cu_seqlens_q + cu_seqlens_pad q_seg_starts, q_seg_ends = cu_seqlens_q[:-1], cu_seqlens_q[1:] xattn_q_ranges, xattn_k_ranges = [], [] for i in range(batch_size): inner_xattn_q_ranges, inner_xattn_k_ranges = [], [] for j in range(cp_shuffle_num): global_offset = i * total_seq_len * cp_shuffle_num + j * total_seq_len cp_boundaries = base_cp_boundaries + global_offset this_cp_start, this_cp_end = (cp_boundaries[cp_rank], cp_boundaries[cp_rank + 1]) q_inter_starts = torch.maximum(this_cp_start, q_seg_starts) q_inter_ends = torch.minimum(this_cp_end, q_seg_ends) q_mask = q_inter_starts < q_inter_ends valid_q_starts = q_inter_starts[q_mask] valid_q_ends = q_inter_ends[q_mask] k_seg_starts, k_seg_ends = cu_seqlens_k[:-1], cu_seqlens_k[1:] valid_indices = torch.nonzero(q_mask, as_tuple=True)[0] valid_k_starts = k_seg_starts[valid_indices] valid_k_ends = k_seg_ends[valid_indices] part_xattn_q_rangs = torch.stack((valid_q_starts, valid_q_ends), dim=1) offset = part_xattn_q_rangs[:, 0].min() part_xattn_q_rangs = part_xattn_q_rangs - offset inner_xattn_q_ranges.append(part_xattn_q_rangs) inner_xattn_k_ranges.append(torch.stack((valid_k_starts, valid_k_ends), dim=1)) inner_end_values = torch.tensor([ranges[-1, -1] for ranges in inner_xattn_q_ranges], dtype=torch.int32) inner_offsets = torch.cat((torch.zeros(1, dtype=inner_end_values.dtype), torch.cumsum(inner_end_values[:-1], dim=0))) inner_xattn_q_ranges = [tensor + int(offset) for tensor, offset in zip(inner_xattn_q_ranges, inner_offsets)] xattn_q_ranges.append(torch.cat(inner_xattn_q_ranges, dim=0)) xattn_k_ranges.append(torch.cat(inner_xattn_k_ranges, dim=0)) end_values = torch.tensor([ranges[-1, -1].item() for ranges in xattn_q_ranges], dtype=torch.int32) offsets = torch.cat((torch.zeros(1, dtype=end_values.dtype), torch.cumsum(end_values[:-1], dim=0))) shifted_tensors = [tensor + int(offset) for tensor, offset in zip(xattn_q_ranges, offsets)] xattn_q_ranges_ts = torch.cat(shifted_tensors, dim=0) xattn_k_ranges_ts = torch.cat(xattn_k_ranges, dim=0) cu_seqlens_q = torch.unique(xattn_q_ranges_ts) cu_seqlens_k = torch.unique(xattn_k_ranges_ts) assert ( cu_seqlens_q.shape == cu_seqlens_k.shape ), f"cu_seqlens_q.shape: {cu_seqlens_q.shape}, cu_seqlens_k.shape: {cu_seqlens_k.shape}, " return PackedCrossAttnParams( q_ranges=xattn_q_ranges_ts, kv_ranges=xattn_k_ranges_ts, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=seq_len_cur_rank, max_seqlen_kv=cross_attn_params.max_seqlen_kv, ) def cp_ulysses_process( cp_size: int, x: torch.Tensor, condition_map: torch.Tensor, rope: torch.Tensor, xattn_mask_for_cuda_graph: Union[torch.Tensor, None], cross_attn_params: PackedCrossAttnParams, ): seq_len, N, D = x.shape assert seq_len == rope.size(0), f"seq_len: {seq_len} != rope.size(0): {rope.size(0)}" assert condition_map.size(0) == seq_len, f"condition_map.size(0): {condition_map.size(0)} != seq_len: {seq_len}" # Part1: split for CP cp_split_sizes = [seq_len // cp_size] * cp_size for i in range(seq_len % cp_size): cp_split_sizes[i] += 1 # Part2: scatter to CP x = scatter_to_context_parallel_region(x, cp_split_sizes) condition_map = scatter_to_context_parallel_region(condition_map, cp_split_sizes) rope = scatter_to_context_parallel_region(rope, cp_split_sizes) # Part3: update cross_attn cross_attn_params cross_attn_params = cp_update_cross_attn_qkv_range(cross_attn_params, N, cp_split_sizes, x.device) if xattn_mask_for_cuda_graph is not None: cross_attn_params = update_packed_seq_params_for_cuda_graph(cross_attn_params, xattn_mask_for_cuda_graph) return x, condition_map, rope, cp_split_sizes, cross_attn_params def cp_shuffle_overlap_process( cp_size: int, x: torch.Tensor, condition_map: torch.Tensor, rope: torch.Tensor, xattn_mask_for_cuda_graph: Union[torch.Tensor, None], ardf_meta: dict, core_attn_params: PackedCoreAttnParams, cross_attn_params: PackedCrossAttnParams, ): seq_len, N, D = x.shape assert seq_len == rope.size(0), f"seq_len: {seq_len} != rope.size(0): {rope.size(0)}" assert condition_map.size(0) == seq_len, f"condition_map.size(0): {condition_map.size(0)} != seq_len: {seq_len}" cp_shuffle_num = ardf_meta["denoising_range_num"] # Part1: calculate cp_pad_size and cp_split_sizes cp_pad_size = 0 if divide(seq_len, cp_shuffle_num) % cp_size != 0: cp_pad_size = (cp_size - divide(seq_len, cp_shuffle_num) % cp_size) * cp_shuffle_num cp_split_sizes = [(seq_len + cp_pad_size) // cp_size] * cp_size # Part2: scatter to CP x = scatter_to_context_parallel_region(x, cp_split_sizes, cp_shuffle_num, cp_pad_size) condition_map = scatter_to_context_parallel_region(condition_map, cp_split_sizes, cp_shuffle_num, cp_pad_size) rope = scatter_to_context_parallel_region(rope, cp_split_sizes, cp_shuffle_num, cp_pad_size) # Part3: update core_attn_params gcd = math.gcd(seq_len, seq_len + cp_pad_size) _sq = seq_len // gcd _psq = (seq_len + cp_pad_size) // gcd q_range = ardf_meta["q_range"] * _psq // _sq max_seqlen_q = ardf_meta["max_seqlen_q"] * _psq // _sq core_attn_params = PackedCoreAttnParams( q_range=q_range, k_range=ardf_meta["k_range"], np_q_range=q_range.cpu().numpy(), np_k_range=ardf_meta["k_range"].cpu().numpy(), max_seqlen_q=max_seqlen_q, max_seqlen_k=ardf_meta["max_seqlen_k"], ) # Part4: update cross_attn cross_attn_params cross_attn_params = cp_update_cross_attn_qkv_range( cross_attn_params, N, cp_split_sizes, x.device, cp_shuffle_num, cp_pad_size ) if xattn_mask_for_cuda_graph is not None: cross_attn_params = update_packed_seq_params_for_cuda_graph(cross_attn_params, xattn_mask_for_cuda_graph) return x, condition_map, rope, cp_pad_size, cp_split_sizes, core_attn_params, cross_attn_params def cp_pre_process( cp_size: int, cp_strategy: str, x: torch.Tensor, condition_map: torch.Tensor, rope: torch.Tensor, xattn_mask_for_cuda_graph: Union[torch.Tensor, None], ardf_meta: dict, core_attn_params: PackedCoreAttnParams, cross_attn_params: PackedCrossAttnParams, ): """ This function is used to handle context parallel behavior, split input tensors into multiple parts and scatter them to different GPUs. Input: cp_strategy: str. cp_ulysses for hopper or newer, cp_shuffle_overlap for 4090 or older x: (S, N, D). torch.Tensor of inputs embedding (images or latent representations of images) condition_map: (N * S). torch.Tensor determine which condition to use for each token rope: (S, 96). torch.Tensor of rope xattn_mask_for_cuda_graph: (N * denoising_range_num, L, 1, 1). torch.Tensor of xattn mask for cuda graph, None means no cuda graph core_attn_params: PackedCoreAttnParams. Packed sequence parameters for core_atten cross_attn_params: PackedCrossAttnParams. Packed sequence parameters for cross_atten Output: x: (S', N, D). torch.Tensor of inputs embedding (images or latent representations of images) condition_map: (N * S'). torch.Tensor determine which condition to use for each token rope: (S', 96). torch.Tensor of rope cp_split_sizes: List[int]. Split sizes for each rank core_attn_params: PackedCoreAttnParams cross_attn_params: PackedCrossAttnParams """ if cp_size == 1: return x, condition_map, rope, None, None, core_attn_params, cross_attn_params if cp_strategy == "cp_ulysses": (x, condition_map, rope, cp_split_sizes, cross_attn_params) = cp_ulysses_process( cp_size, x, condition_map, rope, xattn_mask_for_cuda_graph, cross_attn_params ) return (x, condition_map, rope, 0, cp_split_sizes, core_attn_params, cross_attn_params) elif cp_strategy == "cp_shuffle_overlap": ( x, condition_map, rope, cp_pad_size, cp_split_sizes, core_attn_params, cross_attn_params, ) = cp_shuffle_overlap_process( cp_size, x, condition_map, rope, xattn_mask_for_cuda_graph, ardf_meta, core_attn_params, cross_attn_params ) return (x, condition_map, rope, cp_pad_size, cp_split_sizes, core_attn_params, cross_attn_params) else: raise ValueError(f"Invalid CP strategy: {cp_strategy}, expected cp_ulysses or cp_shuffle_overlap") def cp_post_process(cp_size: int, cp_strategy: str, x: torch.Tensor, meta_args: ModelMetaArgs) -> torch.Tensor: if cp_size == 1: return x if cp_strategy == "cp_shuffle_overlap": x = gather_from_context_parallel_region( x, meta_args.cp_split_sizes, meta_args.denoising_range_num, meta_args.cp_pad_size ) elif cp_strategy == "cp_ulysses": x = gather_from_context_parallel_region(x, meta_args.cp_split_sizes) else: raise ValueError(f"Invalid CP strategy: {cp_strategy}, expected cp_ulysses or cp_shuffle_overlap") return x ##################################################### # Ulysses Attention Pipeline ##################################################### def all_to_all_input_split(tensor: torch.Tensor, cp_split_sizes: List[int]) -> Tuple[torch.Tensor, torch.distributed.Work]: """ Scatter head_number and gather seq_len, for example: input: (seq_len, cp * hn, hd) output: (seq_len * cp, hn, hd) NOTE: seq_len of input maybe not equal, which depends on cp_split_sizes[mpu.get_cp_rank()] """ cp_world_size = mpu.get_cp_world_size() if cp_world_size == 1: return tensor, FakeHandle() assert cp_split_sizes is not None _, hn, _ = tensor.shape if cp_world_size % hn == 0 and cp_world_size != hn: tensor = torch.repeat_interleave(tensor, repeats=divide(cp_world_size, hn), dim=1).contiguous() assert tensor.is_contiguous() input = rearrange(tensor, "seq (cp hn) hd -> (cp seq) hn hd", cp=cp_world_size).contiguous() output = torch.empty([sum(cp_split_sizes), *input.shape[1:]], device=input.device, dtype=input.dtype) handle = torch.distributed.all_to_all_single( output, input, output_split_sizes=cp_split_sizes, group=mpu.get_cp_group(), async_op=True ) return output, handle def all_to_all_output_split(tensor: torch.Tensor, cp_split_sizes: List[int]) -> Tuple[torch.Tensor, torch.distributed.Work]: """ Scatter seq_len and gather head_number, for example: input: (seq_len * cp, hn, hd) output: (seq_len, cp * hn, hd) NOTE: seq_len of output maybe not equal, which depends on cp_split_sizes[mpu.get_cp_rank()] """ cp_world_size = mpu.get_cp_world_size() if cp_world_size == 1: return tensor, FakeHandle() assert cp_split_sizes is not None assert tensor.is_contiguous() _, hn, _ = tensor.shape output = torch.empty( [cp_split_sizes[mpu.get_cp_rank()] * cp_world_size, *tensor.shape[1:]], device=tensor.device, dtype=tensor.dtype ) handle = torch.distributed.all_to_all_single( output, tensor, input_split_sizes=cp_split_sizes, group=mpu.get_cp_group(), async_op=True ) return output, handle def fused_qkv_communication( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cp_split_sizes: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: cp_world_size = mpu.get_cp_world_size() if cp_world_size == 1: return q, k, v assert cp_split_sizes is not None _, k_head, _ = k.shape if cp_world_size % k_head == 0 and cp_world_size != k_head: k = torch.repeat_interleave(k, repeats=divide(cp_world_size, k_head), dim=1) v = torch.repeat_interleave(v, repeats=divide(cp_world_size, k_head), dim=1) q = rearrange(q, "seq (cp hn) hd -> (cp seq) hn hd", cp=cp_world_size).contiguous() k = rearrange(k, "seq (cp hn) hd -> (cp seq) hn hd", cp=cp_world_size).contiguous() v = rearrange(v, "seq (cp hn) hd -> (cp seq) hn hd", cp=cp_world_size).contiguous() head_split_number = [q.shape[1], k.shape[1], v.shape[1]] qkv = torch.cat([q, k, v], dim=1).contiguous() qkv_output = torch.empty([sum(cp_split_sizes), *qkv.shape[1:]], device=qkv.device, dtype=qkv.dtype) torch.distributed.all_to_all_single( qkv_output, qkv, output_split_sizes=cp_split_sizes, group=mpu.get_cp_group(), async_op=False ) q, k, v = torch.split(qkv_output, head_split_number, dim=1) return q, k, v class UlyssesScheduler: def __init__(self): pass @staticmethod def get_attn_and_xattn_with_comm_overlap( get_q_func: Callable, # [seq hn hd] get_k_func: Callable, # [seq hn hd] get_v_func: Callable, # [seq hn hd] kv_cache_func: Callable, core_attn_func: Callable, cross_attn_func: Callable, overlap_degree: int, batch_size: int, cp_size: int, cp_split_sizes: List[int] = None, ): """ Get Q, K, V with communication overlap. Input: get_q: Callable, function to get q, shape [b, sq, hn, hd] get_k: Callable, function to get k, shape [sq, b, hn, hd] get_v: Callable, function to get v, shape [sq, b, hn, hd] NOTE: Why follow such compute and comm order? 1. v_compute 2. k_compute(overlap with v_comm) 3. q_compute(overlap with k_comm) 4. kv_cache_func(overlap with q_comm) Follow the principle: We need to begin comm as soon as possible to hide the comm latency. The computation flops and commnunication order is: flops order: q_compute (larger hidden_size + layernorm) > k_compute (layernorm) > v_compute comm order: q_compute (larger hidden_size) > k_compute = v_compute """ value = get_v_func() value, handle_v = all_to_all_input_split(value, cp_split_sizes) key = get_k_func() key, handle_k = all_to_all_input_split(key, cp_split_sizes) query = get_q_func() query, handle_q = all_to_all_input_split(query, cp_split_sizes) handle_v.wait() handle_k.wait() kv = torch.concat([key, value], dim=-1) key, value = kv_cache_func(kv) handle_q.wait() return UlyssesScheduler.get_attn_and_xattn_base( query, key, value, core_attn_func, cross_attn_func, overlap_degree, batch_size, cp_size, cp_split_sizes ) @staticmethod def get_attn_and_xattn_with_fused_kv_comm( get_q_func: Callable, get_kv_func: Callable, kv_cache_func: Callable, core_attn_func: Callable, cross_attn_func: Callable, overlap_degree: int, batch_size: int, cp_size: int, cp_split_sizes: List[int] = None, ): """ When seq_len is very small, CPU-bound issues are severe. By fusing kv communication, CPU operations and the number of kernel launches are reduced. """ kv = get_kv_func() kv, handle_kv = all_to_all_input_split(kv, cp_split_sizes) query = get_q_func() query, handle_q = all_to_all_input_split(query, cp_split_sizes) handle_kv.wait() key, value = kv_cache_func(kv) handle_q.wait() return UlyssesScheduler.get_attn_and_xattn_base( query, key, value, core_attn_func, cross_attn_func, overlap_degree, batch_size, cp_size, cp_split_sizes ) def get_attn_and_xattn_with_fused_qkv_comm( get_qkv_func: Callable, kv_cache_func: Callable, core_attn_func: Callable, cross_attn_func: Callable, overlap_degree: int, batch_size: int, cp_size: int, cp_split_sizes: List[int] = None, ): """ By fusing the communication of q, k, and v together, further optimize CPU-bound issues. """ q, k, v = get_qkv_func() q, k, v = fused_qkv_communication(q, k, v, cp_split_sizes) k, v = kv_cache_func(torch.cat([k, v], dim=-1)) return UlyssesScheduler.get_attn_and_xattn_base( q, k, v, core_attn_func, cross_attn_func, overlap_degree, batch_size, cp_size, cp_split_sizes ) @staticmethod def get_attn_and_xattn_base( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, core_attn_func: Callable, cross_attn_func: Callable, overlap_degree: int, batch_size: int, cp_size: int, cp_split_sizes: List[int] = None, ): # Split Query, Key, Value into multiple parts # k/v may have different sequence length with q due to kv cache q_seq, q_head, q_hidden = query.shape kv_seq, kv_head, kv_hidden = key.shape if overlap_degree == -1: overlap_degree = q_head // kv_head else: assert overlap_degree <= q_head if overlap_degree == 1: query = [query] elif kv_head == 1: # MQA query = query.chunk(overlap_degree, dim=1) else: # GQA assert q_head % (overlap_degree * kv_head) == 0 query = query.reshape(q_seq, kv_head, -1, q_hidden) query = query.chunk(overlap_degree, dim=2) query = [q.reshape(q_seq, -1, q_hidden) for q in query] # Compute Core Attention handle_attn = None core_attn_out = None core_attn_outs = [] for i in range(overlap_degree): core_attn_out_new = core_attn_func(query[i], key, value) if handle_attn is not None: handle_attn.wait() core_attn_outs.append(core_attn_out) core_attn_out, handle_attn = all_to_all_output_split(core_attn_out_new, cp_split_sizes) xattn_out = cross_attn_func() handle_attn.wait() core_attn_outs.append(core_attn_out) core_attn_out = torch.cat(core_attn_outs, dim=1) core_attn_out = rearrange(core_attn_out, "(cp sq b) hn hd -> (sq) b (cp hn hd)", cp=cp_size, b=batch_size) return core_attn_out, xattn_out ##################################################### # CSO(context shuffle overlap) Attention Pipeline ##################################################### def cso_communication( input: torch.Tensor, cp_world_size: int, cp_split_sizes: List[int], comm_type: str = None ) -> Tuple[torch.Tensor, torch.distributed.Work]: if cp_world_size == 1: return input, FakeHandle() assert cp_split_sizes is not None _, hn, _ = input.shape if comm_type == "kv": if cp_world_size % hn == 0 and cp_world_size != hn: input = torch.repeat_interleave(input, repeats=divide(cp_world_size, hn), dim=1) input = rearrange(input, "spb (cp hn) hd -> (cp spb) hn hd", cp=cp_world_size).contiguous() output = torch.empty(input.shape, device=input.device, dtype=input.dtype) handle = torch.distributed.all_to_all_single( output, input, input_split_sizes=cp_split_sizes, group=mpu.get_cp_group(), async_op=True ) return output, handle class CSOHelper: def __init__(self, cp_shuffle_num, cp_world_size, cp_split_sizes): self.cp_shuffle_num = cp_shuffle_num self.cp_world_size = cp_world_size self.cp_split_sizes = [divide(x, self.cp_shuffle_num) for x in cp_split_sizes] def split_query_for_overlap(self, query): query = rearrange( query, "(dn spb) (cp hn) hd -> (dn cp spb) hn hd", cp=self.cp_world_size, dn=self.cp_shuffle_num ).contiguous() querys = list(torch.chunk(query, self.cp_shuffle_num, dim=0)) querys[0], handle_q = cso_communication(querys[0], self.cp_world_size, self.cp_split_sizes) return querys, handle_q def overlap(self, fattn, qs, k, v): core_attn_outs = [] for i in range(self.cp_shuffle_num): if self.cp_shuffle_num == 1: q = qs[0] elif i == 0: q = qs[0] loop_var, loop_handle = cso_communication(qs[i + 1], self.cp_world_size, self.cp_split_sizes) else: loop_handle.wait() if loop_var.numel() == qs[0].numel(): q = loop_var else: assert loop_var.numel() == qs[0].numel() * 2 q, ready_o = torch.chunk(loop_var, 2, dim=-1) core_attn_outs.append(ready_o) loop_var = torch.concat([qs[i + 1], o], dim=-1) if i < self.cp_shuffle_num - 1 else o loop_var, loop_handle = cso_communication(loop_var, self.cp_world_size, self.cp_split_sizes) o = fattn(q, k, v, i) if i == self.cp_shuffle_num - 1: if i != 0: loop_handle.wait() assert loop_var.numel() == qs[0].numel() core_attn_outs.append(loop_var) last_o, handle_attn = cso_communication(o, self.cp_world_size, self.cp_split_sizes) core_attn_outs.append(last_o) return core_attn_outs, handle_attn ``` ## /inference/infra/parallelism/pipeline_parallel.py ```py path="/inference/infra/parallelism/pipeline_parallel.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import queue from dataclasses import dataclass from typing import Optional import torch from inference.infra.distributed import parallel_state as mpu @dataclass class TensorAndHandler: tensor: torch.Tensor handler: torch.distributed.Work class PPScheduler: def __init__(self): """Initialize an instance of the PPScheduler class""" self.device: torch.device = torch.device(f"cuda:{torch.cuda.current_device()}") self.recv_queue: queue.Queue = queue.Queue() def isend_next(self, tensor: torch.Tensor) -> torch.distributed.Work: """Asynchronously send a tensor to the next pipeline and return the send handle. Args: tensor (torch.Tensor): The tensor to be sent. Returns: torch.distributed.Work: The handle for the send operation. """ handle = torch.distributed.isend( tensor.contiguous(), dst=mpu.get_pipeline_model_parallel_next_rank(), group=mpu.get_pp_group() ) return handle def irecv_prev(self, buffer: torch.Tensor) -> torch.distributed.Work: """Asynchronously receive a tensor from the previous pipeline and return the receive handle. Args: buffer (torch.Tensor): The buffer tensor for receiving data. Returns: torch.distributed.Work: The handle for the receive operation. """ handle = torch.distributed.irecv(buffer, src=mpu.get_pipeline_model_parallel_prev_rank(), group=mpu.get_pp_group()) return handle def recv_prev_data(self, shape: torch.Size, dtype: torch.dtype) -> torch.Tensor: """Receive data from the previous pipeline and return the received tensor. Args: shape (torch.Size): The shape of the tensor to receive. dtype (torch.dtype): The data type of the tensor to receive. Returns: torch.Tensor: The received tensor. """ recv_tensor = torch.empty(shape, dtype=dtype, device=self.device) self.irecv_prev(recv_tensor).wait() return recv_tensor def queue_irecv_prev(self, shape: torch.Size, dtype: torch.dtype) -> None: """Put the asynchronously received tensor and handle into the receive queue. Args: shape (torch.Size): The shape of the tensor to receive. dtype (torch.dtype): The data type of the tensor to receive. """ recv_tensor = torch.empty(shape, dtype=dtype, device=self.device) handle = self.irecv_prev(recv_tensor) self.recv_queue.put(TensorAndHandler(tensor=recv_tensor, handler=handle)) def queue_irecv_prev_data(self) -> torch.Tensor: """Get a tensor from the receive queue and wait for the receive operation to complete. Returns: torch.Tensor: The received tensor obtained from the queue. """ tensor_and_handler = self.recv_queue.get() tensor_and_handler.handler.wait() return tensor_and_handler.tensor _PP_SCHEDULER: Optional[PPScheduler] = None def init_pp_scheduler(): """Initialize the PPScheduler instance. Raises: AssertionError: If the PPScheduler is already initialized. """ global _PP_SCHEDULER assert _PP_SCHEDULER is None, "pipeline model parallel group is already initialized" _PP_SCHEDULER = PPScheduler() def pp_scheduler() -> PPScheduler: """Get the current PPScheduler instance. Returns: PPScheduler: The current PPScheduler instance. Raises: AssertionError: If the PPScheduler has not been initialized. """ assert _PP_SCHEDULER is not None, "pipeline model parallel group is not initialized" return _PP_SCHEDULER ``` ## /inference/infra/parallelism/tile_parallel.py ```py path="/inference/infra/parallelism/tile_parallel.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from typing import List import torch from tqdm import tqdm class ParallelHelper: def __init__(self): pass @staticmethod def split_tile_list( tile_numel_dict: OrderedDict[int, int], parallel_group: torch.distributed.ProcessGroup = None ) -> List[int]: """ Splits the given tile size into a list of sizes that each rank should handle. This method takes into account the number of ranks in a distributed setting. If the distributed environment is not initialized, it returns a list of integers from 0 to tile_size - 1, representing each tile index. If the distributed environment is initialized, it calculates the base tile size for each rank and distributes any remaining tiles among the ranks. Args: tile_numel_dict (OrderedDict[int, int]): Dict of index and numel of tiles. parallel_group (torch.distributed.ProcessGroup, optional): Distributed decoding group. Defaults to None. Returns: List[int]: A list of tile indices assigned to the current rank. List[int]: A list of global tile indices. """ if not torch.distributed.is_initialized(): return list(range(len(tile_numel_dict))), list(range(len(tile_numel_dict))) else: tile_idxs = list(OrderedDict(sorted(tile_numel_dict.items(), key=lambda x: x[1], reverse=True)).keys()) world_size = torch.distributed.get_world_size(group=parallel_group) cur_rank = torch.distributed.get_rank(group=parallel_group) global_tile_idxs = [] cur_rank_tile_idxs = [] for rank in range(world_size): rank_tile_idxs = [tile_idxs[rank + world_size * i] for i in range(len(tile_idxs) // world_size)] if rank < len(tile_idxs) % world_size: rank_tile_idxs.append(tile_idxs[len(tile_idxs) // world_size * world_size + rank]) if rank == cur_rank: cur_rank_tile_idxs = rank_tile_idxs global_tile_idxs = global_tile_idxs + rank_tile_idxs return cur_rank_tile_idxs, global_tile_idxs @staticmethod def gather_frames( frames: List[torch.Tensor], global_tile_idxs: List[int], parallel_group: torch.distributed.ProcessGroup = None ) -> List[torch.Tensor]: """ Gathers frame data from all ranks in a distributed environment. This method collects frames from all ranks and combines them into a single list. If the distributed environment is not initialized, it simply returns the input frames. Args: frames (List[torch.Tensor]): A list of frames (tensors) from the current rank. global_tile_idxs (List[int]): A list of global tile indices. parallel_group (torch.distributed.ProcessGroup, optional): Distributed decoding group. Defaults to None. Returns: List[torch.Tensor]: A list of frames (tensors) from all ranks. """ if not torch.distributed.is_initialized(): return frames else: # assert len(frames) > 0 # Communicate shapes if len(frames) == 0: cur_rank_shapes = [] else: cur_rank_shapes = [frame.shape for frame in frames] all_rank_shapes = [None] * torch.distributed.get_world_size(group=parallel_group) torch.distributed.all_gather_object(all_rank_shapes, cur_rank_shapes, group=parallel_group) all_rank_sizes = [] total_size = [] for per_rank_shapes in all_rank_shapes: per_rank_sizes = [] per_rank_total_size = 0 for shape in per_rank_shapes: per_rank_sizes.append(shape[0] * shape[1] * shape[2] * shape[3] * shape[4]) per_rank_total_size += shape[0] * shape[1] * shape[2] * shape[3] * shape[4] all_rank_sizes.append(per_rank_sizes) total_size.append(per_rank_total_size) # Gather all frames if len(frames) == 0: flattened_frames = torch.zeros([0], dtype=torch.bfloat16, device="cuda") else: flattened_frames = torch.cat([frame.flatten().contiguous() for frame in frames], dim=0) assert flattened_frames.dtype == torch.bfloat16 gather_tensors = [ torch.zeros(total_size[i], dtype=torch.bfloat16, device="cuda") for i in range(torch.distributed.get_world_size(group=parallel_group)) ] torch.distributed.all_gather(gather_tensors, flattened_frames, group=parallel_group) result_frames = [] for idx, per_rank_shapes in enumerate(all_rank_shapes): offset = 0 for j, shape in enumerate(per_rank_shapes): result_frames.append(gather_tensors[idx][offset : offset + all_rank_sizes[idx][j]].view(shape)) offset += all_rank_sizes[idx][j] result_frames_dict = OrderedDict((idx, frame) for idx, frame in zip(global_tile_idxs, result_frames)) result_frames = list(OrderedDict(sorted(result_frames_dict.items())).values()) return result_frames @staticmethod def index_undot(index: int, loop_size: List[int]) -> List[int]: """ Converts a single index into a list of indices, representing the position in a multi-dimensional space. This method takes an integer index and a list of loop sizes, and converts the index into a list of indices that correspond to the position in a multi-dimensional space. Args: index (int): The single index to be converted. loop_size (List[int]): A list of integers representing the size of each dimension in the multi-dimensional space. Returns: List[int]: A list of integers representing the position in the multi-dimensional space. """ undotted_index = [] for i in range(len(loop_size) - 1, -1, -1): undotted_index.append(index % loop_size[i]) index = index // loop_size[i] undotted_index.reverse() assert len(undotted_index) == len(loop_size) return undotted_index @staticmethod def index_dot(index: List[int], loop_size: List[int]) -> int: """ Converts a list of indices into a single index, representing the position in a multi-dimensional space. This method takes a list of indices and a list of loop sizes, and converts the list of indices into a single index that corresponds to the position in a multi-dimensional space. Args: index (List[int]): A list of integers representing the position in the multi-dimensional space. loop_size (List[int]): A list of integers representing the size of each dimension in the multi-dimensional space. Returns: int: A single integer representing the position in the multi-dimensional space. """ assert len(index) == len(loop_size) dot_index = 0 strides = [1] for i in range(len(loop_size) - 1, -1, -1): strides.append(strides[-1] * loop_size[i]) strides.reverse() strides = strides[1:] assert len(index) == len(strides) for i in range(len(index)): dot_index += index[i] * strides[i] return dot_index class TileProcessor: def __init__( self, encode_fn, decode_fn, tile_sample_min_height: int = 256, tile_sample_min_width: int = 256, tile_sample_min_length: int = 16, spatial_downsample_factor: int = 8, temporal_downsample_factor: int = 1, spatial_tile_overlap_factor: float = 0.25, temporal_tile_overlap_factor: float = 0, sr_ratio=1, first_frame_as_image: bool = False, parallel_group: torch.distributed.ProcessGroup = None, ): """ Initializes an instance of the class. Args: encode_fn (function): The encoding function used for tile sampling. decode_fn (function): The decoding function used for tile reconstruction. tile_sample_min_size (int, optional): The minimum size of the sampled tiles. Defaults to 256. tile_sample_min_length (int, optional): The minimum length of the sampled tiles. Defaults to 16. spatial_downsample_factor (int, optional): The actual spataial downsample factor of given encode_fn. Defaults to 8. temporal_downsample_factor (int, optional): The actual temporal downsample factor of the latent space tiles. Defaults to 1. tile_overlap_factor (float, optional): The overlap factor between adjacent tiles. Defaults to 0.25. parallel_group (torch.distributed.ProcessGroup, optional): Distributed decoding group. Defaults to None. """ self.encode_fn = encode_fn self.decode_fn = decode_fn self.spatial_downsample_factor = spatial_downsample_factor self.temporal_downsample_factor = temporal_downsample_factor self.tile_sample_min_height = tile_sample_min_height self.tile_sample_min_width = tile_sample_min_width self.tile_sample_min_length = tile_sample_min_length self.tile_latent_min_height = tile_sample_min_height // spatial_downsample_factor self.tile_latent_min_width = tile_sample_min_width // spatial_downsample_factor self.tile_latent_min_length = tile_sample_min_length // temporal_downsample_factor if first_frame_as_image: self.tile_latent_min_length += 1 self.spatial_tile_overlap_factor = spatial_tile_overlap_factor self.temporal_tile_overlap_factor = temporal_tile_overlap_factor self.sr_ratio = sr_ratio self.parallel_group = parallel_group def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[2], b.shape[2], blend_extent) for t in range(blend_extent): b[:, :, t, :, :] = a[:, :, -blend_extent + t, :, :] * (1 - t / blend_extent) + b[:, :, t, :, :] * ( t / blend_extent ) return b def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) for y in range(blend_extent): b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( y / blend_extent ) return b def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[4], b.shape[4], blend_extent) for x in range(blend_extent): b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( x / blend_extent ) return b def tiled_encode(self, x: torch.FloatTensor, verbose: bool = False): overlap_height = int(self.tile_sample_min_height * (1 - self.spatial_tile_overlap_factor)) overlap_width = int(self.tile_sample_min_width * (1 - self.spatial_tile_overlap_factor)) overlap_length = int(self.tile_sample_min_length * (1 - self.temporal_tile_overlap_factor)) blend_extent_h = int(self.tile_latent_min_height * self.spatial_tile_overlap_factor) blend_extent_w = int(self.tile_latent_min_width * self.spatial_tile_overlap_factor) blend_extent_t = int(self.tile_latent_min_length * self.temporal_tile_overlap_factor) height_limit = self.tile_latent_min_height - blend_extent_h width_limit = self.tile_latent_min_width - blend_extent_w frame_limit = self.tile_latent_min_length - blend_extent_t length_tile_size = (x.shape[2] + overlap_length - 1) // overlap_length height_tile_size = (x.shape[3] + overlap_height - 1) // overlap_height width_tile_size = (x.shape[4] + overlap_width - 1) // overlap_width total_tile_size = length_tile_size * height_tile_size * width_tile_size for_loop_size = [length_tile_size, height_tile_size, width_tile_size] tiles = [] tile_numel_dict = OrderedDict() for tile_index in range(total_tile_size): undot_tile_index = ParallelHelper.index_undot(tile_index, for_loop_size) f_idx, i_idx, j_idx = undot_tile_index f = f_idx * overlap_length i = i_idx * overlap_height j = j_idx * overlap_width # Extract the tile from the latent representation and decode it tile = x[ :, :, f : f + self.tile_sample_min_length, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width, ] tiles.append(tile) tile_numel_dict[tile_index] = tile.numel() tile_index_list, global_tile_index_list = ParallelHelper.split_tile_list( tile_numel_dict, parallel_group=self.parallel_group ) progress_bar = tqdm( total=len(tile_index_list), desc=f"[Rank {torch.distributed.get_rank(group=self.parallel_group)}] Encoding Tiles", disable=not verbose, ) frames = [] # Encode each tile based on the tile index list for tile_index in tile_index_list: tile = tiles[tile_index] encoded = self.encode_fn(tile) frames.append(encoded) progress_bar.update(1) # Gather all decoded frames from different ranks frames = ParallelHelper.gather_frames(frames, global_tile_index_list, parallel_group=self.parallel_group) assert len(frames) == total_tile_size progress_bar.close() result_frames = [] # Blend the encoded tiles to create the final output for tile_index in range(total_tile_size): undot_tile_index = ParallelHelper.index_undot(tile_index, for_loop_size) f, i, j = undot_tile_index tile = frames[tile_index] # Blend with previous tiles if applicable if f > 0: idx = ParallelHelper.index_dot([f - 1, i, j], for_loop_size) tile = self.blend_t(frames[idx], tile, blend_extent_t) if i > 0: idx = ParallelHelper.index_dot([f, i - 1, j], for_loop_size) tile = self.blend_v(frames[idx], tile, blend_extent_h) if j > 0: idx = ParallelHelper.index_dot([f, i, j - 1], for_loop_size) tile = self.blend_h(frames[idx], tile, blend_extent_w) result_frames.append(tile[:, :, :frame_limit, :height_limit, :width_limit]) assert len(result_frames) == total_tile_size concat_frames = [] for f in range(length_tile_size): result_rows = [] for i in range(height_tile_size): result_row = [] for j in range(width_tile_size): idx = ParallelHelper.index_dot([f, i, j], for_loop_size) result_row.append(result_frames[idx]) result_rows.append(torch.cat(result_row, dim=4)) concat_frames.append(torch.cat(result_rows, dim=3)) # Concatenate all result frames along the temporal dimension result = torch.cat(concat_frames, dim=2) return result def tiled_decode(self, z: torch.FloatTensor, verbose: bool = False): overlap_height = int(self.tile_latent_min_height * (1 - self.spatial_tile_overlap_factor)) overlap_width = int(self.tile_latent_min_width * (1 - self.spatial_tile_overlap_factor)) overlap_length = int(self.tile_latent_min_length * (1 - self.temporal_tile_overlap_factor)) real_tile_sample_min_height = int(self.tile_latent_min_height * self.spatial_downsample_factor * self.sr_ratio) real_tile_sample_min_width = int(self.tile_latent_min_width * self.spatial_downsample_factor * self.sr_ratio) real_tile_sample_min_length = int(self.tile_latent_min_length * self.temporal_downsample_factor) blend_extent_h = int(real_tile_sample_min_height * self.spatial_tile_overlap_factor) blend_extent_w = int(real_tile_sample_min_width * self.spatial_tile_overlap_factor) blend_extent_t = int(real_tile_sample_min_length * self.temporal_tile_overlap_factor) height_limit = real_tile_sample_min_height - blend_extent_h width_limit = real_tile_sample_min_width - blend_extent_w frame_limit = real_tile_sample_min_length - blend_extent_t length_tile_size = (z.shape[2] + overlap_length - 1) // overlap_length height_tile_size = (z.shape[3] + overlap_height - 1) // overlap_height width_tile_size = (z.shape[4] + overlap_width - 1) // overlap_width total_tile_size = length_tile_size * height_tile_size * width_tile_size for_loop_size = [length_tile_size, height_tile_size, width_tile_size] tiles = [] tile_numel_dict = OrderedDict() for tile_index in range(total_tile_size): undot_tile_index = ParallelHelper.index_undot(tile_index, for_loop_size) f_idx, i_idx, j_idx = undot_tile_index f = f_idx * overlap_length i = i_idx * overlap_height j = j_idx * overlap_width # Extract the tile from the latent representation and decode it tile = z[ :, :, f : f + self.tile_latent_min_length, i : i + self.tile_latent_min_height, j : j + self.tile_latent_min_width, ] tiles.append(tile) tile_numel_dict[tile_index] = tile.numel() tile_index_list, global_tile_index_list = ParallelHelper.split_tile_list( tile_numel_dict, parallel_group=self.parallel_group ) progress_bar = tqdm( total=len(tile_index_list), desc=f"[Rank {torch.distributed.get_rank(group=self.parallel_group)}] Decoding Tiles", disable=not verbose, ) frames = [] # Decode each tile based on the tile index list for tile_index in tile_index_list: tile = tiles[tile_index] decoded = self.decode_fn(tile) frames.append(decoded) progress_bar.update(1) progress_bar.close() # Gather all decoded frames from different ranks frames = ParallelHelper.gather_frames(frames, global_tile_index_list, parallel_group=self.parallel_group) assert len(frames) == total_tile_size result_frames = [] # Blend the decoded tiles to create the final output for tile_index in tile_index_list: undot_tile_index = ParallelHelper.index_undot(tile_index, for_loop_size) f, i, j = undot_tile_index tile = frames[tile_index].clone() # Blend with previous tiles if applicable if f > 0: idx = ParallelHelper.index_dot([f - 1, i, j], for_loop_size) tile = torch.compile(self.blend_t, dynamic=False)(frames[idx], tile, blend_extent_t) if i > 0: idx = ParallelHelper.index_dot([f, i - 1, j], for_loop_size) tile = torch.compile(self.blend_v, dynamic=False)(frames[idx], tile, blend_extent_h) if j > 0: idx = ParallelHelper.index_dot([f, i, j - 1], for_loop_size) tile = torch.compile(self.blend_h, dynamic=False)(frames[idx], tile, blend_extent_w) result_frames.append(tile[:, :, :frame_limit, :height_limit, :width_limit]) # Gather and concatenate the final result frames result_frames = ParallelHelper.gather_frames(result_frames, global_tile_index_list, parallel_group=self.parallel_group) assert len(result_frames) == total_tile_size concat_frames = [] for f in range(length_tile_size): result_rows = [] for i in range(height_tile_size): result_row = [] for j in range(width_tile_size): idx = ParallelHelper.index_dot([f, i, j], for_loop_size) result_row.append(result_frames[idx]) result_rows.append(torch.cat(result_row, dim=4)) concat_frames.append(torch.cat(result_rows, dim=3)) # Concatenate all result frames along the temporal dimension result = torch.cat(concat_frames, dim=2) return result ``` ## /inference/model/dit/__init__.py ```py path="/inference/model/dit/__init__.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .dit_model import get_dit __all__ = ["get_dit"] ``` ## /inference/model/dit/dit_model.py ```py path="/inference/model/dit/dit_model.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gc import math import os from typing import Tuple import torch import torch.distributed import torch.nn as nn from einops import rearrange from inference.common import ( InferenceParams, MagiConfig, ModelMetaArgs, PackedCoreAttnParams, PackedCrossAttnParams, env_is_true, print_per_rank, print_rank_0, ) from inference.infra.checkpoint import load_checkpoint from inference.infra.distributed import parallel_state as mpu from inference.infra.parallelism import cp_post_process, cp_pre_process, pp_scheduler from .dit_module import CaptionEmbedder, FinalLinear, LearnableRotaryEmbeddingCat, TimestepEmbedder, TransformerBlock class VideoDiTModel(torch.nn.Module): """VideoDiT model for video diffusion. Args: config (MagiConfig): Transformer config pre_process (bool, optional): Include embedding layer (used with pipeline parallelism). Defaults to True. post_process (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True. """ def __init__(self, config: MagiConfig, pre_process: bool = True, post_process: bool = True) -> None: super().__init__() self.model_config = config.model_config self.runtime_config = config.runtime_config self.engine_config = config.engine_config self.pre_process = pre_process self.post_process = post_process self.in_channels = self.model_config.in_channels self.out_channels = self.model_config.out_channels self.patch_size = self.model_config.patch_size self.t_patch_size = self.model_config.t_patch_size self.caption_max_length = self.model_config.caption_max_length self.num_heads = self.model_config.num_attention_heads self.x_embedder = nn.Conv3d( self.model_config.in_channels, self.model_config.hidden_size, kernel_size=(self.model_config.t_patch_size, self.model_config.patch_size, self.model_config.patch_size), stride=(self.model_config.t_patch_size, self.model_config.patch_size, self.model_config.patch_size), bias=False, ) self.t_embedder = TimestepEmbedder(model_config=self.model_config) self.y_embedder = CaptionEmbedder(model_config=self.model_config) self.rope = LearnableRotaryEmbeddingCat( self.model_config.hidden_size // self.model_config.num_attention_heads, in_pixels=False ) # trm block self.videodit_blocks = TransformerBlock( model_config=self.model_config, engine_config=self.engine_config, pre_process=pre_process, post_process=post_process, ) self.final_linear = FinalLinear( self.model_config.hidden_size, self.model_config.patch_size, self.model_config.t_patch_size, self.out_channels ) def generate_kv_range_for_uncondition(self, uncond_x) -> torch.Tensor: device = f"cuda:{torch.cuda.current_device()}" B, C, T, H, W = uncond_x.shape chunk_token_nums = ( (T // self.model_config.t_patch_size) * (H // self.model_config.patch_size) * (W // self.model_config.patch_size) ) k_chunk_start = torch.linspace(0, (B - 1) * chunk_token_nums, steps=B).reshape((B, 1)) k_chunk_end = torch.linspace(chunk_token_nums, B * chunk_token_nums, steps=B).reshape((B, 1)) return torch.concat([k_chunk_start, k_chunk_end], dim=1).to(torch.int32).to(device) def unpatchify(self, x, H, W): return rearrange( x, "(T H W) N (pT pH pW C) -> N C (T pT) (H pH) (W pW)", H=H, W=W, pT=self.t_patch_size, pH=self.patch_size, pW=self.patch_size, ).contiguous() @torch.no_grad() def get_embedding_and_meta(self, x, t, y, caption_dropout_mask, xattn_mask, kv_range, **kwargs): """ Forward embedding and meta for VideoDiT. NOTE: This function should only handle single card behavior. Input: x: (N, C, T, H, W). torch.Tensor of spatial inputs (images or latent representations of images) t: (N, denoising_range_num). torch.Tensor of diffusion timesteps y: (N * denoising_range_num, 1, L, C). torch.Tensor of class labels caption_dropout_mask: (N). torch.Tensor of whether to drop caption xattn_mask: (N * denoising_range_num, 1, L). torch.Tensor of xattn mask kv_range: (N * denoising_range_num, 2). torch.Tensor of kv range Output: x: (S, N, D). torch.Tensor of inputs embedding (images or latent representations of images) condition: (N, denoising_range_num, D). torch.Tensor of condition embedding condition_map: (S, N). torch.Tensor determine which condition to use for each token rope: (S, 96). torch.Tensor of rope y_xattn_flat: (total_token, D). torch.Tensor of y_xattn_flat cuda_graph_inputs: (y_xattn_flat, xattn_mask) or None. None means no cuda graph NOTE: y_xattn_flat and xattn_mask with static shape H: int. Height of the input W: int. Width of the input ardf_meta: dict. Meta information for ardf cross_attn_params: PackedCrossAttnParams. Packed sequence parameters for cross_atten """ ################################### # Part1: Embed x # ################################### x = self.x_embedder(x) # [N, C, T, H, W] batch_size, _, T, H, W = x.shape # Prepare necessary variables range_num = kwargs["range_num"] denoising_range_num = kwargs["denoising_range_num"] slice_point = kwargs.get("slice_point", 0) frame_in_range = T // denoising_range_num prev_clean_T = frame_in_range * slice_point T_total = T + prev_clean_T ################################### # Part2: rope # ################################### # caculate rescale_factor for multi-resolution & multi aspect-ratio training # the base_size [16*16] is A predefined size based on data:(256x256) vae: (8,8,4) patch size: (1,1,2) # This definition do not have any relationship with the actual input/model/setting. # ref_feat_shape is used to calculate innner rescale factor, so it can be float. rescale_factor = math.sqrt((H * W) / (16 * 16)) rope = self.rope.get_embed(shape=[T_total, H, W], ref_feat_shape=[T_total, H / rescale_factor, W / rescale_factor]) # the shape of rope is (T*H*W, -1) aka (seq_length, head_dim), as T is the first dimension, we can directly cut it. rope = rope[-(T * H * W) :] ################################### # Part3: Embed t # ################################### assert t.shape[0] == batch_size, f"Invalid t shape, got {t.shape[0]} != {batch_size}" # nolint assert t.shape[1] == denoising_range_num, f"Invalid t shape, got {t.shape[1]} != {denoising_range_num}" # nolint t_flat = t.flatten() # (N * denoising_range_num,) t = self.t_embedder(t_flat) # (N, D) if self.engine_config.distill: distill_dt_scalar = 2 if kwargs["num_steps"] == 12: base_chunk_step = 4 distill_dt_factor = base_chunk_step / kwargs["distill_interval"] * distill_dt_scalar else: distill_dt_factor = kwargs["num_steps"] / 4 * distill_dt_scalar distill_dt = torch.ones_like(t_flat) * distill_dt_factor distill_dt_embed = self.t_embedder(distill_dt) t = t + distill_dt_embed t = t.reshape(batch_size, denoising_range_num, -1) # (N, range_num, D) ###################################################### # Part4: Embed y, prepare condition and y_xattn_flat # ###################################################### # (N * denoising_range_num, 1, L, D) y_xattn, y_adaln = self.y_embedder(y, self.training, caption_dropout_mask) assert xattn_mask is not None xattn_mask = xattn_mask.squeeze(1).squeeze(1) # condition: (N, range_num, D) y_adaln = y_adaln.squeeze(1) # (N, D) condition = t + y_adaln.unsqueeze(1) assert condition.shape[0] == batch_size assert condition.shape[1] == denoising_range_num seqlen_per_chunk = (T * H * W) // denoising_range_num condition_map = torch.arange(batch_size * denoising_range_num, device=x.device) condition_map = torch.repeat_interleave(condition_map, seqlen_per_chunk) condition_map = condition_map.reshape(batch_size, -1).transpose(0, 1).contiguous() # y_xattn_flat: (total_token, D) y_xattn_flat = torch.masked_select(y_xattn.squeeze(1), xattn_mask.unsqueeze(-1).bool()).reshape(-1, y_xattn.shape[-1]) xattn_mask_for_cuda_graph = None ###################################################### # Part5: Prepare cross_attn_params for cross_atten # ###################################################### # (N * denoising_range_num, L) xattn_mask = xattn_mask.reshape(xattn_mask.shape[0], -1) y_index = torch.sum(xattn_mask, dim=-1) clip_token_nums = H * W * frame_in_range cu_seqlens_q = torch.Tensor([0] + ([clip_token_nums] * denoising_range_num * batch_size)).to(torch.int64).to(x.device) cu_seqlens_k = torch.cat([y_index.new_tensor([0]), y_index]).to(torch.int64).to(x.device) cu_seqlens_q = cu_seqlens_q.cumsum(-1).to(torch.int32) cu_seqlens_k = cu_seqlens_k.cumsum(-1).to(torch.int32) assert ( cu_seqlens_q.shape == cu_seqlens_k.shape ), f"cu_seqlens_q.shape: {cu_seqlens_q.shape}, cu_seqlens_k.shape: {cu_seqlens_k.shape}" xattn_q_ranges = torch.cat([cu_seqlens_q[:-1].unsqueeze(1), cu_seqlens_q[1:].unsqueeze(1)], dim=1) xattn_k_ranges = torch.cat([cu_seqlens_k[:-1].unsqueeze(1), cu_seqlens_k[1:].unsqueeze(1)], dim=1) assert ( xattn_q_ranges.shape == xattn_k_ranges.shape ), f"xattn_q_ranges.shape: {xattn_q_ranges.shape}, xattn_k_ranges.shape: {xattn_k_ranges.shape}" cross_attn_params = PackedCrossAttnParams( q_ranges=xattn_q_ranges, kv_ranges=xattn_k_ranges, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=clip_token_nums, max_seqlen_kv=self.caption_max_length, ) ################################################## # Part6: Prepare core_atten related q/kv range # ################################################## q_range = torch.cat([cu_seqlens_q[:-1].unsqueeze(1), cu_seqlens_q[1:].unsqueeze(1)], dim=1) flat_kv = torch.unique(kv_range, sorted=True) max_seqlen_k = (flat_kv[-1] - flat_kv[0]).cpu().item() ardf_meta = dict( clip_token_nums=clip_token_nums, slice_point=slice_point, range_num=range_num, denoising_range_num=denoising_range_num, q_range=q_range, k_range=kv_range, max_seqlen_q=clip_token_nums, max_seqlen_k=max_seqlen_k, ) return (x, condition, condition_map, rope, y_xattn_flat, xattn_mask_for_cuda_graph, H, W, ardf_meta, cross_attn_params) @torch.no_grad() def forward_pre_process( self, x, t, y, caption_dropout_mask=None, xattn_mask=None, kv_range=None, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ModelMetaArgs]: assert kv_range is not None, "Please ensure kv_range is provided" x = x * self.model_config.x_rescale_factor if self.model_config.half_channel_vae: assert x.shape[1] == 16 x = torch.cat([x, x], dim=1) x = x.float() t = t.float() y = y.float() # embedder context will ensure that the processing is in high precision even if the embedder params is in bfloat16 mode with torch.autocast(device_type="cuda", dtype=torch.float32): ( x, condition, condition_map, rope, y_xattn_flat, xattn_mask_for_cuda_graph, H, W, ardf_meta, cross_attn_params, ) = self.get_embedding_and_meta(x, t, y, caption_dropout_mask, xattn_mask, kv_range, **kwargs) # Downcast x and rearrange x x = x.to(self.model_config.params_dtype) x = rearrange(x, "N C T H W -> (T H W) N C").contiguous() # (thw, N, D) # condition and y_xattn_flat will be downcast to bfloat16 in transformer block. condition = condition.to(self.model_config.params_dtype) y_xattn_flat = y_xattn_flat.to(self.model_config.params_dtype) core_attn_params = PackedCoreAttnParams( q_range=ardf_meta["q_range"], k_range=ardf_meta["k_range"], np_q_range=ardf_meta["q_range"].cpu().numpy(), np_k_range=ardf_meta["k_range"].cpu().numpy(), max_seqlen_q=ardf_meta["max_seqlen_q"], max_seqlen_k=ardf_meta["max_seqlen_k"], ) (x, condition_map, rope, cp_pad_size, cp_split_sizes, core_attn_params, cross_attn_params) = cp_pre_process( self.engine_config.cp_size, self.engine_config.cp_strategy, x, condition_map, rope, xattn_mask_for_cuda_graph, ardf_meta, core_attn_params, cross_attn_params, ) meta_args = ModelMetaArgs( H=H, W=W, cp_pad_size=cp_pad_size, cp_split_sizes=cp_split_sizes, slice_point=ardf_meta["slice_point"], denoising_range_num=ardf_meta["denoising_range_num"], range_num=ardf_meta["range_num"], extract_prefix_video_feature=kwargs.get("extract_prefix_video_feature", False), fwd_extra_1st_chunk=kwargs["fwd_extra_1st_chunk"], distill_nearly_clean_chunk=kwargs.get("distill_nearly_clean_chunk", False), clip_token_nums=ardf_meta["clip_token_nums"], enable_cuda_graph=xattn_mask_for_cuda_graph is not None, core_attn_params=core_attn_params, cross_attn_params=cross_attn_params, ) return (x, condition, condition_map, y_xattn_flat, rope, meta_args) @torch.no_grad() def forward_post_process(self, x, meta_args: ModelMetaArgs) -> torch.Tensor: x = x.float() # embedder context will ensure that the processing is in high precision even if the embedder params is in bfloat16 mode with torch.autocast(device_type="cuda", dtype=torch.float32): x = self.final_linear(x) # (thw/cp, N, patch_size ** 2 * out_channels) # leave context parallel region x = cp_post_process(self.engine_config.cp_size, self.engine_config.cp_strategy, x, meta_args) # N C T H W x = self.unpatchify(x, meta_args.H, meta_args.W) if self.model_config.half_channel_vae: assert x.shape[1] == 32 x = x[:, :16] x = x / self.model_config.x_rescale_factor return x @torch.no_grad() def forward( self, x, t, y, caption_dropout_mask=None, xattn_mask=None, kv_range=None, inference_params: InferenceParams = None, **kwargs, ) -> torch.Tensor: (x, condition, condition_map, y_xattn_flat, rope, meta_args) = self.forward_pre_process( x, t, y, caption_dropout_mask, xattn_mask, kv_range, **kwargs ) if not self.pre_process: x = pp_scheduler().recv_prev_data(x.shape, x.dtype) self.videodit_blocks.set_input_tensor(x) else: # clone a new tensor to ensure x is not a view of other tensor x = x.clone() x = self.videodit_blocks.forward( hidden_states=x, condition=condition, condition_map=condition_map, y_xattn_flat=y_xattn_flat, rotary_pos_emb=rope, inference_params=inference_params, meta_args=meta_args, ) if not self.post_process: pp_scheduler().isend_next(x) return self.forward_post_process(x, meta_args) def forward_3cfg( self, x, timestep, y, mask, kv_range, inference_params, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: """ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb assert x.shape[0] == 2 assert mask.shape[0] % 2 == 0 # mask should be a multiple of 2 x = torch.cat([x[0:1], x[0:1]], dim=0) caption_dropout_mask = torch.tensor([False, True], dtype=torch.bool, device=x.device) inference_params.update_kv_cache = False out_cond_pre_and_text = self.forward( x[0:1], timestep[0:1], y[0 : y.shape[0] // 2], caption_dropout_mask=caption_dropout_mask[0:1], xattn_mask=mask[0 : y.shape[0] // 2], kv_range=kv_range, inference_params=inference_params, **kwargs, ) inference_params.update_kv_cache = True out_cond_pre = self.forward( x[1:2], timestep[1:2], y[y.shape[0] // 2 : y.shape[0]], caption_dropout_mask=caption_dropout_mask[1:2], xattn_mask=mask[y.shape[0] // 2 : y.shape[0]], kv_range=kv_range, inference_params=inference_params, **kwargs, ) def chunk_to_batch(input, denoising_range_num): input = input.squeeze(0) input = input.reshape(-1, denoising_range_num, kwargs["chunk_width"], *input.shape[2:]) return input.transpose(0, 1) # (denoising_range_num, chn, chunk_width, h, w) def batch_to_chunk(input, denoising_range_num): input = input.transpose(0, 1) input = input.reshape(1, -1, denoising_range_num * kwargs["chunk_width"], *input.shape[3:]) return input class UnconditionGuard: def __init__(self, kwargs): self.kwargs = kwargs self.prev_state = { "range_num": kwargs["range_num"], "denoising_range_num": kwargs["denoising_range_num"], "slice_point": kwargs["slice_point"], "fwd_extra_1st_chunk": kwargs["fwd_extra_1st_chunk"], } def __enter__(self): if self.kwargs.get("fwd_extra_1st_chunk", False): self.kwargs["denoising_range_num"] -= 1 self.kwargs["slice_point"] += 1 self.kwargs["fwd_extra_1st_chunk"] = False def __exit__(self, exc_type, exc_val, exc_tb): self.kwargs["range_num"] = self.prev_state["range_num"] self.kwargs["denoising_range_num"] = self.prev_state["denoising_range_num"] self.kwargs["slice_point"] = self.prev_state["slice_point"] self.kwargs["fwd_extra_1st_chunk"] = self.prev_state["fwd_extra_1st_chunk"] with UnconditionGuard(kwargs): denoising_range_num = kwargs["denoising_range_num"] denoise_width = kwargs["chunk_width"] * denoising_range_num uncond_x = chunk_to_batch(x[0:1, :, -denoise_width:], denoising_range_num) timestep = timestep[0:1, -denoising_range_num:].transpose(0, 1) uncond_y = y[y.shape[0] // 2 : y.shape[0]][-denoising_range_num:] caption_dropout_mask = torch.tensor([True], dtype=torch.bool, device=x.device) uncond_mask = mask[y.shape[0] // 2 : y.shape[0]][-denoising_range_num:] uncond_kv_range = self.generate_kv_range_for_uncondition(uncond_x) kwargs["range_num"] = 1 kwargs["denoising_range_num"] = 1 kwargs["slice_point"] = 0 out_uncond = self.forward( uncond_x, timestep, uncond_y, caption_dropout_mask=caption_dropout_mask, xattn_mask=uncond_mask, kv_range=uncond_kv_range, inference_params=None, **kwargs, ) out_uncond = batch_to_chunk(out_uncond, denoising_range_num) return out_cond_pre_and_text, out_cond_pre, out_uncond, denoise_width def get_cfg_scale(self, t, cfg_t_range, prev_chunk_scale_s, text_scale_s): indices = torch.searchsorted(cfg_t_range - 1e-7, t) - 1 assert indices.min() >= 0 and indices.max() < len(prev_chunk_scale_s) return prev_chunk_scale_s[indices], text_scale_s[indices] def forward_dispatcher(self, x, timestep, y, mask, kv_range, inference_params, **kwargs): if self.runtime_config.cfg_number == 3: (out_cond_pre_and_text, out_cond_pre, out_uncond, denoise_width) = self.forward_3cfg( x, timestep, y, mask, kv_range, inference_params, **kwargs ) prev_chunk_scale_s = torch.tensor(self.runtime_config.prev_chunk_scales).cuda() text_scale_s = torch.tensor(self.runtime_config.text_scales).cuda() cfg_t_range = torch.tensor(self.runtime_config.cfg_t_range).cuda() applied_cfg_range_num, chunk_width = (kwargs["denoising_range_num"], kwargs["chunk_width"]) if kwargs["fwd_extra_1st_chunk"]: applied_cfg_range_num -= 1 cfg_timestep = timestep[0, -applied_cfg_range_num:] assert len(prev_chunk_scale_s) == len(cfg_t_range), "prev_chunks_scale and t_range should have the same length" assert len(text_scale_s) == len(cfg_t_range), "text_scale and t_range should have the same length" cfg_output_list = [] for chunk_idx in range(applied_cfg_range_num): prev_chunk_scale, text_scale = self.get_cfg_scale( cfg_timestep[chunk_idx], cfg_t_range, prev_chunk_scale_s, text_scale_s ) l = chunk_idx * chunk_width r = (chunk_idx + 1) * chunk_width cfg_output = ( (1 - prev_chunk_scale) * out_uncond[:, :, l:r] + (prev_chunk_scale - text_scale) * out_cond_pre[:, :, -denoise_width:][:, :, l:r] + text_scale * out_cond_pre_and_text[:, :, -denoise_width:][:, :, l:r] ) cfg_output_list.append(cfg_output) cfg_output = torch.cat(cfg_output_list, dim=2) x = torch.cat([x[0:1, :, :-denoise_width], cfg_output], dim=2) x = torch.cat([x, x], dim=0) return x elif self.runtime_config.cfg_number == 1: assert x.shape[0] == 2 x = torch.cat([x[0:1], x[0:1]], dim=0) kwargs["caption_dropout_mask"] = torch.tensor([False], dtype=torch.bool, device=x.device) inference_params.update_kv_cache = True if kwargs.get("distill_nearly_clean_chunk", False): prev_chunks_scale = float(os.getenv("prev_chunks_scale", 0.7)) slice_start = 1 if kwargs["fwd_extra_1st_chunk"] else 0 cond_pre_and_text_channel = x.shape[2] new_x_chunk = x[0:1, :, slice_start * kwargs["chunk_width"] : (slice_start + 1) * kwargs["chunk_width"]] new_kvrange = self.generate_kv_range_for_uncondition(new_x_chunk) kwargs["denoising_range_num"] += 1 cat_x_chunk = torch.cat([x[0:1], new_x_chunk], dim=2) new_kvrange = new_kvrange + kv_range.max() cat_kvrange = torch.cat([kv_range, new_kvrange], dim=0) cat_t = torch.cat([timestep[0:1], timestep[0:1, slice_start : slice_start + 1]], dim=1) cat_y = torch.cat([y[0 : y.shape[0] // 2], y[slice_start : slice_start + 1]], dim=0) cat_xattn_mask = torch.cat([mask[0 : y.shape[0] // 2], mask[slice_start : slice_start + 1]], dim=0) cat_out = self.forward( cat_x_chunk, cat_t, cat_y, xattn_mask=cat_xattn_mask, kv_range=cat_kvrange, inference_params=inference_params, **kwargs, ) near_clean_out_cond_pre_and_text = cat_out[ :, :, slice_start * kwargs["chunk_width"] : (slice_start + 1) * kwargs["chunk_width"] ] near_clean_out_cond_text = cat_out[:, :, cond_pre_and_text_channel:] near_out_cond_pre_and_text = ( near_clean_out_cond_pre_and_text * prev_chunks_scale + near_clean_out_cond_text * (1 - prev_chunks_scale) ) cat_out[ :, :, slice_start * kwargs["chunk_width"] : (slice_start + 1) * kwargs["chunk_width"] ] = near_out_cond_pre_and_text out_cond_pre_and_text = cat_out[:, :, :cond_pre_and_text_channel] else: out_cond_pre_and_text = self.forward( x[0:1], timestep[0:1], y[0 : y.shape[0] // 2], xattn_mask=mask[0 : y.shape[0] // 2], kv_range=kv_range, inference_params=inference_params, **kwargs, ) denoise_width = kwargs["chunk_width"] * kwargs["denoising_range_num"] if kwargs["fwd_extra_1st_chunk"]: denoise_width -= kwargs["chunk_width"] x = torch.cat([x[0:1, :, :-denoise_width], out_cond_pre_and_text[:, :, -denoise_width:]], dim=2) x = torch.cat([x[0:1], x[0:1]], dim=0) return x else: raise NotImplementedError def _build_dit_model(config: MagiConfig): """Builds the model""" device = "cuda" if env_is_true("SKIP_LOAD_MODEL") else "meta" with torch.device(device): model = VideoDiTModel( config=config, pre_process=mpu.is_pipeline_first_stage(), post_process=mpu.is_pipeline_last_stage() ) print_rank_0(model) # Print number of parameters. param_count = sum([p.nelement() for p in model.parameters()]) model_size_gb = sum([p.nelement() * p.element_size() for p in model.parameters()]) / (1024**3) print_per_rank( f"(cp, pp) rank ({mpu.get_cp_rank()}, {mpu.get_pp_rank()}): param count {param_count}, model size {model_size_gb:.2f} GB".format( mpu.get_cp_rank(), mpu.get_pp_rank(), param_count, model_size_gb ) ) return model def _high_precision_promoter(module: VideoDiTModel): module.x_embedder.float() module.y_embedder.float() module.t_embedder.float() module.final_linear.float() module.rope.float() for name, sub_module in module.named_modules(): # skip qk_layernorm_xattn if "_xattn" in name: continue # high precision qk_layernorm by default if "q_layernorm" in name or "k_layernorm" in name: sub_module.float() if "self_attn_post_norm" in name or "mlp_post_norm" in name: sub_module.float() if "final_layernorm" in name: sub_module.float() return module def get_dit(config: MagiConfig): """Build and load VideoDiT model""" model = _build_dit_model(config) print_rank_0("Build DiTModel successfully") mem_allocated_gb = torch.cuda.memory_allocated() / 1024**3 mem_reserved_gb = torch.cuda.memory_reserved() / 1024**3 print_rank_0( f"After build_dit_model, memory allocated: {mem_allocated_gb:.2f} GB, memory reserved: {mem_reserved_gb:.2f} GB" ) # To avoid Error in debug mode, set default iteration to 0 if not env_is_true("SKIP_LOAD_MODEL"): model = load_checkpoint(model) mem_allocated_gb = torch.cuda.memory_allocated() / 1024**3 mem_reserved_gb = torch.cuda.memory_reserved() / 1024**3 print_rank_0( f"After load_checkpoint, memory allocated: {mem_allocated_gb:.2f} GB, memory reserved: {mem_reserved_gb:.2f} GB" ) model = _high_precision_promoter(model) mem_allocated_gb = torch.cuda.memory_allocated() / 1024**3 mem_reserved_gb = torch.cuda.memory_reserved() / 1024**3 print_rank_0( f"After high_precision_promoter, memory allocated: {mem_allocated_gb:.2f} GB, memory reserved: {mem_reserved_gb:.2f} GB" ) model.eval() gc.collect() torch.cuda.empty_cache() print_rank_0("Load checkpoint successfully") return model ``` ## /inference/model/dit/dit_module.py ```py path="/inference/model/dit/dit_module.py" # Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import numbers from functools import partial from typing import Callable, List, Optional, Tuple import flashinfer import torch import torch.distributed import torch.nn as nn import triton import triton.language as tl from einops import rearrange from flash_attn import flash_attn_varlen_func from flash_attn.flash_attn_interface import flash_attn_func from flash_attn.layers.rotary import apply_rotary_emb as flash_apply_rotary_emb from flashinfer.gemm import bmm_fp8 from magi_attention.functional import flex_flash_attn_func as flex_attention # from dffa.functional import flex_flash_attn_func as flex_attention from torch import Tensor from torch.nn import Parameter from inference.common import EngineConfig, InferenceParams, ModelConfig, ModelMetaArgs, PackedCrossAttnParams, divide from inference.infra.distributed import parallel_state from inference.infra.parallelism import CSOHelper, UlyssesScheduler, cso_communication ########################################################## # TimestepEmbedder ########################################################## class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, model_config: ModelConfig, frequency_embedding_size=256): super().__init__() self.data_type = model_config.params_dtype hidden_size = model_config.hidden_size self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, int(hidden_size * model_config.cond_hidden_ratio), bias=True), nn.SiLU(), nn.Linear( int(hidden_size * model_config.cond_hidden_ratio), int(hidden_size * model_config.cond_hidden_ratio), bias=True ), ) self.frequency_embedding_size = frequency_embedding_size # rescale the timestep for the general transport model self.timestep_rescale_factor = 1000 @staticmethod def timestep_embedding(t, dim, max_period=10000, timestep_rescale_factor=1): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( device=t.device ) args = t[:, None].float() * freqs[None] * timestep_rescale_factor embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t = t.to(torch.float32) t_freq = self.timestep_embedding( t, self.frequency_embedding_size, timestep_rescale_factor=self.timestep_rescale_factor ) t_emb = self.mlp(t_freq.to(self.data_type)) return t_emb ########################################################## # CaptionEmbedder ########################################################## class CaptionEmbedder(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. """ def __init__(self, model_config: ModelConfig): super().__init__() in_channels = model_config.caption_channels hidden_size = model_config.hidden_size caption_max_length = model_config.caption_max_length self.y_proj_xattn = nn.Sequential( nn.Linear(in_channels, int(hidden_size * model_config.xattn_cond_hidden_ratio), bias=True), nn.SiLU() ) self.y_proj_adaln = nn.Sequential(nn.Linear(in_channels, int(hidden_size * model_config.cond_hidden_ratio), bias=True)) self.null_caption_embedding = Parameter(torch.empty(caption_max_length, in_channels)) def caption_drop(self, caption, caption_dropout_mask): """ Drops labels to enable classifier-free guidance. caption.shape = (N, 1, cap_len, C) """ dropped_caption = torch.where( caption_dropout_mask[:, None, None, None], # (N, 1, 1, 1) self.null_caption_embedding[None, None, :], # (1, 1, cap_len, C) caption, # (N, 1, cap_len, C) ) return dropped_caption def caption_drop_single_token(self, caption_dropout_mask): dropped_caption = torch.where( caption_dropout_mask[:, None, None], # (N, 1, 1) self.null_caption_embedding[None, -1, :], # (1, 1, C) self.null_caption_embedding[None, -2, :], # (1, 1, C) ) return dropped_caption # (N, 1, C) def forward(self, caption, train, caption_dropout_mask=None): if train and caption_dropout_mask is not None: caption = self.caption_drop(caption, caption_dropout_mask) caption_xattn = self.y_proj_xattn(caption) if caption_dropout_mask is not None: caption = self.caption_drop_single_token(caption_dropout_mask) caption_adaln = self.y_proj_adaln(caption) return caption_xattn, caption_adaln ########################################################## # FinalLinear ########################################################## class FinalLinear(nn.Module): """ The final linear layer of DiT. """ def __init__(self, hidden_size, patch_size, t_patch_size, out_channels): super().__init__() self.linear = nn.Linear(hidden_size, patch_size * patch_size * t_patch_size * out_channels, bias=False) def forward(self, x): x = self.linear(x) return x ########################################################## # AdaModulateLayer ########################################################## class AdaModulateLayer(torch.nn.Module): def __init__(self, model_config: ModelConfig): super().__init__() self.model_config = model_config self.gate_num_chunks = 2 self.act = nn.SiLU() self.proj = nn.Sequential( nn.Linear( int(self.model_config.hidden_size * self.model_config.cond_hidden_ratio), int(self.model_config.hidden_size * self.model_config.cond_gating_ratio * self.gate_num_chunks), bias=True, dtype=self.model_config.params_dtype, ) ) def forward(self, c): c = self.act(c) return self.proj(c) ########################################################## # bias_modulate_add ########################################################## @triton.jit def range_mod_kernel_fwd( X, # pointer to the input MAP, # map x index to gating index GATINGS, # pointer to the gatings Y, # pointer to the output M, # number of rows in X, unused N, # number of columns in X stride_xm, # how much to increase the pointer when moving by 1 row in X stride_xn, # how much to increase the pointer when moving by 1 column in X stride_gm, # how much to increase the pointer when moving by 1 row in GATINGS stride_gn, # how much to increase the pointer when moving by 1 column in GATINGS stride_ym, # how much to increase the pointer when moving by 1 row in Y stride_yn, # how much to increase the pointer when moving by 1 column in Y BLOCK_SIZE: tl.constexpr, # number of columns in a block ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) cur_X = X + row * stride_xm x_cols = tl.arange(0, BLOCK_SIZE) * stride_xn x_mask = x_cols < N * stride_xn x = tl.load(cur_X + x_cols, mask=x_mask, other=0.0) cur_MAP = MAP + row gating_index = tl.load(cur_MAP) cur_GATING = GATINGS + gating_index * stride_gm gating_cols = tl.arange(0, BLOCK_SIZE) * stride_gn gating_mask = gating_cols < N * stride_gn gating = tl.load(cur_GATING + gating_cols, mask=gating_mask, other=0.0) cur_Y = Y + row * stride_ym y_cols = tl.arange(0, BLOCK_SIZE) * stride_yn y_mask = y_cols < N * stride_yn tl.store(cur_Y + y_cols, x * gating, mask=y_mask) def range_mod_triton(x, c_mapping, gatings): """ Inputs: x: (s, b, h). Tensor of inputs embedding (images or latent representations of images) c_mapping: (s, b). Tensor of condition map gatings: (b, denoising_range_num, h). Tensor of condition embedding """ assert x.is_cuda, "x is not on cuda" assert c_mapping.is_cuda, "c_mapping is not on cuda" assert gatings.is_cuda, "gatings is not on cuda" # TODO: use 3D tensor for x, c_mapping, and gatings s, b, h = x.shape x = x.transpose(0, 1).flatten(0, 1) c_mapping = c_mapping.transpose(0, 1).flatten(0, 1) gatings = gatings.flatten(0, 1) assert x.dim() == 2, f"x must be a 2D tensor but got {x.dim()}D" assert c_mapping.dim() == 1, f"c_mapping must be a 1D tensor but got {c_mapping.dim()}D" assert gatings.dim() == 2, f"gatings must be a 2D tensor but got {gatings.dim()}D" M, N = x.shape assert c_mapping.size(0) == M, "c_mapping must have the same number of rows as x" # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_SIZE: raise RuntimeError("range_mod_triton doesn't support feature dim >= 64KB.") MAP = c_mapping y = torch.empty_like(x) range_mod_kernel_fwd[(M,)]( x, MAP, gatings, y, M, N, x.stride(0), x.stride(1), gatings.stride(0), gatings.stride(1), y.stride(0), y.stride(1), BLOCK_SIZE=BLOCK_SIZE, ) y = y.reshape(b, s, h).transpose(0, 1) return y def bias_modulate_add( x: torch.Tensor, residual: torch.Tensor, condition_map: torch.Tensor, gate: torch.Tensor, post_norm: torch.nn.Module ): assert gate.shape[-1] == x.shape[-1] original_dtype = x.dtype x = x.float() residual = residual.float() gate = gate.float() x = range_mod_triton(x, condition_map, gate) x = post_norm(x) x = x + residual x = x.to(original_dtype) return x ########################################################## # FusedLayerNorm ########################################################## def make_viewless_tensor(inp, requires_grad): # return tensor as-is, if not a 'view' if inp._base is None: return inp out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad) out.data = inp.data return out class FusedLayerNorm(torch.nn.Module): """ Layer Norm, fused into a single CUDA kernel. Borrow from: https://github.com/NVIDIA/Megatron-LM/blob/6501752396e9cc360ce894cda4b2217a58c1c09d/megatron/core/fusions/fused_layer_norm.py#L30 Args: hidden_size (int): Transformer hidden dimension. eps (float): Epsilon added to denominator, for numerical stability. zero_centered_gamma (bool): Adjust LayerNorm weights such that they are centered around zero. This improves numerical stability. model_config (ModelConfig): Transformer config. Include to match custom layer norm interfaces. normalization (str): Normalization type, used for Transformer Engine. Must equal 'LayerNorm' here. """ def __init__(self, model_config: ModelConfig, hidden_size: int): super().__init__() self.zero_centered_gamma = model_config.apply_layernorm_1p if isinstance(hidden_size, numbers.Integral): hidden_size = (hidden_size,) self.hidden_size = torch.Size(hidden_size) self.eps = model_config.layernorm_epsilon self.weight = Parameter(torch.empty(*hidden_size, dtype=model_config.params_dtype)) self.bias = Parameter(torch.empty(*hidden_size, dtype=model_config.params_dtype)) def forward(self, input: Tensor) -> Tensor: weight = self.weight + 1 if self.zero_centered_gamma else self.weight return torch.nn.functional.layer_norm(input, self.hidden_size, weight, self.bias, self.eps) def softcap(x: torch.Tensor, cap: int): return (cap * torch.tanh(x.float() / cap)).to(x.dtype) def div_clamp_to(x: torch.Tensor, scale: torch.Tensor): fp8_min = torch.finfo(torch.float8_e4m3fn).min fp8_max = torch.finfo(torch.float8_e4m3fn).max prefix_shape = x.shape[:-1] last_shape = x.shape[-1] x = x.flatten().reshape(-1, last_shape) # Split x into 256 MB parts to avoid big memory peak part_size = 256 * 1024 * 1024 // last_shape part_num = (x.shape[0] + part_size - 1) // part_size return ( torch.cat( [ torch.clamp(x[i * part_size : (i + 1) * part_size].float() / scale.float(), fp8_min, fp8_max).bfloat16() for i in range(part_num) ], dim=0, ) .to(torch.float8_e4m3fn) .reshape(*prefix_shape, last_shape) .contiguous() ) ########################################################## # CustomLayerNormLinear ########################################################## class CustomLayerNormLinear(torch.nn.Module): def __init__( self, input_size: int, output_size_q: int, output_size_kv: int, layer_number: int, model_config: ModelConfig, engine_config: EngineConfig, ): super().__init__() self.layer_norm = torch.nn.LayerNorm(input_size, eps=model_config.layernorm_epsilon, dtype=model_config.params_dtype) self.layer_number = layer_number layers = {"q": output_size_q, "qx": output_size_q, "k": output_size_kv, "v": output_size_kv} for name, output_size in layers.items(): if not engine_config.fp8_quant or self.layer_number == 0 or self.layer_number == model_config.num_layers - 1: setattr(self, name, torch.nn.Linear(input_size, output_size, bias=False, dtype=model_config.params_dtype)) else: setattr(self, name, PerTensorQuantizedFp8Linear(input_size, output_size)) def forward_ln(self, hidden_states): return self.layer_norm(hidden_states) def forward_q(self, hidden_states): return self.q(hidden_states) def forward_qx(self, hidden_states): return self.qx(hidden_states) def forward_k(self, hidden_states): return self.k(hidden_states) def forward_v(self, hidden_states): return self.v(hidden_states) ########################################################## # PerTensorQuantizedFp8Linear ########################################################## class PerTensorQuantizedFp8Linear(torch.nn.Module): # The bias and device parameter is not used; it is included for compatibility with Linear's parameters. def __init__(self, in_features: int, out_features: int, bias=False, dtype=torch.bfloat16, device=None) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.finfo = torch.finfo(torch.float8_e4m3fn) self.output_dtype = dtype self.weight = Parameter(torch.empty((1, out_features, in_features), dtype=torch.float8_e4m3fn)) self.weight_scale = Parameter(torch.empty(1, dtype=torch.float32)) self.input_scale = Parameter(torch.empty(in_features, dtype=torch.float32)) def forward(self, input: torch.Tensor): input = div_clamp_to(input, self.input_scale) prefix_shape = input.shape[:-1] # column major weight return bmm_fp8( input.reshape(1, -1, self.in_features), self.weight.transpose(-2, -1), self.input_scale, self.weight_scale, dtype=self.output_dtype, ).reshape(prefix_shape + (self.out_features,)) ########################################################## # PerChannelQuantizedFp8Linear ########################################################## class PerChannelQuantizedFp8Linear(torch.nn.Module): # The bias and device parameter is not used; it is included for compatibility with Linear's parameters. def __init__(self, in_features: int, out_features: int, bias=False, dtype=torch.bfloat16, device=None) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.output_dtype = dtype self.finfo = torch.finfo(torch.float8_e4m3fn) self.weight = Parameter(torch.empty((1, out_features, in_features), dtype=torch.float8_e4m3fn)) self.weight_scale = Parameter(torch.empty(1, dtype=torch.float32)) self.input_scale = Parameter(torch.empty(1, dtype=torch.float32)) self.smooth_scale = Parameter(torch.empty(1, in_features, dtype=torch.float32)) def forward(self, x): x = div_clamp_to(x, self.smooth_scale.to(torch.float32)) prefix_shape = x.shape[:-1] return bmm_fp8( x.reshape(1, -1, self.in_features), self.weight.transpose(-2, -1), self.input_scale, self.weight_scale, dtype=self.output_dtype, ).reshape(prefix_shape + (self.out_features,)) ########################################################## # CustomMLP ########################################################## class CustomMLP(torch.nn.Module): """ CustomMLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. Returns an output and a bias to be added to the output. We use the following notation: h: hidden size p: number of tensor model parallel partitions b: batch size s: sequence length """ def __init__(self, model_config: ModelConfig, engine_config: EngineConfig, layer_number: int, input_size: int = None): super().__init__() self.model_config: ModelConfig = model_config self.engine_config: EngineConfig = engine_config self.layer_number = layer_number self.input_size = input_size if input_size != None else self.model_config.hidden_size self.layer_norm = torch.nn.LayerNorm( self.input_size, eps=self.model_config.layernorm_epsilon, dtype=self.model_config.params_dtype ) submodules_linear_fc1 = torch.nn.Linear if self.engine_config.fp8_quant and self.layer_number != 0 and self.layer_number != model_config.num_layers - 1: submodules_linear_fc1 = PerTensorQuantizedFp8Linear if self.model_config.gated_linear_unit: self.linear_fc1 = submodules_linear_fc1( self.input_size, 2 * self.model_config.ffn_hidden_size, bias=False, dtype=self.model_config.params_dtype ) else: self.linear_fc1 = submodules_linear_fc1( self.input_size, self.model_config.ffn_hidden_size, bias=False, dtype=self.model_config.params_dtype ) submodules_linear_fc2 = torch.nn.Linear if engine_config.fp8_quant and self.layer_number != 0 and self.layer_number != model_config.num_layers - 1: submodules_linear_fc2 = PerChannelQuantizedFp8Linear self.linear_fc2 = submodules_linear_fc2( self.model_config.ffn_hidden_size, self.model_config.hidden_size, bias=False, dtype=self.model_config.params_dtype ) def forward(self, hidden_states): hidden_states = self.layer_norm(hidden_states) hidden_states = self.linear_fc1(hidden_states) if self.model_config.gated_linear_unit: hidden_states = flashinfer.activation.silu_and_mul(hidden_states) else: hidden_states = torch.nn.functional.gelu(hidden_states) hidden_states = self.linear_fc2(hidden_states) return hidden_states ########################################################## # LearnableRotaryEmbeddingCat ########################################################## def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]: """generate N-D grid in dimension order. The ndgrid function is like meshgrid except that the order of the first two input arguments are switched. That is, the statement [X1,X2,X3] = ndgrid(x1,x2,x3) produces the same result as [X2,X1,X3] = meshgrid(x2,x1,x3) This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy'). """ try: return torch.meshgrid(*tensors, indexing="ij") except TypeError: # old PyTorch < 1.10 will follow this path as it does not have indexing arg, # the old behaviour of meshgrid was 'ij' return torch.meshgrid(*tensors) def pixel_freq_bands( num_bands: int, max_freq: float = 224.0, linear_bands: bool = True, device: Optional[torch.device] = None ): if linear_bands: bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device) else: bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device) return bands * torch.pi def freq_bands( num_bands: int, temperature: float = 10000.0, step: int = 2, device: Optional[torch.device] = None ) -> torch.Tensor: exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands bands = 1.0 / (temperature**exp) return bands def build_fourier_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, num_bands: int = 64, max_res: int = 224, temperature: float = 10000.0, linear_bands: bool = False, include_grid: bool = False, in_pixels: bool = True, ref_feat_shape: Optional[List[int]] = None, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> List[torch.Tensor]: """ Args: feat_shape: Feature shape for embedding. bands: Pre-calculated frequency bands. num_bands: Number of frequency bands (determines output dim). max_res: Maximum resolution for pixel based freq. temperature: Temperature for non-pixel freq. linear_bands: Linear band spacing for pixel based freq. include_grid: Include the spatial grid in output. in_pixels: Output in pixel freq. ref_feat_shape: Reference feature shape for resize / fine-tune. dtype: Output dtype. device: Output device. Returns: """ if bands is None: if in_pixels: bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, device=device) else: bands = freq_bands(num_bands, temperature=temperature, step=1, device=device) else: if device is None: device = bands.device if dtype is None: dtype = bands.dtype if in_pixels: t = [torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=torch.float32) for s in feat_shape] else: t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape] # align spatial center (H/2,W/2) to (0,0) t[1] = t[1] - (feat_shape[1] - 1) / 2 t[2] = t[2] - (feat_shape[2] - 1) / 2 if ref_feat_shape is not None: # eva's scheme for resizing rope embeddings (ref shape = pretrain) # aligning to the endpoint e.g [0,1,2] -> [0, 0.4, 0.8, 1.2, 1.6, 2] t_rescaled = [] for x, f, r in zip(t, feat_shape, ref_feat_shape): # deal with image input if f == 1: assert r == 1, "ref_feat_shape must be 1 when feat_shape is 1" t_rescaled.append(x) else: t_rescaled.append(x / (f - 1) * (r - 1)) t = t_rescaled grid = torch.stack(ndgrid(t), dim=-1) grid = grid.unsqueeze(-1) pos = grid * bands pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype) out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos] return out def build_rotary_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, dim: int = 64, max_res: int = 224, temperature: float = 10000.0, linear_bands: bool = False, in_pixels: bool = True, ref_feat_shape: Optional[List[int]] = None, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): """ Args: feat_shape: Spatial shape of the target tensor for embedding. bands: Optional pre-generated frequency bands dim: Output dimension of embedding tensor. max_res: Maximum resolution for pixel mode. temperature: Temperature (inv freq) for non-pixel mode linear_bands: Linearly (instead of log) spaced bands for pixel mode in_pixels: Pixel vs language (inv freq) mode. dtype: Output dtype. device: Output device. Returns: """ sin_emb, cos_emb = build_fourier_pos_embed( feat_shape, bands=bands, num_bands=dim // 8, max_res=max_res, temperature=temperature, linear_bands=linear_bands, in_pixels=in_pixels, ref_feat_shape=ref_feat_shape, device=device, dtype=dtype, ) num_spatial_dim = 1 # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks for x in feat_shape: num_spatial_dim *= x sin_emb = sin_emb.reshape(num_spatial_dim, -1) cos_emb = cos_emb.reshape(num_spatial_dim, -1) return sin_emb, cos_emb class LearnableRotaryEmbeddingCat(nn.Module): """Rotary position embedding w/ concatenatd sin & cos The following impl/resources were referenced for this impl: * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py * https://blog.eleuther.ai/rotary-embeddings/ """ def __init__( self, dim, max_res=224, temperature=10000, in_pixels=True, linear_bands: bool = False, feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None, ): super().__init__() self.dim = dim self.max_res = max_res self.temperature = temperature self.in_pixels = in_pixels self.linear_bands = linear_bands self.feat_shape = feat_shape self.ref_feat_shape = ref_feat_shape self.bands = nn.Parameter(self.get_default_bands()) def get_default_bands(self): if self.in_pixels: bands = pixel_freq_bands( self.dim // 8, float(self.max_res), linear_bands=self.linear_bands, devicse=torch.cuda.current_device() ) else: bands = freq_bands(self.dim // 8, temperature=self.temperature, step=1, device=torch.cuda.current_device()) return bands def get_embed(self, shape: Optional[List[int]], ref_feat_shape: Optional[List[int]] = None): # rebuild bands and embeddings every call, use if target shape changes embeds = build_rotary_pos_embed( feat_shape=shape, bands=self.bands, # use learned bands dim=self.dim, max_res=self.max_res, linear_bands=self.linear_bands, in_pixels=self.in_pixels, ref_feat_shape=ref_feat_shape if ref_feat_shape else self.ref_feat_shape, temperature=self.temperature, device=torch.cuda.current_device(), ) return torch.cat(embeds, -1) ########################################################## # Attention ########################################################## class Attention(torch.nn.Module): """ Attention layer abstract class. """ def __init__(self, model_config: ModelConfig, engine_config: EngineConfig, layer_number: int): super().__init__() self.model_config: ModelConfig = model_config self.engine_config: EngineConfig = engine_config self.layer_number = layer_number self.hidden_size_per_attention_head = self.model_config.kv_channels # num_query_groups and num_attention_heads are different for GQA self.query_projection_size = self.model_config.kv_channels * self.model_config.num_attention_heads self.kv_projection_size = self.model_config.kv_channels * self.model_config.num_query_groups # Per attention head and per partition values. world_size = parallel_state.get_tp_world_size(with_context_parallel=True) if world_size > self.model_config.num_query_groups and world_size % self.model_config.num_query_groups == 0: self.num_query_groups_per_partition = 1 else: self.num_query_groups_per_partition = divide(self.model_config.num_query_groups, world_size) def _allocate_key_and_value_memory(self, sequence_length, batch_size, dtype): """Allocate memory to store kv cache during inference.""" if self.engine_config.kv_offload: return torch.empty( sequence_length * batch_size, self.num_query_groups_per_partition, self.hidden_size_per_attention_head * 2, dtype=dtype, device=torch.cpu.current_device(), pin_memory=True, ) else: return torch.empty( sequence_length * batch_size, self.num_query_groups_per_partition, self.hidden_size_per_attention_head * 2, dtype=dtype, device=torch.cuda.current_device(), ) ########################################################## # FullyParallelAttention ########################################################## def split_tensor_along_last_dim( tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False ) -> List[torch.Tensor]: """Split a tensor along its last dimension. Args: tensor: input tensor. num_partitions: number of partitions to split the tensor contiguous_split_chunks: If True, make each chunk contiguous in memory. Returns: A list of Tensors """ # Get the size and dimension. last_dim = tensor.dim() - 1 last_dim_size = divide(tensor.size()[last_dim], num_partitions) # Split. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) # Note: torch.split does not create contiguous tensors by default. if contiguous_split_chunks: return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list class FullyParallelAttention(Attention): def __init__(self, model_config: ModelConfig, engine_config: EngineConfig, layer_number: int): super().__init__(model_config=model_config, engine_config=engine_config, layer_number=layer_number) # output 2x query, one for self-attn, one for cross-attn with condition self.linear_qkv = CustomLayerNormLinear( input_size=self.model_config.hidden_size, output_size_q=self.query_projection_size, output_size_kv=self.kv_projection_size, layer_number=self.layer_number, model_config=self.model_config, engine_config=self.engine_config, ) # kv from condition, e.g., caption self.linear_kv_xattn = torch.nn.Linear( int(self.model_config.hidden_size * self.model_config.xattn_cond_hidden_ratio), # 6144 2 * self.kv_projection_size, # 2048 dtype=self.model_config.params_dtype, bias=False, ) # Output. self.adapt_linear_quant = ( self.engine_config.fp8_quant and self.layer_number != 0 and self.layer_number != model_config.num_layers - 1 ) submodules_linear_proj = PerChannelQuantizedFp8Linear if self.adapt_linear_quant else torch.nn.Linear self.linear_proj = submodules_linear_proj( 2 * self.query_projection_size, self.model_config.hidden_size, dtype=self.model_config.params_dtype, bias=False ) self.q_layernorm = FusedLayerNorm(model_config=self.model_config, hidden_size=self.hidden_size_per_attention_head) self.q_layernorm_xattn = FusedLayerNorm( model_config=self.model_config, hidden_size=self.hidden_size_per_attention_head ) self.k_layernorm = FusedLayerNorm(model_config=self.model_config, hidden_size=self.hidden_size_per_attention_head) self.k_layernorm_xattn = FusedLayerNorm( model_config=self.model_config, hidden_size=self.hidden_size_per_attention_head ) def _full_adjust_key_and_value( self, inference_params: InferenceParams, key_and_value: torch.Tensor, meta_args: ModelMetaArgs ): """ Saves the generated key and value tensors to the end of the buffers in inference_params. Returns the full size keys and values from the provided inference_params Returns a tuple: (key, value) """ # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= inf_max_seq_length = inference_params.max_sequence_length inf_max_batch_size = inference_params.max_batch_size if self.layer_number not in inference_params.key_value_memory_dict: inference_key_and_value_memory = self._allocate_key_and_value_memory( inf_max_seq_length, inf_max_batch_size, key_and_value.dtype ) inference_params.key_value_memory_dict[self.layer_number] = inference_key_and_value_memory else: # Get the pre-allocated buffers for this layer inference_key_and_value_memory = inference_params.key_value_memory_dict[self.layer_number] sequence_start = meta_args.slice_point * meta_args.clip_token_nums * inf_max_batch_size get_key_and_value = inference_key_and_value_memory[:sequence_start, ...].cuda() # Copy key and values. if inference_params.update_kv_cache: key_and_value_total = key_and_value clip_size = ( key_and_value_total.size(0) - meta_args.clip_token_nums * inf_max_batch_size if meta_args.distill_nearly_clean_chunk else key_and_value_total.size(0) ) sequence_end = sequence_start + clip_size assert sequence_end <= inference_key_and_value_memory.size(0) # update kv cache inference_key_and_value_memory[sequence_start:sequence_end, ...] = key_and_value_total[:clip_size] return torch.cat([get_key_and_value, key_and_value], dim=0) def adjust_key_and_value_for_inference( self, key_and_value: torch.Tensor, inference_params: InferenceParams, meta_args: ModelMetaArgs ): if inference_params is None: return torch.chunk(key_and_value, 2, dim=-1) # Only update kvcache when necessary, include 3 conditions: # 1. extract prefix video clean feature # 2. the first chunk of current kv is clean, we need to save their feature # 3. previous chunk is clean and we need to save/load their feature if meta_args.extract_prefix_video_feature or meta_args.fwd_extra_1st_chunk or meta_args.slice_point > 0: key_and_value = self._full_adjust_key_and_value(inference_params, key_and_value, meta_args) key, value = torch.chunk(key_and_value, 2, dim=-1) return key.contiguous(), value.contiguous() # ===================== # Get Query for core attn # [sq, b, (hn hd)] -> [(sq b), hn, hd] # ===================== def get_q(self, mixed_qqkv: torch.Tensor, cos_emb: torch.Tensor, sin_emb: torch.Tensor): query = self.linear_qkv.forward_q(mixed_qqkv) query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) assert self.q_layernorm is not None original_dtype = query.dtype query = query.float() query = self.q_layernorm(query) query = query.transpose(0, 1).contiguous() query = flash_apply_rotary_emb(query, cos_emb, sin_emb) query = query.to(original_dtype) return rearrange(query, "b sq hn hd -> (sq b) hn hd").contiguous() # ===================== # Get Key for core attn # [sq, b, (hn hd)] -> [(sq b), hn, hd] # ===================== def get_k(self, mixed_qqkv: torch.Tensor, cos_emb: torch.Tensor, sin_emb: torch.Tensor): key = self.linear_qkv.forward_k(mixed_qqkv) key = key.reshape(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) assert self.k_layernorm is not None original_dtype = key.dtype key = key.float() key = self.k_layernorm(key) key = key.transpose(0, 1).contiguous() key = flash_apply_rotary_emb(key, cos_emb, sin_emb) key = key.to(original_dtype) return rearrange(key, "b sq hn hd -> (sq b) hn hd").contiguous() # ===================== # Get Value for core attn # [sq, b, (hn hd)] -> [(sq b), hn, hd] # ===================== def get_v(self, mixed_qqkv: torch.Tensor): value = self.linear_qkv.forward_v(mixed_qqkv) return rearrange(value, "sq b (hn hd) -> (sq b) hn hd", hd=self.hidden_size_per_attention_head).contiguous() def get_kv(self, mixed_qqkv: torch.Tensor, cos_emb: torch.Tensor, sin_emb: torch.Tensor): # Get KV together for better performance when encoutering cpu-bound, mainly used by cuda graph key = self.get_k(mixed_qqkv, cos_emb, sin_emb) value = self.get_v(mixed_qqkv) # [(sq b), hn, hd] -> [(sq b), hn, 2 * hd] return torch.cat([key, value], dim=-1) def get_qkv(self, mixed_qqkv: torch.Tensor, cos_emb: torch.Tensor, sin_emb: torch.Tensor): # Get QKV together for better performance when encoutering cpu-bound, mainly used by cuda graph q = self.get_q(mixed_qqkv, cos_emb, sin_emb) k = self.get_k(mixed_qqkv, cos_emb, sin_emb) v = self.get_v(mixed_qqkv) return q, k, v def get_xqkv(self, mixed_qqkv: torch.Tensor, key_value_states: torch.Tensor): query_xattn = self.linear_qkv.forward_qx(mixed_qqkv) query_xattn = rearrange(query_xattn, "sq b (hn hd) -> (b sq) hn hd", hd=self.hidden_size_per_attention_head) query_xattn = self.q_layernorm_xattn(query_xattn) # [y_total_token, h] --> [y_total_token, 2*hp] mixed_kv_xattn = torch.concat( [torch.matmul(key_value_states, w.t()) for w in torch.chunk(self.linear_kv_xattn.weight, 8, axis=0)], axis=1 ) # [y_total_token, 2*hn*hd] --> [y_total_token, hn, 2*hd] mixed_kv_xattn = mixed_kv_xattn.view(key_value_states.shape[0], -1, 2 * self.hidden_size_per_attention_head) # [y_total_token, hn, 2*hd] --> 2 [y_total_token, hn, hd] (key_xattn, value_xattn) = split_tensor_along_last_dim(mixed_kv_xattn, 2) key_xattn = self.k_layernorm_xattn(key_xattn) return query_xattn, key_xattn, value_xattn def core_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bs: int, meta_args: ModelMetaArgs): # (sq b) hn hd -> b sq hn hd query = query.reshape(-1, bs, query.shape[1], query.shape[2]).transpose(0, 1).contiguous() # (sq b) hn hd -> b sq hn hd key = key.reshape(-1, bs, key.shape[1], key.shape[2]).transpose(0, 1).contiguous() # (sq b) hn hd -> b sq hn hd value = value.reshape(-1, bs, value.shape[1], value.shape[2]).transpose(0, 1).contiguous() if torch.cuda.get_device_capability()[0] >= 9: core_attn_out, _ = flex_attention( query.flatten(0, 1), key.flatten(0, 1), value.flatten(0, 1), meta_args.core_attn_params.q_range, meta_args.core_attn_params.k_range, max_seqlen_q=meta_args.core_attn_params.max_seqlen_q, max_seqlen_k=meta_args.core_attn_params.max_seqlen_k, softmax_scale=None, deterministic=torch.are_deterministic_algorithms_enabled(), disable_fwd_atomic_reduction=True, ) # (b sq) hn hd -> (sq b) hn hd core_attn_out = rearrange(core_attn_out, "(b sq) h d -> (sq b) h d", b=bs) else: # NOTE(lml): We convert multi denoising_range_num input into multi batch_size input at third time forward under 3_cfg mode, thus could not support normal multi batch_size input. We use an assert statement to ensure that it is still in this situation, thereby guaranteeing the correct use of q_range and k_range later on. assert not (bs > 1 and meta_args.denoising_range_num > 1) q_range = meta_args.core_attn_params.np_q_range k_range = meta_args.core_attn_params.np_k_range core_attn_outs = [] for i in range(meta_args.denoising_range_num): if bs == 1: q = query[:, q_range[i, 0] : q_range[i, 1]] k = key[:, k_range[i, 0] : k_range[i, 1]] v = value[:, k_range[i, 0] : k_range[i, 1]] else: assert i == 0 q = query[:, q_range[0, 0] : q_range[0, 1]] k = key[:, k_range[0, 0] : k_range[0, 1]] v = value[:, k_range[0, 0] : k_range[0, 1]] o = flash_attn_func(q=q, k=k, v=v, deterministic=torch.are_deterministic_algorithms_enabled()) o = rearrange(o, "b sq h d -> (sq b) h d", b=bs) core_attn_outs.append(o) core_attn_out = torch.cat(core_attn_outs, dim=0) return core_attn_out def full_attention(self, bs: int, meta_args: ModelMetaArgs, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, i: int): # NOTE(lml): full_attention is used under cp_shuffle_overlap strategy. We further limit it to the case of bs=1, so that we do not need to pay attention to the arrangement of sq and bs dimensions. assert bs == 1 if torch.cuda.get_device_capability()[0] >= 9: q_range = meta_args.core_attn_params.q_range[i : i + 1] - meta_args.core_attn_params.q_range[i, 0] k_range = meta_args.core_attn_params.k_range[i : i + 1] o, _ = flex_attention( q, k, v, q_ranges=q_range, k_ranges=k_range, max_seqlen_q=meta_args.core_attn_params.max_seqlen_q, max_seqlen_k=meta_args.core_attn_params.max_seqlen_k, softmax_scale=None, deterministic=torch.are_deterministic_algorithms_enabled(), disable_fwd_atomic_reduction=True, ) else: k_range = meta_args.core_attn_params.np_k_range[i : i + 1] k = k[k_range[0, 0] : k_range[0, 1]] v = v[k_range[0, 0] : k_range[0, 1]] o = flash_attn_func( q=q.unsqueeze(0), k=k.unsqueeze(0), v=v.unsqueeze(0), deterministic=torch.are_deterministic_algorithms_enabled(), ).flatten(0, 1) return o def cross_attention( self, mixed_qqkv: torch.Tensor, key_value_states: torch.Tensor, cross_attn_params: PackedCrossAttnParams, get_xqkv_func: Callable, ): # ================= # cross-attn for aggragating caption / condition # ================= query_xattn, key_xattn, value_xattn = get_xqkv_func(mixed_qqkv, key_value_states) if torch.cuda.get_device_capability()[0] >= 9: xattn_out, _ = flex_attention( query_xattn, key_xattn, value_xattn, cross_attn_params.q_ranges, cross_attn_params.kv_ranges, max_seqlen_q=cross_attn_params.max_seqlen_q, max_seqlen_k=cross_attn_params.max_seqlen_kv, softmax_scale=None, deterministic=False, disable_fwd_atomic_reduction=True, ) else: xattn_out = flash_attn_varlen_func( query_xattn, # [b*sq, hn, hd] key_xattn, # [y_total_token, hn, hd] value_xattn, # [y_total_token, hn, hd] cu_seqlens_q=cross_attn_params.cu_seqlens_q, cu_seqlens_k=cross_attn_params.cu_seqlens_kv, max_seqlen_q=cross_attn_params.max_seqlen_q, max_seqlen_k=cross_attn_params.max_seqlen_kv, deterministic=torch.are_deterministic_algorithms_enabled(), ) batch_size = mixed_qqkv.shape[1] xattn_out = rearrange(xattn_out, "(b sq) hn hd -> sq b (hn hd)", b=batch_size).contiguous() return xattn_out def forward( self, hidden_states: torch.Tensor, key_value_states: torch.Tensor, inference_params: InferenceParams, rotary_pos_emb: torch.Tensor, meta_args: ModelMetaArgs, ): assert rotary_pos_emb is not None, "FullyParallelAttention needs rotary_pos_emb" sin_emb, cos_emb = rotary_pos_emb.tensor_split(2, -1) batch_size = hidden_states.shape[1] # All comminications operate on dimensions shaped as (cp * sq * b) batch_cp_split_sizes = None if meta_args.cp_split_sizes is None else [x * batch_size for x in meta_args.cp_split_sizes] # Attention heads [sq, b, h] --> [sq, b, q + qx + k + v] mixed_qqkv = self.linear_qkv.forward_ln(hidden_states) # ===================== # Function wrapper # ===================== get_kv_func = self.get_kv get_q_func = self.get_q get_qkv_func = self.get_qkv get_xqkv_func = self.get_xqkv # ===================== # Parallel Strategy # ===================== if self.engine_config.cp_strategy == "none": assert self.engine_config.cp_size == 1 key_and_value = get_kv_func(mixed_qqkv, cos_emb, sin_emb) query = get_q_func(mixed_qqkv, cos_emb, sin_emb) key, value = self.adjust_key_and_value_for_inference(key_and_value, inference_params, meta_args) core_attn_out = self.core_attention(query, key, value, batch_size, meta_args) core_attn_out = rearrange(core_attn_out, "(sq b) hn hd -> sq b (hn hd)", b=batch_size) xattn_out = self.cross_attention(mixed_qqkv, key_value_states, meta_args.cross_attn_params, get_xqkv_func) elif self.engine_config.cp_strategy == "cp_ulysses": get_kv_func = partial(get_kv_func, mixed_qqkv, cos_emb, sin_emb) get_q_func = partial(get_q_func, mixed_qqkv, cos_emb, sin_emb) get_qkv_func = partial(get_qkv_func, mixed_qqkv, cos_emb, sin_emb) kv_cache_func = partial( self.adjust_key_and_value_for_inference, inference_params=inference_params, meta_args=meta_args ) if meta_args.enable_cuda_graph and meta_args.denoising_range_num <= 3: # Temporal solution for first chunk opt core_attn_out, xattn_out = UlyssesScheduler.get_attn_and_xattn_with_fused_qkv_comm( get_qkv_func, kv_cache_func, partial(self.core_attention, bs=batch_size, meta_args=meta_args), partial(self.cross_attention, mixed_qqkv, key_value_states, meta_args.cross_attn_params, get_xqkv_func), self.engine_config.ulysses_overlap_degree, batch_size, self.engine_config.cp_size, batch_cp_split_sizes, ) else: core_attn_out, xattn_out = UlyssesScheduler.get_attn_and_xattn_with_fused_kv_comm( get_q_func, get_kv_func, kv_cache_func, partial(self.core_attention, bs=batch_size, meta_args=meta_args), partial(self.cross_attention, mixed_qqkv, key_value_states, meta_args.cross_attn_params, get_xqkv_func), self.engine_config.ulysses_overlap_degree, batch_size, self.engine_config.cp_size, batch_cp_split_sizes, ) elif self.engine_config.cp_strategy == "cp_shuffle_overlap": key_and_value = self.get_kv(mixed_qqkv, cos_emb, sin_emb) key_and_value, handle_kv = cso_communication(key_and_value, self.engine_config.cp_size, batch_cp_split_sizes, "kv") query = get_q_func(mixed_qqkv, cos_emb, sin_emb) cso_helper = CSOHelper(meta_args.denoising_range_num, self.engine_config.cp_size, batch_cp_split_sizes) query, handle_q = cso_helper.split_query_for_overlap(query) handle_kv.wait() # NOTE(lml): rearrange and unpad key_and_value for later attention compute under cp_shuffle_overlap strategy, and we should split sqb into sq and b when support multi batch_size input. key_and_value = ( rearrange( key_and_value, "(cp dn sqb) hn nhd -> dn (cp sqb) hn nhd", dn=meta_args.denoising_range_num, cp=self.engine_config.cp_size, )[:, : meta_args.clip_token_nums] .flatten(0, 1) .contiguous() ) key, value = self.adjust_key_and_value_for_inference(key_and_value, inference_params, meta_args) handle_q.wait() core_attn_out, handle_attn = cso_helper.overlap( partial(self.full_attention, hidden_states.shape[1], meta_args), query, key, value ) xattn_out = self.cross_attention(mixed_qqkv, key_value_states, meta_args.cross_attn_params, get_xqkv_func) handle_attn.wait() core_attn_out = rearrange( torch.concat(core_attn_out, dim=0), "(dn cp sq b) hn hd -> (dn sq) b (cp hn hd)", cp=self.engine_config.cp_size, b=hidden_states.shape[1], dn=meta_args.denoising_range_num, ) else: raise ValueError(f"Unsupported cp_strategy: {self.engine_config.cp_strategy}") return core_attn_out, xattn_out ########################################################## # TransformerLayer ########################################################## class TransformerLayer(torch.nn.Module): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ def __init__(self, model_config: ModelConfig, engine_config: EngineConfig, layer_number: int = 1): super().__init__() self.model_config = model_config self.engine_config = engine_config self.layer_number = layer_number + self._get_layer_offset() ## [Module 1: ada_modulate_layer self.ada_modulate_layer = AdaModulateLayer(model_config=self.model_config) ## [Module 2: SelfAttention] self.self_attention = FullyParallelAttention( model_config=self.model_config, engine_config=self.engine_config, layer_number=self.layer_number ) ## [Module 3: SelfAttention PostNorm] self.self_attn_post_norm = FusedLayerNorm(model_config=self.model_config, hidden_size=self.model_config.hidden_size) ## [Module 4: MLP block] self.mlp = CustomMLP(model_config=self.model_config, engine_config=self.engine_config, layer_number=self.layer_number) ## [Module 5: MLP PostNorm] self.mlp_post_norm = FusedLayerNorm(model_config=self.model_config, hidden_size=self.model_config.hidden_size) def _get_layer_offset(self): pipeline_rank = parallel_state.get_pp_rank() num_layers_per_pipeline_rank = self.model_config.num_layers // parallel_state.get_pp_world_size() # Each stage gets a contiguous set of layers. if parallel_state.get_pp_world_size() > 1: offset = pipeline_rank * num_layers_per_pipeline_rank else: offset = 0 return offset def forward( self, hidden_states: torch.Tensor, condition: torch.Tensor, condition_map: torch.Tensor, y_xattn_flat: torch.Tensor, rotary_pos_emb: torch.Tensor, inference_params: InferenceParams, meta_args: ModelMetaArgs, ): # hidden_states: [s/cp/sp, b, h] residual = hidden_states # Self attention. core_attn_out, cross_attn_out = self.self_attention( hidden_states, key_value_states=y_xattn_flat, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, meta_args=meta_args, ) hidden_states = self.attn_post_process(core_attn_out, cross_attn_out, residual, condition, condition_map) return hidden_states def attn_post_process( self, core_attn_out: torch.Tensor, cross_attn_out: torch.Tensor, residual: torch.Tensor, condition: torch.Tensor, condition_map: torch.Tensor, ): hidden_states = self.attn_linear_proj(core_attn_out, cross_attn_out) hidden_states = self.gating_and_mlp(hidden_states, residual, condition, condition_map) return hidden_states def attn_linear_proj(self, core_attn_out: torch.Tensor, cross_attn_out: torch.Tensor): # ============================================ # attention post-process , output. [sq, b, h] # ============================================ attn_out = torch.concat([core_attn_out, cross_attn_out], dim=2) # NOTE: hn=8 is hardcoded to align with TP8 traning and TP1 inference attn_out = rearrange(attn_out, "sq b (n hn hd) -> sq b (hn n hd)", n=2, hn=8) if self.self_attention.adapt_linear_quant: attn_out = self.self_attention.linear_proj(attn_out) else: # Use high-precision for non-quantized linear projection with torch.autocast(device_type="cuda", dtype=torch.float32): attn_out = self.self_attention.linear_proj(attn_out) return attn_out def gating_and_mlp( self, hidden_states: torch.Tensor, residual: torch.Tensor, condition: torch.Tensor, condition_map: torch.Tensor ): gate_output = self.ada_modulate_layer(condition) softcap_gate_cap = 1.0 gate_output = softcap(gate_output, softcap_gate_cap) gate_msa, gate_mlp = gate_output.chunk(2, dim=-1) # Residual connection for self-attention. hidden_states = bias_modulate_add(hidden_states, residual, condition_map, gate_msa, self.self_attn_post_norm).to( self.model_config.params_dtype ) residual = hidden_states hidden_states = self.mlp(hidden_states) # Residual connection for MLP. hidden_states = bias_modulate_add(hidden_states, residual, condition_map, gate_mlp, self.mlp_post_norm).to( self.model_config.params_dtype ) return hidden_states ########################################################## # TransformerBlock ########################################################## class TransformerBlock(torch.nn.Module): """Transformer class.""" def __init__( self, model_config: ModelConfig, engine_config: EngineConfig, pre_process: bool = True, post_process: bool = True ): super().__init__() self.model_config = model_config self.engine_config = engine_config self.pre_process = pre_process self.post_process = post_process # required for pipeline parallel schedules self.input_tensor = None layer_number = self.model_config.num_layers // parallel_state.get_pp_world_size() # offset is implicit in TransformerLayer self.layers = torch.nn.ModuleList( [ TransformerLayer(model_config=self.model_config, engine_config=self.engine_config, layer_number=i) for i in range(layer_number) ] ) if self.post_process: # Final layer norm before output. self.final_layernorm = FusedLayerNorm(model_config=self.model_config, hidden_size=self.model_config.hidden_size) def set_input_tensor(self, input_tensor: Tensor): """Set input tensor to be used instead of forward()'s input. When doing pipeline parallelism the input from the previous stage comes from communication, not from the input, so the model's forward_step_func won't have it. This function is thus used by internal code to bypass the input provided by the forward_step_func""" self.input_tensor = input_tensor @torch.no_grad() def forward( self, hidden_states: Tensor, condition: Tensor, condition_map: Tensor, y_xattn_flat: Tensor, rotary_pos_emb: Tensor, inference_params: InferenceParams, meta_args: ModelMetaArgs, ) -> torch.Tensor: if not self.pre_process: assert self.input_tensor is not None, "please call set_input_tensor for pp" hidden_states = self.input_tensor for layer in self.layers: hidden_states = layer( hidden_states=hidden_states, condition=condition, condition_map=condition_map, y_xattn_flat=y_xattn_flat, rotary_pos_emb=rotary_pos_emb, inference_params=inference_params, meta_args=meta_args, ) # Final layer norm. if self.post_process: hidden_states = self.final_layernorm(hidden_states.float()) return hidden_states ``` The content has been capped at 50000 tokens, and files over NaN bytes have been omitted. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.