```
├── .codestyle/
├── copyright.hook
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── example/
├── 24B/
├── 24B_config.json
├── run.sh
├── 4.5B/
├── 4.5B_config.json
├── run.sh
├── assets/
├── image.jpeg
├── prefix_video.mp4
├── special_tokens.npz
├── figures/
├── algorithm.png
├── dit_architecture.png
├── inhouse_human_evaluation.png
├── logo_black.png
├── inference/
├── common/
├── __init__.py
├── common_utils.py
├── config.py
├── dataclass.py
├── logger.py
├── timer.py
├── infra/
├── checkpoint/
├── __init__.py
├── checkpointing.py
├── distributed/
├── __init__.py
├── dist_utils.py
├── parallel_state.py
├── parallelism/
├── __init__.py
├── context_parallel.py
├── pipeline_parallel.py
├── tile_parallel.py
├── model/
├── dit/
├── __init__.py
├── dit_model.py
├── dit_module.py
├── t5/
├── __init__.py
```
## /.codestyle/copyright.hook
```hook path="/.codestyle/copyright.hook"
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import io
import re
import sys
import os
import datetime
COPYRIGHT = '''Copyright (c) 2025 SandAI. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.'''
def _generate_copyright(comment_mark):
copyright=COPYRIGHT.split(os.linesep)
header = copyright[0].rstrip()
p = re.search('(\d{4})', header).group(0)
now = datetime.datetime.now()
header = header.replace(p,str(now.year))
ans=[comment_mark + " " + header + os.linesep]
for idx, line in enumerate(copyright[1:]):
ans.append(comment_mark + " " + line.rstrip() + os.linesep)
return ans
def _get_comment_mark(path):
lang_type=re.compile(r"\.(py|sh)$")
if lang_type.search(path) is not None:
return "#"
lang_type=re.compile(r"\.(h|c|hpp|cc|cpp|cu|go|cuh|proto)$")
if lang_type.search(path) is not None:
return "//"
return None
RE_ENCODE = re.compile(r"^[ \t\v]*#.*?coding[:=]", re.IGNORECASE)
RE_COPYRIGHT = re.compile(r".*Copyright \(c\) \d{4}", re.IGNORECASE)
RE_SHEBANG = re.compile(r"^[ \t\v]*#[ \t]?\!")
def _check_copyright(path):
head=[]
try:
with open(path) as f:
head = [next(f) for x in range(4)]
except StopIteration:
pass
for idx, line in enumerate(head):
if RE_COPYRIGHT.search(line) is not None:
return True
return False
def generate_copyright(path, comment_mark):
original_contents = io.open(path, encoding="utf-8").readlines()
head = original_contents[0:4]
insert_line_no=0
for i, line in enumerate(head):
if RE_ENCODE.search(line) or RE_SHEBANG.search(line):
insert_line_no=i+1
copyright = _generate_copyright(comment_mark)
if insert_line_no == 0:
new_contents = copyright
if len(original_contents) > 0 and len(original_contents[0].strip()) != 0:
new_contents.append(os.linesep)
new_contents.extend(original_contents)
else:
new_contents=original_contents[0:insert_line_no]
new_contents.append(os.linesep)
new_contents.extend(copyright)
if len(original_contents) > insert_line_no and len(original_contents[insert_line_no].strip()) != 0:
new_contents.append(os.linesep)
new_contents.extend(original_contents[insert_line_no:])
new_contents="".join(new_contents)
with io.open(path, 'w') as output_file:
output_file.write(new_contents)
def main(argv=None):
parser = argparse.ArgumentParser(
description='Checker for copyright declaration.')
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
retv = 0
for path in args.filenames:
comment_mark = _get_comment_mark(path)
if comment_mark is None:
print("warning:Unsupported file", path, file=sys.stderr)
continue
if _check_copyright(path):
continue
generate_copyright(path, comment_mark)
if __name__ == '__main__':
exit(main())
```
## /.gitignore
```gitignore path="/.gitignore"
__pycache__
*.pyc
*.log
*.pt
*.mp4
ckpt
downloads
```
## /.pre-commit-config.yaml
```yaml path="/.pre-commit-config.yaml"
exclude: \.patch$
repos:
- repo: local
hooks:
- id: copyright_checker
name: copyright_checker
entry: python3 ./.codestyle/copyright.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
args:
- --maxkb=30720
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: end-of-file-fixer
- id: trailing-whitespace
- id: requirements-txt-fixer
- id: sort-simple-yaml
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
rev: v1.5.1
hooks:
- id: remove-crlf
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: remove-tabs
name: Tabs remover (C++)
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$
args: [--whitespaces-count, '2']
- id: remove-tabs
name: Tabs remover (Python)
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
args: [--whitespaces-count, '4']
- repo: https://github.com/psf/black.git
rev: 23.3.0
hooks:
- id: black
args: [--line-length=127, --skip-string-normalization, --skip-magic-trailing-comma]
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.10.1
hooks:
- id: isort
args: [--profile=black, --line-length=127, --multi-line=3, --force-grid-wrap=0]
files: \.py$
- repo: https://github.com/PyCQA/autoflake
rev: v2.3.1
hooks:
- id: autoflake
args: [--remove-all-unused-imports, --remove-unused-variables, --in-place, --ignore-init-module-imports, --ignore-pass-after-docstring]
files: \.py$
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks.git
rev: v2.9.0
hooks:
- id: pretty-format-yaml
args: [--autofix, --indent, '4']
```
## /LICENSE
``` path="/LICENSE"
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```
## /README.md

