```
├── .gitattributes (omitted)
├── .github/
├── workflows/
├── pylint.yml (100 tokens)
├── .gitignore (600 tokens)
├── .pre-commit-config.yaml (100 tokens)
├── LICENSE (omitted)
├── README.md (5.3k tokens)
├── configs/
├── ltxv-13b-0.9.8-dev-fp8.yaml (300 tokens)
├── ltxv-13b-0.9.8-dev.yaml (300 tokens)
├── ltxv-13b-0.9.8-distilled-fp8.yaml (200 tokens)
├── ltxv-13b-0.9.8-distilled.yaml (200 tokens)
├── ltxv-2b-0.9.1.yaml (100 tokens)
├── ltxv-2b-0.9.5.yaml (100 tokens)
├── ltxv-2b-0.9.6-dev.yaml (100 tokens)
├── ltxv-2b-0.9.6-distilled.yaml (100 tokens)
├── ltxv-2b-0.9.8-distilled-fp8.yaml (200 tokens)
├── ltxv-2b-0.9.8-distilled.yaml (200 tokens)
├── ltxv-2b-0.9.yaml (100 tokens)
├── docs/
├── _static/
├── ltx-video_example_00001.gif
├── ltx-video_example_00005.gif
├── ltx-video_example_00006.gif
├── ltx-video_example_00007.gif
├── ltx-video_example_00010.gif
├── ltx-video_example_00011.gif
├── ltx-video_example_00013.gif
├── ltx-video_example_00014.gif
├── ltx-video_example_00015.gif
├── ltx-video_i2v_example_00001.gif
├── ltx-video_i2v_example_00002.gif
├── ltx-video_i2v_example_00003.gif
├── ltx-video_i2v_example_00004.gif
├── ltx-video_i2v_example_00005.gif
├── ltx-video_i2v_example_00006.gif
├── ltx-video_i2v_example_00007.gif
├── ltx-video_i2v_example_00008.gif
├── ltx-video_i2v_example_00009.gif
├── ltx-video_ic_2v_example_00000.gif
├── ltx-video_ic_2v_example_00001.gif
├── ltx-video_ic_2v_example_00002.gif
├── ltx-video_ic_2v_example_00003.gif
├── ltx-video_ic_2v_example_00004.gif
├── inference.py (100 tokens)
├── ltx_video/
├── __init__.py
├── inference.py (5.3k tokens)
├── models/
├── __init__.py
├── autoencoders/
├── __init__.py
├── causal_conv3d.py (400 tokens)
├── causal_video_autoencoder.py (10.5k tokens)
├── conv_nd_factory.py (500 tokens)
├── dual_conv3d.py (1400 tokens)
├── latent_upsampler.py (1400 tokens)
├── pixel_norm.py (100 tokens)
├── pixel_shuffle.py (200 tokens)
├── vae.py (2.9k tokens)
├── vae_encode.py (1800 tokens)
├── video_autoencoder.py (7.1k tokens)
├── transformers/
├── __init__.py
├── attention.py (10.5k tokens)
├── embeddings.py (900 tokens)
├── symmetric_patchifier.py (600 tokens)
├── transformer3d.py (4.5k tokens)
├── pipelines/
├── __init__.py
├── crf_compressor.py (300 tokens)
├── pipeline_ltx_video.py (16.7k tokens)
├── schedulers/
├── __init__.py
├── rf.py (3.1k tokens)
├── utils/
├── __init__.py
├── diffusers_config_mapping.py (1100 tokens)
├── prompt_enhance_utils.py (1500 tokens)
├── skip_layer_strategy.py
├── torch_utils.py (200 tokens)
├── pyproject.toml (200 tokens)
├── tests/
├── conftest.py (600 tokens)
├── test_configs.py (200 tokens)
├── test_inference.py (1800 tokens)
├── test_scheduler.py (700 tokens)
├── test_vae.py (600 tokens)
├── utils/
├── .gitattributes
├── woman.jpeg
├── woman.mp4
```
## /.github/workflows/pylint.yml
```yml path="/.github/workflows/pylint.yml"
name: Ruff
on: [push]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- name: Checkout repository and submodules
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.2.2 black==24.2.0
- name: Analyzing the code with ruff
run: |
ruff $(git ls-files '*.py')
- name: Verify that no Black changes are required
run: |
black --check $(git ls-files '*.py')
```
## /.gitignore
```gitignore path="/.gitignore"
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# From inference.py
outputs/
*.mp4
*.png
!tests/utils/car.png
```
## /.pre-commit-config.yaml
```yaml path="/.pre-commit-config.yaml"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.2.2
hooks:
# Run the linter.
- id: ruff
args: [--fix] # Automatically fix issues if possible.
types: [python] # Ensure it only runs on .py files.
- repo: https://github.com/psf/black
rev: 24.2.0 # Specify the version of Black you want
hooks:
- id: black
name: Black code formatter
language_version: python3 # Use the Python version you're targeting (e.g., 3.10)
```
## /README.md
<div align="center">
# LTX-Video
[](https://ltx.video)
[](https://huggingface.co/Lightricks/LTX-Video)
[](https://app.ltx.studio/ltx-2-playground/t2v)
[](https://arxiv.org/abs/2501.00103)
[](https://github.com/Lightricks/LTX-Video-Trainer)
[](https://discord.gg/ltxplatform)
This is the official repository for LTX-Video.
</div>
---
## 🚀 **New: LTX-2 is Now Available!**
**We're excited to announce [LTX-2](https://github.com/Lightricks/LTX-2) - the next generation of LTX with synchronized audio+video generation!**
LTX-2 is the first DiT-based audio-video foundation model that contains all core capabilities of modern video generation in one model. **LTX-2 is now the primary home for LTX development** and includes significant improvements:
- 🎵 **Synchronized Audio+Video Generation** - Generate videos with perfectly synchronized audio
- 🎬 **Latest Model** - LTX-2 with improved quality and capabilities
- 🔌 **ComfyUI Integration** - Built into ComfyUI core for seamless workflows
- 🎯 **Advanced Features:**
- Multiple keyframe support
- IC-LoRA control models for precise generation
- Standard LoRA support for style customization
- Latent upsampler for multiscale pipelines
- 🛠️ **Training Tools** - LoRA training capabilities
- 📚 **Comprehensive Documentation** - Full documentation at [https://docs.ltx.video](https://docs.ltx.video)
- 🔄 **Active Development** - Ongoing improvements and community support
**[👉 Check out LTX-2 here](https://github.com/Lightricks/LTX-2)**
**[📖 View Documentation](https://docs.ltx.video)**
---
## Table of Contents
- [Introduction](#introduction)
- [What's New](#news)
- [Models](#models)
- [Quick Start Guide](#quick-start-guide)
- [Online demo](#online-inference)
- [Run locally](#run-locally)
- [Installation](#installation)
- [Inference](#inference)
- [ComfyUI Integration](#comfyui-integration)
- [Diffusers Integration](#diffusers-integration)
- [Model User Guide](#model-user-guide)
- [Community Contribution](#community-contribution)
- [Training](#training)
- [Control Models](#control-models)
- [Join Us!](#join-us)
- [Acknowledgement](#acknowledgement)
# Introduction
LTX-Video is the first DiT-based video generation model that contains all core capabilities of modern video generation in one model: synchronized audio and video, high fidelity, multiple performance modes, production-ready outputs, API access, and open access. It can generate up to 50 FPS videos at native 4K resolution with synchronized audio in one pass.
The model is trained on a large-scale dataset of diverse videos and can generate high-resolution videos with realistic and diverse content.
The model supports image-to-video, multi-keyframe conditioning, keyframe-based animation, video extension (both forward and backward), video-to-video transformations, and any combination of these features.
### Image-to-video examples
| | | |
|:---:|:---:|:---:|
|  |  |  |
|  |  |  |
|  |  |  |
### Controlled video examples
| | | |
|:---:|:---:|:---:|
|  |  |  |
| | |
|:---:|:---:|
|  |  |
# News
## October 23, 2025: LTX-2 Announced
Today we announced our newest foundation model, LTX-2. LTX-2 represents a major leap forward from our previous model, LTXV 0.9.8. Here’s what’s new:
* **Audio + Video, Together**: Visuals and sound are generated in one coherent process, with motion, dialogue, ambience, and music flowing simultaneously.
* **4K Fidelity**: Professional-grade precision with native 4K and up to 50 fps, sharp textures, clean motion, and synchronized audio.
* **Longer Generations**: LTX-2 supports longer, continuous clips with synchronized audio up to 10 seconds.
* **Low Cost & Efficiency**: Up to 50% lower compute cost than competing models, powered by a multi-GPU inference stack.
* **Creative Control**: Multi-keyframe conditioning, 3D camera logic, and LoRA fine-tuning deliver frame-level precision and style consistency.
For more details, please see our [blog post](https://website.ltx.video/blog/introducing-ltx-2). LTX-2 model weights, code, and benchmarks will be released to the community later in 2025.
## July, 16th, 2025: New Distilled models v0.9.8 with up to 60 seconds of video:
- Long shot generation in LTXV-13B!
* LTX-Video now supports up to 60 seconds of video.
* Compatible also with the official IC-LoRAs.
* Try now in [ComfyUI](https://github.com/Lightricks/ComfyUI-LTXVideo/tree/master/example_workflows/ltxv-13b-i2v-long-multi-prompt.json).
- Release a new distilled models:
* 13B distilled model [ltxv-13b-0.9.8-distilled](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-distilled.yaml)
* 2B distilled model [ltxv-2b-0.9.8-distilled](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.8-distilled.yaml)
* Both models are distilled from the same base model [ltxv-13b-0.9.8-dev](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-dev.yaml) and are compatible for use together in the same multiscale pipeline.
* Improved prompt understanding and detail generation
* Includes corresponding FP8 weights and workflows.
- Release a new detailer model [LTX-Video-ICLoRA-detailer-13B-0.9.8](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8)
* Available in [ComfyUI](https://github.com/Lightricks/ComfyUI-LTXVideo/tree/master/example_workflows/ltxv-13b-upscale.json).
## July, 8th, 2025: New Control Models Released!
- Released three new control models for LTX-Video on HuggingFace:
* **Depth Control**: [LTX-Video-ICLoRA-depth-13b-0.9.7](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-depth-13b-0.9.7)
* **Pose Control**: [LTX-Video-ICLoRA-pose-13b-0.9.7](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-pose-13b-0.9.7)
* **Canny Control**: [LTX-Video-ICLoRA-canny-13b-0.9.7](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-canny-13b-0.9.7)
## May, 14th, 2025: New distilled model 13B v0.9.7:
- Release a new 13B distilled model [ltxv-13b-0.9.7-distilled](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors)
* Amazing for iterative work - generates HD videos in 10 seconds, with low-res preview after just 3 seconds (on H100)!
* Does not require classifier-free guidance and spatio-temporal guidance.
* Supports sampling with 8 (recommended), or less diffusion steps.
* Also released a LoRA version of the distilled model, [ltxv-13b-0.9.7-distilled-lora128](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-lora128.safetensors)
* Requires only 1GB of VRAM
* Can be used with the full 13B model for fast inference
- Release a new quantized distilled model [ltxv-13b-0.9.7-distilled-fp8](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-fp8.safetensors) for *real-time* generation (on H100) with even less VRAM
## May, 5th, 2025: New model 13B v0.9.7:
- Release a new 13B model [ltxv-13b-0.9.7-dev](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors)
- Release a new quantized model [ltxv-13b-0.9.7-dev-fp8](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev-fp8.safetensors) for faster inference with less VRam
- Release a new upscalers
* [ltxv-temporal-upscaler-0.9.7](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-temporal-upscaler-0.9.7.safetensors)
* [ltxv-spatial-upscaler-0.9.7](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors)
- Breakthrough prompt adherence and physical understanding.
- New Pipeline for multi-scale video rendering for fast and high quality results
## April, 15th, 2025: New checkpoints v0.9.6:
- Release a new checkpoint [ltxv-2b-0.9.6-dev-04-25](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-dev-04-25.safetensors) with improved quality
- Release a new distilled model [ltxv-2b-0.9.6-distilled-04-25](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-distilled-04-25.safetensors)
* 15x faster inference than non-distilled model.
* Does not require classifier-free guidance and spatio-temporal guidance.
* Supports sampling with 8 (recommended), or less diffusion steps.
- Improved prompt adherence, motion quality and fine details.
- New default resolution and FPS: 1216 × 704 pixels at 30 FPS
* Still real time on H100 with the distilled model.
* Other resolutions and FPS are still supported.
- Support stochastic inference (can improve visual quality when using the distilled model)
## March, 5th, 2025: New checkpoint v0.9.5
- New license for commercial use ([OpenRail-M](https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.5.license.txt))
- Release a new checkpoint v0.9.5 with improved quality
- Support keyframes and video extension
- Support higher resolutions
- Improved prompt understanding
- Improved VAE
- New online web app in [LTX-Studio](https://app.ltx.studio/ltx-video)
- Automatic prompt enhancement
## February, 20th, 2025: More inference options
- Improve STG (Spatiotemporal Guidance) for LTX-Video
- Support MPS on macOS with PyTorch 2.3.0
- Add support for 8-bit model, LTX-VideoQ8
- Add TeaCache for LTX-Video
- Add [ComfyUI-LTXTricks](#comfyui-integration)
- Add Diffusion-Pipe
## December 31st, 2024: Research paper
- Release the [research paper](https://arxiv.org/abs/2501.00103)
## December 20th, 2024: New checkpoint v0.9.1
- Release a new checkpoint v0.9.1 with improved quality
- Support for STG / PAG
- Support loading checkpoints of LTX-Video in Diffusers format (conversion is done on-the-fly)
- Support offloading unused parts to CPU
- Support the new timestep-conditioned VAE decoder
- Reference contributions from the community in the readme file
- Relax transformers dependency
## November 21th, 2024: Initial release v0.9.0
- Initial release of LTX-Video
- Support text-to-video and image-to-video generation
# Models
| Name | Notes | inference.py config | ComfyUI workflow (Recommended) |
|-------------------------|--------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------|
| ltxv-13b-0.9.8-dev | Highest quality, requires more VRAM | [ltxv-13b-0.9.8-dev.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-dev.yaml) | [ltxv-13b-i2v-base.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-base.json) |
| [ltxv-13b-0.9.8-mix](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b) | Mix ltxv-13b-dev and ltxv-13b-distilled in the same multi-scale rendering workflow for balanced speed-quality | N/A | [ltxv-13b-i2v-mixed-multiscale.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-mixed-multiscale.json) |
[ltxv-13b-0.9.8-distilled](https://app.ltx.studio/motion-workspace?videoModel=ltxv) | Faster, less VRAM usage, slight quality reduction compared to 13b. Ideal for rapid iterations | [ltxv-13b-0.9.8-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-distilled.yaml) | [ltxv-13b-dist-i2v-base.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/13b-distilled/ltxv-13b-dist-i2v-base.json) |
ltxv-2b-0.9.8-distilled | Smaller model, slight quality reduction compared to 13b distilled. Ideal for fast generation with light VRAM usage | [ltxv-2b-0.9.8-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.8-distilled.yaml) | N/A |
| ltxv-13b-0.9.8-dev-fp8 | Quantized version of ltxv-13b | [ltxv-13b-0.9.8-dev-fp8.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-dev-fp8.yaml) | [ltxv-13b-i2v-base-fp8.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-base-fp8.json) |
| ltxv-13b-0.9.8-distilled-fp8 | Quantized version of ltxv-13b-distilled | [ltxv-13b-0.9.8-distilled-fp8.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-distilled-fp8.yaml) | [ltxv-13b-dist-i2v-base-fp8.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/13b-distilled/ltxv-13b-dist-i2v-base-fp8.json) |
| ltxv-2b-0.9.8-distilled-fp8 | Quantized version of ltxv-2b-distilled | [ltxv-2b-0.9.8-distilled-fp8.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.8-distilled-fp8.yaml) | N/A |
| ltxv-2b-0.9.6 | Good quality, lower VRAM requirement than ltxv-13b | [ltxv-2b-0.9.6-dev.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.6-dev.yaml) | [ltxvideo-i2v.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/low_level/ltxvideo-i2v.json) |
| ltxv-2b-0.9.6-distilled | 15× faster, real-time capable, fewer steps needed, no STG/CFG required | [ltxv-2b-0.9.6-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.6-distilled.yaml) | [ltxvideo-i2v-distilled.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/low_level/ltxvideo-i2v-distilled.json) |
# Quick Start Guide
## Online inference
The model is accessible right away via the following links:
- [LTX-Studio image-to-video (13B-mix)](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b)
- [LTX-Studio image-to-video (13B distilled)](https://app.ltx.studio/motion-workspace?videoModel=ltxv)
- [Fal.ai image-to-video (13B full)](https://fal.ai/models/fal-ai/ltx-video-13b-dev/image-to-video)
- [Fal.ai image-to-video (13B distilled)](https://fal.ai/models/fal-ai/ltx-video-13b-distilled/image-to-video)
- [Replicate image-to-video](https://replicate.com/lightricks/ltx-video)
## Run locally
### Installation
The codebase was tested with Python 3.10.5, CUDA version 12.2, and supports PyTorch >= 2.1.2.
On macOS, MPS was tested with PyTorch 2.3.0, and should support PyTorch == 2.3 or >= 2.6.
```bash
git clone https://github.com/Lightricks/LTX-Video.git
cd LTX-Video
# create env
python -m venv env
source env/bin/activate
python -m pip install -e .\[inference\]
```
#### FP8 Kernels (optional)
[FP8 kernels](https://github.com/Lightricks/LTXVideo-Q8-Kernels) developed for LTX-Video provide performance boost on supported graphics cards (Ada architecture and later). To install FP8 kernels, follow the instructions in that repository.
### Inference
📝 **Note:** For best results, we recommend using our [ComfyUI](#comfyui-integration) workflow. We're working on updating the inference.py script to match the high quality and output fidelity of ComfyUI.
To use our model, please follow the inference code in [inference.py](./inference.py):
#### For image-to-video generation:
```bash
python inference.py --prompt "PROMPT" --conditioning_media_paths IMAGE_PATH --conditioning_start_frames 0 --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.8-distilled.yaml
```
#### Extending a video:
📝 **Note:** Input video segments must contain a multiple of 8 frames plus 1 (e.g., 9, 17, 25, etc.), and the target frame number should be a multiple of 8.
```bash
python inference.py --prompt "PROMPT" --conditioning_media_paths VIDEO_PATH --conditioning_start_frames START_FRAME --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.8-distilled.yaml
```
#### For video generation with multiple conditions:
You can now generate a video conditioned on a set of images and/or short video segments.
Simply provide a list of paths to the images or video segments you want to condition on, along with their target frame numbers in the generated video. You can also specify the conditioning strength for each item (default: 1.0).
```bash
python inference.py --prompt "PROMPT" --conditioning_media_paths IMAGE_OR_VIDEO_PATH_1 IMAGE_OR_VIDEO_PATH_2 --conditioning_start_frames TARGET_FRAME_1 TARGET_FRAME_2 --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.8-distilled.yaml
```
### Using as a library
```python
from ltx_video.inference import infer, InferenceConfig
infer(
InferenceConfig(
pipeline_config="configs/ltxv-13b-0.9.8-distilled.yaml",
prompt=PROMPT,
height=HEIGHT,
width=WIDTH,
num_frames=NUM_FRAMES,
output_path="output.mp4",
)
)
```
## ComfyUI Integration
To use our model with ComfyUI, please follow the instructions at [https://github.com/Lightricks/ComfyUI-LTXVideo/](https://github.com/Lightricks/ComfyUI-LTXVideo/).
## Diffusers Integration
To use our model with the Diffusers Python library, check out the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
Diffusers also support an 8-bit version of LTX-Video, [see details below](#ltx-videoq8)
# Model User Guide
## 📝 Prompt Engineering
When writing prompts, focus on detailed, chronological descriptions of actions and scenes. Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. Start directly with the action, and keep descriptions literal and precise. Think like a cinematographer describing a shot list. Keep within 200 words. For best results, build your prompts using this structure:
* Start with main action in a single sentence
* Add specific details about movements and gestures
* Describe character/object appearances precisely
* Include background and environment details
* Specify camera angles and movements
* Describe lighting and colors
* Note any changes or sudden events
* See [examples](#introduction) for more inspiration.
### Automatic Prompt Enhancement
When using `LTXVideoPipeline` directly, you can enable prompt enhancement by setting `enhance_prompt=True`.
## 🎮 Parameter Guide
* Resolution Preset: Higher resolutions for detailed scenes, lower for faster generation and simpler scenes. The model works on resolutions that are divisible by 32 and number of frames that are divisible by 8 + 1 (e.g. 257). In case the resolution or number of frames are not divisible by 32 or 8 + 1, the input will be padded with -1 and then cropped to the desired resolution and number of frames. The model works best on resolutions under 720 x 1280 and number of frames below 257
* Seed: Save seed values to recreate specific styles or compositions you like
* Guidance Scale: 3-3.5 are the recommended values
* Inference Steps: More steps (40+) for quality, fewer steps (20-30) for speed
📝 For advanced parameters usage, please see `python inference.py --help`
## Community Contribution
### ComfyUI-LTXTricks 🛠️
A community project providing additional nodes for enhanced control over the LTX Video model. It includes implementations of advanced techniques like RF-Inversion, RF-Edit, FlowEdit, and more. These nodes enable workflows such as Image and Video to Video (I+V2V), enhanced sampling via Spatiotemporal Skip Guidance (STG), and interpolation with precise frame settings.
- **Repository:** [ComfyUI-LTXTricks](https://github.com/logtd/ComfyUI-LTXTricks)
- **Features:**
- 🔄 **RF-Inversion:** Implements [RF-Inversion](https://rf-inversion.github.io/) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_inversion.json).
- ✂️ **RF-Edit:** Implements [RF-Solver-Edit](https://github.com/wangjiangshan0725/RF-Solver-Edit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_rf_edit.json).
- 🌊 **FlowEdit:** Implements [FlowEdit](https://github.com/fallenshock/FlowEdit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_flow_edit.json).
- 🎥 **I+V2V:** Enables Video to Video with a reference image. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_iv2v.json).
- ✨ **Enhance:** Partial implementation of [STGuidance](https://junhahyung.github.io/STGuidance/). [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltxv_stg.json).
- 🖼️ **Interpolation and Frame Setting:** Nodes for precise control of latents per frame. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_interpolation.json).
### LTX-VideoQ8 🎱 <a id="ltx-videoq8"></a>
**LTX-VideoQ8** is an 8-bit optimized version of [LTX-Video](https://github.com/Lightricks/LTX-Video), designed for faster performance on NVIDIA ADA GPUs.
- **Repository:** [LTX-VideoQ8](https://github.com/KONAKONA666/LTX-Video)
- **Features:**
- 🚀 Up to 3X speed-up with no accuracy loss
- 🎥 Generate 720x480x121 videos in under a minute on RTX 4060 (8GB VRAM)
- 🛠️ Fine-tune 2B transformer models with precalculated latents
- **Community Discussion:** [Reddit Thread](https://www.reddit.com/r/StableDiffusion/comments/1h79ks2/fast_ltx_video_on_rtx_4060_and_other_ada_gpus/)
- **Diffusers integration:** A diffusers integration for the 8-bit model is already out! [Details here](https://github.com/sayakpaul/q8-ltx-video)
### TeaCache for LTX-Video 🍵 <a id="TeaCache"></a>
**TeaCache** is a training-free caching approach that leverages timestep differences across model outputs to accelerate LTX-Video inference by up to 2x without significant visual quality degradation.
- **Repository:** [TeaCache4LTX-Video](https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4LTX-Video)
- **Features:**
- 🚀 Speeds up LTX-Video inference.
- 📊 Adjustable trade-offs between speed (up to 2x) and visual quality using configurable parameters.
- 🛠️ No retraining required: Works directly with existing models.
### Your Contribution
...is welcome! If you have a project or tool that integrates with LTX-Video,
please let us know by opening an issue or pull request.
# Training
We provide an open-source repository for fine-tuning the LTX-Video model: [LTX-Video-Trainer](https://github.com/Lightricks/LTX-Video-Trainer).
This repository supports both the 2B and 13B model variants, enabling full fine-tuning as well as LoRA (Low-Rank Adaptation) fine-tuning for more efficient training. This includes:
- **Control LoRAs**: Train custom control models like depth, pose, and canny control
- **Effect LoRAs**: Create specialized effects and transformations for video generation
Explore the repository to customize the model for your specific use cases!
More information and training instructions can be found in the [README](https://github.com/Lightricks/LTX-Video-Trainer/blob/main/README.md).
# Control Models
[ComfyUI-LTXVideo](https://github.com/Lightricks/ComfyUI-LTXVideo) repository now contains workflows and models for 3 specialized models that enable precise control over LTX-Video generation:
Pose Control, Depth Control and Canny Control
**Example ComfyUI Workflow (for all control types):** [ic-lora.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ic_lora/ic-lora.json)
# Join Us
Want to work on cutting-edge AI research and make a real impact on millions of users worldwide?
At **Lightricks**, an AI-first company, we're revolutionizing how visual content is created.
If you are passionate about AI, computer vision, and video generation, we would love to hear from you!
Please visit our [careers page](https://careers.lightricks.com/careers?query=&office=all&department=R%26D) for more information.
# Acknowledgement
We are grateful for the following awesome projects when implementing LTX-Video:
* [DiT](https://github.com/facebookresearch/DiT) and [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): vision transformers for image generation.
## Citation
📄 Our tech report is out! If you find our work helpful, please ⭐️ star the repository and cite our paper.
```
@article{HaCohen2024LTXVideo,
title={LTX-Video: Realtime Video Latent Diffusion},
author={HaCohen, Yoav and Chiprut, Nisan and Brazowski, Benny and Shalem, Daniel and Moshe, Dudu and Richardson, Eitan and Levin, Eran and Shiran, Guy and Zabari, Nir and Gordon, Ori and Panet, Poriya and Weissbuch, Sapir and Kulikov, Victor and Bitterman, Yaki and Melumian, Zeev and Bibi, Ofir},
journal={arXiv preprint arXiv:2501.00103},
year={2024}
}
```
## /configs/ltxv-13b-0.9.8-dev-fp8.yaml
```yaml path="/configs/ltxv-13b-0.9.8-dev-fp8.yaml"
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-dev-fp8.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
stg_scale: [0, 0, 4, 4, 4, 2, 1]
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
num_inference_steps: 30
skip_final_inference_steps: 3
cfg_star_rescale: true
second_pass:
guidance_scale: [1]
stg_scale: [1]
rescaling_scale: [1]
guidance_timesteps: [1.0]
skip_block_list: [27]
num_inference_steps: 30
skip_initial_inference_steps: 17
cfg_star_rescale: true
```
## /configs/ltxv-13b-0.9.8-dev.yaml
```yaml path="/configs/ltxv-13b-0.9.8-dev.yaml"
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-dev.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
stg_scale: [0, 0, 4, 4, 4, 2, 1]
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
num_inference_steps: 30
skip_final_inference_steps: 3
cfg_star_rescale: true
second_pass:
guidance_scale: [1]
stg_scale: [1]
rescaling_scale: [1]
guidance_timesteps: [1.0]
skip_block_list: [27]
num_inference_steps: 30
skip_initial_inference_steps: 17
cfg_star_rescale: true
```
## /configs/ltxv-13b-0.9.8-distilled-fp8.yaml
```yaml path="/configs/ltxv-13b-0.9.8-distilled-fp8.yaml"
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-distilled-fp8.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
second_pass:
timesteps: [0.9094, 0.7250, 0.4219]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
tone_map_compression_ratio: 0.6
```
## /configs/ltxv-13b-0.9.8-distilled.yaml
```yaml path="/configs/ltxv-13b-0.9.8-distilled.yaml"
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-distilled.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
second_pass:
timesteps: [0.9094, 0.7250, 0.4219]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
tone_map_compression_ratio: 0.6
```
## /configs/ltxv-2b-0.9.1.yaml
```yaml path="/configs/ltxv-2b-0.9.1.yaml"
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.1.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
```
## /configs/ltxv-2b-0.9.5.yaml
```yaml path="/configs/ltxv-2b-0.9.5.yaml"
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.5.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
```
## /configs/ltxv-2b-0.9.6-dev.yaml
```yaml path="/configs/ltxv-2b-0.9.6-dev.yaml"
pipeline_type: base
checkpoint_path: "ltxv-2b-0.9.6-dev-04-25.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
```
## /configs/ltxv-2b-0.9.6-distilled.yaml
```yaml path="/configs/ltxv-2b-0.9.6-distilled.yaml"
pipeline_type: base
checkpoint_path: "ltxv-2b-0.9.6-distilled-04-25.safetensors"
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
num_inference_steps: 8
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: true
```
## /configs/ltxv-2b-0.9.8-distilled-fp8.yaml
```yaml path="/configs/ltxv-2b-0.9.8-distilled-fp8.yaml"
pipeline_type: multi-scale
checkpoint_path: "ltxv-2b-0.9.8-distilled-fp8.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
second_pass:
timesteps: [0.9094, 0.7250, 0.4219]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
```
## /configs/ltxv-2b-0.9.8-distilled.yaml
```yaml path="/configs/ltxv-2b-0.9.8-distilled.yaml"
pipeline_type: multi-scale
checkpoint_path: "ltxv-2b-0.9.8-distilled.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
first_pass:
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
second_pass:
timesteps: [0.9094, 0.7250, 0.4219]
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
skip_block_list: [42]
```
## /configs/ltxv-2b-0.9.yaml
```yaml path="/configs/ltxv-2b-0.9.yaml"
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
```
## /docs/_static/ltx-video_example_00001.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00001.gif
## /docs/_static/ltx-video_example_00005.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00005.gif
## /docs/_static/ltx-video_example_00006.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00006.gif
## /docs/_static/ltx-video_example_00007.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00007.gif
## /docs/_static/ltx-video_example_00010.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00010.gif
## /docs/_static/ltx-video_example_00011.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00011.gif
## /docs/_static/ltx-video_example_00013.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00013.gif
## /docs/_static/ltx-video_example_00014.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00014.gif
## /docs/_static/ltx-video_example_00015.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00015.gif
## /docs/_static/ltx-video_i2v_example_00001.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00001.gif
## /docs/_static/ltx-video_i2v_example_00002.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00002.gif
## /docs/_static/ltx-video_i2v_example_00003.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00003.gif
## /docs/_static/ltx-video_i2v_example_00004.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00004.gif
## /docs/_static/ltx-video_i2v_example_00005.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00005.gif
## /docs/_static/ltx-video_i2v_example_00006.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00006.gif
## /docs/_static/ltx-video_i2v_example_00007.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00007.gif
## /docs/_static/ltx-video_i2v_example_00008.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00008.gif
## /docs/_static/ltx-video_i2v_example_00009.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00009.gif
## /docs/_static/ltx-video_ic_2v_example_00000.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_ic_2v_example_00000.gif
## /docs/_static/ltx-video_ic_2v_example_00001.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_ic_2v_example_00001.gif
## /docs/_static/ltx-video_ic_2v_example_00002.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_ic_2v_example_00002.gif
## /docs/_static/ltx-video_ic_2v_example_00003.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_ic_2v_example_00003.gif
## /docs/_static/ltx-video_ic_2v_example_00004.gif
Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_ic_2v_example_00004.gif
## /inference.py
```py path="/inference.py"
from transformers import HfArgumentParser
from ltx_video.inference import infer, InferenceConfig
def main():
parser = HfArgumentParser(InferenceConfig)
config = parser.parse_args_into_dataclasses()[0]
infer(config=config)
if __name__ == "__main__":
main()
```
## /ltx_video/__init__.py
```py path="/ltx_video/__init__.py"
```
## /ltx_video/inference.py
```py path="/ltx_video/inference.py"
import os
import random
from datetime import datetime
from pathlib import Path
from diffusers.utils import logging
from typing import Optional, List, Union
import yaml
import imageio
import json
import numpy as np
import torch
from safetensors import safe_open
from PIL import Image
import torchvision.transforms.functional as TVF
from transformers import (
T5EncoderModel,
T5Tokenizer,
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
)
from huggingface_hub import hf_hub_download
from dataclasses import dataclass, field
from ltx_video.models.autoencoders.causal_video_autoencoder import (
CausalVideoAutoencoder,
)
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
from ltx_video.models.transformers.transformer3d import Transformer3DModel
from ltx_video.pipelines.pipeline_ltx_video import (
ConditioningItem,
LTXVideoPipeline,
LTXMultiScalePipeline,
)
from ltx_video.schedulers.rf import RectifiedFlowScheduler
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
import ltx_video.pipelines.crf_compressor as crf_compressor
logger = logging.get_logger("LTX-Video")
def get_total_gpu_memory():
if torch.cuda.is_available():
total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
return total_memory
return 0
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
def load_image_to_tensor_with_resize_and_crop(
image_input: Union[str, Image.Image],
target_height: int = 512,
target_width: int = 768,
just_crop: bool = False,
) -> torch.Tensor:
"""Load and process an image into a tensor.
Args:
image_input: Either a file path (str) or a PIL Image object
target_height: Desired height of output tensor
target_width: Desired width of output tensor
just_crop: If True, only crop the image to the target size without resizing
"""
if isinstance(image_input, str):
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input
else:
raise ValueError("image_input must be either a file path or a PIL Image object")
input_width, input_height = image.size
aspect_ratio_target = target_width / target_height
aspect_ratio_frame = input_width / input_height
if aspect_ratio_frame > aspect_ratio_target:
new_width = int(input_height * aspect_ratio_target)
new_height = input_height
x_start = (input_width - new_width) // 2
y_start = 0
else:
new_width = input_width
new_height = int(input_width / aspect_ratio_target)
x_start = 0
y_start = (input_height - new_height) // 2
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
if not just_crop:
image = image.resize((target_width, target_height))
frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W), [0,1]
frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=3, sigma=1.0)
frame_tensor_hwc = frame_tensor.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
frame_tensor = frame_tensor_hwc.permute(2, 0, 1) * 255.0 # (H, W, C) -> (C, H, W)
frame_tensor = (frame_tensor / 127.5) - 1.0
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
return frame_tensor.unsqueeze(0).unsqueeze(2)
def calculate_padding(
source_height: int, source_width: int, target_height: int, target_width: int
) -> tuple[int, int, int, int]:
# Calculate total padding needed
pad_height = target_height - source_height
pad_width = target_width - source_width
# Calculate padding for each side
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top # Handles odd padding
pad_left = pad_width // 2
pad_right = pad_width - pad_left # Handles odd padding
# Return padded tensor
# Padding format is (left, right, top, bottom)
padding = (pad_left, pad_right, pad_top, pad_bottom)
return padding
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
# Remove non-letters and convert to lowercase
clean_text = "".join(
char.lower() for char in text if char.isalpha() or char.isspace()
)
# Split into words
words = clean_text.split()
# Build result string keeping track of length
result = []
current_length = 0
for word in words:
# Add word length plus 1 for underscore (except for first word)
new_length = current_length + len(word)
if new_length <= max_len:
result.append(word)
current_length += len(word)
else:
break
return "-".join(result)
# Generate output video name
def get_unique_filename(
base: str,
ext: str,
prompt: str,
seed: int,
resolution: tuple[int, int, int],
dir: Path,
endswith=None,
index_range=1000,
) -> Path:
base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
for i in range(index_range):
filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
if not os.path.exists(filename):
return filename
raise FileExistsError(
f"Could not find a unique filename after {index_range} attempts."
)
def seed_everething(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel:
if precision == "float8_e4m3fn":
try:
from q8_kernels.integration.patch_transformer import (
patch_diffusers_transformer as patch_transformer_for_q8_kernels,
)
transformer = Transformer3DModel.from_pretrained(
ckpt_path, dtype=torch.float8_e4m3fn
)
patch_transformer_for_q8_kernels(transformer)
return transformer
except ImportError:
raise ValueError(
"Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from https://github.com/Lightricks/LTXVideo-Q8-Kernels"
)
elif precision == "bfloat16":
return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16)
else:
return Transformer3DModel.from_pretrained(ckpt_path)
def create_ltx_video_pipeline(
ckpt_path: str,
precision: str,
text_encoder_model_name_or_path: str,
sampler: Optional[str] = None,
device: Optional[str] = None,
enhance_prompt: bool = False,
prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
) -> LTXVideoPipeline:
ckpt_path = Path(ckpt_path)
assert os.path.exists(
ckpt_path
), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
with safe_open(ckpt_path, framework="pt") as f:
metadata = f.metadata()
config_str = metadata.get("config")
configs = json.loads(config_str)
allowed_inference_steps = configs.get("allowed_inference_steps", None)
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
transformer = create_transformer(ckpt_path, precision)
# Use constructor if sampler is specified, otherwise use from_pretrained
if sampler == "from_checkpoint" or not sampler:
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
else:
scheduler = RectifiedFlowScheduler(
sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
)
text_encoder = T5EncoderModel.from_pretrained(
text_encoder_model_name_or_path, subfolder="text_encoder"
)
patchifier = SymmetricPatchifier(patch_size=1)
tokenizer = T5Tokenizer.from_pretrained(
text_encoder_model_name_or_path, subfolder="tokenizer"
)
transformer = transformer.to(device)
vae = vae.to(device)
text_encoder = text_encoder.to(device)
if enhance_prompt:
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
)
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
)
prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
prompt_enhancer_llm_model_name_or_path,
torch_dtype="bfloat16",
)
prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
prompt_enhancer_llm_model_name_or_path,
)
else:
prompt_enhancer_image_caption_model = None
prompt_enhancer_image_caption_processor = None
prompt_enhancer_llm_model = None
prompt_enhancer_llm_tokenizer = None
vae = vae.to(torch.bfloat16)
text_encoder = text_encoder.to(torch.bfloat16)
# Use submodels for the pipeline
submodel_dict = {
"transformer": transformer,
"patchifier": patchifier,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"scheduler": scheduler,
"vae": vae,
"prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
"prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
"prompt_enhancer_llm_model": prompt_enhancer_llm_model,
"prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
"allowed_inference_steps": allowed_inference_steps,
}
pipeline = LTXVideoPipeline(**submodel_dict)
pipeline = pipeline.to(device)
return pipeline
def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
latent_upsampler.to(device)
latent_upsampler.eval()
return latent_upsampler
def load_pipeline_config(pipeline_config: str):
current_file = Path(__file__)
path = None
if os.path.isfile(current_file.parent / pipeline_config):
path = current_file.parent / pipeline_config
elif os.path.isfile(pipeline_config):
path = pipeline_config
else:
raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
with open(path, "r") as f:
return yaml.safe_load(f)
@dataclass
class InferenceConfig:
prompt: str = field(metadata={"help": "Prompt for the generation"})
output_path: str = field(
default_factory=lambda: Path(
f"outputs/{datetime.today().strftime('%Y-%m-%d')}"
),
metadata={"help": "Path to the folder to save the output video"},
)
# Pipeline settings
pipeline_config: str = field(
default="configs/ltxv-13b-0.9.7-dev.yaml",
metadata={"help": "Path to the pipeline config file"},
)
seed: int = field(
default=171198, metadata={"help": "Random seed for the inference"}
)
height: int = field(
default=704, metadata={"help": "Height of the output video frames"}
)
width: int = field(
default=1216, metadata={"help": "Width of the output video frames"}
)
num_frames: int = field(
default=121,
metadata={"help": "Number of frames to generate in the output video"},
)
frame_rate: int = field(
default=30, metadata={"help": "Frame rate for the output video"}
)
offload_to_cpu: bool = field(
default=False, metadata={"help": "Offloading unnecessary computations to CPU."}
)
negative_prompt: str = field(
default="worst quality, inconsistent motion, blurry, jittery, distorted",
metadata={"help": "Negative prompt for undesired features"},
)
# Video-to-video arguments
input_media_path: Optional[str] = field(
default=None,
metadata={
"help": "Path to the input video (or image) to be modified using the video-to-video pipeline"
},
)
# Conditioning
image_cond_noise_scale: float = field(
default=0.15,
metadata={"help": "Amount of noise to add to the conditioned image"},
)
conditioning_media_paths: Optional[List[str]] = field(
default=None,
metadata={
"help": "List of paths to conditioning media (images or videos). Each path will be used as a conditioning item."
},
)
conditioning_strengths: Optional[List[float]] = field(
default=None,
metadata={
"help": "List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items."
},
)
conditioning_start_frames: Optional[List[int]] = field(
default=None,
metadata={
"help": "List of frame indices where each conditioning item should be applied. Must match the number of conditioning items."
},
)
def infer(config: InferenceConfig):
pipeline_config = load_pipeline_config(config.pipeline_config)
ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
if not os.path.isfile(ltxv_model_name_or_path):
ltxv_model_path = hf_hub_download(
repo_id="Lightricks/LTX-Video",
filename=ltxv_model_name_or_path,
repo_type="model",
)
else:
ltxv_model_path = ltxv_model_name_or_path
spatial_upscaler_model_name_or_path = pipeline_config.get(
"spatial_upscaler_model_path"
)
if spatial_upscaler_model_name_or_path and not os.path.isfile(
spatial_upscaler_model_name_or_path
):
spatial_upscaler_model_path = hf_hub_download(
repo_id="Lightricks/LTX-Video",
filename=spatial_upscaler_model_name_or_path,
repo_type="model",
)
else:
spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
conditioning_media_paths = config.conditioning_media_paths
conditioning_strengths = config.conditioning_strengths
conditioning_start_frames = config.conditioning_start_frames
# Validate conditioning arguments
if conditioning_media_paths:
# Use default strengths of 1.0
if not conditioning_strengths:
conditioning_strengths = [1.0] * len(conditioning_media_paths)
if not conditioning_start_frames:
raise ValueError(
"If `conditioning_media_paths` is provided, "
"`conditioning_start_frames` must also be provided"
)
if len(conditioning_media_paths) != len(conditioning_strengths) or len(
conditioning_media_paths
) != len(conditioning_start_frames):
raise ValueError(
"`conditioning_media_paths`, `conditioning_strengths`, "
"and `conditioning_start_frames` must have the same length"
)
if any(s < 0 or s > 1 for s in conditioning_strengths):
raise ValueError("All conditioning strengths must be between 0 and 1")
if any(f < 0 or f >= config.num_frames for f in conditioning_start_frames):
raise ValueError(
f"All conditioning start frames must be between 0 and {config.num_frames-1}"
)
seed_everething(config.seed)
if config.offload_to_cpu and not torch.cuda.is_available():
logger.warning(
"offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
)
offload_to_cpu = False
else:
offload_to_cpu = config.offload_to_cpu and get_total_gpu_memory() < 30
output_dir = (
Path(config.output_path)
if config.output_path
else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
)
output_dir.mkdir(parents=True, exist_ok=True)
# Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
height_padded = ((config.height - 1) // 32 + 1) * 32
width_padded = ((config.width - 1) // 32 + 1) * 32
num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1
padding = calculate_padding(
config.height, config.width, height_padded, width_padded
)
logger.warning(
f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
)
device = get_device()
prompt_enhancement_words_threshold = pipeline_config[
"prompt_enhancement_words_threshold"
]
prompt_word_count = len(config.prompt.split())
enhance_prompt = (
prompt_enhancement_words_threshold > 0
and prompt_word_count < prompt_enhancement_words_threshold
)
if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
logger.info(
f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
)
precision = pipeline_config["precision"]
text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
sampler = pipeline_config.get("sampler", None)
prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
"prompt_enhancer_image_caption_model_name_or_path"
]
prompt_enhancer_llm_model_name_or_path = pipeline_config[
"prompt_enhancer_llm_model_name_or_path"
]
pipeline = create_ltx_video_pipeline(
ckpt_path=ltxv_model_path,
precision=precision,
text_encoder_model_name_or_path=text_encoder_model_name_or_path,
sampler=sampler,
device=device,
enhance_prompt=enhance_prompt,
prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
)
if pipeline_config.get("pipeline_type", None) == "multi-scale":
if not spatial_upscaler_model_path:
raise ValueError(
"spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
)
latent_upsampler = create_latent_upsampler(
spatial_upscaler_model_path, pipeline.device
)
pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
media_item = None
if config.input_media_path:
media_item = load_media_file(
media_path=config.input_media_path,
height=config.height,
width=config.width,
max_frames=num_frames_padded,
padding=padding,
)
conditioning_items = (
prepare_conditioning(
conditioning_media_paths=conditioning_media_paths,
conditioning_strengths=conditioning_strengths,
conditioning_start_frames=conditioning_start_frames,
height=config.height,
width=config.width,
num_frames=config.num_frames,
padding=padding,
pipeline=pipeline,
)
if conditioning_media_paths
else None
)
stg_mode = pipeline_config.get("stg_mode", "attention_values")
del pipeline_config["stg_mode"]
if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
skip_layer_strategy = SkipLayerStrategy.AttentionValues
elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
skip_layer_strategy = SkipLayerStrategy.AttentionSkip
elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
skip_layer_strategy = SkipLayerStrategy.Residual
elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
skip_layer_strategy = SkipLayerStrategy.TransformerBlock
else:
raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
# Prepare input for the pipeline
sample = {
"prompt": config.prompt,
"prompt_attention_mask": None,
"negative_prompt": config.negative_prompt,
"negative_prompt_attention_mask": None,
}
generator = torch.Generator(device=device).manual_seed(config.seed)
images = pipeline(
**pipeline_config,
skip_layer_strategy=skip_layer_strategy,
generator=generator,
output_type="pt",
callback_on_step_end=None,
height=height_padded,
width=width_padded,
num_frames=num_frames_padded,
frame_rate=config.frame_rate,
**sample,
media_items=media_item,
conditioning_items=conditioning_items,
is_video=True,
vae_per_channel_normalize=True,
image_cond_noise_scale=config.image_cond_noise_scale,
mixed_precision=(precision == "mixed_precision"),
offload_to_cpu=offload_to_cpu,
device=device,
enhance_prompt=enhance_prompt,
).images
# Crop the padded images to the desired resolution and number of frames
(pad_left, pad_right, pad_top, pad_bottom) = padding
pad_bottom = -pad_bottom
pad_right = -pad_right
if pad_bottom == 0:
pad_bottom = images.shape[3]
if pad_right == 0:
pad_right = images.shape[4]
images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right]
for i in range(images.shape[0]):
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
# Unnormalizing images to [0, 255] range
video_np = (video_np * 255).astype(np.uint8)
fps = config.frame_rate
height, width = video_np.shape[1:3]
# In case a single image is generated
if video_np.shape[0] == 1:
output_filename = get_unique_filename(
f"image_output_{i}",
".png",
prompt=config.prompt,
seed=config.seed,
resolution=(height, width, config.num_frames),
dir=output_dir,
)
imageio.imwrite(output_filename, video_np[0])
else:
output_filename = get_unique_filename(
f"video_output_{i}",
".mp4",
prompt=config.prompt,
seed=config.seed,
resolution=(height, width, config.num_frames),
dir=output_dir,
)
# Write video
with imageio.get_writer(output_filename, fps=fps) as video:
for frame in video_np:
video.append_data(frame)
logger.warning(f"Output saved to {output_filename}")
def prepare_conditioning(
conditioning_media_paths: List[str],
conditioning_strengths: List[float],
conditioning_start_frames: List[int],
height: int,
width: int,
num_frames: int,
padding: tuple[int, int, int, int],
pipeline: LTXVideoPipeline,
) -> Optional[List[ConditioningItem]]:
"""Prepare conditioning items based on input media paths and their parameters.
Args:
conditioning_media_paths: List of paths to conditioning media (images or videos)
conditioning_strengths: List of conditioning strengths for each media item
conditioning_start_frames: List of frame indices where each item should be applied
height: Height of the output frames
width: Width of the output frames
num_frames: Number of frames in the output video
padding: Padding to apply to the frames
pipeline: LTXVideoPipeline object used for condition video trimming
Returns:
A list of ConditioningItem objects.
"""
conditioning_items = []
for path, strength, start_frame in zip(
conditioning_media_paths, conditioning_strengths, conditioning_start_frames
):
num_input_frames = orig_num_input_frames = get_media_num_frames(path)
if hasattr(pipeline, "trim_conditioning_sequence") and callable(
getattr(pipeline, "trim_conditioning_sequence")
):
num_input_frames = pipeline.trim_conditioning_sequence(
start_frame, orig_num_input_frames, num_frames
)
if num_input_frames < orig_num_input_frames:
logger.warning(
f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
)
media_tensor = load_media_file(
media_path=path,
height=height,
width=width,
max_frames=num_input_frames,
padding=padding,
just_crop=True,
)
conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
return conditioning_items
def get_media_num_frames(media_path: str) -> int:
is_video = any(
media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
)
num_frames = 1
if is_video:
reader = imageio.get_reader(media_path)
num_frames = reader.count_frames()
reader.close()
return num_frames
def load_media_file(
media_path: str,
height: int,
width: int,
max_frames: int,
padding: tuple[int, int, int, int],
just_crop: bool = False,
) -> torch.Tensor:
is_video = any(
media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
)
if is_video:
reader = imageio.get_reader(media_path)
num_input_frames = min(reader.count_frames(), max_frames)
# Read and preprocess the relevant frames from the video file.
frames = []
for i in range(num_input_frames):
frame = Image.fromarray(reader.get_data(i))
frame_tensor = load_image_to_tensor_with_resize_and_crop(
frame, height, width, just_crop=just_crop
)
frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
frames.append(frame_tensor)
reader.close()
# Stack frames along the temporal dimension
media_tensor = torch.cat(frames, dim=2)
else: # Input image
media_tensor = load_image_to_tensor_with_resize_and_crop(
media_path, height, width, just_crop=just_crop
)
media_tensor = torch.nn.functional.pad(media_tensor, padding)
return media_tensor
```
## /ltx_video/models/__init__.py
```py path="/ltx_video/models/__init__.py"
```
## /ltx_video/models/autoencoders/__init__.py
```py path="/ltx_video/models/autoencoders/__init__.py"
```
## /ltx_video/models/autoencoders/causal_conv3d.py
```py path="/ltx_video/models/autoencoders/causal_conv3d.py"
from typing import Tuple, Union
import torch
import torch.nn as nn
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
spatial_padding_mode: str = "zeros",
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
dilation = (dilation, 1, 1)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
padding = (0, height_pad, width_pad)
self.conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
padding_mode=spatial_padding_mode,
groups=groups,
)
def forward(self, x, causal: bool = True):
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
last_frame_pad = x[:, :, -1:, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
@property
def weight(self):
return self.conv.weight
```
## /ltx_video/models/autoencoders/causal_video_autoencoder.py
```py path="/ltx_video/models/autoencoders/causal_video_autoencoder.py"
import json
import os
from functools import partial
from types import SimpleNamespace
from typing import Any, Mapping, Optional, Tuple, Union, List
from pathlib import Path
import torch
import numpy as np
from einops import rearrange
from torch import nn
from diffusers.utils import logging
import torch.nn.functional as F
from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from safetensors import safe_open
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
from ltx_video.models.autoencoders.pixel_norm import PixelNorm
from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND
from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
from ltx_video.models.transformers.attention import Attention
from ltx_video.utils.diffusers_config_mapping import (
diffusers_and_ours_config_mapping,
make_hashable_key,
VAE_KEYS_RENAME_DICT,
)
PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics."
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CausalVideoAutoencoder(AutoencoderKLWrapper):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*args,
**kwargs,
):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if (
pretrained_model_name_or_path.is_dir()
and (pretrained_model_name_or_path / "autoencoder.pth").exists()
):
config_local_path = pretrained_model_name_or_path / "config.json"
config = cls.load_config(config_local_path, **kwargs)
model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
statistics_local_path = (
pretrained_model_name_or_path / "per_channel_statistics.json"
)
if statistics_local_path.exists():
with open(statistics_local_path, "r") as file:
data = json.load(file)
transposed_data = list(zip(*data["data"]))
data_dict = {
col: torch.tensor(vals)
for col, vals in zip(data["columns"], transposed_data)
}
std_of_means = data_dict["std-of-means"]
mean_of_means = data_dict.get(
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
)
state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = (
std_of_means
)
state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = (
mean_of_means
)
elif pretrained_model_name_or_path.is_dir():
config_path = pretrained_model_name_or_path / "vae" / "config.json"
with open(config_path, "r") as f:
config = make_hashable_key(json.load(f))
assert config in diffusers_and_ours_config_mapping, (
"Provided diffusers checkpoint config for VAE is not suppported. "
"We only support diffusers configs found in Lightricks/LTX-Video."
)
config = diffusers_and_ours_config_mapping[config]
state_dict_path = (
pretrained_model_name_or_path
/ "vae"
/ "diffusion_pytorch_model.safetensors"
)
state_dict = {}
with safe_open(state_dict_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
for key in list(state_dict.keys()):
new_key = key
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
state_dict[new_key] = state_dict.pop(key)
elif pretrained_model_name_or_path.is_file() and str(
pretrained_model_name_or_path
).endswith(".safetensors"):
state_dict = {}
with safe_open(
pretrained_model_name_or_path, framework="pt", device="cpu"
) as f:
metadata = f.metadata()
for k in f.keys():
state_dict[k] = f.get_tensor(k)
configs = json.loads(metadata["config"])
config = configs["vae"]
video_vae = cls.from_config(config)
if "torch_dtype" in kwargs:
video_vae.to(kwargs["torch_dtype"])
video_vae.load_state_dict(state_dict)
return video_vae
@staticmethod
def from_config(config):
assert (
config["_class_name"] == "CausalVideoAutoencoder"
), "config must have _class_name=CausalVideoAutoencoder"
if isinstance(config["dims"], list):
config["dims"] = tuple(config["dims"])
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
use_quant_conv = config.get("use_quant_conv", True)
normalize_latent_channels = config.get("normalize_latent_channels", False)
if use_quant_conv and latent_log_var in ["uniform", "constant"]:
raise ValueError(
f"latent_log_var={latent_log_var} requires use_quant_conv=False"
)
encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("blocks")),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
base_channels=config.get("encoder_base_channels", 128),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("blocks")),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=config.get("timestep_conditioning", False),
base_channels=config.get("decoder_base_channels", 128),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
dims = config["dims"]
return CausalVideoAutoencoder(
encoder=encoder,
decoder=decoder,
latent_channels=config["latent_channels"],
dims=dims,
use_quant_conv=use_quant_conv,
normalize_latent_channels=normalize_latent_channels,
)
@property
def config(self):
return SimpleNamespace(
_class_name="CausalVideoAutoencoder",
dims=self.dims,
in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
out_channels=self.decoder.conv_out.out_channels
// self.decoder.patch_size**2,
latent_channels=self.decoder.conv_in.in_channels,
encoder_blocks=self.encoder.blocks_desc,
decoder_blocks=self.decoder.blocks_desc,
scaling_factor=1.0,
norm_layer=self.encoder.norm_layer,
patch_size=self.encoder.patch_size,
latent_log_var=self.encoder.latent_log_var,
use_quant_conv=self.use_quant_conv,
causal_decoder=self.decoder.causal,
timestep_conditioning=self.decoder.timestep_conditioning,
normalize_latent_channels=self.normalize_latent_channels,
)
@property
def is_video_supported(self):
"""
Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
"""
return self.dims != 2
@property
def spatial_downscale_factor(self):
return (
2
** len(
[
block
for block in self.encoder.blocks_desc
if block[0]
in [
"compress_space",
"compress_all",
"compress_all_res",
"compress_space_res",
]
]
)
* self.encoder.patch_size
)
@property
def temporal_downscale_factor(self):
return 2 ** len(
[
block
for block in self.encoder.blocks_desc
if block[0]
in [
"compress_time",
"compress_all",
"compress_all_res",
"compress_time_res",
]
]
)
def to_json_string(self) -> str:
import json
return json.dumps(self.config.__dict__)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
if any([key.startswith("vae.") for key in state_dict.keys()]):
state_dict = {
key.replace("vae.", ""): value
for key, value in state_dict.items()
if key.startswith("vae.")
}
ckpt_state_dict = {
key: value
for key, value in state_dict.items()
if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
}
model_keys = set(name for name, _ in self.named_modules())
key_mapping = {
".resnets.": ".res_blocks.",
"downsamplers.0": "downsample",
"upsamplers.0": "upsample",
}
converted_state_dict = {}
for key, value in ckpt_state_dict.items():
for k, v in key_mapping.items():
key = key.replace(k, v)
key_prefix = ".".join(key.split(".")[:-1])
if "norm" in key and key_prefix not in model_keys:
logger.info(
f"Removing key {key} from state_dict as it is not present in the model"
)
continue
converted_state_dict[key] = value
super().load_state_dict(converted_state_dict, strict=strict)
data_dict = {
key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value
for key, value in state_dict.items()
if key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
}
if len(data_dict) > 0:
self.register_buffer("std_of_means", data_dict["std-of-means"])
self.register_buffer(
"mean_of_means",
data_dict.get(
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
),
)
def last_layer(self):
if hasattr(self.decoder, "conv_out"):
if isinstance(self.decoder.conv_out, nn.Sequential):
last_layer = self.decoder.conv_out[-1]
else:
last_layer = self.decoder.conv_out
else:
last_layer = self.decoder.layers[-1]
return last_layer
def set_use_tpu_flash_attention(self):
for block in self.decoder.up_blocks:
if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
for attention_block in block.attention_blocks:
attention_block.set_use_tpu_flash_attention()
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self.blocks_desc = blocks
in_channels = in_channels * patch_size**2
output_channel = base_channels
self.conv_in = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.down_blocks = nn.ModuleList([])
for block_name, block_params in blocks:
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 1, 1),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(1, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown block: {block_name}")
self.down_blocks.append(block)
# out
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = out_channels
if latent_log_var == "per_channel":
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var == "constant":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims,
output_channel,
conv_out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample)
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
num_dims = sample.dim()
if num_dims == 4:
# For shape (B, C, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
elif num_dims == 5:
# For shape (B, C, F, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
elif self.latent_log_var == "constant":
sample = sample[:, :-1, ...]
approx_ln_0 = (
-30
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
sample = torch.cat(
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
dim=1,
)
return sample
class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
causal (`bool`, *optional*, defaults to `True`):
Whether to use causal convolutions or not.
"""
def __init__(
self,
dims,
in_channels: int = 3,
out_channels: int = 3,
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128,
layers_per_block: int = 2,
norm_num_groups: int = 32,
patch_size: int = 1,
norm_layer: str = "group_norm",
causal: bool = True,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
self.layers_per_block = layers_per_block
out_channels = out_channels * patch_size**2
self.causal = causal
self.blocks_desc = blocks
# Compute output channel to be product of all channel-multiplier blocks
output_channel = base_channels
for block_name, block_params in list(reversed(blocks)):
block_params = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
output_channel = output_channel * block_params.get("multiplier", 2)
if block_name.startswith("compress"):
output_channel = output_channel * block_params.get("multiplier", 1)
self.conv_in = make_conv_nd(
dims,
in_channels,
output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.up_blocks = nn.ModuleList([])
for block_name, block_params in list(reversed(blocks)):
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=False,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 2, 2),
residual=block_params.get("residual", False),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown layer: {block_name}")
self.up_blocks.append(block)
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims,
output_channel,
out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(
torch.tensor(1000.0, dtype=torch.float32)
)
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
output_channel * 2, 0
)
self.last_scale_shift_table = nn.Parameter(
torch.randn(2, output_channel) / output_channel**0.5
)
def forward(
self,
sample: torch.FloatTensor,
target_shape,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
assert target_shape is not None, "target_shape must be provided"
batch_size = sample.shape[0]
sample = self.conv_in(sample, causal=self.causal)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
sample = sample.to(upscale_dtype)
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier
for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=sample.shape[0],
hidden_dtype=sample.dtype,
)
embedded_timestep = embedded_timestep.view(
batch_size, embedded_timestep.shape[-1], 1, 1, 1
)
ada_values = self.last_scale_shift_table[
None, ..., None, None, None
] + embedded_timestep.reshape(
batch_size,
2,
-1,
embedded_timestep.shape[-3],
embedded_timestep.shape[-2],
embedded_timestep.shape[-1],
)
shift, scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
class UNetMidBlock3D(nn.Module):
"""
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
Args:
in_channels (`int`): The number of input channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
inject_noise (`bool`, *optional*, defaults to `False`):
Whether to inject noise into the hidden states.
timestep_conditioning (`bool`, *optional*, defaults to `False`):
Whether to condition the hidden states on the timestep.
attention_head_dim (`int`, *optional*, defaults to -1):
The dimension of the attention head. If -1, no attention is used.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
attention_head_dim: int = -1,
spatial_padding_mode: str = "zeros",
):
super().__init__()
resnet_groups = (
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
in_channels * 4, 0
)
self.res_blocks = nn.ModuleList(
[
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for _ in range(num_layers)
]
)
self.attention_blocks = None
if attention_head_dim > 0:
if attention_head_dim > in_channels:
raise ValueError(
"attention_head_dim must be less than or equal to in_channels"
)
self.attention_blocks = nn.ModuleList(
[
Attention(
query_dim=in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
bias=True,
out_bias=True,
qk_norm="rms_norm",
residual_connection=True,
)
for _ in range(num_layers)
]
)
def forward(
self,
hidden_states: torch.FloatTensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
timestep_embed = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
batch_size = hidden_states.shape[0]
timestep_embed = self.time_embedder(
timestep=timestep.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
timestep_embed = timestep_embed.view(
batch_size, timestep_embed.shape[-1], 1, 1, 1
)
if self.attention_blocks:
for resnet, attention in zip(self.res_blocks, self.attention_blocks):
hidden_states = resnet(
hidden_states, causal=causal, timestep=timestep_embed
)
# Reshape the hidden states to be (batch_size, frames * height * width, channel)
batch_size, channel, frames, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, frames * height * width
).transpose(1, 2)
if attention.use_tpu_flash_attention:
# Pad the second dimension to be divisible by block_k_major (block in flash attention)
seq_len = hidden_states.shape[1]
block_k_major = 512
pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
if pad_len > 0:
hidden_states = F.pad(
hidden_states, (0, 0, 0, pad_len), "constant", 0
)
# Create a mask with ones for the original sequence length and zeros for the padded indexes
mask = torch.ones(
(hidden_states.shape[0], seq_len),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if pad_len > 0:
mask = F.pad(mask, (0, pad_len), "constant", 0)
hidden_states = attention(
hidden_states,
attention_mask=(
None if not attention.use_tpu_flash_attention else mask
),
)
if attention.use_tpu_flash_attention:
# Remove the padding
if pad_len > 0:
hidden_states = hidden_states[:, :-pad_len, :]
# Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, frames, height, width
)
else:
for resnet in self.res_blocks:
hidden_states = resnet(
hidden_states, causal=causal, timestep=timestep_embed
)
return hidden_states
class SpaceToDepthDownsample(nn.Module):
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
super().__init__()
self.stride = stride
self.group_size = in_channels * np.prod(stride) // out_channels
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=out_channels // np.prod(stride),
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
def forward(self, x, causal: bool = True):
if self.stride[0] == 2:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
# skip connection
x_in = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
x_in = x_in.mean(dim=2)
# conv
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x = x + x_in
return x
class DepthToSpaceUpsample(nn.Module):
def __init__(
self,
dims,
in_channels,
stride,
residual=False,
out_channels_reduction_factor=1,
spatial_padding_mode="zeros",
):
super().__init__()
self.stride = stride
self.out_channels = (
np.prod(stride) * in_channels // out_channels_reduction_factor
)
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
def forward(self, x, causal: bool = True):
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = self.pixel_shuffle(x)
num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
x = self.pixel_shuffle(x)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c")
x = self.norm(x)
x = rearrange(x, "b d h w c -> b c d h w")
return x
class ResnetBlock3D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
groups: int = 32,
eps: float = 1e-6,
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.inject_noise = inject_noise
if norm_layer == "group_norm":
self.norm1 = nn.GroupNorm(
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm1 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
self.non_linearity = nn.SiLU()
self.conv1 = make_conv_nd(
dims,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
if norm_layer == "group_norm":
self.norm2 = nn.GroupNorm(
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm2 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv_nd(
dims,
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
self.conv_shortcut = (
make_linear_nd(
dims=dims, in_channels=in_channels, out_channels=out_channels
)
if in_channels != out_channels
else nn.Identity()
)
self.norm3 = (
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
if in_channels != out_channels
else nn.Identity()
)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
self.scale_shift_table = nn.Parameter(
torch.randn(4, in_channels) / in_channels**0.5
)
def _feed_spatial_noise(
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
) -> torch.FloatTensor:
spatial_shape = hidden_states.shape[-2:]
device = hidden_states.device
dtype = hidden_states.dtype
# similar to the "explicit noise inputs" method in style-gan
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
hidden_states = hidden_states + scaled_noise
return hidden_states
def forward(
self,
input_tensor: torch.FloatTensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
hidden_states = input_tensor
batch_size = hidden_states.shape[0]
hidden_states = self.norm1(hidden_states)
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
ada_values = self.scale_shift_table[
None, ..., None, None, None
] + timestep.reshape(
batch_size,
4,
-1,
timestep.shape[-3],
timestep.shape[-2],
timestep.shape[-1],
)
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
hidden_states = hidden_states * (1 + scale1) + shift1
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv1(hidden_states, causal=causal)
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale1
)
hidden_states = self.norm2(hidden_states)
if self.timestep_conditioning:
hidden_states = hidden_states * (1 + scale2) + shift2
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, causal=causal)
if self.inject_noise:
hidden_states = self._feed_spatial_noise(
hidden_states, self.per_channel_scale2
)
input_tensor = self.norm3(input_tensor)
batch_size = input_tensor.shape[0]
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
def patchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
return x
def create_video_autoencoder_demo_config(
latent_channels: int = 64,
):
encoder_blocks = [
("res_x", {"num_layers": 2}),
("compress_space_res", {"multiplier": 2}),
("compress_time_res", {"multiplier": 2}),
("compress_all_res", {"multiplier": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 1}),
]
decoder_blocks = [
("res_x", {"num_layers": 2, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("compress_all", {"residual": True, "multiplier": 2}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 2, "inject_noise": False}),
]
return {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"encoder_blocks": encoder_blocks,
"decoder_blocks": decoder_blocks,
"latent_channels": latent_channels,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
"timestep_conditioning": True,
"spatial_padding_mode": "replicate",
}
def test_vae_patchify_unpatchify():
import torch
x = torch.randn(2, 3, 8, 64, 64)
x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
assert torch.allclose(x, x_unpatched)
def demo_video_autoencoder_forward_backward():
# Configuration for the VideoAutoencoder
config = create_video_autoencoder_demo_config()
# Instantiate the VideoAutoencoder with the specified configuration
video_autoencoder = CausalVideoAutoencoder.from_config(config)
print(video_autoencoder)
video_autoencoder.eval()
# Print the total number of parameters in the video autoencoder
total_params = sum(p.numel() for p in video_autoencoder.parameters())
print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
# Create a mock input tensor simulating a batch of videos
# Shape: (batch_size, channels, depth, height, width)
# E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
input_videos = torch.randn(2, 3, 17, 64, 64)
# Forward pass: encode and decode the input videos
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
print(f"input shape={input_videos.shape}")
print(f"latent shape={latent.shape}")
timestep = torch.ones(input_videos.shape[0]) * 0.1
reconstructed_videos = video_autoencoder.decode(
latent, target_shape=input_videos.shape, timestep=timestep
).sample
print(f"reconstructed shape={reconstructed_videos.shape}")
# Validate that single image gets treated the same way as first frame
input_image = input_videos[:, :, :1, :, :]
image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
_ = video_autoencoder.decode(
image_latent, target_shape=image_latent.shape, timestep=timestep
).sample
first_frame_latent = latent[:, :, :1, :, :]
assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
# assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6)
# assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
# assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all()
# Calculate the loss (e.g., mean squared error)
loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
# Perform backward pass
loss.backward()
print(f"Demo completed with loss: {loss.item()}")
# Ensure to call the demo function to execute the forward and backward pass
if __name__ == "__main__":
demo_video_autoencoder_forward_backward()
```
## /ltx_video/models/autoencoders/conv_nd_factory.py
```py path="/ltx_video/models/autoencoders/conv_nd_factory.py"
from typing import Tuple, Union
import torch
from ltx_video.models.autoencoders.dual_conv3d import DualConv3d
from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d
def make_conv_nd(
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
kernel_size: int,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
causal=False,
spatial_padding_mode="zeros",
temporal_padding_mode="zeros",
):
if not (spatial_padding_mode == temporal_padding_mode or causal):
raise NotImplementedError("spatial and temporal padding modes must be equal")
if dims == 2:
return torch.nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode,
)
elif dims == 3:
if causal:
return CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
spatial_padding_mode=spatial_padding_mode,
)
return torch.nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode,
)
elif dims == (2, 1):
return DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")
def make_linear_nd(
dims: int,
in_channels: int,
out_channels: int,
bias=True,
):
if dims == 2:
return torch.nn.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
elif dims == 3 or dims == (2, 1):
return torch.nn.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
else:
raise ValueError(f"unsupported dimensions: {dims}")
```
## /ltx_video/models/autoencoders/dual_conv3d.py
```py path="/ltx_video/models/autoencoders/dual_conv3d.py"
import math
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class DualConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1,
bias=True,
padding_mode="zeros",
):
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.padding_mode = padding_mode
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if kernel_size == (1, 1, 1):
raise ValueError(
"kernel_size must be greater than 1. Use make_linear_nd instead."
)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
# Set parameters for convolutions
self.groups = groups
self.bias = bias
# Define the size of the channels after the first convolution
intermediate_channels = (
out_channels if in_channels < out_channels else in_channels
)
# Define parameters for the first convolution
self.weight1 = nn.Parameter(
torch.Tensor(
intermediate_channels,
in_channels // groups,
1,
kernel_size[1],
kernel_size[2],
)
)
self.stride1 = (1, stride[1], stride[2])
self.padding1 = (0, padding[1], padding[2])
self.dilation1 = (1, dilation[1], dilation[2])
if bias:
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
else:
self.register_parameter("bias1", None)
# Define parameters for the second convolution
self.weight2 = nn.Parameter(
torch.Tensor(
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
)
)
self.stride2 = (stride[0], 1, 1)
self.padding2 = (padding[0], 0, 0)
self.dilation2 = (dilation[0], 1, 1)
if bias:
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias2", None)
# Initialize weights and biases
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
if self.bias:
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
bound1 = 1 / math.sqrt(fan_in1)
nn.init.uniform_(self.bias1, -bound1, bound1)
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
bound2 = 1 / math.sqrt(fan_in2)
nn.init.uniform_(self.bias2, -bound2, bound2)
def forward(self, x, use_conv3d=False, skip_time_conv=False):
if use_conv3d:
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
else:
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
def forward_with_3d(self, x, skip_time_conv):
# First convolution
x = F.conv3d(
x,
self.weight1,
self.bias1,
self.stride1,
self.padding1,
self.dilation1,
self.groups,
padding_mode=self.padding_mode,
)
if skip_time_conv:
return x
# Second convolution
x = F.conv3d(
x,
self.weight2,
self.bias2,
self.stride2,
self.padding2,
self.dilation2,
self.groups,
padding_mode=self.padding_mode,
)
return x
def forward_with_2d(self, x, skip_time_conv):
b, c, d, h, w = x.shape
# First 2D convolution
x = rearrange(x, "b c d h w -> (b d) c h w")
# Squeeze the depth dimension out of weight1 since it's 1
weight1 = self.weight1.squeeze(2)
# Select stride, padding, and dilation for the 2D convolution
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(
x,
weight1,
self.bias1,
stride1,
padding1,
dilation1,
self.groups,
padding_mode=self.padding_mode,
)
_, _, h, w = x.shape
if skip_time_conv:
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
return x
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
# Reshape weight2 to match the expected dimensions for conv1d
weight2 = self.weight2.squeeze(-1).squeeze(-1)
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(
x,
weight2,
self.bias2,
stride2,
padding2,
dilation2,
self.groups,
padding_mode=self.padding_mode,
)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x
@property
def weight(self):
return self.weight2
def test_dual_conv3d_consistency():
# Initialize parameters
in_channels = 3
out_channels = 5
kernel_size = (3, 3, 3)
stride = (2, 2, 2)
padding = (1, 1, 1)
# Create an instance of the DualConv3d class
dual_conv3d = DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=True,
)
# Example input tensor
test_input = torch.randn(1, 3, 10, 10, 10)
# Perform forward passes with both 3D and 2D settings
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
output_2d = dual_conv3d(test_input, use_conv3d=False)
# Assert that the outputs from both methods are sufficiently close
assert torch.allclose(
output_conv3d, output_2d, atol=1e-6
), "Outputs are not consistent between 3D and 2D convolutions."
```
## /ltx_video/models/autoencoders/latent_upsampler.py
```py path="/ltx_video/models/autoencoders/latent_upsampler.py"
from typing import Optional, Union
from pathlib import Path
import os
import json
import torch
import torch.nn as nn
from einops import rearrange
from diffusers import ConfigMixin, ModelMixin
from safetensors.torch import safe_open
from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND
class ResBlock(nn.Module):
def __init__(
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
):
super().__init__()
if mid_channels is None:
mid_channels = channels
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = nn.GroupNorm(32, mid_channels)
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = nn.GroupNorm(32, channels)
self.activation = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.activation(x + residual)
return x
class LatentUpsampler(ModelMixin, ConfigMixin):
"""
Model to spatially upsample VAE latents.
Args:
in_channels (`int`): Number of channels in the input latent
mid_channels (`int`): Number of channels in the middle layers
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`): Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`): Whether to spatially upsample the latent
temporal_upsample (`bool`): Whether to temporally upsample the latent
"""
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 512,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = nn.GroupNorm(32, mid_channels)
self.initial_activation = nn.SiLU()
self.res_blocks = nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
if spatial_upsample and temporal_upsample:
self.upsampler = nn.Sequential(
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
self.upsampler = nn.Sequential(
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = nn.Sequential(
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError(
"Either spatial_upsample or temporal_upsample must be True"
)
self.post_upsample_res_blocks = nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, latent: torch.Tensor) -> torch.Tensor:
b, c, f, h, w = latent.shape
if self.dims == 2:
x = rearrange(latent, "b c f h w -> (b f) c h w")
x = self.initial_conv(x)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
x = self.upsampler(x)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
else:
x = self.initial_conv(latent)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
if self.temporal_upsample:
x = self.upsampler(x)
x = x[:, :, 1:, :, :]
else:
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.upsampler(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
return x
@classmethod
def from_config(cls, config):
return cls(
in_channels=config.get("in_channels", 4),
mid_channels=config.get("mid_channels", 128),
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
dims=config.get("dims", 2),
spatial_upsample=config.get("spatial_upsample", True),
temporal_upsample=config.get("temporal_upsample", False),
)
def config(self):
return {
"_class_name": "LatentUpsampler",
"in_channels": self.in_channels,
"mid_channels": self.mid_channels,
"num_blocks_per_stage": self.num_blocks_per_stage,
"dims": self.dims,
"spatial_upsample": self.spatial_upsample,
"temporal_upsample": self.temporal_upsample,
}
@classmethod
def from_pretrained(
cls,
pretrained_model_path: Optional[Union[str, os.PathLike]],
*args,
**kwargs,
):
pretrained_model_path = Path(pretrained_model_path)
if pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
".safetensors"
):
state_dict = {}
with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
for k in f.keys():
state_dict[k] = f.get_tensor(k)
config = json.loads(metadata["config"])
with torch.device("meta"):
latent_upsampler = LatentUpsampler.from_config(config)
latent_upsampler.load_state_dict(state_dict, assign=True)
return latent_upsampler
if __name__ == "__main__":
latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3)
print(latent_upsampler)
total_params = sum(p.numel() for p in latent_upsampler.parameters())
print(f"Total number of parameters: {total_params:,}")
latent = torch.randn(1, 128, 9, 16, 16)
upsampled_latent = latent_upsampler(latent)
print(f"Upsampled latent shape: {upsampled_latent.shape}")
```
## /ltx_video/models/autoencoders/pixel_norm.py
```py path="/ltx_video/models/autoencoders/pixel_norm.py"
import torch
from torch import nn
class PixelNorm(nn.Module):
def __init__(self, dim=1, eps=1e-8):
super(PixelNorm, self).__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
```
## /ltx_video/models/autoencoders/pixel_shuffle.py
```py path="/ltx_video/models/autoencoders/pixel_shuffle.py"
import torch.nn as nn
from einops import rearrange
class PixelShuffleND(nn.Module):
def __init__(self, dims, upscale_factors=(2, 2, 2)):
super().__init__()
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
self.dims = dims
self.upscale_factors = upscale_factors
def forward(self, x):
if self.dims == 3:
return rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
p3=self.upscale_factors[2],
)
elif self.dims == 2:
return rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
)
elif self.dims == 1:
return rearrange(
x,
"b (c p1) f h w -> b c (f p1) h w",
p1=self.upscale_factors[0],
)
```
## /ltx_video/models/autoencoders/vae.py
```py path="/ltx_video/models/autoencoders/vae.py"
from typing import Optional, Union
import torch
import inspect
import math
import torch.nn as nn
from diffusers import ConfigMixin, ModelMixin
from diffusers.models.autoencoders.vae import (
DecoderOutput,
DiagonalGaussianDistribution,
)
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd
class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
"""Variational Autoencoder (VAE) model with KL loss.
VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss.
Args:
encoder (`nn.Module`):
Encoder module.
decoder (`nn.Module`):
Decoder module.
latent_channels (`int`, *optional*, defaults to 4):
Number of latent channels.
"""
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
latent_channels: int = 4,
dims: int = 2,
sample_size=512,
use_quant_conv: bool = True,
normalize_latent_channels: bool = False,
):
super().__init__()
# pass init params to Encoder
self.encoder = encoder
self.use_quant_conv = use_quant_conv
self.normalize_latent_channels = normalize_latent_channels
# pass init params to Decoder
quant_dims = 2 if dims == 2 else 3
self.decoder = decoder
if use_quant_conv:
self.quant_conv = make_conv_nd(
quant_dims, 2 * latent_channels, 2 * latent_channels, 1
)
self.post_quant_conv = make_conv_nd(
quant_dims, latent_channels, latent_channels, 1
)
else:
self.quant_conv = nn.Identity()
self.post_quant_conv = nn.Identity()
if normalize_latent_channels:
if dims == 2:
self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False)
else:
self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False)
else:
self.latent_norm_out = nn.Identity()
self.use_z_tiling = False
self.use_hw_tiling = False
self.dims = dims
self.z_sample_size = 1
self.decoder_params = inspect.signature(self.decoder.forward).parameters
# only relevant if vae tiling is enabled
self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25):
self.tile_sample_min_size = sample_size
num_blocks = len(self.encoder.down_blocks)
self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1)))
self.tile_overlap_factor = overlap_factor
def enable_z_tiling(self, z_sample_size: int = 8):
r"""
Enable tiling during VAE decoding.
When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_z_tiling = z_sample_size > 1
self.z_sample_size = z_sample_size
assert (
z_sample_size % 8 == 0 or z_sample_size == 1
), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}."
def disable_z_tiling(self):
r"""
Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_z_tiling = False
def enable_hw_tiling(self):
r"""
Enable tiling during VAE decoding along the height and width dimension.
"""
self.use_hw_tiling = True
def disable_hw_tiling(self):
r"""
Disable tiling during VAE decoding along the height and width dimension.
"""
self.use_hw_tiling = False
def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True):
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[3], overlap_size):
row = []
for j in range(0, x.shape[4], overlap_size):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
moments = torch.cat(result_rows, dim=3)
return moments
def blend_z(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for z in range(blend_extent):
b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
1 - z / blend_extent
) + b[:, :, z, :, :] * (z / blend_extent)
return b
def blend_v(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
1 - y / blend_extent
) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
1 - x / blend_extent
) + b[:, :, :, :, x] * (x / blend_extent)
return b
def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
tile_target_shape = (
*target_shape[:3],
self.tile_sample_min_size,
self.tile_sample_min_size,
)
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[3], overlap_size):
row = []
for j in range(0, z.shape[4], overlap_size):
tile = z[
:,
:,
:,
i : i + self.tile_latent_min_size,
j : j + self.tile_latent_min_size,
]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile, target_shape=tile_target_shape)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
return dec
def encode(
self, z: torch.FloatTensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
num_splits = z.shape[2] // self.z_sample_size
sizes = [self.z_sample_size] * num_splits
sizes = (
sizes + [z.shape[2] - sum(sizes)]
if z.shape[2] - sum(sizes) > 0
else sizes
)
tiles = z.split(sizes, dim=2)
moments_tiles = [
(
self._hw_tiled_encode(z_tile, return_dict)
if self.use_hw_tiling
else self._encode(z_tile)
)
for z_tile in tiles
]
moments = torch.cat(moments_tiles, dim=2)
else:
moments = (
self._hw_tiled_encode(z, return_dict)
if self.use_hw_tiling
else self._encode(z)
)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
if isinstance(self.latent_norm_out, nn.BatchNorm3d):
_, c, _, _, _ = z.shape
z = torch.cat(
[
self.latent_norm_out(z[:, : c // 2, :, :, :]),
z[:, c // 2 :, :, :, :],
],
dim=1,
)
elif isinstance(self.latent_norm_out, nn.BatchNorm2d):
raise NotImplementedError("BatchNorm2d not supported")
return z
def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
if isinstance(self.latent_norm_out, nn.BatchNorm3d):
running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1)
running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1)
eps = self.latent_norm_out.eps
z = z * torch.sqrt(running_var + eps) + running_mean
elif isinstance(self.latent_norm_out, nn.BatchNorm3d):
raise NotImplementedError("BatchNorm2d not supported")
return z
def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
h = self.encoder(x)
moments = self.quant_conv(h)
moments = self._normalize_latent_channels(moments)
return moments
def _decode(
self,
z: torch.FloatTensor,
target_shape=None,
timestep: Optional[torch.Tensor] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
z = self._unnormalize_latent_channels(z)
z = self.post_quant_conv(z)
if "timestep" in self.decoder_params:
dec = self.decoder(z, target_shape=target_shape, timestep=timestep)
else:
dec = self.decoder(z, target_shape=target_shape)
return dec
def decode(
self,
z: torch.FloatTensor,
return_dict: bool = True,
target_shape=None,
timestep: Optional[torch.Tensor] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
assert target_shape is not None, "target_shape must be provided for decoding"
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
reduction_factor = int(
self.encoder.patch_size_t
* 2
** (
len(self.encoder.down_blocks)
- 1
- math.sqrt(self.encoder.patch_size)
)
)
split_size = self.z_sample_size // reduction_factor
num_splits = z.shape[2] // split_size
# copy target shape, and divide frame dimension (=2) by the context size
target_shape_split = list(target_shape)
target_shape_split[2] = target_shape[2] // num_splits
decoded_tiles = [
(
self._hw_tiled_decode(z_tile, target_shape_split)
if self.use_hw_tiling
else self._decode(z_tile, target_shape=target_shape_split)
)
for z_tile in torch.tensor_split(z, num_splits, dim=2)
]
decoded = torch.cat(decoded_tiles, dim=2)
else:
decoded = (
self._hw_tiled_decode(z, target_shape)
if self.use_hw_tiling
else self._decode(z, target_shape=target_shape, timestep=timestep)
)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*):
Generator used to sample from the posterior.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, target_shape=sample.shape).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
```
## /ltx_video/models/autoencoders/vae_encode.py
```py path="/ltx_video/models/autoencoders/vae_encode.py"
from typing import Tuple
import torch
from diffusers import AutoencoderKL
from einops import rearrange
from torch import Tensor
from ltx_video.models.autoencoders.causal_video_autoencoder import (
CausalVideoAutoencoder,
)
from ltx_video.models.autoencoders.video_autoencoder import (
Downsample3D,
VideoAutoencoder,
)
try:
import torch_xla.core.xla_model as xm
except ImportError:
xm = None
def vae_encode(
media_items: Tensor,
vae: AutoencoderKL,
split_size: int = 1,
vae_per_channel_normalize=False,
) -> Tensor:
"""
Encodes media items (images or videos) into latent representations using a specified VAE model.
The function supports processing batches of images or video frames and can handle the processing
in smaller sub-batches if needed.
Args:
media_items (Tensor): A torch Tensor containing the media items to encode. The expected
shape is (batch_size, channels, height, width) for images or (batch_size, channels,
frames, height, width) for videos.
vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
pre-configured and loaded with the appropriate model weights.
split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
If set to more than 1, the input media items are processed in smaller batches according to
this value. Defaults to 1, which processes all items in a single batch.
Returns:
Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
to match the input shape, scaled by the model's configuration.
Examples:
>>> import torch
>>> from diffusers import AutoencoderKL
>>> vae = AutoencoderKL.from_pretrained('your-model-name')
>>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames.
>>> latents = vae_encode(images, vae)
>>> print(latents.shape) # Output shape will depend on the model's latent configuration.
Note:
In case of a video, the function encodes the media item frame-by frame.
"""
is_video_shaped = media_items.dim() == 5
batch_size, channels = media_items.shape[0:2]
if channels != 3:
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
if is_video_shaped and not isinstance(
vae, (VideoAutoencoder, CausalVideoAutoencoder)
):
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
if split_size > 1:
if len(media_items) % split_size != 0:
raise ValueError(
"Error: The batch size must be divisible by 'train.vae_bs_split"
)
encode_bs = len(media_items) // split_size
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
latents = []
if media_items.device.type == "xla":
xm.mark_step()
for image_batch in media_items.split(encode_bs):
latents.append(vae.encode(image_batch).latent_dist.sample())
if media_items.device.type == "xla":
xm.mark_step()
latents = torch.cat(latents, dim=0)
else:
latents = vae.encode(media_items).latent_dist.sample()
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
if is_video_shaped and not isinstance(
vae, (VideoAutoencoder, CausalVideoAutoencoder)
):
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
return latents
def vae_decode(
latents: Tensor,
vae: AutoencoderKL,
is_video: bool = True,
split_size: int = 1,
vae_per_channel_normalize=False,
timestep=None,
) -> Tensor:
is_video_shaped = latents.dim() == 5
batch_size = latents.shape[0]
if is_video_shaped and not isinstance(
vae, (VideoAutoencoder, CausalVideoAutoencoder)
):
latents = rearrange(latents, "b c n h w -> (b n) c h w")
if split_size > 1:
if len(latents) % split_size != 0:
raise ValueError(
"Error: The batch size must be divisible by 'train.vae_bs_split"
)
encode_bs = len(latents) // split_size
image_batch = [
_run_decoder(
latent_batch, vae, is_video, vae_per_channel_normalize, timestep
)
for latent_batch in latents.split(encode_bs)
]
images = torch.cat(image_batch, dim=0)
else:
images = _run_decoder(
latents, vae, is_video, vae_per_channel_normalize, timestep
)
if is_video_shaped and not isinstance(
vae, (VideoAutoencoder, CausalVideoAutoencoder)
):
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
return images
def _run_decoder(
latents: Tensor,
vae: AutoencoderKL,
is_video: bool,
vae_per_channel_normalize=False,
timestep=None,
) -> Tensor:
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
*_, fl, hl, wl = latents.shape
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
latents = latents.to(vae.dtype)
vae_decode_kwargs = {}
if timestep is not None:
vae_decode_kwargs["timestep"] = timestep
image = vae.decode(
un_normalize_latents(latents, vae, vae_per_channel_normalize),
return_dict=False,
target_shape=(
1,
3,
fl * temporal_scale if is_video else 1,
hl * spatial_scale,
wl * spatial_scale,
),
**vae_decode_kwargs,
)[0]
else:
image = vae.decode(
un_normalize_latents(latents, vae, vae_per_channel_normalize),
return_dict=False,
)[0]
return image
def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
if isinstance(vae, CausalVideoAutoencoder):
spatial = vae.spatial_downscale_factor
temporal = vae.temporal_downscale_factor
else:
down_blocks = len(
[
block
for block in vae.encoder.down_blocks
if isinstance(block.downsample, Downsample3D)
]
)
spatial = vae.config.patch_size * 2**down_blocks
temporal = (
vae.config.patch_size_t * 2**down_blocks
if isinstance(vae, VideoAutoencoder)
else 1
)
return (temporal, spatial, spatial)
def latent_to_pixel_coords(
latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False
) -> Tensor:
"""
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
configuration.
Args:
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
containing the latent corner coordinates of each token.
vae (AutoencoderKL): The VAE model
causal_fix (bool): Whether to take into account the different temporal scale
of the first frame. Default = False for backwards compatibility.
Returns:
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
"""
scale_factors = get_vae_size_scale_factor(vae)
causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix
pixel_coords = latent_to_pixel_coords_from_factors(
latent_coords, scale_factors, causal_fix
)
return pixel_coords
def latent_to_pixel_coords_from_factors(
latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False
) -> Tensor:
pixel_coords = (
latent_coords
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
)
if causal_fix:
# Fix temporal scale for first frame to 1 due to causality
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords
def normalize_latents(
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
) -> Tensor:
return (
(latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
/ vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
if vae_per_channel_normalize
else latents * vae.config.scaling_factor
)
def un_normalize_latents(
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
) -> Tensor:
return (
latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
if vae_per_channel_normalize
else latents / vae.config.scaling_factor
)
```
## /ltx_video/models/autoencoders/video_autoencoder.py
```py path="/ltx_video/models/autoencoders/video_autoencoder.py"
import json
import os
from functools import partial
from types import SimpleNamespace
from typing import Any, Mapping, Optional, Tuple, Union
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional
from diffusers.utils import logging
from ltx_video.utils.torch_utils import Identity
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
from ltx_video.models.autoencoders.pixel_norm import PixelNorm
from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
logger = logging.get_logger(__name__)
class VideoAutoencoder(AutoencoderKLWrapper):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*args,
**kwargs,
):
config_local_path = pretrained_model_name_or_path / "config.json"
config = cls.load_config(config_local_path, **kwargs)
video_vae = cls.from_config(config)
video_vae.to(kwargs["torch_dtype"])
model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
ckpt_state_dict = torch.load(model_local_path)
video_vae.load_state_dict(ckpt_state_dict)
statistics_local_path = (
pretrained_model_name_or_path / "per_channel_statistics.json"
)
if statistics_local_path.exists():
with open(statistics_local_path, "r") as file:
data = json.load(file)
transposed_data = list(zip(*data["data"]))
data_dict = {
col: torch.tensor(vals)
for col, vals in zip(data["columns"], transposed_data)
}
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
video_vae.register_buffer(
"mean_of_means",
data_dict.get(
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
),
)
return video_vae
@staticmethod
def from_config(config):
assert (
config["_class_name"] == "VideoAutoencoder"
), "config must have _class_name=VideoAutoencoder"
if isinstance(config["dims"], list):
config["dims"] = tuple(config["dims"])
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
use_quant_conv = config.get("use_quant_conv", True)
if use_quant_conv and latent_log_var == "uniform":
raise ValueError("uniform latent_log_var requires use_quant_conv=False")
encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
block_out_channels=config["block_out_channels"],
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
add_channel_padding=config.get("add_channel_padding", False),
)
decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
block_out_channels=config["block_out_channels"],
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
add_channel_padding=config.get("add_channel_padding", False),
)
dims = config["dims"]
return VideoAutoencoder(
encoder=encoder,
decoder=decoder,
latent_channels=config["latent_channels"],
dims=dims,
use_quant_conv=use_quant_conv,
)
@property
def config(self):
return SimpleNamespace(
_class_name="VideoAutoencoder",
dims=self.dims,
in_channels=self.encoder.conv_in.in_channels
// (self.encoder.patch_size_t * self.encoder.patch_size**2),
out_channels=self.decoder.conv_out.out_channels
// (self.decoder.patch_size_t * self.decoder.patch_size**2),
latent_channels=self.decoder.conv_in.in_channels,
block_out_channels=[
self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
for i in range(len(self.encoder.down_blocks))
],
scaling_factor=1.0,
norm_layer=self.encoder.norm_layer,
patch_size=self.encoder.patch_size,
latent_log_var=self.encoder.latent_log_var,
use_quant_conv=self.use_quant_conv,
patch_size_t=self.encoder.patch_size_t,
add_channel_padding=self.encoder.add_channel_padding,
)
@property
def is_video_supported(self):
"""
Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
"""
return self.dims != 2
@property
def downscale_factor(self):
return self.encoder.downsample_factor
def to_json_string(self) -> str:
import json
return json.dumps(self.config.__dict__)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
model_keys = set(name for name, _ in self.named_parameters())
key_mapping = {
".resnets.": ".res_blocks.",
"downsamplers.0": "downsample",
"upsamplers.0": "upsample",
}
converted_state_dict = {}
for key, value in state_dict.items():
for k, v in key_mapping.items():
key = key.replace(k, v)
if "norm" in key and key not in model_keys:
logger.info(
f"Removing key {key} from state_dict as it is not present in the model"
)
continue
converted_state_dict[key] = value
super().load_state_dict(converted_state_dict, strict=strict)
def last_layer(self):
if hasattr(self.decoder, "conv_out"):
if isinstance(self.decoder.conv_out, nn.Sequential):
last_layer = self.decoder.conv_out[-1]
else:
last_layer = self.decoder.conv_out
else:
last_layer = self.decoder.layers[-1]
return last_layer
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
patch_size_t: Optional[int] = None,
add_channel_padding: Optional[bool] = False,
):
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
self.add_channel_padding = add_channel_padding
self.layers_per_block = layers_per_block
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
if add_channel_padding:
in_channels = in_channels * self.patch_size**3
else:
in_channels = in_channels * self.patch_size_t * self.patch_size**2
self.in_channels = in_channels
output_channel = block_out_channels[0]
self.conv_in = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
padding=1,
)
self.down_blocks = nn.ModuleList([])
for i in range(len(block_out_channels)):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = DownEncoderBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
add_downsample=not is_final_block and 2**i >= patch_size,
resnet_eps=1e-6,
downsample_padding=0,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
self.down_blocks.append(down_block)
self.mid_block = UNetMidBlock3D(
dims=dims,
in_channels=block_out_channels[-1],
num_layers=self.layers_per_block,
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
# out
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[-1],
num_groups=norm_num_groups,
eps=1e-6,
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
conv_out_channels = out_channels
if latent_log_var == "per_channel":
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims, block_out_channels[-1], conv_out_channels, 3, padding=1
)
self.gradient_checkpointing = False
@property
def downscale_factor(self):
return (
2
** len(
[
block
for block in self.down_blocks
if isinstance(block.downsample, Downsample3D)
]
)
* self.patch_size
)
def forward(
self, sample: torch.FloatTensor, return_features=False
) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
downsample_in_time = sample.shape[2] != 1
# patchify
patch_size_t = self.patch_size_t if downsample_in_time else 1
sample = patchify(
sample,
patch_size_hw=self.patch_size,
patch_size_t=patch_size_t,
add_channel_padding=self.add_channel_padding,
)
sample = self.conv_in(sample)
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
if return_features:
features = []
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(
sample, downsample_in_time=downsample_in_time
)
if return_features:
features.append(sample)
sample = checkpoint_fn(self.mid_block)(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
num_dims = sample.dim()
if num_dims == 4:
# For shape (B, C, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
elif num_dims == 5:
# For shape (B, C, F, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
if return_features:
features.append(sample[:, : self.latent_channels, ...])
return sample, features
return sample
class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
"""
def __init__(
self,
dims,
in_channels: int = 3,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
patch_size: int = 1,
norm_layer: str = "group_norm",
patch_size_t: Optional[int] = None,
add_channel_padding: Optional[bool] = False,
):
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
self.add_channel_padding = add_channel_padding
self.layers_per_block = layers_per_block
if add_channel_padding:
out_channels = out_channels * self.patch_size**3
else:
out_channels = out_channels * self.patch_size_t * self.patch_size**2
self.out_channels = out_channels
self.conv_in = make_conv_nd(
dims,
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
self.mid_block = UNetMidBlock3D(
dims=dims,
in_channels=block_out_channels[-1],
num_layers=self.layers_per_block,
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(reversed_block_out_channels)):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpDecoderBlock3D(
dims=dims,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=not is_final_block
and 2 ** (len(block_out_channels) - i - 1) > patch_size,
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
self.up_blocks.append(up_block)
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims, block_out_channels[0], out_channels, 3, padding=1
)
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
assert target_shape is not None, "target_shape must be provided"
upsample_in_time = sample.shape[2] < target_shape[2]
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
sample = checkpoint_fn(self.mid_block)(sample)
sample = sample.to(upscale_dtype)
for up_block in self.up_blocks:
sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
# un-patchify
patch_size_t = self.patch_size_t if upsample_in_time else 1
sample = unpatchify(
sample,
patch_size_hw=self.patch_size,
patch_size_t=patch_size_t,
add_channel_padding=self.add_channel_padding,
)
return sample
class DownEncoderBlock3D(nn.Module):
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
add_downsample: bool = True,
downsample_padding: int = 1,
norm_layer: str = "group_norm",
):
super().__init__()
res_blocks = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
res_blocks.append(
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=out_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
)
)
self.res_blocks = nn.ModuleList(res_blocks)
if add_downsample:
self.downsample = Downsample3D(
dims,
out_channels,
out_channels=out_channels,
padding=downsample_padding,
)
else:
self.downsample = Identity()
def forward(
self, hidden_states: torch.FloatTensor, downsample_in_time
) -> torch.FloatTensor:
for resnet in self.res_blocks:
hidden_states = resnet(hidden_states)
hidden_states = self.downsample(
hidden_states, downsample_in_time=downsample_in_time
)
return hidden_states
class UNetMidBlock3D(nn.Module):
"""
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
Args:
in_channels (`int`): The number of input channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: str = "group_norm",
):
super().__init__()
resnet_groups = (
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
)
self.res_blocks = nn.ModuleList(
[
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
)
for _ in range(num_layers)
]
)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
for resnet in self.res_blocks:
hidden_states = resnet(hidden_states)
return hidden_states
class UpDecoderBlock3D(nn.Module):
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
add_upsample: bool = True,
norm_layer: str = "group_norm",
):
super().__init__()
res_blocks = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
res_blocks.append(
ResnetBlock3D(
dims=dims,
in_channels=input_channels,
out_channels=out_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
)
)
self.res_blocks = nn.ModuleList(res_blocks)
if add_upsample:
self.upsample = Upsample3D(
dims=dims, channels=out_channels, out_channels=out_channels
)
else:
self.upsample = Identity()
self.resolution_idx = resolution_idx
def forward(
self, hidden_states: torch.FloatTensor, upsample_in_time=True
) -> torch.FloatTensor:
for resnet in self.res_blocks:
hidden_states = resnet(hidden_states)
hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time)
return hidden_states
class ResnetBlock3D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
groups: int = 32,
eps: float = 1e-6,
norm_layer: str = "group_norm",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if norm_layer == "group_norm":
self.norm1 = torch.nn.GroupNorm(
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm1 = PixelNorm()
self.non_linearity = nn.SiLU()
self.conv1 = make_conv_nd(
dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if norm_layer == "group_norm":
self.norm2 = torch.nn.GroupNorm(
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm2 = PixelNorm()
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv_nd(
dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.conv_shortcut = (
make_linear_nd(
dims=dims, in_channels=in_channels, out_channels=out_channels
)
if in_channels != out_channels
else nn.Identity()
)
def forward(
self,
input_tensor: torch.FloatTensor,
) -> torch.FloatTensor:
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
class Downsample3D(nn.Module):
def __init__(
self,
dims,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
padding: int = 1,
):
super().__init__()
stride: int = 2
self.padding = padding
self.in_channels = in_channels
self.dims = dims
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
def forward(self, x, downsample_in_time=True):
conv = self.conv
if self.padding == 0:
if self.dims == 2:
padding = (0, 1, 0, 1)
else:
padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
x = functional.pad(x, padding, mode="constant", value=0)
if self.dims == (2, 1) and not downsample_in_time:
return conv(x, skip_time_conv=True)
return conv(x)
class Upsample3D(nn.Module):
"""
An upsampling layer for 3D tensors of shape (B, C, D, H, W).
:param channels: channels in the inputs and outputs.
"""
def __init__(self, dims, channels, out_channels=None):
super().__init__()
self.dims = dims
self.channels = channels
self.out_channels = out_channels or channels
self.conv = make_conv_nd(
dims, channels, out_channels, kernel_size=3, padding=1, bias=True
)
def forward(self, x, upsample_in_time):
if self.dims == 2:
x = functional.interpolate(
x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
)
else:
time_scale_factor = 2 if upsample_in_time else 1
# print("before:", x.shape)
b, c, d, h, w = x.shape
x = rearrange(x, "b c d h w -> (b d) c h w")
# height and width interpolate
x = functional.interpolate(
x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
)
_, _, h, w = x.shape
if not upsample_in_time and self.dims == (2, 1):
x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w)
return self.conv(x, skip_time_conv=True)
# Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension
x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b)
# (b h w) c 1 d
new_d = x.shape[-1] * time_scale_factor
x = functional.interpolate(x, (1, new_d), mode="nearest")
# (b h w) c 1 new_d
x = rearrange(
x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d
)
# b c d h w
# x = functional.interpolate(
# x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
# )
# print("after:", x.shape)
return self.conv(x)
def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
if (
(x.dim() == 5)
and (patch_size_hw > patch_size_t)
and (patch_size_t > 1 or add_channel_padding)
):
channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
padding_zeros = torch.zeros(
x.shape[0],
channels_to_pad,
x.shape[2],
x.shape[3],
x.shape[4],
device=x.device,
dtype=x.dtype,
)
x = torch.cat([padding_zeros, x], dim=1)
return x
def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if (
(x.dim() == 5)
and (patch_size_hw > patch_size_t)
and (patch_size_t > 1 or add_channel_padding)
):
channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
x = x[:, :channels_to_keep, :, :, :]
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
return x
def create_video_autoencoder_config(
latent_channels: int = 4,
):
config = {
"_class_name": "VideoAutoencoder",
"dims": (
2,
1,
), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
"in_channels": 3, # Number of input color channels (e.g., RGB)
"out_channels": 3, # Number of output color channels
"latent_channels": latent_channels, # Number of channels in the latent space representation
"block_out_channels": [
128,
256,
512,
512,
], # Number of output channels of each encoder / decoder inner block
"patch_size": 1,
}
return config
def create_video_autoencoder_pathify4x4x4_config(
latent_channels: int = 4,
):
config = {
"_class_name": "VideoAutoencoder",
"dims": (
2,
1,
), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
"in_channels": 3, # Number of input color channels (e.g., RGB)
"out_channels": 3, # Number of output color channels
"latent_channels": latent_channels, # Number of channels in the latent space representation
"block_out_channels": [512]
* 4, # Number of output channels of each encoder / decoder inner block
"patch_size": 4,
"latent_log_var": "uniform",
}
return config
def create_video_autoencoder_pathify4x4_config(
latent_channels: int = 4,
):
config = {
"_class_name": "VideoAutoencoder",
"dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
"in_channels": 3, # Number of input color channels (e.g., RGB)
"out_channels": 3, # Number of output color channels
"latent_channels": latent_channels, # Number of channels in the latent space representation
"block_out_channels": [512]
* 4, # Number of output channels of each encoder / decoder inner block
"patch_size": 4,
"norm_layer": "pixel_norm",
}
return config
def test_vae_patchify_unpatchify():
import torch
x = torch.randn(2, 3, 8, 64, 64)
x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
assert torch.allclose(x, x_unpatched)
def demo_video_autoencoder_forward_backward():
# Configuration for the VideoAutoencoder
config = create_video_autoencoder_pathify4x4x4_config()
# Instantiate the VideoAutoencoder with the specified configuration
video_autoencoder = VideoAutoencoder.from_config(config)
print(video_autoencoder)
# Print the total number of parameters in the video autoencoder
total_params = sum(p.numel() for p in video_autoencoder.parameters())
print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
# Create a mock input tensor simulating a batch of videos
# Shape: (batch_size, channels, depth, height, width)
# E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
input_videos = torch.randn(2, 3, 8, 64, 64)
# Forward pass: encode and decode the input videos
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
print(f"input shape={input_videos.shape}")
print(f"latent shape={latent.shape}")
reconstructed_videos = video_autoencoder.decode(
latent, target_shape=input_videos.shape
).sample
print(f"reconstructed shape={reconstructed_videos.shape}")
# Calculate the loss (e.g., mean squared error)
loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
# Perform backward pass
loss.backward()
print(f"Demo completed with loss: {loss.item()}")
# Ensure to call the demo function to execute the forward and backward pass
if __name__ == "__main__":
demo_video_autoencoder_forward_backward()
```
## /ltx_video/models/transformers/__init__.py
```py path="/ltx_video/models/transformers/__init__.py"
```
## /ltx_video/models/transformers/embeddings.py
```py path="/ltx_video/models/transformers/embeddings.py"
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
import math
import numpy as np
import torch
from einops import rearrange
from torch import nn
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
grid = grid.reshape([3, 1, w, h, f])
pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
pos_embed = pos_embed.transpose(1, 0, 2, 3)
return rearrange(pos_embed, "h w f c -> (f h w) c")
def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 3 != 0:
raise ValueError("embed_dim must be divisible by 3")
# use half of dimensions to encode grid_h
emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3)
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3)
emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos_shape = pos.shape
pos = pos.reshape(-1)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
out = out.reshape([*pos_shape, -1])[0]
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D)
return emb
class SinusoidalPositionalEmbedding(nn.Module):
"""Apply positional information to a sequence of embeddings.
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
them
Args:
embed_dim: (int): Dimension of the positional embedding.
max_seq_length: Maximum sequence length to apply positional embeddings
"""
def __init__(self, embed_dim: int, max_seq_length: int = 32):
super().__init__()
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
)
pe = torch.zeros(1, max_seq_length, embed_dim)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
_, seq_length, _ = x.shape
x = x + self.pe[:, :seq_length]
return x
```
## /ltx_video/models/transformers/symmetric_patchifier.py
```py path="/ltx_video/models/transformers/symmetric_patchifier.py"
from abc import ABC, abstractmethod
from typing import Tuple
import torch
from diffusers.configuration_utils import ConfigMixin
from einops import rearrange
from torch import Tensor
class Patchifier(ConfigMixin, ABC):
def __init__(self, patch_size: int):
super().__init__()
self._patch_size = (1, patch_size, patch_size)
@abstractmethod
def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
raise NotImplementedError("Patchify method not implemented")
@abstractmethod
def unpatchify(
self,
latents: Tensor,
output_height: int,
output_width: int,
out_channels: int,
) -> Tuple[Tensor, Tensor]:
pass
@property
def patch_size(self):
return self._patch_size
def get_latent_coords(
self, latent_num_frames, latent_height, latent_width, batch_size, device
):
"""
Return a tensor of shape [batch_size, 3, num_patches] containing the
top-left corner latent coordinates of each latent patch.
The tensor is repeated for each batch element.
"""
latent_sample_coords = torch.meshgrid(
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
torch.arange(0, latent_height, self._patch_size[1], device=device),
torch.arange(0, latent_width, self._patch_size[2], device=device),
)
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_coords = rearrange(
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
)
return latent_coords
class SymmetricPatchifier(Patchifier):
def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
b, _, f, h, w = latents.shape
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
latents = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self._patch_size[0],
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents, latent_coords
def unpatchify(
self,
latents: Tensor,
output_height: int,
output_width: int,
out_channels: int,
) -> Tuple[Tensor, Tensor]:
output_height = output_height // self._patch_size[1]
output_width = output_width // self._patch_size[2]
latents = rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q)",
h=output_height,
w=output_width,
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents
```
## /ltx_video/models/transformers/transformer3d.py
```py path="/ltx_video/models/transformers/transformer3d.py"
# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import os
import json
import glob
from pathlib import Path
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.embeddings import PixArtAlphaTextProjection
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormSingle
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils import logging
from torch import nn
from safetensors import safe_open
from ltx_video.models.transformers.attention import BasicTransformerBlock
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
from ltx_video.utils.diffusers_config_mapping import (
diffusers_and_ours_config_mapping,
make_hashable_key,
TRANSFORMER_KEYS_RENAME_DICT,
)
logger = logging.get_logger(__name__)
@dataclass
class Transformer3DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.FloatTensor
class Transformer3DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale'
standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
attention_type: str = "default",
caption_channels: int = None,
use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention')
qk_norm: Optional[str] = None,
positional_embedding_type: str = "rope",
positional_embedding_theta: Optional[float] = None,
positional_embedding_max_pos: Optional[List[int]] = None,
timestep_scale_multiplier: Optional[float] = None,
causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated
):
super().__init__()
self.use_tpu_flash_attention = (
use_tpu_flash_attention # FIXME: push config down to the attention modules
)
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True)
self.positional_embedding_type = positional_embedding_type
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = positional_embedding_max_pos
self.use_rope = self.positional_embedding_type == "rope"
self.timestep_scale_multiplier = timestep_scale_multiplier
if self.positional_embedding_type == "absolute":
raise ValueError("Absolute positional embedding is no longer supported")
elif self.positional_embedding_type == "rope":
if positional_embedding_theta is None:
raise ValueError(
"If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined"
)
if positional_embedding_max_pos is None:
raise ValueError(
"If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined"
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
adaptive_norm=adaptive_norm,
standardization_norm=standardization_norm,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
use_tpu_flash_attention=use_tpu_flash_attention,
qk_norm=qk_norm,
use_rope=self.use_rope,
)
for d in range(num_layers)
]
)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(
torch.randn(2, inner_dim) / inner_dim**0.5
)
self.proj_out = nn.Linear(inner_dim, self.out_channels)
self.adaln_single = AdaLayerNormSingle(
inner_dim, use_additional_conditions=False
)
if adaptive_norm == "single_scale":
self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
self.caption_projection = None
if caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=inner_dim
)
self.gradient_checkpointing = False
def set_use_tpu_flash_attention(self):
r"""
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
attention kernel.
"""
logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
self.use_tpu_flash_attention = True
# push config down to the attention modules
for block in self.transformer_blocks:
block.set_use_tpu_flash_attention()
def create_skip_layer_mask(
self,
batch_size: int,
num_conds: int,
ptb_index: int,
skip_block_list: Optional[List[int]] = None,
):
if skip_block_list is None or len(skip_block_list) == 0:
return None
num_layers = len(self.transformer_blocks)
mask = torch.ones(
(num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype
)
for block_idx in skip_block_list:
mask[block_idx, ptb_index::num_conds] = 0
return mask
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def get_fractional_positions(self, indices_grid):
fractional_positions = torch.stack(
[
indices_grid[:, i] / self.positional_embedding_max_pos[i]
for i in range(3)
],
dim=-1,
)
return fractional_positions
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
dtype = torch.float32 # We need full precision in the freqs_cis computation.
dim = self.inner_dim
theta = self.positional_embedding_theta
fractional_positions = self.get_fractional_positions(indices_grid)
start = 1
end = theta
device = fractional_positions.device
if spacing == "exp":
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
elif spacing == "exp_2":
indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim)
indices = indices.to(dtype=dtype)
elif spacing == "linear":
indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
elif spacing == "sqrt":
indices = torch.linspace(
start**2, end**2, dim // 6, device=device, dtype=dtype
).sqrt()
indices = indices * math.pi / 2
if spacing == "exp_2":
freqs = (
(indices * fractional_positions.unsqueeze(-1))
.transpose(-1, -2)
.flatten(2)
)
else:
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if dim % 6 != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(self.dtype), sin_freq.to(self.dtype)
def load_state_dict(
self,
state_dict: Dict,
*args,
**kwargs,
):
if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]):
state_dict = {
key.replace("model.diffusion_model.", ""): value
for key, value in state_dict.items()
if key.startswith("model.diffusion_model.")
}
super().load_state_dict(state_dict, *args, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_path: Optional[Union[str, os.PathLike]],
*args,
**kwargs,
):
pretrained_model_path = Path(pretrained_model_path)
if pretrained_model_path.is_dir():
config_path = pretrained_model_path / "transformer" / "config.json"
with open(config_path, "r") as f:
config = make_hashable_key(json.load(f))
assert config in diffusers_and_ours_config_mapping, (
"Provided diffusers checkpoint config for transformer is not suppported. "
"We only support diffusers configs found in Lightricks/LTX-Video."
)
config = diffusers_and_ours_config_mapping[config]
state_dict = {}
ckpt_paths = (
pretrained_model_path
/ "transformer"
/ "diffusion_pytorch_model*.safetensors"
)
dict_list = glob.glob(str(ckpt_paths))
for dict_path in dict_list:
part_dict = {}
with safe_open(dict_path, framework="pt", device="cpu") as f:
for k in f.keys():
part_dict[k] = f.get_tensor(k)
state_dict.update(part_dict)
for key in list(state_dict.keys()):
new_key = key
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
state_dict[new_key] = state_dict.pop(key)
with torch.device("meta"):
transformer = cls.from_config(config)
transformer.load_state_dict(state_dict, assign=True, strict=True)
elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
".safetensors"
):
comfy_single_file_state_dict = {}
with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
for k in f.keys():
comfy_single_file_state_dict[k] = f.get_tensor(k)
configs = json.loads(metadata["config"])
transformer_config = configs["transformer"]
with torch.device("meta"):
transformer = Transformer3DModel.from_config(transformer_config)
transformer.load_state_dict(comfy_single_file_state_dict, assign=True)
return transformer
def forward(
self,
hidden_states: torch.Tensor,
indices_grid: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
skip_layer_mask: Optional[torch.Tensor] = None,
skip_layer_strategy: Optional[SkipLayerStrategy] = None,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
skip_layer_mask ( `torch.Tensor`, *optional*):
A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position
`layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index.
skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`):
Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# for tpu attention offload 2d token masks are used. No need to transform.
if not self.use_tpu_flash_attention:
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (
1 - encoder_attention_mask.to(hidden_states.dtype)
) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
hidden_states = self.patchify_proj(hidden_states)
if self.timestep_scale_multiplier:
timestep = self.timestep_scale_multiplier * timestep
freqs_cis = self.precompute_freqs_cis(indices_grid)
batch_size = hidden_states.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(
batch_size, -1, embedded_timestep.shape[-1]
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(
batch_size, -1, hidden_states.shape[-1]
)
for block_idx, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
freqs_cis,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
(
skip_layer_mask[block_idx]
if skip_layer_mask is not None
else None
),
skip_layer_strategy,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
freqs_cis=freqs_cis,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
skip_layer_mask=(
skip_layer_mask[block_idx]
if skip_layer_mask is not None
else None
),
skip_layer_strategy=skip_layer_strategy,
)
# 3. Output
scale_shift_values = (
self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
if not return_dict:
return (hidden_states,)
return Transformer3DModelOutput(sample=hidden_states)
```
## /ltx_video/pipelines/__init__.py
```py path="/ltx_video/pipelines/__init__.py"
```
## /ltx_video/pipelines/crf_compressor.py
```py path="/ltx_video/pipelines/crf_compressor.py"
import av
import torch
import io
import numpy as np
def _encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream(
"libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
)
stream.height = image_array.shape[0]
stream.width = image_array.shape[1]
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
format="yuv420p"
)
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
finally:
container.close()
def _decode_single_frame(video_file):
container = av.open(video_file)
try:
stream = next(s for s in container.streams if s.type == "video")
frame = next(container.decode(stream))
finally:
container.close()
return frame.to_ndarray(format="rgb24")
def compress(image: torch.Tensor, crf=29):
if crf == 0:
return image
image_array = (
(image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0)
.byte()
.cpu()
.numpy()
)
with io.BytesIO() as output_file:
_encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue()
with io.BytesIO(video_bytes) as video_file:
image_array = _decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor
```
## /ltx_video/schedulers/__init__.py
```py path="/ltx_video/schedulers/__init__.py"
```
## /ltx_video/utils/__init__.py
```py path="/ltx_video/utils/__init__.py"
```
## /ltx_video/utils/diffusers_config_mapping.py
```py path="/ltx_video/utils/diffusers_config_mapping.py"
def make_hashable_key(dict_key):
def convert_value(value):
if isinstance(value, list):
return tuple(value)
elif isinstance(value, dict):
return tuple(sorted((k, convert_value(v)) for k, v in value.items()))
else:
return value
return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items()))
DIFFUSERS_SCHEDULER_CONFIG = {
"_class_name": "FlowMatchEulerDiscreteScheduler",
"_diffusers_version": "0.32.0.dev0",
"base_image_seq_len": 1024,
"base_shift": 0.95,
"invert_sigmas": False,
"max_image_seq_len": 4096,
"max_shift": 2.05,
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": 0.1,
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
DIFFUSERS_TRANSFORMER_CONFIG = {
"_class_name": "LTXVideoTransformer3DModel",
"_diffusers_version": "0.32.0.dev0",
"activation_fn": "gelu-approximate",
"attention_bias": True,
"attention_head_dim": 64,
"attention_out_bias": True,
"caption_channels": 4096,
"cross_attention_dim": 2048,
"in_channels": 128,
"norm_elementwise_affine": False,
"norm_eps": 1e-06,
"num_attention_heads": 32,
"num_layers": 28,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"qk_norm": "rms_norm_across_heads",
}
DIFFUSERS_VAE_CONFIG = {
"_class_name": "AutoencoderKLLTXVideo",
"_diffusers_version": "0.32.0.dev0",
"block_out_channels": [128, 256, 512, 512],
"decoder_causal": False,
"encoder_causal": True,
"in_channels": 3,
"latent_channels": 128,
"layers_per_block": [4, 3, 3, 3, 4],
"out_channels": 3,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-06,
"scaling_factor": 1.0,
"spatio_temporal_scaling": [True, True, True, False],
}
OURS_SCHEDULER_CONFIG = {
"_class_name": "RectifiedFlowScheduler",
"_diffusers_version": "0.25.1",
"num_train_timesteps": 1000,
"shifting": "SD3",
"base_resolution": None,
"target_shift_terminal": 0.1,
}
OURS_TRANSFORMER_CONFIG = {
"_class_name": "Transformer3DModel",
"_diffusers_version": "0.25.1",
"_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256",
"activation_fn": "gelu-approximate",
"attention_bias": True,
"attention_head_dim": 64,
"attention_type": "default",
"caption_channels": 4096,
"cross_attention_dim": 2048,
"double_self_attention": False,
"dropout": 0.0,
"in_channels": 128,
"norm_elementwise_affine": False,
"norm_eps": 1e-06,
"norm_num_groups": 32,
"num_attention_heads": 32,
"num_embeds_ada_norm": 1000,
"num_layers": 28,
"num_vector_embeds": None,
"only_cross_attention": False,
"out_channels": 128,
"project_to_2d_pos": True,
"upcast_attention": False,
"use_linear_projection": False,
"qk_norm": "rms_norm",
"standardization_norm": "rms_norm",
"positional_embedding_type": "rope",
"positional_embedding_theta": 10000.0,
"positional_embedding_max_pos": [20, 2048, 2048],
"timestep_scale_multiplier": 1000,
}
OURS_VAE_CONFIG = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"blocks": [
["res_x", 4],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x", 3],
["res_x", 4],
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
}
diffusers_and_ours_config_mapping = {
make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG,
make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG,
make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG,
}
TRANSFORMER_KEYS_RENAME_DICT = {
"proj_in": "patchify_proj",
"time_embed": "adaln_single",
"norm_q": "q_norm",
"norm_k": "k_norm",
}
VAE_KEYS_RENAME_DICT = {
"decoder.up_blocks.3.conv_in": "decoder.up_blocks.7",
"decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8",
"decoder.up_blocks.3": "decoder.up_blocks.9",
"decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5",
"decoder.up_blocks.2.conv_in": "decoder.up_blocks.4",
"decoder.up_blocks.2": "decoder.up_blocks.6",
"decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2",
"decoder.up_blocks.1": "decoder.up_blocks.3",
"decoder.up_blocks.0": "decoder.up_blocks.1",
"decoder.mid_block": "decoder.up_blocks.0",
"encoder.down_blocks.3": "encoder.down_blocks.8",
"encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7",
"encoder.down_blocks.2": "encoder.down_blocks.6",
"encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4",
"encoder.down_blocks.1.conv_out": "encoder.down_blocks.5",
"encoder.down_blocks.1": "encoder.down_blocks.3",
"encoder.down_blocks.0.conv_out": "encoder.down_blocks.2",
"encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1",
"encoder.down_blocks.0": "encoder.down_blocks.0",
"encoder.mid_block": "encoder.down_blocks.9",
"conv_shortcut.conv": "conv_shortcut",
"resnets": "res_blocks",
"norm3": "norm3.norm",
"latents_mean": "per_channel_statistics.mean-of-means",
"latents_std": "per_channel_statistics.std-of-means",
}
```
## /ltx_video/utils/skip_layer_strategy.py
```py path="/ltx_video/utils/skip_layer_strategy.py"
from enum import Enum, auto
class SkipLayerStrategy(Enum):
AttentionSkip = auto()
AttentionValues = auto()
Residual = auto()
TransformerBlock = auto()
```
## /ltx_video/utils/torch_utils.py
```py path="/ltx_video/utils/torch_utils.py"
import torch
from torch import nn
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
elif dims_to_append == 0:
return x
return x[(...,) + (None,) * dims_to_append]
class Identity(nn.Module):
"""A placeholder identity operator that is argument-insensitive."""
def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
super().__init__()
# pylint: disable=unused-argument
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return x
```
## /pyproject.toml
```toml path="/pyproject.toml"
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "ltx-video"
version = "0.1.2"
description = "A package for LTX-Video model"
authors = [
{ name = "LTX-Video Team", email = "ltx-video@lightricks.com" }
]
requires-python = ">=3.10"
readme = "README.md"
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent"
]
dependencies = [
"torch>=2.1.0",
"diffusers>=0.28.2",
"transformers>=4.47.2,<4.52.0",
"sentencepiece>=0.1.96",
"huggingface-hub~=0.30",
"einops",
"timm"
]
[project.optional-dependencies]
inference = [
"imageio[ffmpeg]",
"av",
"torchvision"
]
test = [
"pytest",
]
[tool.setuptools.packages.find]
include = ["ltx_video*"]
[tool.setuptools.package-data]
ltx_video = ["configs/*.yaml"]
```
## /tests/conftest.py
```py path="/tests/conftest.py"
import json
import pytest
import safetensors.torch
import torch
from ltx_video.models.autoencoders.causal_video_autoencoder import (
CausalVideoAutoencoder,
create_video_autoencoder_demo_config,
PER_CHANNEL_STATISTICS_PREFIX,
)
from ltx_video.models.transformers.transformer3d import Transformer3DModel
def pytest_make_parametrize_id(config, val, argname):
if isinstance(val, str):
return f"{argname}-{val}"
return f"{argname}-{repr(val)}"
@pytest.fixture
def num_latent_channels():
return 16
@pytest.fixture
def video_autoencoder(num_latent_channels):
config = create_video_autoencoder_demo_config(latent_channels=num_latent_channels)
model = CausalVideoAutoencoder.from_config(config)
model.eval().to(torch.bfloat16)
return model
@pytest.fixture
def transformer_config(num_latent_channels):
transformer_config = {
"activation_fn": "gelu-approximate",
"attention_bias": True,
"attention_head_dim": 12,
"attention_type": "default",
"caption_channels": 4096,
"cross_attention_dim": 192,
"double_self_attention": False,
"dropout": 0.0,
"in_channels": num_latent_channels,
"norm_elementwise_affine": False,
"norm_eps": 1e-06,
"norm_num_groups": 32,
"num_attention_heads": 16,
"num_embeds_ada_norm": 1000,
"num_layers": 2,
"num_vector_embeds": None,
"only_cross_attention": False,
"out_channels": num_latent_channels,
"upcast_attention": False,
"use_linear_projection": False,
"qk_norm": "rms_norm",
"standardization_norm": "rms_norm",
"positional_embedding_type": "rope",
"positional_embedding_theta": 10000.0,
"positional_embedding_max_pos": [120, 1, 1],
"timestep_scale_multiplier": 1000,
}
return transformer_config
@pytest.fixture
def synthetic_ckpt_path(
tmp_path, video_autoencoder, num_latent_channels, transformer_config
):
# Create transformer
transformer = Transformer3DModel.from_config(transformer_config)
transformer.to(torch.bfloat16)
# Prepare configs and state dicts
configs = {"transformer": transformer_config, "vae": vars(video_autoencoder.config)}
transformer_sd = transformer.state_dict()
transformer_sd = {
"model.diffusion_model." + key: value for key, value in transformer_sd.items()
}
# Prepare VAE state dict with per-channel statistics
vae_sd = video_autoencoder.state_dict()
vae_sd[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = torch.rand(
num_latent_channels,
)
vae_sd[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = torch.rand(
num_latent_channels,
)
vae_sd = {"vae." + key: value for key, value in vae_sd.items()}
out_file_path = f"{tmp_path}/test_ckpt.safetensors"
safetensors.torch.save_file(
{**transformer_sd, **vae_sd},
out_file_path,
metadata={"config": json.dumps(configs)},
)
return out_file_path
```
The content has been capped at 50000 tokens. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.