felixtaubner/cap4d/main 585k tokens More Tools
```
├── .gitignore (900 tokens)
├── LICENSE (omitted)
├── README.md (2k tokens)
├── assets/
   ├── banner.gif
├── cap4d/
   ├── datasets/
      ├── utils.py (1800 tokens)
   ├── flame/
      ├── flame.py (1400 tokens)
      ├── mouth.py (700 tokens)
   ├── inference/
      ├── data/
         ├── generation_data.py (700 tokens)
         ├── inference_data.py (900 tokens)
         ├── reference_data.py (400 tokens)
      ├── generate_images.py (1300 tokens)
      ├── utils.py (1000 tokens)
   ├── mmdm/
      ├── conditioning/
         ├── cap4dcond.py (1300 tokens)
         ├── mesh2img.py (2.4k tokens)
      ├── mmdm.py (3.2k tokens)
      ├── modules/
         ├── module.py (400 tokens)
      ├── net/
         ├── attention.py (2.5k tokens)
         ├── mmdm_unet.py (800 tokens)
      ├── sampler.py (2.3k tokens)
      ├── utils.py (300 tokens)
├── configs/
   ├── avatar/
      ├── debug.yaml (300 tokens)
      ├── default.yaml (300 tokens)
      ├── high_quality.yaml (300 tokens)
      ├── low_quality.yaml (300 tokens)
      ├── medium_quality.yaml (300 tokens)
   ├── generation/
      ├── debug.yaml
      ├── default.yaml
      ├── high_quality.yaml
      ├── low_quality.yaml
      ├── medium_quality.yaml
   ├── mmdm/
      ├── cap4d_mmdm_final.yaml (900 tokens)
├── controlnet/
   ├── cldm/
      ├── ddim_hacked.py (3.3k tokens)
      ├── hack.py (700 tokens)
      ├── logger.py (900 tokens)
      ├── model.py (200 tokens)
   ├── ldm/
      ├── data/
         ├── __init__.py
         ├── util.py (100 tokens)
      ├── models/
         ├── autoencoder.py (1700 tokens)
         ├── diffusion/
            ├── __init__.py
            ├── ddim.py (3.6k tokens)
            ├── ddpm.py (17k tokens)
            ├── dpm_solver/
               ├── __init__.py
               ├── dpm_solver.py (13.2k tokens)
               ├── sampler.py (600 tokens)
            ├── plms.py (2.6k tokens)
            ├── sampling_util.py (200 tokens)
      ├── modules/
         ├── attention.py (2.4k tokens)
         ├── diffusionmodules/
            ├── __init__.py
            ├── model.py (6.9k tokens)
            ├── openaimodel.py (6.7k tokens)
            ├── upscaling.py (700 tokens)
            ├── util.py (2000 tokens)
         ├── distributions/
            ├── __init__.py
            ├── distributions.py (600 tokens)
         ├── ema.py (600 tokens)
         ├── encoders/
            ├── __init__.py
            ├── modules.py (1700 tokens)
      ├── util.py (1300 tokens)
   ├── share.py
├── data/
   ├── assets/
      ├── datasets/
         ├── gen_data.npz
      ├── flame/
         ├── blink_blendshape.npy
         ├── cap4d_avatar_template.obj (206.3k tokens)
         ├── cap4d_flame_template.obj (117.3k tokens)
         ├── deformable_verts.txt (4.7k tokens)
         ├── flowface_vertex_mask.npy
         ├── flowface_vertex_weights.npy
         ├── head_template_mesh.obj (112.7k tokens)
         ├── head_vertices.txt (4.6k tokens)
         ├── jaw_regressor.npy
   ├── weights/
      ├── mmdm/
         ├── config_dump.yaml (900 tokens)
├── examples/
   ├── input/
      ├── animation/
         ├── example_video.mp4
         ├── sequence_00/
            ├── fit.npz
            ├── orbit.npz
         ├── sequence_01/
            ├── fit.npz
            ├── orbit.npz
      ├── felix/
         ├── alignment.npz
         ├── bg/
            ├── cam0/
               ├── 0000.png
               ├── 0001.png
               ├── 0002.png
               ├── 0003.png
         ├── fit.npz
         ├── images/
            ├── cam0/
               ├── 00000001.png
               ├── 00000002.png
               ├── 00000003.png
               ├── 00000004.png
         ├── reference_images.json
         ├── visualization/
            ├── vis_cam0.mp4
      ├── lincoln/
         ├── alignment.npz
         ├── fit.npz
         ├── images/
            ├── cam0/
               ├── abraham_lincoln.png
         ├── reference_images.json
         ├── visualization/
            ├── vis_cam0.mp4
      ├── tesla/
         ├── alignment.npz
         ├── bg/
            ├── cam0/
               ├── 0000.png
         ├── fit.npz
         ├── images/
            ├── cam0/
               ├── tesla.png
         ├── reference_images.json
         ├── visualization/
            ├── vis_cam0.mp4
├── flowface/
   ├── flame/
      ├── flame.py (2.7k tokens)
      ├── io.py (300 tokens)
      ├── utils.py (700 tokens)
├── gaussianavatars/
   ├── LICENSE_GS.md (900 tokens)
   ├── animate.py (1500 tokens)
   ├── gaussian_renderer/
      ├── gsplat_renderer.py (600 tokens)
   ├── lpipsPyTorch/
      ├── __init__.py (100 tokens)
      ├── modules/
         ├── lpips.py (200 tokens)
         ├── networks.py (500 tokens)
         ├── utils.py (200 tokens)
   ├── scene/
      ├── cameras.py (300 tokens)
      ├── cap4d_gaussian_model.py (3.4k tokens)
      ├── dataset_readers.py (1700 tokens)
      ├── gaussian_model.py (5.6k tokens)
      ├── net/
         ├── positional_encoding.py (100 tokens)
         ├── unet.py (2.2k tokens)
      ├── scene.py (1100 tokens)
   ├── train.py (3.5k tokens)
   ├── utils/
      ├── camera_utils.py (300 tokens)
      ├── export_utils.py (1600 tokens)
      ├── general_utils.py (800 tokens)
      ├── graphics_utils.py (1000 tokens)
      ├── image_utils.py (200 tokens)
      ├── loss_utils.py (400 tokens)
      ├── mesh_utils.py (100 tokens)
      ├── sh_utils.py (900 tokens)
      ├── system_utils.py (200 tokens)
├── requirements.txt
├── scripts/
   ├── download_flame.sh (200 tokens)
   ├── download_mmdm_weights.sh
   ├── fixes/
      ├── fix_flame_pickle.py (400 tokens)
      ├── pipnet_cpu_nms.pyx (1100 tokens)
   ├── generate_avatar.sh (800 tokens)
   ├── generate_felix.sh (200 tokens)
   ├── generate_lincoln.sh (200 tokens)
   ├── generate_tesla.sh (200 tokens)
   ├── install_pixel3dmm.sh (600 tokens)
   ├── pixel3dmm/
      ├── convert_to_flowface.py (4.1k tokens)
      ├── l2cs_eye_tracker.py (1300 tokens)
      ├── robust_video_matting/
         ├── model/
            ├── decoder.py (1400 tokens)
            ├── deep_guided_filter.py (500 tokens)
            ├── fast_guided_filter.py (600 tokens)
            ├── lraspp.py (200 tokens)
            ├── mobilenetv3.py (600 tokens)
            ├── model.py (600 tokens)
   ├── test_pipeline.sh (200 tokens)
   ├── track_video_pixel3dmm.sh (500 tokens)
```


## /.gitignore

```gitignore path="/.gitignore" 
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# UV
#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#uv.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/

# Visual Studio Code
#  Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 
#  that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
#  and can be added to the global gitignore or merged into this file. However, if you prefer, 
#  you could uncomment the following to ignore the enitre vscode folder
# .vscode/

# Ruff stuff:
.ruff_cache/

# PyPI configuration file
.pypirc

# Cursor
#  Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
#  exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
#  refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore

```

## /README.md

# 🧢 CAP4D
Official repository for the paper

**CAP4D: Creating Animatable 4D Portrait Avatars with Morphable Multi-View Diffusion Models**, ***CVPR 2025 (Oral)***.

<a href="https://felixtaubner.github.io/" target="_blank">Felix Taubner</a><sup>1,2</sup>, <a href="https://ruihangzhang97.github.io/" target="_blank">Ruihang Zhang</a><sup>1</sup>, <a href="https://mathieutuli.com/" target="_blank">Mathieu Tuli</a><sup>3</sup>, <a href="https://davidlindell.com/" target="_blank">David B. Lindell</a><sup>1,2</sup>

<sup>1</sup>University of Toronto, <sup>2</sup>Vector Institute, <sup>3</sup>LG Electronics

<a href='https://arxiv.org/abs/2412.12093'><img src='https://img.shields.io/badge/arXiv-2301.02379-red'></a> <a href='https://felixtaubner.github.io/cap4d/'><img src='https://img.shields.io/badge/project page-CAP4D-Green'></a> <a href='#citation'><img src='https://img.shields.io/badge/cite-blue'></a>

![Preview](assets/banner.gif)

TL;DR: CAP4D turns any number of reference images into an animatable avatar. 

## ⚡️ Quick start guide

### 🛠️ 1. Create conda environment and install requirements

```bash
# 1. Clone repo
git clone https://github.com/felixtaubner/cap4d/
cd cap4d

# 2. Create conda environment for CAP4D:
conda create --name cap4d_env python=3.10
conda activate cap4d_env

# 3. Install requirements
pip install -r requirements.txt

# 4. Set python path
export PYTHONPATH=$(realpath "./"):$PYTHONPATH
```
Follow the [instructions](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md) and install Pytorch3D. Make sure to install with CUDA support. We recommend to install from source: 

```bash
export FORCE_CUDA=1
pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
```

