Lightricks/LTX-Video/main 83k tokens More Tools
```
├── .gitattributes (omitted)
├── .github/
   ├── workflows/
      ├── pylint.yml (100 tokens)
├── .gitignore (600 tokens)
├── .pre-commit-config.yaml (100 tokens)
├── LICENSE (omitted)
├── README.md (5.3k tokens)
├── configs/
   ├── ltxv-13b-0.9.8-dev-fp8.yaml (300 tokens)
   ├── ltxv-13b-0.9.8-dev.yaml (300 tokens)
   ├── ltxv-13b-0.9.8-distilled-fp8.yaml (200 tokens)
   ├── ltxv-13b-0.9.8-distilled.yaml (200 tokens)
   ├── ltxv-2b-0.9.1.yaml (100 tokens)
   ├── ltxv-2b-0.9.5.yaml (100 tokens)
   ├── ltxv-2b-0.9.6-dev.yaml (100 tokens)
   ├── ltxv-2b-0.9.6-distilled.yaml (100 tokens)
   ├── ltxv-2b-0.9.8-distilled-fp8.yaml (200 tokens)
   ├── ltxv-2b-0.9.8-distilled.yaml (200 tokens)
   ├── ltxv-2b-0.9.yaml (100 tokens)
├── docs/
   ├── _static/
      ├── ltx-video_example_00001.gif
      ├── ltx-video_example_00005.gif
      ├── ltx-video_example_00006.gif
      ├── ltx-video_example_00007.gif
      ├── ltx-video_example_00010.gif
      ├── ltx-video_example_00011.gif
      ├── ltx-video_example_00013.gif
      ├── ltx-video_example_00014.gif
      ├── ltx-video_example_00015.gif
      ├── ltx-video_i2v_example_00001.gif
      ├── ltx-video_i2v_example_00002.gif
      ├── ltx-video_i2v_example_00003.gif
      ├── ltx-video_i2v_example_00004.gif
      ├── ltx-video_i2v_example_00005.gif
      ├── ltx-video_i2v_example_00006.gif
      ├── ltx-video_i2v_example_00007.gif
      ├── ltx-video_i2v_example_00008.gif
      ├── ltx-video_i2v_example_00009.gif
      ├── ltx-video_ic_2v_example_00000.gif
      ├── ltx-video_ic_2v_example_00001.gif
      ├── ltx-video_ic_2v_example_00002.gif
      ├── ltx-video_ic_2v_example_00003.gif
      ├── ltx-video_ic_2v_example_00004.gif
├── inference.py (100 tokens)
├── ltx_video/
   ├── __init__.py
   ├── inference.py (5.3k tokens)
   ├── models/
      ├── __init__.py
      ├── autoencoders/
         ├── __init__.py
         ├── causal_conv3d.py (400 tokens)
         ├── causal_video_autoencoder.py (10.5k tokens)
         ├── conv_nd_factory.py (500 tokens)
         ├── dual_conv3d.py (1400 tokens)
         ├── latent_upsampler.py (1400 tokens)
         ├── pixel_norm.py (100 tokens)
         ├── pixel_shuffle.py (200 tokens)
         ├── vae.py (2.9k tokens)
         ├── vae_encode.py (1800 tokens)
         ├── video_autoencoder.py (7.1k tokens)
      ├── transformers/
         ├── __init__.py
         ├── attention.py (10.5k tokens)
         ├── embeddings.py (900 tokens)
         ├── symmetric_patchifier.py (600 tokens)
         ├── transformer3d.py (4.5k tokens)
   ├── pipelines/
      ├── __init__.py
      ├── crf_compressor.py (300 tokens)
      ├── pipeline_ltx_video.py (16.7k tokens)
   ├── schedulers/
      ├── __init__.py
      ├── rf.py (3.1k tokens)
   ├── utils/
      ├── __init__.py
      ├── diffusers_config_mapping.py (1100 tokens)
      ├── prompt_enhance_utils.py (1500 tokens)
      ├── skip_layer_strategy.py
      ├── torch_utils.py (200 tokens)
├── pyproject.toml (200 tokens)
├── tests/
   ├── conftest.py (600 tokens)
   ├── test_configs.py (200 tokens)
   ├── test_inference.py (1800 tokens)
   ├── test_scheduler.py (700 tokens)
   ├── test_vae.py (600 tokens)
   ├── utils/
      ├── .gitattributes
      ├── woman.jpeg
      ├── woman.mp4
```


## /.github/workflows/pylint.yml

```yml path="/.github/workflows/pylint.yml" 
name: Ruff

on: [push]

jobs:
  build:
    runs-on: ubuntu-latest
    strategy:
      matrix:
        python-version: ["3.10"]
    steps:
    - name: Checkout repository and submodules
      uses: actions/checkout@v4
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v5
      with:
        python-version: ${{ matrix.python-version }}
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install ruff==0.2.2 black==24.2.0
    - name: Analyzing the code with ruff
      run: |
        ruff $(git ls-files '*.py')
    - name: Verify that no Black changes are required
      run: |
        black --check $(git ls-files '*.py')

```

## /.gitignore

```gitignore path="/.gitignore" 
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# From inference.py
outputs/
*.mp4
*.png
!tests/utils/car.png
```

## /.pre-commit-config.yaml

```yaml path="/.pre-commit-config.yaml" 
repos:
  - repo: https://github.com/astral-sh/ruff-pre-commit
    # Ruff version.
    rev: v0.2.2
    hooks:
      # Run the linter.
      - id: ruff
        args: [--fix]  # Automatically fix issues if possible.
        types: [python]  # Ensure it only runs on .py files.

  - repo: https://github.com/psf/black
    rev: 24.2.0  # Specify the version of Black you want
    hooks:
      - id: black
        name: Black code formatter
        language_version: python3  # Use the Python version you're targeting (e.g., 3.10)
```

## /README.md

<div align="center">

# LTX-Video