-----
# MAGI-1: Autoregressive Video Generation at Scale
This repository contains the code for the MAGI-1 model, pre-trained weights and inference code. You can find more information on our [technical report](https://static.magi.world/static/files/MAGI_1.pdf) or directly create magic with MAGI-1 [here](http://sand.ai) . 🚀✨
## 🔥🔥🔥 Latest News
- Apr 22, 2025: We’re planning to release our 4.5B model by the end of April. Final touches are still underway — stay tuned for the official drop.
- Apr 21, 2025: MAGI-1 is here 🎉. We've released the model weights and inference code — check it out!
## 1. About
We present MAGI-1, a world model that generates videos by ***autoregressively*** predicting a sequence of video chunks, defined as fixed-length segments of consecutive frames. Trained to denoise per-chunk noise that increases monotonically over time, MAGI-1 enables causal temporal modeling and naturally supports streaming generation. It achieves strong performance on image-to-video (I2V) tasks conditioned on text instructions, providing high temporal consistency and scalability, which are made possible by several algorithmic innovations and a dedicated infrastructure stack. MAGI-1 further supports controllable generation via chunk-wise prompting, enabling smooth scene transitions, long-horizon synthesis, and fine-grained text-driven control. We believe MAGI-1 offers a promising direction for unifying high-fidelity video generation with flexible instruction control and real-time deployment.
## 2. Model Summary
### Transformer-based VAE
- Variational autoencoder (VAE) with transformer-based architecture, 8x spatial and 4x temporal compression.
- Fastest average decoding time and highly competitive reconstruction quality
### Auto-Regressive Denoising Algorithm
MAGI-1 is an autoregressive denoising video generation model generating videos chunk-by-chunk instead of as a whole. Each chunk (24 frames) is denoised holistically, and the generation of the next chunk begins as soon as the current one reaches a certain level of denoising. This pipeline design enables concurrent processing of up to four chunks for efficient video generation.

### Diffusion Model Architecture
MAGI-1 is built upon the Diffusion Transformer, incorporating several key innovations to enhance training efficiency and stability at scale. These advancements include Block-Causal Attention, Parallel Attention Block, QK-Norm and GQA, Sandwich Normalization in FFN, SwiGLU, and Softcap Modulation. For more details, please refer to the [technical report.](https://static.magi.world/static/files/MAGI_1.pdf)
### Distillation Algorithm
We adopt a shortcut distillation approach that trains a single velocity-based model to support variable inference budgets. By enforcing a self-consistency constraint—equating one large step with two smaller steps—the model learns to approximate flow-matching trajectories across multiple step sizes. During training, step sizes are cyclically sampled from {64, 32, 16, 8}, and classifier-free guidance distillation is incorporated to preserve conditional alignment. This enables efficient inference with minimal loss in fidelity.
## 3. Model Zoo
We provide the pre-trained weights for MAGI-1, including the 24B and 4.5B models, as well as the corresponding distill and distill+quant models. The model weight links are shown in the table.
| Model | Link | Recommend Machine |
| ----------------------------- | ------------------------------------------------------------ | ------------------------------- |
| T5 | [T5](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/t5) | - |
| MAGI-1-VAE | [MAGI-1-VAE](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/vae) | - |
| MAGI-1-24B | [MAGI-1-24B](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/24B_base) | H100/H800 \* 8 |
| MAGI-1-24B-distill | [MAGI-1-24B-distill](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/24B_distill) | H100/H800 \* 8 |
| MAGI-1-24B-distill+fp8_quant | [MAGI-1-24B-distill+quant](https://huggingface.co/sand-ai/MAGI-1/tree/main/ckpt/magi/24B_distill_quant) | H100/H800 \* 4 or RTX 4090 \* 8 |
| MAGI-1-4.5B | MAGI-1-4.5B | RTX 4090 \* 1 |
## 4. Evaluation
### In-house Human Evaluation
MAGI-1 achieves state-of-the-art performance among open-source models like Wan-2.1 and HunyuanVideo and closed-source model like Hailuo (i2v-01), particularly excelling in instruction following and motion quality, positioning it as a strong potential competitor to closed-source commercial models such as Kling.

### Physical Evaluation
Thanks to the natural advantages of autoregressive architecture, Magi achieves far superior precision in predicting physical behavior on the [Physics-IQ benchmark](https://github.com/google-deepmind/physics-IQ-benchmark) through video continuation—significantly outperforming all existing models.
| Model | Phys. IQ Score ↑ | Spatial IoU ↑ | Spatio Temporal ↑ | Weighted Spatial IoU ↑ | MSE ↓ |
|----------------|------------------|---------------|-------------------|-------------------------|--------|
| **V2V Models** | | | | | |
| **Magi (V2V)** | **56.02** | **0.367** | **0.270** | **0.304** | **0.005** |
| VideoPoet (V2V)| 29.50 | 0.204 | 0.164 | 0.137 | 0.010 |
| **I2V Models** | | | | | |
| **Magi (I2V)** | **30.23** | **0.203** | **0.151** | **0.154** | **0.012** |
| Kling1.6 (I2V) | 23.64 | 0.197 | 0.086 | 0.144 | 0.025 |
| VideoPoet (I2V)| 20.30 | 0.141 | 0.126 | 0.087 | 0.012 |
| Gen 3 (I2V) | 22.80 | 0.201 | 0.115 | 0.116 | 0.015 |
| Wan2.1 (I2V) | 20.89 | 0.153 | 0.100 | 0.112 | 0.023 |
| Sora (I2V) | 10.00 | 0.138 | 0.047 | 0.063 | 0.030 |
| **GroundTruth**| **100.0** | **0.678** | **0.535** | **0.577** | **0.002** |
## 5. How to run
### Environment Preparation
We provide two ways to run MAGI-1, with the Docker environment being the recommended option.
**Run with Docker Environment (Recommend)**
```bash
docker pull sandai/magi:latest
docker run -it --gpus all --privileged --shm-size=32g --name magi --net=host --ipc=host --ulimit memlock=-1 --ulimit stack=6710886 sandai/magi:latest /bin/bash
```
**Run with Source Code**
```bash
# Create a new environment
conda create -n magi python==3.10.12
# Install pytorch
conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.4 -c pytorch -c nvidia
# Install other dependencies
pip install -r requirements.txt
# Install ffmpeg
conda install -c conda-forge ffmpeg=4.4
# Install MagiAttention, for more information, please refer to https://github.com/SandAI-org/MagiAttention#
git clone git@github.com:SandAI-org/MagiAttention.git
cd MagiAttention
git submodule update --init --recursive
pip install --no-build-isolation .
```
### Inference Command
To run the `MagiPipeline`, you can control the input and output by modifying the parameters in the `example/24B/run.sh` or `example/4.5B/run.sh` script. Below is an explanation of the key parameters:
#### Parameter Descriptions
- `--config_file`: Specifies the path to the configuration file, which contains model configuration parameters, e.g., `example/24B/24B_config.json`.
- `--mode`: Specifies the mode of operation. Available options are:
- `t2v`: Text to Video
- `i2v`: Image to Video
- `v2v`: Video to Video
- `--prompt`: The text prompt used for video generation, e.g., `"Good Boy"`.
- `--image_path`: Path to the image file, used only in `i2v` mode.
- `--prefix_video_path`: Path to the prefix video file, used only in `v2v` mode.
- `--output_path`: Path where the generated video file will be saved.
#### Bash Script
```bash
#!/bin/bash
# Run 24B MAGI-1 model
bash example/24B/run.sh
# Run 4.5B MAGI-1 model
bash example/4.5B/run.sh
```
#### Customizing Parameters
You can modify the parameters in `run.sh` as needed. For example:
- To use the Image to Video mode (`i2v`), set `--mode` to `i2v` and provide `--image_path`:
```bash
--mode i2v \
--image_path example/assets/image.jpeg \
```
- To use the Video to Video mode (`v2v`), set `--mode` to `v2v` and provide `--prefix_video_path`:
```bash
--mode v2v \
--prefix_video_path example/assets/prefix_video.mp4 \
```
By adjusting these parameters, you can flexibly control the input and output to meet different requirements.
### Some Useful Configs (for config.json)
> NOTE: If you are running 24B model with RTX 4090 \* 8, please set `pp_size:2 cp_size: 4`.
| Config | Help |
| -------------- | ------------------------------------------------------------ |
| seed | Random seed used for video generation |
| video_size_h | Height of the video |
| video_size_w | Width of the video |
| num_frames | Controls the duration of generated video |
| fps | Frames per second, 4 video frames correspond to 1 latent_frame |
| cfg_number | Base model uses cfg_number==3, distill and quant model uses cfg_number=1 |
| load | Directory containing a model checkpoint. |
| t5_pretrained | Path to load pretrained T5 model |
| vae_pretrained | Path to load pretrained VAE model |
## 6. License
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
## 7. Citation
If you find our code or model useful in your research, please cite:
```bibtex
@misc{magi1,
title={MAGI-1: Autoregressive Video Generation at Scale},
author={Sand-AI},
year={2025},
url={https://static.magi.world/static/files/MAGI_1.pdf},
}
```
## 8. Contact
If you have any questions, please feel free to raise an issue or contact us at [research@sand.ai](mailto:research@sand.ai) .
## /example/24B/24B_config.json
```json path="/example/24B/24B_config.json"
{
"model_config": {
"model_name": "videodit_ardf",
"num_layers": 48,
"hidden_size": 6144,
"ffn_hidden_size": 16384,
"num_attention_heads": 48,
"num_query_groups": 8,
"kv_channels": 128,
"layernorm_epsilon": 1e-06,
"apply_layernorm_1p": true,
"x_rescale_factor": 0.1,
"half_channel_vae": true,
"params_dtype": "torch.bfloat16",
"patch_size": 2,
"t_patch_size": 1,
"in_channels": 32,
"out_channels": 32,
"cond_hidden_ratio": 0.25,
"caption_channels": 4096,
"caption_max_length": 800,
"xattn_cond_hidden_ratio": 1.0,
"cond_gating_ratio": 1.0,
"gated_linear_unit": true
},
"runtime_config": {
"cfg_number": 1,
"cfg_t_range": [
0.0,
0.0217,
0.1,
0.3,
0.999
],
"prev_chunk_scales": [
1.5,
1.5,
1.5,
1.0,
1.0
],
"text_scales": [
7.5,
7.5,
7.5,
0.0,
0.0
],
"noise2clean_kvrange": [
5,
4,
3,
2
],
"clean_chunk_kvrange": 1,
"clean_t": 0.9999,
"seed": 1234,
"num_frames": 96,
"video_size_h": 720,
"video_size_w": 1280,
"num_steps": 8,
"window_size": 4,
"fps": 24,
"chunk_width": 6,
"load": "./downloads/24B_base",
"t5_pretrained": "./downloads/t5_pretrained",
"t5_device": "cuda",
"vae_pretrained": "./downloads/vae",
"scale_factor": 0.18215,
"temporal_downsample_factor": 4
},
"engine_config": {
"distributed_backend": "nccl",
"distributed_timeout_minutes": 15,
"pp_size": 1,
"cp_size": 8,
"cp_strategy": "cp_ulysses",
"ulysses_overlap_degree": 1,
"fp8_quant": true,
"distill_nearly_clean_chunk_threshold": 0.3,
"shortcut_mode": "8,16,16",
"distill": true,
"kv_offload": true,
"enable_cuda_graph": false
}
}
```
## /example/24B/run.sh
```sh path="/example/24B/run.sh"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_ALGO=^NVLS
export PAD_HQ=1
export PAD_DURATION=1
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export OFFLOAD_T5_CACHE=true
export OFFLOAD_VAE_CACHE=true
export TORCH_CUDA_ARCH_LIST="8.9;9.0"
GPUS_PER_NODE=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
DISTRIBUTED_ARGS="
--rdzv-backend=c10d \
--rdzv-endpoint=localhost:6009 \
--nnodes=1 \
--nproc_per_node=$GPUS_PER_NODE
"
MAGI_ROOT=$(git rev-parse --show-toplevel)
LOG_DIR=log_$(date "+%Y-%m-%d_%H:%M:%S").log
export PYTHONPATH="$MAGI_ROOT:$PYTHONPATH"
torchrun $DISTRIBUTED_ARGS inference/pipeline/entry.py \
--config_file example/24B/24B_config.json \
--mode i2v \
--prompt "Good Boy" \
--image_path example/assets/image.jpeg \
--output_path example/assets/output_i2v.mp4 \
2>&1 | tee $LOG_DIR
```
## /example/4.5B/4.5B_config.json
```json path="/example/4.5B/4.5B_config.json"
{
"model_config": {
"model_name": "videodit_ardf",
"num_layers": 34,
"hidden_size": 3072,
"ffn_hidden_size": 12288,
"num_attention_heads": 24,
"num_query_groups": 8,
"kv_channels": 128,
"layernorm_epsilon": 1e-06,
"apply_layernorm_1p": true,
"x_rescale_factor": 1,
"half_channel_vae": false,
"params_dtype": "torch.bfloat16",
"patch_size": 2,
"t_patch_size": 1,
"in_channels": 16,
"out_channels": 16,
"cond_hidden_ratio": 0.25,
"caption_channels": 4096,
"caption_max_length": 800,
"xattn_cond_hidden_ratio": 1.0,
"cond_gating_ratio": 1.0,
"gated_linear_unit": false
},
"runtime_config": {
"cfg_number": 3,
"cfg_t_range": [
0.0,
0.0217,
0.1,
0.3,
0.999
],
"prev_chunk_scales": [
1.5,
1.5,
1.5,
1.0,
1.0
],
"text_scales": [
7.5,
7.5,
7.5,
0.0,
0.0
],
"noise2clean_kvrange": [
5,
4,
3,
2
],
"clean_chunk_kvrange": 1,
"clean_t": 0.9999,
"seed": 1234,
"num_frames": 192,
"video_size_h": 720,
"video_size_w": 1280,
"num_steps": 64,
"window_size": 4,
"fps": 24,
"chunk_width": 6,
"load": "./downloads/4.5B_base",
"t5_pretrained": "./downloads/t5_pretrained",
"t5_device": "cpu",
"vae_pretrained": "./downloads/vae",
"scale_factor": 0.18215,
"temporal_downsample_factor": 4
},
"engine_config": {
"distributed_backend": "nccl",
"distributed_timeout_minutes": 15,
"pp_size": 1,
"cp_size": 1,
"cp_strategy": "cp_ulysses",
"ulysses_overlap_degree": 1,
"fp8_quant": false,
"distill_nearly_clean_chunk_threshold": 0.3,
"shortcut_mode": "8,16,16",
"distill": false,
"kv_offload": true,
"enable_cuda_graph": false
}
}
```
## /example/4.5B/run.sh
```sh path="/example/4.5B/run.sh"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export MASTER_ADDR=localhost
export MASTER_PORT=6009
export GPUS_PER_NODE=1
export NNODES=1
export WORLD_SIZE=1
export CUDA_VISIBLE_DEVICES=1
export PAD_HQ=1
export PAD_DURATION=1
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export OFFLOAD_T5_CACHE=true
export OFFLOAD_VAE_CACHE=true
export TORCH_CUDA_ARCH_LIST="8.9;9.0"
MAGI_ROOT=$(git rev-parse --show-toplevel)
LOG_DIR=log_$(date "+%Y-%m-%d_%H:%M:%S").log
export PYTHONPATH="$MAGI_ROOT:$PYTHONPATH"
python3 inference/pipeline/entry.py \
--config_file example/4.5B/4.5B_config.json \
--mode t2v \
--prompt "Good Boy" \
--output_path example/assets/output_t2v.mp4 \
2>&1 | tee $LOG_DIR
```
## /example/assets/image.jpeg
Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/example/assets/image.jpeg
## /example/assets/prefix_video.mp4
Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/example/assets/prefix_video.mp4
## /example/assets/special_tokens.npz
Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/example/assets/special_tokens.npz
## /figures/algorithm.png
Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/figures/algorithm.png
## /figures/dit_architecture.png
Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/figures/dit_architecture.png
## /figures/inhouse_human_evaluation.png
Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/figures/inhouse_human_evaluation.png
## /figures/logo_black.png
Binary file available at https://raw.githubusercontent.com/SandAI-org/MAGI-1/refs/heads/main/figures/logo_black.png
## /inference/common/__init__.py
```py path="/inference/common/__init__.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .common_utils import divide, env_is_true, set_random_seed
from .config import EngineConfig, MagiConfig, ModelConfig, RuntimeConfig
from .dataclass import InferenceParams, ModelMetaArgs, PackedCoreAttnParams, PackedCrossAttnParams
from .logger import magi_logger, print_per_rank, print_rank_0
from .timer import event_path_timer
__all__ = [
"MagiConfig",
"ModelConfig",
"EngineConfig",
"RuntimeConfig",
"magi_logger",
"print_per_rank",
"print_rank_0",
"event_path_timer",
"divide",
"env_is_true",
"set_random_seed",
"PackedCoreAttnParams",
"PackedCrossAttnParams",
"ModelMetaArgs",
"InferenceParams",
]
```
## /inference/common/common_utils.py
```py path="/inference/common/common_utils.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy as np
import torch
def env_is_true(env_name: str) -> bool:
return str(os.environ.get(env_name, "0")).lower() in {"1", "true", "yes", "y", "on", "enabled"}
def divide(numerator, denominator):
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
return numerator // denominator
def set_random_seed(seed):
"""Set random seed.
Args:
seed (int): Seed to be used.
If not provided or set to 0, a random seed will be generated.
"""
if not seed or seed == 0:
seed = random.randint(0, 2**32 - 1)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
return seed
```
## /inference/common/config.py
```py path="/inference/common/config.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
import os
import torch
@dataclasses.dataclass
class ModelConfig:
model_name: str
# Transformer
num_layers: int = None # Number of transformer layers.
hidden_size: int = None # Transformer hidden size.
ffn_hidden_size: int = None # Transformer Feed-Forward Network hidden size
num_attention_heads: int = None # Number of transformer attention heads.
num_query_groups: int = 1 # Number of query groups, which used for GQA
kv_channels: int = None # Projection weights dimension in multi-head attention
layernorm_epsilon: float = 1e-6 # Epsilon for layer norm and RMS norm.
apply_layernorm_1p: bool = False # Adjust LayerNorm weights which improves numerical stability.
x_rescale_factor: float = 1.0
half_channel_vae: bool = False
params_dtype: torch.dtype = None
# Embedding
patch_size: int = 2 # (latent) patch size for DiT patch embedding layer
t_patch_size: int = 1 # (latent) patch size for t dim patch embedding layer
in_channels: int = 4 # latent input channel for DiT
out_channels: int = 4 # latent output channel for DiT
cond_hidden_ratio: float = 0.25
caption_channels: int = 4096
caption_max_length: int = 800
xattn_cond_hidden_ratio: float = 1.0
cond_gating_ratio: float = 1.0
gated_linear_unit: bool = False
@dataclasses.dataclass
class RuntimeConfig:
# Inference settings such as cfg, kv range, clean t, etc.
cfg_number: int = None # Number of CFG
cfg_t_range: list = dataclasses.field(
default_factory=lambda: [0, 0.0217, 0.1000, 0.3, 0.999]
) # CFG t-range of each scales
prev_chunk_scales: list = dataclasses.field(
default_factory=lambda: [1.5, 1.5, 1.5, 1.5, 1.5]
) # CFG scales of previous chunks
text_scales: list = dataclasses.field(default_factory=lambda: [7.5, 7.5, 7.5, 7.5, 7.5]) # CFG scales of text
noise2clean_kvrange: list = dataclasses.field(default_factory=list) # Range of kv for noise2clean chunks
clean_chunk_kvrange: int = -1 # Range of kv for clean chunks
clean_t: float = 1.0 # timestep for clean chunks
# Video settings
seed: int = 1234 # Random seed used for python, numpy, pytorch, and cuda.
num_frames: int = 128
video_size_h: int = None
video_size_w: int = None
num_steps: int = 64 # Number of steps for the diffusion model
window_size: int = 4 # Window size for the diffusion model
fps: int = 24 # Frames per second
chunk_width: int = 6 # Clip width for the diffusion model
# Checkpoint, includes t5, vae, dit, etc.
t5_pretrained: str = None # Path to load pretrained T5 model.
t5_device: str = "cuda" # Device for T5 model to run on.
vae_pretrained: str = None # Path to load pretrained VAE model.
scale_factor: float = 0.18215 # Scale factor for the vae
temporal_downsample_factor: int = 4 # Temporal downsample factor for the vae
load: str = None # Directory containing a model checkpoint.
@dataclasses.dataclass
class EngineConfig:
# Parallism strategy
distributed_backend: str = "nccl" # Choices: ["nccl", "gloo"]
distributed_timeout_minutes: int = 10 # Timeout minutes for torch.distributed.
pp_size: int = 1 # Degree of pipeline model parallelism.
cp_size: int = 1 # Degree of context parallelism.
cp_strategy: str = "none" # Choices: ["none", "cp_ulysses", "cp_shuffle_overlap"]
ulysses_overlap_degree: int = 1 # Overlap degree for Ulysses
# Quantization
fp8_quant: bool = False # Enable 8-bit floating point quantization for model weights.
# Distillation
distill_nearly_clean_chunk_threshold: float = 0.3 # Threshold for distilling nearly clean chunks
shortcut_mode: str = "8,16,16" # Parameters for shortcut mode
distill: bool = False # Use distill mode
# Optimization
kv_offload: bool = False # Use kv-offload algorithm
enable_cuda_graph: bool = False # Enable CUDA graph for video generation
@dataclasses.dataclass
class MagiConfig:
model_config: ModelConfig
runtime_config: RuntimeConfig
engine_config: EngineConfig
@classmethod
def _check_missing_fields(cls, config_dict: dict, required_fields: list):
actual_fields = set(config_dict.keys())
missing_fields = set(required_fields) - actual_fields
if missing_fields:
raise ValueError(f"Missing fields in the configuration file: {', '.join(missing_fields)}")
@classmethod
def _create_nested_config(cls, config_dict: dict, config_name: str, config_cls):
nested_config_dict = config_dict.get(config_name, {})
cls._check_missing_fields(nested_config_dict, config_cls.__dataclass_fields__.keys())
return config_cls(**nested_config_dict)
@classmethod
def _create_config_from_dict(cls, config_dict: dict):
cls._check_missing_fields(config_dict, cls.__dataclass_fields__.keys())
# Create nested configs
model_config = cls._create_nested_config(config_dict, "model_config", ModelConfig)
runtime_config = cls._create_nested_config(config_dict, "runtime_config", RuntimeConfig)
engine_config = cls._create_nested_config(config_dict, "engine_config", EngineConfig)
return cls(model_config=model_config, runtime_config=runtime_config, engine_config=engine_config)
@classmethod
def from_json(cls, json_path: str):
def simple_json_decoder(dct):
dtype_map = {"torch.bfloat16": torch.bfloat16, "torch.float16": torch.float16, "torch.float32": torch.float32}
if 'params_dtype' in dct:
dct['params_dtype'] = dtype_map[dct['params_dtype']]
return dct
with open(json_path, "r") as f:
config_dict = json.load(f, object_hook=simple_json_decoder)
magi_config = cls._create_config_from_dict(config_dict)
def post_validation(magi_config):
if magi_config.engine_config.fp8_quant or magi_config.engine_config.distill:
assert (
magi_config.runtime_config.cfg_number == 1
), "Please set `cfg_number: 1` in config.json for distill or quant model"
else:
assert magi_config.runtime_config.cfg_number == 3, "Please set `cfg_number: 3` in config.json for base model"
post_validation(magi_config)
return magi_config
def to_json(self, json_path: str):
class SimpleJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, torch.dtype):
return str(obj)
return super().default(obj)
# Ensure the directory exists
os.makedirs(os.path.dirname(json_path), exist_ok=True)
config_dict = {
"model_config": dataclasses.asdict(self.model_config),
"runtime_config": dataclasses.asdict(self.runtime_config),
"engine_config": dataclasses.asdict(self.engine_config),
}
with open(json_path, "w") as f:
json.dump(config_dict, f, indent=4, cls=SimpleJSONEncoder)
```
## /inference/common/dataclass.py
```py path="/inference/common/dataclass.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List
import numpy as np
import torch
@dataclass(frozen=True)
class PackedCoreAttnParams:
# Packed sequence parameters for core_attn
q_range: torch.Tensor
k_range: torch.Tensor
np_q_range: np.ndarray
np_k_range: np.ndarray
max_seqlen_q: int
max_seqlen_k: int
@dataclass(frozen=True)
class PackedCrossAttnParams:
# Packed sequence parameters for cross_attn
q_ranges: torch.Tensor = None
kv_ranges: torch.Tensor = None
cu_seqlens_q: torch.Tensor = None
cu_seqlens_kv: torch.Tensor = None
max_seqlen_q: int = None
max_seqlen_kv: int = None
@dataclass(frozen=True)
class ModelMetaArgs:
H: int
W: int
cp_pad_size: int
cp_split_sizes: List[int]
slice_point: int
denoising_range_num: int
range_num: int
extract_prefix_video_feature: bool
fwd_extra_1st_chunk: bool
distill_nearly_clean_chunk: bool
clip_token_nums: int
enable_cuda_graph: bool
core_attn_params: PackedCoreAttnParams
cross_attn_params: PackedCrossAttnParams
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.key_value_memory_dict = {}
self.update_kv_cache = False
def swap_key_value_dict(self, batch_idx):
"swap between batches"
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number in self.key_value_memory_dict.keys():
inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
assert len(batch_idx) == inference_key_memory.shape[1] # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_idx]
new_inference_value_memory = inference_value_memory[:, batch_idx]
self.key_value_memory_dict[layer_number] = (new_inference_key_memory, new_inference_value_memory)
```
## /inference/common/logger.py
```py path="/inference/common/logger.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
class GlobalLogger:
_logger = None
@classmethod
def get_logger(cls, name=__name__, level=logging.INFO):
if cls._logger is None:
cls._logger = logging.getLogger("magi_logger")
cls._logger.setLevel(logging.INFO)
cls._logger.propagate = False
cls._logger.handlers.clear()
formatter = logging.Formatter("[%(asctime)s - %(levelname)s] %(message)s")
handler = logging.StreamHandler()
handler.setFormatter(formatter)
cls._logger.addHandler(handler)
return cls._logger
magi_logger = GlobalLogger.get_logger()
def print_per_rank(message):
magi_logger.info(message)
def print_rank_0(message):
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
magi_logger.info(message)
else:
magi_logger.info(message)
```
## /inference/common/timer.py
```py path="/inference/common/timer.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import datetime
import torch
from .logger import print_rank_0
class EventPathTimer:
"""
A lightweight class for recording time without any distributed barrier.
This class allows for recording elapsed time between events without requiring
synchronization across distributed processes. It maintains the previous message
and time to calculate the duration between consecutive records.
"""
def __init__(self):
"""
Initialize the EventPathTimer.
This constructor sets the previous message and time to None, preparing
the instance for recording events.
"""
self.prev_message: str = None
self.prev_time: datetime = None
def reset(self):
"""
Reset the recorded message and time.
This method clears the previous message and time, allowing for a fresh
start in recording new events.
"""
self.prev_message = None
self.prev_time = None
def synced_record(self, message):
"""
Record the current time with a message.
Args:
message (str): A message to log along with the current time.
This method synchronizes the CUDA operations, records the current time,
and calculates the elapsed time since the last recorded message, if any.
It then logs the elapsed time along with the previous and current messages.
"""
torch.cuda.synchronize()
current_time = datetime.now()
if self.prev_message is not None:
print_rank_0(
f"\nTime Elapsed: [{current_time - self.prev_time}] From [{self.prev_message} ({self.prev_time})] To [{message} ({current_time})]"
)
self.prev_message = message
self.prev_time = current_time
_GLOBAL_LIGHT_TIMER = EventPathTimer()
def event_path_timer() -> EventPathTimer:
"""Get the current EventPathTimer instance.
Returns:
EventPathTimer: The current EventPathTimer instance.
Raises:
AssertionError: If the EventPathTimer has not been initialized.
"""
assert _GLOBAL_LIGHT_TIMER is not None, "light time recorder is not initialized"
return _GLOBAL_LIGHT_TIMER
```
## /inference/infra/checkpoint/__init__.py
```py path="/inference/infra/checkpoint/__init__.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .checkpointing import load_checkpoint
__all__ = ["load_checkpoint"]
```
## /inference/infra/checkpoint/checkpointing.py
```py path="/inference/infra/checkpoint/checkpointing.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import json
import os
import re
import subprocess
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import numpy as np
import torch
import torch.distributed
from safetensors.torch import load as load_from_bytes
from safetensors.torch import load_file
from tqdm.auto import tqdm
import inference.infra.distributed.parallel_state as mpu
from inference.common import EngineConfig, ModelConfig, RuntimeConfig, print_per_rank, print_rank_0
def _load_shard(shard_path, param_names, num_threads=None):
zstd_path = shard_path + ".zst"
if os.path.exists(zstd_path):
start_time = datetime.now()
print_per_rank(f"Decompressing {zstd_path} with {num_threads} threads")
cmd = ["zstd", "-d"]
if num_threads:
cmd.extend(["-T", str(num_threads)])
process = subprocess.Popen(cmd + ["-c", zstd_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=-1)
decompressed_data = process.stdout.read()
process.stdout.close()
retcode = process.wait()
if retcode != 0:
raise RuntimeError(f"Decompression failed: {process.stderr.read().decode()}")
print_per_rank(
f"Decompressed {zstd_path} with {num_threads} threads, duration: {(datetime.now() - start_time).total_seconds()}s"
)
buffer = io.BytesIO(decompressed_data)
start_time = datetime.now()
print_per_rank(f"Loading {shard_path} from zstd file, start time: {start_time}")
weights = load_from_bytes(buffer.getvalue())
print_per_rank(f"Loaded {shard_path} from zstd file, duration: {(datetime.now() - start_time).total_seconds()}s")
buffer.close()
else:
weights = load_file(shard_path)
return {name: weights[name] for name in param_names}
def load_sharded_safetensors_parallel_with_progress(checkpoint_dir):
index_path = os.path.join(checkpoint_dir, "model.safetensors.index.json")
with open(index_path, "r") as f:
index = json.load(f)
state_dict = {}
shard_map = {}
# Group parameters by shard file
for param_name, shard_file in index["weight_map"].items():
shard_path = os.path.join(checkpoint_dir, shard_file)
if shard_path not in shard_map:
shard_map[shard_path] = []
shard_map[shard_path].append(param_name)
# Load shards in parallel with a progress bar
with ThreadPoolExecutor() as executor:
futures = {
executor.submit(_load_shard, shard_path, param_names): shard_path for shard_path, param_names in shard_map.items()
}
pbar = tqdm(futures, desc="Loading shards", total=len(futures))
for future in pbar:
result = future.result()
state_dict.update(result)
return state_dict
def unwrap_model(model):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while hasattr(model_module, "module"):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def _split_state_dict_for_pp(weight_dict: OrderedDict, model_config: ModelConfig):
num_layers = model_config.num_layers
partition = mpu.get_pp_world_size()
## use partition and num_layers to get current rank layer order
layers_for_each_stage = np.array_split(range(num_layers), partition)
current_stage = mpu.get_pp_rank()
allow_layer_num = layers_for_each_stage[current_stage]
layer_offset = allow_layer_num[0]
new_weight_dict = {}
for k, v in weight_dict.items():
if "videodit_blocks.layers" in k:
layer_num = int(re.search(r"videodit_blocks\.layers\.(\d+)", k).group(1))
if layer_num not in allow_layer_num:
continue
## replace the old key name by new layer number
new_layer_num = layer_num - layer_offset
new_k = k.replace(f"videodit_blocks.layers.{layer_num}", f"videodit_blocks.layers.{new_layer_num}")
new_weight_dict[new_k] = v
else:
new_weight_dict[k] = v
return new_weight_dict
def load_state_dict(runtime_config: RuntimeConfig, engine_config: EngineConfig):
load_dir = runtime_config.load
default_subdir = "inference_weight"
if engine_config.fp8_quant:
default_subdir = f"{default_subdir}.fp8"
if engine_config.distill:
default_subdir = f"{default_subdir}.distill"
inference_weight_dir = os.path.join(load_dir, default_subdir)
assert os.path.exists(inference_weight_dir)
print_rank_0(f"load {default_subdir} weight from {inference_weight_dir}")
assert (
os.path.exists(inference_weight_dir) and len(os.listdir(inference_weight_dir)) > 0
), f"Ckpt directory {inference_weight_dir} does not exist or empty. If you are using fp8_quant, please run calibration first."
state_dict = load_sharded_safetensors_parallel_with_progress(inference_weight_dir)
return state_dict
def load_checkpoint(model):
state_dict = load_state_dict(model.runtime_config, model.engine_config)
model = unwrap_model(model)
# if we use pipeline parallelism, we need to load the state dict for each stage
# as it always record layer from 0 -> num_layers//pipeline_parallel_size
# so we need to choose correct layer weight when load_state_dict
if mpu.get_pp_world_size() > 1:
state_dict = _split_state_dict_for_pp(state_dict, model.model_config)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
model.cuda(torch.cuda.current_device())
if mpu.get_pp_world_size() > 1:
rank_msg = f"CP_rank={mpu.get_cp_rank()} PP_rank={mpu.get_pp_rank()}"
print_per_rank(
f"""[{rank_msg}] Load Weight Missing Keys: {missing_keys} Load Weight Unexpected Keys: {unexpected_keys} You should see message [missing fianl layer norm weight] except the final pipeline stage"""
)
else:
print_rank_0(f"Load Weight Missing Keys: {missing_keys}")
print_rank_0(f"Load Weight Unexpected Keys: {unexpected_keys}")
return model
```
## /inference/infra/distributed/__init__.py
```py path="/inference/infra/distributed/__init__.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dist_utils import dist_init, get_device, get_world_size, is_last_rank, is_last_tp_cp_rank
from .parallel_state import (
destroy_model_parallel,
get_cp_group,
get_cp_rank,
get_cp_world_size,
get_dp_group,
get_dp_group_gloo,
get_dp_rank,
get_dp_world_size,
get_pipeline_model_parallel_first_rank,
get_pipeline_model_parallel_last_rank,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_pp_group,
get_pp_rank,
get_pp_world_size,
get_tensor_model_parallel_last_rank,
get_tensor_model_parallel_ranks,
get_tensor_model_parallel_src_rank,
get_tp_group,
get_tp_rank,
get_tp_world_size,
is_initialized,
is_pipeline_first_stage,
is_pipeline_last_stage,
)
__all__ = [
"dist_init",
"is_initialized",
"get_tp_group",
"get_pp_group",
"get_dp_group",
"get_dp_group_gloo",
"get_cp_group",
"get_tp_world_size",
"get_pp_world_size",
"get_dp_world_size",
"get_cp_world_size",
"get_tp_rank",
"get_pp_rank",
"get_dp_rank",
"get_cp_rank",
"is_pipeline_first_stage",
"is_pipeline_last_stage",
"get_tensor_model_parallel_src_rank",
"get_tensor_model_parallel_ranks",
"get_tensor_model_parallel_last_rank",
"get_pipeline_model_parallel_first_rank",
"get_pipeline_model_parallel_last_rank",
"get_pipeline_model_parallel_next_rank",
"get_pipeline_model_parallel_prev_rank",
"destroy_model_parallel",
"is_last_rank",
"is_last_tp_cp_rank",
"get_world_size",
"get_device",
]
```
## /inference/infra/distributed/dist_utils.py
```py path="/inference/infra/distributed/dist_utils.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import timedelta
import torch
import inference.infra.distributed.parallel_state as mpu
from inference.common import print_rank_0
from inference.infra.parallelism.pipeline_parallel import init_pp_scheduler
from . import parallel_state as mpu
def dist_init(config):
"""Initialize torch.distributed and core model parallel."""
assert torch.cuda.is_available()
device_count = torch.cuda.device_count()
if torch.distributed.is_initialized():
print_rank_0("Torch distribution already initialized, skipping initialization ...")
else:
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
# Manually set the device ids.
if device_count > 0:
device = rank % device_count
torch.cuda.set_device(device)
# Call the init process
torch.distributed.init_process_group(
backend=config.engine_config.distributed_backend,
world_size=world_size,
rank=rank,
timeout=timedelta(minutes=config.engine_config.distributed_timeout_minutes),
)
assert config.engine_config.cp_size * config.engine_config.pp_size == torch.distributed.get_world_size()
if device_count > 0:
if mpu.model_parallel_is_initialized():
print_rank_0("Model parallel is already initialized")
else:
mpu.initialize_model_parallel(
cp_size=config.engine_config.cp_size,
pp_size=config.engine_config.pp_size,
nccl_communicator_config_path=None,
distributed_timeout_minutes=config.engine_config.distributed_timeout_minutes,
order="tp-cp-pp-dp",
)
if mpu.get_pp_world_size() > 1:
init_pp_scheduler()
print_rank_0("Initialize torch distribution and model parallel successfully")
def is_last_rank():
return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1)
def is_last_tp_cp_rank():
return mpu.get_tp_rank(with_context_parallel=True) == mpu.get_tp_world_size(with_context_parallel=True) - 1
def get_world_size():
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
else:
world_size = 1
return world_size
def get_device(local_rank=None):
backend = torch.distributed.get_backend()
if backend == "nccl":
if local_rank is None:
device = torch.device("cuda")
else:
device = torch.device(f"cuda:{local_rank}")
elif backend == "gloo":
device = torch.device("cpu")
else:
raise RuntimeError
return device
```
## /inference/infra/distributed/parallel_state.py
```py path="/inference/infra/distributed/parallel_state.py"
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model and data parallel groups."""
import warnings
from datetime import timedelta
from typing import List, Optional
import torch
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Tensor parallel group information with context parallel combined.
_TENSOR_MODEL_PARALLEL_GROUP_WITH_CP = None
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_GLOO = None
# tensor model parallel group and data parallel group combined
# used for fp8 and moe training
_TENSOR_AND_DATA_PARALLEL_GROUP = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None
# A list of global ranks for each tensor model parallel group to ease calculation of
# the first local rank in the tensor model parallel group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None
# Context parallel group that the current rank belongs to
_CONTEXT_PARALLEL_GROUP = None
# A list of global ranks for each context parallel group to ease calculation of the
# destination rank when exchanging KV/dKV between context parallel_ranks
_CONTEXT_PARALLEL_GLOBAL_RANKS = None
# Data parallel group information with context parallel combined.
_DATA_PARALLEL_GROUP_WITH_CP = None
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None
# combined parallel group of TP, DP, and CP used for fp8
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
def get_nccl_options(pg_name, nccl_comm_cfgs):
"""Set the NCCL process group options.
Args:
pg_name (str): process group name
nccl_comm_cfgs (dict): nccl communicator configurations
When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting.
"""
if pg_name in nccl_comm_cfgs:
nccl_options = torch.distributed.ProcessGroupNCCL.Options()
nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get("cga_cluster_size", 4)
nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get("max_ctas", 32)
nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get("min_ctas", 1)
return nccl_options
else:
return None
def generate_masked_orthogonal_rank_groups(world_size: int, parallel_size: List[int], mask: List[bool]) -> List[List[int]]:
"""Generate orthogonal parallel groups based on the parallel size and mask.
Arguments:
world_size (int): world size
parallel_size (List[int]):
The parallel size of each orthogonal parallel type. For example, if
tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4,
and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].
mask (List[bool]):
The mask controls which parallel methods the generated groups represent. If mask[i] is
True, it means the generated group contains the i-th parallelism method. For example,
if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then
the generated group is the `tp-dp` group, if the mask = [False, True, False], then the
generated group is the `pp` group.
Algorithm:
For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and
local_rank satisfy the following equation:
global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1)
tp_rank \in [0, tp_size)
dp_rank \in [0, dp_size)
pp_rank \in [0, pp_size)
If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each.
For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the
dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].)
The tp_rank and pp_rank will be combined to form the `dp_group_index`.
dp_group_index = tp_rank + pp_rank * tp_size (2)
So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in
range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the
equation (1).
This function solve this math problem.
For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4],
and the mask = [False, True, False]. Then,
dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2
dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2
...
dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2
dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4]
dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5]
...
dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]
"""
def prefix_product(a: List[int], init=1) -> List[int]:
r = [init]
for v in a:
init = init * v
r.append(init)
return r
def inner_product(a: List[int], b: List[int]) -> int:
return sum([x * y for x, y in zip(a, b)])
def decompose(index, shape, stride=None):
"""
This function solve the math problem below:
There is an equation:
index = sum(idx[i] * stride[i])
And given the value of index, stride.
Return the idx.
This function will used to get the pp/dp/pp_rank
from group_index and rank_in_group.
"""
if stride is None:
stride = prefix_product(shape)
idx = [(index // d) % s for s, d in zip(shape, stride)]
# stride is a prefix_product result. And the value of stride[-1]
# is not used.
assert (
sum([x * y for x, y in zip(idx, stride[:-1])]) == index
), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)
return idx
masked_shape = [s for s, m in zip(parallel_size, mask) if m]
unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m]
global_stride = prefix_product(parallel_size)
masked_stride = [d for d, m in zip(global_stride, mask) if m]
unmasked_stride = [d for d, m in zip(global_stride, mask) if not m]
group_size = prefix_product(masked_shape)[-1]
num_of_group = world_size // group_size
ranks = []
for group_index in range(num_of_group):
# get indices from unmaksed for group_index.
decomposed_group_idx = decompose(group_index, unmasked_shape)
rank = []
for rank_in_group in range(group_size):
# get indices from masked for rank_in_group.
decomposed_rank_idx = decompose(rank_in_group, masked_shape)
rank.append(
inner_product(decomposed_rank_idx, masked_stride) + inner_product(decomposed_group_idx, unmasked_stride)
)
ranks.append(rank)
return ranks
class RankGenerator(object):
def __init__(self, tp: int, dp: int, pp: int, cp: int, order: str) -> None:
self.tp = tp
self.dp = dp
self.pp = pp
self.cp = cp
self.world_size = tp * dp * pp * cp
self.name_to_size = {"tp": self.tp, "pp": self.pp, "dp": self.dp, "cp": self.cp}
order = order.lower()
for name in self.name_to_size.keys():
if name not in order and self.name_to_size[name] != 1:
raise RuntimeError(
f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})."
)
elif name not in order:
order = order + "-" + name
self.order = order
self.ordered_size = [self.name_to_size[token] for token in order.split("-")]
def get_mask(self, order: str, token: str):
ordered_token = order.split("-")
token = token.split("-")
mask = [False] * len(ordered_token)
for t in token:
mask[ordered_token.index(t)] = True
return mask
def get_ranks(self, token):
"""Get rank group by input token.
Arguments:
token (str):
Specify the ranks type that want to get. If we want
to obtain multiple parallel types, we can use a hyphen
'-' to separate them. For example, if we want to obtain
the TP_DP group, the token should be 'tp-dp'.
"""
mask = self.get_mask(self.order, token)
ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask)
return ranks
def initialize_model_parallel(
tp_size: int = 1,
pp_size: int = 1,
cp_size: int = 1,
nccl_communicator_config_path: Optional[str] = None,
distributed_timeout_minutes: int = 30,
order: str = "tp-cp-pp-dp",
) -> None:
"""Initialize model data parallel groups.
Borrow from: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Args:
tp_size (int, default = 1):
The number of GPUs to split individual tensors across.
pp_size (int, default = 1):
The number of tensor parallel GPU groups to split the
Transformer layers across. For example, if tp_size is 4 and
pp_size is 2, the model will be split into 2 groups of 4 GPUs.
cp_size (int, default = 1):
The number of tensor parallel GPU groups to split the
network input sequence length across. Compute of attention
module requires tokens of full sequence length, so GPUs
in a context parallel group need to communicate with each
other to exchange information of other sequence chunks.
Each GPU and its counterparts in other tensor parallel
groups compose a context parallel group.
For example, assume we have 8 GPUs, if tensor model parallel
size is 4 and context parallel size is 2, the network input
will be split into two sequence chunks, which are processed
by 2 different groups of 4 GPUs. One chunk is processed by
GPU0-3, the other chunk is processed by GPU4-7. Four groups
are build to do context parallel communications: [GPU0, GPU4],
[GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].
Context parallelism partitions sequence length, so it has no
impact on weights, which means weights are duplicated among
GPUs in a context parallel group. Hence, weight gradients
all-reduce is required in backward. For simplicity, we piggyback
GPUs of context parallelism on data parallel group for
weight gradient all-reduce.
nccl_communicator_config_path (str, default = None):
Path to the yaml file of NCCL communicator configurations.
`min_ctas`, `max_ctas`, and `cga_cluster_size` can be set
for each communicator.
distributed_timeout_minutes (int, default = 30): Timeout, in
minutes,for operations executed against distributed
process groups. See PyTorch documentation at
https://pytorch.org/docs/stable/distributed.html for
caveats.
order (str, default=tp-dp-pp):
The rank initialization order of parallelism. Now we support
tp-dp-pp and tp-pp-dp orders.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
if world_size % (tp_size * pp_size * cp_size) != 0:
raise RuntimeError(
f"world_size ({world_size}) is not divisible by tp_size "
f"({tp_size}) x pp_size ({pp_size}) "
f"x cp_size ({cp_size})"
)
nccl_comm_cfgs = {}
if nccl_communicator_config_path is not None:
try:
import yaml
except ImportError:
raise RuntimeError("Cannot import `yaml`. Setting custom nccl communicator configs " "requires the yaml package.")
with open(nccl_communicator_config_path, "r") as stream:
nccl_comm_cfgs = yaml.safe_load(stream)
dp_size: int = world_size // (tp_size * pp_size * cp_size)
rank = torch.distributed.get_rank()
rank_generator = RankGenerator(tp=tp_size, dp=dp_size, pp=pp_size, cp=cp_size, order=order)
timeout = timedelta(minutes=distributed_timeout_minutes)
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GROUP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS
global _DATA_PARALLEL_GROUP_WITH_CP
global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
for ranks in rank_generator.get_ranks("dp"):
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("dp", nccl_comm_cfgs))
group_gloo = torch.distributed.new_group(ranks, timeout=timeout, backend="gloo")
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_GLOO = group_gloo
_DATA_PARALLEL_GLOBAL_RANKS = ranks
for ranks_with_cp in rank_generator.get_ranks("dp-cp"):
group_with_cp = torch.distributed.new_group(
ranks_with_cp, timeout=timeout, pg_options=get_nccl_options("dp_cp", nccl_comm_cfgs)
)
group_with_cp_gloo = torch.distributed.new_group(ranks_with_cp, timeout=timeout, backend="gloo")
if rank in ranks_with_cp:
_DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp
# Build the context-parallel groups.
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_GLOBAL_RANKS
assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
for ranks in rank_generator.get_ranks("cp"):
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("cp", nccl_comm_cfgs))
if rank in ranks:
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
for ranks in rank_generator.get_ranks("tp-pp"):
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("mp", nccl_comm_cfgs))
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized"
for ranks in rank_generator.get_ranks("tp"):
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("tp", nccl_comm_cfgs))
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ranks
# Build the tensor + context parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP
global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP
assert (
_TENSOR_MODEL_PARALLEL_GROUP_WITH_CP is None
), "tensor model parallel group with context parallel is already initialized"
for ranks in rank_generator.get_ranks("tp-cp"):
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("tp_cp", nccl_comm_cfgs))
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP_WITH_CP = group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks
# Build the pipeline model-parallel groups
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized"
for ranks in rank_generator.get_ranks("pp"):
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("pp", nccl_comm_cfgs))
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
# Build the tensor + data parallel groups.
global _TENSOR_AND_DATA_PARALLEL_GROUP
global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
assert _TENSOR_AND_DATA_PARALLEL_GROUP is None, "Tensor + data parallel group is already initialized"
for ranks in rank_generator.get_ranks("tp-cp-dp"):
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("tp_cp_dp", nccl_comm_cfgs))
if rank in ranks:
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group
for ranks in rank_generator.get_ranks("tp-dp"):
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=get_nccl_options("tp_dp", nccl_comm_cfgs))
if rank in ranks:
_TENSOR_AND_DATA_PARALLEL_GROUP = group
def is_initialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is not None
def is_unitialized() -> bool:
"""Check if parallel state has been initialized
Deprecated. Use is_initialized instead.
"""
warnings.warn("is_unitialized is deprecated, use is_initialized instead", DeprecationWarning)
return not is_initialized()
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
return False
return True
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
return _MODEL_PARALLEL_GROUP
def get_tp_group(check_initialized=True, with_context_parallel=False):
"""Get the tensor model parallel group the caller rank belongs to."""
if check_initialized:
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "tensor model parallel group is not initialized"
if with_context_parallel:
assert (
_TENSOR_MODEL_PARALLEL_GROUP_WITH_CP is not None
), "tensor model parallel group with context parallel combined is not initialized"
return _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP
else:
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "tensor model parallel group is not initialized"
return _TENSOR_MODEL_PARALLEL_GROUP
def get_pp_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, "pipeline_model parallel group is not initialized"
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_dp_group(with_context_parallel=False):
"""Get the data parallel group the caller rank belongs to."""
if with_context_parallel:
assert (
_DATA_PARALLEL_GROUP_WITH_CP is not None
), "data parallel group with context parallel combined is not initialized"
return _DATA_PARALLEL_GROUP_WITH_CP
else:
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP
def get_dp_group_gloo(with_context_parallel=False):
"""Get the data parallel group-gloo the caller rank belongs to."""
if with_context_parallel:
assert (
_DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None
), "data parallel group-gloo with context parallel combined is not initialized"
return _DATA_PARALLEL_GROUP_WITH_CP_GLOO
else:
assert _DATA_PARALLEL_GROUP_GLOO is not None, "data parallel group-gloo is not initialized"
return _DATA_PARALLEL_GROUP_GLOO
def get_cp_group(check_initialized=True):
"""Get the context parallel group the caller rank belongs to."""
if check_initialized:
assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
return _CONTEXT_PARALLEL_GROUP
def get_tp_world_size(with_context_parallel=False):
"""Return world size for the tensor model parallel group."""
return torch.distributed.get_world_size(group=get_tp_group(with_context_parallel=with_context_parallel))
def get_pp_world_size():
"""Return world size for the pipeline model parallel group."""
return torch.distributed.get_world_size(group=get_pp_group())
def get_tp_rank(with_context_parallel=False):
"""Return my rank for the tensor model parallel group."""
return torch.distributed.get_rank(group=get_tp_group(with_context_parallel=with_context_parallel))
def get_pp_rank():
"""Return my rank for the pipeline model parallel group."""
return torch.distributed.get_rank(group=get_pp_group())
def is_pipeline_first_stage():
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
return get_pp_rank() == 0
def is_pipeline_last_stage():
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
return get_pp_rank() == (get_pp_world_size() - 1)
def get_tensor_model_parallel_src_rank(with_context_parallel=False):
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
assert _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None, "Tensor model parallel group is not initialized"
if with_context_parallel:
assert (
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
), "Tensor model parallel group with context parallel combined is not initialized"
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP[0]
else:
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[0]
def get_tensor_model_parallel_ranks(with_context_parallel=False):
"""Return all global ranks for the tensor model parallel group."""
if with_context_parallel:
assert (
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
), "Tensor model parallel group with context parallel combined is not initialized"
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP
else:
assert _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None, "Tensor model parallel group is not initialized"
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
def get_tensor_model_parallel_last_rank(with_context_parallel=False):
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
assert _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None, "Tensor model parallel group is not initialized"
if with_context_parallel:
assert (
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
), "Tensor model parallel group with context parallel combined is not initialized"
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP[-1]
else:
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[-1]
def get_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
last_rank_local = get_pp_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pp_rank()
world_size = get_pp_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that preceeds the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pp_rank()
world_size = get_pp_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def get_dp_world_size(with_context_parallel=False):
"""Return world size for the data parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_world_size(group=get_dp_group(with_context_parallel=with_context_parallel))
else:
return 0
def get_dp_rank(with_context_parallel=False):
"""Return my rank for the data parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank(group=get_dp_group(with_context_parallel=with_context_parallel))
else:
return 0
def get_cp_world_size():
"""Return world size for the context parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_world_size(group=get_cp_group())
else:
return 0
def get_cp_rank():
"""Return my rank for the context parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank(group=get_cp_group())
else:
return 0
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP
_TENSOR_MODEL_PARALLEL_GROUP_WITH_CP = None
global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP_GLOO
_DATA_PARALLEL_GROUP_GLOO = None
global _TENSOR_AND_DATA_PARALLEL_GROUP
_TENSOR_AND_DATA_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None
global _DATA_PARALLEL_GLOBAL_RANKS
_DATA_PARALLEL_GLOBAL_RANKS = None
global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None
global _CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP = None
global _CONTEXT_PARALLEL_GLOBAL_RANKS
_CONTEXT_PARALLEL_GLOBAL_RANKS = None
global _DATA_PARALLEL_GROUP_WITH_CP
_DATA_PARALLEL_GROUP_WITH_CP = None
global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None
global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
```
## /inference/infra/parallelism/__init__.py
```py path="/inference/infra/parallelism/__init__.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .context_parallel import CSOHelper, UlyssesScheduler, cp_post_process, cp_pre_process, cso_communication
from .pipeline_parallel import pp_scheduler
from .tile_parallel import TileProcessor
__all__ = [
"CSOHelper",
"cso_communication",
"UlyssesScheduler",
"pp_scheduler",
"TileProcessor",
"cp_pre_process",
"cp_post_process",
]
```
## /inference/infra/parallelism/context_parallel.py
```py path="/inference/infra/parallelism/context_parallel.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, List, Tuple, Union
import torch
import torch.distributed
from einops import rearrange
from inference.common import ModelMetaArgs, PackedCoreAttnParams, PackedCrossAttnParams, divide
from inference.infra.distributed import parallel_state as mpu
#####################################################
# Common Primitives
#####################################################
def scatter_to_context_parallel_region(input_, cp_split_sizes, cp_shuffle_num=1, cp_pad_size=0):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size = mpu.get_cp_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along first dimension with padding.
rank = mpu.get_cp_rank()
if cp_shuffle_num > 1:
cp_pad_size = divide(cp_pad_size, cp_shuffle_num)
cp_split_sizes = [divide(s, cp_shuffle_num) for s in cp_split_sizes]
dim_offset = sum(cp_split_sizes[:rank])
xs = []
for x in torch.chunk(input_, cp_shuffle_num, dim=0):
x = torch.nn.functional.pad(x, [0, 0] * (x.dim() - 1) + [0, cp_pad_size], mode="constant", value=0)
xs.append(x[dim_offset : dim_offset + cp_split_sizes[rank]])
output = torch.concat(xs, dim=0)
else:
dim_offset = sum(cp_split_sizes[:rank])
x = torch.nn.functional.pad(input_, [0, 0] * (input_.dim() - 1) + [0, cp_pad_size], mode="constant", value=0)
output = x[dim_offset : dim_offset + cp_split_sizes[rank]].contiguous()
return output
def gather_from_context_parallel_region(input_, cp_split_sizes, cp_shuffle_num=1, cp_pad_size=0):
"""Gather tensors and concatinate along the first dimension."""
world_size = mpu.get_cp_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
input_ = input_.contiguous()
total_seq_len = sum(cp_split_sizes)
dim_size = list(input_.size())
dim_size[0] = total_seq_len
output = torch.empty(dim_size, dtype=input_.dtype, device=input_.device)
outputs = list(torch.split(output, cp_split_sizes, dim=0))
torch.distributed.all_gather(outputs, input_, group=mpu.get_cp_group())
if cp_shuffle_num > 1:
total_seq_len = divide(total_seq_len, cp_shuffle_num)
cp_pad_size = divide(cp_pad_size, cp_shuffle_num)
chunks = [torch.chunk(o, cp_shuffle_num, dim=0) for o in outputs]
output = torch.concat(
[
torch.concat([chunk[i] for chunk in chunks], dim=0)[: total_seq_len - cp_pad_size]
for i in range(cp_shuffle_num)
],
dim=0,
)
else:
output = torch.concat(outputs, dim=0)[: total_seq_len - cp_pad_size]
return output
class FakeHandle:
def __init__(self):
pass
def wait(self):
pass
#####################################################
# Context Parallel Process
#####################################################
def update_packed_seq_params_for_cuda_graph(cross_attn_params: PackedCrossAttnParams, xattn_mask: torch.Tensor):
assert xattn_mask is not None
# xattn_mask: (N * denoising_range_num, L, 1, 1)
xattn_mask = xattn_mask.reshape(xattn_mask.shape[0], -1)
batch_size, static_caption_length = xattn_mask.shape
# Get index_map for kv_range injection, map y_index to static_caption_length
y_index = torch.sum(xattn_mask, dim=-1)
cu_seqlens_k = torch.cat([y_index.new_tensor([0]), y_index]).to(torch.int32).to(xattn_mask.device)
cu_seqlens_k = cu_seqlens_k.cumsum(-1).to(torch.int32)
static_cu_seqlens_k = torch.arange(0, (batch_size + 1) * static_caption_length, static_caption_length)
assert cu_seqlens_k.shape[0] == batch_size + 1 == static_cu_seqlens_k.shape[0]
start_index_map = dict(zip(cu_seqlens_k.flatten().tolist(), static_cu_seqlens_k.flatten().tolist()))
# Move kv_range to the right position
kv_range_start_list = cross_attn_params.kv_ranges[:, 0].flatten().tolist()
static_kv_range_start = [start_index_map[kv_range_start_list[i]] for i in range(len(kv_range_start_list))]
static_kv_range_start = torch.tensor(static_kv_range_start, dtype=torch.int32, device=xattn_mask.device)
assert static_kv_range_start.shape[0] == cross_attn_params.kv_ranges.shape[0]
static_kv_range_diff = cross_attn_params.kv_ranges[:, 1] - cross_attn_params.kv_ranges[:, 0]
static_kv_range_end = static_kv_range_start + static_kv_range_diff
static_kv_range = torch.stack((static_kv_range_start, static_kv_range_end), dim=1)
assert static_kv_range.shape == cross_attn_params.kv_ranges.shape
return PackedCrossAttnParams(
q_ranges=cross_attn_params.q_ranges,
kv_ranges=static_kv_range,
cu_seqlens_q=cross_attn_params.cu_seqlens_q,
cu_seqlens_kv=cross_attn_params.cu_seqlens_kv,
max_seqlen_q=cross_attn_params.max_seqlen_q,
max_seqlen_kv=cross_attn_params.max_seqlen_kv,
)
def cp_update_cross_attn_qkv_range(
cross_attn_params: PackedCrossAttnParams,
batch_size: int,
cp_split_sizes: List[int],
device: torch.device,
cp_shuffle_num: int = 1,
cp_pad_size: int = 0,
):
"""
Update cross_attn_params for cross_attn in context parallel.
Input:
cross_attn_params: PackedCrossAttnParams. Packed sequence parameters for cross_atten
batch_size: int. Batch size
cp_split_sizes: List[int]. Split sizes for each rank
device: torch.device. Device
Output:
cross_attn_params: PackedCrossAttnParams. Updated packed parameters for cross_atten
"""
# Update cu_seqlens_q and max_seqlen_q because split x maybe unbalanced
cp_rank = mpu.get_cp_rank()
seq_len_cur_rank = cp_split_sizes[cp_rank]
cp_split_sizes = [divide(x, cp_shuffle_num) for x in cp_split_sizes]
cp_split_sizes = torch.tensor(cp_split_sizes, dtype=torch.int32, device=device)
base_cp_boundaries = torch.cat((torch.zeros(1, dtype=torch.int32, device=device), cp_split_sizes.cumsum(0)))
total_seq_len = base_cp_boundaries[-1]
cu_seqlens_q = cross_attn_params.cu_seqlens_q
cu_seqlens_k = cross_attn_params.cu_seqlens_kv
cu_seqlens_pad = torch.arange(cu_seqlens_q.shape[0], dtype=torch.int32, device=device) * divide(
cp_pad_size, cp_shuffle_num
)
cu_seqlens_q = cu_seqlens_q + cu_seqlens_pad
q_seg_starts, q_seg_ends = cu_seqlens_q[:-1], cu_seqlens_q[1:]
xattn_q_ranges, xattn_k_ranges = [], []
for i in range(batch_size):
inner_xattn_q_ranges, inner_xattn_k_ranges = [], []
for j in range(cp_shuffle_num):
global_offset = i * total_seq_len * cp_shuffle_num + j * total_seq_len
cp_boundaries = base_cp_boundaries + global_offset
this_cp_start, this_cp_end = (cp_boundaries[cp_rank], cp_boundaries[cp_rank + 1])
q_inter_starts = torch.maximum(this_cp_start, q_seg_starts)
q_inter_ends = torch.minimum(this_cp_end, q_seg_ends)
q_mask = q_inter_starts < q_inter_ends
valid_q_starts = q_inter_starts[q_mask]
valid_q_ends = q_inter_ends[q_mask]
k_seg_starts, k_seg_ends = cu_seqlens_k[:-1], cu_seqlens_k[1:]
valid_indices = torch.nonzero(q_mask, as_tuple=True)[0]
valid_k_starts = k_seg_starts[valid_indices]
valid_k_ends = k_seg_ends[valid_indices]
part_xattn_q_rangs = torch.stack((valid_q_starts, valid_q_ends), dim=1)
offset = part_xattn_q_rangs[:, 0].min()
part_xattn_q_rangs = part_xattn_q_rangs - offset
inner_xattn_q_ranges.append(part_xattn_q_rangs)
inner_xattn_k_ranges.append(torch.stack((valid_k_starts, valid_k_ends), dim=1))
inner_end_values = torch.tensor([ranges[-1, -1] for ranges in inner_xattn_q_ranges], dtype=torch.int32)
inner_offsets = torch.cat((torch.zeros(1, dtype=inner_end_values.dtype), torch.cumsum(inner_end_values[:-1], dim=0)))
inner_xattn_q_ranges = [tensor + int(offset) for tensor, offset in zip(inner_xattn_q_ranges, inner_offsets)]
xattn_q_ranges.append(torch.cat(inner_xattn_q_ranges, dim=0))
xattn_k_ranges.append(torch.cat(inner_xattn_k_ranges, dim=0))
end_values = torch.tensor([ranges[-1, -1].item() for ranges in xattn_q_ranges], dtype=torch.int32)
offsets = torch.cat((torch.zeros(1, dtype=end_values.dtype), torch.cumsum(end_values[:-1], dim=0)))
shifted_tensors = [tensor + int(offset) for tensor, offset in zip(xattn_q_ranges, offsets)]
xattn_q_ranges_ts = torch.cat(shifted_tensors, dim=0)
xattn_k_ranges_ts = torch.cat(xattn_k_ranges, dim=0)
cu_seqlens_q = torch.unique(xattn_q_ranges_ts)
cu_seqlens_k = torch.unique(xattn_k_ranges_ts)
assert (
cu_seqlens_q.shape == cu_seqlens_k.shape
), f"cu_seqlens_q.shape: {cu_seqlens_q.shape}, cu_seqlens_k.shape: {cu_seqlens_k.shape}, "
return PackedCrossAttnParams(
q_ranges=xattn_q_ranges_ts,
kv_ranges=xattn_k_ranges_ts,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=seq_len_cur_rank,
max_seqlen_kv=cross_attn_params.max_seqlen_kv,
)
def cp_ulysses_process(
cp_size: int,
x: torch.Tensor,
condition_map: torch.Tensor,
rope: torch.Tensor,
xattn_mask_for_cuda_graph: Union[torch.Tensor, None],
cross_attn_params: PackedCrossAttnParams,
):
seq_len, N, D = x.shape
assert seq_len == rope.size(0), f"seq_len: {seq_len} != rope.size(0): {rope.size(0)}"
assert condition_map.size(0) == seq_len, f"condition_map.size(0): {condition_map.size(0)} != seq_len: {seq_len}"
# Part1: split for CP
cp_split_sizes = [seq_len // cp_size] * cp_size
for i in range(seq_len % cp_size):
cp_split_sizes[i] += 1
# Part2: scatter to CP
x = scatter_to_context_parallel_region(x, cp_split_sizes)
condition_map = scatter_to_context_parallel_region(condition_map, cp_split_sizes)
rope = scatter_to_context_parallel_region(rope, cp_split_sizes)
# Part3: update cross_attn cross_attn_params
cross_attn_params = cp_update_cross_attn_qkv_range(cross_attn_params, N, cp_split_sizes, x.device)
if xattn_mask_for_cuda_graph is not None:
cross_attn_params = update_packed_seq_params_for_cuda_graph(cross_attn_params, xattn_mask_for_cuda_graph)
return x, condition_map, rope, cp_split_sizes, cross_attn_params
def cp_shuffle_overlap_process(
cp_size: int,
x: torch.Tensor,
condition_map: torch.Tensor,
rope: torch.Tensor,
xattn_mask_for_cuda_graph: Union[torch.Tensor, None],
ardf_meta: dict,
core_attn_params: PackedCoreAttnParams,
cross_attn_params: PackedCrossAttnParams,
):
seq_len, N, D = x.shape
assert seq_len == rope.size(0), f"seq_len: {seq_len} != rope.size(0): {rope.size(0)}"
assert condition_map.size(0) == seq_len, f"condition_map.size(0): {condition_map.size(0)} != seq_len: {seq_len}"
cp_shuffle_num = ardf_meta["denoising_range_num"]
# Part1: calculate cp_pad_size and cp_split_sizes
cp_pad_size = 0
if divide(seq_len, cp_shuffle_num) % cp_size != 0:
cp_pad_size = (cp_size - divide(seq_len, cp_shuffle_num) % cp_size) * cp_shuffle_num
cp_split_sizes = [(seq_len + cp_pad_size) // cp_size] * cp_size
# Part2: scatter to CP
x = scatter_to_context_parallel_region(x, cp_split_sizes, cp_shuffle_num, cp_pad_size)
condition_map = scatter_to_context_parallel_region(condition_map, cp_split_sizes, cp_shuffle_num, cp_pad_size)
rope = scatter_to_context_parallel_region(rope, cp_split_sizes, cp_shuffle_num, cp_pad_size)
# Part3: update core_attn_params
gcd = math.gcd(seq_len, seq_len + cp_pad_size)
_sq = seq_len // gcd
_psq = (seq_len + cp_pad_size) // gcd
q_range = ardf_meta["q_range"] * _psq // _sq
max_seqlen_q = ardf_meta["max_seqlen_q"] * _psq // _sq
core_attn_params = PackedCoreAttnParams(
q_range=q_range,
k_range=ardf_meta["k_range"],
np_q_range=q_range.cpu().numpy(),
np_k_range=ardf_meta["k_range"].cpu().numpy(),
max_seqlen_q=max_seqlen_q,
max_seqlen_k=ardf_meta["max_seqlen_k"],
)
# Part4: update cross_attn cross_attn_params
cross_attn_params = cp_update_cross_attn_qkv_range(
cross_attn_params, N, cp_split_sizes, x.device, cp_shuffle_num, cp_pad_size
)
if xattn_mask_for_cuda_graph is not None:
cross_attn_params = update_packed_seq_params_for_cuda_graph(cross_attn_params, xattn_mask_for_cuda_graph)
return x, condition_map, rope, cp_pad_size, cp_split_sizes, core_attn_params, cross_attn_params
def cp_pre_process(
cp_size: int,
cp_strategy: str,
x: torch.Tensor,
condition_map: torch.Tensor,
rope: torch.Tensor,
xattn_mask_for_cuda_graph: Union[torch.Tensor, None],
ardf_meta: dict,
core_attn_params: PackedCoreAttnParams,
cross_attn_params: PackedCrossAttnParams,
):
"""
This function is used to handle context parallel behavior,
split input tensors into multiple parts and scatter them to different GPUs.
Input:
cp_strategy: str. cp_ulysses for hopper or newer, cp_shuffle_overlap for 4090 or older
x: (S, N, D). torch.Tensor of inputs embedding (images or latent representations of images)
condition_map: (N * S). torch.Tensor determine which condition to use for each token
rope: (S, 96). torch.Tensor of rope
xattn_mask_for_cuda_graph: (N * denoising_range_num, L, 1, 1). torch.Tensor of xattn mask for cuda graph, None means no cuda graph
core_attn_params: PackedCoreAttnParams. Packed sequence parameters for core_atten
cross_attn_params: PackedCrossAttnParams. Packed sequence parameters for cross_atten
Output:
x: (S', N, D). torch.Tensor of inputs embedding (images or latent representations of images)
condition_map: (N * S'). torch.Tensor determine which condition to use for each token
rope: (S', 96). torch.Tensor of rope
cp_split_sizes: List[int]. Split sizes for each rank
core_attn_params: PackedCoreAttnParams
cross_attn_params: PackedCrossAttnParams
"""
if cp_size == 1:
return x, condition_map, rope, None, None, core_attn_params, cross_attn_params
if cp_strategy == "cp_ulysses":
(x, condition_map, rope, cp_split_sizes, cross_attn_params) = cp_ulysses_process(
cp_size, x, condition_map, rope, xattn_mask_for_cuda_graph, cross_attn_params
)
return (x, condition_map, rope, 0, cp_split_sizes, core_attn_params, cross_attn_params)
elif cp_strategy == "cp_shuffle_overlap":
(
x,
condition_map,
rope,
cp_pad_size,
cp_split_sizes,
core_attn_params,
cross_attn_params,
) = cp_shuffle_overlap_process(
cp_size, x, condition_map, rope, xattn_mask_for_cuda_graph, ardf_meta, core_attn_params, cross_attn_params
)
return (x, condition_map, rope, cp_pad_size, cp_split_sizes, core_attn_params, cross_attn_params)
else:
raise ValueError(f"Invalid CP strategy: {cp_strategy}, expected cp_ulysses or cp_shuffle_overlap")
def cp_post_process(cp_size: int, cp_strategy: str, x: torch.Tensor, meta_args: ModelMetaArgs) -> torch.Tensor:
if cp_size == 1:
return x
if cp_strategy == "cp_shuffle_overlap":
x = gather_from_context_parallel_region(
x, meta_args.cp_split_sizes, meta_args.denoising_range_num, meta_args.cp_pad_size
)
elif cp_strategy == "cp_ulysses":
x = gather_from_context_parallel_region(x, meta_args.cp_split_sizes)
else:
raise ValueError(f"Invalid CP strategy: {cp_strategy}, expected cp_ulysses or cp_shuffle_overlap")
return x
#####################################################
# Ulysses Attention Pipeline
#####################################################
def all_to_all_input_split(tensor: torch.Tensor, cp_split_sizes: List[int]) -> Tuple[torch.Tensor, torch.distributed.Work]:
"""
Scatter head_number and gather seq_len, for example:
input: (seq_len, cp * hn, hd)
output: (seq_len * cp, hn, hd)
NOTE: seq_len of input maybe not equal, which depends on cp_split_sizes[mpu.get_cp_rank()]
"""
cp_world_size = mpu.get_cp_world_size()
if cp_world_size == 1:
return tensor, FakeHandle()
assert cp_split_sizes is not None
_, hn, _ = tensor.shape
if cp_world_size % hn == 0 and cp_world_size != hn:
tensor = torch.repeat_interleave(tensor, repeats=divide(cp_world_size, hn), dim=1).contiguous()
assert tensor.is_contiguous()
input = rearrange(tensor, "seq (cp hn) hd -> (cp seq) hn hd", cp=cp_world_size).contiguous()
output = torch.empty([sum(cp_split_sizes), *input.shape[1:]], device=input.device, dtype=input.dtype)
handle = torch.distributed.all_to_all_single(
output, input, output_split_sizes=cp_split_sizes, group=mpu.get_cp_group(), async_op=True
)
return output, handle
def all_to_all_output_split(tensor: torch.Tensor, cp_split_sizes: List[int]) -> Tuple[torch.Tensor, torch.distributed.Work]:
"""
Scatter seq_len and gather head_number, for example:
input: (seq_len * cp, hn, hd)
output: (seq_len, cp * hn, hd)
NOTE: seq_len of output maybe not equal, which depends on cp_split_sizes[mpu.get_cp_rank()]
"""
cp_world_size = mpu.get_cp_world_size()
if cp_world_size == 1:
return tensor, FakeHandle()
assert cp_split_sizes is not None
assert tensor.is_contiguous()
_, hn, _ = tensor.shape
output = torch.empty(
[cp_split_sizes[mpu.get_cp_rank()] * cp_world_size, *tensor.shape[1:]], device=tensor.device, dtype=tensor.dtype
)
handle = torch.distributed.all_to_all_single(
output, tensor, input_split_sizes=cp_split_sizes, group=mpu.get_cp_group(), async_op=True
)
return output, handle
def fused_qkv_communication(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cp_split_sizes: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
cp_world_size = mpu.get_cp_world_size()
if cp_world_size == 1:
return q, k, v
assert cp_split_sizes is not None
_, k_head, _ = k.shape
if cp_world_size % k_head == 0 and cp_world_size != k_head:
k = torch.repeat_interleave(k, repeats=divide(cp_world_size, k_head), dim=1)
v = torch.repeat_interleave(v, repeats=divide(cp_world_size, k_head), dim=1)
q = rearrange(q, "seq (cp hn) hd -> (cp seq) hn hd", cp=cp_world_size).contiguous()
k = rearrange(k, "seq (cp hn) hd -> (cp seq) hn hd", cp=cp_world_size).contiguous()
v = rearrange(v, "seq (cp hn) hd -> (cp seq) hn hd", cp=cp_world_size).contiguous()
head_split_number = [q.shape[1], k.shape[1], v.shape[1]]
qkv = torch.cat([q, k, v], dim=1).contiguous()
qkv_output = torch.empty([sum(cp_split_sizes), *qkv.shape[1:]], device=qkv.device, dtype=qkv.dtype)
torch.distributed.all_to_all_single(
qkv_output, qkv, output_split_sizes=cp_split_sizes, group=mpu.get_cp_group(), async_op=False
)
q, k, v = torch.split(qkv_output, head_split_number, dim=1)
return q, k, v
class UlyssesScheduler:
def __init__(self):
pass
@staticmethod
def get_attn_and_xattn_with_comm_overlap(
get_q_func: Callable, # [seq hn hd]
get_k_func: Callable, # [seq hn hd]
get_v_func: Callable, # [seq hn hd]
kv_cache_func: Callable,
core_attn_func: Callable,
cross_attn_func: Callable,
overlap_degree: int,
batch_size: int,
cp_size: int,
cp_split_sizes: List[int] = None,
):
"""
Get Q, K, V with communication overlap.
Input:
get_q: Callable, function to get q, shape [b, sq, hn, hd]
get_k: Callable, function to get k, shape [sq, b, hn, hd]
get_v: Callable, function to get v, shape [sq, b, hn, hd]
NOTE: Why follow such compute and comm order?
1. v_compute
2. k_compute(overlap with v_comm)
3. q_compute(overlap with k_comm)
4. kv_cache_func(overlap with q_comm)
Follow the principle: We need to begin comm as soon as possible to hide the comm latency.
The computation flops and commnunication order is:
flops order: q_compute (larger hidden_size + layernorm) > k_compute (layernorm) > v_compute
comm order: q_compute (larger hidden_size) > k_compute = v_compute
"""
value = get_v_func()
value, handle_v = all_to_all_input_split(value, cp_split_sizes)
key = get_k_func()
key, handle_k = all_to_all_input_split(key, cp_split_sizes)
query = get_q_func()
query, handle_q = all_to_all_input_split(query, cp_split_sizes)
handle_v.wait()
handle_k.wait()
kv = torch.concat([key, value], dim=-1)
key, value = kv_cache_func(kv)
handle_q.wait()
return UlyssesScheduler.get_attn_and_xattn_base(
query, key, value, core_attn_func, cross_attn_func, overlap_degree, batch_size, cp_size, cp_split_sizes
)
@staticmethod
def get_attn_and_xattn_with_fused_kv_comm(
get_q_func: Callable,
get_kv_func: Callable,
kv_cache_func: Callable,
core_attn_func: Callable,
cross_attn_func: Callable,
overlap_degree: int,
batch_size: int,
cp_size: int,
cp_split_sizes: List[int] = None,
):
"""
When seq_len is very small, CPU-bound issues are severe. By fusing kv communication,
CPU operations and the number of kernel launches are reduced.
"""
kv = get_kv_func()
kv, handle_kv = all_to_all_input_split(kv, cp_split_sizes)
query = get_q_func()
query, handle_q = all_to_all_input_split(query, cp_split_sizes)
handle_kv.wait()
key, value = kv_cache_func(kv)
handle_q.wait()
return UlyssesScheduler.get_attn_and_xattn_base(
query, key, value, core_attn_func, cross_attn_func, overlap_degree, batch_size, cp_size, cp_split_sizes
)
def get_attn_and_xattn_with_fused_qkv_comm(
get_qkv_func: Callable,
kv_cache_func: Callable,
core_attn_func: Callable,
cross_attn_func: Callable,
overlap_degree: int,
batch_size: int,
cp_size: int,
cp_split_sizes: List[int] = None,
):
"""
By fusing the communication of q, k, and v together, further optimize CPU-bound issues.
"""
q, k, v = get_qkv_func()
q, k, v = fused_qkv_communication(q, k, v, cp_split_sizes)
k, v = kv_cache_func(torch.cat([k, v], dim=-1))
return UlyssesScheduler.get_attn_and_xattn_base(
q, k, v, core_attn_func, cross_attn_func, overlap_degree, batch_size, cp_size, cp_split_sizes
)
@staticmethod
def get_attn_and_xattn_base(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
core_attn_func: Callable,
cross_attn_func: Callable,
overlap_degree: int,
batch_size: int,
cp_size: int,
cp_split_sizes: List[int] = None,
):
# Split Query, Key, Value into multiple parts
# k/v may have different sequence length with q due to kv cache
q_seq, q_head, q_hidden = query.shape
kv_seq, kv_head, kv_hidden = key.shape
if overlap_degree == -1:
overlap_degree = q_head // kv_head
else:
assert overlap_degree <= q_head
if overlap_degree == 1:
query = [query]
elif kv_head == 1: # MQA
query = query.chunk(overlap_degree, dim=1)
else: # GQA
assert q_head % (overlap_degree * kv_head) == 0
query = query.reshape(q_seq, kv_head, -1, q_hidden)
query = query.chunk(overlap_degree, dim=2)
query = [q.reshape(q_seq, -1, q_hidden) for q in query]
# Compute Core Attention
handle_attn = None
core_attn_out = None
core_attn_outs = []
for i in range(overlap_degree):
core_attn_out_new = core_attn_func(query[i], key, value)
if handle_attn is not None:
handle_attn.wait()
core_attn_outs.append(core_attn_out)
core_attn_out, handle_attn = all_to_all_output_split(core_attn_out_new, cp_split_sizes)
xattn_out = cross_attn_func()
handle_attn.wait()
core_attn_outs.append(core_attn_out)
core_attn_out = torch.cat(core_attn_outs, dim=1)
core_attn_out = rearrange(core_attn_out, "(cp sq b) hn hd -> (sq) b (cp hn hd)", cp=cp_size, b=batch_size)
return core_attn_out, xattn_out
#####################################################
# CSO(context shuffle overlap) Attention Pipeline
#####################################################
def cso_communication(
input: torch.Tensor, cp_world_size: int, cp_split_sizes: List[int], comm_type: str = None
) -> Tuple[torch.Tensor, torch.distributed.Work]:
if cp_world_size == 1:
return input, FakeHandle()
assert cp_split_sizes is not None
_, hn, _ = input.shape
if comm_type == "kv":
if cp_world_size % hn == 0 and cp_world_size != hn:
input = torch.repeat_interleave(input, repeats=divide(cp_world_size, hn), dim=1)
input = rearrange(input, "spb (cp hn) hd -> (cp spb) hn hd", cp=cp_world_size).contiguous()
output = torch.empty(input.shape, device=input.device, dtype=input.dtype)
handle = torch.distributed.all_to_all_single(
output, input, input_split_sizes=cp_split_sizes, group=mpu.get_cp_group(), async_op=True
)
return output, handle
class CSOHelper:
def __init__(self, cp_shuffle_num, cp_world_size, cp_split_sizes):
self.cp_shuffle_num = cp_shuffle_num
self.cp_world_size = cp_world_size
self.cp_split_sizes = [divide(x, self.cp_shuffle_num) for x in cp_split_sizes]
def split_query_for_overlap(self, query):
query = rearrange(
query, "(dn spb) (cp hn) hd -> (dn cp spb) hn hd", cp=self.cp_world_size, dn=self.cp_shuffle_num
).contiguous()
querys = list(torch.chunk(query, self.cp_shuffle_num, dim=0))
querys[0], handle_q = cso_communication(querys[0], self.cp_world_size, self.cp_split_sizes)
return querys, handle_q
def overlap(self, fattn, qs, k, v):
core_attn_outs = []
for i in range(self.cp_shuffle_num):
if self.cp_shuffle_num == 1:
q = qs[0]
elif i == 0:
q = qs[0]
loop_var, loop_handle = cso_communication(qs[i + 1], self.cp_world_size, self.cp_split_sizes)
else:
loop_handle.wait()
if loop_var.numel() == qs[0].numel():
q = loop_var
else:
assert loop_var.numel() == qs[0].numel() * 2
q, ready_o = torch.chunk(loop_var, 2, dim=-1)
core_attn_outs.append(ready_o)
loop_var = torch.concat([qs[i + 1], o], dim=-1) if i < self.cp_shuffle_num - 1 else o
loop_var, loop_handle = cso_communication(loop_var, self.cp_world_size, self.cp_split_sizes)
o = fattn(q, k, v, i)
if i == self.cp_shuffle_num - 1:
if i != 0:
loop_handle.wait()
assert loop_var.numel() == qs[0].numel()
core_attn_outs.append(loop_var)
last_o, handle_attn = cso_communication(o, self.cp_world_size, self.cp_split_sizes)
core_attn_outs.append(last_o)
return core_attn_outs, handle_attn
```
## /inference/infra/parallelism/pipeline_parallel.py
```py path="/inference/infra/parallelism/pipeline_parallel.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
from dataclasses import dataclass
from typing import Optional
import torch
from inference.infra.distributed import parallel_state as mpu
@dataclass
class TensorAndHandler:
tensor: torch.Tensor
handler: torch.distributed.Work
class PPScheduler:
def __init__(self):
"""Initialize an instance of the PPScheduler class"""
self.device: torch.device = torch.device(f"cuda:{torch.cuda.current_device()}")
self.recv_queue: queue.Queue = queue.Queue()
def isend_next(self, tensor: torch.Tensor) -> torch.distributed.Work:
"""Asynchronously send a tensor to the next pipeline and return the send handle.
Args:
tensor (torch.Tensor): The tensor to be sent.
Returns:
torch.distributed.Work: The handle for the send operation.
"""
handle = torch.distributed.isend(
tensor.contiguous(), dst=mpu.get_pipeline_model_parallel_next_rank(), group=mpu.get_pp_group()
)
return handle
def irecv_prev(self, buffer: torch.Tensor) -> torch.distributed.Work:
"""Asynchronously receive a tensor from the previous pipeline and return the receive handle.
Args:
buffer (torch.Tensor): The buffer tensor for receiving data.
Returns:
torch.distributed.Work: The handle for the receive operation.
"""
handle = torch.distributed.irecv(buffer, src=mpu.get_pipeline_model_parallel_prev_rank(), group=mpu.get_pp_group())
return handle
def recv_prev_data(self, shape: torch.Size, dtype: torch.dtype) -> torch.Tensor:
"""Receive data from the previous pipeline and return the received tensor.
Args:
shape (torch.Size): The shape of the tensor to receive.
dtype (torch.dtype): The data type of the tensor to receive.
Returns:
torch.Tensor: The received tensor.
"""
recv_tensor = torch.empty(shape, dtype=dtype, device=self.device)
self.irecv_prev(recv_tensor).wait()
return recv_tensor
def queue_irecv_prev(self, shape: torch.Size, dtype: torch.dtype) -> None:
"""Put the asynchronously received tensor and handle into the receive queue.
Args:
shape (torch.Size): The shape of the tensor to receive.
dtype (torch.dtype): The data type of the tensor to receive.
"""
recv_tensor = torch.empty(shape, dtype=dtype, device=self.device)
handle = self.irecv_prev(recv_tensor)
self.recv_queue.put(TensorAndHandler(tensor=recv_tensor, handler=handle))
def queue_irecv_prev_data(self) -> torch.Tensor:
"""Get a tensor from the receive queue and wait for the receive operation to complete.
Returns:
torch.Tensor: The received tensor obtained from the queue.
"""
tensor_and_handler = self.recv_queue.get()
tensor_and_handler.handler.wait()
return tensor_and_handler.tensor
_PP_SCHEDULER: Optional[PPScheduler] = None
def init_pp_scheduler():
"""Initialize the PPScheduler instance.
Raises:
AssertionError: If the PPScheduler is already initialized.
"""
global _PP_SCHEDULER
assert _PP_SCHEDULER is None, "pipeline model parallel group is already initialized"
_PP_SCHEDULER = PPScheduler()
def pp_scheduler() -> PPScheduler:
"""Get the current PPScheduler instance.
Returns:
PPScheduler: The current PPScheduler instance.
Raises:
AssertionError: If the PPScheduler has not been initialized.
"""
assert _PP_SCHEDULER is not None, "pipeline model parallel group is not initialized"
return _PP_SCHEDULER
```
## /inference/infra/parallelism/tile_parallel.py
```py path="/inference/infra/parallelism/tile_parallel.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import List
import torch
from tqdm import tqdm
class ParallelHelper:
def __init__(self):
pass
@staticmethod
def split_tile_list(
tile_numel_dict: OrderedDict[int, int], parallel_group: torch.distributed.ProcessGroup = None
) -> List[int]:
"""
Splits the given tile size into a list of sizes that each rank should handle.
This method takes into account the number of ranks in a distributed setting.
If the distributed environment is not initialized, it returns a list of
integers from 0 to tile_size - 1, representing each tile index.
If the distributed environment is initialized, it calculates the base tile size
for each rank and distributes any remaining tiles among the ranks.
Args:
tile_numel_dict (OrderedDict[int, int]): Dict of index and numel of tiles.
parallel_group (torch.distributed.ProcessGroup, optional):
Distributed decoding group. Defaults to None.
Returns:
List[int]: A list of tile indices assigned to the current rank.
List[int]: A list of global tile indices.
"""
if not torch.distributed.is_initialized():
return list(range(len(tile_numel_dict))), list(range(len(tile_numel_dict)))
else:
tile_idxs = list(OrderedDict(sorted(tile_numel_dict.items(), key=lambda x: x[1], reverse=True)).keys())
world_size = torch.distributed.get_world_size(group=parallel_group)
cur_rank = torch.distributed.get_rank(group=parallel_group)
global_tile_idxs = []
cur_rank_tile_idxs = []
for rank in range(world_size):
rank_tile_idxs = [tile_idxs[rank + world_size * i] for i in range(len(tile_idxs) // world_size)]
if rank < len(tile_idxs) % world_size:
rank_tile_idxs.append(tile_idxs[len(tile_idxs) // world_size * world_size + rank])
if rank == cur_rank:
cur_rank_tile_idxs = rank_tile_idxs
global_tile_idxs = global_tile_idxs + rank_tile_idxs
return cur_rank_tile_idxs, global_tile_idxs
@staticmethod
def gather_frames(
frames: List[torch.Tensor], global_tile_idxs: List[int], parallel_group: torch.distributed.ProcessGroup = None
) -> List[torch.Tensor]:
"""
Gathers frame data from all ranks in a distributed environment.
This method collects frames from all ranks and combines them into a single list.
If the distributed environment is not initialized, it simply returns the input frames.
Args:
frames (List[torch.Tensor]): A list of frames (tensors) from the current rank.
global_tile_idxs (List[int]): A list of global tile indices.
parallel_group (torch.distributed.ProcessGroup, optional):
Distributed decoding group. Defaults to None.
Returns:
List[torch.Tensor]: A list of frames (tensors) from all ranks.
"""
if not torch.distributed.is_initialized():
return frames
else:
# assert len(frames) > 0
# Communicate shapes
if len(frames) == 0:
cur_rank_shapes = []
else:
cur_rank_shapes = [frame.shape for frame in frames]
all_rank_shapes = [None] * torch.distributed.get_world_size(group=parallel_group)
torch.distributed.all_gather_object(all_rank_shapes, cur_rank_shapes, group=parallel_group)
all_rank_sizes = []
total_size = []
for per_rank_shapes in all_rank_shapes:
per_rank_sizes = []
per_rank_total_size = 0
for shape in per_rank_shapes:
per_rank_sizes.append(shape[0] * shape[1] * shape[2] * shape[3] * shape[4])
per_rank_total_size += shape[0] * shape[1] * shape[2] * shape[3] * shape[4]
all_rank_sizes.append(per_rank_sizes)
total_size.append(per_rank_total_size)
# Gather all frames
if len(frames) == 0:
flattened_frames = torch.zeros([0], dtype=torch.bfloat16, device="cuda")
else:
flattened_frames = torch.cat([frame.flatten().contiguous() for frame in frames], dim=0)
assert flattened_frames.dtype == torch.bfloat16
gather_tensors = [
torch.zeros(total_size[i], dtype=torch.bfloat16, device="cuda")
for i in range(torch.distributed.get_world_size(group=parallel_group))
]
torch.distributed.all_gather(gather_tensors, flattened_frames, group=parallel_group)
result_frames = []
for idx, per_rank_shapes in enumerate(all_rank_shapes):
offset = 0
for j, shape in enumerate(per_rank_shapes):
result_frames.append(gather_tensors[idx][offset : offset + all_rank_sizes[idx][j]].view(shape))
offset += all_rank_sizes[idx][j]
result_frames_dict = OrderedDict((idx, frame) for idx, frame in zip(global_tile_idxs, result_frames))
result_frames = list(OrderedDict(sorted(result_frames_dict.items())).values())
return result_frames
@staticmethod
def index_undot(index: int, loop_size: List[int]) -> List[int]:
"""
Converts a single index into a list of indices, representing the position in a multi-dimensional space.
This method takes an integer index and a list of loop sizes, and converts the index into a list of indices
that correspond to the position in a multi-dimensional space.
Args:
index (int): The single index to be converted.
loop_size (List[int]): A list of integers representing the size of each dimension in the multi-dimensional space.
Returns:
List[int]: A list of integers representing the position in the multi-dimensional space.
"""
undotted_index = []
for i in range(len(loop_size) - 1, -1, -1):
undotted_index.append(index % loop_size[i])
index = index // loop_size[i]
undotted_index.reverse()
assert len(undotted_index) == len(loop_size)
return undotted_index
@staticmethod
def index_dot(index: List[int], loop_size: List[int]) -> int:
"""
Converts a list of indices into a single index, representing the position in a multi-dimensional space.
This method takes a list of indices and a list of loop sizes, and converts the list of indices into a single index
that corresponds to the position in a multi-dimensional space.
Args:
index (List[int]): A list of integers representing the position in the multi-dimensional space.
loop_size (List[int]): A list of integers representing the size of each dimension in the multi-dimensional space.
Returns:
int: A single integer representing the position in the multi-dimensional space.
"""
assert len(index) == len(loop_size)
dot_index = 0
strides = [1]
for i in range(len(loop_size) - 1, -1, -1):
strides.append(strides[-1] * loop_size[i])
strides.reverse()
strides = strides[1:]
assert len(index) == len(strides)
for i in range(len(index)):
dot_index += index[i] * strides[i]
return dot_index
class TileProcessor:
def __init__(
self,
encode_fn,
decode_fn,
tile_sample_min_height: int = 256,
tile_sample_min_width: int = 256,
tile_sample_min_length: int = 16,
spatial_downsample_factor: int = 8,
temporal_downsample_factor: int = 1,
spatial_tile_overlap_factor: float = 0.25,
temporal_tile_overlap_factor: float = 0,
sr_ratio=1,
first_frame_as_image: bool = False,
parallel_group: torch.distributed.ProcessGroup = None,
):
"""
Initializes an instance of the class.
Args:
encode_fn (function): The encoding function used for tile sampling.
decode_fn (function): The decoding function used for tile reconstruction.
tile_sample_min_size (int, optional): The minimum size of the sampled tiles. Defaults to 256.
tile_sample_min_length (int, optional): The minimum length of the sampled tiles. Defaults to 16.
spatial_downsample_factor (int, optional): The actual spataial downsample factor of given encode_fn. Defaults to 8.
temporal_downsample_factor (int, optional): The actual temporal downsample factor of the latent space tiles. Defaults to 1.
tile_overlap_factor (float, optional): The overlap factor between adjacent tiles. Defaults to 0.25.
parallel_group (torch.distributed.ProcessGroup, optional): Distributed decoding group. Defaults to None.
"""
self.encode_fn = encode_fn
self.decode_fn = decode_fn
self.spatial_downsample_factor = spatial_downsample_factor
self.temporal_downsample_factor = temporal_downsample_factor
self.tile_sample_min_height = tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width
self.tile_sample_min_length = tile_sample_min_length
self.tile_latent_min_height = tile_sample_min_height // spatial_downsample_factor
self.tile_latent_min_width = tile_sample_min_width // spatial_downsample_factor
self.tile_latent_min_length = tile_sample_min_length // temporal_downsample_factor
if first_frame_as_image:
self.tile_latent_min_length += 1
self.spatial_tile_overlap_factor = spatial_tile_overlap_factor
self.temporal_tile_overlap_factor = temporal_tile_overlap_factor
self.sr_ratio = sr_ratio
self.parallel_group = parallel_group
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for t in range(blend_extent):
b[:, :, t, :, :] = a[:, :, -blend_extent + t, :, :] * (1 - t / blend_extent) + b[:, :, t, :, :] * (
t / blend_extent
)
return b
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.FloatTensor, verbose: bool = False):
overlap_height = int(self.tile_sample_min_height * (1 - self.spatial_tile_overlap_factor))
overlap_width = int(self.tile_sample_min_width * (1 - self.spatial_tile_overlap_factor))
overlap_length = int(self.tile_sample_min_length * (1 - self.temporal_tile_overlap_factor))
blend_extent_h = int(self.tile_latent_min_height * self.spatial_tile_overlap_factor)
blend_extent_w = int(self.tile_latent_min_width * self.spatial_tile_overlap_factor)
blend_extent_t = int(self.tile_latent_min_length * self.temporal_tile_overlap_factor)
height_limit = self.tile_latent_min_height - blend_extent_h
width_limit = self.tile_latent_min_width - blend_extent_w
frame_limit = self.tile_latent_min_length - blend_extent_t
length_tile_size = (x.shape[2] + overlap_length - 1) // overlap_length
height_tile_size = (x.shape[3] + overlap_height - 1) // overlap_height
width_tile_size = (x.shape[4] + overlap_width - 1) // overlap_width
total_tile_size = length_tile_size * height_tile_size * width_tile_size
for_loop_size = [length_tile_size, height_tile_size, width_tile_size]
tiles = []
tile_numel_dict = OrderedDict()
for tile_index in range(total_tile_size):
undot_tile_index = ParallelHelper.index_undot(tile_index, for_loop_size)
f_idx, i_idx, j_idx = undot_tile_index
f = f_idx * overlap_length
i = i_idx * overlap_height
j = j_idx * overlap_width
# Extract the tile from the latent representation and decode it
tile = x[
:,
:,
f : f + self.tile_sample_min_length,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tiles.append(tile)
tile_numel_dict[tile_index] = tile.numel()
tile_index_list, global_tile_index_list = ParallelHelper.split_tile_list(
tile_numel_dict, parallel_group=self.parallel_group
)
progress_bar = tqdm(
total=len(tile_index_list),
desc=f"[Rank {torch.distributed.get_rank(group=self.parallel_group)}] Encoding Tiles",
disable=not verbose,
)
frames = []
# Encode each tile based on the tile index list
for tile_index in tile_index_list:
tile = tiles[tile_index]
encoded = self.encode_fn(tile)
frames.append(encoded)
progress_bar.update(1)
# Gather all decoded frames from different ranks
frames = ParallelHelper.gather_frames(frames, global_tile_index_list, parallel_group=self.parallel_group)
assert len(frames) == total_tile_size
progress_bar.close()
result_frames = []
# Blend the encoded tiles to create the final output
for tile_index in range(total_tile_size):
undot_tile_index = ParallelHelper.index_undot(tile_index, for_loop_size)
f, i, j = undot_tile_index
tile = frames[tile_index]
# Blend with previous tiles if applicable
if f > 0:
idx = ParallelHelper.index_dot([f - 1, i, j], for_loop_size)
tile = self.blend_t(frames[idx], tile, blend_extent_t)
if i > 0:
idx = ParallelHelper.index_dot([f, i - 1, j], for_loop_size)
tile = self.blend_v(frames[idx], tile, blend_extent_h)
if j > 0:
idx = ParallelHelper.index_dot([f, i, j - 1], for_loop_size)
tile = self.blend_h(frames[idx], tile, blend_extent_w)
result_frames.append(tile[:, :, :frame_limit, :height_limit, :width_limit])
assert len(result_frames) == total_tile_size
concat_frames = []
for f in range(length_tile_size):
result_rows = []
for i in range(height_tile_size):
result_row = []
for j in range(width_tile_size):
idx = ParallelHelper.index_dot([f, i, j], for_loop_size)
result_row.append(result_frames[idx])
result_rows.append(torch.cat(result_row, dim=4))
concat_frames.append(torch.cat(result_rows, dim=3))
# Concatenate all result frames along the temporal dimension
result = torch.cat(concat_frames, dim=2)
return result
def tiled_decode(self, z: torch.FloatTensor, verbose: bool = False):
overlap_height = int(self.tile_latent_min_height * (1 - self.spatial_tile_overlap_factor))
overlap_width = int(self.tile_latent_min_width * (1 - self.spatial_tile_overlap_factor))
overlap_length = int(self.tile_latent_min_length * (1 - self.temporal_tile_overlap_factor))
real_tile_sample_min_height = int(self.tile_latent_min_height * self.spatial_downsample_factor * self.sr_ratio)
real_tile_sample_min_width = int(self.tile_latent_min_width * self.spatial_downsample_factor * self.sr_ratio)
real_tile_sample_min_length = int(self.tile_latent_min_length * self.temporal_downsample_factor)
blend_extent_h = int(real_tile_sample_min_height * self.spatial_tile_overlap_factor)
blend_extent_w = int(real_tile_sample_min_width * self.spatial_tile_overlap_factor)
blend_extent_t = int(real_tile_sample_min_length * self.temporal_tile_overlap_factor)
height_limit = real_tile_sample_min_height - blend_extent_h
width_limit = real_tile_sample_min_width - blend_extent_w
frame_limit = real_tile_sample_min_length - blend_extent_t
length_tile_size = (z.shape[2] + overlap_length - 1) // overlap_length
height_tile_size = (z.shape[3] + overlap_height - 1) // overlap_height
width_tile_size = (z.shape[4] + overlap_width - 1) // overlap_width
total_tile_size = length_tile_size * height_tile_size * width_tile_size
for_loop_size = [length_tile_size, height_tile_size, width_tile_size]
tiles = []
tile_numel_dict = OrderedDict()
for tile_index in range(total_tile_size):
undot_tile_index = ParallelHelper.index_undot(tile_index, for_loop_size)
f_idx, i_idx, j_idx = undot_tile_index
f = f_idx * overlap_length
i = i_idx * overlap_height
j = j_idx * overlap_width
# Extract the tile from the latent representation and decode it
tile = z[
:,
:,
f : f + self.tile_latent_min_length,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
tiles.append(tile)
tile_numel_dict[tile_index] = tile.numel()
tile_index_list, global_tile_index_list = ParallelHelper.split_tile_list(
tile_numel_dict, parallel_group=self.parallel_group
)
progress_bar = tqdm(
total=len(tile_index_list),
desc=f"[Rank {torch.distributed.get_rank(group=self.parallel_group)}] Decoding Tiles",
disable=not verbose,
)
frames = []
# Decode each tile based on the tile index list
for tile_index in tile_index_list:
tile = tiles[tile_index]
decoded = self.decode_fn(tile)
frames.append(decoded)
progress_bar.update(1)
progress_bar.close()
# Gather all decoded frames from different ranks
frames = ParallelHelper.gather_frames(frames, global_tile_index_list, parallel_group=self.parallel_group)
assert len(frames) == total_tile_size
result_frames = []
# Blend the decoded tiles to create the final output
for tile_index in tile_index_list:
undot_tile_index = ParallelHelper.index_undot(tile_index, for_loop_size)
f, i, j = undot_tile_index
tile = frames[tile_index].clone()
# Blend with previous tiles if applicable
if f > 0:
idx = ParallelHelper.index_dot([f - 1, i, j], for_loop_size)
tile = torch.compile(self.blend_t, dynamic=False)(frames[idx], tile, blend_extent_t)
if i > 0:
idx = ParallelHelper.index_dot([f, i - 1, j], for_loop_size)
tile = torch.compile(self.blend_v, dynamic=False)(frames[idx], tile, blend_extent_h)
if j > 0:
idx = ParallelHelper.index_dot([f, i, j - 1], for_loop_size)
tile = torch.compile(self.blend_h, dynamic=False)(frames[idx], tile, blend_extent_w)
result_frames.append(tile[:, :, :frame_limit, :height_limit, :width_limit])
# Gather and concatenate the final result frames
result_frames = ParallelHelper.gather_frames(result_frames, global_tile_index_list, parallel_group=self.parallel_group)
assert len(result_frames) == total_tile_size
concat_frames = []
for f in range(length_tile_size):
result_rows = []
for i in range(height_tile_size):
result_row = []
for j in range(width_tile_size):
idx = ParallelHelper.index_dot([f, i, j], for_loop_size)
result_row.append(result_frames[idx])
result_rows.append(torch.cat(result_row, dim=4))
concat_frames.append(torch.cat(result_rows, dim=3))
# Concatenate all result frames along the temporal dimension
result = torch.cat(concat_frames, dim=2)
return result
```
## /inference/model/dit/__init__.py
```py path="/inference/model/dit/__init__.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dit_model import get_dit
__all__ = ["get_dit"]
```
## /inference/model/dit/dit_model.py
```py path="/inference/model/dit/dit_model.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import math
import os
from typing import Tuple
import torch
import torch.distributed
import torch.nn as nn
from einops import rearrange
from inference.common import (
InferenceParams,
MagiConfig,
ModelMetaArgs,
PackedCoreAttnParams,
PackedCrossAttnParams,
env_is_true,
print_per_rank,
print_rank_0,
)
from inference.infra.checkpoint import load_checkpoint
from inference.infra.distributed import parallel_state as mpu
from inference.infra.parallelism import cp_post_process, cp_pre_process, pp_scheduler
from .dit_module import CaptionEmbedder, FinalLinear, LearnableRotaryEmbeddingCat, TimestepEmbedder, TransformerBlock
class VideoDiTModel(torch.nn.Module):
"""VideoDiT model for video diffusion.
Args:
config (MagiConfig): Transformer config
pre_process (bool, optional): Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True.
"""
def __init__(self, config: MagiConfig, pre_process: bool = True, post_process: bool = True) -> None:
super().__init__()
self.model_config = config.model_config
self.runtime_config = config.runtime_config
self.engine_config = config.engine_config
self.pre_process = pre_process
self.post_process = post_process
self.in_channels = self.model_config.in_channels
self.out_channels = self.model_config.out_channels
self.patch_size = self.model_config.patch_size
self.t_patch_size = self.model_config.t_patch_size
self.caption_max_length = self.model_config.caption_max_length
self.num_heads = self.model_config.num_attention_heads
self.x_embedder = nn.Conv3d(
self.model_config.in_channels,
self.model_config.hidden_size,
kernel_size=(self.model_config.t_patch_size, self.model_config.patch_size, self.model_config.patch_size),
stride=(self.model_config.t_patch_size, self.model_config.patch_size, self.model_config.patch_size),
bias=False,
)
self.t_embedder = TimestepEmbedder(model_config=self.model_config)
self.y_embedder = CaptionEmbedder(model_config=self.model_config)
self.rope = LearnableRotaryEmbeddingCat(
self.model_config.hidden_size // self.model_config.num_attention_heads, in_pixels=False
)
# trm block
self.videodit_blocks = TransformerBlock(
model_config=self.model_config,
engine_config=self.engine_config,
pre_process=pre_process,
post_process=post_process,
)
self.final_linear = FinalLinear(
self.model_config.hidden_size, self.model_config.patch_size, self.model_config.t_patch_size, self.out_channels
)
def generate_kv_range_for_uncondition(self, uncond_x) -> torch.Tensor:
device = f"cuda:{torch.cuda.current_device()}"
B, C, T, H, W = uncond_x.shape
chunk_token_nums = (
(T // self.model_config.t_patch_size) * (H // self.model_config.patch_size) * (W // self.model_config.patch_size)
)
k_chunk_start = torch.linspace(0, (B - 1) * chunk_token_nums, steps=B).reshape((B, 1))
k_chunk_end = torch.linspace(chunk_token_nums, B * chunk_token_nums, steps=B).reshape((B, 1))
return torch.concat([k_chunk_start, k_chunk_end], dim=1).to(torch.int32).to(device)
def unpatchify(self, x, H, W):
return rearrange(
x,
"(T H W) N (pT pH pW C) -> N C (T pT) (H pH) (W pW)",
H=H,
W=W,
pT=self.t_patch_size,
pH=self.patch_size,
pW=self.patch_size,
).contiguous()
@torch.no_grad()
def get_embedding_and_meta(self, x, t, y, caption_dropout_mask, xattn_mask, kv_range, **kwargs):
"""
Forward embedding and meta for VideoDiT.
NOTE: This function should only handle single card behavior.
Input:
x: (N, C, T, H, W). torch.Tensor of spatial inputs (images or latent representations of images)
t: (N, denoising_range_num). torch.Tensor of diffusion timesteps
y: (N * denoising_range_num, 1, L, C). torch.Tensor of class labels
caption_dropout_mask: (N). torch.Tensor of whether to drop caption
xattn_mask: (N * denoising_range_num, 1, L). torch.Tensor of xattn mask
kv_range: (N * denoising_range_num, 2). torch.Tensor of kv range
Output:
x: (S, N, D). torch.Tensor of inputs embedding (images or latent representations of images)
condition: (N, denoising_range_num, D). torch.Tensor of condition embedding
condition_map: (S, N). torch.Tensor determine which condition to use for each token
rope: (S, 96). torch.Tensor of rope
y_xattn_flat: (total_token, D). torch.Tensor of y_xattn_flat
cuda_graph_inputs: (y_xattn_flat, xattn_mask) or None. None means no cuda graph
NOTE: y_xattn_flat and xattn_mask with static shape
H: int. Height of the input
W: int. Width of the input
ardf_meta: dict. Meta information for ardf
cross_attn_params: PackedCrossAttnParams. Packed sequence parameters for cross_atten
"""
###################################
# Part1: Embed x #
###################################
x = self.x_embedder(x) # [N, C, T, H, W]
batch_size, _, T, H, W = x.shape
# Prepare necessary variables
range_num = kwargs["range_num"]
denoising_range_num = kwargs["denoising_range_num"]
slice_point = kwargs.get("slice_point", 0)
frame_in_range = T // denoising_range_num
prev_clean_T = frame_in_range * slice_point
T_total = T + prev_clean_T
###################################
# Part2: rope #
###################################
# caculate rescale_factor for multi-resolution & multi aspect-ratio training
# the base_size [16*16] is A predefined size based on data:(256x256) vae: (8,8,4) patch size: (1,1,2)
# This definition do not have any relationship with the actual input/model/setting.
# ref_feat_shape is used to calculate innner rescale factor, so it can be float.
rescale_factor = math.sqrt((H * W) / (16 * 16))
rope = self.rope.get_embed(shape=[T_total, H, W], ref_feat_shape=[T_total, H / rescale_factor, W / rescale_factor])
# the shape of rope is (T*H*W, -1) aka (seq_length, head_dim), as T is the first dimension, we can directly cut it.
rope = rope[-(T * H * W) :]
###################################
# Part3: Embed t #
###################################
assert t.shape[0] == batch_size, f"Invalid t shape, got {t.shape[0]} != {batch_size}" # nolint
assert t.shape[1] == denoising_range_num, f"Invalid t shape, got {t.shape[1]} != {denoising_range_num}" # nolint
t_flat = t.flatten() # (N * denoising_range_num,)
t = self.t_embedder(t_flat) # (N, D)
if self.engine_config.distill:
distill_dt_scalar = 2
if kwargs["num_steps"] == 12:
base_chunk_step = 4
distill_dt_factor = base_chunk_step / kwargs["distill_interval"] * distill_dt_scalar
else:
distill_dt_factor = kwargs["num_steps"] / 4 * distill_dt_scalar
distill_dt = torch.ones_like(t_flat) * distill_dt_factor
distill_dt_embed = self.t_embedder(distill_dt)
t = t + distill_dt_embed
t = t.reshape(batch_size, denoising_range_num, -1) # (N, range_num, D)
######################################################
# Part4: Embed y, prepare condition and y_xattn_flat #
######################################################
# (N * denoising_range_num, 1, L, D)
y_xattn, y_adaln = self.y_embedder(y, self.training, caption_dropout_mask)
assert xattn_mask is not None
xattn_mask = xattn_mask.squeeze(1).squeeze(1)
# condition: (N, range_num, D)
y_adaln = y_adaln.squeeze(1) # (N, D)
condition = t + y_adaln.unsqueeze(1)
assert condition.shape[0] == batch_size
assert condition.shape[1] == denoising_range_num
seqlen_per_chunk = (T * H * W) // denoising_range_num
condition_map = torch.arange(batch_size * denoising_range_num, device=x.device)
condition_map = torch.repeat_interleave(condition_map, seqlen_per_chunk)
condition_map = condition_map.reshape(batch_size, -1).transpose(0, 1).contiguous()
# y_xattn_flat: (total_token, D)
y_xattn_flat = torch.masked_select(y_xattn.squeeze(1), xattn_mask.unsqueeze(-1).bool()).reshape(-1, y_xattn.shape[-1])
xattn_mask_for_cuda_graph = None
######################################################
# Part5: Prepare cross_attn_params for cross_atten #
######################################################
# (N * denoising_range_num, L)
xattn_mask = xattn_mask.reshape(xattn_mask.shape[0], -1)
y_index = torch.sum(xattn_mask, dim=-1)
clip_token_nums = H * W * frame_in_range
cu_seqlens_q = torch.Tensor([0] + ([clip_token_nums] * denoising_range_num * batch_size)).to(torch.int64).to(x.device)
cu_seqlens_k = torch.cat([y_index.new_tensor([0]), y_index]).to(torch.int64).to(x.device)
cu_seqlens_q = cu_seqlens_q.cumsum(-1).to(torch.int32)
cu_seqlens_k = cu_seqlens_k.cumsum(-1).to(torch.int32)
assert (
cu_seqlens_q.shape == cu_seqlens_k.shape
), f"cu_seqlens_q.shape: {cu_seqlens_q.shape}, cu_seqlens_k.shape: {cu_seqlens_k.shape}"
xattn_q_ranges = torch.cat([cu_seqlens_q[:-1].unsqueeze(1), cu_seqlens_q[1:].unsqueeze(1)], dim=1)
xattn_k_ranges = torch.cat([cu_seqlens_k[:-1].unsqueeze(1), cu_seqlens_k[1:].unsqueeze(1)], dim=1)
assert (
xattn_q_ranges.shape == xattn_k_ranges.shape
), f"xattn_q_ranges.shape: {xattn_q_ranges.shape}, xattn_k_ranges.shape: {xattn_k_ranges.shape}"
cross_attn_params = PackedCrossAttnParams(
q_ranges=xattn_q_ranges,
kv_ranges=xattn_k_ranges,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=clip_token_nums,
max_seqlen_kv=self.caption_max_length,
)
##################################################
# Part6: Prepare core_atten related q/kv range #
##################################################
q_range = torch.cat([cu_seqlens_q[:-1].unsqueeze(1), cu_seqlens_q[1:].unsqueeze(1)], dim=1)
flat_kv = torch.unique(kv_range, sorted=True)
max_seqlen_k = (flat_kv[-1] - flat_kv[0]).cpu().item()
ardf_meta = dict(
clip_token_nums=clip_token_nums,
slice_point=slice_point,
range_num=range_num,
denoising_range_num=denoising_range_num,
q_range=q_range,
k_range=kv_range,
max_seqlen_q=clip_token_nums,
max_seqlen_k=max_seqlen_k,
)
return (x, condition, condition_map, rope, y_xattn_flat, xattn_mask_for_cuda_graph, H, W, ardf_meta, cross_attn_params)
@torch.no_grad()
def forward_pre_process(
self, x, t, y, caption_dropout_mask=None, xattn_mask=None, kv_range=None, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ModelMetaArgs]:
assert kv_range is not None, "Please ensure kv_range is provided"
x = x * self.model_config.x_rescale_factor
if self.model_config.half_channel_vae:
assert x.shape[1] == 16
x = torch.cat([x, x], dim=1)
x = x.float()
t = t.float()
y = y.float()
# embedder context will ensure that the processing is in high precision even if the embedder params is in bfloat16 mode
with torch.autocast(device_type="cuda", dtype=torch.float32):
(
x,
condition,
condition_map,
rope,
y_xattn_flat,
xattn_mask_for_cuda_graph,
H,
W,
ardf_meta,
cross_attn_params,
) = self.get_embedding_and_meta(x, t, y, caption_dropout_mask, xattn_mask, kv_range, **kwargs)
# Downcast x and rearrange x
x = x.to(self.model_config.params_dtype)
x = rearrange(x, "N C T H W -> (T H W) N C").contiguous() # (thw, N, D)
# condition and y_xattn_flat will be downcast to bfloat16 in transformer block.
condition = condition.to(self.model_config.params_dtype)
y_xattn_flat = y_xattn_flat.to(self.model_config.params_dtype)
core_attn_params = PackedCoreAttnParams(
q_range=ardf_meta["q_range"],
k_range=ardf_meta["k_range"],
np_q_range=ardf_meta["q_range"].cpu().numpy(),
np_k_range=ardf_meta["k_range"].cpu().numpy(),
max_seqlen_q=ardf_meta["max_seqlen_q"],
max_seqlen_k=ardf_meta["max_seqlen_k"],
)
(x, condition_map, rope, cp_pad_size, cp_split_sizes, core_attn_params, cross_attn_params) = cp_pre_process(
self.engine_config.cp_size,
self.engine_config.cp_strategy,
x,
condition_map,
rope,
xattn_mask_for_cuda_graph,
ardf_meta,
core_attn_params,
cross_attn_params,
)
meta_args = ModelMetaArgs(
H=H,
W=W,
cp_pad_size=cp_pad_size,
cp_split_sizes=cp_split_sizes,
slice_point=ardf_meta["slice_point"],
denoising_range_num=ardf_meta["denoising_range_num"],
range_num=ardf_meta["range_num"],
extract_prefix_video_feature=kwargs.get("extract_prefix_video_feature", False),
fwd_extra_1st_chunk=kwargs["fwd_extra_1st_chunk"],
distill_nearly_clean_chunk=kwargs.get("distill_nearly_clean_chunk", False),
clip_token_nums=ardf_meta["clip_token_nums"],
enable_cuda_graph=xattn_mask_for_cuda_graph is not None,
core_attn_params=core_attn_params,
cross_attn_params=cross_attn_params,
)
return (x, condition, condition_map, y_xattn_flat, rope, meta_args)
@torch.no_grad()
def forward_post_process(self, x, meta_args: ModelMetaArgs) -> torch.Tensor:
x = x.float()
# embedder context will ensure that the processing is in high precision even if the embedder params is in bfloat16 mode
with torch.autocast(device_type="cuda", dtype=torch.float32):
x = self.final_linear(x) # (thw/cp, N, patch_size ** 2 * out_channels)
# leave context parallel region
x = cp_post_process(self.engine_config.cp_size, self.engine_config.cp_strategy, x, meta_args)
# N C T H W
x = self.unpatchify(x, meta_args.H, meta_args.W)
if self.model_config.half_channel_vae:
assert x.shape[1] == 32
x = x[:, :16]
x = x / self.model_config.x_rescale_factor
return x
@torch.no_grad()
def forward(
self,
x,
t,
y,
caption_dropout_mask=None,
xattn_mask=None,
kv_range=None,
inference_params: InferenceParams = None,
**kwargs,
) -> torch.Tensor:
(x, condition, condition_map, y_xattn_flat, rope, meta_args) = self.forward_pre_process(
x, t, y, caption_dropout_mask, xattn_mask, kv_range, **kwargs
)
if not self.pre_process:
x = pp_scheduler().recv_prev_data(x.shape, x.dtype)
self.videodit_blocks.set_input_tensor(x)
else:
# clone a new tensor to ensure x is not a view of other tensor
x = x.clone()
x = self.videodit_blocks.forward(
hidden_states=x,
condition=condition,
condition_map=condition_map,
y_xattn_flat=y_xattn_flat,
rotary_pos_emb=rope,
inference_params=inference_params,
meta_args=meta_args,
)
if not self.post_process:
pp_scheduler().isend_next(x)
return self.forward_post_process(x, meta_args)
def forward_3cfg(
self, x, timestep, y, mask, kv_range, inference_params, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
assert x.shape[0] == 2
assert mask.shape[0] % 2 == 0 # mask should be a multiple of 2
x = torch.cat([x[0:1], x[0:1]], dim=0)
caption_dropout_mask = torch.tensor([False, True], dtype=torch.bool, device=x.device)
inference_params.update_kv_cache = False
out_cond_pre_and_text = self.forward(
x[0:1],
timestep[0:1],
y[0 : y.shape[0] // 2],
caption_dropout_mask=caption_dropout_mask[0:1],
xattn_mask=mask[0 : y.shape[0] // 2],
kv_range=kv_range,
inference_params=inference_params,
**kwargs,
)
inference_params.update_kv_cache = True
out_cond_pre = self.forward(
x[1:2],
timestep[1:2],
y[y.shape[0] // 2 : y.shape[0]],
caption_dropout_mask=caption_dropout_mask[1:2],
xattn_mask=mask[y.shape[0] // 2 : y.shape[0]],
kv_range=kv_range,
inference_params=inference_params,
**kwargs,
)
def chunk_to_batch(input, denoising_range_num):
input = input.squeeze(0)
input = input.reshape(-1, denoising_range_num, kwargs["chunk_width"], *input.shape[2:])
return input.transpose(0, 1) # (denoising_range_num, chn, chunk_width, h, w)
def batch_to_chunk(input, denoising_range_num):
input = input.transpose(0, 1)
input = input.reshape(1, -1, denoising_range_num * kwargs["chunk_width"], *input.shape[3:])
return input
class UnconditionGuard:
def __init__(self, kwargs):
self.kwargs = kwargs
self.prev_state = {
"range_num": kwargs["range_num"],
"denoising_range_num": kwargs["denoising_range_num"],
"slice_point": kwargs["slice_point"],
"fwd_extra_1st_chunk": kwargs["fwd_extra_1st_chunk"],
}
def __enter__(self):
if self.kwargs.get("fwd_extra_1st_chunk", False):
self.kwargs["denoising_range_num"] -= 1
self.kwargs["slice_point"] += 1
self.kwargs["fwd_extra_1st_chunk"] = False
def __exit__(self, exc_type, exc_val, exc_tb):
self.kwargs["range_num"] = self.prev_state["range_num"]
self.kwargs["denoising_range_num"] = self.prev_state["denoising_range_num"]
self.kwargs["slice_point"] = self.prev_state["slice_point"]
self.kwargs["fwd_extra_1st_chunk"] = self.prev_state["fwd_extra_1st_chunk"]
with UnconditionGuard(kwargs):
denoising_range_num = kwargs["denoising_range_num"]
denoise_width = kwargs["chunk_width"] * denoising_range_num
uncond_x = chunk_to_batch(x[0:1, :, -denoise_width:], denoising_range_num)
timestep = timestep[0:1, -denoising_range_num:].transpose(0, 1)
uncond_y = y[y.shape[0] // 2 : y.shape[0]][-denoising_range_num:]
caption_dropout_mask = torch.tensor([True], dtype=torch.bool, device=x.device)
uncond_mask = mask[y.shape[0] // 2 : y.shape[0]][-denoising_range_num:]
uncond_kv_range = self.generate_kv_range_for_uncondition(uncond_x)
kwargs["range_num"] = 1
kwargs["denoising_range_num"] = 1
kwargs["slice_point"] = 0
out_uncond = self.forward(
uncond_x,
timestep,
uncond_y,
caption_dropout_mask=caption_dropout_mask,
xattn_mask=uncond_mask,
kv_range=uncond_kv_range,
inference_params=None,
**kwargs,
)
out_uncond = batch_to_chunk(out_uncond, denoising_range_num)
return out_cond_pre_and_text, out_cond_pre, out_uncond, denoise_width
def get_cfg_scale(self, t, cfg_t_range, prev_chunk_scale_s, text_scale_s):
indices = torch.searchsorted(cfg_t_range - 1e-7, t) - 1
assert indices.min() >= 0 and indices.max() < len(prev_chunk_scale_s)
return prev_chunk_scale_s[indices], text_scale_s[indices]
def forward_dispatcher(self, x, timestep, y, mask, kv_range, inference_params, **kwargs):
if self.runtime_config.cfg_number == 3:
(out_cond_pre_and_text, out_cond_pre, out_uncond, denoise_width) = self.forward_3cfg(
x, timestep, y, mask, kv_range, inference_params, **kwargs
)
prev_chunk_scale_s = torch.tensor(self.runtime_config.prev_chunk_scales).cuda()
text_scale_s = torch.tensor(self.runtime_config.text_scales).cuda()
cfg_t_range = torch.tensor(self.runtime_config.cfg_t_range).cuda()
applied_cfg_range_num, chunk_width = (kwargs["denoising_range_num"], kwargs["chunk_width"])
if kwargs["fwd_extra_1st_chunk"]:
applied_cfg_range_num -= 1
cfg_timestep = timestep[0, -applied_cfg_range_num:]
assert len(prev_chunk_scale_s) == len(cfg_t_range), "prev_chunks_scale and t_range should have the same length"
assert len(text_scale_s) == len(cfg_t_range), "text_scale and t_range should have the same length"
cfg_output_list = []
for chunk_idx in range(applied_cfg_range_num):
prev_chunk_scale, text_scale = self.get_cfg_scale(
cfg_timestep[chunk_idx], cfg_t_range, prev_chunk_scale_s, text_scale_s
)
l = chunk_idx * chunk_width
r = (chunk_idx + 1) * chunk_width
cfg_output = (
(1 - prev_chunk_scale) * out_uncond[:, :, l:r]
+ (prev_chunk_scale - text_scale) * out_cond_pre[:, :, -denoise_width:][:, :, l:r]
+ text_scale * out_cond_pre_and_text[:, :, -denoise_width:][:, :, l:r]
)
cfg_output_list.append(cfg_output)
cfg_output = torch.cat(cfg_output_list, dim=2)
x = torch.cat([x[0:1, :, :-denoise_width], cfg_output], dim=2)
x = torch.cat([x, x], dim=0)
return x
elif self.runtime_config.cfg_number == 1:
assert x.shape[0] == 2
x = torch.cat([x[0:1], x[0:1]], dim=0)
kwargs["caption_dropout_mask"] = torch.tensor([False], dtype=torch.bool, device=x.device)
inference_params.update_kv_cache = True
if kwargs.get("distill_nearly_clean_chunk", False):
prev_chunks_scale = float(os.getenv("prev_chunks_scale", 0.7))
slice_start = 1 if kwargs["fwd_extra_1st_chunk"] else 0
cond_pre_and_text_channel = x.shape[2]
new_x_chunk = x[0:1, :, slice_start * kwargs["chunk_width"] : (slice_start + 1) * kwargs["chunk_width"]]
new_kvrange = self.generate_kv_range_for_uncondition(new_x_chunk)
kwargs["denoising_range_num"] += 1
cat_x_chunk = torch.cat([x[0:1], new_x_chunk], dim=2)
new_kvrange = new_kvrange + kv_range.max()
cat_kvrange = torch.cat([kv_range, new_kvrange], dim=0)
cat_t = torch.cat([timestep[0:1], timestep[0:1, slice_start : slice_start + 1]], dim=1)
cat_y = torch.cat([y[0 : y.shape[0] // 2], y[slice_start : slice_start + 1]], dim=0)
cat_xattn_mask = torch.cat([mask[0 : y.shape[0] // 2], mask[slice_start : slice_start + 1]], dim=0)
cat_out = self.forward(
cat_x_chunk,
cat_t,
cat_y,
xattn_mask=cat_xattn_mask,
kv_range=cat_kvrange,
inference_params=inference_params,
**kwargs,
)
near_clean_out_cond_pre_and_text = cat_out[
:, :, slice_start * kwargs["chunk_width"] : (slice_start + 1) * kwargs["chunk_width"]
]
near_clean_out_cond_text = cat_out[:, :, cond_pre_and_text_channel:]
near_out_cond_pre_and_text = (
near_clean_out_cond_pre_and_text * prev_chunks_scale + near_clean_out_cond_text * (1 - prev_chunks_scale)
)
cat_out[
:, :, slice_start * kwargs["chunk_width"] : (slice_start + 1) * kwargs["chunk_width"]
] = near_out_cond_pre_and_text
out_cond_pre_and_text = cat_out[:, :, :cond_pre_and_text_channel]
else:
out_cond_pre_and_text = self.forward(
x[0:1],
timestep[0:1],
y[0 : y.shape[0] // 2],
xattn_mask=mask[0 : y.shape[0] // 2],
kv_range=kv_range,
inference_params=inference_params,
**kwargs,
)
denoise_width = kwargs["chunk_width"] * kwargs["denoising_range_num"]
if kwargs["fwd_extra_1st_chunk"]:
denoise_width -= kwargs["chunk_width"]
x = torch.cat([x[0:1, :, :-denoise_width], out_cond_pre_and_text[:, :, -denoise_width:]], dim=2)
x = torch.cat([x[0:1], x[0:1]], dim=0)
return x
else:
raise NotImplementedError
def _build_dit_model(config: MagiConfig):
"""Builds the model"""
device = "cuda" if env_is_true("SKIP_LOAD_MODEL") else "meta"
with torch.device(device):
model = VideoDiTModel(
config=config, pre_process=mpu.is_pipeline_first_stage(), post_process=mpu.is_pipeline_last_stage()
)
print_rank_0(model)
# Print number of parameters.
param_count = sum([p.nelement() for p in model.parameters()])
model_size_gb = sum([p.nelement() * p.element_size() for p in model.parameters()]) / (1024**3)
print_per_rank(
f"(cp, pp) rank ({mpu.get_cp_rank()}, {mpu.get_pp_rank()}): param count {param_count}, model size {model_size_gb:.2f} GB".format(
mpu.get_cp_rank(), mpu.get_pp_rank(), param_count, model_size_gb
)
)
return model
def _high_precision_promoter(module: VideoDiTModel):
module.x_embedder.float()
module.y_embedder.float()
module.t_embedder.float()
module.final_linear.float()
module.rope.float()
for name, sub_module in module.named_modules():
# skip qk_layernorm_xattn
if "_xattn" in name:
continue
# high precision qk_layernorm by default
if "q_layernorm" in name or "k_layernorm" in name:
sub_module.float()
if "self_attn_post_norm" in name or "mlp_post_norm" in name:
sub_module.float()
if "final_layernorm" in name:
sub_module.float()
return module
def get_dit(config: MagiConfig):
"""Build and load VideoDiT model"""
model = _build_dit_model(config)
print_rank_0("Build DiTModel successfully")
mem_allocated_gb = torch.cuda.memory_allocated() / 1024**3
mem_reserved_gb = torch.cuda.memory_reserved() / 1024**3
print_rank_0(
f"After build_dit_model, memory allocated: {mem_allocated_gb:.2f} GB, memory reserved: {mem_reserved_gb:.2f} GB"
)
# To avoid Error in debug mode, set default iteration to 0
if not env_is_true("SKIP_LOAD_MODEL"):
model = load_checkpoint(model)
mem_allocated_gb = torch.cuda.memory_allocated() / 1024**3
mem_reserved_gb = torch.cuda.memory_reserved() / 1024**3
print_rank_0(
f"After load_checkpoint, memory allocated: {mem_allocated_gb:.2f} GB, memory reserved: {mem_reserved_gb:.2f} GB"
)
model = _high_precision_promoter(model)
mem_allocated_gb = torch.cuda.memory_allocated() / 1024**3
mem_reserved_gb = torch.cuda.memory_reserved() / 1024**3
print_rank_0(
f"After high_precision_promoter, memory allocated: {mem_allocated_gb:.2f} GB, memory reserved: {mem_reserved_gb:.2f} GB"
)
model.eval()
gc.collect()
torch.cuda.empty_cache()
print_rank_0("Load checkpoint successfully")
return model
```
## /inference/model/dit/dit_module.py
```py path="/inference/model/dit/dit_module.py"
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numbers
from functools import partial
from typing import Callable, List, Optional, Tuple
import flashinfer
import torch
import torch.distributed
import torch.nn as nn
import triton
import triton.language as tl
from einops import rearrange
from flash_attn import flash_attn_varlen_func
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.layers.rotary import apply_rotary_emb as flash_apply_rotary_emb
from flashinfer.gemm import bmm_fp8
from magi_attention.functional import flex_flash_attn_func as flex_attention
# from dffa.functional import flex_flash_attn_func as flex_attention
from torch import Tensor
from torch.nn import Parameter
from inference.common import EngineConfig, InferenceParams, ModelConfig, ModelMetaArgs, PackedCrossAttnParams, divide
from inference.infra.distributed import parallel_state
from inference.infra.parallelism import CSOHelper, UlyssesScheduler, cso_communication
##########################################################
# TimestepEmbedder
##########################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, model_config: ModelConfig, frequency_embedding_size=256):
super().__init__()
self.data_type = model_config.params_dtype
hidden_size = model_config.hidden_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, int(hidden_size * model_config.cond_hidden_ratio), bias=True),
nn.SiLU(),
nn.Linear(
int(hidden_size * model_config.cond_hidden_ratio), int(hidden_size * model_config.cond_hidden_ratio), bias=True
),
)
self.frequency_embedding_size = frequency_embedding_size
# rescale the timestep for the general transport model
self.timestep_rescale_factor = 1000
@staticmethod
def timestep_embedding(t, dim, max_period=10000, timestep_rescale_factor=1):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=t.device
)
args = t[:, None].float() * freqs[None] * timestep_rescale_factor
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t = t.to(torch.float32)
t_freq = self.timestep_embedding(
t, self.frequency_embedding_size, timestep_rescale_factor=self.timestep_rescale_factor
)
t_emb = self.mlp(t_freq.to(self.data_type))
return t_emb
##########################################################
# CaptionEmbedder
##########################################################
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, model_config: ModelConfig):
super().__init__()
in_channels = model_config.caption_channels
hidden_size = model_config.hidden_size
caption_max_length = model_config.caption_max_length
self.y_proj_xattn = nn.Sequential(
nn.Linear(in_channels, int(hidden_size * model_config.xattn_cond_hidden_ratio), bias=True), nn.SiLU()
)
self.y_proj_adaln = nn.Sequential(nn.Linear(in_channels, int(hidden_size * model_config.cond_hidden_ratio), bias=True))
self.null_caption_embedding = Parameter(torch.empty(caption_max_length, in_channels))
def caption_drop(self, caption, caption_dropout_mask):
"""
Drops labels to enable classifier-free guidance.
caption.shape = (N, 1, cap_len, C)
"""
dropped_caption = torch.where(
caption_dropout_mask[:, None, None, None], # (N, 1, 1, 1)
self.null_caption_embedding[None, None, :], # (1, 1, cap_len, C)
caption, # (N, 1, cap_len, C)
)
return dropped_caption
def caption_drop_single_token(self, caption_dropout_mask):
dropped_caption = torch.where(
caption_dropout_mask[:, None, None], # (N, 1, 1)
self.null_caption_embedding[None, -1, :], # (1, 1, C)
self.null_caption_embedding[None, -2, :], # (1, 1, C)
)
return dropped_caption # (N, 1, C)
def forward(self, caption, train, caption_dropout_mask=None):
if train and caption_dropout_mask is not None:
caption = self.caption_drop(caption, caption_dropout_mask)
caption_xattn = self.y_proj_xattn(caption)
if caption_dropout_mask is not None:
caption = self.caption_drop_single_token(caption_dropout_mask)
caption_adaln = self.y_proj_adaln(caption)
return caption_xattn, caption_adaln
##########################################################
# FinalLinear
##########################################################
class FinalLinear(nn.Module):
"""
The final linear layer of DiT.
"""
def __init__(self, hidden_size, patch_size, t_patch_size, out_channels):
super().__init__()
self.linear = nn.Linear(hidden_size, patch_size * patch_size * t_patch_size * out_channels, bias=False)
def forward(self, x):
x = self.linear(x)
return x
##########################################################
# AdaModulateLayer
##########################################################
class AdaModulateLayer(torch.nn.Module):
def __init__(self, model_config: ModelConfig):
super().__init__()
self.model_config = model_config
self.gate_num_chunks = 2
self.act = nn.SiLU()
self.proj = nn.Sequential(
nn.Linear(
int(self.model_config.hidden_size * self.model_config.cond_hidden_ratio),
int(self.model_config.hidden_size * self.model_config.cond_gating_ratio * self.gate_num_chunks),
bias=True,
dtype=self.model_config.params_dtype,
)
)
def forward(self, c):
c = self.act(c)
return self.proj(c)
##########################################################
# bias_modulate_add
##########################################################
@triton.jit
def range_mod_kernel_fwd(
X, # pointer to the input
MAP, # map x index to gating index
GATINGS, # pointer to the gatings
Y, # pointer to the output
M, # number of rows in X, unused
N, # number of columns in X
stride_xm, # how much to increase the pointer when moving by 1 row in X
stride_xn, # how much to increase the pointer when moving by 1 column in X
stride_gm, # how much to increase the pointer when moving by 1 row in GATINGS
stride_gn, # how much to increase the pointer when moving by 1 column in GATINGS
stride_ym, # how much to increase the pointer when moving by 1 row in Y
stride_yn, # how much to increase the pointer when moving by 1 column in Y
BLOCK_SIZE: tl.constexpr, # number of columns in a block
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
cur_X = X + row * stride_xm
x_cols = tl.arange(0, BLOCK_SIZE) * stride_xn
x_mask = x_cols < N * stride_xn
x = tl.load(cur_X + x_cols, mask=x_mask, other=0.0)
cur_MAP = MAP + row
gating_index = tl.load(cur_MAP)
cur_GATING = GATINGS + gating_index * stride_gm
gating_cols = tl.arange(0, BLOCK_SIZE) * stride_gn
gating_mask = gating_cols < N * stride_gn
gating = tl.load(cur_GATING + gating_cols, mask=gating_mask, other=0.0)
cur_Y = Y + row * stride_ym
y_cols = tl.arange(0, BLOCK_SIZE) * stride_yn
y_mask = y_cols < N * stride_yn
tl.store(cur_Y + y_cols, x * gating, mask=y_mask)
def range_mod_triton(x, c_mapping, gatings):
"""
Inputs:
x: (s, b, h). Tensor of inputs embedding (images or latent representations of images)
c_mapping: (s, b). Tensor of condition map
gatings: (b, denoising_range_num, h). Tensor of condition embedding
"""
assert x.is_cuda, "x is not on cuda"
assert c_mapping.is_cuda, "c_mapping is not on cuda"
assert gatings.is_cuda, "gatings is not on cuda"
# TODO: use 3D tensor for x, c_mapping, and gatings
s, b, h = x.shape
x = x.transpose(0, 1).flatten(0, 1)
c_mapping = c_mapping.transpose(0, 1).flatten(0, 1)
gatings = gatings.flatten(0, 1)
assert x.dim() == 2, f"x must be a 2D tensor but got {x.dim()}D"
assert c_mapping.dim() == 1, f"c_mapping must be a 1D tensor but got {c_mapping.dim()}D"
assert gatings.dim() == 2, f"gatings must be a 2D tensor but got {gatings.dim()}D"
M, N = x.shape
assert c_mapping.size(0) == M, "c_mapping must have the same number of rows as x"
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("range_mod_triton doesn't support feature dim >= 64KB.")
MAP = c_mapping
y = torch.empty_like(x)
range_mod_kernel_fwd[(M,)](
x,
MAP,
gatings,
y,
M,
N,
x.stride(0),
x.stride(1),
gatings.stride(0),
gatings.stride(1),
y.stride(0),
y.stride(1),
BLOCK_SIZE=BLOCK_SIZE,
)
y = y.reshape(b, s, h).transpose(0, 1)
return y
def bias_modulate_add(
x: torch.Tensor, residual: torch.Tensor, condition_map: torch.Tensor, gate: torch.Tensor, post_norm: torch.nn.Module
):
assert gate.shape[-1] == x.shape[-1]
original_dtype = x.dtype
x = x.float()
residual = residual.float()
gate = gate.float()
x = range_mod_triton(x, condition_map, gate)
x = post_norm(x)
x = x + residual
x = x.to(original_dtype)
return x
##########################################################
# FusedLayerNorm
##########################################################
def make_viewless_tensor(inp, requires_grad):
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad)
out.data = inp.data
return out
class FusedLayerNorm(torch.nn.Module):
"""
Layer Norm, fused into a single CUDA kernel.
Borrow from: https://github.com/NVIDIA/Megatron-LM/blob/6501752396e9cc360ce894cda4b2217a58c1c09d/megatron/core/fusions/fused_layer_norm.py#L30
Args:
hidden_size (int): Transformer hidden dimension.
eps (float): Epsilon added to denominator, for numerical stability.
zero_centered_gamma (bool): Adjust LayerNorm weights such that they are
centered around zero. This improves numerical stability.
model_config (ModelConfig): Transformer config. Include to match custom
layer norm interfaces.
normalization (str): Normalization type, used for Transformer Engine.
Must equal 'LayerNorm' here.
"""
def __init__(self, model_config: ModelConfig, hidden_size: int):
super().__init__()
self.zero_centered_gamma = model_config.apply_layernorm_1p
if isinstance(hidden_size, numbers.Integral):
hidden_size = (hidden_size,)
self.hidden_size = torch.Size(hidden_size)
self.eps = model_config.layernorm_epsilon
self.weight = Parameter(torch.empty(*hidden_size, dtype=model_config.params_dtype))
self.bias = Parameter(torch.empty(*hidden_size, dtype=model_config.params_dtype))
def forward(self, input: Tensor) -> Tensor:
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
return torch.nn.functional.layer_norm(input, self.hidden_size, weight, self.bias, self.eps)
def softcap(x: torch.Tensor, cap: int):
return (cap * torch.tanh(x.float() / cap)).to(x.dtype)
def div_clamp_to(x: torch.Tensor, scale: torch.Tensor):
fp8_min = torch.finfo(torch.float8_e4m3fn).min
fp8_max = torch.finfo(torch.float8_e4m3fn).max
prefix_shape = x.shape[:-1]
last_shape = x.shape[-1]
x = x.flatten().reshape(-1, last_shape)
# Split x into 256 MB parts to avoid big memory peak
part_size = 256 * 1024 * 1024 // last_shape
part_num = (x.shape[0] + part_size - 1) // part_size
return (
torch.cat(
[
torch.clamp(x[i * part_size : (i + 1) * part_size].float() / scale.float(), fp8_min, fp8_max).bfloat16()
for i in range(part_num)
],
dim=0,
)
.to(torch.float8_e4m3fn)
.reshape(*prefix_shape, last_shape)
.contiguous()
)
##########################################################
# CustomLayerNormLinear
##########################################################
class CustomLayerNormLinear(torch.nn.Module):
def __init__(
self,
input_size: int,
output_size_q: int,
output_size_kv: int,
layer_number: int,
model_config: ModelConfig,
engine_config: EngineConfig,
):
super().__init__()
self.layer_norm = torch.nn.LayerNorm(input_size, eps=model_config.layernorm_epsilon, dtype=model_config.params_dtype)
self.layer_number = layer_number
layers = {"q": output_size_q, "qx": output_size_q, "k": output_size_kv, "v": output_size_kv}
for name, output_size in layers.items():
if not engine_config.fp8_quant or self.layer_number == 0 or self.layer_number == model_config.num_layers - 1:
setattr(self, name, torch.nn.Linear(input_size, output_size, bias=False, dtype=model_config.params_dtype))
else:
setattr(self, name, PerTensorQuantizedFp8Linear(input_size, output_size))
def forward_ln(self, hidden_states):
return self.layer_norm(hidden_states)
def forward_q(self, hidden_states):
return self.q(hidden_states)
def forward_qx(self, hidden_states):
return self.qx(hidden_states)
def forward_k(self, hidden_states):
return self.k(hidden_states)
def forward_v(self, hidden_states):
return self.v(hidden_states)
##########################################################
# PerTensorQuantizedFp8Linear
##########################################################
class PerTensorQuantizedFp8Linear(torch.nn.Module):
# The bias and device parameter is not used; it is included for compatibility with Linear's parameters.
def __init__(self, in_features: int, out_features: int, bias=False, dtype=torch.bfloat16, device=None) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.finfo = torch.finfo(torch.float8_e4m3fn)
self.output_dtype = dtype
self.weight = Parameter(torch.empty((1, out_features, in_features), dtype=torch.float8_e4m3fn))
self.weight_scale = Parameter(torch.empty(1, dtype=torch.float32))
self.input_scale = Parameter(torch.empty(in_features, dtype=torch.float32))
def forward(self, input: torch.Tensor):
input = div_clamp_to(input, self.input_scale)
prefix_shape = input.shape[:-1]
# column major weight
return bmm_fp8(
input.reshape(1, -1, self.in_features),
self.weight.transpose(-2, -1),
self.input_scale,
self.weight_scale,
dtype=self.output_dtype,
).reshape(prefix_shape + (self.out_features,))
##########################################################
# PerChannelQuantizedFp8Linear
##########################################################
class PerChannelQuantizedFp8Linear(torch.nn.Module):
# The bias and device parameter is not used; it is included for compatibility with Linear's parameters.
def __init__(self, in_features: int, out_features: int, bias=False, dtype=torch.bfloat16, device=None) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.output_dtype = dtype
self.finfo = torch.finfo(torch.float8_e4m3fn)
self.weight = Parameter(torch.empty((1, out_features, in_features), dtype=torch.float8_e4m3fn))
self.weight_scale = Parameter(torch.empty(1, dtype=torch.float32))
self.input_scale = Parameter(torch.empty(1, dtype=torch.float32))
self.smooth_scale = Parameter(torch.empty(1, in_features, dtype=torch.float32))
def forward(self, x):
x = div_clamp_to(x, self.smooth_scale.to(torch.float32))
prefix_shape = x.shape[:-1]
return bmm_fp8(
x.reshape(1, -1, self.in_features),
self.weight.transpose(-2, -1),
self.input_scale,
self.weight_scale,
dtype=self.output_dtype,
).reshape(prefix_shape + (self.out_features,))
##########################################################
# CustomMLP
##########################################################
class CustomMLP(torch.nn.Module):
"""
CustomMLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
Returns an output and a bias to be added to the output.
We use the following notation:
h: hidden size
p: number of tensor model parallel partitions
b: batch size
s: sequence length
"""
def __init__(self, model_config: ModelConfig, engine_config: EngineConfig, layer_number: int, input_size: int = None):
super().__init__()
self.model_config: ModelConfig = model_config
self.engine_config: EngineConfig = engine_config
self.layer_number = layer_number
self.input_size = input_size if input_size != None else self.model_config.hidden_size
self.layer_norm = torch.nn.LayerNorm(
self.input_size, eps=self.model_config.layernorm_epsilon, dtype=self.model_config.params_dtype
)
submodules_linear_fc1 = torch.nn.Linear
if self.engine_config.fp8_quant and self.layer_number != 0 and self.layer_number != model_config.num_layers - 1:
submodules_linear_fc1 = PerTensorQuantizedFp8Linear
if self.model_config.gated_linear_unit:
self.linear_fc1 = submodules_linear_fc1(
self.input_size, 2 * self.model_config.ffn_hidden_size, bias=False, dtype=self.model_config.params_dtype
)
else:
self.linear_fc1 = submodules_linear_fc1(
self.input_size, self.model_config.ffn_hidden_size, bias=False, dtype=self.model_config.params_dtype
)
submodules_linear_fc2 = torch.nn.Linear
if engine_config.fp8_quant and self.layer_number != 0 and self.layer_number != model_config.num_layers - 1:
submodules_linear_fc2 = PerChannelQuantizedFp8Linear
self.linear_fc2 = submodules_linear_fc2(
self.model_config.ffn_hidden_size, self.model_config.hidden_size, bias=False, dtype=self.model_config.params_dtype
)
def forward(self, hidden_states):
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.linear_fc1(hidden_states)
if self.model_config.gated_linear_unit:
hidden_states = flashinfer.activation.silu_and_mul(hidden_states)
else:
hidden_states = torch.nn.functional.gelu(hidden_states)
hidden_states = self.linear_fc2(hidden_states)
return hidden_states
##########################################################
# LearnableRotaryEmbeddingCat
##########################################################
def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
"""generate N-D grid in dimension order.
The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
That is, the statement
[X1,X2,X3] = ndgrid(x1,x2,x3)
produces the same result as
[X2,X1,X3] = meshgrid(x2,x1,x3)
This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
"""
try:
return torch.meshgrid(*tensors, indexing="ij")
except TypeError:
# old PyTorch < 1.10 will follow this path as it does not have indexing arg,
# the old behaviour of meshgrid was 'ij'
return torch.meshgrid(*tensors)
def pixel_freq_bands(
num_bands: int, max_freq: float = 224.0, linear_bands: bool = True, device: Optional[torch.device] = None
):
if linear_bands:
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
else:
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
return bands * torch.pi
def freq_bands(
num_bands: int, temperature: float = 10000.0, step: int = 2, device: Optional[torch.device] = None
) -> torch.Tensor:
exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
bands = 1.0 / (temperature**exp)
return bands
def build_fourier_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
num_bands: int = 64,
max_res: int = 224,
temperature: float = 10000.0,
linear_bands: bool = False,
include_grid: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
"""
Args:
feat_shape: Feature shape for embedding.
bands: Pre-calculated frequency bands.
num_bands: Number of frequency bands (determines output dim).
max_res: Maximum resolution for pixel based freq.
temperature: Temperature for non-pixel freq.
linear_bands: Linear band spacing for pixel based freq.
include_grid: Include the spatial grid in output.
in_pixels: Output in pixel freq.
ref_feat_shape: Reference feature shape for resize / fine-tune.
dtype: Output dtype.
device: Output device.
Returns:
"""
if bands is None:
if in_pixels:
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, device=device)
else:
bands = freq_bands(num_bands, temperature=temperature, step=1, device=device)
else:
if device is None:
device = bands.device
if dtype is None:
dtype = bands.dtype
if in_pixels:
t = [torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=torch.float32) for s in feat_shape]
else:
t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
# align spatial center (H/2,W/2) to (0,0)
t[1] = t[1] - (feat_shape[1] - 1) / 2
t[2] = t[2] - (feat_shape[2] - 1) / 2
if ref_feat_shape is not None:
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
# aligning to the endpoint e.g [0,1,2] -> [0, 0.4, 0.8, 1.2, 1.6, 2]
t_rescaled = []
for x, f, r in zip(t, feat_shape, ref_feat_shape):
# deal with image input
if f == 1:
assert r == 1, "ref_feat_shape must be 1 when feat_shape is 1"
t_rescaled.append(x)
else:
t_rescaled.append(x / (f - 1) * (r - 1))
t = t_rescaled
grid = torch.stack(ndgrid(t), dim=-1)
grid = grid.unsqueeze(-1)
pos = grid * bands
pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype)
out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
return out
def build_rotary_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
dim: int = 64,
max_res: int = 224,
temperature: float = 10000.0,
linear_bands: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
"""
Args:
feat_shape: Spatial shape of the target tensor for embedding.
bands: Optional pre-generated frequency bands
dim: Output dimension of embedding tensor.
max_res: Maximum resolution for pixel mode.
temperature: Temperature (inv freq) for non-pixel mode
linear_bands: Linearly (instead of log) spaced bands for pixel mode
in_pixels: Pixel vs language (inv freq) mode.
dtype: Output dtype.
device: Output device.
Returns:
"""
sin_emb, cos_emb = build_fourier_pos_embed(
feat_shape,
bands=bands,
num_bands=dim // 8,
max_res=max_res,
temperature=temperature,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=ref_feat_shape,
device=device,
dtype=dtype,
)
num_spatial_dim = 1
# this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
for x in feat_shape:
num_spatial_dim *= x
sin_emb = sin_emb.reshape(num_spatial_dim, -1)
cos_emb = cos_emb.reshape(num_spatial_dim, -1)
return sin_emb, cos_emb
class LearnableRotaryEmbeddingCat(nn.Module):
"""Rotary position embedding w/ concatenatd sin & cos
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(
self,
dim,
max_res=224,
temperature=10000,
in_pixels=True,
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
):
super().__init__()
self.dim = dim
self.max_res = max_res
self.temperature = temperature
self.in_pixels = in_pixels
self.linear_bands = linear_bands
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
self.bands = nn.Parameter(self.get_default_bands())
def get_default_bands(self):
if self.in_pixels:
bands = pixel_freq_bands(
self.dim // 8, float(self.max_res), linear_bands=self.linear_bands, devicse=torch.cuda.current_device()
)
else:
bands = freq_bands(self.dim // 8, temperature=self.temperature, step=1, device=torch.cuda.current_device())
return bands
def get_embed(self, shape: Optional[List[int]], ref_feat_shape: Optional[List[int]] = None):
# rebuild bands and embeddings every call, use if target shape changes
embeds = build_rotary_pos_embed(
feat_shape=shape,
bands=self.bands, # use learned bands
dim=self.dim,
max_res=self.max_res,
linear_bands=self.linear_bands,
in_pixels=self.in_pixels,
ref_feat_shape=ref_feat_shape if ref_feat_shape else self.ref_feat_shape,
temperature=self.temperature,
device=torch.cuda.current_device(),
)
return torch.cat(embeds, -1)
##########################################################
# Attention
##########################################################
class Attention(torch.nn.Module):
"""
Attention layer abstract class.
"""
def __init__(self, model_config: ModelConfig, engine_config: EngineConfig, layer_number: int):
super().__init__()
self.model_config: ModelConfig = model_config
self.engine_config: EngineConfig = engine_config
self.layer_number = layer_number
self.hidden_size_per_attention_head = self.model_config.kv_channels
# num_query_groups and num_attention_heads are different for GQA
self.query_projection_size = self.model_config.kv_channels * self.model_config.num_attention_heads
self.kv_projection_size = self.model_config.kv_channels * self.model_config.num_query_groups
# Per attention head and per partition values.
world_size = parallel_state.get_tp_world_size(with_context_parallel=True)
if world_size > self.model_config.num_query_groups and world_size % self.model_config.num_query_groups == 0:
self.num_query_groups_per_partition = 1
else:
self.num_query_groups_per_partition = divide(self.model_config.num_query_groups, world_size)
def _allocate_key_and_value_memory(self, sequence_length, batch_size, dtype):
"""Allocate memory to store kv cache during inference."""
if self.engine_config.kv_offload:
return torch.empty(
sequence_length * batch_size,
self.num_query_groups_per_partition,
self.hidden_size_per_attention_head * 2,
dtype=dtype,
device=torch.cpu.current_device(),
pin_memory=True,
)
else:
return torch.empty(
sequence_length * batch_size,
self.num_query_groups_per_partition,
self.hidden_size_per_attention_head * 2,
dtype=dtype,
device=torch.cuda.current_device(),
)
##########################################################
# FullyParallelAttention
##########################################################
def split_tensor_along_last_dim(
tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False
) -> List[torch.Tensor]:
"""Split a tensor along its last dimension.
Args:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
class FullyParallelAttention(Attention):
def __init__(self, model_config: ModelConfig, engine_config: EngineConfig, layer_number: int):
super().__init__(model_config=model_config, engine_config=engine_config, layer_number=layer_number)
# output 2x query, one for self-attn, one for cross-attn with condition
self.linear_qkv = CustomLayerNormLinear(
input_size=self.model_config.hidden_size,
output_size_q=self.query_projection_size,
output_size_kv=self.kv_projection_size,
layer_number=self.layer_number,
model_config=self.model_config,
engine_config=self.engine_config,
)
# kv from condition, e.g., caption
self.linear_kv_xattn = torch.nn.Linear(
int(self.model_config.hidden_size * self.model_config.xattn_cond_hidden_ratio), # 6144
2 * self.kv_projection_size, # 2048
dtype=self.model_config.params_dtype,
bias=False,
)
# Output.
self.adapt_linear_quant = (
self.engine_config.fp8_quant and self.layer_number != 0 and self.layer_number != model_config.num_layers - 1
)
submodules_linear_proj = PerChannelQuantizedFp8Linear if self.adapt_linear_quant else torch.nn.Linear
self.linear_proj = submodules_linear_proj(
2 * self.query_projection_size, self.model_config.hidden_size, dtype=self.model_config.params_dtype, bias=False
)
self.q_layernorm = FusedLayerNorm(model_config=self.model_config, hidden_size=self.hidden_size_per_attention_head)
self.q_layernorm_xattn = FusedLayerNorm(
model_config=self.model_config, hidden_size=self.hidden_size_per_attention_head
)
self.k_layernorm = FusedLayerNorm(model_config=self.model_config, hidden_size=self.hidden_size_per_attention_head)
self.k_layernorm_xattn = FusedLayerNorm(
model_config=self.model_config, hidden_size=self.hidden_size_per_attention_head
)
def _full_adjust_key_and_value(
self, inference_params: InferenceParams, key_and_value: torch.Tensor, meta_args: ModelMetaArgs
):
"""
Saves the generated key and value tensors to the end of the buffers in inference_params.
Returns the full size keys and values from the provided inference_params
Returns a tuple: (key, value)
"""
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
inf_max_seq_length = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
if self.layer_number not in inference_params.key_value_memory_dict:
inference_key_and_value_memory = self._allocate_key_and_value_memory(
inf_max_seq_length, inf_max_batch_size, key_and_value.dtype
)
inference_params.key_value_memory_dict[self.layer_number] = inference_key_and_value_memory
else:
# Get the pre-allocated buffers for this layer
inference_key_and_value_memory = inference_params.key_value_memory_dict[self.layer_number]
sequence_start = meta_args.slice_point * meta_args.clip_token_nums * inf_max_batch_size
get_key_and_value = inference_key_and_value_memory[:sequence_start, ...].cuda()
# Copy key and values.
if inference_params.update_kv_cache:
key_and_value_total = key_and_value
clip_size = (
key_and_value_total.size(0) - meta_args.clip_token_nums * inf_max_batch_size
if meta_args.distill_nearly_clean_chunk
else key_and_value_total.size(0)
)
sequence_end = sequence_start + clip_size
assert sequence_end <= inference_key_and_value_memory.size(0)
# update kv cache
inference_key_and_value_memory[sequence_start:sequence_end, ...] = key_and_value_total[:clip_size]
return torch.cat([get_key_and_value, key_and_value], dim=0)
def adjust_key_and_value_for_inference(
self, key_and_value: torch.Tensor, inference_params: InferenceParams, meta_args: ModelMetaArgs
):
if inference_params is None:
return torch.chunk(key_and_value, 2, dim=-1)
# Only update kvcache when necessary, include 3 conditions:
# 1. extract prefix video clean feature
# 2. the first chunk of current kv is clean, we need to save their feature
# 3. previous chunk is clean and we need to save/load their feature
if meta_args.extract_prefix_video_feature or meta_args.fwd_extra_1st_chunk or meta_args.slice_point > 0:
key_and_value = self._full_adjust_key_and_value(inference_params, key_and_value, meta_args)
key, value = torch.chunk(key_and_value, 2, dim=-1)
return key.contiguous(), value.contiguous()
# =====================
# Get Query for core attn
# [sq, b, (hn hd)] -> [(sq b), hn, hd]
# =====================
def get_q(self, mixed_qqkv: torch.Tensor, cos_emb: torch.Tensor, sin_emb: torch.Tensor):
query = self.linear_qkv.forward_q(mixed_qqkv)
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
assert self.q_layernorm is not None
original_dtype = query.dtype
query = query.float()
query = self.q_layernorm(query)
query = query.transpose(0, 1).contiguous()
query = flash_apply_rotary_emb(query, cos_emb, sin_emb)
query = query.to(original_dtype)
return rearrange(query, "b sq hn hd -> (sq b) hn hd").contiguous()
# =====================
# Get Key for core attn
# [sq, b, (hn hd)] -> [(sq b), hn, hd]
# =====================
def get_k(self, mixed_qqkv: torch.Tensor, cos_emb: torch.Tensor, sin_emb: torch.Tensor):
key = self.linear_qkv.forward_k(mixed_qqkv)
key = key.reshape(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head)
assert self.k_layernorm is not None
original_dtype = key.dtype
key = key.float()
key = self.k_layernorm(key)
key = key.transpose(0, 1).contiguous()
key = flash_apply_rotary_emb(key, cos_emb, sin_emb)
key = key.to(original_dtype)
return rearrange(key, "b sq hn hd -> (sq b) hn hd").contiguous()
# =====================
# Get Value for core attn
# [sq, b, (hn hd)] -> [(sq b), hn, hd]
# =====================
def get_v(self, mixed_qqkv: torch.Tensor):
value = self.linear_qkv.forward_v(mixed_qqkv)
return rearrange(value, "sq b (hn hd) -> (sq b) hn hd", hd=self.hidden_size_per_attention_head).contiguous()
def get_kv(self, mixed_qqkv: torch.Tensor, cos_emb: torch.Tensor, sin_emb: torch.Tensor):
# Get KV together for better performance when encoutering cpu-bound, mainly used by cuda graph
key = self.get_k(mixed_qqkv, cos_emb, sin_emb)
value = self.get_v(mixed_qqkv)
# [(sq b), hn, hd] -> [(sq b), hn, 2 * hd]
return torch.cat([key, value], dim=-1)
def get_qkv(self, mixed_qqkv: torch.Tensor, cos_emb: torch.Tensor, sin_emb: torch.Tensor):
# Get QKV together for better performance when encoutering cpu-bound, mainly used by cuda graph
q = self.get_q(mixed_qqkv, cos_emb, sin_emb)
k = self.get_k(mixed_qqkv, cos_emb, sin_emb)
v = self.get_v(mixed_qqkv)
return q, k, v
def get_xqkv(self, mixed_qqkv: torch.Tensor, key_value_states: torch.Tensor):
query_xattn = self.linear_qkv.forward_qx(mixed_qqkv)
query_xattn = rearrange(query_xattn, "sq b (hn hd) -> (b sq) hn hd", hd=self.hidden_size_per_attention_head)
query_xattn = self.q_layernorm_xattn(query_xattn)
# [y_total_token, h] --> [y_total_token, 2*hp]
mixed_kv_xattn = torch.concat(
[torch.matmul(key_value_states, w.t()) for w in torch.chunk(self.linear_kv_xattn.weight, 8, axis=0)], axis=1
)
# [y_total_token, 2*hn*hd] --> [y_total_token, hn, 2*hd]
mixed_kv_xattn = mixed_kv_xattn.view(key_value_states.shape[0], -1, 2 * self.hidden_size_per_attention_head)
# [y_total_token, hn, 2*hd] --> 2 [y_total_token, hn, hd]
(key_xattn, value_xattn) = split_tensor_along_last_dim(mixed_kv_xattn, 2)
key_xattn = self.k_layernorm_xattn(key_xattn)
return query_xattn, key_xattn, value_xattn
def core_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bs: int, meta_args: ModelMetaArgs):
# (sq b) hn hd -> b sq hn hd
query = query.reshape(-1, bs, query.shape[1], query.shape[2]).transpose(0, 1).contiguous()
# (sq b) hn hd -> b sq hn hd
key = key.reshape(-1, bs, key.shape[1], key.shape[2]).transpose(0, 1).contiguous()
# (sq b) hn hd -> b sq hn hd
value = value.reshape(-1, bs, value.shape[1], value.shape[2]).transpose(0, 1).contiguous()
if torch.cuda.get_device_capability()[0] >= 9:
core_attn_out, _ = flex_attention(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
meta_args.core_attn_params.q_range,
meta_args.core_attn_params.k_range,
max_seqlen_q=meta_args.core_attn_params.max_seqlen_q,
max_seqlen_k=meta_args.core_attn_params.max_seqlen_k,
softmax_scale=None,
deterministic=torch.are_deterministic_algorithms_enabled(),
disable_fwd_atomic_reduction=True,
)
# (b sq) hn hd -> (sq b) hn hd
core_attn_out = rearrange(core_attn_out, "(b sq) h d -> (sq b) h d", b=bs)
else:
# NOTE(lml): We convert multi denoising_range_num input into multi batch_size input at third time forward under 3_cfg mode, thus could not support normal multi batch_size input. We use an assert statement to ensure that it is still in this situation, thereby guaranteeing the correct use of q_range and k_range later on.
assert not (bs > 1 and meta_args.denoising_range_num > 1)
q_range = meta_args.core_attn_params.np_q_range
k_range = meta_args.core_attn_params.np_k_range
core_attn_outs = []
for i in range(meta_args.denoising_range_num):
if bs == 1:
q = query[:, q_range[i, 0] : q_range[i, 1]]
k = key[:, k_range[i, 0] : k_range[i, 1]]
v = value[:, k_range[i, 0] : k_range[i, 1]]
else:
assert i == 0
q = query[:, q_range[0, 0] : q_range[0, 1]]
k = key[:, k_range[0, 0] : k_range[0, 1]]
v = value[:, k_range[0, 0] : k_range[0, 1]]
o = flash_attn_func(q=q, k=k, v=v, deterministic=torch.are_deterministic_algorithms_enabled())
o = rearrange(o, "b sq h d -> (sq b) h d", b=bs)
core_attn_outs.append(o)
core_attn_out = torch.cat(core_attn_outs, dim=0)
return core_attn_out
def full_attention(self, bs: int, meta_args: ModelMetaArgs, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, i: int):
# NOTE(lml): full_attention is used under cp_shuffle_overlap strategy. We further limit it to the case of bs=1, so that we do not need to pay attention to the arrangement of sq and bs dimensions.
assert bs == 1
if torch.cuda.get_device_capability()[0] >= 9:
q_range = meta_args.core_attn_params.q_range[i : i + 1] - meta_args.core_attn_params.q_range[i, 0]
k_range = meta_args.core_attn_params.k_range[i : i + 1]
o, _ = flex_attention(
q,
k,
v,
q_ranges=q_range,
k_ranges=k_range,
max_seqlen_q=meta_args.core_attn_params.max_seqlen_q,
max_seqlen_k=meta_args.core_attn_params.max_seqlen_k,
softmax_scale=None,
deterministic=torch.are_deterministic_algorithms_enabled(),
disable_fwd_atomic_reduction=True,
)
else:
k_range = meta_args.core_attn_params.np_k_range[i : i + 1]
k = k[k_range[0, 0] : k_range[0, 1]]
v = v[k_range[0, 0] : k_range[0, 1]]
o = flash_attn_func(
q=q.unsqueeze(0),
k=k.unsqueeze(0),
v=v.unsqueeze(0),
deterministic=torch.are_deterministic_algorithms_enabled(),
).flatten(0, 1)
return o
def cross_attention(
self,
mixed_qqkv: torch.Tensor,
key_value_states: torch.Tensor,
cross_attn_params: PackedCrossAttnParams,
get_xqkv_func: Callable,
):
# =================
# cross-attn for aggragating caption / condition
# =================
query_xattn, key_xattn, value_xattn = get_xqkv_func(mixed_qqkv, key_value_states)
if torch.cuda.get_device_capability()[0] >= 9:
xattn_out, _ = flex_attention(
query_xattn,
key_xattn,
value_xattn,
cross_attn_params.q_ranges,
cross_attn_params.kv_ranges,
max_seqlen_q=cross_attn_params.max_seqlen_q,
max_seqlen_k=cross_attn_params.max_seqlen_kv,
softmax_scale=None,
deterministic=False,
disable_fwd_atomic_reduction=True,
)
else:
xattn_out = flash_attn_varlen_func(
query_xattn, # [b*sq, hn, hd]
key_xattn, # [y_total_token, hn, hd]
value_xattn, # [y_total_token, hn, hd]
cu_seqlens_q=cross_attn_params.cu_seqlens_q,
cu_seqlens_k=cross_attn_params.cu_seqlens_kv,
max_seqlen_q=cross_attn_params.max_seqlen_q,
max_seqlen_k=cross_attn_params.max_seqlen_kv,
deterministic=torch.are_deterministic_algorithms_enabled(),
)
batch_size = mixed_qqkv.shape[1]
xattn_out = rearrange(xattn_out, "(b sq) hn hd -> sq b (hn hd)", b=batch_size).contiguous()
return xattn_out
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: torch.Tensor,
inference_params: InferenceParams,
rotary_pos_emb: torch.Tensor,
meta_args: ModelMetaArgs,
):
assert rotary_pos_emb is not None, "FullyParallelAttention needs rotary_pos_emb"
sin_emb, cos_emb = rotary_pos_emb.tensor_split(2, -1)
batch_size = hidden_states.shape[1]
# All comminications operate on dimensions shaped as (cp * sq * b)
batch_cp_split_sizes = None if meta_args.cp_split_sizes is None else [x * batch_size for x in meta_args.cp_split_sizes]
# Attention heads [sq, b, h] --> [sq, b, q + qx + k + v]
mixed_qqkv = self.linear_qkv.forward_ln(hidden_states)
# =====================
# Function wrapper
# =====================
get_kv_func = self.get_kv
get_q_func = self.get_q
get_qkv_func = self.get_qkv
get_xqkv_func = self.get_xqkv
# =====================
# Parallel Strategy
# =====================
if self.engine_config.cp_strategy == "none":
assert self.engine_config.cp_size == 1
key_and_value = get_kv_func(mixed_qqkv, cos_emb, sin_emb)
query = get_q_func(mixed_qqkv, cos_emb, sin_emb)
key, value = self.adjust_key_and_value_for_inference(key_and_value, inference_params, meta_args)
core_attn_out = self.core_attention(query, key, value, batch_size, meta_args)
core_attn_out = rearrange(core_attn_out, "(sq b) hn hd -> sq b (hn hd)", b=batch_size)
xattn_out = self.cross_attention(mixed_qqkv, key_value_states, meta_args.cross_attn_params, get_xqkv_func)
elif self.engine_config.cp_strategy == "cp_ulysses":
get_kv_func = partial(get_kv_func, mixed_qqkv, cos_emb, sin_emb)
get_q_func = partial(get_q_func, mixed_qqkv, cos_emb, sin_emb)
get_qkv_func = partial(get_qkv_func, mixed_qqkv, cos_emb, sin_emb)
kv_cache_func = partial(
self.adjust_key_and_value_for_inference, inference_params=inference_params, meta_args=meta_args
)
if meta_args.enable_cuda_graph and meta_args.denoising_range_num <= 3:
# Temporal solution for first chunk opt
core_attn_out, xattn_out = UlyssesScheduler.get_attn_and_xattn_with_fused_qkv_comm(
get_qkv_func,
kv_cache_func,
partial(self.core_attention, bs=batch_size, meta_args=meta_args),
partial(self.cross_attention, mixed_qqkv, key_value_states, meta_args.cross_attn_params, get_xqkv_func),
self.engine_config.ulysses_overlap_degree,
batch_size,
self.engine_config.cp_size,
batch_cp_split_sizes,
)
else:
core_attn_out, xattn_out = UlyssesScheduler.get_attn_and_xattn_with_fused_kv_comm(
get_q_func,
get_kv_func,
kv_cache_func,
partial(self.core_attention, bs=batch_size, meta_args=meta_args),
partial(self.cross_attention, mixed_qqkv, key_value_states, meta_args.cross_attn_params, get_xqkv_func),
self.engine_config.ulysses_overlap_degree,
batch_size,
self.engine_config.cp_size,
batch_cp_split_sizes,
)
elif self.engine_config.cp_strategy == "cp_shuffle_overlap":
key_and_value = self.get_kv(mixed_qqkv, cos_emb, sin_emb)
key_and_value, handle_kv = cso_communication(key_and_value, self.engine_config.cp_size, batch_cp_split_sizes, "kv")
query = get_q_func(mixed_qqkv, cos_emb, sin_emb)
cso_helper = CSOHelper(meta_args.denoising_range_num, self.engine_config.cp_size, batch_cp_split_sizes)
query, handle_q = cso_helper.split_query_for_overlap(query)
handle_kv.wait()
# NOTE(lml): rearrange and unpad key_and_value for later attention compute under cp_shuffle_overlap strategy, and we should split sqb into sq and b when support multi batch_size input.
key_and_value = (
rearrange(
key_and_value,
"(cp dn sqb) hn nhd -> dn (cp sqb) hn nhd",
dn=meta_args.denoising_range_num,
cp=self.engine_config.cp_size,
)[:, : meta_args.clip_token_nums]
.flatten(0, 1)
.contiguous()
)
key, value = self.adjust_key_and_value_for_inference(key_and_value, inference_params, meta_args)
handle_q.wait()
core_attn_out, handle_attn = cso_helper.overlap(
partial(self.full_attention, hidden_states.shape[1], meta_args), query, key, value
)
xattn_out = self.cross_attention(mixed_qqkv, key_value_states, meta_args.cross_attn_params, get_xqkv_func)
handle_attn.wait()
core_attn_out = rearrange(
torch.concat(core_attn_out, dim=0),
"(dn cp sq b) hn hd -> (dn sq) b (cp hn hd)",
cp=self.engine_config.cp_size,
b=hidden_states.shape[1],
dn=meta_args.denoising_range_num,
)
else:
raise ValueError(f"Unsupported cp_strategy: {self.engine_config.cp_strategy}")
return core_attn_out, xattn_out
##########################################################
# TransformerLayer
##########################################################
class TransformerLayer(torch.nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(self, model_config: ModelConfig, engine_config: EngineConfig, layer_number: int = 1):
super().__init__()
self.model_config = model_config
self.engine_config = engine_config
self.layer_number = layer_number + self._get_layer_offset()
## [Module 1: ada_modulate_layer
self.ada_modulate_layer = AdaModulateLayer(model_config=self.model_config)
## [Module 2: SelfAttention]
self.self_attention = FullyParallelAttention(
model_config=self.model_config, engine_config=self.engine_config, layer_number=self.layer_number
)
## [Module 3: SelfAttention PostNorm]
self.self_attn_post_norm = FusedLayerNorm(model_config=self.model_config, hidden_size=self.model_config.hidden_size)
## [Module 4: MLP block]
self.mlp = CustomMLP(model_config=self.model_config, engine_config=self.engine_config, layer_number=self.layer_number)
## [Module 5: MLP PostNorm]
self.mlp_post_norm = FusedLayerNorm(model_config=self.model_config, hidden_size=self.model_config.hidden_size)
def _get_layer_offset(self):
pipeline_rank = parallel_state.get_pp_rank()
num_layers_per_pipeline_rank = self.model_config.num_layers // parallel_state.get_pp_world_size()
# Each stage gets a contiguous set of layers.
if parallel_state.get_pp_world_size() > 1:
offset = pipeline_rank * num_layers_per_pipeline_rank
else:
offset = 0
return offset
def forward(
self,
hidden_states: torch.Tensor,
condition: torch.Tensor,
condition_map: torch.Tensor,
y_xattn_flat: torch.Tensor,
rotary_pos_emb: torch.Tensor,
inference_params: InferenceParams,
meta_args: ModelMetaArgs,
):
# hidden_states: [s/cp/sp, b, h]
residual = hidden_states
# Self attention.
core_attn_out, cross_attn_out = self.self_attention(
hidden_states,
key_value_states=y_xattn_flat,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
meta_args=meta_args,
)
hidden_states = self.attn_post_process(core_attn_out, cross_attn_out, residual, condition, condition_map)
return hidden_states
def attn_post_process(
self,
core_attn_out: torch.Tensor,
cross_attn_out: torch.Tensor,
residual: torch.Tensor,
condition: torch.Tensor,
condition_map: torch.Tensor,
):
hidden_states = self.attn_linear_proj(core_attn_out, cross_attn_out)
hidden_states = self.gating_and_mlp(hidden_states, residual, condition, condition_map)
return hidden_states
def attn_linear_proj(self, core_attn_out: torch.Tensor, cross_attn_out: torch.Tensor):
# ============================================
# attention post-process , output. [sq, b, h]
# ============================================
attn_out = torch.concat([core_attn_out, cross_attn_out], dim=2)
# NOTE: hn=8 is hardcoded to align with TP8 traning and TP1 inference
attn_out = rearrange(attn_out, "sq b (n hn hd) -> sq b (hn n hd)", n=2, hn=8)
if self.self_attention.adapt_linear_quant:
attn_out = self.self_attention.linear_proj(attn_out)
else:
# Use high-precision for non-quantized linear projection
with torch.autocast(device_type="cuda", dtype=torch.float32):
attn_out = self.self_attention.linear_proj(attn_out)
return attn_out
def gating_and_mlp(
self, hidden_states: torch.Tensor, residual: torch.Tensor, condition: torch.Tensor, condition_map: torch.Tensor
):
gate_output = self.ada_modulate_layer(condition)
softcap_gate_cap = 1.0
gate_output = softcap(gate_output, softcap_gate_cap)
gate_msa, gate_mlp = gate_output.chunk(2, dim=-1)
# Residual connection for self-attention.
hidden_states = bias_modulate_add(hidden_states, residual, condition_map, gate_msa, self.self_attn_post_norm).to(
self.model_config.params_dtype
)
residual = hidden_states
hidden_states = self.mlp(hidden_states)
# Residual connection for MLP.
hidden_states = bias_modulate_add(hidden_states, residual, condition_map, gate_mlp, self.mlp_post_norm).to(
self.model_config.params_dtype
)
return hidden_states
##########################################################
# TransformerBlock
##########################################################
class TransformerBlock(torch.nn.Module):
"""Transformer class."""
def __init__(
self, model_config: ModelConfig, engine_config: EngineConfig, pre_process: bool = True, post_process: bool = True
):
super().__init__()
self.model_config = model_config
self.engine_config = engine_config
self.pre_process = pre_process
self.post_process = post_process
# required for pipeline parallel schedules
self.input_tensor = None
layer_number = self.model_config.num_layers // parallel_state.get_pp_world_size()
# offset is implicit in TransformerLayer
self.layers = torch.nn.ModuleList(
[
TransformerLayer(model_config=self.model_config, engine_config=self.engine_config, layer_number=i)
for i in range(layer_number)
]
)
if self.post_process:
# Final layer norm before output.
self.final_layernorm = FusedLayerNorm(model_config=self.model_config, hidden_size=self.model_config.hidden_size)
def set_input_tensor(self, input_tensor: Tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
@torch.no_grad()
def forward(
self,
hidden_states: Tensor,
condition: Tensor,
condition_map: Tensor,
y_xattn_flat: Tensor,
rotary_pos_emb: Tensor,
inference_params: InferenceParams,
meta_args: ModelMetaArgs,
) -> torch.Tensor:
if not self.pre_process:
assert self.input_tensor is not None, "please call set_input_tensor for pp"
hidden_states = self.input_tensor
for layer in self.layers:
hidden_states = layer(
hidden_states=hidden_states,
condition=condition,
condition_map=condition_map,
y_xattn_flat=y_xattn_flat,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
meta_args=meta_args,
)
# Final layer norm.
if self.post_process:
hidden_states = self.final_layernorm(hidden_states.float())
return hidden_states
```
The content has been capped at 50000 tokens, and files over NaN bytes have been omitted. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.