### 📦 2. Download FLAME and MMDM weights
Setup your FLAME account at the [FLAME website](https://flame.is.tue.mpg.de/index.html) and set the username 
and password environment variables:
```bash
export FLAME_USERNAME=your_flame_user_name
export FLAME_PWD=your_flame_password
```

Download FLAME and MMDM weights using the provided scripts:

```bash 
# 1. Download FLAME blendshapes
# set your flame username and password
bash scripts/download_flame.sh 

# 2. Download CAP4D MMDM weights
bash scripts/download_mmdm_weights.sh
```

If the FLAME download script did not work, download FLAME2023 from the [FLAME website](https://flame.is.tue.mpg.de/index.html) and place `flame2023_no_jaw.pkl` in `data/assets/flame/`.
Then, fix the flame pkl file to be compatible with newer numpy versions:

```bash
python scripts/fixes/fix_flame_pickle.py --pickle_path data/assets/flame/flame2023_no_jaw.pkl
```

### ✅ 3. Check installation with a test run
Run the pipeline in debug settings to test the installation.

```bash
bash scripts/test_pipeline.sh
```

Check if a video is exported to `examples/debug_output/tesla/sequence_00/renders.mp4`.
If it appears to show a blurry cartoon Nicola Tesla, you're all set! 

### 🎬 4. Inference 
Run the provided scripts to generate avatars and animate them with a single script:

```bash
bash scripts/generate_felix.sh
bash scripts/generate_lincoln.sh
bash scripts/generate_tesla.sh
```

The output directories contain exported animations which you can view in real-time.
Open the [real-time viewer](https://felixtaubner.github.io/cap4d/viewer/) in your browser (powered by [Brush](https://github.com/ArthurBrussee/brush/)). Click `Load file` and
upload the exported animation found in `examples/output/{SUBJECT}/animation_{ID}/exported_animation.ply`.

## 🔧 Custom inference

See below for how to run your custom inference on your own reference images/videos and driving videos.

### ⚙️ 1. Run FLAME 3D face tracking

#### 1.1 FlowFace tracking
Coming soon! For now, only generations using the provided identities with precomputed [FlowFace](https://felixtaubner.github.io/flowface/) annotations are supported. 

#### 1.2 Pixel3DMM tracking
Install [Pixel3DMM](https://github.com/SimonGiebenhain/pixel3dmm) using the provided script. Notice that this is prone to errors due to package version mismatches. Please report any errors as an issue!

```bash
export FLAME_USERNAME=your_flame_user_name
export FLAME_PWD=your_flame_password
export PIXEL3DMM_PATH=$(realpath "../PATH/TO/pixel3dmm")  # set this to where you would like to clone the Pixel3DMM repo (absolute path)
export CAP4D_PATH=$(realpath "./")  # set this to the cap4d directory (absolute path)

bash scripts/install_pixel3Dmm.sh
```

Run tracking and conversion on reference images/videos using the provided script. Note: If input is a directory of frames, it is assumed to be discontinous set of (monocular!) images. If input is a file, it will assume that it is a continous monocular video.

```bash
export PIXEL3DMM_PATH=$(realpath "../PATH/TO/pixel3dmm")
export CAP4D_PATH=$(realpath "./") 

mkdir examples/output/custom/

# For more information on arguments
bash scripts/track_video_pixel3dmm.sh --help

# Process a directory of (reference) images
bash scripts/track_video_pixel3dmm.sh examples/input/felix/images/cam0/ examples/output/custom/reference_tracking/

# Optional: process a driving (or reference) video
bash scripts/track_video_pixel3dmm.sh examples/input/animation/example_video.mp4 examples/output/custom/driving_video_tracking/
```

Notice that results will be slightly worse than with FlowFace tracking, since the MMDM is trained with FlowFace.

### 🖼️ 2. Generate images using MMDM

```bash
# Generate images with single reference image
python cap4d/inference/generate_images.py --config_path configs/generation/default.yaml --reference_data_path examples/output/custom/reference_tracking/ --output_path examples/output/custom/mmdm/
```
Note: the generation script will use all visible CUDA devices. The more available devices, the faster it runs! This will take hours, and requires lots of RAM (ideally > 64 GB) to run smoothly.

### 👤 3. Fit Gaussian avatar 

```bash
python gaussianavatars/train.py --config_path configs/avatar/default.yaml --source_paths examples/output/custom/mmdm/reference_images/ examples/output/custom/mmdm/generated_images/ --model_path examples/output/custom/avatar/ --interval 5000
```

### 🕺 4. Animate your avatar

Once the avatar is generated, it can be animated with the driving video computed in step 1 or the provided animations. 

```bash
# Animate the avatar with provided animation files
python gaussianavatars/animate.py --model_path examples/output/custom/avatar/ --target_animation_path examples/input/animation/sequence_00/fit.npz  --target_cam_trajectory_path examples/input/animation/sequence_00/orbit.npz  --output_path examples/output/custom/animation_00/ --export_ply 1 --compress_ply 0

# Animate the avatar with driving video (computed using Pixel3DMM)
python gaussianavatars/animate.py --model_path examples/output/custom/avatar/ --target_animation_path examples/output/custom/driving_video_tracking/fit.npz  --target_cam_trajectory_path examples/output/custom/driving_video_tracking/cam_static.npz  --output_path examples/output/custom/animation_example/ --export_ply 1 --compress_ply 0
```

The `--target_animation_path` argument contains FLAME expressions and pose, while the (optional) `--target_cam_trajectory_path` argument contains the relative camera trajectory. 

### ⚡️ 5. Full inference

We provide a convenient script to run full inference using your reference images and optionally a driving video.

```bash
export PIXEL3DMM_PATH=$(realpath "../PATH/TO/pixel3dmm")
export CAP4D_PATH=$(realpath "./") 

# Generate avatar with custom input images/videos.
bash scripts/generate_avatar.sh --help
bash scripts/generate_avatar.sh {INPUT_VIDEO_PATH} {OUTPUT_PATH} [{QUALITY}] [{DRIVING_VIDEO_PATH}]

# Example generation with default quality generation with input images and driving video.
bash scripts/generate_avatar.sh examples/input/felix/images/cam0/ examples/output/felix_custom/ default examples/input/animation/example_video.mp4
```

### ✨ 6. View avatar in live viewer

Open the [real-time viewer](https://felixtaubner.github.io/cap4d/viewer/) in your browser (powered by [Brush](https://github.com/ArthurBrussee/brush/)). Click `Load file` and
upload the exported animation found in 
`examples/output/custom/animation_00/exported_animation.ply` or
`examples/output/custom/animation_example/exported_animation.ply`.

## 📚 Related Resources

The MMDM code is based on [ControlNet](https://github.com/lllyasviel/ControlNet). The 4D Gaussian avatar code is based on [GaussianAvatars](https://github.com/ShenhanQian/GaussianAvatars). Special thanks to the authors for making their code public!

Related work: 
- [CAT3D](https://cat3d.github.io/): Create Anything in 3D with Multi-View Diffusion Models
- [GaussianAvatars](https://shenhanqian.github.io/gaussian-avatars): Photorealistic Head Avatars with Rigged 3D Gaussians
- [FlowFace](https://felixtaubner.github.io/flowface/): 3D Face Tracking from 2D Video through Iterative Dense UV to Image Flow
- [StableDiffusion](https://github.com/Stability-AI/stablediffusion): High-Resolution Image Synthesis with Latent Diffusion Models
- [Pixel3DMM](https://github.com/SimonGiebenhain/pixel3dmm): Versatile Screen-Space Priors for Single-Image 3D Face Reconstruction

Awesome concurrent work:
- [Pippo](https://yashkant.github.io/pippo/): High-Resolution Multi-View Humans from a Single Image
- [Avat3r](https://tobias-kirschstein.github.io/avat3r/): Large Animatable Gaussian Reconstruction Model for High-fidelity 3D Head Avatars

## 📖 Citation

```tex
@inproceedings{taubner2025cap4d,
    author    = {Taubner, Felix and Zhang, Ruihang and Tuli, Mathieu and Lindell, David B.},
    title     = {{CAP4D}: Creating Animatable {4D} Portrait Avatars with Morphable Multi-View Diffusion Models},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2025},
    pages     = {5318-5330}
}
```

## Acknowledgement
This work was developed in collaboration with and with sponsorship from **LG Electronics**. We gratefully acknowledge their support and contributions throughout the course of this project.


## /assets/banner.gif

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/assets/banner.gif

## /cap4d/datasets/utils.py

```py path="/cap4d/datasets/utils.py" 
import contextlib
from pathlib import Path
from typing import Dict

import numpy as np
import einops
import cv2
from decord import VideoReader
from scipy.spatial.transform import Rotation as R

from cap4d.flame.flame import CAP4DFlameSkinner, compute_flame


CROP_MARGIN = 0.2


@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)


def crop_image(
    img: np.ndarray, 
    crop_box: np.ndarray, 
    bg_value=0,
) -> np.ndarray:
    """
    Crop image with the provided crop ranges. If crop box out of image range,
    corresponding pixels will be padded black
    """
    img_h = img.shape[0]
    img_w = img.shape[1]
    crop_h = crop_box[3] - crop_box[1]
    crop_w = crop_box[2] - crop_box[0]
    x_start = max(0, -crop_box[0])
    x_end = max(0, crop_box[2] - img_w)
    y_start = max(0, -crop_box[1])
    y_end = max(0, crop_box[3] - img_h)
    cropped_img = np.ones((crop_h, crop_w, *img.shape[2:]), dtype=img.dtype) * bg_value
    cropped_img[y_start : crop_h - y_end, x_start : crop_w - x_end, ...] = img[
        crop_box[1] + y_start : crop_box[3] - y_end,
        crop_box[0] + x_start : crop_box[2] - x_end,
        ...,
    ]

    return cropped_img


def rescale_image(
    img: np.ndarray,
    target_resolution: int,
):
    interpolation_mode = cv2.INTER_LINEAR
    if target_resolution < img.shape[0]:
        interpolation_mode = cv2.INTER_AREA
    
    img = cv2.resize(img, (target_resolution, target_resolution), interpolation=interpolation_mode)

    return img


def apply_bg(
    img: np.ndarray,
    bg_weights: np.ndarray,
    bg_color: np.ndarray = np.array([255, 255, 255]),
):
    bg_weights = bg_weights / 255.

    bg_img = bg_color[None, None]
    img = bg_img * (1. - bg_weights) + img * bg_weights

    return img


def verts_to_pytorch3d(
    verts_2d: np.ndarray,
    crop_box: np.ndarray,
):
    """
    convert vertex 2D coordinates to pytorch3d screen space convention
    """
    verts_2d[..., 0] = -((verts_2d[..., 0] - crop_box[..., 0]) / (crop_box[..., 2] - crop_box[..., 0]) * 2. - 1.)
    verts_2d[..., 1] = -((verts_2d[..., 1] - crop_box[..., 1]) / (crop_box[..., 3] - crop_box[..., 1]) * 2. - 1.)

    return verts_2d


def get_square_bbox(
    bbox: np.ndarray,
    border_margin: float = 0.1,
    mode: str = "max",  # min or max
):
    """
    Square crops the image with a specified bounding box.
    The image size will be squared, adjusted to the bounding box and min_border
    width.

    Parameters
    ----------
    img_shape: tuple[int, int]
        the shape of the image (h, w)
    bbox: np.ndarray
        the face bounding box

    Returns
    -------
    crop_box: tuple[int, int, int, int]
        the index ranges taken from the original image
        x_min, y_min, x_max, y_max
    """

    bbox = bbox.astype(int)

    bbox_h = bbox[3] - bbox[1]
    bbox_w = bbox[2] - bbox[0]
    b_center = ((bbox[2] + bbox[0]) // 2, (bbox[3] + bbox[1]) // 2)
    if mode == "max":
        dim = int(max(bbox_h, bbox_w) // 2.0 * (1.0 + border_margin))
    elif mode == "min":
        dim = int(min(bbox_h, bbox_w) // 2.0 * (1.0 + border_margin))

    return (
        b_center[0] - dim,
        b_center[1] - dim,
        b_center[0] + dim,
        b_center[1] + dim,
    )


def get_bbox_from_verts(verts_2d, vert_mask):
    head_verts = verts_2d[vert_mask]
    head_bbox = [head_verts[..., 0].min(), head_verts[..., 1].min(), head_verts[..., 0].max(), head_verts[..., 1].max()]
    crop_box = get_square_bbox(np.array(head_bbox), border_margin=CROP_MARGIN)

    return np.array(crop_box)


def load_flame_verts_and_cam(
    flame_skinner: CAP4DFlameSkinner,
    flame_item: Dict[str, np.ndarray],
):
    flame_out = compute_flame(flame_skinner, flame_item)

    verts_2d = flame_out["verts_2d"][0, 0]
    offsets_3d = flame_out["offsets_3d"][0]

    intrinsics = np.eye(3)
    intrinsics[0, 0] = flame_item["fx"][0, 0]
    intrinsics[1, 1] = flame_item["fy"][0, 0]
    intrinsics[0, 2] = flame_item["cx"][0, 0]
    intrinsics[1, 2] = flame_item["cy"][0, 0]
    extrinsics = flame_item["extr"][0]

    return verts_2d, offsets_3d, intrinsics, extrinsics


def load_camera_rays(
    crop_box,
    intr,
    extr,
    target_resolution,
):
    downscale_resolution = target_resolution

    scale = downscale_resolution / (crop_box[2] - crop_box[0])
    new_fx = intr[0, 0] * scale
    new_fy = intr[1, 1] * scale
    new_cx = (intr[0, 2] - crop_box[0]) * scale
    new_cy = (intr[1, 2] - crop_box[1]) * scale

    u, v = np.meshgrid(np.arange(downscale_resolution), np.arange(downscale_resolution)) # [H, w]
    
    d = np.stack(((u - new_cx) / new_fx, (v - new_cy) / new_fy, np.ones_like(u)), axis=0)
    d = d / (np.linalg.norm(d, axis=0, keepdims=True) + 1e-8)
    h, w = d.shape[1:]

    # project camera coordinates back to world
    d = einops.rearrange(d, 'v h w -> v (h w)')
    d = np.linalg.inv(extr[:3, :3]) @ d
    d = einops.rearrange(d, 'v (h w) -> v h w', h=h)

    return d  # ray directions


def adjust_intrinsics_crop(fx, fy, cx, cy, bbox, target_resolution):
    scale = target_resolution / (bbox[2] - bbox[0])
    new_fx = fx * scale
    new_fy = fy * scale
    new_cx = (cx - bbox[0]) * scale
    new_cy = (cy - bbox[1]) * scale

    return new_fx, new_fy, new_cx, new_cy


def get_crop_mask(orig_resolution, target_resolution, crop_box):
    crop_mask = np.ones((orig_resolution))
    crop_mask = crop_image(crop_mask, crop_box, bg_value=0)
    crop_mask = rescale_image(crop_mask, target_resolution)

    return crop_mask


class FrameReader:
    def __init__(self, video_path):
        self.frame_list = sorted(list(Path(video_path).glob("*.*")))

    def __len__(self):
        return len(self.frame_list)
    
    def __getitem__(self, index):
        img = cv2.imread(str(self.frame_list[index]))[..., [2, 1, 0]]
        return img


def load_frame(
    video_path: Path,  # path to .mp4 or dir containing frames
    frame_id: np.ndarray,
):
    if (video_path).is_dir():
        video_reader = FrameReader(video_path)
    else:
        video_reader = VideoReader(str(video_path))

    if frame_id >= len(video_reader):
        print(f"WARNING: Frame {frame_id} out of bounds for video with length {len(video_reader)}")
        frame_id = len(video_reader) - 1

    frame_img = video_reader[frame_id]
    if not isinstance(frame_img, np.ndarray):
        frame_img = frame_img.asnumpy()  # if the video reader is a decord reader

    return frame_img


def pivot_camera_intrinsic(extrinsics, target, angles, distance_factor=1.):
    """
    Rotates a camera around a target point.

    Parameters:
    - extrinsics: (4x4) numpy array, world_to_camera transformation matrix.
    - target: (3,) numpy array, target coordinates to pivot around.
    - angles: (3,) array-like, rotation angles (degrees) around X, Y, Z axes.

    Returns:
    - new_extrinsics: (4x4) numpy array, updated world_to_camera transformation.
    """
    extrinsics = np.linalg.inv(extrinsics)

    # Extract rotation and translation from extrinsics
    R_c2w = extrinsics[:3, :3]  # 3x3 rotation matrix
    t_c2w = extrinsics[:3, 3]   # 3x1 translation vector

    # Compute offset vector from target to camera
    v = (t_c2w - target) * distance_factor

    # Compute rotation matrix for given angles
    R_delta = R.from_euler('YX', angles, degrees=True).as_matrix()  # 'yx'

    # Apply intrinsic rotation to the camera's rotation (local frame)
    new_R_c2w = R_c2w @ R_delta

    # Rotate position offset in camera frame as well
    new_v = R_c2w @ R_delta @ np.linalg.inv(R_c2w) @ v
    new_t_c2w = target + new_v

    # Construct new extrinsics
    new_extrinsics = np.eye(4)
    new_extrinsics[:3, :3] = new_R_c2w
    new_extrinsics[:3, 3] = new_t_c2w

    # import pdb; pdb.set_trace()

    return np.linalg.inv(new_extrinsics)


def get_head_direction(rot):
    rot_mat = R.from_rotvec(rot).as_matrix()
    head_dir = -rot_mat[:3, 2]  # -z is head direction
    head_dir[1:] = -head_dir[1:]  # p3d to opencv
    return head_dir


def compute_yaw_pitch_to_face_direction(extrinsics, world_direction):
    """
    Computes yaw and pitch angles (in degrees) needed to rotate the camera
    so its forward vector aligns with the given world-space direction.

    Parameters:
    - extrinsics: (4x4) camera-to-world matrix
    - world_direction: (3,) unit vector representing desired view direction in world space

    Returns:
    - yaw, pitch: angles in degrees
    """

    # Normalize direction
    world_direction = world_direction / np.linalg.norm(world_direction)

    # Camera's current rotation (camera-to-world)
    R_c2w = extrinsics[:3, :3]

    # Convert world direction into camera's local frame
    dir_local = R_c2w.T @ world_direction

    dx, dy, dz = dir_local

    # Handle potential divide-by-zero or arcsin domain errors
    dz = np.clip(dz, -1e-6, 1) if dz == 0 else dz
    dy = np.clip(dy, -1, 1)

    # Yaw: rotation around Y (left-right)
    yaw = np.degrees(np.arctan2(dx, dz))

    # Pitch: rotation around X (up-down)
    pitch = np.degrees(np.arcsin(-dy))  # negative because +Y is up

    return yaw, pitch

```

## /cap4d/flame/flame.py

```py path="/cap4d/flame/flame.py" 
from typing import Dict, Optional

import einops
import numpy as np
import torch

from flowface.flame.flame import FlameSkinner, FLAME_N_SHAPE, FLAME_N_EXPR
from flowface.flame.utils import batch_rodrigues, OPENCV2PYTORCH3D, transform_vertices, project_vertices

from cap4d.flame.mouth import FlameMouth


FLAME_PKL_PATH = "data/assets/flame/flame2023_no_jaw.pkl"
JAW_REGRESSOR_PATH = "data/assets/flame/jaw_regressor.npy"
BLINK_BLENDSHAPE_PATH = "data/assets/flame/blink_blendshape.npy"


# This module has NO trainable parameters
class CAP4DFlameSkinner(FlameSkinner):
    def __init__(
        self,
        flame_pkl_path: str  = FLAME_PKL_PATH,
        n_shape_params: int = FLAME_N_SHAPE,
        n_expr_params: int = FLAME_N_EXPR,
        blink_blendshape_path: str = BLINK_BLENDSHAPE_PATH,
        add_mouth: bool = False,
        add_lower_jaw: bool = False,
        jaw_regressor_path: str = JAW_REGRESSOR_PATH,
    ):
        super().__init__(flame_pkl_path, n_shape_params, n_expr_params, blink_blendshape_path)

        self.add_mouth = add_mouth
        if add_mouth:
            self.mouth = FlameMouth()
        
        self.add_lower_jaw = add_lower_jaw
        if add_lower_jaw:
            self.lower_jaw = FlameMouth()
            jaw_regressor = torch.tensor(np.load(jaw_regressor_path))
            self.register_buffer("jaw_regressor", jaw_regressor)
    
    def forward(
        self,
        flame_sequence: Dict,
        return_offsets: bool = True,
        return_transforms: bool = False,
    ) -> torch.Tensor:
        """
        Compute 3D vertices given a flame sequence with shape parameters
        flame_sequence (dictionary):
            shape (N_shape)
            and N_t timesteps of expression, pose (rot, tra), eye_rot (optional), jaw_rot (optional):
            expr (N_t, N_exp)
            rot (N_t, 3)
            tra (N_t, 3)
            eye_rot (N_t, 3)
            jaw_rot (N_t, 3)
            neck_rot (N_t, 3)

        output: verts (N_t, V, 3)
        """
        shape_offsets = self._get_shape_offsets(flame_sequence["shape"][None], None)
        shape_verts = self._get_template_vertices(None) + shape_offsets

        expr_offsets = self._get_expr_offsets(flame_sequence["expr"], None)

        verts = shape_verts + expr_offsets  # N_t, V, 3

        # create rotation matrix for joint rotations, we apply base transform separately
        rotations = torch.eye(3, device=verts.device)[None, None].repeat(verts.shape[0], 5, 1, 1)
        if "neck_rot" in flame_sequence and flame_sequence["neck_rot"] is not None:
            rotations[:, 0, ...] = batch_rodrigues(flame_sequence["neck_rot"])
        if "jaw_rot" in flame_sequence and flame_sequence["jaw_rot"] is not None:
            rotations[:, 2, ...] = batch_rodrigues(flame_sequence["jaw_rot"])
        if "eye_rot" in flame_sequence and flame_sequence["eye_rot"] is not None:
            eye_rot = batch_rodrigues(flame_sequence["eye_rot"])
            rotations[:, 3, ...] = eye_rot
            rotations[:, 4, ...] = eye_rot

        verts, v_transforms = self._apply_joint_rotation(verts, rotations=rotations, vert_mask=None, return_transforms=True)

        # compute offsets (including joint rotations)
        offsets = verts - shape_verts
        if self.add_mouth:
            mouth_verts = self.mouth(shape_verts, self.joint_regressor)
            mouth_verts = mouth_verts.repeat(verts.shape[0], 1, 1)
            verts = torch.cat([verts, mouth_verts], dim=1)
            offsets = torch.cat([offsets, torch.zeros_like(mouth_verts)], dim=1)
            v_transforms = torch.cat([v_transforms, torch.zeros(mouth_verts.shape[0], mouth_verts.shape[1], 4, 4, device=v_transforms.device)], dim=1)
        if self.add_lower_jaw:
            jaw_rot = einops.einsum(flame_sequence["expr"], self.jaw_regressor, 'b exp, exp r -> b r')
            neutral_jaw_verts = self.lower_jaw(shape_verts, self.joint_regressor, batch_rodrigues(jaw_rot * 0.))
            jaw_verts = self.lower_jaw(shape_verts, self.joint_regressor, batch_rodrigues(jaw_rot))
            verts = torch.cat([verts, jaw_verts], dim=1)
            offsets = torch.cat([offsets, jaw_verts - neutral_jaw_verts], dim=1)
            jaw_transforms = torch.zeros(jaw_verts.shape[0], 4, 4, device=v_transforms.device)
            jaw_transforms[:, :3, :3] = batch_rodrigues(jaw_rot)
            jaw_transforms[..., -1, -1] = 1.
            jaw_transforms = jaw_transforms[:, None].repeat(1, jaw_verts.shape[1], 1, 1)
            v_transforms = torch.cat([v_transforms, jaw_transforms], dim=1)

        # apply base transform separately
        base_rot = batch_rodrigues(flame_sequence["rot"])
        base_tra = flame_sequence["tra"][..., None]
        verts = (base_rot @ verts.permute(0, 2, 1) + base_tra).permute(0, 2, 1)

        output = [verts]

        if return_offsets:
            output.append(offsets)
        if return_transforms:
            base_transform = torch.cat([base_rot, base_tra], dim=2)
            base_transform = torch.cat([base_transform, torch.zeros_like(base_transform[:, :1, ...])], dim=1)
            base_transform[..., -1, -1] = 1.
            v_transforms = einops.einsum(base_transform, v_transforms, 'b i j, b N j k -> b N i k')

            output.append(v_transforms)

        return output


def compute_flame(
    flame: CAP4DFlameSkinner, 
    fit_3d: Dict[str, np.ndarray],
):
    flame_sequence = {
        "shape": torch.tensor(fit_3d["shape"]).float(),
        "expr": torch.tensor(fit_3d["expr"]).float(),
        "rot": torch.tensor(fit_3d["rot"]).float(),
        "tra": torch.tensor(fit_3d["tra"]).float(),
        "eye_rot": torch.tensor(fit_3d["eye_rot"]).float(),
        "jaw_rot": None, 
        "neck_rot": None, 
    }
    if "neck_rot" in fit_3d:
        flame_sequence["neck_rot"] = torch.tensor(fit_3d["neck_rot"]).float()
    if "jaw_rot" in fit_3d:
        flame_sequence["jaw_rot"] = torch.tensor(fit_3d["jaw_rot"]).float()

    # compute FLAME vertices
    verts_3d, offsets_3d = flame(
        flame_sequence, 
        return_offsets=True,
    )  # [N_t V 3], [N_t 2 3]

    fx, fy, cx, cy = [torch.tensor(fit_3d[key]).float() for key in ["fx", "fy", "cx", "cy"]]  # [N_C, 1]
    extr = torch.tensor(fit_3d["extr"]).float()  # [N_c 3 3]
    cam_parameters = {
        "fx": fx,
        "fy": fy,
        "cx": cx,
        "cy": cy,
        "extr": extr,
    }

    # transform into OpenCV camera coordinate convention
    verts_3d_cv = transform_vertices(OPENCV2PYTORCH3D[None].to(verts_3d.device), verts_3d)  # [N_t V 3]
    # project vertices to cameras
    verts_2d = project_vertices(verts_3d_cv, cam_parameters)  # [N_c N_t V 3]

    return {
        "verts_3d": verts_3d.cpu().numpy(),
        "verts_3d_cv": verts_3d_cv.cpu().numpy(),
        "verts_2d": verts_2d.cpu().numpy(),
        "offsets_3d": offsets_3d.cpu().numpy(),
    }

```

## /cap4d/flame/mouth.py

```py path="/cap4d/flame/mouth.py" 
import torch
import torch.nn as nn
import numpy as np
import einops


def generate_uv_sphere(r=1.0, latitude_steps=30, longitude_steps=30):
    # Creat half a sphere!
    # Generate latitude and longitude
    latitudes = torch.linspace(-np.pi / 2, np.pi / 2, latitude_steps)[:latitude_steps // 2]
    longitudes = torch.linspace(0, 2 * np.pi, longitude_steps)

    verts = []
    # Create a meshgrid for latitudes and longitudes
    for lat in latitudes:
        for lon in longitudes:
            verts.append([
                r * torch.cos(lat) * torch.cos(lon),
                r * torch.cos(lat) * torch.sin(lon),
                r * torch.sin(lat),
            ])

    verts = torch.tensor(verts)

    # Generate triangle indices
    indices = []
    for i in range(latitude_steps // 2 - 1):
        for j in range(longitude_steps):
            # Current vertex
            lat_1_lon_1 = i * longitude_steps + j
            lat_1_lon_2 = i * longitude_steps + (j + 1) % longitude_steps
            # Next row's vertices
            lat_2_lon_1 = (i + 1) * longitude_steps + j
            lat_2_lon_2 = (i + 1) * longitude_steps + (j + 1) % longitude_steps
            
            if i < latitude_steps - 2:
                indices.append([lat_1_lon_1, lat_2_lon_2, lat_2_lon_1])

            if i > 0:
                indices.append([lat_1_lon_1, lat_1_lon_2, lat_2_lon_2])

    # Convert indices to a tensor
    face_indices = torch.tensor(indices, dtype=torch.long)
    
    return verts, face_indices


class FlameMouth(nn.Module):
    def __init__(
        self,
        long_steps=20,
        lat_steps=20,
        lip_v_index=3533,
        lip_offset=0.005,
    ):
        super().__init__()

        v_sphere, f_sphere = generate_uv_sphere(
            r=1.,
            latitude_steps=lat_steps,
            longitude_steps=long_steps,
        )
        v_sphere[:, 1] = -v_sphere[:, 1]  # flip axis
        v_sphere[:, 2] = -v_sphere[:, 2]  # flip axis to align in right direction

        self.lip_v_index = lip_v_index
        self.register_buffer("vertices", v_sphere)
        self.register_buffer("faces", f_sphere)

        self.lip_offset = lip_offset

    def forward(
        self,
        neutral_verts,
        joint_regressor,
        jaw_rotation=None,
    ):  
        jaw_joint = einops.einsum(neutral_verts, joint_regressor[2], "b V xyz, V -> b xyz") # (B, 3)

        lip_vert = neutral_verts[:, self.lip_v_index]

        offset = lip_vert - jaw_joint

        distance = offset.norm(dim=-1, keepdim=True)
        direction = offset / distance
        y = torch.zeros_like(direction, device=direction.device)
        y[:, 1] = 1 
        # y[:, 0] = 1
        new_x = torch.cross(y, direction, dim=-1)
        new_x = new_x / new_x.norm(dim=-1, keepdim=True)
        new_y = torch.cross(direction, new_x, dim=-1)
        new_y = new_y / new_y.norm(dim=-1, keepdim=True)
        new_z = direction

        rot_mat = torch.stack([new_x, new_y, new_z], dim=-1)
        v_sphere = self.vertices[None] * distance[..., None] * 0.25

        # v_sphere = v_sphere + y * 0.5

        v_sphere = (rot_mat @ v_sphere.permute(0, 2, 1)).permute(0, 2, 1)
        center = jaw_joint + offset * 0.75 - self.lip_offset * direction
        v_sphere = v_sphere + center

        if jaw_rotation is not None:
            v_offset = jaw_rotation @ (v_sphere - jaw_joint).permute(0, 2, 1)
            v_sphere = jaw_joint + v_offset.permute(0, 2, 1)

        return v_sphere

```

## /cap4d/inference/data/generation_data.py

```py path="/cap4d/inference/data/generation_data.py" 
import numpy as np

from cap4d.datasets.utils import (
    pivot_camera_intrinsic,
    get_head_direction,
    compute_yaw_pitch_to_face_direction,
)

from cap4d.inference.data.inference_data import CAP4DInferenceDataset


def elipsis_sample(yaw_limit, pitch_limit):
    if yaw_limit == 0. or pitch_limit == 0.:
        return 0., 0.
    
    dist = 1.
    while dist >= 1.:
        yaw = np.random.uniform(-yaw_limit, yaw_limit)
        pitch = np.random.uniform(-pitch_limit, pitch_limit)

        dist = np.sqrt((yaw / yaw_limit) ** 2 + (pitch / pitch_limit) ** 2)

    return yaw, pitch


class GenerationDataset(CAP4DInferenceDataset):
    def __init__(
        self, 
        generation_data_path,
        reference_flame_item,
        n_samples=840,
        yaw_range=55,
        pitch_range=20,
        expr_factor=1.0,
        resolution=512,
        downsample_ratio=8,
    ):
        super().__init__(resolution, downsample_ratio)

        self.n_samples = n_samples
        self.yaw_range = yaw_range
        self.pitch_range = pitch_range

        self.flame_dicts = self.init_flame_params(
            generation_data_path,
            reference_flame_item,
            n_samples,
            yaw_range,
            pitch_range,
            expr_factor,
        )

    def init_flame_params(
        self,
        generation_data_path,
        reference_flame_item,
        n_samples,
        yaw_range,
        pitch_range,
        expr_factor,
    ):
        gen_data = dict(np.load(generation_data_path))

        ref_extr = reference_flame_item["extr"]
        ref_shape = reference_flame_item["shape"]
        ref_fx = reference_flame_item["fx"]
        ref_fy = reference_flame_item["fy"]
        ref_cx = reference_flame_item["cx"]
        ref_cy = reference_flame_item["cy"]
        ref_resolution = reference_flame_item["resolutions"]
        ref_rot = reference_flame_item["rot"]
        ref_tra = reference_flame_item["tra"]
        ref_tra_cv = ref_tra.copy()
        ref_tra_cv[:, 1:] = -ref_tra_cv[:, 1:]  # p3d to opencv
        ref_head_dir = get_head_direction(ref_rot)  # in p3d
        ref_head_dir[:, 1:] = -ref_head_dir[:, 1:]  # p3d to opencv

        center_yaw, center_pitch = compute_yaw_pitch_to_face_direction(
            ref_extr[0],
            ref_head_dir[0],
        )

        flame_list = []

        assert n_samples <= len(gen_data["expr"]), "too many samples"
        for expr, eye_rot in zip(gen_data["expr"][:n_samples], gen_data["eye_rot"][:n_samples]):
            yaw, pitch = elipsis_sample(yaw_range, pitch_range)
            yaw += center_yaw
            pitch -= center_pitch  # pitch is flipped for some reason

            rotated_extr = pivot_camera_intrinsic(ref_extr[0], ref_tra_cv[0], [yaw, pitch])

            flame_dict = {
                "shape": ref_shape,
                "expr": expr[None] * expr_factor,
                "eye_rot": eye_rot[None] * expr_factor,
                "rot": ref_rot,
                "tra": ref_tra,
                "extr": rotated_extr[None],
                "resolutions": ref_resolution,
                "fx": ref_fx,
                "fy": ref_fy,
                "cx": ref_cx,
                "cy": ref_cy,
            }
            flame_list.append(flame_dict)

        self.flame_list = flame_list
        self.ref_extr = ref_extr[0]

```

## /cap4d/inference/data/inference_data.py

```py path="/cap4d/inference/data/inference_data.py" 
import numpy as np
from torch.utils.data import Dataset
import einops

from cap4d.flame.flame import CAP4DFlameSkinner
from cap4d.datasets.utils import (
    load_frame, 
    crop_image, 
    rescale_image, 
    apply_bg, 
    load_flame_verts_and_cam,
    get_bbox_from_verts,
    load_camera_rays,
    verts_to_pytorch3d,
)


class CAP4DInferenceDataset(Dataset):
    def __init__(
        self, 
        resolution=512,
        downsample_ratio=8,
    ):
        self.resolution = resolution
        self.latent_resolution = self.resolution // downsample_ratio

        self.flame_skinner = CAP4DFlameSkinner(
            add_mouth=True, 
            n_shape_params=150,
            n_expr_params=65,
        )
        self.head_vertex_ids = np.genfromtxt("data/assets/flame/head_vertices.txt").astype(int)

        self.flame_list = None
        self.ref_extr = None
        self.data_path = None

    def __len__(self):
        assert self.flame_list is not None, "self.flame_list not properly initialized"
        return len(self.flame_list)

    def __getitem__(self, idx):
        flame_item = self.flame_list[idx]

        verts_2d, offsets_3d, intrinsics, extrinsics = load_flame_verts_and_cam(
            self.flame_skinner,
            flame_item,
        )
        crop_box = get_bbox_from_verts(verts_2d, self.head_vertex_ids)
        flame_item["crop_box"] = crop_box

        if "img_dir_path" in flame_item:
            # we have images available, load them
            # load and crop image, including background
            img_dir_path = flame_item["img_dir_path"]
            timestep_id = flame_item["timestep_id"]
            img = load_frame(img_dir_path, timestep_id)
            del flame_item["img_dir_path"]  # delete string from flame dict so that it can be collated
            if "bg_dir_path" in flame_item:
                bg = load_frame(flame_item["bg_dir_path"], timestep_id)
                del flame_item["bg_dir_path"]  # delete string from flame dict so that it can be collated
            else:
                print(f"WARNING: bg does not exist for image {img_dir_path}. Make sure the background is white.")
                bg = np.ones_like(img) * 255
            out_crop_mask = np.ones_like(img[..., [0]])
            img = apply_bg(img, bg)
            img = crop_image(img, crop_box, bg_value=255)
            out_crop_mask = crop_image(out_crop_mask, crop_box, bg_value=0)
            img = rescale_image(img, self.resolution)
            img = ((img / 127.5) - 1.0).astype(np.float32)
            out_crop_mask = rescale_image(out_crop_mask, self.latent_resolution)
            is_ref = True
        else:
            # no image available means these images need to be generated
            # set image to zero
            img = np.zeros((self.resolution, self.resolution, 3), dtype=np.float32)
            out_crop_mask = np.ones((self.latent_resolution, self.latent_resolution), dtype=np.float32)
            is_ref = False

        # load and transform ray map
        ray_map = load_camera_rays(
            crop_box,
            intrinsics,
            extrinsics,
            self.latent_resolution,
        )
        assert self.ref_extr is not None, "reference extrinsics ref_extr not set"
        # transform raymap to base extrinsics
        ray_map_h = ray_map.shape[1]
        ray_map = einops.rearrange(ray_map, 'v h w -> v (h w)')
        ray_map = self.ref_extr[:3, :3] @ ray_map
        ray_map = einops.rearrange(ray_map, 'v (h w) -> v h w', h=ray_map_h)

        # reference mask is one for reference dataset
        reference_mask = np.ones_like(out_crop_mask) * is_ref

        # convert pixel space vertices to pytorch3d space [-1, 1]
        verts_2d = verts_to_pytorch3d(verts_2d, np.array(crop_box))

        cond_dict = {
            "out_crop_mask": out_crop_mask[None],
            "reference_mask": reference_mask[None],
            "ray_map": ray_map[None],
            "verts_2d": verts_2d[None],
            "offsets_3d": offsets_3d[None],
        }  # [None] is for fake time dimension

        out_dict = {
            "jpg": img[None],  # jpg names comes from controlnet implementation
            "hint": cond_dict,
            "flame_params": flame_item,
        }
        
        return out_dict

```

## /cap4d/inference/data/reference_data.py

```py path="/cap4d/inference/data/reference_data.py" 
import numpy as np
import json
from pathlib import Path

from cap4d.inference.data.inference_data import CAP4DInferenceDataset


class ReferenceDataset(CAP4DInferenceDataset):
    def __init__(
        self, 
        data_path,
        resolution=512,
        downsample_ratio=8,
    ):
        super().__init__(resolution, downsample_ratio)
        
        self.load_flame_params(data_path)

    def load_flame_params(
        self,
        data_path: Path,
    ):
        flame_dict = dict(np.load(data_path / "fit.npz"))

        with open(data_path / "reference_images.json") as f:
            ref_json = json.load(f)

        ref_list = []
        for cam_name, timestep_id in ref_json:
            cam_id = np.where(flame_dict["camera_order"] == cam_name)[0].item()
            ref_list.append((cam_id, timestep_id))  # cam_id, timestep_id

        flame_list = []
        ref_extr = None
        for cam_id, timestep_id in ref_list:
            # select a single frame (camera, timestep) set from flame_dict
            flame_item = {}

            for key in flame_dict:
                if key in ["expr", "rot", "tra", "eye_rot"]:
                    flame_item[key] = flame_dict[key][[timestep_id]]

                elif key in ["fx", "fy", "cx", "cy", "extr", "resolutions"]:
                    flame_item[key] = flame_dict[key][[cam_id]]

                elif key in ["shape"]:
                    flame_item[key] = flame_dict[key]

            flame_item["timestep_id"] = timestep_id
            cam_dir_path = flame_dict["camera_order"][cam_id]
            flame_item["img_dir_path"] = data_path / "images" / cam_dir_path
            bg_dir_path = data_path / "bg" / cam_dir_path
            if bg_dir_path.exists():
                flame_item["bg_dir_path"] = bg_dir_path

            flame_list.append(flame_item)

            if ref_extr is None:
                ref_extr = flame_item["extr"]

        self.ref_extr = ref_extr[0]
        self.flame_list = flame_list

```

## /cap4d/inference/generate_images.py

```py path="/cap4d/inference/generate_images.py" 
from pathlib import Path
import argparse
import copy
import shutil

import torch
from torch.utils.data import  DataLoader
import gc
from omegaconf import OmegaConf

from pytorch_lightning import seed_everything
from cap4d.mmdm.sampler import StochasticIOSampler
from cap4d.inference.data.reference_data import ReferenceDataset
from cap4d.inference.data.generation_data import GenerationDataset
from cap4d.inference.utils import (
    get_condition_from_dataloader,
    load_model,
    save_visualization,
    save_flame_params,
    convert_and_save_latent_images,
    find_number_of_generated_images,
)


@torch.no_grad()
def main(args):
    ref_data_path = Path(args.reference_data_path)

    gen_config_path = Path(args.config_path)
    gen_config = OmegaConf.load(gen_config_path)

    output_path = Path(args.output_path)
    output_path.mkdir(exist_ok=True, parents=True)
    output_ref_path = output_path / "reference_images"
    output_ref_path.mkdir(exist_ok=True)
    output_gen_path = output_path / "generated_images"
    output_gen_path.mkdir(exist_ok=True)

    seed_everything(gen_config["seed"])

    shutil.copy(gen_config_path, output_path / "mmdm_config_dump.yaml")

    # create reference dataloaders
    print("Creating dataloaders")
    refset = ReferenceDataset(ref_data_path, gen_config["resolution"])
    ref_dataloader = DataLoader(refset, num_workers=1, batch_size=1, shuffle=False)

    n_gen_samples = find_number_of_generated_images(
        gen_config["generation_data"]["n_samples"], 
        len(ref_dataloader),
        gen_config["R_max"],
        gen_config["V"],
    )

    # create generation dataloaders
    genset = GenerationDataset(
        generation_data_path=gen_config["generation_data"]["data_path"],
        reference_flame_item=refset.flame_list[0],
        n_samples=n_gen_samples,
        resolution=gen_config["resolution"],
        yaw_range=gen_config["generation_data"]["yaw_range"],
        pitch_range=gen_config["generation_data"]["pitch_range"],
        expr_factor=gen_config["generation_data"]["expr_factor"],
    )
    gen_dataloader = DataLoader(genset, num_workers=1, batch_size=1, shuffle=False)

    # load models
    model = load_model(Path(gen_config["ckpt_path"]))

    device_model_map = {}
    first_rank_model = None
    first_rank_device = None
    if "cuda" in args.device:
        for cuda_id in range(torch.cuda.device_count()):
            dev_key = f"cuda:{cuda_id}"
            device_model_map[dev_key] = copy.deepcopy(model).to(dev_key)
            if first_rank_model is None:
                first_rank_model = device_model_map[dev_key]
                first_rank_device = dev_key
    else:
        device_model_map[args.device] = model.to(args.device)
        first_rank_model = device_model_map[args.device]
        first_rank_device = args.device
        # model.cond_stage_model.device = "cpu"

    print(f"Done loading model.")

    # load all reference frames and create conditioning (and unconditional) images for each
    print(f"Loading reference dataset from {ref_data_path}")
    ref_data = get_condition_from_dataloader(
        first_rank_model,
        ref_dataloader,
        args.device,
    )

    # load all generation frames and create conditioning images for each
    print(f"Loading generation dataset from {gen_config['generation_data']['data_path']}")
    gen_data = get_condition_from_dataloader(
        first_rank_model,
        gen_dataloader,
        args.device,
    )

    if args.visualize_conditioning:
        print("Saving visualization of conditioning images")
        save_visualization(ref_data["cond_vis_frames"], output_ref_path)
        save_visualization(gen_data["cond_vis_frames"], output_gen_path)

    print("Saving flame parameters")
    save_flame_params(ref_data["flame_params"], output_ref_path)
    save_flame_params(gen_data["flame_params"], output_gen_path)

    gc.collect()

    for key in ref_data["cond_frames"]:
        ref_data["cond_frames"][key] = torch.cat(ref_data["cond_frames"][key], dim=0)
        ref_data["uncond_frames"][key] = torch.cat(ref_data["uncond_frames"][key], dim=0)
        gen_data["cond_frames"][key] = torch.cat(gen_data["cond_frames"][key], dim=0)
        gen_data["uncond_frames"][key] = torch.cat(gen_data["uncond_frames"][key], dim=0)

    # Sample with Stochastic I/O Conditioning:
    print(f"Generating images on {len(list(device_model_map))} devices with stochastic I/O.")
    stochastic_io_sampler = StochasticIOSampler(device_model_map)

    z_gen = stochastic_io_sampler.sample(
        S=gen_config["n_ddim_steps"],
        ref_cond=ref_data["cond_frames"],
        ref_uncond=ref_data["uncond_frames"],
        gen_cond=gen_data["cond_frames"],
        gen_uncond=gen_data["uncond_frames"],
        latent_shape=(4, gen_config["resolution"] // 8, gen_config["resolution"] // 8),
        V=gen_config["V"],
        R_max=gen_config["R_max"],
        cfg_scale=gen_config["cfg_scale"],
    )

    print("Done generating.")

    gc.collect()
    z_gen = z_gen.cpu()
    z_ref = ref_data["cond_frames"]["z_input"].cpu()

    print(f"Saving reference images to {output_ref_path}/images")
    convert_and_save_latent_images(z_ref, first_rank_model, first_rank_device, output_ref_path)
    print(f"Saving generated images to {output_gen_path}/images")
    convert_and_save_latent_images(z_gen, first_rank_model, first_rank_device, output_gen_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_path",
        type=str,
        required=True,
        help="path to generation config file",
    )
    parser.add_argument(
        "--reference_data_path",
        type=str,
        required=True,
        help="path to reference json file",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        required=True,
        help="path to output",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="batch size",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="inference device",
    )
    parser.add_argument(
        "--visualize_conditioning",
        type=int,
        default=1,
        help="whether to save visualizations of conditioning images",
    )
    args = parser.parse_args()
    main(args)

```

## /cap4d/inference/utils.py

```py path="/cap4d/inference/utils.py" 
from collections import defaultdict
import os

import einops
from omegaconf import OmegaConf
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import cv2

from controlnet.ldm.util import instantiate_from_config
from cap4d.mmdm.mmdm import MMLDM

    
def to_device(batch, device="cuda"):
    for key in batch:
        if isinstance(batch[key], dict):
            batch[key] = to_device(batch[key])
        elif not isinstance(batch[key], list):
            batch[key] = batch[key].to(device)

    return batch


def log_cond(module, batch):
    cond_model = module.cond_stage_model
    cond_key = module.control_key

    c_cond = cond_model(batch[cond_key], unconditional=False)
    enc_vis = cond_model.get_vis(c_cond["pos_enc"])

    for key in enc_vis:
        vis = enc_vis[key]
        b_ = vis.shape[0]
        vis = einops.rearrange(vis, 'b t h w c -> (b t) c h w')
        vis = F.interpolate(vis, scale_factor=8., mode="nearest")
        vis = vis.clamp(-1., 1.)
        enc_vis[key] = einops.rearrange(vis, '(b t) c h w -> (b t) h w c', b=b_)

    return enc_vis


def load_model(ckpt_path):
    list_of_files = list((ckpt_path / "checkpoints").glob("*.ckpt"))
    latest_file = max(list_of_files, key=os.path.getctime)
    print("Loading model using checkpoint", latest_file)
    weight_path = latest_file

    # load modified model
    config_path = ckpt_path / "config_dump.yaml"
    config = OmegaConf.load(config_path)
    print(f'Loaded model config from [{config_path}]')
    model: MMLDM = instantiate_from_config(config.model).cpu()

    print("Loading state dict")
    state_dict = torch.load(weight_path, map_location='cpu')
    model.load_state_dict(state_dict)
    model.eval()

    return model


def get_condition_from_dataloader(model, dataloader, device):
    cond_frames = defaultdict(list)
    uncond_frames = defaultdict(list)
    cond_vis_frames = defaultdict(list)
    flame_params = []

    for frame_id, batch in enumerate(tqdm(dataloader)):
        batch = to_device(batch, device=device)

        # get conditioning from data batch
        z, c = model.get_input(batch, model.first_stage_key, force_conditional=True)

        # store conditioning for later
        for key in c["c_concat"][0]:
            c_cond = einops.rearrange(c["c_concat"][0][key], 'b t ... -> (b t) ...')
            c_uncond = einops.rearrange(c["c_uncond"][0][key], 'b t ... -> (b t) ...')
            cond_frames[key].append(c_cond.cpu())
            uncond_frames[key].append(c_uncond.cpu())
        
        # store conditioning visualization for later
        cond_vis = log_cond(model, batch)
        for key in cond_vis:
            cond_vis_frames[key].append(cond_vis[key].cpu())
        
        # store the flame and camera parameters for later
        for b in range(batch["flame_params"]["fx"].shape[0]):
            flame_dict = {}
            for key in batch["flame_params"]:
                flame_dict[key] = batch["flame_params"][key][b].cpu().numpy()
            flame_params.append(flame_dict)

    return {
        "cond_frames": cond_frames,
        "uncond_frames": uncond_frames,
        "cond_vis_frames": cond_vis_frames,
        "flame_params": flame_params,
    }


def save_visualization(vis_frames, output_dir):
    condition_base_dir = output_dir / "condition_vis"
    condition_base_dir.mkdir(exist_ok=True)

    for key in vis_frames:
        for frame_id, vis_img in enumerate(vis_frames[key]):
            out_dir = condition_base_dir / f"{key}"
            out_dir.mkdir(exist_ok=True)

            vis_img = vis_img[0]
            cv2.imwrite(
                str(out_dir / f"{frame_id:05d}.jpg"),
                (((vis_img[..., [2, 1, 0]].cpu().numpy() + 1.) / 2.) * 255).astype(np.uint8),
            )


def save_flame_params(flame_params, output_dir):
    out_flame_dir = output_dir / "flame"
    out_flame_dir.mkdir(exist_ok=True)
    
    for frame_id, flame_item in enumerate(flame_params):
        np.savez(out_flame_dir / f"{frame_id:05d}.npz", **flame_item)


def convert_and_save_latent_images(latents, model, device, output_dir):
    out_img_dir = output_dir / "images"
    out_img_dir.mkdir(exist_ok=True)

    for i in range(latents.shape[0]):
        # Convert latent to RGB
        x_samples = model.decode_first_stage(latents[None, [i]].to(device))[0, 0]
        img = ((x_samples + 1.) / 2.).clip(0., 1.)
        img = img.permute(1, 2, 0).cpu().numpy() * 255.
        out_img_path = out_img_dir / f"{i:05d}.png"
        success = cv2.imwrite(str(out_img_path), img[..., [2, 1, 0]].astype(np.uint8))
        assert success, f"failed to save image to {out_img_path}"


def find_number_of_generated_images(n_gen_default, n_ref, R_max, V):
    R = min(n_ref, R_max)
    missing_frames = n_gen_default % (V - R)
    n_gen_new = n_gen_default + missing_frames

    if missing_frames > 0:
        print(f"WARNING: Default number of generated images {n_gen_default} incompatible with R={R}, adjusting the number to {n_gen_new}")

    return n_gen_new
    
```

## /cap4d/mmdm/conditioning/cap4dcond.py

```py path="/cap4d/mmdm/conditioning/cap4dcond.py" 
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops

from cap4d.mmdm.conditioning.mesh2img import PropRenderer


class PositionalEncoding(nn.Module):
    def __init__(self, channels_per_dim):
        """
        :param channels: The last dimension of the tensor you want to apply pos emb to.
        """
        super().__init__()

        assert channels_per_dim % 2 == 0
        n_ch = channels_per_dim // 2
        freqs = 2. ** torch.linspace(0., n_ch - 1, steps=n_ch)

        self.register_buffer("freqs", freqs)

    def forward(self, tensor):
        """
        :param tensor: A 4d tensor of size (b, h, w, [x, y, z]), x, y, z should be [0, 1]
        :return: Positional Encoding Matrix of size (b, h, w, ch)
        """
        if len(tensor.shape) != 4:
            raise RuntimeError("The input tensor has to be 4d!")

        pos_xyz = tensor[..., None] * self.freqs[None, None, None, None]

        pos_emb = torch.cat([torch.sin(pos_xyz), torch.cos(pos_xyz)], dim=-1)
        pos_emb = einops.rearrange(pos_emb, "b h w c f -> b h w (c f)")

        return pos_emb


class CAP4DConditioning(nn.Module):
    def __init__(
        self,
        image_size=64,
        positional_channels=42,
        positional_multiplier=1.,
        super_resolution=2,
        use_ray_directions=True,
        use_expr_deformation=True,
        use_crop_mask=False,
        std_expr_deformation=0.0104,
    ) -> None:
        super().__init__()

        self.image_size = image_size
        assert super_resolution >=1 and super_resolution % 1 == 0
        self.super_resolution = super_resolution
        self.positional_channels = positional_channels
        self.positional_multiplier = positional_multiplier
        self.use_ray_directions = use_ray_directions
        self.use_expr_deformation = use_expr_deformation
        self.std_expr_deformation = std_expr_deformation
        self.use_crop_mask = use_crop_mask

        assert positional_channels % 3 == 0
        self.pos_encoding = PositionalEncoding(positional_channels // 3)
        self.renderer = PropRenderer()

    def forward(self, batch, unconditional=True):
        verts = batch["verts_2d"]
        offsets = batch["offsets_3d"]
        ref_mask = batch["reference_mask"][:, :, None]
        B, T = verts.shape[:2]

        z_input = None
        if "z" in batch:
            z_input = batch["z"]

        img_size = self.image_size

        if unconditional:
            total_channels = self.positional_channels + 1 # positional and ref mask
            if self.use_crop_mask:
                total_channels += 1
            if self.use_ray_directions:
                total_channels += 3
            if self.use_expr_deformation: 
                total_channels += 3
            pose_pos_enc = torch.zeros((B, T, img_size, img_size, total_channels), device=verts.device)
            if z_input is not None:
                z_input = z_input * 0.
        else:
            with torch.no_grad():
                verts = einops.rearrange(verts, 'b t n v -> (b t) n v')
                offsets = einops.rearrange(offsets, 'b t n v -> (b t) n v')
                offsets = offsets / self.std_expr_deformation  # normalize offset magnitude

                pose_map, mask = self.renderer.render(
                    verts, 
                    (img_size * self.super_resolution, img_size * self.super_resolution),
                    prop=offsets if self.use_expr_deformation else None,
                )

                if self.use_expr_deformation:
                    # extract last three channels which are offsets
                    pose_map, offsets = pose_map.split([3, 3], dim=-1)

                # Need to unnormalize from [-1, 1] to [0, resolution]
                # pos_enc = self.pos_encoding((uv_img + 1.) / 2. * self.image_size * self.positional_multiplier)

                pose_pos_enc = self.pos_encoding(pose_map * self.positional_multiplier)

                if self.use_expr_deformation:
                    # append expression offset
                    pose_pos_enc = torch.cat([pose_pos_enc, offsets], dim=-1)

                pose_pos_enc = pose_pos_enc * mask  # mask values not rendered

                # downscale pos_enc if we use super resolution
                pose_pos_enc = einops.rearrange(pose_pos_enc, 'bt h w c -> bt c h w')
                pose_pos_enc = F.interpolate(pose_pos_enc, (img_size, img_size), mode="area")
                pose_pos_enc = einops.rearrange(pose_pos_enc, '(b t) c h w -> b t h w c', b=B)

                if self.use_ray_directions:
                    # concat ray map
                    ray_map = batch["ray_map"]
                    ray_map = einops.rearrange(ray_map, 'b t c h w -> b t h w c')
                    pose_pos_enc = torch.cat([pose_pos_enc, ray_map], dim=-1)

                # concat ref mask
                ref_mask_reshape = einops.rearrange(ref_mask, 'b t c h w -> b t h w c')
                pose_pos_enc = torch.cat([pose_pos_enc, ref_mask_reshape], dim=-1)

                if self.use_crop_mask:
                    crop_mask = batch["out_crop_mask"][..., None]
                    pose_pos_enc = torch.cat([pose_pos_enc, crop_mask], dim=-1)

        return {
            "pos_enc": pose_pos_enc,
            "z_input": z_input,
            "ref_mask": ref_mask,
        }

    def get_vis(self, enc):
        visualizations = {}

        n_pos = self.positional_channels // 3

        counter = 0

        pos_enc = enc[..., 0:self.positional_channels]

        # for i in [n_pos-1]:
        for i in range(n_pos-2, n_pos):
            visualizations[f"pose_map_{i}"] = pos_enc[..., [i, i + n_pos, i + n_pos * 2]]

        counter = self.positional_channels

        if self.use_expr_deformation:
            visualizations["expr_disp"] = enc[..., counter:counter+3]
            counter += 3

        if self.use_ray_directions:
            visualizations["ray_map"] = enc[..., counter:counter+3]
            counter += 3

        visualizations["ref_mask"] = enc[..., [counter] * 3]
        counter += 1

        if self.use_crop_mask:
            visualizations["crop_mask"] = enc[..., [counter] * 3]
            counter += 1

        return visualizations

```

## /cap4d/mmdm/conditioning/mesh2img.py

```py path="/cap4d/mmdm/conditioning/mesh2img.py" 
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch3d.ops.interp_face_attrs import interpolate_face_attributes
from pytorch3d.renderer import (
    BlendParams,
    PerspectiveCameras,
    hard_rgb_blend,
    rasterize_meshes,
)
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.structures.meshes import Meshes
from pytorch3d.io import load_obj


def create_camera_objects(
    K: torch.Tensor, RT: torch.Tensor, resolution: Tuple[int, int]
) -> PerspectiveCameras:
    """
    Creates pytorch3d camera objects from KRT matrices in the open3d convention.

    Parameters
    ----------
    K: torch.Tensor [3, 3]
        Camera calibration matrix
    RT: torch.Tensor [3, 4]
        Transform matrix of the camera (cam to world) in the open3d format
    resolution: tuple[int, int]
        Resolution (height, width) of the camera
    """

    R = RT[:, :3, :3]
    tvec = RT[:, :3, 3]

    focal_length = torch.stack([K[:, 0, 0], K[:, 1, 1]], dim=-1)
    principal_point = K[:, :2, 2]

    # Retype the image_size correctly and flip to width, height.
    H, W = resolution
    img_size = torch.tensor([[W, H]] * len(K), dtype=torch.int, device=K.device)

    # Screen to NDC conversion:
    # For non square images, we scale the points such that smallest side
    # has range [-1, 1] and the largest side has range [-u, u], with u > 1.
    # This convention is consistent with the PyTorch3D renderer, as well as
    # the transformation function `get_ndc_to_screen_transform`.
    scale = img_size.min(dim=1, keepdim=True)[0] / 2.0  # .to(RT)
    scale = scale.expand(-1, 2)

    c0 = img_size / 2.0

    # Get the PyTorch3D focal length and principal point.
    focal_pytorch3d = focal_length / scale
    p0_pytorch3d = -(principal_point - c0) / scale

    # For R, T we flip x, y axes (opencv screen space has an opposite
    # orientation of screen axes).
    # We also transpose R (opencv multiplies points from the opposite=left side).
    R_pytorch3d = R.clone().permute(0, 2, 1)
    T_pytorch3d = tvec.clone()
    R_pytorch3d[:, :, :2] *= -1
    T_pytorch3d[:, :2] *= -1

    return PerspectiveCameras(
        R=R_pytorch3d,
        T=T_pytorch3d,
        focal_length=focal_pytorch3d,
        principal_point=p0_pytorch3d,
        image_size=img_size,
        device=K.device,
    )


def create_camera_objects_pytorch3d(K, RT, resolution):
    """
    Create pytorch3D camera objects from camera parameters
    :param K:
    :param RT:
    :param resolution:
    :return:
    """
    R = RT[:, :, :3]
    T = RT[:, :, 3]
    H, W = resolution
    img_size = torch.tensor([[H, W]] * len(K), dtype=torch.int, device=K.device)
    f = torch.stack((K[:, 0, 0], K[:, 1, 1]), dim=-1)
    principal_point = torch.cat([K[:, [0], -1], H - K[:, [1], -1]], dim=1)
    cameras = PerspectiveCameras(
        R=R,
        T=T,
        principal_point=principal_point,
        focal_length=f,
        device=K.device,
        image_size=img_size,
        in_ndc=False,
    )
    return cameras


def project_points(
    lmks: torch.Tensor,
    K: torch.Tensor,
    RT: torch.Tensor,
    resolution: Tuple[int, int] = None,
) -> torch.Tensor:
    """
    Projects 3D points to 2D screen space using pytorch3d cameras

    Parameters
    ----------
    lmks: torch.Tensor [B, N, 3]
        3D points to project
    K, RT, resolution: see create_camera_objects() for definition

    Returns
    -------
    lmks2d: torch.Tensor [B, N, 2]
        2D reprojected points
    """
    # create cameras
    points = torch.cat(
        [lmks, torch.ones((*lmks.shape[:-1], 1), device=lmks.device)], dim=-1
    )
    rt = torch.cat(
        [RT, torch.ones((RT.shape[0], 1, 4), device=RT.device)], dim=-2
    )
    points_3d = (rt @ points.permute(0, 2, 1)).permute(0, 2, 1)
    k = K[:, None, None, ...]
    points_2d = torch.cat(
        [
            points_3d[..., [0]] / points_3d[..., [2]] * k[..., 0, 0]
            + k[..., 0, 2],
            points_3d[..., [1]] / points_3d[..., [2]] * k[..., 1, 1]
            + k[..., 1, 2],
        ],
        dim=-1,
    )
    return points_2d


class VertexShader(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def _get_mesh_ndc(
        self,
        meshes: Meshes,
        cameras: PerspectiveCameras,
    ) -> Meshes:
        eps = None
        verts_world = meshes.verts_padded()
        verts_view = cameras.get_world_to_view_transform().transform_points(
            verts_world, eps=eps
        )
        projection_trafo = cameras.get_projection_transform().compose(
            cameras.get_ndc_camera_transform()
        )
        verts_ndc = projection_trafo.transform_points(verts_view, eps=eps)
        verts_ndc[..., 2] = verts_view[..., 2]
        meshes_ndc = meshes.update_padded(new_verts_padded=verts_ndc)

        return meshes_ndc

    def _get_fragments(
        self, cameras, meshes_ndc, img_shape, blur_sigma
    ) -> Fragments:
        znear = None
        if cameras is not None:
            znear = cameras.get_znear()
            if isinstance(znear, torch.Tensor):
                znear = znear.min().detach().item()
        z_clip = None if znear is None else znear / 2

        fragments = rasterize_meshes(
            meshes_ndc,
            image_size=img_shape,
            blur_radius=np.log(1.0 / 1e-4 - 1.0) * blur_sigma,
            faces_per_pixel=4 if blur_sigma > 0.0 else 1,
            bin_size=None,
            max_faces_per_bin=None,
            clip_barycentric_coords=True,
            perspective_correct=cameras is not None,
            cull_backfaces=True,
            z_clip_value=z_clip,
            cull_to_frustum=False,
        )
        return Fragments(
            pix_to_face=fragments[0],
            zbuf=fragments[1],
            bary_coords=fragments[2],
            dists=fragments[3],
        )

    def _rasterize_property(self, property, fragments):

        # rasterize vertex attribute over faces
        # prop has to be not packed, [B, F, 3, D] -> [B * F, 3, D]
        prop_packed = torch.cat(
            [property[i] for i in range(property.shape[0])], dim=0
        )
        return interpolate_face_attributes(
            fragments.pix_to_face, fragments.bary_coords, prop_packed
        )

    def _rasterize_vertices(
        self, vertices: Dict[str, torch.Tensor], fragments: Fragments
    ):
        rasterized_properties = {}
        for key, prop in vertices.items():
            if key == "positions":
                continue

            rasterized_properties[key] = self._rasterize_property(
                prop, fragments
            )

        return rasterized_properties

    def forward(
        self,
        vertices: Dict[str, torch.Tensor],  # packed vertex properties!
        faces,
        intrinsics,
        extrinsics,
        img_shape,
        blur_sigma,
        return_meshes_and_cameras=False,
    ):
        meshes = Meshes(verts=vertices["positions"], faces=faces)
        cameras = None
        if intrinsics is not None:
            cameras = create_camera_objects(intrinsics, extrinsics, img_shape)
            meshes = self._get_mesh_ndc(meshes, cameras)
        fragments = self._get_fragments(cameras, meshes, img_shape, blur_sigma)
        pixels = self._rasterize_vertices(vertices, fragments)

        if return_meshes_and_cameras:
            return pixels, fragments, meshes, cameras
        else:
            return pixels, fragments


class BasePixelShader(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def sample_texture(self, pixels, texture):
        assert "uv_coords" in pixels

        pixel_uvs = pixels["uv_coords"]
        N, H, W, K_faces, N_f = pixel_uvs.shape
        # pixel_uvs: (N, H, W, K_faces, 3) -> (N, K_faces, H, W, 3) -> (N*K_faces, H, W, 3)
        pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(
            N * K_faces, H, W, N_f
        )
        tex_stack = torch.cat(
            [texture[[i]].expand(K_faces, -1, -1, -1) for i in range(N)]
        )
        tex = F.grid_sample(tex_stack, pixel_uvs[..., :2], align_corners=False)
        return tex.reshape(N, K_faces, -1, H, W).permute(0, 3, 4, 1, 2)

    def forward(self, fragments, pixels, textures):
        raise NotImplementedError()


class TextureShader(BasePixelShader):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, fragments, pixels, texture):
        features = self.sample_texture(pixels, texture)

        blend_params = BlendParams(
            sigma=0.,
            gamma=1e-4,
            background_color=[0] * texture.shape[1],
            # N, H, W, K_faces, C
        )
        depth = fragments.zbuf[..., None]
        depth[depth < 0.0] = 0.0
        depth = depth.repeat(1, 1, 1, 1, 3)
        
        img = hard_rgb_blend(features, fragments, blend_params)[..., :-1]  # remove alpha channel
        depth_img = hard_rgb_blend(depth, fragments, blend_params)

        return img.permute(0, 3, 1, 2), depth_img[..., [0]].permute(0, 3, 1, 2)


def vertex_to_face_mask(vert_mask, faces):
    face_mask = vert_mask[:, faces[0]].max(dim=-1)[0]

    return face_mask


class PropRenderer(nn.Module):
    """
    Rasterizes per-vertex properties given ndc meshes.
    """

    def __init__(
        self,
        template_path="./data/assets/flame/cap4d_flame_template.obj",
        head_vert_path="./data/assets/flame/head_vertices.txt",
        n_mouth_verts=200,
        prop_type="verts",  # either uv or verts
    ) -> None:
        super().__init__()

        self.v_shader = VertexShader()

        verts, faces, aux = load_obj(template_path)

        self.register_buffer("faces", faces.verts_idx)
        self.register_buffer("faces_uvs", faces.textures_idx)

        # load old lower neck mask and set all new faces also to zero (because they are from the mouth)
        vert_mask = torch.zeros(verts.shape[0]).bool()
        head_verts = torch.tensor(np.genfromtxt(head_vert_path)).long()
        vert_mask[head_verts] = 1
        vert_mask[-n_mouth_verts:] = 1

        # convert vert mask to face mask
        face_mask = vert_mask[self.faces].max(dim=-1)[0]
        self.register_buffer("face_mask", face_mask)

        if prop_type == "verts":
            self.register_buffer("props", verts)
            # normalize:
            self.props = self.props - self.props.mean(dim=-2, keepdim=True)
            self.props = self.props / self.props.max()
        elif prop_type == "uvs":
            self.register_buffer("props", aux.verts_uvs)
            self.props = self.props * 2. - 1.
            self.props[..., 1] = -self.props[..., 1]

    def render(
        self,
        vertices,
        img_shape,
        prop=None,
    ):
        b = vertices.shape[0]
        props_unpacked = self.props[self.faces][None].repeat(b, 1, 1, 1)

        verts = {
            "positions": vertices,
            "prop": props_unpacked,
        }

        if prop is not None:
            add_prop = prop[:, self.faces]
            verts["add_prop"] = add_prop

        pixels, fragments = self.v_shader(
            verts, 
            self.faces[None].repeat(vertices.shape[0], 1, 1), 
            None, 
            None, 
            img_shape, 
            0.
        )
        
        img = pixels["prop"][..., 0, :]

        if prop is not None:
            img = torch.cat([img, pixels["add_prop"][..., 0, :]], dim=-1)

        render_mask = fragments.pix_to_face != -1
        face_mask = self.face_mask.repeat(b)
        face_masked = face_mask[torch.clamp(fragments.pix_to_face, 0)]
        render_mask = torch.logical_and(render_mask, face_masked)

        return img, render_mask




```

## /cap4d/mmdm/mmdm.py

```py path="/cap4d/mmdm/mmdm.py" 
import einops
import torch
import numpy as np
from functools import partial
from einops import rearrange, repeat
from torchvision.utils import make_grid

from controlnet.ldm.models.diffusion.ddpm import LatentDiffusion
from controlnet.ldm.util import exists, default
from controlnet.ldm.modules.diffusionmodules.util import make_beta_schedule
from controlnet.ldm.models.diffusion.ddim import DDIMSampler
from cap4d.mmdm.utils import shift_schedule, enforce_zero_terminal_snr


class MMLDM(LatentDiffusion):
    """
    Class for morphable multi-view latent diffusion model
    """

    def __init__(
        self, 
        control_key, 
        only_mid_control, 
        n_frames, 
        *args, 
        cfg_probability=0.1,
        shift_schedule=False,
        sqrt_shift=False,
        zero_snr_shift=True,
        minus_one_shift=True,
        negative_shift=False,
        **kwargs
    ):
        self.n_frames = n_frames
        self.shift_schedule = shift_schedule
        self.sqrt_shift = sqrt_shift
        self.minus_one_shift = minus_one_shift
        self.control_key = control_key
        self.only_mid_control = only_mid_control
        self.cfg_probability = cfg_probability
        self.negative_shift = negative_shift
        self.zero_snr_shift = zero_snr_shift

        super().__init__(*args, **kwargs)

    # @torch.no_grad()
    def get_input(self, batch, k, bs=None, force_conditional=False, *args, **kwargs):
        with torch.no_grad():
            # From DDPM
            x = batch[k]
            if len(x.shape) == 3:
                x = x[..., None]
            x = rearrange(x, 'b t h w c -> b t c h w')
            x = x.to(memory_format=torch.contiguous_format) # .float()  CONTIGUOUS
            
            # From LatentDiffusion
            if bs is not None:
                x = x[:bs]
            b_, t_ = x.shape[:2]
            x_flat = einops.rearrange(x, 'b t c h w -> (b t) c h w')
            encoder_posterior = self.encode_first_stage(x_flat)
            z_flat = self.get_first_stage_encoding(encoder_posterior).detach()
            z = einops.rearrange(z_flat, '(b t) c h w -> b t c h w', b=b_)

            # Add gt z to control
            batch[self.control_key]['z'] = z.detach()
            
            c_uncond = self.get_unconditional_conditioning(batch[self.control_key])

            if "mask" in batch:
                loss_mask = batch["mask"]
            else:
                loss_mask = None

        c_cond = self.get_learned_conditioning(batch[self.control_key])

        if not force_conditional:
            is_uncond = torch.rand(b_, device=x.device) < self.cfg_probability  # do a mix with probability
            is_cond = torch.logical_not(is_uncond)
            control = {}
            for key in c_cond:
                control[key] = (
                    einops.einsum(c_uncond[key], is_uncond, 'b ..., b -> b ...') + 
                    einops.einsum(c_cond[key], is_cond, 'b ..., b -> b ...')
                )
        else:
            control = c_cond

        # New stuff
        assert isinstance(control, dict)

        if bs is not None:
            for key in control:
                control[key] = control[key][:bs]
        
        return z, dict(c_concat=[control], c_uncond=[c_uncond], mask=loss_mask)
    
    @torch.no_grad()
    def decode_first_stage(self, z, predict_cids=False):
        b_, t_ = z.shape[:2]
        z = einops.rearrange(z, 'b t c h w -> (b t) c h w')
        z = super().decode_first_stage(z, predict_cids)
        return einops.rearrange(z, '(b t) c h w -> b t c h w', b=b_)
    
    def forward(self, x, c, *args, **kwargs):
        t = torch.randint(0, self.num_timesteps, x.shape[:2], device=self.device).long()

        assert c is not None
        assert not self.shorten_cond_schedule

        return self.p_losses(x, c, t, *args, **kwargs)

    def apply_model(self, x_noisy, t, cond, *args, **kwargs):
        assert isinstance(cond, dict)
        diffusion_model = self.model.diffusion_model

        cond_txt = None  # remove text conditioning

        assert len(cond['c_concat']) == 1

        control = cond['c_concat'][0]
        eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)

        return eps
    
    def p_losses(self, x_start, cond, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        b_, t_ = t.shape[:2]
        t_flat = einops.rearrange(t, 'b t -> (b t)')
        noise_flat = einops.rearrange(noise, 'b t c h w -> (b t) c h w')
        x_start_flat = einops.rearrange(x_start, 'b t c h w -> (b t) c h w')
        x_noisy_flat = self.q_sample(x_start=x_start_flat, t=t_flat, noise=noise_flat)
        x_noisy = einops.rearrange(x_noisy_flat, '(b t) c h w -> b t c h w', b=b_)
        
        model_output = self.apply_model(x_noisy, t, cond)

        loss_dict = {}
        prefix = 'train' if self.training else 'val'

        assert self.parameterization == 'eps'
        target = noise

        loss_simple = self.get_loss(model_output, target, mean=False)
        # Mask loss by references
        loss_simple = loss_simple.mean(dim=[2, 3, 4])

        ref_mask = torch.logical_not(cond['c_concat'][0]['ref_mask'])
        loss_simple_mean = (loss_simple * ref_mask).sum(dim=-1) / ref_mask.sum(dim=-1)
        
        # Losses updated with time dimension
        loss_dict.update({f'{prefix}/loss_simple': loss_simple_mean.mean()})

        logvar_t = self.logvar[t]  # .to(self.device)
        loss = loss_simple / torch.exp(logvar_t) + logvar_t
        loss = (loss * ref_mask).sum(dim=-1) / ref_mask.sum(dim=-1)
        # loss = loss_simple / torch.exp(self.logvar) + self.logvar
        if self.learn_logvar:
            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
            loss_dict.update({'logvar': self.logvar.data.mean()})

        loss = self.l_simple_weight * loss.mean()

        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(2, 3, 4))
        loss_vlb = (self.lvlb_weights[t] * loss_vlb * ref_mask).sum(dim=-1) / ref_mask.sum(dim=-1)
        loss_vlb = loss_vlb.mean()
        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
        loss += (self.original_elbo_weight * loss_vlb)
        loss_dict.update({f'{prefix}/loss': loss})

        return loss, loss_dict
    
    # @torch.no_grad()
    def get_learned_conditioning(self, c):
        return self.cond_stage_model(c, unconditional=False)

    @torch.no_grad()
    def get_unconditional_conditioning(self, c):
        return self.cond_stage_model(c, unconditional=True)

    @torch.no_grad()
    def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
                   plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
                   use_ema_scope=True,
                   **kwargs):
        use_ddim = ddim_steps is not None

        log = dict()
        z, c = self.get_input(batch, self.first_stage_key, bs=N, force_conditional=True)
        c_cat = c["c_concat"][0]
        c_uncond = c["c_uncond"][0]
        for key in c_cat:
            c_cat[key] = c_cat[key][:N]
        N = min(z.shape[0], N)
        n_row = min(z.shape[0], n_row)
        log["reconstruction"] = self.decode_first_stage(z)

        if plot_diffusion_rows:
            # get diffusion row
            diffusion_row = list()
            z_start = z[:n_row]
            for t in range(self.num_timesteps):
                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                    t = repeat(torch.tensor([t], device=self.device), '1 -> b', b=n_row)
                    t = t.long()  # .to(self.device)
                    noise = torch.randn_like(z_start)
                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
                    diffusion_row.append(self.decode_first_stage(z_noisy))

            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
            log["diffusion_row"] = diffusion_grid

        if sample:
            # get denoise row
            samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
                                                     batch_size=N, ddim=use_ddim,
                                                     ddim_steps=ddim_steps, eta=ddim_eta)
            x_samples = self.decode_first_stage(samples)
            log["samples"] = x_samples
            if plot_denoise_rows:
                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
                log["denoise_row"] = denoise_grid

        if unconditional_guidance_scale > 1.0:
            samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [None]},
                                             batch_size=N, ddim=use_ddim,
                                             ddim_steps=ddim_steps, eta=ddim_eta,
                                             unconditional_guidance_scale=unconditional_guidance_scale,
                                             unconditional_conditioning=c_uncond,
                                             )
            x_samples_cfg = self.decode_first_stage(samples_cfg)
            log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg

        return log

    @torch.no_grad()
    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
        ddim_sampler = DDIMSampler(self)
        # b, c, h, w = cond["c_concat"][0].shape
        # shape = (self.channels, h // 8, w // 8)
        shape = (self.n_frames, 4, self.image_size, self.image_size) # 64, 64)
        samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
        return samples, intermediates

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.model.diffusion_model.parameters())

        if self.cond_stage_trainable:
            print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
            params = params + list([param for param in self.cond_stage_model.parameters() if param.requires_grad])

            for n, param in self.cond_stage_model.named_parameters():
                if param.requires_grad:
                    print(n)
        
        opt = torch.optim.AdamW(params, lr=lr)
        return opt

    def low_vram_shift(self, is_diffusing):
        if is_diffusing:
            self.model = self.model.cuda()
            # self.control_model = self.control_model.cuda()
            self.first_stage_model = self.first_stage_model.cpu()
            self.cond_stage_model = self.cond_stage_model.cpu()
        else:
            self.model = self.model.cpu()
            # self.control_model = self.control_model.cpu()
            self.first_stage_model = self.first_stage_model.cuda()
            self.cond_stage_model = self.cond_stage_model.cuda()

    def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
        if exists(given_betas):
            betas = given_betas
        else:
            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
                                       cosine_s=cosine_s)
            
        if self.zero_snr_shift:
            print(f"Enforcing zero terminal SNR in noise schedule.")
            betas = enforce_zero_terminal_snr(betas)

        betas[betas > 0.99] = 0.99

        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)

        if self.shift_schedule:
            n_gen = self.n_frames
            if self.minus_one_shift:
                n_gen = n_gen - 1  # we are generating only n_total - 1 frames technically
            shift_ratio = (64 ** 2) / (self.image_size ** 2 * n_gen)
            if self.negative_shift:
                shift_ratio = 1. / shift_ratio
            if self.sqrt_shift:
                shift_ratio = np.sqrt(shift_ratio)
            new_alpha_cumprod, new_betas = shift_schedule(alphas_cumprod, shift_ratio=shift_ratio)

            print(f"Shifted log psnr of noise schedule by {shift_ratio}.")

            alphas = 1. - new_betas
            betas = new_betas
            alphas_cumprod = new_alpha_cumprod

        print("Using non persistent schedule buffers.")

        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.linear_start = linear_start
        self.linear_end = linear_end
        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer('betas', to_torch(betas), persistent=False)
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod), persistent=False)
        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev), persistent=False)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)), persistent=False)
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)), persistent=False)
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)), persistent=False)
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)), persistent=False)
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)), persistent=False)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
                1. - alphas_cumprod) + self.v_posterior * betas
        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        self.register_buffer('posterior_variance', to_torch(posterior_variance), persistent=False)
        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))), persistent=False)
        self.register_buffer('posterior_mean_coef1', to_torch(
            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)), persistent=False)
        self.register_buffer('posterior_mean_coef2', to_torch(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)), persistent=False)

        if self.parameterization == "eps":
            lvlb_weights = self.betas ** 2 / (
                    2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
        elif self.parameterization == "x0":
            lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
        elif self.parameterization == "v":
            lvlb_weights = torch.ones_like(self.betas ** 2 / (
                    2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
        else:
            raise NotImplementedError("mu not supported")
        lvlb_weights[0] = lvlb_weights[1]
        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
        assert not torch.isnan(self.lvlb_weights).all()

        # From LatentDiffusion
        self.shorten_cond_schedule = self.num_timesteps_cond > 1
        if self.shorten_cond_schedule:
            self.make_cond_schedule()

```

## /cap4d/mmdm/modules/module.py

```py path="/cap4d/mmdm/modules/module.py" 
from pathlib import Path

import cv2
import numpy as np
import torch
import pytorch_lightning as pl
from controlnet.ldm.util import instantiate_from_config


def tensor_to_img(img):
    img = ((img.permute(1, 2, 0) + 1.) / 2. * 255).clamp(0, 255).detach().cpu().numpy()
    return img[..., [2, 1, 0]].astype(np.uint8)


class CAP4DModule(pl.LightningModule):
    def __init__(
        self, 
        model_config, 
        loss_config,
        # callback_config,
        ckpt_dir,
        *args: pl.Any, 
        **kwargs: pl.Any
    ) -> None:
        super().__init__(*args, **kwargs)

        self.model = instantiate_from_config(model_config)
        self.loss = instantiate_from_config(loss_config)

        self.ckpt_dir = Path(ckpt_dir)

        self.epoch_id = 0
        
    def training_step(self, batch, batch_idx):
        y = self.model(batch["hint"])

        diff = (batch["jpg"] - y["image"]).abs().mean(dim=1, keepdims=True)
        loss = (diff * y["mask"]).sum() / y["mask"].sum()  # self.loss(batch["jpg"], y)

        return loss
    
    def validation_step(self, batch, batch_idx):
        y = self.model(batch["hint"])

        diff = (batch["jpg"] - y["image"]).abs().mean(dim=1, keepdims=True)
        loss = (diff * y["mask"]).sum() / y["mask"].sum()  # self.loss(batch["jpg"], y)

        if batch_idx == 0:
            epoch_dir = self.ckpt_dir / f"epoch_{self.epoch_id:04d}"
            epoch_dir.mkdir(exist_ok=True)

            print(self.ckpt_dir)
            for i in range(y["image"].shape[0]):
                cv2.imwrite(str(epoch_dir / f"{i:02d}_pred.png"), tensor_to_img(y["image"][i]))
                cv2.imwrite(str(epoch_dir / f"{i:02d}_source.png"), tensor_to_img(batch["hint"]["source_img"][i]))
                cv2.imwrite(str(epoch_dir / f"{i:02d}_target.png"), tensor_to_img(batch["jpg"][i]))

        self.epoch_id += 1

        self.log("loss", loss)
        return loss

    def configure_optimizers(self) -> pl.Any:
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer
```

## /cap4d/mmdm/net/attention.py

```py path="/cap4d/mmdm/net/attention.py" 
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any

from controlnet.ldm.modules.diffusionmodules.util import checkpoint, GroupNorm32, LayerNorm32

try:
    import xformers
    import xformers.ops
    XFORMERS_IS_AVAILBLE = True
except:
    XFORMERS_IS_AVAILBLE = False

# CrossAttn precision handling
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
# Whether to use flash attention
FIX_LEGACY_FAIL = os.environ.get("FIX_LEGACY_FAIL", False)
if FIX_LEGACY_FAIL:
    print("================================")
    print("Fixing legacy failed k and v layers")
    print("================================")

_USE_FP16_ATTENTION = os.environ.get("USE_FP16_ATTENTION", False)
if _USE_FP16_ATTENTION:
    print("================================")
    print("Using fp16 attention")
    print("================================")

_USE_FLASH = os.environ.get("USE_FLASH_ATTENTION", False)
if _USE_FLASH:
    from flash_attn import flash_attn_func
    print("================================")
    print("Using flash attention")
    print("================================")


def exists(val):
    return val is not None


def uniq(arr):
    return{el: True for el in arr}.keys()


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def max_neg_value(t):
    return -torch.finfo(t.dtype).max


def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor


# feedforward
class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)


class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU()
        ) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(
            project_in,
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim_out)
        )

    def forward(self, x):
        return self.net(x)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def Normalize(in_channels):
    # return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
    return GroupNorm32(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)


def legacy_attention(q, k, v, scale, mask=None):
    # force cast to fp32 to avoid overflowing
    if _ATTN_PRECISION =="fp32":
        with torch.autocast(enabled=False, device_type = 'cuda'):
            q, k = q.float(), k.float()
            sim = einsum('b i d, b j d -> b i j', q, k) * scale
    else:
        sim = einsum('b i d, b j d -> b i j', q, k) * scale
    
    del q, k

    if exists(mask):
        mask = rearrange(mask, 'b ... -> b (...)')
        max_neg_value = -torch.finfo(sim.dtype).max
        mask = repeat(mask, 'b j -> (b h) () j', h=h)
        sim.masked_fill_(~mask, max_neg_value)

    # attention, what we cannot get enough of
    sim = sim.softmax(dim=-1)

    return einsum('b i j, b j d -> b i d', sim, v)


class AttentionModule(nn.Module):
    def __init__(
        self, 
        query_dim, 
        heads=8, 
        dim_head=64, 
        dropout=0., 
        mode="spatial",  # ["spatial", "context", "temporal", "3d"]
        context_dim=None, 
        num_timesteps=0,
    ):
        super().__init__()
        inner_dim = dim_head * heads

        self.mode = mode
        if mode == "context":
            assert context_dim is not None
            kv_dim = context_dim
        elif mode == "spatial":
            kv_dim = query_dim
        elif mode == "temporal":
            kv_dim = query_dim
            assert num_timesteps > 0
        elif mode == "3d":
            kv_dim = query_dim
            assert num_timesteps > 0
        else:
            assert False, f"ERROR: unrecognized mode {mode}"

        self.scale = dim_head ** -0.5
        self.heads = heads
        self.num_timesteps = num_timesteps

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(kv_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(kv_dim, inner_dim, bias=False)
        self.k_v_fixed = False

        is_zero_module = mode == "temporal"

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim) if is_zero_module else zero_module(nn.Linear(inner_dim, query_dim)),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads
        t = self.num_timesteps
        b = x.shape[0]  # is batch_size * time_steps

        q = self.to_q(x)
        if self.mode == "context":
            assert context is not None
            # context attention
            k = self.to_k(context)
            v = self.to_v(context)
        else:
            if FIX_LEGACY_FAIL and not self.k_v_fixed:
                with torch.no_grad():
                    self.to_v.weight.data = self.to_k.weight.data
                self.k_v_fixed = True
            # self attention
            k = self.to_k(x)
            v = self.to_v(x)


        if _USE_FLASH or XFORMERS_IS_AVAILBLE:  # XFORMER ATTENTION
            if self.mode == "3d":
                q, k, v = map(lambda yt: rearrange(yt, '(b t) n (h d) -> b (n t) h d', h=h, t=t), (q, k, v))
            elif self.mode == "temporal":
                q, k, v = map(lambda yt: rearrange(yt, '(b t) n (h d) -> (b n) t h d', h=h, t=t), (q, k, v))
            elif self.mode == "context" or self.mode == "spatial":
                q, k, v = map(lambda yt: rearrange(yt, 'b n (h d) -> b n h d', h=h), (q, k, v))

            if _USE_FLASH:
                dtype_before = q.dtype
                out = flash_attn_func(q.half(), k.half(), v.half()).type(dtype_before)
            else:
                if _USE_FP16_ATTENTION:
                    dtype_before = q.dtype
                    out = xformers.ops.memory_efficient_attention(
                        q.half(), k.half(), v.half(), attn_bias=None, op=None,
                    ).type(dtype_before)
                else:
                    out = xformers.ops.memory_efficient_attention(
                        q, k, v, attn_bias=None, op=None,
                    )

            if self.mode == "3d":
                out = rearrange(out, 'b (n t) h d -> (b t) n (h d)', b=b//t, h=h, t=t)
            elif self.mode == "temporal":
                out = rearrange(out, '(b n) t h d -> (b t) n (h d)', b=b//t, h=h, t=t)
            elif self.mode == "context" or self.mode == "spatial":
                out = rearrange(out, 'b n h d -> b n (h d)', h=h)


        else:  # NORMAL ATTENTION
            if self.mode == "3d":
                q, k, v = map(lambda yt: rearrange(yt, '(b t) n (h d) -> (b h) (n t) d', h=h, t=t), (q, k, v))
            elif self.mode == "temporal":
                q, k, v = map(lambda yt: rearrange(yt, '(b t) n (h d) -> (b h n) t d', h=h, t=t), (q, k, v))
            elif self.mode == "context" or self.mode == "spatial":
                q, k, v = map(lambda yt: rearrange(yt, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

            if _USE_FP16_ATTENTION:
                dtype_before = q.dtype
                out = legacy_attention(q.half(), k.half(), v.half(), self.scale, mask=mask).type(dtype_before)
            else:
                out = legacy_attention(q, k, v, self.scale, mask=mask)

            if self.mode == "3d":
                out = rearrange(out, '(b h) (n t) d -> (b t) n (h d)', b=b//t, h=h, t=t)
            elif self.mode == "temporal":
                out = rearrange(out, '(b h n) t d -> (b t) n (h d)', b=b//t, h=h, t=t)
            elif self.mode == "context" or self.mode == "spatial":
                out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)


class BasicTransformerBlock(nn.Module):
    def __init__(
        self, 
        dim, 
        n_heads, 
        d_head, 
        dropout=0., 
        use_context=True,
        context_dim=None, 
        gated_ff=True, 
        temporal_connection_type="none",  # [3d, temporal, none]
        num_timesteps=0,
    ):
        super().__init__()
        self.temporal_connection_type = temporal_connection_type
        if temporal_connection_type != "none":
            assert num_timesteps > 0
        
        self.attn1 = AttentionModule(
            query_dim=dim, 
            heads=n_heads, 
            dim_head=d_head, 
            dropout=dropout,
            mode="spatial" if temporal_connection_type != "3d" else "3d",
            num_timesteps=num_timesteps,
        )  # is a self-attention if not self.disable_self_attn
        self.norm1 = LayerNorm32(dim)  # nn.LayerNorm(dim)

        self.use_context = use_context
        if use_context:
            self.attn2 = AttentionModule(
                query_dim=dim, 
                context_dim=context_dim,
                heads=n_heads, 
                dim_head=d_head, 
                dropout=dropout,
                mode="context",
            )  # is self-attn if context is none
            self.norm2 = LayerNorm32(dim)  # nn.LayerNorm(dim)

        if temporal_connection_type == "temporal":
            self.attn_t = AttentionModule(
                query_dim=dim, 
                context_dim=None,
                heads=n_heads, 
                dim_head=d_head, 
                dropout=dropout, 
                num_timesteps=num_timesteps
            )
            self.norm_t = LayerNorm32(dim)  # nn.LayerNorm(dim)

        self.norm3 = LayerNorm32(dim)  # nn.LayerNorm(dim)
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)

    def forward(self, x, context=None):
        return self._forward(x, context)

    def _forward(self, x, context=None):
        x = self.attn1(self.norm1(x), context=None) + x

        # context attention
        if self.use_context:
            assert context is not None # DISABLE CROSS ATTENTION IF NO CONTEXT
            x = self.attn2(self.norm2(x), context=context) + x

        # temporal attention
        if self.temporal_connection_type == "temporal":
            attn = self.attn_t(self.norm_t(x), context=None)
            x = attn + x

        # ff and normalization
        x = self.ff(self.norm3(x)) + x
        return x


class SpatioTemporalTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    NEW: use_linear for more efficiency instead of the 1x1 convs
    """
    def __init__(
        self, 
        in_channels, 
        n_heads, 
        d_head,
        dropout=0., 
        use_context=True,
        context_dim=None,
        temporal_connection_type="none",  # [3d, temporal, none]
        num_timesteps=0,
    ):
        super().__init__()

        if use_context:
            assert context_dim is not None

        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)
        self.proj_in = nn.Linear(in_channels, inner_dim)

        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(
                inner_dim, 
                n_heads, 
                d_head, 
                dropout=dropout, 
                use_context=use_context,
                context_dim=context_dim,
                temporal_connection_type=temporal_connection_type, 
                num_timesteps=num_timesteps,
            )]
        )
        self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))

        self.has_context = True

    def forward(self, x, context=None):
        # note: if no context is given, cross-attention defaults to self-attention
        assert not isinstance(context, list)
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)
        x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
        x = self.proj_in(x)
        for i, block in enumerate(self.transformer_blocks):
            x = block(x, context=context)
        x = self.proj_out(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
        return x + x_in


```

## /cap4d/mmdm/net/mmdm_unet.py

```py path="/cap4d/mmdm/net/mmdm_unet.py" 
import torch
import torch.nn as nn
import einops

from controlnet.ldm.modules.diffusionmodules.openaimodel import UNetModel
from controlnet.ldm.modules.diffusionmodules.util import (
    zero_module,
    timestep_embedding,
)

from cap4d.mmdm.net.attention import SpatioTemporalTransformer


class MMDMUnetModel(UNetModel):
    def __init__(
        self, 
        *args, 
        time_steps,
        condition_channels=50, 
        model_channels=320,
        image_size=32,
        context_dim=1024,
        temporal_mode="3d",  # ["3d", "temporal"]
        **kwargs,
    ):
        assert temporal_mode in ["3d", "temporal"]
        self.temporal_mode = temporal_mode
        self.time_steps = time_steps
        self.use_context = False

        super().__init__(*args, model_channels=model_channels, image_size=image_size, context_dim=context_dim, **kwargs)

        self.cond_linear = zero_module(nn.Linear(condition_channels, model_channels))

    def create_attention_block(
        self, 
        ch,
        mult,
        use_checkpoint,
        num_heads,
        dim_head,
        transformer_depth,
        context_dim,
        disable_self_attn,
        use_linear,
        use_new_attention_order,
        use_spatial_transformer,
    ):
        if self.temporal_mode == "temporal":
            temporal_connection_type = "temporal"
        elif self.temporal_mode == "3d":
            if mult >= 2:
                temporal_connection_type = "3d"
            else:
                temporal_connection_type = "none"

        return SpatioTemporalTransformer(
            ch, 
            num_heads, 
            dim_head, 
            use_context=self.use_context,
            context_dim=context_dim,
            temporal_connection_type=temporal_connection_type,
            num_timesteps=self.time_steps,
        )

    def forward(self, x, timesteps=None, context=None, control=None, **kwargs):
        """
        x (b t h w c): input latent
        control (b t h w c_cond): input conditioning signal
        """
        
        z_input = control["z_input"]

        ref_mask = control["ref_mask"]
        # ground truth noise output
        x_input = x - z_input  

        ref_mask_inv = torch.logical_not(ref_mask)

        # replace with input latents where available
        x = z_input * ref_mask + x * ref_mask_inv

        # Disabling context cross attention for MMDM
        assert context == None

        # Flatten time dimension for processing
        b_, t_ = x.shape[:2]
        x = einops.rearrange(x, 'b t c h w -> (b t) c h w')
        timesteps = einops.rearrange(timesteps, 'b t -> (b t)')

        pos_enc = einops.rearrange(control["pos_enc"], 'b t h w c -> (b t) h w c')
        pos_enc = pos_enc.type(self.dtype)
        pos_embedding = self.cond_linear(pos_enc)
        pos_embedding = pos_embedding.permute(0, 3, 1, 2)

        hs = []
        
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        t_emb = t_emb.type(self.dtype)
        emb = self.time_embed(t_emb)
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)

            if pos_embedding is not None:
                h += pos_embedding
                pos_embedding = None

            hs.append(h)

        h = self.middle_block(h, emb, context)

        for i, module in enumerate(self.output_blocks):
            h = torch.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)

        h = self.out(h)
        h = h.type(x.dtype)

        # Unflatten time dimension
        h = einops.rearrange(h, '(b t) c h w -> b t c h w', b=b_)
        
        # replace with input latents where available
        h = x_input * ref_mask + h * ref_mask_inv

        return h
```

## /cap4d/mmdm/sampler.py

```py path="/cap4d/mmdm/sampler.py" 
from typing import Tuple, Dict
import torch
import numpy as np
from tqdm import tqdm
from controlnet.ldm.modules.diffusionmodules.util import (
    make_ddim_sampling_parameters, 
    make_ddim_timesteps, 
)


class StochasticIOSampler(object):
    def __init__(self, model, **kwargs):
        super().__init__()

        if isinstance(model, dict):
            # model distributed on different devices
            self.device_model_map = model
        else:
            if torch.cuda.is_available():
                self.device_model_map = {"cuda": model}
            else:
                self.device_model_map = {"cpu": model}

        for key in self.device_model_map:
            self.main_model = self.device_model_map[key]
            self.ddpm_num_timesteps = self.main_model.num_timesteps
            break

    def register_buffer(self, name, attr):
        setattr(self, name, attr)

    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
        alphas_cumprod = self.main_model.alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
        to_torch = lambda x: x.clone().detach().to(torch.float32)  # .to(self.main_model.device)

        self.register_buffer('betas', to_torch(self.main_model.betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(self.main_model.alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.detach().cpu())))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.detach().cpu())))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.detach().cpu())))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.detach().cpu())))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.detach().cpu() - 1)))

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.detach().cpu(),
                                                                                   ddim_timesteps=self.ddim_timesteps,
                                                                                   eta=ddim_eta,verbose=verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

    @torch.no_grad()
    def sample(
        self,
        S: int,
        ref_cond: Dict[str, torch.Tensor],
        ref_uncond: Dict[str, torch.Tensor],
        gen_cond: Dict[str, torch.Tensor],
        gen_uncond: Dict[str, torch.Tensor],
        latent_shape: Tuple[int, int, int],
        V: int = 8,
        R_max: int = 4,
        cfg_scale: float = 1.,
        eta: float = 0.,
        verbose: bool = False,
    ):
        """
        Generate images from reference images using Stochastic I/O conditioning.

        Parameters:
            S (int): Number of diffusion steps.
            ref_cond (Dict[str, torch.Tensor]): Conditioning images used for reference (ref latents, pose maps, reference masks etc.).
            ref_uncond (Dict[str, torch.Tensor]): Unconditional conditioning images used for reference (zeroed conditioning).
            gen_cond (Dict[str, torch.Tensor]): Conditioning images used for reference (pose maps, reference masks etc.).
            gen_uncond (Dict[str, torch.Tensor]): Unconditional conditioning images used for reference (pose maps, reference masks etc.).
            latent_shape (Tuple[int]): Shape of the latent to be generated (B, C, H, W).
            V (int): Number of views supported by the MMDM.
            R_max (int, optional): Maximum number of reference images to use. Defaults to 4.
            cfg_scale (float, optional): Classifier-free guidance scale. Higher values increase conditioning strength. Defaults to 1.0.
            eta (float, optional): Noise scaling factor for DDIM sampling. 0 means deterministic sampling. Defaults to 0.
            verbose (bool, optional): Whether to print detailed logs during sampling. Defaults to False.

        Returns:
            torch.Tensor: A tensor representing the generated sample(s) in latent space.
        """

        mem_device = next(iter(gen_cond.items()))[1].device
        n_devices = len(self.device_model_map)

        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)

        n_gen = next(iter(gen_cond.items()))[1].shape[0]
        n_all_ref = next(iter(ref_cond.items()))[1].shape[0]
        R = min(n_all_ref, R_max)

        assert n_gen % (V - R) == 0, f"number of generated images ({n_gen}) has to be divisible by G ({V-R})"  # has to be divisible for now
        n_its = n_gen // (V - R)

        # store all latents on CPU (to prevent using too much GPU memory)
        all_x_T = torch.randn((n_gen, *latent_shape), device=mem_device)
        all_e_t = torch.zeros_like(all_x_T)

        timesteps = self.ddim_timesteps
        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]

        print(f"Running stochastic I/O sampling with {total_steps} timesteps, {R} reference images and {n_gen} generated images")

        iterator = tqdm(time_range, desc='Stochastic I/O sampler', total=total_steps)

        for i, step in enumerate(iterator):
            index = total_steps - i - 1

            ts = torch.full((1, V), step, device=mem_device, dtype=torch.long)

            # reset e_t accumulator
            all_e_t = all_e_t * 0.

            # gather ref and gen batches
            if R == 1:
                ref_batches = np.zeros((n_its, R), dtype=np.int64)
            else:
                ref_batches = np.stack([
                    np.random.permutation(np.arange(n_all_ref))[:R] for _ in range(n_its)
                ], axis=0)

            gen_batches = np.reshape(np.random.permutation(np.arange(n_gen)), (n_its, -1))

            def dict_sample(in_dict, indices, device=None):
                out_dict = {}
                for key in in_dict:
                    if device is None:
                        out_dict[key] = in_dict[key][indices]
                    else:
                        out_dict[key] = in_dict[key][indices].to(device)
                return out_dict
            
            # Prepare input to GPUs
            batch_indices = []  # [[b] for b in range(n_its)]
            for l in range(int(np.ceil(n_its / n_devices))):
                device_batch = []
                for device_id in range(min(n_devices, n_its)):
                    if l * n_devices + device_id < n_its:
                        device_batch.append([l * n_devices + device_id])

                batch_indices.append(device_batch)

            # Go through all gen_batches and apply noise update
            for dev_batches in batch_indices:
                x_in_list = []
                t_in_list = []
                c_in_list = []
                e_t_list = []
                
                for dev_id, dev_batch in enumerate(dev_batches):
                    dev_key = list(self.device_model_map)[dev_id]
                    dev_device = self.device_model_map[dev_key].device

                    curr_ref_cond = dict_sample(ref_cond, ref_batches[dev_batch], device=dev_device)
                    curr_ref_uncond = dict_sample(ref_uncond, ref_batches[dev_batch], device=dev_device)

                    curr_gen_cond = dict_sample(gen_cond, gen_batches[dev_batch], device=dev_device)
                    curr_gen_uncond = dict_sample(gen_uncond, gen_batches[dev_batch], device=dev_device)

                    curr_x_T = all_x_T[gen_batches[dev_batch]].to(dev_device)  # making batch_size = 1 this way

                    curr_cond = {}
                    curr_uncond = {}
                    c_in = {}
                    for key in curr_ref_cond:
                        curr_cond[key] = torch.cat([curr_ref_cond[key], curr_gen_cond[key]], dim=1)
                        curr_uncond[key] = torch.cat([curr_ref_uncond[key], curr_gen_uncond[key]], dim=1)

                        c_in[key] = torch.cat([curr_uncond[key], curr_cond[key]], dim=0) # stack them to run uncond and cond in one pass
                    
                    t_in = torch.cat([ts] * 2, dim=0).to(dev_device)
                    c_in = dict(c_concat=[c_in])
                    x_in = torch.cat([curr_cond["z_input"][:, :R], curr_x_T], dim=1)
                    x_in = torch.cat([x_in] * 2, dim=0).to(dev_device)

                    x_in_list.append(x_in)
                    t_in_list.append(t_in)
                    c_in_list.append(c_in)

                # Run model in parallel on all available devices
                for dev_id, dev_batch in enumerate(dev_batches):
                    dev_key = list(self.device_model_map)[dev_id]
                    dev_device = self.device_model_map[dev_key].device
                    model_uncond, model_t = self.device_model_map[dev_key].apply_model(
                        x_in_list[dev_id], 
                        t_in_list[dev_id], 
                        c_in_list[dev_id],
                    ).chunk(2)
                    model_output = model_uncond + cfg_scale * (model_t - model_uncond)

                    e_t = model_output[:, R:]  # eps prediction mode, extract the generation samples starting at n_ref

                    e_t_list.append(e_t)

                for dev_id, dev_batch in enumerate(dev_batches):
                    all_e_t[gen_batches[dev_batch]] += e_t_list[dev_id].to(mem_device)

            alpha_t = self.ddim_alphas.float()[index]
            sqrt_one_minus_alpha_t = self.ddim_sqrt_one_minus_alphas[index]
            sigma_t = self.ddim_sigmas[index]
            alpha_prev_t = torch.tensor(self.ddim_alphas_prev).float()[index]

            alpha_prev_t = alpha_prev_t.double()
            sqrt_one_minus_alpha_t = sqrt_one_minus_alpha_t.double()
            alpha_t = alpha_t.double()
            alpha_prev_t = alpha_prev_t.double()
            
            e_t_factor = -alpha_prev_t.sqrt() * sqrt_one_minus_alpha_t / alpha_t.sqrt() + (1. - alpha_prev_t - sigma_t**2).sqrt()
            x_t_factor = alpha_prev_t.sqrt() / alpha_t.sqrt() 
            
            e_t_factor = e_t_factor.float()
            x_t_factor = x_t_factor.float()

            all_x_T = all_x_T * x_t_factor + all_e_t * e_t_factor

        return all_x_T

            
```

## /cap4d/mmdm/utils.py

```py path="/cap4d/mmdm/utils.py" 
import numpy as np


def shift_schedule(alpha_cumprods, shift_ratio):
    # shift_ratio = original_resolution (512) ** 2 / (new_resolution ** 2 * n_images)
    sigma_cp = 1. - alpha_cumprods
    snr = alpha_cumprods / sigma_cp   

    # log_snr_shifted = np.log(snr) - np.log(shift_ratio)
    log_snr_shifted = np.log(snr) + np.log(shift_ratio)  
    alpha_shifted = np.exp(log_snr_shifted) / (1 + np.exp(log_snr_shifted))
    betas_shifted = 1 - np.concatenate([[1], (alpha_shifted[1:] / alpha_shifted[:-1])])

    return alpha_shifted, betas_shifted


# https://arxiv.org/pdf/2305.08891
def enforce_zero_terminal_snr(betas):
    # Convert betas to alphas_bar_sqrt
    alphas = 1 - betas
    alphas_bar = alphas.cumprod(0)
    alphas_bar_sqrt = np.sqrt(alphas_bar)

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
    # Shift so last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T
    # Scale so first timestep is back to old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt ** 2
    alphas = alphas_bar[1:] / alphas_bar[:-1]
    alphas = np.concatenate([alphas_bar[0:1], alphas])
    betas = 1 - alphas
    return betas
```

## /configs/avatar/debug.yaml

```yaml path="/configs/avatar/debug.yaml" 
opt_params:
  iterations: 1_000

  sh_warmup_iterations: 100

  # clip_probability: 0.2

  lambda_scale: 1.
  threshold_scale: 1.
  lambda_xyz: 1e-3
  threshold_xyz: 2.
  metric_xyz: False
  metric_scale: False

  feature_lr: 0.0025
  opacity_lr: 0.025
  scaling_lr: 0.005  # (scaled up according to mean triangle scale)  # 0.005 (original)
  rotation_lr: 0.001
  percent_dense: 0.01
  lambda_dssim: 0.4

  densification_interval: 200
  densify_grad_threshold: 1e-6
  opacity_reset_interval: 200
  densify_until_iter: 700
  densify_from_iter: 80

  position_lr_init: 5e-3
  position_lr_final: 5e-5
  position_lr_delay_mult: 0.01
  position_lr_max_steps: 1000
  
  w_lpips: 0.1
  lambda_lpips_end: 0.9
  lpips_linear_start: 100
  lpips_linear_end: 600
  deform_net_w_decay: 2e-3
  deform_net_lr_init: 1e-5
  deform_net_lr_final: 1e-7
  deform_net_lr_delay_mult: 0.01
  deform_net_lr_max_steps: 1000
  lambda_laplacian: 1.0
  lambda_relative_deform: 0.4
  lambda_relative_rot: 0.005

  neck_lr_init: 1e-5
  neck_lr_final: 1e-7
  neck_lr_delay_mult: 0.01
  neck_lr_max_steps: 1000

  lambda_neck: 1.0
  
model_params:
  n_unet_layers: 6
  n_points_per_triangle: 2
  use_lower_jaw: True
  static_neck: False
  gaussian_init_type: "scaled"
  use_expr_mask: True
  uv_resolution: 128
  n_gaussians_init: 100_000

  sh_degree: 3

```

## /configs/avatar/default.yaml

```yaml path="/configs/avatar/default.yaml" 
opt_params:
  iterations: 100_000

  sh_warmup_iterations: 1_000

  # clip_probability: 0.2

  lambda_scale: 1.
  threshold_scale: 1.
  lambda_xyz: 1e-3
  threshold_xyz: 2.
  metric_xyz: False
  metric_scale: False

  feature_lr: 0.0025
  opacity_lr: 0.025
  scaling_lr: 0.005  # (scaled up according to mean triangle scale)  # 0.005 (original)
  rotation_lr: 0.001
  percent_dense: 0.01
  lambda_dssim: 0.4

  densification_interval: 2000
  densify_grad_threshold: 1e-6
  opacity_reset_interval: 20_000
  densify_until_iter: 70_000
  densify_from_iter: 8_000

  position_lr_init: 5e-3
  position_lr_final: 5e-5
  position_lr_delay_mult: 0.01
  position_lr_max_steps: 100_000
  
  w_lpips: 0.1
  lambda_lpips_end: 0.9
  lpips_linear_start: 10_000
  lpips_linear_end: 60_000
  deform_net_w_decay: 2e-3
  deform_net_lr_init: 1e-5
  deform_net_lr_final: 1e-7
  deform_net_lr_delay_mult: 0.01
  deform_net_lr_max_steps: 100_000
  lambda_laplacian: 1.0
  lambda_relative_deform: 0.4
  lambda_relative_rot: 0.005

  neck_lr_init: 1e-5
  neck_lr_final: 1e-7
  neck_lr_delay_mult: 0.01
  neck_lr_max_steps: 100_000

  lambda_neck: 1.0
  
model_params:
  n_unet_layers: 6
  n_points_per_triangle: 2
  use_lower_jaw: True
  static_neck: False
  gaussian_init_type: "scaled"
  use_expr_mask: True
  uv_resolution: 128
  n_gaussians_init: 100_000

  sh_degree: 3

```

## /configs/avatar/high_quality.yaml

```yaml path="/configs/avatar/high_quality.yaml" 
opt_params:
  iterations: 100_000

  sh_warmup_iterations: 1_000

  # clip_probability: 0.2

  lambda_scale: 1.
  threshold_scale: 1.
  lambda_xyz: 1e-3
  threshold_xyz: 2.
  metric_xyz: False
  metric_scale: False

  feature_lr: 0.0025
  opacity_lr: 0.025
  scaling_lr: 0.005  # (scaled up according to mean triangle scale)  # 0.005 (original)
  rotation_lr: 0.001
  percent_dense: 0.01
  lambda_dssim: 0.4

  densification_interval: 2000
  densify_grad_threshold: 1e-6
  opacity_reset_interval: 20_000
  densify_until_iter: 70_000
  densify_from_iter: 8_000

  position_lr_init: 5e-3
  position_lr_final: 5e-5
  position_lr_delay_mult: 0.01
  position_lr_max_steps: 100_000
  
  w_lpips: 0.1
  lambda_lpips_end: 0.9
  lpips_linear_start: 10_000
  lpips_linear_end: 60_000
  deform_net_w_decay: 2e-3
  deform_net_lr_init: 1e-5
  deform_net_lr_final: 1e-7
  deform_net_lr_delay_mult: 0.01
  deform_net_lr_max_steps: 100_000
  lambda_laplacian: 1.0
  lambda_relative_deform: 0.4
  lambda_relative_rot: 0.005

  neck_lr_init: 1e-5
  neck_lr_final: 1e-7
  neck_lr_delay_mult: 0.01
  neck_lr_max_steps: 100_000

  lambda_neck: 1.0
  
model_params:
  n_unet_layers: 6
  n_points_per_triangle: 2
  use_lower_jaw: True
  static_neck: False
  gaussian_init_type: "scaled"
  use_expr_mask: True
  uv_resolution: 128
  n_gaussians_init: 100_000

  sh_degree: 3

```

## /configs/avatar/low_quality.yaml

```yaml path="/configs/avatar/low_quality.yaml" 
opt_params:
  iterations: 25_000

  sh_warmup_iterations: 500

  # clip_probability: 0.2

  lambda_scale: 1.
  threshold_scale: 1.
  lambda_xyz: 1e-3
  threshold_xyz: 2.
  metric_xyz: False
  metric_scale: False

  feature_lr: 0.0025
  opacity_lr: 0.025
  scaling_lr: 0.005  # (scaled up according to mean triangle scale)  # 0.005 (original)
  rotation_lr: 0.001
  percent_dense: 0.01
  lambda_dssim: 0.4

  densification_interval: 1000
  densify_grad_threshold: 1e-6
  opacity_reset_interval: 10_000
  densify_until_iter: 24_000
  densify_from_iter: 2_000

  position_lr_init: 5e-3
  position_lr_final: 5e-5
  position_lr_delay_mult: 0.01
  position_lr_max_steps: 25_000
  
  w_lpips: 0.1
  lambda_lpips_end: 0.9
  lpips_linear_start: 4_000
  lpips_linear_end: 24_000
  deform_net_w_decay: 2e-3
  deform_net_lr_init: 1e-5
  deform_net_lr_final: 1e-7
  deform_net_lr_delay_mult: 0.01
  deform_net_lr_max_steps: 25_000
  lambda_laplacian: 1.0
  lambda_relative_deform: 0.4
  lambda_relative_rot: 0.005

  neck_lr_init: 1e-5
  neck_lr_final: 1e-7
  neck_lr_delay_mult: 0.01
  neck_lr_max_steps: 25_000

  lambda_neck: 1.0
  
model_params:
  n_unet_layers: 6
  n_points_per_triangle: 2
  use_lower_jaw: True
  static_neck: False
  gaussian_init_type: "scaled"
  use_expr_mask: True
  uv_resolution: 128
  n_gaussians_init: 25_000

  sh_degree: 3

```

## /configs/avatar/medium_quality.yaml

```yaml path="/configs/avatar/medium_quality.yaml" 
opt_params:
  iterations: 50_000

  sh_warmup_iterations: 1_000

  # clip_probability: 0.2

  lambda_scale: 1.
  threshold_scale: 1.
  lambda_xyz: 1e-3
  threshold_xyz: 2.
  metric_xyz: False
  metric_scale: False

  feature_lr: 0.0025
  opacity_lr: 0.025
  scaling_lr: 0.005  # (scaled up according to mean triangle scale)  # 0.005 (original)
  rotation_lr: 0.001
  percent_dense: 0.01
  lambda_dssim: 0.4

  densification_interval: 2000
  densify_grad_threshold: 1e-6
  opacity_reset_interval: 10_000
  densify_until_iter: 40_000
  densify_from_iter: 4_000

  position_lr_init: 5e-3
  position_lr_final: 5e-5
  position_lr_delay_mult: 0.01
  position_lr_max_steps: 50_000
  
  w_lpips: 0.1
  lambda_lpips_end: 0.9
  lpips_linear_start: 5_000
  lpips_linear_end: 40_000
  deform_net_w_decay: 2e-3
  deform_net_lr_init: 1e-5
  deform_net_lr_final: 1e-7
  deform_net_lr_delay_mult: 0.01
  deform_net_lr_max_steps: 50_000
  lambda_laplacian: 1.0
  lambda_relative_deform: 0.4
  lambda_relative_rot: 0.005

  neck_lr_init: 1e-5
  neck_lr_final: 1e-7
  neck_lr_delay_mult: 0.01
  neck_lr_max_steps: 50_000

  lambda_neck: 1.0
  
model_params:
  n_unet_layers: 6
  n_points_per_triangle: 2
  use_lower_jaw: True
  static_neck: False
  gaussian_init_type: "scaled"
  use_expr_mask: True
  uv_resolution: 128
  n_gaussians_init: 50_000

  sh_degree: 3

```

## /configs/generation/debug.yaml

```yaml path="/configs/generation/debug.yaml" 
n_ddim_steps: 10
cfg_scale: 2.0
resolution: 512
seed: 124
R_max: 4
V: 8

ckpt_path: ./data/weights/mmdm/

generation_data:
  data_path: ./data/assets/datasets/gen_data.npz
  yaw_range: 55
  pitch_range: 20
  expr_factor: 1.0
  n_samples: 28

```

## /configs/generation/default.yaml

```yaml path="/configs/generation/default.yaml" 
n_ddim_steps: 100
cfg_scale: 2.0
resolution: 512
seed: 124
R_max: 4
V: 8

ckpt_path: ./data/weights/mmdm/

generation_data:
  data_path: ./data/assets/datasets/gen_data.npz
  yaw_range: 55
  pitch_range: 20
  expr_factor: 1.0
  n_samples: 840

```

## /configs/generation/high_quality.yaml

```yaml path="/configs/generation/high_quality.yaml" 
n_ddim_steps: 100
cfg_scale: 2.0
resolution: 512
seed: 124
R_max: 4
V: 8

ckpt_path: ./data/weights/mmdm/

generation_data:
  data_path: ./data/assets/datasets/gen_data.npz
  yaw_range: 55
  pitch_range: 20
  expr_factor: 1.0
  n_samples: 840

```

## /configs/generation/low_quality.yaml

```yaml path="/configs/generation/low_quality.yaml" 
n_ddim_steps: 50
cfg_scale: 2.0
resolution: 512
seed: 124
R_max: 4
V: 8

ckpt_path: ./data/weights/mmdm/

generation_data:
  data_path: ./data/assets/datasets/gen_data.npz
  yaw_range: 55
  pitch_range: 20
  expr_factor: 1.0
  n_samples: 210

```

## /configs/generation/medium_quality.yaml

```yaml path="/configs/generation/medium_quality.yaml" 
n_ddim_steps: 50
cfg_scale: 2.0
resolution: 512
seed: 124
R_max: 4
V: 8

ckpt_path: ./data/weights/mmdm/

generation_data:
  data_path: ./data/assets/datasets/gen_data.npz
  yaw_range: 55
  pitch_range: 20
  expr_factor: 1.0
  n_samples: 420

```

## /configs/mmdm/cap4d_mmdm_final.yaml

```yaml path="/configs/mmdm/cap4d_mmdm_final.yaml" 
dataset_root: PATH_TO_DATASET

gpu_batch_size: 1
virtual_batch_size: 64
logger_freq: 400
learning_rate: 1e-4
n_steps: 100000
n_ref: 4
save_every_n_steps: 1000
init_path: "models/v2-1_512-ema-pruned.ckpt"

dataset:
  target: cap4d.datasets.concat_dataset.ConcatDataset
  params:
    dataset_configs:
      - target: cap4d.datasets.flow_face_dataset.FlowFaceDataset
        params:
          data_path: ${dataset_root}/mead-compressed/
          split: train
          source_frame: shuffle
          resolution: 512
          n_views: 8
          max_n_references: ${n_ref}
          add_mouth: True
          adapter_config:
            target: cap4d.datasets.adapters.nersemble_adapter.NersembleAdapter
            params:
              is_compressed: True
      - target: cap4d.datasets.flow_face_dataset.FlowFaceDataset
        params:
          data_path: ${dataset_root}/vfhq-compressed/
          split: train
          source_frame: shuffle
          resolution: 512
          n_views: 8
          max_n_references: ${n_ref}
          add_mouth: True
          adapter_config:
            target: cap4d.datasets.adapters.vfhq_adapter.VFHQAdapter
            params:
              is_compressed: True
      - target: cap4d.datasets.flow_face_dataset.FlowFaceDataset
        params:
          data_path: ${dataset_root}/ava-compressed/
          split: train
          source_frame: shuffle
          resolution: 512
          n_views: 8
          max_n_references: ${n_ref}
          add_mouth: True
          adapter_config:
            target: cap4d.datasets.adapters.ava_adapter.AvaAdapter
            params:
              is_compressed: True
      - target: cap4d.datasets.flow_face_dataset.FlowFaceDataset
        params:
          data_path: ${dataset_root}/nersemble-compressed/
          split: train
          source_frame: shuffle
          resolution: 512
          n_views: 8
          max_n_references: ${n_ref}
          add_mouth: True
          adapter_config:
            target: cap4d.datasets.adapters.nersemble_adapter.NersembleAdapter
            params:
              is_compressed: True

model:
  target: cap4d.mmdm.mmdm.MMLDM
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    n_frames: 8
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    control_key: "hint"
    image_size: 64
    channels: 4
    cond_stage_trainable: False
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False
    only_mid_control: False

    shift_schedule: True
    zero_snr_shift: True
    sqrt_shift: True
    minus_one_shift: True

    unet_config:
      target: cap4d.mmdm.net.mmdm_unet.MMDMUnetModel
      params:
        # use_checkpoint: True
        image_size: 64 # unused
        time_steps: 8
        temporal_mode: "3d"
        in_channels: 4
        out_channels: 4
        model_channels: 320
        condition_channels: 50  # 42 (pos map) + 3 (expr deform map) + 3 (ray map) + 1 (ref mask) + 1 (crop mask) 
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_head_channels: 64 
        use_spatial_transformer: True
        use_linear_in_transformer: True
        transformer_depth: 1
        context_dim: 1024
        use_checkpoint: False
        legacy: False

    first_stage_config:
      target: controlnet.ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          #attn_type: "vanilla-xformers"
          double_z: true
          z_channels: 4
          resolution: 512
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: cap4d.mmdm.conditioning.cap4dcond.CAP4DConditioning
      params:
        image_size: 64
        positional_channels: 42
        positional_multiplier: 1.
        super_resolution: 2
        use_ray_directions: True
        use_expr_deformation: True
        use_crop_mask: True

```

## /controlnet/cldm/ddim_hacked.py

```py path="/controlnet/cldm/ddim_hacked.py" 
"""SAMPLING ONLY."""

import torch
import numpy as np
from tqdm import tqdm

from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor


class DDIMSampler(object):
    def __init__(self, model, schedule="linear", **kwargs):
        super().__init__()
        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule

    def register_buffer(self, name, attr):
        if type(attr) == torch.Tensor:
            if attr.device != torch.device("cuda"):
                attr = attr.to(torch.device("cuda"))
        setattr(self, name, attr)

    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
        alphas_cumprod = self.model.alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)

        self.register_buffer('betas', to_torch(self.model.betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.detach().cpu())))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.detach().cpu())))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.detach().cpu())))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.detach().cpu())))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.detach().cpu() - 1)))

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.detach().cpu(),
                                                                                   ddim_timesteps=self.ddim_timesteps,
                                                                                   eta=ddim_eta,verbose=verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

    @torch.no_grad()
    def sample(self,
               S,
               batch_size,
               shape,
               conditioning=None,
               callback=None,
               normals_sequence=None,
               img_callback=None,
               quantize_x0=False,
               eta=0.,
               mask=None,
               x0=None,
               temperature=1.,
               noise_dropout=0.,
               score_corrector=None,
               corrector_kwargs=None,
               verbose=True,
               x_T=None,
               log_every_t=100,
               unconditional_guidance_scale=1.,
               unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
               dynamic_threshold=None,
               ucg_schedule=None,
               **kwargs
               ):
        if conditioning is not None:
            if False:
                if isinstance(conditioning, dict):
                    ctmp = conditioning[list(conditioning.keys())[0]]
                    while isinstance(ctmp, list): ctmp = ctmp[0]
                    cbs = ctmp.shape[0]
                    if cbs != batch_size:
                        print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

                elif isinstance(conditioning, list):
                    for ctmp in conditioning:
                        if ctmp.shape[0] != batch_size:
                            print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

                else:
                    if conditioning.shape[0] != batch_size:
                        print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
        # sampling
        # C, H, W = shape
        size = (batch_size, *shape)
        print(f'Data shape for DDIM sampling is {size}, eta {eta}')

        samples, intermediates = self.ddim_sampling(conditioning, size,
                                                    callback=callback,
                                                    img_callback=img_callback,
                                                    quantize_denoised=quantize_x0,
                                                    mask=mask, x0=x0,
                                                    ddim_use_original_steps=False,
                                                    noise_dropout=noise_dropout,
                                                    temperature=temperature,
                                                    score_corrector=score_corrector,
                                                    corrector_kwargs=corrector_kwargs,
                                                    x_T=x_T,
                                                    log_every_t=log_every_t,
                                                    unconditional_guidance_scale=unconditional_guidance_scale,
                                                    unconditional_conditioning=unconditional_conditioning,
                                                    dynamic_threshold=dynamic_threshold,
                                                    ucg_schedule=ucg_schedule
                                                    )
        return samples, intermediates

    @torch.no_grad()
    def ddim_sampling(self, cond, shape,
                      x_T=None, ddim_use_original_steps=False,
                      callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, log_every_t=100,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
                      ucg_schedule=None):
        device = self.model.betas.device
        b_t = shape[:-3]  # B, T, C, H, W
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        if timesteps is None:
            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
        elif timesteps is not None and not ddim_use_original_steps:
            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
            timesteps = self.ddim_timesteps[:subset_end]

        intermediates = {'x_inter': [img], 'pred_x0': [img]}
        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)

        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full(b_t, step, device=device, dtype=torch.long)

            if mask is not None:
                assert x0 is not None
                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img

            if ucg_schedule is not None:
                assert len(ucg_schedule) == len(time_range)
                unconditional_guidance_scale = ucg_schedule[i]

            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning,
                                      dynamic_threshold=dynamic_threshold)
            img, pred_x0 = outs
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)

            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)

        return img, intermediates

    @torch.no_grad()
    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None,
                      dynamic_threshold=None):
        b_t, device = x.shape[:-3], x.device

        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
            model_output = self.model.apply_model(x, t, c)
        else:
            model_t = self.model.apply_model(x, t, c)
            model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
            model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)

        if self.model.parameterization == "v":
            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
        else:
            e_t = model_output

        if score_corrector is not None:
            assert self.model.parameterization == "eps", 'not implemented'
            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)

        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
        # select parameters corresponding to the currently considered timestep
        a_t = torch.full((*b_t, 1, 1, 1), alphas[index], device=device)
        a_prev = torch.full((*b_t, 1, 1, 1), alphas_prev[index], device=device)
        sigma_t = torch.full((*b_t, 1, 1, 1), sigmas[index], device=device)
        sqrt_one_minus_at = torch.full((*b_t, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

        # current prediction for x_0
        if self.model.parameterization != "v":
            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
        else:
            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)

        if quantize_denoised:
            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)

        if dynamic_threshold is not None:
            raise NotImplementedError()

        # direction pointing to x_t
        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
        return x_prev, pred_x0

    @torch.no_grad()
    def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
               unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
        num_reference_steps = timesteps.shape[0]

        assert t_enc <= num_reference_steps
        num_steps = t_enc

        if use_original_steps:
            alphas_next = self.alphas_cumprod[:num_steps]
            alphas = self.alphas_cumprod_prev[:num_steps]
        else:
            alphas_next = self.ddim_alphas[:num_steps]
            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])

        x_next = x0
        intermediates = []
        inter_steps = []
        for i in tqdm(range(num_steps), desc='Encoding Image'):
            t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
            if unconditional_guidance_scale == 1.:
                noise_pred = self.model.apply_model(x_next, t, c)
            else:
                assert unconditional_conditioning is not None
                e_t_uncond, noise_pred = torch.chunk(
                    self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
                                           torch.cat((unconditional_conditioning, c))), 2)
                noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)

            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
            weighted_noise_pred = alphas_next[i].sqrt() * (
                    (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
            x_next = xt_weighted + weighted_noise_pred
            if return_intermediates and i % (
                    num_steps // return_intermediates) == 0 and i < num_steps - 1:
                intermediates.append(x_next)
                inter_steps.append(i)
            elif return_intermediates and i >= num_steps - 2:
                intermediates.append(x_next)
                inter_steps.append(i)
            if callback: callback(i)

        out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
        if return_intermediates:
            out.update({'intermediates': intermediates})
        return x_next, out

    @torch.no_grad()
    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
        # fast, but does not allow for exact reconstruction
        # t serves as an index to gather the correct alphas
        if use_original_steps:
            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
        else:
            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas

        if noise is None:
            noise = torch.randn_like(x0)
        return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
                extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)

    @torch.no_grad()
    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
               use_original_steps=False, callback=None):

        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
        timesteps = timesteps[:t_start]

        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
        x_dec = x_latent
        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
                                          unconditional_guidance_scale=unconditional_guidance_scale,
                                          unconditional_conditioning=unconditional_conditioning)
            if callback: callback(i)
        return x_dec

```

## /controlnet/cldm/hack.py

```py path="/controlnet/cldm/hack.py" 
import torch
import einops

import controlnet.ldm.modules.encoders.modules
import controlnet.ldm.modules.attention

from transformers import logging
from controlnet.ldm.modules.attention import default


def disable_verbosity():
    logging.set_verbosity_error()
    print('logging improved.')
    return


def enable_sliced_attention():
    ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
    print('Enabled sliced_attention.')
    return


def hack_everything(clip_skip=0):
    disable_verbosity()
    ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
    ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
    print('Enabled clip hacks.')
    return


# Written by Lvmin
def _hacked_clip_forward(self, text):
    PAD = self.tokenizer.pad_token_id
    EOS = self.tokenizer.eos_token_id
    BOS = self.tokenizer.bos_token_id

    def tokenize(t):
        return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]

    def transformer_encode(t):
        if self.clip_skip > 1:
            rt = self.transformer(input_ids=t, output_hidden_states=True)
            return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
        else:
            return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state

    def split(x):
        return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]

    def pad(x, p, i):
        return x[:i] if len(x) >= i else x + [p] * (i - len(x))

    raw_tokens_list = tokenize(text)
    tokens_list = []

    for raw_tokens in raw_tokens_list:
        raw_tokens_123 = split(raw_tokens)
        raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
        raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
        tokens_list.append(raw_tokens_123)

    tokens_list = torch.IntTensor(tokens_list).to(self.device)

    feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
    y = transformer_encode(feed)
    z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)

    return z


# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)
    k = self.to_k(context)
    v = self.to_v(context)
    del context, x

    q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    limit = k.shape[0]
    att_step = 1
    q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
    k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
    v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))

    q_chunks.reverse()
    k_chunks.reverse()
    v_chunks.reverse()
    sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
    del k, q, v
    for i in range(0, limit, att_step):
        q_buffer = q_chunks.pop()
        k_buffer = k_chunks.pop()
        v_buffer = v_chunks.pop()
        sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale

        del k_buffer, q_buffer
        # attention, what we cannot get enough of, by chunks

        sim_buffer = sim_buffer.softmax(dim=-1)

        sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
        del v_buffer
        sim[i:i + att_step, :, :] = sim_buffer

        del sim_buffer
    sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
    return self.to_out(sim)

```

## /controlnet/cldm/logger.py

```py path="/controlnet/cldm/logger.py" 
import os

import numpy as np
import torch
import torchvision
from pathlib import Path
from PIL import Image
from pytorch_lightning.callbacks import Callback
# from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.rank_zero import rank_zero_only
import einops
import torch.nn.functional as F
import gc


class ImageLogger(Callback):
    def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
                 rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
                 log_images_kwargs=None):
        super().__init__()
        self.rescale = rescale
        self.batch_freq = batch_frequency
        self.max_images = max_images
        if not increase_log_steps:
            self.log_steps = [self.batch_freq]
        self.clamp = clamp
        self.disabled = disabled
        self.log_on_batch_idx = log_on_batch_idx
        self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
        self.log_first_step = log_first_step

    # @rank_zero_only
    def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx, gpu_id=0):
        root = os.path.join(save_dir, "image_log", split, f"e-{current_epoch:06d}")

        rows = []

        for k in images:
            if len(images[k].shape) == 4:
                # Single images
                # TODO: Implement
                continue
            if len(images[k].shape) == 5:
                # We have videos
                b, t = images[k].shape[:2]
                imgs = einops.rearrange(images[k], 'b t c h w -> (b t) c h w')
                grid = torchvision.utils.make_grid(imgs, nrow=b * t)

            if self.rescale:
                grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w

            grid = grid.permute(1, 2, 0).numpy()
            grid = (grid * 255).astype(np.uint8)

            rows.append(grid)
        
        filename = "gs-{:06}_b-{:06}_{:02d}_i.jpg".format(global_step, batch_idx, gpu_id)
        path = os.path.join(root, filename)
        os.makedirs(os.path.split(path)[0], exist_ok=True)
        Image.fromarray(np.concatenate(rows, axis=0)).save(path)

    def log_cond(self, pl_module, batch):
        cond_model = pl_module.cond_stage_model
        cond_key = pl_module.control_key

        c_cond = cond_model(batch[cond_key], conditioned=True)
        enc_vis = cond_model.get_vis(c_cond["pos_enc"])

        for key in enc_vis:
            vis = enc_vis[key]
            b_ = vis.shape[0]
            vis = einops.rearrange(vis, 'b t h w c -> (b t) c h w')
            vis = F.interpolate(vis, scale_factor=8., mode="nearest")
            enc_vis[key] = einops.rearrange(vis, '(b t) c h w -> b t c h w', b=b_)
        
        return enc_vis

    def log_img(self, pl_module, batch, batch_idx, split="train"):
        check_idx = batch_idx  # if self.log_on_batch_idx else pl_module.global_step
        if (self.check_frequency(check_idx) and  # batch_idx % self.batch_freq == 0
                hasattr(pl_module, "log_images") and
                callable(pl_module.log_images) and
                self.max_images > 0):
            logger = type(pl_module.logger)

            is_train = pl_module.training
            if is_train:
                pl_module.eval()

            # draw vertices!
            with torch.no_grad():
                cond_vis = self.log_cond(pl_module, batch)

                images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)

            for k in images:
                N = min(images[k].shape[0], self.max_images)
                images[k] = images[k][:N]
                if isinstance(images[k], torch.Tensor):
                    images[k] = images[k].detach().cpu()
                    if self.clamp:
                        if not images[k].shape[-1] == 3:
                            images[k] = torch.clamp(images[k], -1., 1.)
                
            for key in cond_vis:
                images[key] = cond_vis[key].detach().cpu().clamp(-1., 1.)

            self.log_local(
                pl_module.logger.save_dir, 
                split, 
                images,
                pl_module.global_step, 
                pl_module.current_epoch, 
                batch_idx,
                gpu_id=pl_module.global_rank,
            )
            
            gc.collect()

            if is_train:
                pl_module.train()

    def check_frequency(self, check_idx):
        return check_idx % self.batch_freq == 0

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if not self.disabled:
            self.log_img(pl_module, batch, batch_idx, split="train")

```

## /controlnet/cldm/model.py

```py path="/controlnet/cldm/model.py" 
import os
import torch

from omegaconf import OmegaConf
from controlnet.ldm.util import instantiate_from_config


def get_state_dict(d):
    return d.get('state_dict', d)


def load_state_dict(ckpt_path, location='cpu'):
    _, extension = os.path.splitext(ckpt_path)
    if extension.lower() == ".safetensors":
        import safetensors.torch
        state_dict = safetensors.torch.load_file(ckpt_path, device=location)
    else:
        state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
    state_dict = get_state_dict(state_dict)
    print(f'Loaded state_dict from [{ckpt_path}]')
    return state_dict


def create_model(config_path):
    config = OmegaConf.load(config_path)
    model = instantiate_from_config(config.model).cpu()
    print(f'Loaded model config from [{config_path}]')
    return model

```

## /controlnet/ldm/data/__init__.py

```py path="/controlnet/ldm/data/__init__.py" 

```

## /controlnet/ldm/data/util.py

```py path="/controlnet/ldm/data/util.py" 
import torch

from ldm.modules.midas.api import load_midas_transform


class AddMiDaS(object):
    def __init__(self, model_type):
        super().__init__()
        self.transform = load_midas_transform(model_type)

    def pt2np(self, x):
        x = ((x + 1.0) * .5).detach().cpu().numpy()
        return x

    def np2pt(self, x):
        x = torch.from_numpy(x) * 2 - 1.
        return x

    def __call__(self, sample):
        # sample['jpg'] is tensor hwc in [-1, 1] at this point
        x = self.pt2np(sample['jpg'])
        x = self.transform({"image": x})["image"]
        sample['midas_in'] = x
        return sample
```

## /controlnet/ldm/models/autoencoder.py

```py path="/controlnet/ldm/models/autoencoder.py" 
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager

from controlnet.ldm.modules.diffusionmodules.model import Encoder, Decoder
from controlnet.ldm.modules.distributions.distributions import DiagonalGaussianDistribution

from controlnet.ldm.util import instantiate_from_config
from controlnet.ldm.modules.ema import LitEma


class AutoencoderKL(pl.LightningModule):
    def __init__(self,
                 ddconfig,
                 lossconfig,
                 embed_dim,
                 ckpt_path=None,
                 ignore_keys=[],
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
                 ema_decay=None,
                 learn_logvar=False
                 ):
        super().__init__()
        self.learn_logvar = learn_logvar
        self.image_key = image_key
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        self.loss = instantiate_from_config(lossconfig)
        assert ddconfig["double_z"]
        self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
        self.embed_dim = embed_dim
        if colorize_nlabels is not None:
            assert type(colorize_nlabels)==int
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
        if monitor is not None:
            self.monitor = monitor

        self.use_ema = ema_decay is not None
        if self.use_ema:
            self.ema_decay = ema_decay
            assert 0. < ema_decay < 1.
            self.model_ema = LitEma(self, decay=ema_decay)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        self.load_state_dict(sd, strict=False)
        print(f"Restored from {path}")

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.parameters())
            self.model_ema.copy_to(self)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self)

    def encode(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, z):
        z = self.post_quant_conv(z)
        dec = self.decoder(z)
        return dec

    def forward(self, input, sample_posterior=True):
        posterior = self.encode(input)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        dec = self.decode(z)
        return dec, posterior

    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)  # .float() CONTIGUOUS
        return x

    def training_step(self, batch, batch_idx, optimizer_idx):
        inputs = self.get_input(batch, self.image_key)
        reconstructions, posterior = self(inputs)

        if optimizer_idx == 0:
            # train encoder+decoder+logvar
            aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
                                            last_layer=self.get_last_layer(), split="train")
            self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
            return aeloss

        if optimizer_idx == 1:
            # train the discriminator
            discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
                                                last_layer=self.get_last_layer(), split="train")

            self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
            return discloss

    def validation_step(self, batch, batch_idx):
        log_dict = self._validation_step(batch, batch_idx)
        with self.ema_scope():
            log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
        return log_dict

    def _validation_step(self, batch, batch_idx, postfix=""):
        inputs = self.get_input(batch, self.image_key)
        reconstructions, posterior = self(inputs)
        aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
                                        last_layer=self.get_last_layer(), split="val"+postfix)

        discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
                                            last_layer=self.get_last_layer(), split="val"+postfix)

        self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
        self.log_dict(log_dict_ae)
        self.log_dict(log_dict_disc)
        return self.log_dict

    def configure_optimizers(self):
        lr = self.learning_rate
        ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
            self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
        if self.learn_logvar:
            print(f"{self.__class__.__name__}: Learning logvar")
            ae_params_list.append(self.loss.logvar)
        opt_ae = torch.optim.Adam(ae_params_list,
                                  lr=lr, betas=(0.5, 0.9))
        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
                                    lr=lr, betas=(0.5, 0.9))
        return [opt_ae, opt_disc], []

    def get_last_layer(self):
        return self.decoder.conv_out.weight

    @torch.no_grad()
    def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
        log = dict()
        x = self.get_input(batch, self.image_key)
        x = x.to(self.device)
        if not only_inputs:
            xrec, posterior = self(x)
            if x.shape[1] > 3:
                # colorize with random projection
                assert xrec.shape[1] > 3
                x = self.to_rgb(x)
                xrec = self.to_rgb(xrec)
            log["samples"] = self.decode(torch.randn_like(posterior.sample()))
            log["reconstructions"] = xrec
            if log_ema or self.use_ema:
                with self.ema_scope():
                    xrec_ema, posterior_ema = self(x)
                    if x.shape[1] > 3:
                        # colorize with random projection
                        assert xrec_ema.shape[1] > 3
                        xrec_ema = self.to_rgb(xrec_ema)
                    log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
                    log["reconstructions_ema"] = xrec_ema
        log["inputs"] = x
        return log

    def to_rgb(self, x):
        assert self.image_key == "segmentation"
        if not hasattr(self, "colorize"):
            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
        x = F.conv2d(x, weight=self.colorize)
        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
        return x


class IdentityFirstStage(torch.nn.Module):
    def __init__(self, *args, vq_interface=False, **kwargs):
        self.vq_interface = vq_interface
        super().__init__()

    def encode(self, x, *args, **kwargs):
        return x

    def decode(self, x, *args, **kwargs):
        return x

    def quantize(self, x, *args, **kwargs):
        if self.vq_interface:
            return x, None, [None, None, None]
        return x

    def forward(self, x, *args, **kwargs):
        return x


```

## /controlnet/ldm/models/diffusion/__init__.py

```py path="/controlnet/ldm/models/diffusion/__init__.py" 

```

## /controlnet/ldm/models/diffusion/ddim.py

```py path="/controlnet/ldm/models/diffusion/ddim.py" 
"""SAMPLING ONLY."""

import torch
import numpy as np
from tqdm import tqdm

from controlnet.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor


class DDIMSampler(object):
    def __init__(self, model, schedule="linear", **kwargs):
        super().__init__()
        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule

    def register_buffer(self, name, attr):
        if type(attr) == torch.Tensor:
            if attr.device != torch.device("cuda"):
                attr = attr.to(torch.device("cuda"))
        setattr(self, name, attr)

    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
        alphas_cumprod = self.model.alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)

        self.register_buffer('betas', to_torch(self.model.betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.detach().cpu())))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.detach().cpu())))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.detach().cpu())))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.detach().cpu())))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.detach().cpu() - 1)))

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.detach().cpu(),
                                                                                   ddim_timesteps=self.ddim_timesteps,
                                                                                   eta=ddim_eta,verbose=verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

    @torch.no_grad()
    def sample(self,
               S,
               batch_size,
               shape,
               conditioning=None,
               callback=None,
               normals_sequence=None,
               img_callback=None,
               quantize_x0=False,
               eta=0.,
               mask=None,
               x0=None,
               temperature=1.,
               noise_dropout=0.,
               score_corrector=None,
               corrector_kwargs=None,
               verbose=True,
               x_T=None,
               log_every_t=100,
               unconditional_guidance_scale=1.,
               unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
               dynamic_threshold=None,
               ucg_schedule=None,
               **kwargs
               ):
        if conditioning is not None:
            if isinstance(conditioning, dict):
                ctmp = conditioning[list(conditioning.keys())[0]]
                while isinstance(ctmp, list): ctmp = ctmp[0]
                if not isinstance(ctmp, dict):
                    cbs = ctmp.shape[0]
                    if cbs != batch_size:
                        print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

            elif isinstance(conditioning, list):
                for ctmp in conditioning:
                    if ctmp.shape[0] != batch_size:
                        print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

            else:
                if conditioning.shape[0] != batch_size:
                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
        # sampling
        # C, H, W = shape
        size = (batch_size, *shape)
        print(f'Data shape for DDIM sampling is {size}, eta {eta}')

        samples, intermediates = self.ddim_sampling(conditioning, size,
                                                    callback=callback,
                                                    img_callback=img_callback,
                                                    quantize_denoised=quantize_x0,
                                                    mask=mask, x0=x0,
                                                    ddim_use_original_steps=False,
                                                    noise_dropout=noise_dropout,
                                                    temperature=temperature,
                                                    score_corrector=score_corrector,
                                                    corrector_kwargs=corrector_kwargs,
                                                    x_T=x_T,
                                                    log_every_t=log_every_t,
                                                    unconditional_guidance_scale=unconditional_guidance_scale,
                                                    unconditional_conditioning=unconditional_conditioning,
                                                    dynamic_threshold=dynamic_threshold,
                                                    ucg_schedule=ucg_schedule
                                                    )
        return samples, intermediates

    @torch.no_grad()
    def ddim_sampling(self, cond, shape,
                      x_T=None, ddim_use_original_steps=False,
                      callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, log_every_t=100,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
                      ucg_schedule=None):
        device = self.model.betas.device
        b_ = shape[:-3]
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        if timesteps is None:
            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
        elif timesteps is not None and not ddim_use_original_steps:
            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
            timesteps = self.ddim_timesteps[:subset_end]

        intermediates = {'x_inter': [img], 'pred_x0': [img]}
        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)

        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full(b_, step, device=device, dtype=torch.long)

            if mask is not None:
                assert x0 is not None
                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img

            if ucg_schedule is not None:
                assert len(ucg_schedule) == len(time_range)
                unconditional_guidance_scale = ucg_schedule[i]

            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning,
                                      dynamic_threshold=dynamic_threshold)
            img, pred_x0 = outs
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)

            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)

        return img, intermediates

    @torch.no_grad()
    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None,
                      dynamic_threshold=None):
        b_t, device = x.shape[:-3], x.device

        # model_output = self.model.apply_model(x, t, c)
        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
            model_output = self.model.apply_model(x, t, c)
        else:
            x_in = torch.cat([x] * 2)
            t_in = torch.cat([t] * 2)
            # if isinstance(c, dict):
            #     assert isinstance(unconditional_conditioning, dict)
            #     c_in = dict()
            #     for k in c:
            #         if isinstance(c[k], list):
            #             c_in[k] = [torch.cat([
            #                 unconditional_conditioning[k][i],
            #                 c[k][i]]) for i in range(len(c[k]))]
            #         else:
            #             c_in[k] = torch.cat([
            #                     unconditional_conditioning[k],
            #                     c[k]])
            # elif isinstance(c, list):
            #     c_in = list()
            #     assert isinstance(unconditional_conditioning, list)
            #     for i in range(len(c)):
            #         c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
            # else:
            #     c_in = torch.cat([unconditional_conditioning, c])

            # HACK:
            c_cond = c["c_concat"][0]
            c_uncond = unconditional_conditioning
            c_in = dict()
            for key in c_cond:
                c_in[key] = torch.cat([c_uncond[key], c_cond[key]], dim=0)
            c_in = {"c_concat": [c_in]}
            ###

            model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
            model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)

        if self.model.parameterization == "v":
            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
        else:
            e_t = model_output

        if score_corrector is not None:
            assert self.model.parameterization == "eps", 'not implemented'
            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)

        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
        # select parameters corresponding to the currently considered timestep
        a_t = torch.full((*b_t, 1, 1, 1), alphas[index], device=device)
        a_prev = torch.full((*b_t, 1, 1, 1), alphas_prev[index], device=device)
        sigma_t = torch.full((*b_t, 1, 1, 1), sigmas[index], device=device)
        sqrt_one_minus_at = torch.full((*b_t, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

        # current prediction for x_0
        if self.model.parameterization != "v":
            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
        else:
            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)

        if quantize_denoised:
            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)

        if dynamic_threshold is not None:
            raise NotImplementedError()

        # direction pointing to x_t
        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
        return x_prev, pred_x0

    @torch.no_grad()
    def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
               unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
        num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]

        assert t_enc <= num_reference_steps
        num_steps = t_enc

        if use_original_steps:
            alphas_next = self.alphas_cumprod[:num_steps]
            alphas = self.alphas_cumprod_prev[:num_steps]
        else:
            alphas_next = self.ddim_alphas[:num_steps]
            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])

        x_next = x0
        intermediates = []
        inter_steps = []
        for i in tqdm(range(num_steps), desc='Encoding Image'):
            t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
            if unconditional_guidance_scale == 1.:
                noise_pred = self.model.apply_model(x_next, t, c)
            else:
                assert unconditional_conditioning is not None
                e_t_uncond, noise_pred = torch.chunk(
                    self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
                                           torch.cat((unconditional_conditioning, c))), 2)
                noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)

            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
            weighted_noise_pred = alphas_next[i].sqrt() * (
                    (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
            x_next = xt_weighted + weighted_noise_pred
            if return_intermediates and i % (
                    num_steps // return_intermediates) == 0 and i < num_steps - 1:
                intermediates.append(x_next)
                inter_steps.append(i)
            elif return_intermediates and i >= num_steps - 2:
                intermediates.append(x_next)
                inter_steps.append(i)
            if callback: callback(i)

        out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
        if return_intermediates:
            out.update({'intermediates': intermediates})
        return x_next, out

    @torch.no_grad()
    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
        # fast, but does not allow for exact reconstruction
        # t serves as an index to gather the correct alphas
        if use_original_steps:
            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
        else:
            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas

        if noise is None:
            noise = torch.randn_like(x0)
        return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
                extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)

    @torch.no_grad()
    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
               use_original_steps=False, callback=None):

        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
        timesteps = timesteps[:t_start]

        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
        x_dec = x_latent
        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
                                          unconditional_guidance_scale=unconditional_guidance_scale,
                                          unconditional_conditioning=unconditional_conditioning)
            if callback: callback(i)
        return x_dec
```

## /controlnet/ldm/models/diffusion/dpm_solver/__init__.py

```py path="/controlnet/ldm/models/diffusion/dpm_solver/__init__.py" 
from .sampler import DPMSolverSampler
```

## /controlnet/ldm/models/diffusion/dpm_solver/sampler.py

```py path="/controlnet/ldm/models/diffusion/dpm_solver/sampler.py" 
"""SAMPLING ONLY."""
import torch

from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver


MODEL_TYPES = {
    "eps": "noise",
    "v": "v"
}


class DPMSolverSampler(object):
    def __init__(self, model, **kwargs):
        super().__init__()
        self.model = model
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
        self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))

    def register_buffer(self, name, attr):
        if type(attr) == torch.Tensor:
            if attr.device != torch.device("cuda"):
                attr = attr.to(torch.device("cuda"))
        setattr(self, name, attr)

    @torch.no_grad()
    def sample(self,
               S,
               batch_size,
               shape,
               conditioning=None,
               callback=None,
               normals_sequence=None,
               img_callback=None,
               quantize_x0=False,
               eta=0.,
               mask=None,
               x0=None,
               temperature=1.,
               noise_dropout=0.,
               score_corrector=None,
               corrector_kwargs=None,
               verbose=True,
               x_T=None,
               log_every_t=100,
               unconditional_guidance_scale=1.,
               unconditional_conditioning=None,
               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
               **kwargs
               ):
        if conditioning is not None:
            if isinstance(conditioning, dict):
                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
                if cbs != batch_size:
                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
            else:
                if conditioning.shape[0] != batch_size:
                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

        # sampling
        C, H, W = shape
        size = (batch_size, C, H, W)

        print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')

        device = self.model.betas.device
        if x_T is None:
            img = torch.randn(size, device=device)
        else:
            img = x_T

        ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)

        model_fn = model_wrapper(
            lambda x, t, c: self.model.apply_model(x, t, c),
            ns,
            model_type=MODEL_TYPES[self.model.parameterization],
            guidance_type="classifier-free",
            condition=conditioning,
            unconditional_condition=unconditional_conditioning,
            guidance_scale=unconditional_guidance_scale,
        )

        dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
        x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)

        return x.to(device), None
```

## /controlnet/ldm/models/diffusion/plms.py

```py path="/controlnet/ldm/models/diffusion/plms.py" 
"""SAMPLING ONLY."""

import torch
import numpy as np
from tqdm import tqdm
from functools import partial

from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding


class PLMSSampler(object):
    def __init__(self, model, schedule="linear", **kwargs):
        super().__init__()
        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule

    def register_buffer(self, name, attr):
        if type(attr) == torch.Tensor:
            if attr.device != torch.device("cuda"):
                attr = attr.to(torch.device("cuda"))
        setattr(self, name, attr)

    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
        if ddim_eta != 0:
            raise ValueError('ddim_eta must be 0 for PLMS')
        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
        alphas_cumprod = self.model.alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)

        self.register_buffer('betas', to_torch(self.model.betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
                                                                                   ddim_timesteps=self.ddim_timesteps,
                                                                                   eta=ddim_eta,verbose=verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

    @torch.no_grad()
    def sample(self,
               S,
               batch_size,
               shape,
               conditioning=None,
               callback=None,
               normals_sequence=None,
               img_callback=None,
               quantize_x0=False,
               eta=0.,
               mask=None,
               x0=None,
               temperature=1.,
               noise_dropout=0.,
               score_corrector=None,
               corrector_kwargs=None,
               verbose=True,
               x_T=None,
               log_every_t=100,
               unconditional_guidance_scale=1.,
               unconditional_conditioning=None,
               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
               dynamic_threshold=None,
               **kwargs
               ):
        if conditioning is not None:
            if isinstance(conditioning, dict):
                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
                if cbs != batch_size:
                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
            else:
                if conditioning.shape[0] != batch_size:
                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
        # sampling
        C, H, W = shape
        size = (batch_size, C, H, W)
        print(f'Data shape for PLMS sampling is {size}')

        samples, intermediates = self.plms_sampling(conditioning, size,
                                                    callback=callback,
                                                    img_callback=img_callback,
                                                    quantize_denoised=quantize_x0,
                                                    mask=mask, x0=x0,
                                                    ddim_use_original_steps=False,
                                                    noise_dropout=noise_dropout,
                                                    temperature=temperature,
                                                    score_corrector=score_corrector,
                                                    corrector_kwargs=corrector_kwargs,
                                                    x_T=x_T,
                                                    log_every_t=log_every_t,
                                                    unconditional_guidance_scale=unconditional_guidance_scale,
                                                    unconditional_conditioning=unconditional_conditioning,
                                                    dynamic_threshold=dynamic_threshold,
                                                    )
        return samples, intermediates

    @torch.no_grad()
    def plms_sampling(self, cond, shape,
                      x_T=None, ddim_use_original_steps=False,
                      callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, log_every_t=100,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None,
                      dynamic_threshold=None):
        device = self.model.betas.device
        b = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        if timesteps is None:
            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
        elif timesteps is not None and not ddim_use_original_steps:
            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
            timesteps = self.ddim_timesteps[:subset_end]

        intermediates = {'x_inter': [img], 'pred_x0': [img]}
        time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
        print(f"Running PLMS Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
        old_eps = []

        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)
            ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)

            if mask is not None:
                assert x0 is not None
                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img

            outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning,
                                      old_eps=old_eps, t_next=ts_next,
                                      dynamic_threshold=dynamic_threshold)
            img, pred_x0, e_t = outs
            old_eps.append(e_t)
            if len(old_eps) >= 4:
                old_eps.pop(0)
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)

            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)

        return img, intermediates

    @torch.no_grad()
    def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
                      dynamic_threshold=None):
        b, *_, device = *x.shape, x.device

        def get_model_output(x, t):
            if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
                e_t = self.model.apply_model(x, t, c)
            else:
                x_in = torch.cat([x] * 2)
                t_in = torch.cat([t] * 2)
                c_in = torch.cat([unconditional_conditioning, c])
                e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
                e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

            if score_corrector is not None:
                assert self.model.parameterization == "eps"
                e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)

            return e_t

        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas

        def get_x_prev_and_pred_x0(e_t, index):
            # select parameters corresponding to the currently considered timestep
            a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
            a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
            sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
            sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

            # current prediction for x_0
            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
            if quantize_denoised:
                pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
            if dynamic_threshold is not None:
                pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
            # direction pointing to x_t
            dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
            noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
            if noise_dropout > 0.:
                noise = torch.nn.functional.dropout(noise, p=noise_dropout)
            x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
            return x_prev, pred_x0

        e_t = get_model_output(x, t)
        if len(old_eps) == 0:
            # Pseudo Improved Euler (2nd order)
            x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
            e_t_next = get_model_output(x_prev, t_next)
            e_t_prime = (e_t + e_t_next) / 2
        elif len(old_eps) == 1:
            # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
            e_t_prime = (3 * e_t - old_eps[-1]) / 2
        elif len(old_eps) == 2:
            # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
            e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
        elif len(old_eps) >= 3:
            # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
            e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24

        x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)

        return x_prev, pred_x0, e_t

```

## /controlnet/ldm/models/diffusion/sampling_util.py

```py path="/controlnet/ldm/models/diffusion/sampling_util.py" 
import torch
import numpy as np


def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions.
    From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
    return x[(...,) + (None,) * dims_to_append]


def norm_thresholding(x0, value):
    s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
    return x0 * (value / s)


def spatial_norm_thresholding(x0, value):
    # b c h w
    s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
    return x0 * (value / s)
```

## /controlnet/ldm/modules/attention.py

```py path="/controlnet/ldm/modules/attention.py" 
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any

from controlnet.ldm.modules.diffusionmodules.util import checkpoint


try:
    import xformers
    import xformers.ops
    XFORMERS_IS_AVAILBLE = True
except:
    XFORMERS_IS_AVAILBLE = False

# CrossAttn precision handling
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")

def exists(val):
    return val is not None


def uniq(arr):
    return{el: True for el in arr}.keys()


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def max_neg_value(t):
    return -torch.finfo(t.dtype).max


def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor


# feedforward
class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)


class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU()
        ) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(
            project_in,
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim_out)
        )

    def forward(self, x):
        return self.net(x)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def Normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)


class SpatialSelfAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b,c,h,w = q.shape
        q = rearrange(q, 'b c h w -> b (h w) c')
        k = rearrange(k, 'b c h w -> b c (h w)')
        w_ = torch.einsum('bij,bjk->bik', q, k)

        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = rearrange(v, 'b c h w -> b c (h w)')
        w_ = rearrange(w_, 'b i j -> b j i')
        h_ = torch.einsum('bij,bjk->bik', v, w_)
        h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
        h_ = self.proj_out(h_)

        return x+h_


class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        # force cast to fp32 to avoid overflowing
        if _ATTN_PRECISION =="fp32":
            with torch.autocast(enabled=False, device_type = 'cuda'):
                q, k = q.float(), k.float()
                sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        else:
            sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        
        del q, k
    
        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        sim = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', sim, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)


class MemoryEfficientCrossAttention(nn.Module):
    # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
        print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
              f"{heads} heads.")
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.heads = heads
        self.dim_head = dim_head

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
        self.attention_op: Optional[Any] = None

    def forward(self, x, context=None, mask=None):
        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        b, _, _ = q.shape
        q, k, v = map(
            lambda t: t.unsqueeze(3)
            .reshape(b, t.shape[1], self.heads, self.dim_head)
            .permute(0, 2, 1, 3)
            .reshape(b * self.heads, t.shape[1], self.dim_head)
            .contiguous(),
            (q, k, v),
        )

        # actually compute the attention, what we cannot get enough of
        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)

        if exists(mask):
            raise NotImplementedError
        out = (
            out.unsqueeze(0)
            .reshape(b, self.heads, out.shape[1], self.dim_head)
            .permute(0, 2, 1, 3)
            .reshape(b, out.shape[1], self.heads * self.dim_head)
        )
        return self.to_out(out)


class BasicTransformerBlock(nn.Module):
    ATTENTION_MODES = {
        "softmax": CrossAttention,  # vanilla attention
        "softmax-xformers": MemoryEfficientCrossAttention
    }
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
                 disable_self_attn=False):
        super().__init__()
        attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
        assert attn_mode in self.ATTENTION_MODES
        attn_cls = self.ATTENTION_MODES[attn_mode]
        self.disable_self_attn = disable_self_attn
        self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
                              context_dim=context_dim if self.disable_self_attn else None)  # is a self-attention if not self.disable_self_attn
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
                              heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.checkpoint = checkpoint

    def forward(self, x, context=None):
        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)

    def _forward(self, x, context=None):
        x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x


class SpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    NEW: use_linear for more efficiency instead of the 1x1 convs
    """
    def __init__(self, in_channels, n_heads, d_head,
                 depth=1, dropout=0., context_dim=None,
                 disable_self_attn=False, use_linear=False,
                 use_checkpoint=True):
        super().__init__()
        if exists(context_dim) and not isinstance(context_dim, list):
            context_dim = [context_dim]
        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)
        if not use_linear:
            self.proj_in = nn.Conv2d(in_channels,
                                     inner_dim,
                                     kernel_size=1,
                                     stride=1,
                                     padding=0)
        else:
            self.proj_in = nn.Linear(in_channels, inner_dim)

        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
                                   disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
                for d in range(depth)]
        )
        if not use_linear:
            self.proj_out = zero_module(nn.Conv2d(inner_dim,
                                                  in_channels,
                                                  kernel_size=1,
                                                  stride=1,
                                                  padding=0))
        else:
            self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
        self.use_linear = use_linear

    def forward(self, x, context=None):
        # note: if no context is given, cross-attention defaults to self-attention
        if not isinstance(context, list):
            context = [context]
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)
        if not self.use_linear:
            x = self.proj_in(x)
        x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
        if self.use_linear:
            x = self.proj_in(x)
        for i, block in enumerate(self.transformer_blocks):
            x = block(x, context=context[i])
        if self.use_linear:
            x = self.proj_out(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
        if not self.use_linear:
            x = self.proj_out(x)
        return x + x_in


```

## /controlnet/ldm/modules/diffusionmodules/__init__.py

```py path="/controlnet/ldm/modules/diffusionmodules/__init__.py" 

```

## /controlnet/ldm/modules/diffusionmodules/upscaling.py

```py path="/controlnet/ldm/modules/diffusionmodules/upscaling.py" 
import torch
import torch.nn as nn
import numpy as np
from functools import partial

from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
from ldm.util import default


class AbstractLowScaleModel(nn.Module):
    # for concatenating a downsampled image to the latent representation
    def __init__(self, noise_schedule_config=None):
        super(AbstractLowScaleModel, self).__init__()
        if noise_schedule_config is not None:
            self.register_schedule(**noise_schedule_config)

    def register_schedule(self, beta_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
        betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
                                   cosine_s=cosine_s)
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.linear_start = linear_start
        self.linear_end = linear_end
        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))

    def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)

    def forward(self, x):
        return x, None

    def decode(self, x):
        return x


class SimpleImageConcat(AbstractLowScaleModel):
    # no noise level conditioning
    def __init__(self):
        super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
        self.max_noise_level = 0

    def forward(self, x):
        # fix to constant noise level
        return x, torch.zeros(x.shape[0], device=x.device).long()


class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
    def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
        super().__init__(noise_schedule_config=noise_schedule_config)
        self.max_noise_level = max_noise_level

    def forward(self, x, noise_level=None):
        if noise_level is None:
            noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
        else:
            assert isinstance(noise_level, torch.Tensor)
        z = self.q_sample(x, noise_level)
        return z, noise_level




```

## /controlnet/ldm/modules/diffusionmodules/util.py

```py path="/controlnet/ldm/modules/diffusionmodules/util.py" 
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!


import os
import math
import torch
import torch.nn as nn
import numpy as np
from einops import repeat

from controlnet.ldm.util import instantiate_from_config


def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
    if schedule == "linear":
        betas = (
                torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
        )

    elif schedule == "cosine":
        timesteps = (
                torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
        )
        alphas = timesteps / (1 + cosine_s) * np.pi / 2
        alphas = torch.cos(alphas).pow(2)
        alphas = alphas / alphas[0]
        betas = 1 - alphas[1:] / alphas[:-1]
        betas = np.clip(betas, a_min=0, a_max=0.999)

    elif schedule == "sqrt_linear":
        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
    elif schedule == "sqrt":
        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
    else:
        raise ValueError(f"schedule '{schedule}' unknown.")
    return betas.numpy()


def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
    if ddim_discr_method == 'uniform':
        c = num_ddpm_timesteps // num_ddim_timesteps
        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
    elif ddim_discr_method == 'quad':
        ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
    else:
        raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')

    # assert ddim_timesteps.shape[0] == num_ddim_timesteps
    # add one to get the final alpha values right (the ones from first scale to data during sampling)
    steps_out = ddim_timesteps + 1
    if verbose:
        print(f'Selected timesteps for ddim sampler: {steps_out}')
    return steps_out


def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
    # select alphas for computing the variance schedule
    alphas = alphacums[ddim_timesteps]
    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())

    # according the the formula provided in https://arxiv.org/abs/2010.02502
    sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
    if verbose:
        print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
        print(f'For the chosen value of eta, which is {eta}, '
              f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
    return sigmas, alphas, alphas_prev


def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].
    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)


def extract_into_tensor(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def checkpoint(func, inputs, params, flag):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.
    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)


class CheckpointFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.input_tensors = list(args[:length])
        ctx.input_params = list(args[length:])
        ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
                                   "dtype": torch.get_autocast_gpu_dtype(),
                                   "cache_enabled": torch.is_autocast_cache_enabled()}
        with torch.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        return output_tensors

    @staticmethod
    def backward(ctx, *output_grads):
        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
        with torch.enable_grad(), \
                torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
        input_grads = torch.autograd.grad(
            output_tensors,
            ctx.input_tensors + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        return (None, None) + input_grads


def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    if not repeat_only:
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=timesteps.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    else:
        embedding = repeat(timesteps, 'b -> b d', d=dim)
    return embedding


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def scale_module(module, scale):
    """
    Scale the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().mul_(scale)
    return module


def mean_flat(tensor):
    """
    Take the mean over all non-batch dimensions.
    """
    return tensor.mean(dim=list(range(1, len(tensor.shape))))


def normalization(channels):
    """
    Make a standard normalization layer.
    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    return GroupNorm32(32, channels)


# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)

class LayerNorm32(nn.LayerNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)

def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)


def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


class HybridConditioner(nn.Module):

    def __init__(self, c_concat_config, c_crossattn_config):
        super().__init__()
        self.concat_conditioner = instantiate_from_config(c_concat_config)
        self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)

    def forward(self, c_concat, c_crossattn):
        c_concat = self.concat_conditioner(c_concat)
        c_crossattn = self.crossattn_conditioner(c_crossattn)
        return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}


def noise_like(shape, device, repeat=False):
    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
    noise = lambda: torch.randn(shape, device=device)
    return repeat_noise() if repeat else noise()
```

## /controlnet/ldm/modules/distributions/__init__.py

```py path="/controlnet/ldm/modules/distributions/__init__.py" 

```

## /controlnet/ldm/modules/distributions/distributions.py

```py path="/controlnet/ldm/modules/distributions/distributions.py" 
import torch
import numpy as np


class AbstractDistribution:
    def sample(self):
        raise NotImplementedError()

    def mode(self):
        raise NotImplementedError()


class DiracDistribution(AbstractDistribution):
    def __init__(self, value):
        self.value = value

    def sample(self):
        return self.value

    def mode(self):
        return self.value


class DiagonalGaussianDistribution(object):
    def __init__(self, parameters, deterministic=False):
        self.parameters = parameters
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)

    def sample(self):
        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
        return x

    def kl(self, other=None):
        if self.deterministic:
            return torch.Tensor([0.])
        else:
            if other is None:
                return 0.5 * torch.sum(torch.pow(self.mean, 2)
                                       + self.var - 1.0 - self.logvar,
                                       dim=[1, 2, 3])
            else:
                return 0.5 * torch.sum(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var - 1.0 - self.logvar + other.logvar,
                    dim=[1, 2, 3])

    def nll(self, sample, dims=[1,2,3]):
        if self.deterministic:
            return torch.Tensor([0.])
        logtwopi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims)

    def mode(self):
        return self.mean


def normal_kl(mean1, logvar1, mean2, logvar2):
    """
    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
    Compute the KL divergence between two gaussians.
    Shapes are automatically broadcasted, so batches can be compared to
    scalars, among other use cases.
    """
    tensor = None
    for obj in (mean1, logvar1, mean2, logvar2):
        if isinstance(obj, torch.Tensor):
            tensor = obj
            break
    assert tensor is not None, "at least one argument must be a Tensor"

    # Force variances to be Tensors. Broadcasting helps convert scalars to
    # Tensors, but it does not work for torch.exp().
    logvar1, logvar2 = [
        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
        for x in (logvar1, logvar2)
    ]

    return 0.5 * (
        -1.0
        + logvar2
        - logvar1
        + torch.exp(logvar1 - logvar2)
        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
    )

```

## /controlnet/ldm/modules/ema.py

```py path="/controlnet/ldm/modules/ema.py" 
import torch
from torch import nn


class LitEma(nn.Module):
    def __init__(self, model, decay=0.9999, use_num_upates=True):
        super().__init__()
        if decay < 0.0 or decay > 1.0:
            raise ValueError('Decay must be between 0 and 1')

        self.m_name2s_name = {}
        self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
        else torch.tensor(-1, dtype=torch.int))

        for name, p in model.named_parameters():
            if p.requires_grad:
                # remove as '.'-character is not allowed in buffers
                s_name = name.replace('.', '')
                self.m_name2s_name.update({name: s_name})
                self.register_buffer(s_name, p.clone().detach().data)

        self.collected_params = []

    def reset_num_updates(self):
        del self.num_updates
        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))

    def forward(self, model):
        decay = self.decay

        if self.num_updates >= 0:
            self.num_updates += 1
            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))

        one_minus_decay = 1.0 - decay

        with torch.no_grad():
            m_param = dict(model.named_parameters())
            shadow_params = dict(self.named_buffers())

            for key in m_param:
                if m_param[key].requires_grad:
                    sname = self.m_name2s_name[key]
                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
                    shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
                else:
                    assert not key in self.m_name2s_name

    def copy_to(self, model):
        m_param = dict(model.named_parameters())
        shadow_params = dict(self.named_buffers())
        for key in m_param:
            if m_param[key].requires_grad:
                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
            else:
                assert not key in self.m_name2s_name

    def store(self, parameters):
        """
        Save the current parameters for restoring later.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            temporarily stored.
        """
        self.collected_params = [param.clone() for param in parameters]

    def restore(self, parameters):
        """
        Restore the parameters stored with the `store` method.
        Useful to validate the model with EMA parameters without affecting the
        original optimization process. Store the parameters before the
        `copy_to` method. After validation (or model saving), use this to
        restore the former parameters.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored parameters.
        """
        for c_param, param in zip(self.collected_params, parameters):
            param.data.copy_(c_param.data)

```

## /controlnet/ldm/modules/encoders/__init__.py

```py path="/controlnet/ldm/modules/encoders/__init__.py" 

```

## /controlnet/ldm/modules/encoders/modules.py

```py path="/controlnet/ldm/modules/encoders/modules.py" 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel

import open_clip
from controlnet.ldm.util import default, count_params

OPENAI_CLIP_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073])
OPENAI_CLIP_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711])


class AbstractEncoder(nn.Module):
    def __init__(self):
        super().__init__()

    def encode(self, *args, **kwargs):
        raise NotImplementedError


class IdentityEncoder(AbstractEncoder):

    def encode(self, x):
        return x


class ClassEmbedder(nn.Module):
    def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
        super().__init__()
        self.key = key
        self.embedding = nn.Embedding(n_classes, embed_dim)
        self.n_classes = n_classes
        self.ucg_rate = ucg_rate

    def forward(self, batch, key=None, disable_dropout=False):
        if key is None:
            key = self.key
        # this is for use in crossattn
        c = batch[key][:, None]
        if self.ucg_rate > 0. and not disable_dropout:
            mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
            c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
            c = c.long()
        c = self.embedding(c)
        return c

    def get_unconditional_conditioning(self, bs, device="cuda"):
        uc_class = self.n_classes - 1  # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
        uc = torch.ones((bs,), device=device) * uc_class
        uc = {self.key: uc}
        return uc


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


class FrozenT5Embedder(AbstractEncoder):
    """Uses the T5 transformer encoder for text"""
    def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True):  # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
        super().__init__()
        self.tokenizer = T5Tokenizer.from_pretrained(version)
        self.transformer = T5EncoderModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length   # TODO: typical value?
        if freeze:
            self.freeze()

    def freeze(self):
        self.transformer = self.transformer.eval()
        #self.train = disabled_train
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        outputs = self.transformer(input_ids=tokens)

        z = outputs.last_hidden_state
        return z

    def encode(self, text):
        return self(text)


class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from huggingface)"""
    LAYERS = [
        "last",
        "pooled",
        "hidden"
    ]
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
                 freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32
        super().__init__()
        assert layer in self.LAYERS
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length
        if freeze:
            self.freeze()
        self.layer = layer
        self.layer_idx = layer_idx
        if layer == "hidden":
            assert layer_idx is not None
            assert 0 <= abs(layer_idx) <= 12

    def freeze(self):
        self.transformer = self.transformer.eval()
        #self.train = disabled_train
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
        if self.layer == "last":
            z = outputs.last_hidden_state
        elif self.layer == "pooled":
            z = outputs.pooler_output[:, None, :]
        else:
            z = outputs.hidden_states[self.layer_idx]
        return z

    def encode(self, text):
        return self(text)


class FrozenOpenCLIPEmbedder(AbstractEncoder):
    """
    Uses the OpenCLIP transformer encoder for text
    """
    LAYERS = [
        #"pooled",
        "last",
        "penultimate"
    ]
    def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
                 freeze=True, layer="last", embedding_type="text"):
        super().__init__()
        assert layer in self.LAYERS
        model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
        assert embedding_type in ["text", "visual"]
        self.embedding_type = embedding_type
        if embedding_type == "text":
            del model.visual
        elif embedding_type == "visual":
            del model.token_embedding
            del model.transformer
        self.model = model

        self.device = device
        self.max_length = max_length
        if freeze:
            self.freeze()
        self.layer = layer
        if self.layer == "last":
            self.layer_idx = 0
        elif self.layer == "penultimate":
            self.layer_idx = 1
        else:
            raise NotImplementedError()

    def freeze(self):
        self.model = self.model.eval()
        for param in self.parameters():
            param.requires_grad = False

    def img_forward(self, img):
        img = ((img + 1.) / 2. - OPENAI_CLIP_MEAN[None, :, None, None].to(img.device)) / OPENAI_CLIP_STD[None, :, None, None].to(img.device)
        img = F.interpolate(img, (224, 224))
        return self.model.visual(img)[:, None, :]

    def text_forward(self, text):
        tokens = open_clip.tokenize(text)
        z = self.encode_with_transformer(tokens.to(self.device))
        return z

    def encode_with_transformer(self, text):
        x = self.model.token_embedding(text)  # [batch_size, n_ctx, d_model]
        x = x + self.model.positional_embedding
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.model.ln_final(x)
        return x

    def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
        for i, r in enumerate(self.model.transformer.resblocks):
            if i == len(self.model.transformer.resblocks) - self.layer_idx:
                break
            if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
                x = checkpoint(r, x, attn_mask)
            else:
                x = r(x, attn_mask=attn_mask)
        return x

    def encode(self, x):
        if self.embedding_type == "text":
            return self.text_forward(x)
        elif self.embedding_type == "visual":
            return self.img_forward(x)


class FrozenCLIPT5Encoder(AbstractEncoder):
    def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
                 clip_max_length=77, t5_max_length=77):
        super().__init__()
        self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
        self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
        print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
              f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")

    def encode(self, text):
        return self(text)

    def forward(self, text):
        clip_z = self.clip_encoder.encode(text)
        t5_z = self.t5_encoder.encode(text)
        return [clip_z, t5_z]



```

## /controlnet/share.py

```py path="/controlnet/share.py" 
import config
from controlnet.cldm.hack import disable_verbosity, enable_sliced_attention


disable_verbosity()

if config.save_memory:
    enable_sliced_attention()

```

## /data/assets/datasets/gen_data.npz

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/data/assets/datasets/gen_data.npz

## /data/assets/flame/blink_blendshape.npy

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/data/assets/flame/blink_blendshape.npy

## /data/assets/flame/flowface_vertex_mask.npy

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/data/assets/flame/flowface_vertex_mask.npy

## /data/assets/flame/flowface_vertex_weights.npy

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/data/assets/flame/flowface_vertex_weights.npy

## /data/assets/flame/jaw_regressor.npy

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/data/assets/flame/jaw_regressor.npy

## /examples/input/animation/example_video.mp4

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/animation/example_video.mp4

## /examples/input/animation/sequence_00/fit.npz

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/animation/sequence_00/fit.npz

## /examples/input/animation/sequence_00/orbit.npz

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/animation/sequence_00/orbit.npz

## /examples/input/animation/sequence_01/fit.npz

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/animation/sequence_01/fit.npz

## /examples/input/animation/sequence_01/orbit.npz

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/animation/sequence_01/orbit.npz

## /examples/input/felix/alignment.npz

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/felix/alignment.npz

## /examples/input/felix/bg/cam0/0000.png

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/felix/bg/cam0/0000.png

## /examples/input/felix/bg/cam0/0001.png

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/felix/bg/cam0/0001.png

## /examples/input/felix/bg/cam0/0002.png

Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/felix/bg/cam0/0002.png

## /examples/input/lincoln/reference_images.json

```json path="/examples/input/lincoln/reference_images.json" 
[
    ["cam0", 0]
]
```


The content has been capped at 50000 tokens. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.
Copied!