[![Website](https://img.shields.io/badge/Website-LTXV-181717?logo=google-chrome)](https://ltx.video)
[![Model](https://img.shields.io/badge/HuggingFace-Model-orange?logo=huggingface)](https://huggingface.co/Lightricks/LTX-Video)
[![Demo](https://img.shields.io/badge/Demo-Try%20Now-brightgreen?logo=vercel)](https://app.ltx.studio/ltx-2-playground/t2v)
[![Paper](https://img.shields.io/badge/Paper-arXiv-B31B1B?logo=arxiv)](https://arxiv.org/abs/2501.00103)
[![Trainer](https://img.shields.io/badge/LTXV-Trainer-9146FF?logo=github)](https://github.com/Lightricks/LTX-Video-Trainer)
[![Discord](https://img.shields.io/badge/Join-Discord-5865F2?logo=discord)](https://discord.gg/ltxplatform)

This is the official repository for LTX-Video.

</div>

---

## 🚀 **New: LTX-2 is Now Available!**

**We're excited to announce [LTX-2](https://github.com/Lightricks/LTX-2) - the next generation of LTX with synchronized audio+video generation!**

LTX-2 is the first DiT-based audio-video foundation model that contains all core capabilities of modern video generation in one model. **LTX-2 is now the primary home for LTX development** and includes significant improvements:

- 🎵 **Synchronized Audio+Video Generation** - Generate videos with perfectly synchronized audio
- 🎬 **Latest Model** - LTX-2 with improved quality and capabilities
- 🔌 **ComfyUI Integration** - Built into ComfyUI core for seamless workflows
- 🎯 **Advanced Features:**
  - Multiple keyframe support
  - IC-LoRA control models for precise generation
  - Standard LoRA support for style customization
  - Latent upsampler for multiscale pipelines
- 🛠️ **Training Tools** - LoRA training capabilities
- 📚 **Comprehensive Documentation** - Full documentation at [https://docs.ltx.video](https://docs.ltx.video)
- 🔄 **Active Development** - Ongoing improvements and community support

**[👉 Check out LTX-2 here](https://github.com/Lightricks/LTX-2)**

**[📖 View Documentation](https://docs.ltx.video)**

---

## Table of Contents

- [Introduction](#introduction)
- [What's New](#news)
- [Models](#models)
- [Quick Start Guide](#quick-start-guide)
  - [Online demo](#online-inference)
  - [Run locally](#run-locally)
    - [Installation](#installation)
    - [Inference](#inference)
  - [ComfyUI Integration](#comfyui-integration)
  - [Diffusers Integration](#diffusers-integration)
- [Model User Guide](#model-user-guide)
- [Community Contribution](#community-contribution)
- [Training](#training)
- [Control Models](#control-models)
- [Join Us!](#join-us)
- [Acknowledgement](#acknowledgement)

# Introduction

LTX-Video is the first DiT-based video generation model that contains all core capabilities of modern video generation in one model: synchronized audio and video, high fidelity, multiple performance modes, production-ready outputs, API access, and open access. It can generate up to 50 FPS videos at native 4K resolution with synchronized audio in one pass.
The model is trained on a large-scale dataset of diverse videos and can generate high-resolution videos with realistic and diverse content.

The model supports image-to-video, multi-keyframe conditioning, keyframe-based animation, video extension (both forward and backward), video-to-video transformations, and any combination of these features.

### Image-to-video examples
| | | |
|:---:|:---:|:---:|
| ![example1](./docs/_static/ltx-video_i2v_example_00001.gif) | ![example2](./docs/_static/ltx-video_i2v_example_00002.gif) | ![example3](./docs/_static/ltx-video_i2v_example_00003.gif) |
| ![example4](./docs/_static/ltx-video_i2v_example_00004.gif) | ![example5](./docs/_static/ltx-video_i2v_example_00005.gif) |  ![example6](./docs/_static/ltx-video_i2v_example_00006.gif) |
| ![example7](./docs/_static/ltx-video_i2v_example_00007.gif) |  ![example8](./docs/_static/ltx-video_i2v_example_00008.gif) | ![example9](./docs/_static/ltx-video_i2v_example_00009.gif) |

### Controlled video examples
| | | |
|:---:|:---:|:---:|
| ![control0](./docs/_static/ltx-video_ic_2v_example_00000.gif) | ![control1](./docs/_static/ltx-video_ic_2v_example_00001.gif) | ![control2](./docs/_static/ltx-video_ic_2v_example_00002.gif) |

| | |
|:---:|:---:|
| ![control3](./docs/_static/ltx-video_ic_2v_example_00003.gif) | ![control4](./docs/_static/ltx-video_ic_2v_example_00004.gif) |

# News

## October 23, 2025: LTX-2 Announced

Today we announced our newest foundation model, LTX-2. LTX-2 represents a major leap forward from our previous model, LTXV 0.9.8. Here’s what’s new:
* **Audio + Video, Together**: Visuals and sound are generated in one coherent process, with motion, dialogue, ambience, and music flowing simultaneously.
* **4K Fidelity**: Professional-grade precision with native 4K and up to 50 fps, sharp textures, clean motion, and synchronized audio.
* **Longer Generations**: LTX-2 supports longer, continuous clips with synchronized audio up to 10 seconds.
* **Low Cost & Efficiency**: Up to 50% lower compute cost than competing models, powered by a multi-GPU inference stack.
* **Creative Control**: Multi-keyframe conditioning, 3D camera logic, and LoRA fine-tuning deliver frame-level precision and style consistency.

For more details, please see our [blog post](https://website.ltx.video/blog/introducing-ltx-2). LTX-2 model weights, code, and benchmarks will be released to the community later in 2025. 


## July, 16th, 2025: New Distilled models v0.9.8 with up to 60 seconds of video:
- Long shot generation in LTXV-13B!
  * LTX-Video now supports up to 60 seconds of video.
  * Compatible also with the official IC-LoRAs.
  * Try now in [ComfyUI](https://github.com/Lightricks/ComfyUI-LTXVideo/tree/master/example_workflows/ltxv-13b-i2v-long-multi-prompt.json).
- Release a new distilled models:
  * 13B distilled model [ltxv-13b-0.9.8-distilled](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-distilled.yaml)
  * 2B distilled model [ltxv-2b-0.9.8-distilled](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.8-distilled.yaml)
  * Both models are distilled from the same base model [ltxv-13b-0.9.8-dev](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-dev.yaml) and are compatible for use together in the same multiscale pipeline.
  * Improved prompt understanding and detail generation
  * Includes corresponding FP8 weights and workflows.
- Release a new detailer model [LTX-Video-ICLoRA-detailer-13B-0.9.8](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8)
  * Available in [ComfyUI](https://github.com/Lightricks/ComfyUI-LTXVideo/tree/master/example_workflows/ltxv-13b-upscale.json).

## July, 8th, 2025: New Control Models Released!
- Released three new control models for LTX-Video on HuggingFace:
    * **Depth Control**: [LTX-Video-ICLoRA-depth-13b-0.9.7](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-depth-13b-0.9.7)
    * **Pose Control**: [LTX-Video-ICLoRA-pose-13b-0.9.7](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-pose-13b-0.9.7)
    * **Canny Control**: [LTX-Video-ICLoRA-canny-13b-0.9.7](https://huggingface.co/Lightricks/LTX-Video-ICLoRA-canny-13b-0.9.7)


## May, 14th, 2025: New distilled model 13B v0.9.7:
- Release a new 13B distilled model [ltxv-13b-0.9.7-distilled](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors)
    * Amazing for iterative work - generates HD videos in 10 seconds, with low-res preview after just 3 seconds (on H100)!
    * Does not require classifier-free guidance and spatio-temporal guidance.
    * Supports sampling with 8 (recommended), or less diffusion steps.
    * Also released a LoRA version of the distilled model, [ltxv-13b-0.9.7-distilled-lora128](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-lora128.safetensors)
        * Requires only 1GB of VRAM
        * Can be used with the full 13B model for fast inference
- Release a new quantized distilled model [ltxv-13b-0.9.7-distilled-fp8](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-fp8.safetensors) for *real-time* generation (on H100) with even less VRAM

## May, 5th, 2025: New model 13B v0.9.7:
- Release a new 13B model [ltxv-13b-0.9.7-dev](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors)
- Release a new quantized model [ltxv-13b-0.9.7-dev-fp8](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev-fp8.safetensors) for faster inference with less VRam
- Release a new upscalers
  * [ltxv-temporal-upscaler-0.9.7](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-temporal-upscaler-0.9.7.safetensors)
  * [ltxv-spatial-upscaler-0.9.7](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors)
- Breakthrough prompt adherence and physical understanding.
- New Pipeline for multi-scale video rendering for fast and high quality results


## April, 15th, 2025: New checkpoints v0.9.6:
- Release a new checkpoint [ltxv-2b-0.9.6-dev-04-25](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-dev-04-25.safetensors) with improved quality
- Release a new distilled model [ltxv-2b-0.9.6-distilled-04-25](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-distilled-04-25.safetensors)
    * 15x faster inference than non-distilled model.
    * Does not require classifier-free guidance and spatio-temporal guidance.
    * Supports sampling with 8 (recommended), or less diffusion steps.
- Improved prompt adherence, motion quality and fine details.
- New default resolution and FPS: 1216 × 704 pixels at 30 FPS
    * Still real time on H100 with the distilled model.
    * Other resolutions and FPS are still supported.
- Support stochastic inference (can improve visual quality when using the distilled model)

## March, 5th, 2025: New checkpoint v0.9.5
- New license for commercial use ([OpenRail-M](https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.5.license.txt))
- Release a new checkpoint v0.9.5 with improved quality
- Support keyframes and video extension
- Support higher resolutions
- Improved prompt understanding
- Improved VAE
- New online web app in [LTX-Studio](https://app.ltx.studio/ltx-video)
- Automatic prompt enhancement

## February, 20th, 2025: More inference options
- Improve STG (Spatiotemporal Guidance) for LTX-Video
- Support MPS on macOS with PyTorch 2.3.0
- Add support for 8-bit model, LTX-VideoQ8
- Add TeaCache for LTX-Video
- Add [ComfyUI-LTXTricks](#comfyui-integration)
- Add Diffusion-Pipe

## December 31st, 2024: Research paper
- Release the [research paper](https://arxiv.org/abs/2501.00103)

## December 20th, 2024: New checkpoint v0.9.1
- Release a new checkpoint v0.9.1 with improved quality
- Support for STG / PAG
- Support loading checkpoints of LTX-Video in Diffusers format (conversion is done on-the-fly)
- Support offloading unused parts to CPU
- Support the new timestep-conditioned VAE decoder
- Reference contributions from the community in the readme file
- Relax transformers dependency

## November 21th, 2024: Initial release v0.9.0
- Initial release of LTX-Video
- Support text-to-video and image-to-video generation


# Models

| Name                    | Notes                                                                                      | inference.py config                                                                                                                                      | ComfyUI workflow (Recommended) |
|-------------------------|--------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------|
| ltxv-13b-0.9.8-dev                   | Highest quality, requires more VRAM                                                        | [ltxv-13b-0.9.8-dev.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-dev.yaml)                                             | [ltxv-13b-i2v-base.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-base.json)             |
| [ltxv-13b-0.9.8-mix](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b)            | Mix ltxv-13b-dev and ltxv-13b-distilled in the same multi-scale rendering workflow for balanced speed-quality | N/A                                             | [ltxv-13b-i2v-mixed-multiscale.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-mixed-multiscale.json)             |
 [ltxv-13b-0.9.8-distilled](https://app.ltx.studio/motion-workspace?videoModel=ltxv)        | Faster, less VRAM usage, slight quality reduction compared to 13b. Ideal for rapid iterations | [ltxv-13b-0.9.8-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-distilled.yaml)                                    | [ltxv-13b-dist-i2v-base.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/13b-distilled/ltxv-13b-dist-i2v-base.json) |
ltxv-2b-0.9.8-distilled        | Smaller model, slight quality reduction compared to 13b distilled. Ideal for fast generation with light VRAM usage | [ltxv-2b-0.9.8-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.8-distilled.yaml)                                    | N/A |
| ltxv-13b-0.9.8-dev-fp8               | Quantized version of ltxv-13b | [ltxv-13b-0.9.8-dev-fp8.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-dev-fp8.yaml) | [ltxv-13b-i2v-base-fp8.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-base-fp8.json) |
| ltxv-13b-0.9.8-distilled-fp8     | Quantized version of ltxv-13b-distilled | [ltxv-13b-0.9.8-distilled-fp8.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.8-distilled-fp8.yaml) | [ltxv-13b-dist-i2v-base-fp8.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/13b-distilled/ltxv-13b-dist-i2v-base-fp8.json) |
| ltxv-2b-0.9.8-distilled-fp8     | Quantized version of ltxv-2b-distilled | [ltxv-2b-0.9.8-distilled-fp8.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.8-distilled-fp8.yaml) | N/A |
| ltxv-2b-0.9.6                     | Good quality, lower VRAM requirement than ltxv-13b                                         | [ltxv-2b-0.9.6-dev.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.6-dev.yaml)                                                 | [ltxvideo-i2v.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/low_level/ltxvideo-i2v.json)             |
| ltxv-2b-0.9.6-distilled         | 15× faster, real-time capable, fewer steps needed, no STG/CFG required                     | [ltxv-2b-0.9.6-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.6-distilled.yaml)                                     | [ltxvideo-i2v-distilled.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/low_level/ltxvideo-i2v-distilled.json)             |


# Quick Start Guide

## Online inference
The model is accessible right away via the following links:
- [LTX-Studio image-to-video (13B-mix)](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b)
- [LTX-Studio image-to-video (13B distilled)](https://app.ltx.studio/motion-workspace?videoModel=ltxv)
- [Fal.ai image-to-video (13B full)](https://fal.ai/models/fal-ai/ltx-video-13b-dev/image-to-video)
- [Fal.ai image-to-video (13B distilled)](https://fal.ai/models/fal-ai/ltx-video-13b-distilled/image-to-video)
- [Replicate image-to-video](https://replicate.com/lightricks/ltx-video)

## Run locally

### Installation
The codebase was tested with Python 3.10.5, CUDA version 12.2, and supports PyTorch >= 2.1.2.
On macOS, MPS was tested with PyTorch 2.3.0, and should support PyTorch == 2.3 or >= 2.6.

```bash
git clone https://github.com/Lightricks/LTX-Video.git
cd LTX-Video

# create env
python -m venv env
source env/bin/activate
python -m pip install -e .\[inference\]
```

#### FP8 Kernels (optional)

[FP8 kernels](https://github.com/Lightricks/LTXVideo-Q8-Kernels) developed for LTX-Video provide performance boost on supported graphics cards (Ada architecture and later). To install FP8 kernels, follow the instructions in that repository.

### Inference

📝 **Note:** For best results, we recommend using our [ComfyUI](#comfyui-integration) workflow. We're working on updating the inference.py script to match the high quality and output fidelity of ComfyUI.

To use our model, please follow the inference code in [inference.py](./inference.py):

#### For image-to-video generation:

```bash
python inference.py --prompt "PROMPT" --conditioning_media_paths IMAGE_PATH --conditioning_start_frames 0 --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.8-distilled.yaml
```

#### Extending a video:

📝 **Note:** Input video segments must contain a multiple of 8 frames plus 1 (e.g., 9, 17, 25, etc.), and the target frame number should be a multiple of 8.


```bash
python inference.py --prompt "PROMPT" --conditioning_media_paths VIDEO_PATH --conditioning_start_frames START_FRAME --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.8-distilled.yaml
```

#### For video generation with multiple conditions:

You can now generate a video conditioned on a set of images and/or short video segments.
Simply provide a list of paths to the images or video segments you want to condition on, along with their target frame numbers in the generated video. You can also specify the conditioning strength for each item (default: 1.0).

```bash
python inference.py --prompt "PROMPT" --conditioning_media_paths IMAGE_OR_VIDEO_PATH_1 IMAGE_OR_VIDEO_PATH_2 --conditioning_start_frames TARGET_FRAME_1 TARGET_FRAME_2 --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.8-distilled.yaml
```

### Using as a library

```python
from ltx_video.inference import infer, InferenceConfig

infer(
    InferenceConfig(
        pipeline_config="configs/ltxv-13b-0.9.8-distilled.yaml",
        prompt=PROMPT,
        height=HEIGHT,
        width=WIDTH,
        num_frames=NUM_FRAMES,
        output_path="output.mp4",
    )
)
```

## ComfyUI Integration
To use our model with ComfyUI, please follow the instructions at [https://github.com/Lightricks/ComfyUI-LTXVideo/](https://github.com/Lightricks/ComfyUI-LTXVideo/).

## Diffusers Integration
To use our model with the Diffusers Python library, check out the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).

Diffusers also support an 8-bit version of LTX-Video, [see details below](#ltx-videoq8)

# Model User Guide

## 📝 Prompt Engineering

When writing prompts, focus on detailed, chronological descriptions of actions and scenes. Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. Start directly with the action, and keep descriptions literal and precise. Think like a cinematographer describing a shot list. Keep within 200 words. For best results, build your prompts using this structure:

* Start with main action in a single sentence
* Add specific details about movements and gestures
* Describe character/object appearances precisely
* Include background and environment details
* Specify camera angles and movements
* Describe lighting and colors
* Note any changes or sudden events
* See [examples](#introduction) for more inspiration.

### Automatic Prompt Enhancement

When using `LTXVideoPipeline` directly, you can enable prompt enhancement by setting `enhance_prompt=True`.

## 🎮 Parameter Guide

* Resolution Preset: Higher resolutions for detailed scenes, lower for faster generation and simpler scenes. The model works on resolutions that are divisible by 32 and number of frames that are divisible by 8 + 1 (e.g. 257). In case the resolution or number of frames are not divisible by 32 or 8 + 1, the input will be padded with -1 and then cropped to the desired resolution and number of frames. The model works best on resolutions under 720 x 1280 and number of frames below 257
* Seed: Save seed values to recreate specific styles or compositions you like
* Guidance Scale: 3-3.5 are the recommended values
* Inference Steps: More steps (40+) for quality, fewer steps (20-30) for speed

📝 For advanced parameters usage, please see `python inference.py --help`

## Community Contribution

### ComfyUI-LTXTricks 🛠️

A community project providing additional nodes for enhanced control over the LTX Video model. It includes implementations of advanced techniques like RF-Inversion, RF-Edit, FlowEdit, and more. These nodes enable workflows such as Image and Video to Video (I+V2V), enhanced sampling via Spatiotemporal Skip Guidance (STG), and interpolation with precise frame settings.

- **Repository:** [ComfyUI-LTXTricks](https://github.com/logtd/ComfyUI-LTXTricks)
- **Features:**
  - 🔄 **RF-Inversion:** Implements [RF-Inversion](https://rf-inversion.github.io/) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_inversion.json).
  - ✂️ **RF-Edit:** Implements [RF-Solver-Edit](https://github.com/wangjiangshan0725/RF-Solver-Edit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_rf_edit.json).
  - 🌊 **FlowEdit:** Implements [FlowEdit](https://github.com/fallenshock/FlowEdit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_flow_edit.json).
  - 🎥 **I+V2V:** Enables Video to Video with a reference image. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_iv2v.json).
  - ✨ **Enhance:** Partial implementation of [STGuidance](https://junhahyung.github.io/STGuidance/). [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltxv_stg.json).
  - 🖼️ **Interpolation and Frame Setting:** Nodes for precise control of latents per frame. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_interpolation.json).


### LTX-VideoQ8 🎱 <a id="ltx-videoq8"></a>

**LTX-VideoQ8** is an 8-bit optimized version of [LTX-Video](https://github.com/Lightricks/LTX-Video), designed for faster performance on NVIDIA ADA GPUs.

- **Repository:** [LTX-VideoQ8](https://github.com/KONAKONA666/LTX-Video)
- **Features:**
  - 🚀 Up to 3X speed-up with no accuracy loss
  - 🎥 Generate 720x480x121 videos in under a minute on RTX 4060 (8GB VRAM)
  - 🛠️ Fine-tune 2B transformer models with precalculated latents
- **Community Discussion:** [Reddit Thread](https://www.reddit.com/r/StableDiffusion/comments/1h79ks2/fast_ltx_video_on_rtx_4060_and_other_ada_gpus/)
- **Diffusers integration:** A diffusers integration for the 8-bit model is already out! [Details here](https://github.com/sayakpaul/q8-ltx-video)


### TeaCache for LTX-Video 🍵 <a id="TeaCache"></a>

**TeaCache** is a training-free caching approach that leverages timestep differences across model outputs to accelerate LTX-Video inference by up to 2x without significant visual quality degradation.

- **Repository:** [TeaCache4LTX-Video](https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4LTX-Video)
- **Features:**
  - 🚀 Speeds up LTX-Video inference.
  - 📊 Adjustable trade-offs between speed (up to 2x) and visual quality using configurable parameters.
  - 🛠️ No retraining required: Works directly with existing models.

### Your Contribution

...is welcome! If you have a project or tool that integrates with LTX-Video,
please let us know by opening an issue or pull request.

# Training

We provide an open-source repository for fine-tuning the LTX-Video model: [LTX-Video-Trainer](https://github.com/Lightricks/LTX-Video-Trainer).
This repository supports both the 2B and 13B model variants, enabling full fine-tuning as well as LoRA (Low-Rank Adaptation) fine-tuning for more efficient training. This includes:

- **Control LoRAs**: Train custom control models like depth, pose, and canny control
- **Effect LoRAs**: Create specialized effects and transformations for video generation

Explore the repository to customize the model for your specific use cases!
More information and training instructions can be found in the [README](https://github.com/Lightricks/LTX-Video-Trainer/blob/main/README.md).

# Control Models

[ComfyUI-LTXVideo](https://github.com/Lightricks/ComfyUI-LTXVideo) repository now contains workflows and models for 3 specialized models that enable precise control over LTX-Video generation:

Pose Control, Depth Control and Canny Control

**Example ComfyUI Workflow (for all control types):** [ic-lora.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ic_lora/ic-lora.json)

# Join Us

Want to work on cutting-edge AI research and make a real impact on millions of users worldwide?

At **Lightricks**, an AI-first company, we're revolutionizing how visual content is created.

If you are passionate about AI, computer vision, and video generation, we would love to hear from you!

Please visit our [careers page](https://careers.lightricks.com/careers?query=&office=all&department=R%26D) for more information.

# Acknowledgement

We are grateful for the following awesome projects when implementing LTX-Video:
* [DiT](https://github.com/facebookresearch/DiT) and [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): vision transformers for image generation.


## Citation

📄 Our tech report is out! If you find our work helpful, please ⭐️ star the repository and cite our paper.

```
@article{HaCohen2024LTXVideo,
  title={LTX-Video: Realtime Video Latent Diffusion},
  author={HaCohen, Yoav and Chiprut, Nisan and Brazowski, Benny and Shalem, Daniel and Moshe, Dudu and Richardson, Eitan and Levin, Eran and Shiran, Guy and Zabari, Nir and Gordon, Ori and Panet, Poriya and Weissbuch, Sapir and Kulikov, Victor and Bitterman, Yaki and Melumian, Zeev and Bibi, Ofir},
  journal={arXiv preprint arXiv:2501.00103},
  year={2024}
}
```


## /configs/ltxv-13b-0.9.8-dev-fp8.yaml

```yaml path="/configs/ltxv-13b-0.9.8-dev-fp8.yaml" 
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-dev-fp8.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

first_pass:
  guidance_scale: [1, 1, 6, 8, 6, 1, 1]
  stg_scale: [0, 0, 4, 4, 4, 2, 1]
  rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
  guidance_timesteps: [1.0, 0.996,  0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
  skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
  num_inference_steps: 30
  skip_final_inference_steps: 3
  cfg_star_rescale: true

second_pass:
  guidance_scale: [1]
  stg_scale: [1]
  rescaling_scale: [1]
  guidance_timesteps: [1.0]
  skip_block_list: [27]
  num_inference_steps: 30
  skip_initial_inference_steps: 17
  cfg_star_rescale: true

```

## /configs/ltxv-13b-0.9.8-dev.yaml

```yaml path="/configs/ltxv-13b-0.9.8-dev.yaml" 
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-dev.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

first_pass:
  guidance_scale: [1, 1, 6, 8, 6, 1, 1]
  stg_scale: [0, 0, 4, 4, 4, 2, 1]
  rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
  guidance_timesteps: [1.0, 0.996,  0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
  skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
  num_inference_steps: 30
  skip_final_inference_steps: 3
  cfg_star_rescale: true

second_pass:
  guidance_scale: [1]
  stg_scale: [1]
  rescaling_scale: [1]
  guidance_timesteps: [1.0]
  skip_block_list: [27]
  num_inference_steps: 30
  skip_initial_inference_steps: 17
  cfg_star_rescale: true
```

## /configs/ltxv-13b-0.9.8-distilled-fp8.yaml

```yaml path="/configs/ltxv-13b-0.9.8-distilled-fp8.yaml" 
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-distilled-fp8.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

first_pass:
  timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]

second_pass:
  timesteps: [0.9094, 0.7250, 0.4219]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]
  tone_map_compression_ratio: 0.6

```

## /configs/ltxv-13b-0.9.8-distilled.yaml

```yaml path="/configs/ltxv-13b-0.9.8-distilled.yaml" 
pipeline_type: multi-scale
checkpoint_path: "ltxv-13b-0.9.8-distilled.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

first_pass:
  timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]

second_pass:
  timesteps: [0.9094, 0.7250, 0.4219]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]
  tone_map_compression_ratio: 0.6

```

## /configs/ltxv-2b-0.9.1.yaml

```yaml path="/configs/ltxv-2b-0.9.1.yaml" 
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.1.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
```

## /configs/ltxv-2b-0.9.5.yaml

```yaml path="/configs/ltxv-2b-0.9.5.yaml" 
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.5.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
```

## /configs/ltxv-2b-0.9.6-dev.yaml

```yaml path="/configs/ltxv-2b-0.9.6-dev.yaml" 
pipeline_type: base
checkpoint_path: "ltxv-2b-0.9.6-dev-04-25.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
```

## /configs/ltxv-2b-0.9.6-distilled.yaml

```yaml path="/configs/ltxv-2b-0.9.6-distilled.yaml" 
pipeline_type: base
checkpoint_path: "ltxv-2b-0.9.6-distilled-04-25.safetensors"
guidance_scale: 1
stg_scale: 0
rescaling_scale: 1
num_inference_steps: 8
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: true
```

## /configs/ltxv-2b-0.9.8-distilled-fp8.yaml

```yaml path="/configs/ltxv-2b-0.9.8-distilled-fp8.yaml" 
pipeline_type: multi-scale
checkpoint_path: "ltxv-2b-0.9.8-distilled-fp8.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

first_pass:
  timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]

second_pass:
  timesteps: [0.9094, 0.7250, 0.4219]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]

```

## /configs/ltxv-2b-0.9.8-distilled.yaml

```yaml path="/configs/ltxv-2b-0.9.8-distilled.yaml" 
pipeline_type: multi-scale
checkpoint_path: "ltxv-2b-0.9.8-distilled.safetensors"
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false

first_pass:
  timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]

second_pass:
  timesteps: [0.9094, 0.7250, 0.4219]
  guidance_scale: 1
  stg_scale: 0
  rescaling_scale: 1
  skip_block_list: [42]

```

## /configs/ltxv-2b-0.9.yaml

```yaml path="/configs/ltxv-2b-0.9.yaml" 
pipeline_type: base
checkpoint_path: "ltx-video-2b-v0.9.safetensors"
guidance_scale: 3
stg_scale: 1
rescaling_scale: 0.7
skip_block_list: [19]
num_inference_steps: 40
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
decode_timestep: 0.05
decode_noise_scale: 0.025
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
precision: "bfloat16"
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
prompt_enhancement_words_threshold: 120
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
stochastic_sampling: false
```

## /docs/_static/ltx-video_example_00001.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00001.gif

## /docs/_static/ltx-video_example_00005.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00005.gif

## /docs/_static/ltx-video_example_00006.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00006.gif

## /docs/_static/ltx-video_example_00007.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00007.gif

## /docs/_static/ltx-video_example_00010.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00010.gif

## /docs/_static/ltx-video_example_00011.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00011.gif

## /docs/_static/ltx-video_example_00013.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00013.gif

## /docs/_static/ltx-video_example_00014.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00014.gif

## /docs/_static/ltx-video_example_00015.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_example_00015.gif

## /docs/_static/ltx-video_i2v_example_00001.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00001.gif

## /docs/_static/ltx-video_i2v_example_00002.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00002.gif

## /docs/_static/ltx-video_i2v_example_00003.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00003.gif

## /docs/_static/ltx-video_i2v_example_00004.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00004.gif

## /docs/_static/ltx-video_i2v_example_00005.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00005.gif

## /docs/_static/ltx-video_i2v_example_00006.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00006.gif

## /docs/_static/ltx-video_i2v_example_00007.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00007.gif

## /docs/_static/ltx-video_i2v_example_00008.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00008.gif

## /docs/_static/ltx-video_i2v_example_00009.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_i2v_example_00009.gif

## /docs/_static/ltx-video_ic_2v_example_00000.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_ic_2v_example_00000.gif

## /docs/_static/ltx-video_ic_2v_example_00001.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_ic_2v_example_00001.gif

## /docs/_static/ltx-video_ic_2v_example_00002.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_ic_2v_example_00002.gif

## /docs/_static/ltx-video_ic_2v_example_00003.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_ic_2v_example_00003.gif

## /docs/_static/ltx-video_ic_2v_example_00004.gif

Binary file available at https://raw.githubusercontent.com/Lightricks/LTX-Video/refs/heads/main/docs/_static/ltx-video_ic_2v_example_00004.gif

## /inference.py

```py path="/inference.py" 
from transformers import HfArgumentParser

from ltx_video.inference import infer, InferenceConfig


def main():
    parser = HfArgumentParser(InferenceConfig)
    config = parser.parse_args_into_dataclasses()[0]
    infer(config=config)


if __name__ == "__main__":
    main()

```

## /ltx_video/__init__.py

```py path="/ltx_video/__init__.py" 

```

## /ltx_video/inference.py

```py path="/ltx_video/inference.py" 
import os
import random
from datetime import datetime
from pathlib import Path
from diffusers.utils import logging
from typing import Optional, List, Union
import yaml

import imageio
import json
import numpy as np
import torch
from safetensors import safe_open
from PIL import Image
import torchvision.transforms.functional as TVF
from transformers import (
    T5EncoderModel,
    T5Tokenizer,
    AutoModelForCausalLM,
    AutoProcessor,
    AutoTokenizer,
)
from huggingface_hub import hf_hub_download
from dataclasses import dataclass, field

from ltx_video.models.autoencoders.causal_video_autoencoder import (
    CausalVideoAutoencoder,
)
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
from ltx_video.models.transformers.transformer3d import Transformer3DModel
from ltx_video.pipelines.pipeline_ltx_video import (
    ConditioningItem,
    LTXVideoPipeline,
    LTXMultiScalePipeline,
)
from ltx_video.schedulers.rf import RectifiedFlowScheduler
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
import ltx_video.pipelines.crf_compressor as crf_compressor

logger = logging.get_logger("LTX-Video")


def get_total_gpu_memory():
    if torch.cuda.is_available():
        total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        return total_memory
    return 0


def get_device():
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def load_image_to_tensor_with_resize_and_crop(
    image_input: Union[str, Image.Image],
    target_height: int = 512,
    target_width: int = 768,
    just_crop: bool = False,
) -> torch.Tensor:
    """Load and process an image into a tensor.

    Args:
        image_input: Either a file path (str) or a PIL Image object
        target_height: Desired height of output tensor
        target_width: Desired width of output tensor
        just_crop: If True, only crop the image to the target size without resizing
    """
    if isinstance(image_input, str):
        image = Image.open(image_input).convert("RGB")
    elif isinstance(image_input, Image.Image):
        image = image_input
    else:
        raise ValueError("image_input must be either a file path or a PIL Image object")

    input_width, input_height = image.size
    aspect_ratio_target = target_width / target_height
    aspect_ratio_frame = input_width / input_height
    if aspect_ratio_frame > aspect_ratio_target:
        new_width = int(input_height * aspect_ratio_target)
        new_height = input_height
        x_start = (input_width - new_width) // 2
        y_start = 0
    else:
        new_width = input_width
        new_height = int(input_width / aspect_ratio_target)
        x_start = 0
        y_start = (input_height - new_height) // 2

    image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
    if not just_crop:
        image = image.resize((target_width, target_height))

    frame_tensor = TVF.to_tensor(image)  # PIL -> tensor (C, H, W), [0,1]
    frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=3, sigma=1.0)
    frame_tensor_hwc = frame_tensor.permute(1, 2, 0)  # (C, H, W) -> (H, W, C)
    frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
    frame_tensor = frame_tensor_hwc.permute(2, 0, 1) * 255.0  # (H, W, C) -> (C, H, W)
    frame_tensor = (frame_tensor / 127.5) - 1.0
    # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
    return frame_tensor.unsqueeze(0).unsqueeze(2)


def calculate_padding(
    source_height: int, source_width: int, target_height: int, target_width: int
) -> tuple[int, int, int, int]:

    # Calculate total padding needed
    pad_height = target_height - source_height
    pad_width = target_width - source_width

    # Calculate padding for each side
    pad_top = pad_height // 2
    pad_bottom = pad_height - pad_top  # Handles odd padding
    pad_left = pad_width // 2
    pad_right = pad_width - pad_left  # Handles odd padding

    # Return padded tensor
    # Padding format is (left, right, top, bottom)
    padding = (pad_left, pad_right, pad_top, pad_bottom)
    return padding


def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
    # Remove non-letters and convert to lowercase
    clean_text = "".join(
        char.lower() for char in text if char.isalpha() or char.isspace()
    )

    # Split into words
    words = clean_text.split()

    # Build result string keeping track of length
    result = []
    current_length = 0

    for word in words:
        # Add word length plus 1 for underscore (except for first word)
        new_length = current_length + len(word)

        if new_length <= max_len:
            result.append(word)
            current_length += len(word)
        else:
            break

    return "-".join(result)


# Generate output video name
def get_unique_filename(
    base: str,
    ext: str,
    prompt: str,
    seed: int,
    resolution: tuple[int, int, int],
    dir: Path,
    endswith=None,
    index_range=1000,
) -> Path:
    base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
    for i in range(index_range):
        filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
        if not os.path.exists(filename):
            return filename
    raise FileExistsError(
        f"Could not find a unique filename after {index_range} attempts."
    )


def seed_everething(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    if torch.backends.mps.is_available():
        torch.mps.manual_seed(seed)


def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel:
    if precision == "float8_e4m3fn":
        try:
            from q8_kernels.integration.patch_transformer import (
                patch_diffusers_transformer as patch_transformer_for_q8_kernels,
            )

            transformer = Transformer3DModel.from_pretrained(
                ckpt_path, dtype=torch.float8_e4m3fn
            )
            patch_transformer_for_q8_kernels(transformer)
            return transformer
        except ImportError:
            raise ValueError(
                "Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from https://github.com/Lightricks/LTXVideo-Q8-Kernels"
            )
    elif precision == "bfloat16":
        return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16)
    else:
        return Transformer3DModel.from_pretrained(ckpt_path)


def create_ltx_video_pipeline(
    ckpt_path: str,
    precision: str,
    text_encoder_model_name_or_path: str,
    sampler: Optional[str] = None,
    device: Optional[str] = None,
    enhance_prompt: bool = False,
    prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
    prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
) -> LTXVideoPipeline:
    ckpt_path = Path(ckpt_path)
    assert os.path.exists(
        ckpt_path
    ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"

    with safe_open(ckpt_path, framework="pt") as f:
        metadata = f.metadata()
        config_str = metadata.get("config")
        configs = json.loads(config_str)
        allowed_inference_steps = configs.get("allowed_inference_steps", None)

    vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
    transformer = create_transformer(ckpt_path, precision)

    # Use constructor if sampler is specified, otherwise use from_pretrained
    if sampler == "from_checkpoint" or not sampler:
        scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
    else:
        scheduler = RectifiedFlowScheduler(
            sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
        )

    text_encoder = T5EncoderModel.from_pretrained(
        text_encoder_model_name_or_path, subfolder="text_encoder"
    )
    patchifier = SymmetricPatchifier(patch_size=1)
    tokenizer = T5Tokenizer.from_pretrained(
        text_encoder_model_name_or_path, subfolder="tokenizer"
    )

    transformer = transformer.to(device)
    vae = vae.to(device)
    text_encoder = text_encoder.to(device)

    if enhance_prompt:
        prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
            prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
        )
        prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
            prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
        )
        prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
            prompt_enhancer_llm_model_name_or_path,
            torch_dtype="bfloat16",
        )
        prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
            prompt_enhancer_llm_model_name_or_path,
        )
    else:
        prompt_enhancer_image_caption_model = None
        prompt_enhancer_image_caption_processor = None
        prompt_enhancer_llm_model = None
        prompt_enhancer_llm_tokenizer = None

    vae = vae.to(torch.bfloat16)
    text_encoder = text_encoder.to(torch.bfloat16)

    # Use submodels for the pipeline
    submodel_dict = {
        "transformer": transformer,
        "patchifier": patchifier,
        "text_encoder": text_encoder,
        "tokenizer": tokenizer,
        "scheduler": scheduler,
        "vae": vae,
        "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
        "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
        "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
        "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
        "allowed_inference_steps": allowed_inference_steps,
    }

    pipeline = LTXVideoPipeline(**submodel_dict)
    pipeline = pipeline.to(device)
    return pipeline


def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
    latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
    latent_upsampler.to(device)
    latent_upsampler.eval()
    return latent_upsampler


def load_pipeline_config(pipeline_config: str):
    current_file = Path(__file__)

    path = None
    if os.path.isfile(current_file.parent / pipeline_config):
        path = current_file.parent / pipeline_config
    elif os.path.isfile(pipeline_config):
        path = pipeline_config
    else:
        raise ValueError(f"Pipeline config file {pipeline_config} does not exist")

    with open(path, "r") as f:
        return yaml.safe_load(f)


@dataclass
class InferenceConfig:
    prompt: str = field(metadata={"help": "Prompt for the generation"})

    output_path: str = field(
        default_factory=lambda: Path(
            f"outputs/{datetime.today().strftime('%Y-%m-%d')}"
        ),
        metadata={"help": "Path to the folder to save the output video"},
    )

    # Pipeline settings
    pipeline_config: str = field(
        default="configs/ltxv-13b-0.9.7-dev.yaml",
        metadata={"help": "Path to the pipeline config file"},
    )
    seed: int = field(
        default=171198, metadata={"help": "Random seed for the inference"}
    )
    height: int = field(
        default=704, metadata={"help": "Height of the output video frames"}
    )
    width: int = field(
        default=1216, metadata={"help": "Width of the output video frames"}
    )
    num_frames: int = field(
        default=121,
        metadata={"help": "Number of frames to generate in the output video"},
    )
    frame_rate: int = field(
        default=30, metadata={"help": "Frame rate for the output video"}
    )
    offload_to_cpu: bool = field(
        default=False, metadata={"help": "Offloading unnecessary computations to CPU."}
    )
    negative_prompt: str = field(
        default="worst quality, inconsistent motion, blurry, jittery, distorted",
        metadata={"help": "Negative prompt for undesired features"},
    )

    # Video-to-video arguments
    input_media_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "Path to the input video (or image) to be modified using the video-to-video pipeline"
        },
    )

    # Conditioning
    image_cond_noise_scale: float = field(
        default=0.15,
        metadata={"help": "Amount of noise to add to the conditioned image"},
    )
    conditioning_media_paths: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "List of paths to conditioning media (images or videos). Each path will be used as a conditioning item."
        },
    )
    conditioning_strengths: Optional[List[float]] = field(
        default=None,
        metadata={
            "help": "List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items."
        },
    )
    conditioning_start_frames: Optional[List[int]] = field(
        default=None,
        metadata={
            "help": "List of frame indices where each conditioning item should be applied. Must match the number of conditioning items."
        },
    )


def infer(config: InferenceConfig):
    pipeline_config = load_pipeline_config(config.pipeline_config)

    ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
    if not os.path.isfile(ltxv_model_name_or_path):
        ltxv_model_path = hf_hub_download(
            repo_id="Lightricks/LTX-Video",
            filename=ltxv_model_name_or_path,
            repo_type="model",
        )
    else:
        ltxv_model_path = ltxv_model_name_or_path

    spatial_upscaler_model_name_or_path = pipeline_config.get(
        "spatial_upscaler_model_path"
    )
    if spatial_upscaler_model_name_or_path and not os.path.isfile(
        spatial_upscaler_model_name_or_path
    ):
        spatial_upscaler_model_path = hf_hub_download(
            repo_id="Lightricks/LTX-Video",
            filename=spatial_upscaler_model_name_or_path,
            repo_type="model",
        )
    else:
        spatial_upscaler_model_path = spatial_upscaler_model_name_or_path

    conditioning_media_paths = config.conditioning_media_paths
    conditioning_strengths = config.conditioning_strengths
    conditioning_start_frames = config.conditioning_start_frames

    # Validate conditioning arguments
    if conditioning_media_paths:
        # Use default strengths of 1.0
        if not conditioning_strengths:
            conditioning_strengths = [1.0] * len(conditioning_media_paths)
        if not conditioning_start_frames:
            raise ValueError(
                "If `conditioning_media_paths` is provided, "
                "`conditioning_start_frames` must also be provided"
            )
        if len(conditioning_media_paths) != len(conditioning_strengths) or len(
            conditioning_media_paths
        ) != len(conditioning_start_frames):
            raise ValueError(
                "`conditioning_media_paths`, `conditioning_strengths`, "
                "and `conditioning_start_frames` must have the same length"
            )
        if any(s < 0 or s > 1 for s in conditioning_strengths):
            raise ValueError("All conditioning strengths must be between 0 and 1")
        if any(f < 0 or f >= config.num_frames for f in conditioning_start_frames):
            raise ValueError(
                f"All conditioning start frames must be between 0 and {config.num_frames-1}"
            )

    seed_everething(config.seed)
    if config.offload_to_cpu and not torch.cuda.is_available():
        logger.warning(
            "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
        )
        offload_to_cpu = False
    else:
        offload_to_cpu = config.offload_to_cpu and get_total_gpu_memory() < 30

    output_dir = (
        Path(config.output_path)
        if config.output_path
        else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
    )
    output_dir.mkdir(parents=True, exist_ok=True)

    # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
    height_padded = ((config.height - 1) // 32 + 1) * 32
    width_padded = ((config.width - 1) // 32 + 1) * 32
    num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1

    padding = calculate_padding(
        config.height, config.width, height_padded, width_padded
    )

    logger.warning(
        f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
    )

    device = get_device()

    prompt_enhancement_words_threshold = pipeline_config[
        "prompt_enhancement_words_threshold"
    ]

    prompt_word_count = len(config.prompt.split())
    enhance_prompt = (
        prompt_enhancement_words_threshold > 0
        and prompt_word_count < prompt_enhancement_words_threshold
    )

    if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
        logger.info(
            f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
        )

    precision = pipeline_config["precision"]
    text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
    sampler = pipeline_config.get("sampler", None)
    prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
        "prompt_enhancer_image_caption_model_name_or_path"
    ]
    prompt_enhancer_llm_model_name_or_path = pipeline_config[
        "prompt_enhancer_llm_model_name_or_path"
    ]

    pipeline = create_ltx_video_pipeline(
        ckpt_path=ltxv_model_path,
        precision=precision,
        text_encoder_model_name_or_path=text_encoder_model_name_or_path,
        sampler=sampler,
        device=device,
        enhance_prompt=enhance_prompt,
        prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
        prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
    )

    if pipeline_config.get("pipeline_type", None) == "multi-scale":
        if not spatial_upscaler_model_path:
            raise ValueError(
                "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
            )
        latent_upsampler = create_latent_upsampler(
            spatial_upscaler_model_path, pipeline.device
        )
        pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)

    media_item = None
    if config.input_media_path:
        media_item = load_media_file(
            media_path=config.input_media_path,
            height=config.height,
            width=config.width,
            max_frames=num_frames_padded,
            padding=padding,
        )

    conditioning_items = (
        prepare_conditioning(
            conditioning_media_paths=conditioning_media_paths,
            conditioning_strengths=conditioning_strengths,
            conditioning_start_frames=conditioning_start_frames,
            height=config.height,
            width=config.width,
            num_frames=config.num_frames,
            padding=padding,
            pipeline=pipeline,
        )
        if conditioning_media_paths
        else None
    )

    stg_mode = pipeline_config.get("stg_mode", "attention_values")
    del pipeline_config["stg_mode"]
    if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
        skip_layer_strategy = SkipLayerStrategy.AttentionValues
    elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
        skip_layer_strategy = SkipLayerStrategy.AttentionSkip
    elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
        skip_layer_strategy = SkipLayerStrategy.Residual
    elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
        skip_layer_strategy = SkipLayerStrategy.TransformerBlock
    else:
        raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")

    # Prepare input for the pipeline
    sample = {
        "prompt": config.prompt,
        "prompt_attention_mask": None,
        "negative_prompt": config.negative_prompt,
        "negative_prompt_attention_mask": None,
    }

    generator = torch.Generator(device=device).manual_seed(config.seed)

    images = pipeline(
        **pipeline_config,
        skip_layer_strategy=skip_layer_strategy,
        generator=generator,
        output_type="pt",
        callback_on_step_end=None,
        height=height_padded,
        width=width_padded,
        num_frames=num_frames_padded,
        frame_rate=config.frame_rate,
        **sample,
        media_items=media_item,
        conditioning_items=conditioning_items,
        is_video=True,
        vae_per_channel_normalize=True,
        image_cond_noise_scale=config.image_cond_noise_scale,
        mixed_precision=(precision == "mixed_precision"),
        offload_to_cpu=offload_to_cpu,
        device=device,
        enhance_prompt=enhance_prompt,
    ).images

    # Crop the padded images to the desired resolution and number of frames
    (pad_left, pad_right, pad_top, pad_bottom) = padding
    pad_bottom = -pad_bottom
    pad_right = -pad_right
    if pad_bottom == 0:
        pad_bottom = images.shape[3]
    if pad_right == 0:
        pad_right = images.shape[4]
    images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right]

    for i in range(images.shape[0]):
        # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
        video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
        # Unnormalizing images to [0, 255] range
        video_np = (video_np * 255).astype(np.uint8)
        fps = config.frame_rate
        height, width = video_np.shape[1:3]
        # In case a single image is generated
        if video_np.shape[0] == 1:
            output_filename = get_unique_filename(
                f"image_output_{i}",
                ".png",
                prompt=config.prompt,
                seed=config.seed,
                resolution=(height, width, config.num_frames),
                dir=output_dir,
            )
            imageio.imwrite(output_filename, video_np[0])
        else:
            output_filename = get_unique_filename(
                f"video_output_{i}",
                ".mp4",
                prompt=config.prompt,
                seed=config.seed,
                resolution=(height, width, config.num_frames),
                dir=output_dir,
            )

            # Write video
            with imageio.get_writer(output_filename, fps=fps) as video:
                for frame in video_np:
                    video.append_data(frame)

        logger.warning(f"Output saved to {output_filename}")


def prepare_conditioning(
    conditioning_media_paths: List[str],
    conditioning_strengths: List[float],
    conditioning_start_frames: List[int],
    height: int,
    width: int,
    num_frames: int,
    padding: tuple[int, int, int, int],
    pipeline: LTXVideoPipeline,
) -> Optional[List[ConditioningItem]]:
    """Prepare conditioning items based on input media paths and their parameters.

    Args:
        conditioning_media_paths: List of paths to conditioning media (images or videos)
        conditioning_strengths: List of conditioning strengths for each media item
        conditioning_start_frames: List of frame indices where each item should be applied
        height: Height of the output frames
        width: Width of the output frames
        num_frames: Number of frames in the output video
        padding: Padding to apply to the frames
        pipeline: LTXVideoPipeline object used for condition video trimming

    Returns:
        A list of ConditioningItem objects.
    """
    conditioning_items = []
    for path, strength, start_frame in zip(
        conditioning_media_paths, conditioning_strengths, conditioning_start_frames
    ):
        num_input_frames = orig_num_input_frames = get_media_num_frames(path)
        if hasattr(pipeline, "trim_conditioning_sequence") and callable(
            getattr(pipeline, "trim_conditioning_sequence")
        ):
            num_input_frames = pipeline.trim_conditioning_sequence(
                start_frame, orig_num_input_frames, num_frames
            )
        if num_input_frames < orig_num_input_frames:
            logger.warning(
                f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
            )

        media_tensor = load_media_file(
            media_path=path,
            height=height,
            width=width,
            max_frames=num_input_frames,
            padding=padding,
            just_crop=True,
        )
        conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
    return conditioning_items


def get_media_num_frames(media_path: str) -> int:
    is_video = any(
        media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
    )
    num_frames = 1
    if is_video:
        reader = imageio.get_reader(media_path)
        num_frames = reader.count_frames()
        reader.close()
    return num_frames


def load_media_file(
    media_path: str,
    height: int,
    width: int,
    max_frames: int,
    padding: tuple[int, int, int, int],
    just_crop: bool = False,
) -> torch.Tensor:
    is_video = any(
        media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
    )
    if is_video:
        reader = imageio.get_reader(media_path)
        num_input_frames = min(reader.count_frames(), max_frames)

        # Read and preprocess the relevant frames from the video file.
        frames = []
        for i in range(num_input_frames):
            frame = Image.fromarray(reader.get_data(i))
            frame_tensor = load_image_to_tensor_with_resize_and_crop(
                frame, height, width, just_crop=just_crop
            )
            frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
            frames.append(frame_tensor)
        reader.close()

        # Stack frames along the temporal dimension
        media_tensor = torch.cat(frames, dim=2)
    else:  # Input image
        media_tensor = load_image_to_tensor_with_resize_and_crop(
            media_path, height, width, just_crop=just_crop
        )
        media_tensor = torch.nn.functional.pad(media_tensor, padding)
    return media_tensor

```

## /ltx_video/models/__init__.py

```py path="/ltx_video/models/__init__.py" 

```

## /ltx_video/models/autoencoders/__init__.py

```py path="/ltx_video/models/autoencoders/__init__.py" 

```

## /ltx_video/models/autoencoders/causal_conv3d.py

```py path="/ltx_video/models/autoencoders/causal_conv3d.py" 
from typing import Tuple, Union

import torch
import torch.nn as nn


class CausalConv3d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size: int = 3,
        stride: Union[int, Tuple[int]] = 1,
        dilation: int = 1,
        groups: int = 1,
        spatial_padding_mode: str = "zeros",
        **kwargs,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        kernel_size = (kernel_size, kernel_size, kernel_size)
        self.time_kernel_size = kernel_size[0]

        dilation = (dilation, 1, 1)

        height_pad = kernel_size[1] // 2
        width_pad = kernel_size[2] // 2
        padding = (0, height_pad, width_pad)

        self.conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            dilation=dilation,
            padding=padding,
            padding_mode=spatial_padding_mode,
            groups=groups,
        )

    def forward(self, x, causal: bool = True):
        if causal:
            first_frame_pad = x[:, :, :1, :, :].repeat(
                (1, 1, self.time_kernel_size - 1, 1, 1)
            )
            x = torch.concatenate((first_frame_pad, x), dim=2)
        else:
            first_frame_pad = x[:, :, :1, :, :].repeat(
                (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
            )
            last_frame_pad = x[:, :, -1:, :, :].repeat(
                (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
            )
            x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
        x = self.conv(x)
        return x

    @property
    def weight(self):
        return self.conv.weight

```

## /ltx_video/models/autoencoders/causal_video_autoencoder.py

```py path="/ltx_video/models/autoencoders/causal_video_autoencoder.py" 
import json
import os
from functools import partial
from types import SimpleNamespace
from typing import Any, Mapping, Optional, Tuple, Union, List
from pathlib import Path

import torch
import numpy as np
from einops import rearrange
from torch import nn
from diffusers.utils import logging
import torch.nn.functional as F
from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from safetensors import safe_open


from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
from ltx_video.models.autoencoders.pixel_norm import PixelNorm
from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND
from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
from ltx_video.models.transformers.attention import Attention
from ltx_video.utils.diffusers_config_mapping import (
    diffusers_and_ours_config_mapping,
    make_hashable_key,
    VAE_KEYS_RENAME_DICT,
)

PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics."
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class CausalVideoAutoencoder(AutoencoderKLWrapper):
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        *args,
        **kwargs,
    ):
        pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
        if (
            pretrained_model_name_or_path.is_dir()
            and (pretrained_model_name_or_path / "autoencoder.pth").exists()
        ):
            config_local_path = pretrained_model_name_or_path / "config.json"
            config = cls.load_config(config_local_path, **kwargs)

            model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
            state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))

            statistics_local_path = (
                pretrained_model_name_or_path / "per_channel_statistics.json"
            )
            if statistics_local_path.exists():
                with open(statistics_local_path, "r") as file:
                    data = json.load(file)
                transposed_data = list(zip(*data["data"]))
                data_dict = {
                    col: torch.tensor(vals)
                    for col, vals in zip(data["columns"], transposed_data)
                }
                std_of_means = data_dict["std-of-means"]
                mean_of_means = data_dict.get(
                    "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
                )
                state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = (
                    std_of_means
                )
                state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = (
                    mean_of_means
                )

        elif pretrained_model_name_or_path.is_dir():
            config_path = pretrained_model_name_or_path / "vae" / "config.json"
            with open(config_path, "r") as f:
                config = make_hashable_key(json.load(f))

            assert config in diffusers_and_ours_config_mapping, (
                "Provided diffusers checkpoint config for VAE is not suppported. "
                "We only support diffusers configs found in Lightricks/LTX-Video."
            )

            config = diffusers_and_ours_config_mapping[config]

            state_dict_path = (
                pretrained_model_name_or_path
                / "vae"
                / "diffusion_pytorch_model.safetensors"
            )

            state_dict = {}
            with safe_open(state_dict_path, framework="pt", device="cpu") as f:
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
            for key in list(state_dict.keys()):
                new_key = key
                for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
                    new_key = new_key.replace(replace_key, rename_key)

                state_dict[new_key] = state_dict.pop(key)

        elif pretrained_model_name_or_path.is_file() and str(
            pretrained_model_name_or_path
        ).endswith(".safetensors"):
            state_dict = {}
            with safe_open(
                pretrained_model_name_or_path, framework="pt", device="cpu"
            ) as f:
                metadata = f.metadata()
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
            configs = json.loads(metadata["config"])
            config = configs["vae"]

        video_vae = cls.from_config(config)
        if "torch_dtype" in kwargs:
            video_vae.to(kwargs["torch_dtype"])
        video_vae.load_state_dict(state_dict)
        return video_vae

    @staticmethod
    def from_config(config):
        assert (
            config["_class_name"] == "CausalVideoAutoencoder"
        ), "config must have _class_name=CausalVideoAutoencoder"
        if isinstance(config["dims"], list):
            config["dims"] = tuple(config["dims"])

        assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"

        double_z = config.get("double_z", True)
        latent_log_var = config.get(
            "latent_log_var", "per_channel" if double_z else "none"
        )
        use_quant_conv = config.get("use_quant_conv", True)
        normalize_latent_channels = config.get("normalize_latent_channels", False)

        if use_quant_conv and latent_log_var in ["uniform", "constant"]:
            raise ValueError(
                f"latent_log_var={latent_log_var} requires use_quant_conv=False"
            )

        encoder = Encoder(
            dims=config["dims"],
            in_channels=config.get("in_channels", 3),
            out_channels=config["latent_channels"],
            blocks=config.get("encoder_blocks", config.get("blocks")),
            patch_size=config.get("patch_size", 1),
            latent_log_var=latent_log_var,
            norm_layer=config.get("norm_layer", "group_norm"),
            base_channels=config.get("encoder_base_channels", 128),
            spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
        )

        decoder = Decoder(
            dims=config["dims"],
            in_channels=config["latent_channels"],
            out_channels=config.get("out_channels", 3),
            blocks=config.get("decoder_blocks", config.get("blocks")),
            patch_size=config.get("patch_size", 1),
            norm_layer=config.get("norm_layer", "group_norm"),
            causal=config.get("causal_decoder", False),
            timestep_conditioning=config.get("timestep_conditioning", False),
            base_channels=config.get("decoder_base_channels", 128),
            spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
        )

        dims = config["dims"]
        return CausalVideoAutoencoder(
            encoder=encoder,
            decoder=decoder,
            latent_channels=config["latent_channels"],
            dims=dims,
            use_quant_conv=use_quant_conv,
            normalize_latent_channels=normalize_latent_channels,
        )

    @property
    def config(self):
        return SimpleNamespace(
            _class_name="CausalVideoAutoencoder",
            dims=self.dims,
            in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
            out_channels=self.decoder.conv_out.out_channels
            // self.decoder.patch_size**2,
            latent_channels=self.decoder.conv_in.in_channels,
            encoder_blocks=self.encoder.blocks_desc,
            decoder_blocks=self.decoder.blocks_desc,
            scaling_factor=1.0,
            norm_layer=self.encoder.norm_layer,
            patch_size=self.encoder.patch_size,
            latent_log_var=self.encoder.latent_log_var,
            use_quant_conv=self.use_quant_conv,
            causal_decoder=self.decoder.causal,
            timestep_conditioning=self.decoder.timestep_conditioning,
            normalize_latent_channels=self.normalize_latent_channels,
        )

    @property
    def is_video_supported(self):
        """
        Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
        """
        return self.dims != 2

    @property
    def spatial_downscale_factor(self):
        return (
            2
            ** len(
                [
                    block
                    for block in self.encoder.blocks_desc
                    if block[0]
                    in [
                        "compress_space",
                        "compress_all",
                        "compress_all_res",
                        "compress_space_res",
                    ]
                ]
            )
            * self.encoder.patch_size
        )

    @property
    def temporal_downscale_factor(self):
        return 2 ** len(
            [
                block
                for block in self.encoder.blocks_desc
                if block[0]
                in [
                    "compress_time",
                    "compress_all",
                    "compress_all_res",
                    "compress_time_res",
                ]
            ]
        )

    def to_json_string(self) -> str:
        import json

        return json.dumps(self.config.__dict__)

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
        if any([key.startswith("vae.") for key in state_dict.keys()]):
            state_dict = {
                key.replace("vae.", ""): value
                for key, value in state_dict.items()
                if key.startswith("vae.")
            }
        ckpt_state_dict = {
            key: value
            for key, value in state_dict.items()
            if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
        }

        model_keys = set(name for name, _ in self.named_modules())

        key_mapping = {
            ".resnets.": ".res_blocks.",
            "downsamplers.0": "downsample",
            "upsamplers.0": "upsample",
        }
        converted_state_dict = {}
        for key, value in ckpt_state_dict.items():
            for k, v in key_mapping.items():
                key = key.replace(k, v)

            key_prefix = ".".join(key.split(".")[:-1])
            if "norm" in key and key_prefix not in model_keys:
                logger.info(
                    f"Removing key {key} from state_dict as it is not present in the model"
                )
                continue

            converted_state_dict[key] = value

        super().load_state_dict(converted_state_dict, strict=strict)

        data_dict = {
            key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value
            for key, value in state_dict.items()
            if key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
        }
        if len(data_dict) > 0:
            self.register_buffer("std_of_means", data_dict["std-of-means"])
            self.register_buffer(
                "mean_of_means",
                data_dict.get(
                    "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
                ),
            )

    def last_layer(self):
        if hasattr(self.decoder, "conv_out"):
            if isinstance(self.decoder.conv_out, nn.Sequential):
                last_layer = self.decoder.conv_out[-1]
            else:
                last_layer = self.decoder.conv_out
        else:
            last_layer = self.decoder.layers[-1]
        return last_layer

    def set_use_tpu_flash_attention(self):
        for block in self.decoder.up_blocks:
            if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
                for attention_block in block.attention_blocks:
                    attention_block.set_use_tpu_flash_attention()


class Encoder(nn.Module):
    r"""
    The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.

    Args:
        dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
            The number of dimensions to use in convolutions.
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
            The blocks to use. Each block is a tuple of the block name and the number of layers.
        base_channels (`int`, *optional*, defaults to 128):
            The number of output channels for the first convolutional layer.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        patch_size (`int`, *optional*, defaults to 1):
            The patch size to use. Should be a power of 2.
        norm_layer (`str`, *optional*, defaults to `group_norm`):
            The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
        latent_log_var (`str`, *optional*, defaults to `per_channel`):
            The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
    """

    def __init__(
        self,
        dims: Union[int, Tuple[int, int]] = 3,
        in_channels: int = 3,
        out_channels: int = 3,
        blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
        base_channels: int = 128,
        norm_num_groups: int = 32,
        patch_size: Union[int, Tuple[int]] = 1,
        norm_layer: str = "group_norm",  # group_norm, pixel_norm
        latent_log_var: str = "per_channel",
        spatial_padding_mode: str = "zeros",
    ):
        super().__init__()
        self.patch_size = patch_size
        self.norm_layer = norm_layer
        self.latent_channels = out_channels
        self.latent_log_var = latent_log_var
        self.blocks_desc = blocks

        in_channels = in_channels * patch_size**2
        output_channel = base_channels

        self.conv_in = make_conv_nd(
            dims=dims,
            in_channels=in_channels,
            out_channels=output_channel,
            kernel_size=3,
            stride=1,
            padding=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

        self.down_blocks = nn.ModuleList([])

        for block_name, block_params in blocks:
            input_channel = output_channel
            if isinstance(block_params, int):
                block_params = {"num_layers": block_params}

            if block_name == "res_x":
                block = UNetMidBlock3D(
                    dims=dims,
                    in_channels=input_channel,
                    num_layers=block_params["num_layers"],
                    resnet_eps=1e-6,
                    resnet_groups=norm_num_groups,
                    norm_layer=norm_layer,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "res_x_y":
                output_channel = block_params.get("multiplier", 2) * output_channel
                block = ResnetBlock3D(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    eps=1e-6,
                    groups=norm_num_groups,
                    norm_layer=norm_layer,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_time":
                block = make_conv_nd(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    kernel_size=3,
                    stride=(2, 1, 1),
                    causal=True,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_space":
                block = make_conv_nd(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    kernel_size=3,
                    stride=(1, 2, 2),
                    causal=True,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_all":
                block = make_conv_nd(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    kernel_size=3,
                    stride=(2, 2, 2),
                    causal=True,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_all_x_y":
                output_channel = block_params.get("multiplier", 2) * output_channel
                block = make_conv_nd(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    kernel_size=3,
                    stride=(2, 2, 2),
                    causal=True,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_all_res":
                output_channel = block_params.get("multiplier", 2) * output_channel
                block = SpaceToDepthDownsample(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    stride=(2, 2, 2),
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_space_res":
                output_channel = block_params.get("multiplier", 2) * output_channel
                block = SpaceToDepthDownsample(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    stride=(1, 2, 2),
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_time_res":
                output_channel = block_params.get("multiplier", 2) * output_channel
                block = SpaceToDepthDownsample(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    stride=(2, 1, 1),
                    spatial_padding_mode=spatial_padding_mode,
                )
            else:
                raise ValueError(f"unknown block: {block_name}")

            self.down_blocks.append(block)

        # out
        if norm_layer == "group_norm":
            self.conv_norm_out = nn.GroupNorm(
                num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
            )
        elif norm_layer == "pixel_norm":
            self.conv_norm_out = PixelNorm()
        elif norm_layer == "layer_norm":
            self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)

        self.conv_act = nn.SiLU()

        conv_out_channels = out_channels
        if latent_log_var == "per_channel":
            conv_out_channels *= 2
        elif latent_log_var == "uniform":
            conv_out_channels += 1
        elif latent_log_var == "constant":
            conv_out_channels += 1
        elif latent_log_var != "none":
            raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
        self.conv_out = make_conv_nd(
            dims,
            output_channel,
            conv_out_channels,
            3,
            padding=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

        self.gradient_checkpointing = False

    def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
        r"""The forward method of the `Encoder` class."""

        sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
        sample = self.conv_in(sample)

        checkpoint_fn = (
            partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
            if self.gradient_checkpointing and self.training
            else lambda x: x
        )

        for down_block in self.down_blocks:
            sample = checkpoint_fn(down_block)(sample)

        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        if self.latent_log_var == "uniform":
            last_channel = sample[:, -1:, ...]
            num_dims = sample.dim()

            if num_dims == 4:
                # For shape (B, C, H, W)
                repeated_last_channel = last_channel.repeat(
                    1, sample.shape[1] - 2, 1, 1
                )
                sample = torch.cat([sample, repeated_last_channel], dim=1)
            elif num_dims == 5:
                # For shape (B, C, F, H, W)
                repeated_last_channel = last_channel.repeat(
                    1, sample.shape[1] - 2, 1, 1, 1
                )
                sample = torch.cat([sample, repeated_last_channel], dim=1)
            else:
                raise ValueError(f"Invalid input shape: {sample.shape}")
        elif self.latent_log_var == "constant":
            sample = sample[:, :-1, ...]
            approx_ln_0 = (
                -30
            )  # this is the minimal clamp value in DiagonalGaussianDistribution objects
            sample = torch.cat(
                [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
                dim=1,
            )

        return sample


class Decoder(nn.Module):
    r"""
    The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.

    Args:
        dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
            The number of dimensions to use in convolutions.
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
            The blocks to use. Each block is a tuple of the block name and the number of layers.
        base_channels (`int`, *optional*, defaults to 128):
            The number of output channels for the first convolutional layer.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        patch_size (`int`, *optional*, defaults to 1):
            The patch size to use. Should be a power of 2.
        norm_layer (`str`, *optional*, defaults to `group_norm`):
            The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
        causal (`bool`, *optional*, defaults to `True`):
            Whether to use causal convolutions or not.
    """

    def __init__(
        self,
        dims,
        in_channels: int = 3,
        out_channels: int = 3,
        blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
        base_channels: int = 128,
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        patch_size: int = 1,
        norm_layer: str = "group_norm",
        causal: bool = True,
        timestep_conditioning: bool = False,
        spatial_padding_mode: str = "zeros",
    ):
        super().__init__()
        self.patch_size = patch_size
        self.layers_per_block = layers_per_block
        out_channels = out_channels * patch_size**2
        self.causal = causal
        self.blocks_desc = blocks

        # Compute output channel to be product of all channel-multiplier blocks
        output_channel = base_channels
        for block_name, block_params in list(reversed(blocks)):
            block_params = block_params if isinstance(block_params, dict) else {}
            if block_name == "res_x_y":
                output_channel = output_channel * block_params.get("multiplier", 2)
            if block_name.startswith("compress"):
                output_channel = output_channel * block_params.get("multiplier", 1)

        self.conv_in = make_conv_nd(
            dims,
            in_channels,
            output_channel,
            kernel_size=3,
            stride=1,
            padding=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

        self.up_blocks = nn.ModuleList([])

        for block_name, block_params in list(reversed(blocks)):
            input_channel = output_channel
            if isinstance(block_params, int):
                block_params = {"num_layers": block_params}

            if block_name == "res_x":
                block = UNetMidBlock3D(
                    dims=dims,
                    in_channels=input_channel,
                    num_layers=block_params["num_layers"],
                    resnet_eps=1e-6,
                    resnet_groups=norm_num_groups,
                    norm_layer=norm_layer,
                    inject_noise=block_params.get("inject_noise", False),
                    timestep_conditioning=timestep_conditioning,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "attn_res_x":
                block = UNetMidBlock3D(
                    dims=dims,
                    in_channels=input_channel,
                    num_layers=block_params["num_layers"],
                    resnet_groups=norm_num_groups,
                    norm_layer=norm_layer,
                    inject_noise=block_params.get("inject_noise", False),
                    timestep_conditioning=timestep_conditioning,
                    attention_head_dim=block_params["attention_head_dim"],
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "res_x_y":
                output_channel = output_channel // block_params.get("multiplier", 2)
                block = ResnetBlock3D(
                    dims=dims,
                    in_channels=input_channel,
                    out_channels=output_channel,
                    eps=1e-6,
                    groups=norm_num_groups,
                    norm_layer=norm_layer,
                    inject_noise=block_params.get("inject_noise", False),
                    timestep_conditioning=False,
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_time":
                block = DepthToSpaceUpsample(
                    dims=dims,
                    in_channels=input_channel,
                    stride=(2, 1, 1),
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_space":
                block = DepthToSpaceUpsample(
                    dims=dims,
                    in_channels=input_channel,
                    stride=(1, 2, 2),
                    spatial_padding_mode=spatial_padding_mode,
                )
            elif block_name == "compress_all":
                output_channel = output_channel // block_params.get("multiplier", 1)
                block = DepthToSpaceUpsample(
                    dims=dims,
                    in_channels=input_channel,
                    stride=(2, 2, 2),
                    residual=block_params.get("residual", False),
                    out_channels_reduction_factor=block_params.get("multiplier", 1),
                    spatial_padding_mode=spatial_padding_mode,
                )
            else:
                raise ValueError(f"unknown layer: {block_name}")

            self.up_blocks.append(block)

        if norm_layer == "group_norm":
            self.conv_norm_out = nn.GroupNorm(
                num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
            )
        elif norm_layer == "pixel_norm":
            self.conv_norm_out = PixelNorm()
        elif norm_layer == "layer_norm":
            self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)

        self.conv_act = nn.SiLU()
        self.conv_out = make_conv_nd(
            dims,
            output_channel,
            out_channels,
            3,
            padding=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

        self.gradient_checkpointing = False

        self.timestep_conditioning = timestep_conditioning

        if timestep_conditioning:
            self.timestep_scale_multiplier = nn.Parameter(
                torch.tensor(1000.0, dtype=torch.float32)
            )
            self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
                output_channel * 2, 0
            )
            self.last_scale_shift_table = nn.Parameter(
                torch.randn(2, output_channel) / output_channel**0.5
            )

    def forward(
        self,
        sample: torch.FloatTensor,
        target_shape,
        timestep: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        r"""The forward method of the `Decoder` class."""
        assert target_shape is not None, "target_shape must be provided"
        batch_size = sample.shape[0]

        sample = self.conv_in(sample, causal=self.causal)

        upscale_dtype = next(iter(self.up_blocks.parameters())).dtype

        checkpoint_fn = (
            partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
            if self.gradient_checkpointing and self.training
            else lambda x: x
        )

        sample = sample.to(upscale_dtype)

        if self.timestep_conditioning:
            assert (
                timestep is not None
            ), "should pass timestep with timestep_conditioning=True"
            scaled_timestep = timestep * self.timestep_scale_multiplier

        for up_block in self.up_blocks:
            if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
                sample = checkpoint_fn(up_block)(
                    sample, causal=self.causal, timestep=scaled_timestep
                )
            else:
                sample = checkpoint_fn(up_block)(sample, causal=self.causal)

        sample = self.conv_norm_out(sample)

        if self.timestep_conditioning:
            embedded_timestep = self.last_time_embedder(
                timestep=scaled_timestep.flatten(),
                resolution=None,
                aspect_ratio=None,
                batch_size=sample.shape[0],
                hidden_dtype=sample.dtype,
            )
            embedded_timestep = embedded_timestep.view(
                batch_size, embedded_timestep.shape[-1], 1, 1, 1
            )
            ada_values = self.last_scale_shift_table[
                None, ..., None, None, None
            ] + embedded_timestep.reshape(
                batch_size,
                2,
                -1,
                embedded_timestep.shape[-3],
                embedded_timestep.shape[-2],
                embedded_timestep.shape[-1],
            )
            shift, scale = ada_values.unbind(dim=1)
            sample = sample * (1 + scale) + shift

        sample = self.conv_act(sample)
        sample = self.conv_out(sample, causal=self.causal)

        sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)

        return sample


class UNetMidBlock3D(nn.Module):
    """
    A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.

    Args:
        in_channels (`int`): The number of input channels.
        dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
        num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
        resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
        resnet_groups (`int`, *optional*, defaults to 32):
            The number of groups to use in the group normalization layers of the resnet blocks.
        norm_layer (`str`, *optional*, defaults to `group_norm`):
            The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
        inject_noise (`bool`, *optional*, defaults to `False`):
            Whether to inject noise into the hidden states.
        timestep_conditioning (`bool`, *optional*, defaults to `False`):
            Whether to condition the hidden states on the timestep.
        attention_head_dim (`int`, *optional*, defaults to -1):
            The dimension of the attention head. If -1, no attention is used.

    Returns:
        `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
        in_channels, height, width)`.

    """

    def __init__(
        self,
        dims: Union[int, Tuple[int, int]],
        in_channels: int,
        dropout: float = 0.0,
        num_layers: int = 1,
        resnet_eps: float = 1e-6,
        resnet_groups: int = 32,
        norm_layer: str = "group_norm",
        inject_noise: bool = False,
        timestep_conditioning: bool = False,
        attention_head_dim: int = -1,
        spatial_padding_mode: str = "zeros",
    ):
        super().__init__()
        resnet_groups = (
            resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
        )
        self.timestep_conditioning = timestep_conditioning

        if timestep_conditioning:
            self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
                in_channels * 4, 0
            )

        self.res_blocks = nn.ModuleList(
            [
                ResnetBlock3D(
                    dims=dims,
                    in_channels=in_channels,
                    out_channels=in_channels,
                    eps=resnet_eps,
                    groups=resnet_groups,
                    dropout=dropout,
                    norm_layer=norm_layer,
                    inject_noise=inject_noise,
                    timestep_conditioning=timestep_conditioning,
                    spatial_padding_mode=spatial_padding_mode,
                )
                for _ in range(num_layers)
            ]
        )

        self.attention_blocks = None

        if attention_head_dim > 0:
            if attention_head_dim > in_channels:
                raise ValueError(
                    "attention_head_dim must be less than or equal to in_channels"
                )

            self.attention_blocks = nn.ModuleList(
                [
                    Attention(
                        query_dim=in_channels,
                        heads=in_channels // attention_head_dim,
                        dim_head=attention_head_dim,
                        bias=True,
                        out_bias=True,
                        qk_norm="rms_norm",
                        residual_connection=True,
                    )
                    for _ in range(num_layers)
                ]
            )

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        causal: bool = True,
        timestep: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        timestep_embed = None
        if self.timestep_conditioning:
            assert (
                timestep is not None
            ), "should pass timestep with timestep_conditioning=True"
            batch_size = hidden_states.shape[0]
            timestep_embed = self.time_embedder(
                timestep=timestep.flatten(),
                resolution=None,
                aspect_ratio=None,
                batch_size=batch_size,
                hidden_dtype=hidden_states.dtype,
            )
            timestep_embed = timestep_embed.view(
                batch_size, timestep_embed.shape[-1], 1, 1, 1
            )

        if self.attention_blocks:
            for resnet, attention in zip(self.res_blocks, self.attention_blocks):
                hidden_states = resnet(
                    hidden_states, causal=causal, timestep=timestep_embed
                )

                # Reshape the hidden states to be (batch_size, frames * height * width, channel)
                batch_size, channel, frames, height, width = hidden_states.shape
                hidden_states = hidden_states.view(
                    batch_size, channel, frames * height * width
                ).transpose(1, 2)

                if attention.use_tpu_flash_attention:
                    # Pad the second dimension to be divisible by block_k_major (block in flash attention)
                    seq_len = hidden_states.shape[1]
                    block_k_major = 512
                    pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
                    if pad_len > 0:
                        hidden_states = F.pad(
                            hidden_states, (0, 0, 0, pad_len), "constant", 0
                        )

                    # Create a mask with ones for the original sequence length and zeros for the padded indexes
                    mask = torch.ones(
                        (hidden_states.shape[0], seq_len),
                        device=hidden_states.device,
                        dtype=hidden_states.dtype,
                    )
                    if pad_len > 0:
                        mask = F.pad(mask, (0, pad_len), "constant", 0)

                hidden_states = attention(
                    hidden_states,
                    attention_mask=(
                        None if not attention.use_tpu_flash_attention else mask
                    ),
                )

                if attention.use_tpu_flash_attention:
                    # Remove the padding
                    if pad_len > 0:
                        hidden_states = hidden_states[:, :-pad_len, :]

                # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
                hidden_states = hidden_states.transpose(-1, -2).reshape(
                    batch_size, channel, frames, height, width
                )
        else:
            for resnet in self.res_blocks:
                hidden_states = resnet(
                    hidden_states, causal=causal, timestep=timestep_embed
                )

        return hidden_states


class SpaceToDepthDownsample(nn.Module):
    def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
        super().__init__()
        self.stride = stride
        self.group_size = in_channels * np.prod(stride) // out_channels
        self.conv = make_conv_nd(
            dims=dims,
            in_channels=in_channels,
            out_channels=out_channels // np.prod(stride),
            kernel_size=3,
            stride=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

    def forward(self, x, causal: bool = True):
        if self.stride[0] == 2:
            x = torch.cat(
                [x[:, :, :1, :, :], x], dim=2
            )  # duplicate first frames for padding

        # skip connection
        x_in = rearrange(
            x,
            "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
            p1=self.stride[0],
            p2=self.stride[1],
            p3=self.stride[2],
        )
        x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
        x_in = x_in.mean(dim=2)

        # conv
        x = self.conv(x, causal=causal)
        x = rearrange(
            x,
            "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
            p1=self.stride[0],
            p2=self.stride[1],
            p3=self.stride[2],
        )

        x = x + x_in

        return x


class DepthToSpaceUpsample(nn.Module):
    def __init__(
        self,
        dims,
        in_channels,
        stride,
        residual=False,
        out_channels_reduction_factor=1,
        spatial_padding_mode="zeros",
    ):
        super().__init__()
        self.stride = stride
        self.out_channels = (
            np.prod(stride) * in_channels // out_channels_reduction_factor
        )
        self.conv = make_conv_nd(
            dims=dims,
            in_channels=in_channels,
            out_channels=self.out_channels,
            kernel_size=3,
            stride=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )
        self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride)
        self.residual = residual
        self.out_channels_reduction_factor = out_channels_reduction_factor

    def forward(self, x, causal: bool = True):
        if self.residual:
            # Reshape and duplicate the input to match the output shape
            x_in = self.pixel_shuffle(x)
            num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
            x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
            if self.stride[0] == 2:
                x_in = x_in[:, :, 1:, :, :]
        x = self.conv(x, causal=causal)
        x = self.pixel_shuffle(x)
        if self.stride[0] == 2:
            x = x[:, :, 1:, :, :]
        if self.residual:
            x = x + x_in
        return x


class LayerNorm(nn.Module):
    def __init__(self, dim, eps, elementwise_affine=True) -> None:
        super().__init__()
        self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)

    def forward(self, x):
        x = rearrange(x, "b c d h w -> b d h w c")
        x = self.norm(x)
        x = rearrange(x, "b d h w c -> b c d h w")
        return x


class ResnetBlock3D(nn.Module):
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
    """

    def __init__(
        self,
        dims: Union[int, Tuple[int, int]],
        in_channels: int,
        out_channels: Optional[int] = None,
        dropout: float = 0.0,
        groups: int = 32,
        eps: float = 1e-6,
        norm_layer: str = "group_norm",
        inject_noise: bool = False,
        timestep_conditioning: bool = False,
        spatial_padding_mode: str = "zeros",
    ):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.inject_noise = inject_noise

        if norm_layer == "group_norm":
            self.norm1 = nn.GroupNorm(
                num_groups=groups, num_channels=in_channels, eps=eps, affine=True
            )
        elif norm_layer == "pixel_norm":
            self.norm1 = PixelNorm()
        elif norm_layer == "layer_norm":
            self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)

        self.non_linearity = nn.SiLU()

        self.conv1 = make_conv_nd(
            dims,
            in_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

        if inject_noise:
            self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))

        if norm_layer == "group_norm":
            self.norm2 = nn.GroupNorm(
                num_groups=groups, num_channels=out_channels, eps=eps, affine=True
            )
        elif norm_layer == "pixel_norm":
            self.norm2 = PixelNorm()
        elif norm_layer == "layer_norm":
            self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)

        self.dropout = torch.nn.Dropout(dropout)

        self.conv2 = make_conv_nd(
            dims,
            out_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            causal=True,
            spatial_padding_mode=spatial_padding_mode,
        )

        if inject_noise:
            self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))

        self.conv_shortcut = (
            make_linear_nd(
                dims=dims, in_channels=in_channels, out_channels=out_channels
            )
            if in_channels != out_channels
            else nn.Identity()
        )

        self.norm3 = (
            LayerNorm(in_channels, eps=eps, elementwise_affine=True)
            if in_channels != out_channels
            else nn.Identity()
        )

        self.timestep_conditioning = timestep_conditioning

        if timestep_conditioning:
            self.scale_shift_table = nn.Parameter(
                torch.randn(4, in_channels) / in_channels**0.5
            )

    def _feed_spatial_noise(
        self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
    ) -> torch.FloatTensor:
        spatial_shape = hidden_states.shape[-2:]
        device = hidden_states.device
        dtype = hidden_states.dtype

        # similar to the "explicit noise inputs" method in style-gan
        spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
        scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
        hidden_states = hidden_states + scaled_noise

        return hidden_states

    def forward(
        self,
        input_tensor: torch.FloatTensor,
        causal: bool = True,
        timestep: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        hidden_states = input_tensor
        batch_size = hidden_states.shape[0]

        hidden_states = self.norm1(hidden_states)
        if self.timestep_conditioning:
            assert (
                timestep is not None
            ), "should pass timestep with timestep_conditioning=True"
            ada_values = self.scale_shift_table[
                None, ..., None, None, None
            ] + timestep.reshape(
                batch_size,
                4,
                -1,
                timestep.shape[-3],
                timestep.shape[-2],
                timestep.shape[-1],
            )
            shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)

            hidden_states = hidden_states * (1 + scale1) + shift1

        hidden_states = self.non_linearity(hidden_states)

        hidden_states = self.conv1(hidden_states, causal=causal)

        if self.inject_noise:
            hidden_states = self._feed_spatial_noise(
                hidden_states, self.per_channel_scale1
            )

        hidden_states = self.norm2(hidden_states)

        if self.timestep_conditioning:
            hidden_states = hidden_states * (1 + scale2) + shift2

        hidden_states = self.non_linearity(hidden_states)

        hidden_states = self.dropout(hidden_states)

        hidden_states = self.conv2(hidden_states, causal=causal)

        if self.inject_noise:
            hidden_states = self._feed_spatial_noise(
                hidden_states, self.per_channel_scale2
            )

        input_tensor = self.norm3(input_tensor)

        batch_size = input_tensor.shape[0]

        input_tensor = self.conv_shortcut(input_tensor)

        output_tensor = input_tensor + hidden_states

        return output_tensor


def patchify(x, patch_size_hw, patch_size_t=1):
    if patch_size_hw == 1 and patch_size_t == 1:
        return x
    if x.dim() == 4:
        x = rearrange(
            x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
        )
    elif x.dim() == 5:
        x = rearrange(
            x,
            "b c (f p) (h q) (w r) -> b (c p r q) f h w",
            p=patch_size_t,
            q=patch_size_hw,
            r=patch_size_hw,
        )
    else:
        raise ValueError(f"Invalid input shape: {x.shape}")

    return x


def unpatchify(x, patch_size_hw, patch_size_t=1):
    if patch_size_hw == 1 and patch_size_t == 1:
        return x

    if x.dim() == 4:
        x = rearrange(
            x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
        )
    elif x.dim() == 5:
        x = rearrange(
            x,
            "b (c p r q) f h w -> b c (f p) (h q) (w r)",
            p=patch_size_t,
            q=patch_size_hw,
            r=patch_size_hw,
        )

    return x


def create_video_autoencoder_demo_config(
    latent_channels: int = 64,
):
    encoder_blocks = [
        ("res_x", {"num_layers": 2}),
        ("compress_space_res", {"multiplier": 2}),
        ("compress_time_res", {"multiplier": 2}),
        ("compress_all_res", {"multiplier": 2}),
        ("compress_all_res", {"multiplier": 2}),
        ("res_x", {"num_layers": 1}),
    ]
    decoder_blocks = [
        ("res_x", {"num_layers": 2, "inject_noise": False}),
        ("compress_all", {"residual": True, "multiplier": 2}),
        ("compress_all", {"residual": True, "multiplier": 2}),
        ("compress_all", {"residual": True, "multiplier": 2}),
        ("res_x", {"num_layers": 2, "inject_noise": False}),
    ]
    return {
        "_class_name": "CausalVideoAutoencoder",
        "dims": 3,
        "encoder_blocks": encoder_blocks,
        "decoder_blocks": decoder_blocks,
        "latent_channels": latent_channels,
        "norm_layer": "pixel_norm",
        "patch_size": 4,
        "latent_log_var": "uniform",
        "use_quant_conv": False,
        "causal_decoder": False,
        "timestep_conditioning": True,
        "spatial_padding_mode": "replicate",
    }


def test_vae_patchify_unpatchify():
    import torch

    x = torch.randn(2, 3, 8, 64, 64)
    x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
    x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
    assert torch.allclose(x, x_unpatched)


def demo_video_autoencoder_forward_backward():
    # Configuration for the VideoAutoencoder
    config = create_video_autoencoder_demo_config()

    # Instantiate the VideoAutoencoder with the specified configuration
    video_autoencoder = CausalVideoAutoencoder.from_config(config)

    print(video_autoencoder)
    video_autoencoder.eval()
    # Print the total number of parameters in the video autoencoder
    total_params = sum(p.numel() for p in video_autoencoder.parameters())
    print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")

    # Create a mock input tensor simulating a batch of videos
    # Shape: (batch_size, channels, depth, height, width)
    # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
    input_videos = torch.randn(2, 3, 17, 64, 64)

    # Forward pass: encode and decode the input videos
    latent = video_autoencoder.encode(input_videos).latent_dist.mode()
    print(f"input shape={input_videos.shape}")
    print(f"latent shape={latent.shape}")

    timestep = torch.ones(input_videos.shape[0]) * 0.1
    reconstructed_videos = video_autoencoder.decode(
        latent, target_shape=input_videos.shape, timestep=timestep
    ).sample

    print(f"reconstructed shape={reconstructed_videos.shape}")

    # Validate that single image gets treated the same way as first frame
    input_image = input_videos[:, :, :1, :, :]
    image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
    _ = video_autoencoder.decode(
        image_latent, target_shape=image_latent.shape, timestep=timestep
    ).sample

    first_frame_latent = latent[:, :, :1, :, :]

    assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
    # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6)
    # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
    # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all()

    # Calculate the loss (e.g., mean squared error)
    loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)

    # Perform backward pass
    loss.backward()

    print(f"Demo completed with loss: {loss.item()}")


# Ensure to call the demo function to execute the forward and backward pass
if __name__ == "__main__":
    demo_video_autoencoder_forward_backward()

```

## /ltx_video/models/autoencoders/conv_nd_factory.py

```py path="/ltx_video/models/autoencoders/conv_nd_factory.py" 
from typing import Tuple, Union

import torch

from ltx_video.models.autoencoders.dual_conv3d import DualConv3d
from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d


def make_conv_nd(
    dims: Union[int, Tuple[int, int]],
    in_channels: int,
    out_channels: int,
    kernel_size: int,
    stride=1,
    padding=0,
    dilation=1,
    groups=1,
    bias=True,
    causal=False,
    spatial_padding_mode="zeros",
    temporal_padding_mode="zeros",
):
    if not (spatial_padding_mode == temporal_padding_mode or causal):
        raise NotImplementedError("spatial and temporal padding modes must be equal")
    if dims == 2:
        return torch.nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=spatial_padding_mode,
        )
    elif dims == 3:
        if causal:
            return CausalConv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=groups,
                bias=bias,
                spatial_padding_mode=spatial_padding_mode,
            )
        return torch.nn.Conv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=spatial_padding_mode,
        )
    elif dims == (2, 1):
        return DualConv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=bias,
            padding_mode=spatial_padding_mode,
        )
    else:
        raise ValueError(f"unsupported dimensions: {dims}")


def make_linear_nd(
    dims: int,
    in_channels: int,
    out_channels: int,
    bias=True,
):
    if dims == 2:
        return torch.nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
        )
    elif dims == 3 or dims == (2, 1):
        return torch.nn.Conv3d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
        )
    else:
        raise ValueError(f"unsupported dimensions: {dims}")

```

## /ltx_video/models/autoencoders/dual_conv3d.py

```py path="/ltx_video/models/autoencoders/dual_conv3d.py" 
import math
from typing import Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class DualConv3d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride: Union[int, Tuple[int, int, int]] = 1,
        padding: Union[int, Tuple[int, int, int]] = 0,
        dilation: Union[int, Tuple[int, int, int]] = 1,
        groups=1,
        bias=True,
        padding_mode="zeros",
    ):
        super(DualConv3d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.padding_mode = padding_mode
        # Ensure kernel_size, stride, padding, and dilation are tuples of length 3
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size, kernel_size)
        if kernel_size == (1, 1, 1):
            raise ValueError(
                "kernel_size must be greater than 1. Use make_linear_nd instead."
            )
        if isinstance(stride, int):
            stride = (stride, stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding, padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation, dilation)

        # Set parameters for convolutions
        self.groups = groups
        self.bias = bias

        # Define the size of the channels after the first convolution
        intermediate_channels = (
            out_channels if in_channels < out_channels else in_channels
        )

        # Define parameters for the first convolution
        self.weight1 = nn.Parameter(
            torch.Tensor(
                intermediate_channels,
                in_channels // groups,
                1,
                kernel_size[1],
                kernel_size[2],
            )
        )
        self.stride1 = (1, stride[1], stride[2])
        self.padding1 = (0, padding[1], padding[2])
        self.dilation1 = (1, dilation[1], dilation[2])
        if bias:
            self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
        else:
            self.register_parameter("bias1", None)

        # Define parameters for the second convolution
        self.weight2 = nn.Parameter(
            torch.Tensor(
                out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
            )
        )
        self.stride2 = (stride[0], 1, 1)
        self.padding2 = (padding[0], 0, 0)
        self.dilation2 = (dilation[0], 1, 1)
        if bias:
            self.bias2 = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias2", None)

        # Initialize weights and biases
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
        if self.bias:
            fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
            bound1 = 1 / math.sqrt(fan_in1)
            nn.init.uniform_(self.bias1, -bound1, bound1)
            fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
            bound2 = 1 / math.sqrt(fan_in2)
            nn.init.uniform_(self.bias2, -bound2, bound2)

    def forward(self, x, use_conv3d=False, skip_time_conv=False):
        if use_conv3d:
            return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
        else:
            return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)

    def forward_with_3d(self, x, skip_time_conv):
        # First convolution
        x = F.conv3d(
            x,
            self.weight1,
            self.bias1,
            self.stride1,
            self.padding1,
            self.dilation1,
            self.groups,
            padding_mode=self.padding_mode,
        )

        if skip_time_conv:
            return x

        # Second convolution
        x = F.conv3d(
            x,
            self.weight2,
            self.bias2,
            self.stride2,
            self.padding2,
            self.dilation2,
            self.groups,
            padding_mode=self.padding_mode,
        )

        return x

    def forward_with_2d(self, x, skip_time_conv):
        b, c, d, h, w = x.shape

        # First 2D convolution
        x = rearrange(x, "b c d h w -> (b d) c h w")
        # Squeeze the depth dimension out of weight1 since it's 1
        weight1 = self.weight1.squeeze(2)
        # Select stride, padding, and dilation for the 2D convolution
        stride1 = (self.stride1[1], self.stride1[2])
        padding1 = (self.padding1[1], self.padding1[2])
        dilation1 = (self.dilation1[1], self.dilation1[2])
        x = F.conv2d(
            x,
            weight1,
            self.bias1,
            stride1,
            padding1,
            dilation1,
            self.groups,
            padding_mode=self.padding_mode,
        )

        _, _, h, w = x.shape

        if skip_time_conv:
            x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
            return x

        # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
        x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)

        # Reshape weight2 to match the expected dimensions for conv1d
        weight2 = self.weight2.squeeze(-1).squeeze(-1)
        # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
        stride2 = self.stride2[0]
        padding2 = self.padding2[0]
        dilation2 = self.dilation2[0]
        x = F.conv1d(
            x,
            weight2,
            self.bias2,
            stride2,
            padding2,
            dilation2,
            self.groups,
            padding_mode=self.padding_mode,
        )
        x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)

        return x

    @property
    def weight(self):
        return self.weight2


def test_dual_conv3d_consistency():
    # Initialize parameters
    in_channels = 3
    out_channels = 5
    kernel_size = (3, 3, 3)
    stride = (2, 2, 2)
    padding = (1, 1, 1)

    # Create an instance of the DualConv3d class
    dual_conv3d = DualConv3d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        bias=True,
    )

    # Example input tensor
    test_input = torch.randn(1, 3, 10, 10, 10)

    # Perform forward passes with both 3D and 2D settings
    output_conv3d = dual_conv3d(test_input, use_conv3d=True)
    output_2d = dual_conv3d(test_input, use_conv3d=False)

    # Assert that the outputs from both methods are sufficiently close
    assert torch.allclose(
        output_conv3d, output_2d, atol=1e-6
    ), "Outputs are not consistent between 3D and 2D convolutions."

```

## /ltx_video/models/autoencoders/latent_upsampler.py

```py path="/ltx_video/models/autoencoders/latent_upsampler.py" 
from typing import Optional, Union
from pathlib import Path
import os
import json

import torch
import torch.nn as nn
from einops import rearrange
from diffusers import ConfigMixin, ModelMixin
from safetensors.torch import safe_open

from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND


class ResBlock(nn.Module):
    def __init__(
        self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
    ):
        super().__init__()
        if mid_channels is None:
            mid_channels = channels

        Conv = nn.Conv2d if dims == 2 else nn.Conv3d

        self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(32, mid_channels)
        self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(32, channels)
        self.activation = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.activation(x + residual)
        return x


class LatentUpsampler(ModelMixin, ConfigMixin):
    """
    Model to spatially upsample VAE latents.

    Args:
        in_channels (`int`): Number of channels in the input latent
        mid_channels (`int`): Number of channels in the middle layers
        num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
        dims (`int`): Number of dimensions for convolutions (2 or 3)
        spatial_upsample (`bool`): Whether to spatially upsample the latent
        temporal_upsample (`bool`): Whether to temporally upsample the latent
    """

    def __init__(
        self,
        in_channels: int = 128,
        mid_channels: int = 512,
        num_blocks_per_stage: int = 4,
        dims: int = 3,
        spatial_upsample: bool = True,
        temporal_upsample: bool = False,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.mid_channels = mid_channels
        self.num_blocks_per_stage = num_blocks_per_stage
        self.dims = dims
        self.spatial_upsample = spatial_upsample
        self.temporal_upsample = temporal_upsample

        Conv = nn.Conv2d if dims == 2 else nn.Conv3d

        self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
        self.initial_norm = nn.GroupNorm(32, mid_channels)
        self.initial_activation = nn.SiLU()

        self.res_blocks = nn.ModuleList(
            [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
        )

        if spatial_upsample and temporal_upsample:
            self.upsampler = nn.Sequential(
                nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
                PixelShuffleND(3),
            )
        elif spatial_upsample:
            self.upsampler = nn.Sequential(
                nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
                PixelShuffleND(2),
            )
        elif temporal_upsample:
            self.upsampler = nn.Sequential(
                nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
                PixelShuffleND(1),
            )
        else:
            raise ValueError(
                "Either spatial_upsample or temporal_upsample must be True"
            )

        self.post_upsample_res_blocks = nn.ModuleList(
            [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
        )

        self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)

    def forward(self, latent: torch.Tensor) -> torch.Tensor:
        b, c, f, h, w = latent.shape

        if self.dims == 2:
            x = rearrange(latent, "b c f h w -> (b f) c h w")
            x = self.initial_conv(x)
            x = self.initial_norm(x)
            x = self.initial_activation(x)

            for block in self.res_blocks:
                x = block(x)

            x = self.upsampler(x)

            for block in self.post_upsample_res_blocks:
                x = block(x)

            x = self.final_conv(x)
            x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
        else:
            x = self.initial_conv(latent)
            x = self.initial_norm(x)
            x = self.initial_activation(x)

            for block in self.res_blocks:
                x = block(x)

            if self.temporal_upsample:
                x = self.upsampler(x)
                x = x[:, :, 1:, :, :]
            else:
                x = rearrange(x, "b c f h w -> (b f) c h w")
                x = self.upsampler(x)
                x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)

            for block in self.post_upsample_res_blocks:
                x = block(x)

            x = self.final_conv(x)

        return x

    @classmethod
    def from_config(cls, config):
        return cls(
            in_channels=config.get("in_channels", 4),
            mid_channels=config.get("mid_channels", 128),
            num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
            dims=config.get("dims", 2),
            spatial_upsample=config.get("spatial_upsample", True),
            temporal_upsample=config.get("temporal_upsample", False),
        )

    def config(self):
        return {
            "_class_name": "LatentUpsampler",
            "in_channels": self.in_channels,
            "mid_channels": self.mid_channels,
            "num_blocks_per_stage": self.num_blocks_per_stage,
            "dims": self.dims,
            "spatial_upsample": self.spatial_upsample,
            "temporal_upsample": self.temporal_upsample,
        }

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_path: Optional[Union[str, os.PathLike]],
        *args,
        **kwargs,
    ):
        pretrained_model_path = Path(pretrained_model_path)
        if pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
            ".safetensors"
        ):
            state_dict = {}
            with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
                metadata = f.metadata()
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
            config = json.loads(metadata["config"])
            with torch.device("meta"):
                latent_upsampler = LatentUpsampler.from_config(config)
            latent_upsampler.load_state_dict(state_dict, assign=True)
        return latent_upsampler


if __name__ == "__main__":
    latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3)
    print(latent_upsampler)
    total_params = sum(p.numel() for p in latent_upsampler.parameters())
    print(f"Total number of parameters: {total_params:,}")
    latent = torch.randn(1, 128, 9, 16, 16)
    upsampled_latent = latent_upsampler(latent)
    print(f"Upsampled latent shape: {upsampled_latent.shape}")

```

## /ltx_video/models/autoencoders/pixel_norm.py

```py path="/ltx_video/models/autoencoders/pixel_norm.py" 
import torch
from torch import nn


class PixelNorm(nn.Module):
    def __init__(self, dim=1, eps=1e-8):
        super(PixelNorm, self).__init__()
        self.dim = dim
        self.eps = eps

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)

```

## /ltx_video/models/autoencoders/pixel_shuffle.py

```py path="/ltx_video/models/autoencoders/pixel_shuffle.py" 
import torch.nn as nn
from einops import rearrange


class PixelShuffleND(nn.Module):
    def __init__(self, dims, upscale_factors=(2, 2, 2)):
        super().__init__()
        assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
        self.dims = dims
        self.upscale_factors = upscale_factors

    def forward(self, x):
        if self.dims == 3:
            return rearrange(
                x,
                "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
                p1=self.upscale_factors[0],
                p2=self.upscale_factors[1],
                p3=self.upscale_factors[2],
            )
        elif self.dims == 2:
            return rearrange(
                x,
                "b (c p1 p2) h w -> b c (h p1) (w p2)",
                p1=self.upscale_factors[0],
                p2=self.upscale_factors[1],
            )
        elif self.dims == 1:
            return rearrange(
                x,
                "b (c p1) f h w -> b c (f p1) h w",
                p1=self.upscale_factors[0],
            )

```

## /ltx_video/models/autoencoders/vae.py

```py path="/ltx_video/models/autoencoders/vae.py" 
from typing import Optional, Union

import torch
import inspect
import math
import torch.nn as nn
from diffusers import ConfigMixin, ModelMixin
from diffusers.models.autoencoders.vae import (
    DecoderOutput,
    DiagonalGaussianDistribution,
)
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd


class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
    """Variational Autoencoder (VAE) model with KL loss.

    VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
    This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss.

    Args:
        encoder (`nn.Module`):
            Encoder module.
        decoder (`nn.Module`):
            Decoder module.
        latent_channels (`int`, *optional*, defaults to 4):
            Number of latent channels.
    """

    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        latent_channels: int = 4,
        dims: int = 2,
        sample_size=512,
        use_quant_conv: bool = True,
        normalize_latent_channels: bool = False,
    ):
        super().__init__()

        # pass init params to Encoder
        self.encoder = encoder
        self.use_quant_conv = use_quant_conv
        self.normalize_latent_channels = normalize_latent_channels

        # pass init params to Decoder
        quant_dims = 2 if dims == 2 else 3
        self.decoder = decoder
        if use_quant_conv:
            self.quant_conv = make_conv_nd(
                quant_dims, 2 * latent_channels, 2 * latent_channels, 1
            )
            self.post_quant_conv = make_conv_nd(
                quant_dims, latent_channels, latent_channels, 1
            )
        else:
            self.quant_conv = nn.Identity()
            self.post_quant_conv = nn.Identity()

        if normalize_latent_channels:
            if dims == 2:
                self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False)
            else:
                self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False)
        else:
            self.latent_norm_out = nn.Identity()
        self.use_z_tiling = False
        self.use_hw_tiling = False
        self.dims = dims
        self.z_sample_size = 1

        self.decoder_params = inspect.signature(self.decoder.forward).parameters

        # only relevant if vae tiling is enabled
        self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)

    def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25):
        self.tile_sample_min_size = sample_size
        num_blocks = len(self.encoder.down_blocks)
        self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1)))
        self.tile_overlap_factor = overlap_factor

    def enable_z_tiling(self, z_sample_size: int = 8):
        r"""
        Enable tiling during VAE decoding.

        When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several
        steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.use_z_tiling = z_sample_size > 1
        self.z_sample_size = z_sample_size
        assert (
            z_sample_size % 8 == 0 or z_sample_size == 1
        ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}."

    def disable_z_tiling(self):
        r"""
        Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing
        decoding in one step.
        """
        self.use_z_tiling = False

    def enable_hw_tiling(self):
        r"""
        Enable tiling during VAE decoding along the height and width dimension.
        """
        self.use_hw_tiling = True

    def disable_hw_tiling(self):
        r"""
        Disable tiling during VAE decoding along the height and width dimension.
        """
        self.use_hw_tiling = False

    def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True):
        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
        row_limit = self.tile_latent_min_size - blend_extent

        # Split the image into 512x512 tiles and encode them separately.
        rows = []
        for i in range(0, x.shape[3], overlap_size):
            row = []
            for j in range(0, x.shape[4], overlap_size):
                tile = x[
                    :,
                    :,
                    :,
                    i : i + self.tile_sample_min_size,
                    j : j + self.tile_sample_min_size,
                ]
                tile = self.encoder(tile)
                tile = self.quant_conv(tile)
                row.append(tile)
            rows.append(row)
        result_rows = []
        for i, row in enumerate(rows):
            result_row = []
            for j, tile in enumerate(row):
                # blend the above tile and the left tile
                # to the current tile and add the current tile to the result row
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_extent)
                result_row.append(tile[:, :, :, :row_limit, :row_limit])
            result_rows.append(torch.cat(result_row, dim=4))

        moments = torch.cat(result_rows, dim=3)
        return moments

    def blend_z(
        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
    ) -> torch.Tensor:
        blend_extent = min(a.shape[2], b.shape[2], blend_extent)
        for z in range(blend_extent):
            b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
                1 - z / blend_extent
            ) + b[:, :, z, :, :] * (z / blend_extent)
        return b

    def blend_v(
        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
    ) -> torch.Tensor:
        blend_extent = min(a.shape[3], b.shape[3], blend_extent)
        for y in range(blend_extent):
            b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
                1 - y / blend_extent
            ) + b[:, :, :, y, :] * (y / blend_extent)
        return b

    def blend_h(
        self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
    ) -> torch.Tensor:
        blend_extent = min(a.shape[4], b.shape[4], blend_extent)
        for x in range(blend_extent):
            b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
                1 - x / blend_extent
            ) + b[:, :, :, :, x] * (x / blend_extent)
        return b

    def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
        overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
        blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
        row_limit = self.tile_sample_min_size - blend_extent
        tile_target_shape = (
            *target_shape[:3],
            self.tile_sample_min_size,
            self.tile_sample_min_size,
        )
        # Split z into overlapping 64x64 tiles and decode them separately.
        # The tiles have an overlap to avoid seams between tiles.
        rows = []
        for i in range(0, z.shape[3], overlap_size):
            row = []
            for j in range(0, z.shape[4], overlap_size):
                tile = z[
                    :,
                    :,
                    :,
                    i : i + self.tile_latent_min_size,
                    j : j + self.tile_latent_min_size,
                ]
                tile = self.post_quant_conv(tile)
                decoded = self.decoder(tile, target_shape=tile_target_shape)
                row.append(decoded)
            rows.append(row)
        result_rows = []
        for i, row in enumerate(rows):
            result_row = []
            for j, tile in enumerate(row):
                # blend the above tile and the left tile
                # to the current tile and add the current tile to the result row
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_extent)
                result_row.append(tile[:, :, :, :row_limit, :row_limit])
            result_rows.append(torch.cat(result_row, dim=4))

        dec = torch.cat(result_rows, dim=3)
        return dec

    def encode(
        self, z: torch.FloatTensor, return_dict: bool = True
    ) -> Union[DecoderOutput, torch.FloatTensor]:
        if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
            num_splits = z.shape[2] // self.z_sample_size
            sizes = [self.z_sample_size] * num_splits
            sizes = (
                sizes + [z.shape[2] - sum(sizes)]
                if z.shape[2] - sum(sizes) > 0
                else sizes
            )
            tiles = z.split(sizes, dim=2)
            moments_tiles = [
                (
                    self._hw_tiled_encode(z_tile, return_dict)
                    if self.use_hw_tiling
                    else self._encode(z_tile)
                )
                for z_tile in tiles
            ]
            moments = torch.cat(moments_tiles, dim=2)

        else:
            moments = (
                self._hw_tiled_encode(z, return_dict)
                if self.use_hw_tiling
                else self._encode(z)
            )

        posterior = DiagonalGaussianDistribution(moments)
        if not return_dict:
            return (posterior,)

        return AutoencoderKLOutput(latent_dist=posterior)

    def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
        if isinstance(self.latent_norm_out, nn.BatchNorm3d):
            _, c, _, _, _ = z.shape
            z = torch.cat(
                [
                    self.latent_norm_out(z[:, : c // 2, :, :, :]),
                    z[:, c // 2 :, :, :, :],
                ],
                dim=1,
            )
        elif isinstance(self.latent_norm_out, nn.BatchNorm2d):
            raise NotImplementedError("BatchNorm2d not supported")
        return z

    def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
        if isinstance(self.latent_norm_out, nn.BatchNorm3d):
            running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1)
            running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1)
            eps = self.latent_norm_out.eps

            z = z * torch.sqrt(running_var + eps) + running_mean
        elif isinstance(self.latent_norm_out, nn.BatchNorm3d):
            raise NotImplementedError("BatchNorm2d not supported")
        return z

    def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
        h = self.encoder(x)
        moments = self.quant_conv(h)
        moments = self._normalize_latent_channels(moments)
        return moments

    def _decode(
        self,
        z: torch.FloatTensor,
        target_shape=None,
        timestep: Optional[torch.Tensor] = None,
    ) -> Union[DecoderOutput, torch.FloatTensor]:
        z = self._unnormalize_latent_channels(z)
        z = self.post_quant_conv(z)
        if "timestep" in self.decoder_params:
            dec = self.decoder(z, target_shape=target_shape, timestep=timestep)
        else:
            dec = self.decoder(z, target_shape=target_shape)
        return dec

    def decode(
        self,
        z: torch.FloatTensor,
        return_dict: bool = True,
        target_shape=None,
        timestep: Optional[torch.Tensor] = None,
    ) -> Union[DecoderOutput, torch.FloatTensor]:
        assert target_shape is not None, "target_shape must be provided for decoding"
        if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
            reduction_factor = int(
                self.encoder.patch_size_t
                * 2
                ** (
                    len(self.encoder.down_blocks)
                    - 1
                    - math.sqrt(self.encoder.patch_size)
                )
            )
            split_size = self.z_sample_size // reduction_factor
            num_splits = z.shape[2] // split_size

            # copy target shape, and divide frame dimension (=2) by the context size
            target_shape_split = list(target_shape)
            target_shape_split[2] = target_shape[2] // num_splits

            decoded_tiles = [
                (
                    self._hw_tiled_decode(z_tile, target_shape_split)
                    if self.use_hw_tiling
                    else self._decode(z_tile, target_shape=target_shape_split)
                )
                for z_tile in torch.tensor_split(z, num_splits, dim=2)
            ]
            decoded = torch.cat(decoded_tiles, dim=2)
        else:
            decoded = (
                self._hw_tiled_decode(z, target_shape)
                if self.use_hw_tiling
                else self._decode(z, target_shape=target_shape, timestep=timestep)
            )

        if not return_dict:
            return (decoded,)

        return DecoderOutput(sample=decoded)

    def forward(
        self,
        sample: torch.FloatTensor,
        sample_posterior: bool = False,
        return_dict: bool = True,
        generator: Optional[torch.Generator] = None,
    ) -> Union[DecoderOutput, torch.FloatTensor]:
        r"""
        Args:
            sample (`torch.FloatTensor`): Input sample.
            sample_posterior (`bool`, *optional*, defaults to `False`):
                Whether to sample from the posterior.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`DecoderOutput`] instead of a plain tuple.
            generator (`torch.Generator`, *optional*):
                Generator used to sample from the posterior.
        """
        x = sample
        posterior = self.encode(x).latent_dist
        if sample_posterior:
            z = posterior.sample(generator=generator)
        else:
            z = posterior.mode()
        dec = self.decode(z, target_shape=sample.shape).sample

        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)

```

## /ltx_video/models/autoencoders/vae_encode.py

```py path="/ltx_video/models/autoencoders/vae_encode.py" 
from typing import Tuple
import torch
from diffusers import AutoencoderKL
from einops import rearrange
from torch import Tensor


from ltx_video.models.autoencoders.causal_video_autoencoder import (
    CausalVideoAutoencoder,
)
from ltx_video.models.autoencoders.video_autoencoder import (
    Downsample3D,
    VideoAutoencoder,
)

try:
    import torch_xla.core.xla_model as xm
except ImportError:
    xm = None


def vae_encode(
    media_items: Tensor,
    vae: AutoencoderKL,
    split_size: int = 1,
    vae_per_channel_normalize=False,
) -> Tensor:
    """
    Encodes media items (images or videos) into latent representations using a specified VAE model.
    The function supports processing batches of images or video frames and can handle the processing
    in smaller sub-batches if needed.

    Args:
        media_items (Tensor): A torch Tensor containing the media items to encode. The expected
            shape is (batch_size, channels, height, width) for images or (batch_size, channels,
            frames, height, width) for videos.
        vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
            pre-configured and loaded with the appropriate model weights.
        split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
            If set to more than 1, the input media items are processed in smaller batches according to
            this value. Defaults to 1, which processes all items in a single batch.

    Returns:
        Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
            to match the input shape, scaled by the model's configuration.

    Examples:
        >>> import torch
        >>> from diffusers import AutoencoderKL
        >>> vae = AutoencoderKL.from_pretrained('your-model-name')
        >>> images = torch.rand(10, 3, 8 256, 256)  # Example tensor with 10 videos of 8 frames.
        >>> latents = vae_encode(images, vae)
        >>> print(latents.shape)  # Output shape will depend on the model's latent configuration.

    Note:
        In case of a video, the function encodes the media item frame-by frame.
    """
    is_video_shaped = media_items.dim() == 5
    batch_size, channels = media_items.shape[0:2]

    if channels != 3:
        raise ValueError(f"Expects tensors with 3 channels, got {channels}.")

    if is_video_shaped and not isinstance(
        vae, (VideoAutoencoder, CausalVideoAutoencoder)
    ):
        media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
    if split_size > 1:
        if len(media_items) % split_size != 0:
            raise ValueError(
                "Error: The batch size must be divisible by 'train.vae_bs_split"
            )
        encode_bs = len(media_items) // split_size
        # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
        latents = []
        if media_items.device.type == "xla":
            xm.mark_step()
        for image_batch in media_items.split(encode_bs):
            latents.append(vae.encode(image_batch).latent_dist.sample())
            if media_items.device.type == "xla":
                xm.mark_step()
        latents = torch.cat(latents, dim=0)
    else:
        latents = vae.encode(media_items).latent_dist.sample()

    latents = normalize_latents(latents, vae, vae_per_channel_normalize)
    if is_video_shaped and not isinstance(
        vae, (VideoAutoencoder, CausalVideoAutoencoder)
    ):
        latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
    return latents


def vae_decode(
    latents: Tensor,
    vae: AutoencoderKL,
    is_video: bool = True,
    split_size: int = 1,
    vae_per_channel_normalize=False,
    timestep=None,
) -> Tensor:
    is_video_shaped = latents.dim() == 5
    batch_size = latents.shape[0]

    if is_video_shaped and not isinstance(
        vae, (VideoAutoencoder, CausalVideoAutoencoder)
    ):
        latents = rearrange(latents, "b c n h w -> (b n) c h w")
    if split_size > 1:
        if len(latents) % split_size != 0:
            raise ValueError(
                "Error: The batch size must be divisible by 'train.vae_bs_split"
            )
        encode_bs = len(latents) // split_size
        image_batch = [
            _run_decoder(
                latent_batch, vae, is_video, vae_per_channel_normalize, timestep
            )
            for latent_batch in latents.split(encode_bs)
        ]
        images = torch.cat(image_batch, dim=0)
    else:
        images = _run_decoder(
            latents, vae, is_video, vae_per_channel_normalize, timestep
        )

    if is_video_shaped and not isinstance(
        vae, (VideoAutoencoder, CausalVideoAutoencoder)
    ):
        images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
    return images


def _run_decoder(
    latents: Tensor,
    vae: AutoencoderKL,
    is_video: bool,
    vae_per_channel_normalize=False,
    timestep=None,
) -> Tensor:
    if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
        *_, fl, hl, wl = latents.shape
        temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
        latents = latents.to(vae.dtype)
        vae_decode_kwargs = {}
        if timestep is not None:
            vae_decode_kwargs["timestep"] = timestep
        image = vae.decode(
            un_normalize_latents(latents, vae, vae_per_channel_normalize),
            return_dict=False,
            target_shape=(
                1,
                3,
                fl * temporal_scale if is_video else 1,
                hl * spatial_scale,
                wl * spatial_scale,
            ),
            **vae_decode_kwargs,
        )[0]
    else:
        image = vae.decode(
            un_normalize_latents(latents, vae, vae_per_channel_normalize),
            return_dict=False,
        )[0]
    return image


def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
    if isinstance(vae, CausalVideoAutoencoder):
        spatial = vae.spatial_downscale_factor
        temporal = vae.temporal_downscale_factor
    else:
        down_blocks = len(
            [
                block
                for block in vae.encoder.down_blocks
                if isinstance(block.downsample, Downsample3D)
            ]
        )
        spatial = vae.config.patch_size * 2**down_blocks
        temporal = (
            vae.config.patch_size_t * 2**down_blocks
            if isinstance(vae, VideoAutoencoder)
            else 1
        )

    return (temporal, spatial, spatial)


def latent_to_pixel_coords(
    latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False
) -> Tensor:
    """
    Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
    configuration.

    Args:
        latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
        containing the latent corner coordinates of each token.
        vae (AutoencoderKL): The VAE model
        causal_fix (bool): Whether to take into account the different temporal scale
            of the first frame. Default = False for backwards compatibility.
    Returns:
        Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
    """

    scale_factors = get_vae_size_scale_factor(vae)
    causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix
    pixel_coords = latent_to_pixel_coords_from_factors(
        latent_coords, scale_factors, causal_fix
    )
    return pixel_coords


def latent_to_pixel_coords_from_factors(
    latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False
) -> Tensor:
    pixel_coords = (
        latent_coords
        * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
    )
    if causal_fix:
        # Fix temporal scale for first frame to 1 due to causality
        pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
    return pixel_coords


def normalize_latents(
    latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
) -> Tensor:
    return (
        (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
        / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
        if vae_per_channel_normalize
        else latents * vae.config.scaling_factor
    )


def un_normalize_latents(
    latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
) -> Tensor:
    return (
        latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
        + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
        if vae_per_channel_normalize
        else latents / vae.config.scaling_factor
    )

```

## /ltx_video/models/autoencoders/video_autoencoder.py

```py path="/ltx_video/models/autoencoders/video_autoencoder.py" 
import json
import os
from functools import partial
from types import SimpleNamespace
from typing import Any, Mapping, Optional, Tuple, Union

import torch
from einops import rearrange
from torch import nn
from torch.nn import functional

from diffusers.utils import logging

from ltx_video.utils.torch_utils import Identity
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
from ltx_video.models.autoencoders.pixel_norm import PixelNorm
from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper

logger = logging.get_logger(__name__)


class VideoAutoencoder(AutoencoderKLWrapper):
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        *args,
        **kwargs,
    ):
        config_local_path = pretrained_model_name_or_path / "config.json"
        config = cls.load_config(config_local_path, **kwargs)
        video_vae = cls.from_config(config)
        video_vae.to(kwargs["torch_dtype"])

        model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
        ckpt_state_dict = torch.load(model_local_path)
        video_vae.load_state_dict(ckpt_state_dict)

        statistics_local_path = (
            pretrained_model_name_or_path / "per_channel_statistics.json"
        )
        if statistics_local_path.exists():
            with open(statistics_local_path, "r") as file:
                data = json.load(file)
            transposed_data = list(zip(*data["data"]))
            data_dict = {
                col: torch.tensor(vals)
                for col, vals in zip(data["columns"], transposed_data)
            }
            video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
            video_vae.register_buffer(
                "mean_of_means",
                data_dict.get(
                    "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
                ),
            )

        return video_vae

    @staticmethod
    def from_config(config):
        assert (
            config["_class_name"] == "VideoAutoencoder"
        ), "config must have _class_name=VideoAutoencoder"
        if isinstance(config["dims"], list):
            config["dims"] = tuple(config["dims"])

        assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"

        double_z = config.get("double_z", True)
        latent_log_var = config.get(
            "latent_log_var", "per_channel" if double_z else "none"
        )
        use_quant_conv = config.get("use_quant_conv", True)

        if use_quant_conv and latent_log_var == "uniform":
            raise ValueError("uniform latent_log_var requires use_quant_conv=False")

        encoder = Encoder(
            dims=config["dims"],
            in_channels=config.get("in_channels", 3),
            out_channels=config["latent_channels"],
            block_out_channels=config["block_out_channels"],
            patch_size=config.get("patch_size", 1),
            latent_log_var=latent_log_var,
            norm_layer=config.get("norm_layer", "group_norm"),
            patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
            add_channel_padding=config.get("add_channel_padding", False),
        )

        decoder = Decoder(
            dims=config["dims"],
            in_channels=config["latent_channels"],
            out_channels=config.get("out_channels", 3),
            block_out_channels=config["block_out_channels"],
            patch_size=config.get("patch_size", 1),
            norm_layer=config.get("norm_layer", "group_norm"),
            patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
            add_channel_padding=config.get("add_channel_padding", False),
        )

        dims = config["dims"]
        return VideoAutoencoder(
            encoder=encoder,
            decoder=decoder,
            latent_channels=config["latent_channels"],
            dims=dims,
            use_quant_conv=use_quant_conv,
        )

    @property
    def config(self):
        return SimpleNamespace(
            _class_name="VideoAutoencoder",
            dims=self.dims,
            in_channels=self.encoder.conv_in.in_channels
            // (self.encoder.patch_size_t * self.encoder.patch_size**2),
            out_channels=self.decoder.conv_out.out_channels
            // (self.decoder.patch_size_t * self.decoder.patch_size**2),
            latent_channels=self.decoder.conv_in.in_channels,
            block_out_channels=[
                self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
                for i in range(len(self.encoder.down_blocks))
            ],
            scaling_factor=1.0,
            norm_layer=self.encoder.norm_layer,
            patch_size=self.encoder.patch_size,
            latent_log_var=self.encoder.latent_log_var,
            use_quant_conv=self.use_quant_conv,
            patch_size_t=self.encoder.patch_size_t,
            add_channel_padding=self.encoder.add_channel_padding,
        )

    @property
    def is_video_supported(self):
        """
        Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
        """
        return self.dims != 2

    @property
    def downscale_factor(self):
        return self.encoder.downsample_factor

    def to_json_string(self) -> str:
        import json

        return json.dumps(self.config.__dict__)

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
        model_keys = set(name for name, _ in self.named_parameters())

        key_mapping = {
            ".resnets.": ".res_blocks.",
            "downsamplers.0": "downsample",
            "upsamplers.0": "upsample",
        }

        converted_state_dict = {}
        for key, value in state_dict.items():
            for k, v in key_mapping.items():
                key = key.replace(k, v)

            if "norm" in key and key not in model_keys:
                logger.info(
                    f"Removing key {key} from state_dict as it is not present in the model"
                )
                continue

            converted_state_dict[key] = value

        super().load_state_dict(converted_state_dict, strict=strict)

    def last_layer(self):
        if hasattr(self.decoder, "conv_out"):
            if isinstance(self.decoder.conv_out, nn.Sequential):
                last_layer = self.decoder.conv_out[-1]
            else:
                last_layer = self.decoder.conv_out
        else:
            last_layer = self.decoder.layers[-1]
        return last_layer


class Encoder(nn.Module):
    r"""
    The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        patch_size (`int`, *optional*, defaults to 1):
            The patch size to use. Should be a power of 2.
        norm_layer (`str`, *optional*, defaults to `group_norm`):
            The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
        latent_log_var (`str`, *optional*, defaults to `per_channel`):
            The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
    """

    def __init__(
        self,
        dims: Union[int, Tuple[int, int]] = 3,
        in_channels: int = 3,
        out_channels: int = 3,
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        patch_size: Union[int, Tuple[int]] = 1,
        norm_layer: str = "group_norm",  # group_norm, pixel_norm
        latent_log_var: str = "per_channel",
        patch_size_t: Optional[int] = None,
        add_channel_padding: Optional[bool] = False,
    ):
        super().__init__()
        self.patch_size = patch_size
        self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
        self.add_channel_padding = add_channel_padding
        self.layers_per_block = layers_per_block
        self.norm_layer = norm_layer
        self.latent_channels = out_channels
        self.latent_log_var = latent_log_var
        if add_channel_padding:
            in_channels = in_channels * self.patch_size**3
        else:
            in_channels = in_channels * self.patch_size_t * self.patch_size**2
        self.in_channels = in_channels
        output_channel = block_out_channels[0]

        self.conv_in = make_conv_nd(
            dims=dims,
            in_channels=in_channels,
            out_channels=output_channel,
            kernel_size=3,
            stride=1,
            padding=1,
        )

        self.down_blocks = nn.ModuleList([])

        for i in range(len(block_out_channels)):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = DownEncoderBlock3D(
                dims=dims,
                in_channels=input_channel,
                out_channels=output_channel,
                num_layers=self.layers_per_block,
                add_downsample=not is_final_block and 2**i >= patch_size,
                resnet_eps=1e-6,
                downsample_padding=0,
                resnet_groups=norm_num_groups,
                norm_layer=norm_layer,
            )
            self.down_blocks.append(down_block)

        self.mid_block = UNetMidBlock3D(
            dims=dims,
            in_channels=block_out_channels[-1],
            num_layers=self.layers_per_block,
            resnet_eps=1e-6,
            resnet_groups=norm_num_groups,
            norm_layer=norm_layer,
        )

        # out
        if norm_layer == "group_norm":
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[-1],
                num_groups=norm_num_groups,
                eps=1e-6,
            )
        elif norm_layer == "pixel_norm":
            self.conv_norm_out = PixelNorm()
        self.conv_act = nn.SiLU()

        conv_out_channels = out_channels
        if latent_log_var == "per_channel":
            conv_out_channels *= 2
        elif latent_log_var == "uniform":
            conv_out_channels += 1
        elif latent_log_var != "none":
            raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
        self.conv_out = make_conv_nd(
            dims, block_out_channels[-1], conv_out_channels, 3, padding=1
        )

        self.gradient_checkpointing = False

    @property
    def downscale_factor(self):
        return (
            2
            ** len(
                [
                    block
                    for block in self.down_blocks
                    if isinstance(block.downsample, Downsample3D)
                ]
            )
            * self.patch_size
        )

    def forward(
        self, sample: torch.FloatTensor, return_features=False
    ) -> torch.FloatTensor:
        r"""The forward method of the `Encoder` class."""

        downsample_in_time = sample.shape[2] != 1

        # patchify
        patch_size_t = self.patch_size_t if downsample_in_time else 1
        sample = patchify(
            sample,
            patch_size_hw=self.patch_size,
            patch_size_t=patch_size_t,
            add_channel_padding=self.add_channel_padding,
        )

        sample = self.conv_in(sample)

        checkpoint_fn = (
            partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
            if self.gradient_checkpointing and self.training
            else lambda x: x
        )

        if return_features:
            features = []
        for down_block in self.down_blocks:
            sample = checkpoint_fn(down_block)(
                sample, downsample_in_time=downsample_in_time
            )
            if return_features:
                features.append(sample)

        sample = checkpoint_fn(self.mid_block)(sample)

        # post-process
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        if self.latent_log_var == "uniform":
            last_channel = sample[:, -1:, ...]
            num_dims = sample.dim()

            if num_dims == 4:
                # For shape (B, C, H, W)
                repeated_last_channel = last_channel.repeat(
                    1, sample.shape[1] - 2, 1, 1
                )
                sample = torch.cat([sample, repeated_last_channel], dim=1)
            elif num_dims == 5:
                # For shape (B, C, F, H, W)
                repeated_last_channel = last_channel.repeat(
                    1, sample.shape[1] - 2, 1, 1, 1
                )
                sample = torch.cat([sample, repeated_last_channel], dim=1)
            else:
                raise ValueError(f"Invalid input shape: {sample.shape}")

        if return_features:
            features.append(sample[:, : self.latent_channels, ...])
            return sample, features
        return sample


class Decoder(nn.Module):
    r"""
    The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        patch_size (`int`, *optional*, defaults to 1):
            The patch size to use. Should be a power of 2.
        norm_layer (`str`, *optional*, defaults to `group_norm`):
            The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
    """

    def __init__(
        self,
        dims,
        in_channels: int = 3,
        out_channels: int = 3,
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        patch_size: int = 1,
        norm_layer: str = "group_norm",
        patch_size_t: Optional[int] = None,
        add_channel_padding: Optional[bool] = False,
    ):
        super().__init__()
        self.patch_size = patch_size
        self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
        self.add_channel_padding = add_channel_padding
        self.layers_per_block = layers_per_block
        if add_channel_padding:
            out_channels = out_channels * self.patch_size**3
        else:
            out_channels = out_channels * self.patch_size_t * self.patch_size**2
        self.out_channels = out_channels

        self.conv_in = make_conv_nd(
            dims,
            in_channels,
            block_out_channels[-1],
            kernel_size=3,
            stride=1,
            padding=1,
        )

        self.mid_block = None
        self.up_blocks = nn.ModuleList([])

        self.mid_block = UNetMidBlock3D(
            dims=dims,
            in_channels=block_out_channels[-1],
            num_layers=self.layers_per_block,
            resnet_eps=1e-6,
            resnet_groups=norm_num_groups,
            norm_layer=norm_layer,
        )

        reversed_block_out_channels = list(reversed(block_out_channels))
        output_channel = reversed_block_out_channels[0]
        for i in range(len(reversed_block_out_channels)):
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]

            is_final_block = i == len(block_out_channels) - 1

            up_block = UpDecoderBlock3D(
                dims=dims,
                num_layers=self.layers_per_block + 1,
                in_channels=prev_output_channel,
                out_channels=output_channel,
                add_upsample=not is_final_block
                and 2 ** (len(block_out_channels) - i - 1) > patch_size,
                resnet_eps=1e-6,
                resnet_groups=norm_num_groups,
                norm_layer=norm_layer,
            )
            self.up_blocks.append(up_block)

        if norm_layer == "group_norm":
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
            )
        elif norm_layer == "pixel_norm":
            self.conv_norm_out = PixelNorm()

        self.conv_act = nn.SiLU()
        self.conv_out = make_conv_nd(
            dims, block_out_channels[0], out_channels, 3, padding=1
        )

        self.gradient_checkpointing = False

    def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
        r"""The forward method of the `Decoder` class."""
        assert target_shape is not None, "target_shape must be provided"
        upsample_in_time = sample.shape[2] < target_shape[2]

        sample = self.conv_in(sample)

        upscale_dtype = next(iter(self.up_blocks.parameters())).dtype

        checkpoint_fn = (
            partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
            if self.gradient_checkpointing and self.training
            else lambda x: x
        )

        sample = checkpoint_fn(self.mid_block)(sample)
        sample = sample.to(upscale_dtype)

        for up_block in self.up_blocks:
            sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)

        # post-process
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        # un-patchify
        patch_size_t = self.patch_size_t if upsample_in_time else 1
        sample = unpatchify(
            sample,
            patch_size_hw=self.patch_size,
            patch_size_t=patch_size_t,
            add_channel_padding=self.add_channel_padding,
        )

        return sample


class DownEncoderBlock3D(nn.Module):
    def __init__(
        self,
        dims: Union[int, Tuple[int, int]],
        in_channels: int,
        out_channels: int,
        dropout: float = 0.0,
        num_layers: int = 1,
        resnet_eps: float = 1e-6,
        resnet_groups: int = 32,
        add_downsample: bool = True,
        downsample_padding: int = 1,
        norm_layer: str = "group_norm",
    ):
        super().__init__()
        res_blocks = []

        for i in range(num_layers):
            in_channels = in_channels if i == 0 else out_channels
            res_blocks.append(
                ResnetBlock3D(
                    dims=dims,
                    in_channels=in_channels,
                    out_channels=out_channels,
                    eps=resnet_eps,
                    groups=resnet_groups,
                    dropout=dropout,
                    norm_layer=norm_layer,
                )
            )

        self.res_blocks = nn.ModuleList(res_blocks)

        if add_downsample:
            self.downsample = Downsample3D(
                dims,
                out_channels,
                out_channels=out_channels,
                padding=downsample_padding,
            )
        else:
            self.downsample = Identity()

    def forward(
        self, hidden_states: torch.FloatTensor, downsample_in_time
    ) -> torch.FloatTensor:
        for resnet in self.res_blocks:
            hidden_states = resnet(hidden_states)

        hidden_states = self.downsample(
            hidden_states, downsample_in_time=downsample_in_time
        )

        return hidden_states


class UNetMidBlock3D(nn.Module):
    """
    A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.

    Args:
        in_channels (`int`): The number of input channels.
        dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
        num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
        resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
        resnet_groups (`int`, *optional*, defaults to 32):
            The number of groups to use in the group normalization layers of the resnet blocks.

    Returns:
        `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
        in_channels, height, width)`.

    """

    def __init__(
        self,
        dims: Union[int, Tuple[int, int]],
        in_channels: int,
        dropout: float = 0.0,
        num_layers: int = 1,
        resnet_eps: float = 1e-6,
        resnet_groups: int = 32,
        norm_layer: str = "group_norm",
    ):
        super().__init__()
        resnet_groups = (
            resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
        )

        self.res_blocks = nn.ModuleList(
            [
                ResnetBlock3D(
                    dims=dims,
                    in_channels=in_channels,
                    out_channels=in_channels,
                    eps=resnet_eps,
                    groups=resnet_groups,
                    dropout=dropout,
                    norm_layer=norm_layer,
                )
                for _ in range(num_layers)
            ]
        )

    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
        for resnet in self.res_blocks:
            hidden_states = resnet(hidden_states)

        return hidden_states


class UpDecoderBlock3D(nn.Module):
    def __init__(
        self,
        dims: Union[int, Tuple[int, int]],
        in_channels: int,
        out_channels: int,
        resolution_idx: Optional[int] = None,
        dropout: float = 0.0,
        num_layers: int = 1,
        resnet_eps: float = 1e-6,
        resnet_groups: int = 32,
        add_upsample: bool = True,
        norm_layer: str = "group_norm",
    ):
        super().__init__()
        res_blocks = []

        for i in range(num_layers):
            input_channels = in_channels if i == 0 else out_channels

            res_blocks.append(
                ResnetBlock3D(
                    dims=dims,
                    in_channels=input_channels,
                    out_channels=out_channels,
                    eps=resnet_eps,
                    groups=resnet_groups,
                    dropout=dropout,
                    norm_layer=norm_layer,
                )
            )

        self.res_blocks = nn.ModuleList(res_blocks)

        if add_upsample:
            self.upsample = Upsample3D(
                dims=dims, channels=out_channels, out_channels=out_channels
            )
        else:
            self.upsample = Identity()

        self.resolution_idx = resolution_idx

    def forward(
        self, hidden_states: torch.FloatTensor, upsample_in_time=True
    ) -> torch.FloatTensor:
        for resnet in self.res_blocks:
            hidden_states = resnet(hidden_states)

        hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time)

        return hidden_states


class ResnetBlock3D(nn.Module):
    r"""
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
    """

    def __init__(
        self,
        dims: Union[int, Tuple[int, int]],
        in_channels: int,
        out_channels: Optional[int] = None,
        conv_shortcut: bool = False,
        dropout: float = 0.0,
        groups: int = 32,
        eps: float = 1e-6,
        norm_layer: str = "group_norm",
    ):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        if norm_layer == "group_norm":
            self.norm1 = torch.nn.GroupNorm(
                num_groups=groups, num_channels=in_channels, eps=eps, affine=True
            )
        elif norm_layer == "pixel_norm":
            self.norm1 = PixelNorm()

        self.non_linearity = nn.SiLU()

        self.conv1 = make_conv_nd(
            dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1
        )

        if norm_layer == "group_norm":
            self.norm2 = torch.nn.GroupNorm(
                num_groups=groups, num_channels=out_channels, eps=eps, affine=True
            )
        elif norm_layer == "pixel_norm":
            self.norm2 = PixelNorm()

        self.dropout = torch.nn.Dropout(dropout)

        self.conv2 = make_conv_nd(
            dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1
        )

        self.conv_shortcut = (
            make_linear_nd(
                dims=dims, in_channels=in_channels, out_channels=out_channels
            )
            if in_channels != out_channels
            else nn.Identity()
        )

    def forward(
        self,
        input_tensor: torch.FloatTensor,
    ) -> torch.FloatTensor:
        hidden_states = input_tensor

        hidden_states = self.norm1(hidden_states)

        hidden_states = self.non_linearity(hidden_states)

        hidden_states = self.conv1(hidden_states)

        hidden_states = self.norm2(hidden_states)

        hidden_states = self.non_linearity(hidden_states)

        hidden_states = self.dropout(hidden_states)

        hidden_states = self.conv2(hidden_states)

        input_tensor = self.conv_shortcut(input_tensor)

        output_tensor = input_tensor + hidden_states

        return output_tensor


class Downsample3D(nn.Module):
    def __init__(
        self,
        dims,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        padding: int = 1,
    ):
        super().__init__()
        stride: int = 2
        self.padding = padding
        self.in_channels = in_channels
        self.dims = dims
        self.conv = make_conv_nd(
            dims=dims,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

    def forward(self, x, downsample_in_time=True):
        conv = self.conv
        if self.padding == 0:
            if self.dims == 2:
                padding = (0, 1, 0, 1)
            else:
                padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)

            x = functional.pad(x, padding, mode="constant", value=0)

            if self.dims == (2, 1) and not downsample_in_time:
                return conv(x, skip_time_conv=True)

        return conv(x)


class Upsample3D(nn.Module):
    """
    An upsampling layer for 3D tensors of shape (B, C, D, H, W).

    :param channels: channels in the inputs and outputs.
    """

    def __init__(self, dims, channels, out_channels=None):
        super().__init__()
        self.dims = dims
        self.channels = channels
        self.out_channels = out_channels or channels
        self.conv = make_conv_nd(
            dims, channels, out_channels, kernel_size=3, padding=1, bias=True
        )

    def forward(self, x, upsample_in_time):
        if self.dims == 2:
            x = functional.interpolate(
                x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
            )
        else:
            time_scale_factor = 2 if upsample_in_time else 1
            # print("before:", x.shape)
            b, c, d, h, w = x.shape
            x = rearrange(x, "b c d h w -> (b d) c h w")
            # height and width interpolate
            x = functional.interpolate(
                x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
            )
            _, _, h, w = x.shape

            if not upsample_in_time and self.dims == (2, 1):
                x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w)
                return self.conv(x, skip_time_conv=True)

            # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension
            x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b)

            # (b h w) c 1 d
            new_d = x.shape[-1] * time_scale_factor
            x = functional.interpolate(x, (1, new_d), mode="nearest")
            # (b h w) c 1 new_d
            x = rearrange(
                x, "(b h w) c 1 new_d  -> b c new_d h w", b=b, h=h, w=w, new_d=new_d
            )
            # b c d h w

            # x = functional.interpolate(
            #     x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
            # )
            # print("after:", x.shape)

        return self.conv(x)


def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
    if patch_size_hw == 1 and patch_size_t == 1:
        return x
    if x.dim() == 4:
        x = rearrange(
            x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
        )
    elif x.dim() == 5:
        x = rearrange(
            x,
            "b c (f p) (h q) (w r) -> b (c p r q) f h w",
            p=patch_size_t,
            q=patch_size_hw,
            r=patch_size_hw,
        )
    else:
        raise ValueError(f"Invalid input shape: {x.shape}")

    if (
        (x.dim() == 5)
        and (patch_size_hw > patch_size_t)
        and (patch_size_t > 1 or add_channel_padding)
    ):
        channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
        padding_zeros = torch.zeros(
            x.shape[0],
            channels_to_pad,
            x.shape[2],
            x.shape[3],
            x.shape[4],
            device=x.device,
            dtype=x.dtype,
        )
        x = torch.cat([padding_zeros, x], dim=1)

    return x


def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
    if patch_size_hw == 1 and patch_size_t == 1:
        return x

    if (
        (x.dim() == 5)
        and (patch_size_hw > patch_size_t)
        and (patch_size_t > 1 or add_channel_padding)
    ):
        channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
        x = x[:, :channels_to_keep, :, :, :]

    if x.dim() == 4:
        x = rearrange(
            x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
        )
    elif x.dim() == 5:
        x = rearrange(
            x,
            "b (c p r q) f h w -> b c (f p) (h q) (w r)",
            p=patch_size_t,
            q=patch_size_hw,
            r=patch_size_hw,
        )

    return x


def create_video_autoencoder_config(
    latent_channels: int = 4,
):
    config = {
        "_class_name": "VideoAutoencoder",
        "dims": (
            2,
            1,
        ),  # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
        "in_channels": 3,  # Number of input color channels (e.g., RGB)
        "out_channels": 3,  # Number of output color channels
        "latent_channels": latent_channels,  # Number of channels in the latent space representation
        "block_out_channels": [
            128,
            256,
            512,
            512,
        ],  # Number of output channels of each encoder / decoder inner block
        "patch_size": 1,
    }

    return config


def create_video_autoencoder_pathify4x4x4_config(
    latent_channels: int = 4,
):
    config = {
        "_class_name": "VideoAutoencoder",
        "dims": (
            2,
            1,
        ),  # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
        "in_channels": 3,  # Number of input color channels (e.g., RGB)
        "out_channels": 3,  # Number of output color channels
        "latent_channels": latent_channels,  # Number of channels in the latent space representation
        "block_out_channels": [512]
        * 4,  # Number of output channels of each encoder / decoder inner block
        "patch_size": 4,
        "latent_log_var": "uniform",
    }

    return config


def create_video_autoencoder_pathify4x4_config(
    latent_channels: int = 4,
):
    config = {
        "_class_name": "VideoAutoencoder",
        "dims": 2,  # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
        "in_channels": 3,  # Number of input color channels (e.g., RGB)
        "out_channels": 3,  # Number of output color channels
        "latent_channels": latent_channels,  # Number of channels in the latent space representation
        "block_out_channels": [512]
        * 4,  # Number of output channels of each encoder / decoder inner block
        "patch_size": 4,
        "norm_layer": "pixel_norm",
    }

    return config


def test_vae_patchify_unpatchify():
    import torch

    x = torch.randn(2, 3, 8, 64, 64)
    x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
    x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
    assert torch.allclose(x, x_unpatched)


def demo_video_autoencoder_forward_backward():
    # Configuration for the VideoAutoencoder
    config = create_video_autoencoder_pathify4x4x4_config()

    # Instantiate the VideoAutoencoder with the specified configuration
    video_autoencoder = VideoAutoencoder.from_config(config)

    print(video_autoencoder)

    # Print the total number of parameters in the video autoencoder
    total_params = sum(p.numel() for p in video_autoencoder.parameters())
    print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")

    # Create a mock input tensor simulating a batch of videos
    # Shape: (batch_size, channels, depth, height, width)
    # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
    input_videos = torch.randn(2, 3, 8, 64, 64)

    # Forward pass: encode and decode the input videos
    latent = video_autoencoder.encode(input_videos).latent_dist.mode()
    print(f"input shape={input_videos.shape}")
    print(f"latent shape={latent.shape}")
    reconstructed_videos = video_autoencoder.decode(
        latent, target_shape=input_videos.shape
    ).sample

    print(f"reconstructed shape={reconstructed_videos.shape}")

    # Calculate the loss (e.g., mean squared error)
    loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)

    # Perform backward pass
    loss.backward()

    print(f"Demo completed with loss: {loss.item()}")


# Ensure to call the demo function to execute the forward and backward pass
if __name__ == "__main__":
    demo_video_autoencoder_forward_backward()

```

## /ltx_video/models/transformers/__init__.py

```py path="/ltx_video/models/transformers/__init__.py" 

```

## /ltx_video/models/transformers/embeddings.py

```py path="/ltx_video/models/transformers/embeddings.py" 
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
import math

import numpy as np
import torch
from einops import rearrange
from torch import nn


def get_timestep_embedding(
    timesteps: torch.Tensor,
    embedding_dim: int,
    flip_sin_to_cos: bool = False,
    downscale_freq_shift: float = 1,
    scale: float = 1,
    max_period: int = 10000,
):
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
    embeddings. :return: an [N x dim] Tensor of positional embeddings.
    """
    assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

    half_dim = embedding_dim // 2
    exponent = -math.log(max_period) * torch.arange(
        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
    )
    exponent = exponent / (half_dim - downscale_freq_shift)

    emb = torch.exp(exponent)
    emb = timesteps[:, None].float() * emb[None, :]

    # scale embeddings
    emb = scale * emb

    # concat sine and cosine embeddings
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

    # flip sine and cosine embeddings
    if flip_sin_to_cos:
        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)

    # zero pad
    if embedding_dim % 2 == 1:
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb


def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
    """
    grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
    [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
    grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
    grid = grid.reshape([3, 1, w, h, f])
    pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
    pos_embed = pos_embed.transpose(1, 0, 2, 3)
    return rearrange(pos_embed, "h w f c -> (f h w) c")


def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
    if embed_dim % 3 != 0:
        raise ValueError("embed_dim must be divisible by 3")

    # use half of dimensions to encode grid_h
    emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0])  # (H*W*T, D/3)
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1])  # (H*W*T, D/3)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2])  # (H*W*T, D/3)

    emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1)  # (H*W*T, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
    """
    if embed_dim % 2 != 0:
        raise ValueError("embed_dim must be divisible by 2")

    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos_shape = pos.shape

    pos = pos.reshape(-1)
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
    out = out.reshape([*pos_shape, -1])[0]

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=-1)  # (M, D)
    return emb


class SinusoidalPositionalEmbedding(nn.Module):
    """Apply positional information to a sequence of embeddings.

    Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
    them

    Args:
        embed_dim: (int): Dimension of the positional embedding.
        max_seq_length: Maximum sequence length to apply positional embeddings

    """

    def __init__(self, embed_dim: int, max_seq_length: int = 32):
        super().__init__()
        position = torch.arange(max_seq_length).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
        )
        pe = torch.zeros(1, max_seq_length, embed_dim)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        _, seq_length, _ = x.shape
        x = x + self.pe[:, :seq_length]
        return x

```

## /ltx_video/models/transformers/symmetric_patchifier.py

```py path="/ltx_video/models/transformers/symmetric_patchifier.py" 
from abc import ABC, abstractmethod
from typing import Tuple

import torch
from diffusers.configuration_utils import ConfigMixin
from einops import rearrange
from torch import Tensor


class Patchifier(ConfigMixin, ABC):
    def __init__(self, patch_size: int):
        super().__init__()
        self._patch_size = (1, patch_size, patch_size)

    @abstractmethod
    def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
        raise NotImplementedError("Patchify method not implemented")

    @abstractmethod
    def unpatchify(
        self,
        latents: Tensor,
        output_height: int,
        output_width: int,
        out_channels: int,
    ) -> Tuple[Tensor, Tensor]:
        pass

    @property
    def patch_size(self):
        return self._patch_size

    def get_latent_coords(
        self, latent_num_frames, latent_height, latent_width, batch_size, device
    ):
        """
        Return a tensor of shape [batch_size, 3, num_patches] containing the
            top-left corner latent coordinates of each latent patch.
        The tensor is repeated for each batch element.
        """
        latent_sample_coords = torch.meshgrid(
            torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
            torch.arange(0, latent_height, self._patch_size[1], device=device),
            torch.arange(0, latent_width, self._patch_size[2], device=device),
        )
        latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
        latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
        latent_coords = rearrange(
            latent_coords, "b c f h w -> b c (f h w)", b=batch_size
        )
        return latent_coords


class SymmetricPatchifier(Patchifier):
    def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
        b, _, f, h, w = latents.shape
        latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
        latents = rearrange(
            latents,
            "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
            p1=self._patch_size[0],
            p2=self._patch_size[1],
            p3=self._patch_size[2],
        )
        return latents, latent_coords

    def unpatchify(
        self,
        latents: Tensor,
        output_height: int,
        output_width: int,
        out_channels: int,
    ) -> Tuple[Tensor, Tensor]:
        output_height = output_height // self._patch_size[1]
        output_width = output_width // self._patch_size[2]
        latents = rearrange(
            latents,
            "b (f h w) (c p q) -> b c f (h p) (w q)",
            h=output_height,
            w=output_width,
            p=self._patch_size[1],
            q=self._patch_size[2],
        )
        return latents

```

## /ltx_video/models/transformers/transformer3d.py

```py path="/ltx_video/models/transformers/transformer3d.py" 
# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import os
import json
import glob
from pathlib import Path

import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.embeddings import PixArtAlphaTextProjection
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormSingle
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils import logging
from torch import nn
from safetensors import safe_open


from ltx_video.models.transformers.attention import BasicTransformerBlock
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy

from ltx_video.utils.diffusers_config_mapping import (
    diffusers_and_ours_config_mapping,
    make_hashable_key,
    TRANSFORMER_KEYS_RENAME_DICT,
)


logger = logging.get_logger(__name__)


@dataclass
class Transformer3DModelOutput(BaseOutput):
    """
    The output of [`Transformer2DModel`].

    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
            The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
            distributions for the unnoised latent pixels.
    """

    sample: torch.FloatTensor


class Transformer3DModel(ModelMixin, ConfigMixin):
    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        num_attention_heads: int = 16,
        attention_head_dim: int = 88,
        in_channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        num_layers: int = 1,
        dropout: float = 0.0,
        norm_num_groups: int = 32,
        cross_attention_dim: Optional[int] = None,
        attention_bias: bool = False,
        num_vector_embeds: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        use_linear_projection: bool = False,
        only_cross_attention: bool = False,
        double_self_attention: bool = False,
        upcast_attention: bool = False,
        adaptive_norm: str = "single_scale_shift",  # 'single_scale_shift' or 'single_scale'
        standardization_norm: str = "layer_norm",  # 'layer_norm' or 'rms_norm'
        norm_elementwise_affine: bool = True,
        norm_eps: float = 1e-5,
        attention_type: str = "default",
        caption_channels: int = None,
        use_tpu_flash_attention: bool = False,  # if True uses the TPU attention offload ('flash attention')
        qk_norm: Optional[str] = None,
        positional_embedding_type: str = "rope",
        positional_embedding_theta: Optional[float] = None,
        positional_embedding_max_pos: Optional[List[int]] = None,
        timestep_scale_multiplier: Optional[float] = None,
        causal_temporal_positioning: bool = False,  # For backward compatibility, will be deprecated
    ):
        super().__init__()
        self.use_tpu_flash_attention = (
            use_tpu_flash_attention  # FIXME: push config down to the attention modules
        )
        self.use_linear_projection = use_linear_projection
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        inner_dim = num_attention_heads * attention_head_dim
        self.inner_dim = inner_dim
        self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True)
        self.positional_embedding_type = positional_embedding_type
        self.positional_embedding_theta = positional_embedding_theta
        self.positional_embedding_max_pos = positional_embedding_max_pos
        self.use_rope = self.positional_embedding_type == "rope"
        self.timestep_scale_multiplier = timestep_scale_multiplier

        if self.positional_embedding_type == "absolute":
            raise ValueError("Absolute positional embedding is no longer supported")
        elif self.positional_embedding_type == "rope":
            if positional_embedding_theta is None:
                raise ValueError(
                    "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined"
                )
            if positional_embedding_max_pos is None:
                raise ValueError(
                    "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined"
                )

        # 3. Define transformers blocks
        self.transformer_blocks = nn.ModuleList(
            [
                BasicTransformerBlock(
                    inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    dropout=dropout,
                    cross_attention_dim=cross_attention_dim,
                    activation_fn=activation_fn,
                    num_embeds_ada_norm=num_embeds_ada_norm,
                    attention_bias=attention_bias,
                    only_cross_attention=only_cross_attention,
                    double_self_attention=double_self_attention,
                    upcast_attention=upcast_attention,
                    adaptive_norm=adaptive_norm,
                    standardization_norm=standardization_norm,
                    norm_elementwise_affine=norm_elementwise_affine,
                    norm_eps=norm_eps,
                    attention_type=attention_type,
                    use_tpu_flash_attention=use_tpu_flash_attention,
                    qk_norm=qk_norm,
                    use_rope=self.use_rope,
                )
                for d in range(num_layers)
            ]
        )

        # 4. Define output layers
        self.out_channels = in_channels if out_channels is None else out_channels
        self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
        self.scale_shift_table = nn.Parameter(
            torch.randn(2, inner_dim) / inner_dim**0.5
        )
        self.proj_out = nn.Linear(inner_dim, self.out_channels)

        self.adaln_single = AdaLayerNormSingle(
            inner_dim, use_additional_conditions=False
        )
        if adaptive_norm == "single_scale":
            self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)

        self.caption_projection = None
        if caption_channels is not None:
            self.caption_projection = PixArtAlphaTextProjection(
                in_features=caption_channels, hidden_size=inner_dim
            )

        self.gradient_checkpointing = False

    def set_use_tpu_flash_attention(self):
        r"""
        Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
        attention kernel.
        """
        logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
        self.use_tpu_flash_attention = True
        # push config down to the attention modules
        for block in self.transformer_blocks:
            block.set_use_tpu_flash_attention()

    def create_skip_layer_mask(
        self,
        batch_size: int,
        num_conds: int,
        ptb_index: int,
        skip_block_list: Optional[List[int]] = None,
    ):
        if skip_block_list is None or len(skip_block_list) == 0:
            return None
        num_layers = len(self.transformer_blocks)
        mask = torch.ones(
            (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype
        )
        for block_idx in skip_block_list:
            mask[block_idx, ptb_index::num_conds] = 0
        return mask

    def _set_gradient_checkpointing(self, module, value=False):
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    def get_fractional_positions(self, indices_grid):
        fractional_positions = torch.stack(
            [
                indices_grid[:, i] / self.positional_embedding_max_pos[i]
                for i in range(3)
            ],
            dim=-1,
        )
        return fractional_positions

    def precompute_freqs_cis(self, indices_grid, spacing="exp"):
        dtype = torch.float32  # We need full precision in the freqs_cis computation.
        dim = self.inner_dim
        theta = self.positional_embedding_theta

        fractional_positions = self.get_fractional_positions(indices_grid)

        start = 1
        end = theta
        device = fractional_positions.device
        if spacing == "exp":
            indices = theta ** (
                torch.linspace(
                    math.log(start, theta),
                    math.log(end, theta),
                    dim // 6,
                    device=device,
                    dtype=dtype,
                )
            )
            indices = indices.to(dtype=dtype)
        elif spacing == "exp_2":
            indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim)
            indices = indices.to(dtype=dtype)
        elif spacing == "linear":
            indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
        elif spacing == "sqrt":
            indices = torch.linspace(
                start**2, end**2, dim // 6, device=device, dtype=dtype
            ).sqrt()

        indices = indices * math.pi / 2

        if spacing == "exp_2":
            freqs = (
                (indices * fractional_positions.unsqueeze(-1))
                .transpose(-1, -2)
                .flatten(2)
            )
        else:
            freqs = (
                (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
                .transpose(-1, -2)
                .flatten(2)
            )

        cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
        sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
        if dim % 6 != 0:
            cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
            sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
            cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
            sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
        return cos_freq.to(self.dtype), sin_freq.to(self.dtype)

    def load_state_dict(
        self,
        state_dict: Dict,
        *args,
        **kwargs,
    ):
        if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]):
            state_dict = {
                key.replace("model.diffusion_model.", ""): value
                for key, value in state_dict.items()
                if key.startswith("model.diffusion_model.")
            }
        super().load_state_dict(state_dict, *args, **kwargs)

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_path: Optional[Union[str, os.PathLike]],
        *args,
        **kwargs,
    ):
        pretrained_model_path = Path(pretrained_model_path)
        if pretrained_model_path.is_dir():
            config_path = pretrained_model_path / "transformer" / "config.json"
            with open(config_path, "r") as f:
                config = make_hashable_key(json.load(f))

            assert config in diffusers_and_ours_config_mapping, (
                "Provided diffusers checkpoint config for transformer is not suppported. "
                "We only support diffusers configs found in Lightricks/LTX-Video."
            )

            config = diffusers_and_ours_config_mapping[config]
            state_dict = {}
            ckpt_paths = (
                pretrained_model_path
                / "transformer"
                / "diffusion_pytorch_model*.safetensors"
            )
            dict_list = glob.glob(str(ckpt_paths))
            for dict_path in dict_list:
                part_dict = {}
                with safe_open(dict_path, framework="pt", device="cpu") as f:
                    for k in f.keys():
                        part_dict[k] = f.get_tensor(k)
                state_dict.update(part_dict)

            for key in list(state_dict.keys()):
                new_key = key
                for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
                    new_key = new_key.replace(replace_key, rename_key)
                state_dict[new_key] = state_dict.pop(key)

            with torch.device("meta"):
                transformer = cls.from_config(config)
            transformer.load_state_dict(state_dict, assign=True, strict=True)
        elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
            ".safetensors"
        ):
            comfy_single_file_state_dict = {}
            with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
                metadata = f.metadata()
                for k in f.keys():
                    comfy_single_file_state_dict[k] = f.get_tensor(k)
            configs = json.loads(metadata["config"])
            transformer_config = configs["transformer"]
            with torch.device("meta"):
                transformer = Transformer3DModel.from_config(transformer_config)
            transformer.load_state_dict(comfy_single_file_state_dict, assign=True)
        return transformer

    def forward(
        self,
        hidden_states: torch.Tensor,
        indices_grid: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        timestep: Optional[torch.LongTensor] = None,
        class_labels: Optional[torch.LongTensor] = None,
        cross_attention_kwargs: Dict[str, Any] = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        skip_layer_mask: Optional[torch.Tensor] = None,
        skip_layer_strategy: Optional[SkipLayerStrategy] = None,
        return_dict: bool = True,
    ):
        """
        The [`Transformer2DModel`] forward method.

        Args:
            hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
                Input `hidden_states`.
            indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
            encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
                self-attention.
            timestep ( `torch.LongTensor`, *optional*):
                Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
            class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
                Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
                `AdaLayerZeroNorm`.
            cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            attention_mask ( `torch.Tensor`, *optional*):
                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
                negative values to the attention scores corresponding to "discard" tokens.
            encoder_attention_mask ( `torch.Tensor`, *optional*):
                Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:

                    * Mask `(batch, sequence_length)` True = keep, False = discard.
                    * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.

                If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
                above. This bias will be added to the cross-attention scores.
            skip_layer_mask ( `torch.Tensor`, *optional*):
                A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position
                `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index.
            skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`):
                Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
                tuple.

        Returns:
            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
            `tuple` where the first element is the sample tensor.
        """
        # for tpu attention offload 2d token masks are used. No need to transform.
        if not self.use_tpu_flash_attention:
            # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
            #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
            #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
            # expects mask of shape:
            #   [batch, key_tokens]
            # adds singleton query_tokens dimension:
            #   [batch,                    1, key_tokens]
            # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
            #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
            #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
            if attention_mask is not None and attention_mask.ndim == 2:
                # assume that mask is expressed as:
                #   (1 = keep,      0 = discard)
                # convert mask into a bias that can be added to attention scores:
                #       (keep = +0,     discard = -10000.0)
                attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
                attention_mask = attention_mask.unsqueeze(1)

            # convert encoder_attention_mask to a bias the same way we do for attention_mask
            if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
                encoder_attention_mask = (
                    1 - encoder_attention_mask.to(hidden_states.dtype)
                ) * -10000.0
                encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

        # 1. Input
        hidden_states = self.patchify_proj(hidden_states)

        if self.timestep_scale_multiplier:
            timestep = self.timestep_scale_multiplier * timestep

        freqs_cis = self.precompute_freqs_cis(indices_grid)

        batch_size = hidden_states.shape[0]
        timestep, embedded_timestep = self.adaln_single(
            timestep.flatten(),
            {"resolution": None, "aspect_ratio": None},
            batch_size=batch_size,
            hidden_dtype=hidden_states.dtype,
        )
        # Second dimension is 1 or number of tokens (if timestep_per_token)
        timestep = timestep.view(batch_size, -1, timestep.shape[-1])
        embedded_timestep = embedded_timestep.view(
            batch_size, -1, embedded_timestep.shape[-1]
        )

        # 2. Blocks
        if self.caption_projection is not None:
            batch_size = hidden_states.shape[0]
            encoder_hidden_states = self.caption_projection(encoder_hidden_states)
            encoder_hidden_states = encoder_hidden_states.view(
                batch_size, -1, hidden_states.shape[-1]
            )

        for block_idx, block in enumerate(self.transformer_blocks):
            if self.training and self.gradient_checkpointing:

                def create_custom_forward(module, return_dict=None):
                    def custom_forward(*inputs):
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)

                    return custom_forward

                ckpt_kwargs: Dict[str, Any] = (
                    {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                )
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    freqs_cis,
                    attention_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    timestep,
                    cross_attention_kwargs,
                    class_labels,
                    (
                        skip_layer_mask[block_idx]
                        if skip_layer_mask is not None
                        else None
                    ),
                    skip_layer_strategy,
                    **ckpt_kwargs,
                )
            else:
                hidden_states = block(
                    hidden_states,
                    freqs_cis=freqs_cis,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    timestep=timestep,
                    cross_attention_kwargs=cross_attention_kwargs,
                    class_labels=class_labels,
                    skip_layer_mask=(
                        skip_layer_mask[block_idx]
                        if skip_layer_mask is not None
                        else None
                    ),
                    skip_layer_strategy=skip_layer_strategy,
                )

        # 3. Output
        scale_shift_values = (
            self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
        )
        shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
        hidden_states = self.norm_out(hidden_states)
        # Modulation
        hidden_states = hidden_states * (1 + scale) + shift
        hidden_states = self.proj_out(hidden_states)
        if not return_dict:
            return (hidden_states,)

        return Transformer3DModelOutput(sample=hidden_states)

```

## /ltx_video/pipelines/__init__.py

```py path="/ltx_video/pipelines/__init__.py" 

```

## /ltx_video/pipelines/crf_compressor.py

```py path="/ltx_video/pipelines/crf_compressor.py" 
import av
import torch
import io
import numpy as np


def _encode_single_frame(output_file, image_array: np.ndarray, crf):
    container = av.open(output_file, "w", format="mp4")
    try:
        stream = container.add_stream(
            "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
        )
        stream.height = image_array.shape[0]
        stream.width = image_array.shape[1]
        av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
            format="yuv420p"
        )
        container.mux(stream.encode(av_frame))
        container.mux(stream.encode())
    finally:
        container.close()


def _decode_single_frame(video_file):
    container = av.open(video_file)
    try:
        stream = next(s for s in container.streams if s.type == "video")
        frame = next(container.decode(stream))
    finally:
        container.close()
    return frame.to_ndarray(format="rgb24")


def compress(image: torch.Tensor, crf=29):
    if crf == 0:
        return image

    image_array = (
        (image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0)
        .byte()
        .cpu()
        .numpy()
    )
    with io.BytesIO() as output_file:
        _encode_single_frame(output_file, image_array, crf)
        video_bytes = output_file.getvalue()
    with io.BytesIO(video_bytes) as video_file:
        image_array = _decode_single_frame(video_file)
    tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
    return tensor

```

## /ltx_video/schedulers/__init__.py

```py path="/ltx_video/schedulers/__init__.py" 

```

## /ltx_video/utils/__init__.py

```py path="/ltx_video/utils/__init__.py" 

```

## /ltx_video/utils/diffusers_config_mapping.py

```py path="/ltx_video/utils/diffusers_config_mapping.py" 
def make_hashable_key(dict_key):
    def convert_value(value):
        if isinstance(value, list):
            return tuple(value)
        elif isinstance(value, dict):
            return tuple(sorted((k, convert_value(v)) for k, v in value.items()))
        else:
            return value

    return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items()))


DIFFUSERS_SCHEDULER_CONFIG = {
    "_class_name": "FlowMatchEulerDiscreteScheduler",
    "_diffusers_version": "0.32.0.dev0",
    "base_image_seq_len": 1024,
    "base_shift": 0.95,
    "invert_sigmas": False,
    "max_image_seq_len": 4096,
    "max_shift": 2.05,
    "num_train_timesteps": 1000,
    "shift": 1.0,
    "shift_terminal": 0.1,
    "use_beta_sigmas": False,
    "use_dynamic_shifting": True,
    "use_exponential_sigmas": False,
    "use_karras_sigmas": False,
}
DIFFUSERS_TRANSFORMER_CONFIG = {
    "_class_name": "LTXVideoTransformer3DModel",
    "_diffusers_version": "0.32.0.dev0",
    "activation_fn": "gelu-approximate",
    "attention_bias": True,
    "attention_head_dim": 64,
    "attention_out_bias": True,
    "caption_channels": 4096,
    "cross_attention_dim": 2048,
    "in_channels": 128,
    "norm_elementwise_affine": False,
    "norm_eps": 1e-06,
    "num_attention_heads": 32,
    "num_layers": 28,
    "out_channels": 128,
    "patch_size": 1,
    "patch_size_t": 1,
    "qk_norm": "rms_norm_across_heads",
}
DIFFUSERS_VAE_CONFIG = {
    "_class_name": "AutoencoderKLLTXVideo",
    "_diffusers_version": "0.32.0.dev0",
    "block_out_channels": [128, 256, 512, 512],
    "decoder_causal": False,
    "encoder_causal": True,
    "in_channels": 3,
    "latent_channels": 128,
    "layers_per_block": [4, 3, 3, 3, 4],
    "out_channels": 3,
    "patch_size": 4,
    "patch_size_t": 1,
    "resnet_norm_eps": 1e-06,
    "scaling_factor": 1.0,
    "spatio_temporal_scaling": [True, True, True, False],
}

OURS_SCHEDULER_CONFIG = {
    "_class_name": "RectifiedFlowScheduler",
    "_diffusers_version": "0.25.1",
    "num_train_timesteps": 1000,
    "shifting": "SD3",
    "base_resolution": None,
    "target_shift_terminal": 0.1,
}

OURS_TRANSFORMER_CONFIG = {
    "_class_name": "Transformer3DModel",
    "_diffusers_version": "0.25.1",
    "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256",
    "activation_fn": "gelu-approximate",
    "attention_bias": True,
    "attention_head_dim": 64,
    "attention_type": "default",
    "caption_channels": 4096,
    "cross_attention_dim": 2048,
    "double_self_attention": False,
    "dropout": 0.0,
    "in_channels": 128,
    "norm_elementwise_affine": False,
    "norm_eps": 1e-06,
    "norm_num_groups": 32,
    "num_attention_heads": 32,
    "num_embeds_ada_norm": 1000,
    "num_layers": 28,
    "num_vector_embeds": None,
    "only_cross_attention": False,
    "out_channels": 128,
    "project_to_2d_pos": True,
    "upcast_attention": False,
    "use_linear_projection": False,
    "qk_norm": "rms_norm",
    "standardization_norm": "rms_norm",
    "positional_embedding_type": "rope",
    "positional_embedding_theta": 10000.0,
    "positional_embedding_max_pos": [20, 2048, 2048],
    "timestep_scale_multiplier": 1000,
}
OURS_VAE_CONFIG = {
    "_class_name": "CausalVideoAutoencoder",
    "dims": 3,
    "in_channels": 3,
    "out_channels": 3,
    "latent_channels": 128,
    "blocks": [
        ["res_x", 4],
        ["compress_all", 1],
        ["res_x_y", 1],
        ["res_x", 3],
        ["compress_all", 1],
        ["res_x_y", 1],
        ["res_x", 3],
        ["compress_all", 1],
        ["res_x", 3],
        ["res_x", 4],
    ],
    "scaling_factor": 1.0,
    "norm_layer": "pixel_norm",
    "patch_size": 4,
    "latent_log_var": "uniform",
    "use_quant_conv": False,
    "causal_decoder": False,
}


diffusers_and_ours_config_mapping = {
    make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG,
    make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG,
    make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG,
}


TRANSFORMER_KEYS_RENAME_DICT = {
    "proj_in": "patchify_proj",
    "time_embed": "adaln_single",
    "norm_q": "q_norm",
    "norm_k": "k_norm",
}


VAE_KEYS_RENAME_DICT = {
    "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7",
    "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8",
    "decoder.up_blocks.3": "decoder.up_blocks.9",
    "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5",
    "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4",
    "decoder.up_blocks.2": "decoder.up_blocks.6",
    "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2",
    "decoder.up_blocks.1": "decoder.up_blocks.3",
    "decoder.up_blocks.0": "decoder.up_blocks.1",
    "decoder.mid_block": "decoder.up_blocks.0",
    "encoder.down_blocks.3": "encoder.down_blocks.8",
    "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7",
    "encoder.down_blocks.2": "encoder.down_blocks.6",
    "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4",
    "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5",
    "encoder.down_blocks.1": "encoder.down_blocks.3",
    "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2",
    "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1",
    "encoder.down_blocks.0": "encoder.down_blocks.0",
    "encoder.mid_block": "encoder.down_blocks.9",
    "conv_shortcut.conv": "conv_shortcut",
    "resnets": "res_blocks",
    "norm3": "norm3.norm",
    "latents_mean": "per_channel_statistics.mean-of-means",
    "latents_std": "per_channel_statistics.std-of-means",
}

```

## /ltx_video/utils/skip_layer_strategy.py

```py path="/ltx_video/utils/skip_layer_strategy.py" 
from enum import Enum, auto


class SkipLayerStrategy(Enum):
    AttentionSkip = auto()
    AttentionValues = auto()
    Residual = auto()
    TransformerBlock = auto()

```

## /ltx_video/utils/torch_utils.py

```py path="/ltx_video/utils/torch_utils.py" 
import torch
from torch import nn


def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(
            f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
        )
    elif dims_to_append == 0:
        return x
    return x[(...,) + (None,) * dims_to_append]


class Identity(nn.Module):
    """A placeholder identity operator that is argument-insensitive."""

    def __init__(self, *args, **kwargs) -> None:  # pylint: disable=unused-argument
        super().__init__()

    # pylint: disable=unused-argument
    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        return x

```

## /pyproject.toml

```toml path="/pyproject.toml" 
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "ltx-video"
version = "0.1.2"
description = "A package for LTX-Video model"
authors = [
    { name = "LTX-Video Team", email = "ltx-video@lightricks.com" }
]
requires-python = ">=3.10"
readme = "README.md"
classifiers = [
    "Programming Language :: Python :: 3",
    "Operating System :: OS Independent"
]
dependencies = [
    "torch>=2.1.0",
    "diffusers>=0.28.2",
    "transformers>=4.47.2,<4.52.0",
    "sentencepiece>=0.1.96",
    "huggingface-hub~=0.30",
    "einops",
    "timm"
]

[project.optional-dependencies]
inference = [
    "imageio[ffmpeg]",
    "av",
    "torchvision"
]
test = [
    "pytest",
]

[tool.setuptools.packages.find]
include = ["ltx_video*"]

[tool.setuptools.package-data]
ltx_video = ["configs/*.yaml"]

```

## /tests/conftest.py

```py path="/tests/conftest.py" 
import json
import pytest
import safetensors.torch
import torch

from ltx_video.models.autoencoders.causal_video_autoencoder import (
    CausalVideoAutoencoder,
    create_video_autoencoder_demo_config,
    PER_CHANNEL_STATISTICS_PREFIX,
)
from ltx_video.models.transformers.transformer3d import Transformer3DModel


def pytest_make_parametrize_id(config, val, argname):
    if isinstance(val, str):
        return f"{argname}-{val}"
    return f"{argname}-{repr(val)}"


@pytest.fixture
def num_latent_channels():
    return 16


@pytest.fixture
def video_autoencoder(num_latent_channels):
    config = create_video_autoencoder_demo_config(latent_channels=num_latent_channels)
    model = CausalVideoAutoencoder.from_config(config)
    model.eval().to(torch.bfloat16)
    return model


@pytest.fixture
def transformer_config(num_latent_channels):
    transformer_config = {
        "activation_fn": "gelu-approximate",
        "attention_bias": True,
        "attention_head_dim": 12,
        "attention_type": "default",
        "caption_channels": 4096,
        "cross_attention_dim": 192,
        "double_self_attention": False,
        "dropout": 0.0,
        "in_channels": num_latent_channels,
        "norm_elementwise_affine": False,
        "norm_eps": 1e-06,
        "norm_num_groups": 32,
        "num_attention_heads": 16,
        "num_embeds_ada_norm": 1000,
        "num_layers": 2,
        "num_vector_embeds": None,
        "only_cross_attention": False,
        "out_channels": num_latent_channels,
        "upcast_attention": False,
        "use_linear_projection": False,
        "qk_norm": "rms_norm",
        "standardization_norm": "rms_norm",
        "positional_embedding_type": "rope",
        "positional_embedding_theta": 10000.0,
        "positional_embedding_max_pos": [120, 1, 1],
        "timestep_scale_multiplier": 1000,
    }
    return transformer_config


@pytest.fixture
def synthetic_ckpt_path(
    tmp_path, video_autoencoder, num_latent_channels, transformer_config
):
    # Create transformer
    transformer = Transformer3DModel.from_config(transformer_config)
    transformer.to(torch.bfloat16)

    # Prepare configs and state dicts
    configs = {"transformer": transformer_config, "vae": vars(video_autoencoder.config)}
    transformer_sd = transformer.state_dict()
    transformer_sd = {
        "model.diffusion_model." + key: value for key, value in transformer_sd.items()
    }

    # Prepare VAE state dict with per-channel statistics
    vae_sd = video_autoencoder.state_dict()
    vae_sd[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = torch.rand(
        num_latent_channels,
    )
    vae_sd[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = torch.rand(
        num_latent_channels,
    )
    vae_sd = {"vae." + key: value for key, value in vae_sd.items()}

    out_file_path = f"{tmp_path}/test_ckpt.safetensors"
    safetensors.torch.save_file(
        {**transformer_sd, **vae_sd},
        out_file_path,
        metadata={"config": json.dumps(configs)},
    )
    return out_file_path

```


The content has been capped at 50000 tokens. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.
Copied!