```
├── .gitignore (900 tokens)
├── LICENSE (omitted)
├── README.md (2k tokens)
├── assets/
├── banner.gif
├── cap4d/
├── datasets/
├── utils.py (1800 tokens)
├── flame/
├── flame.py (1400 tokens)
├── mouth.py (700 tokens)
├── inference/
├── data/
├── generation_data.py (700 tokens)
├── inference_data.py (900 tokens)
├── reference_data.py (400 tokens)
├── generate_images.py (1300 tokens)
├── utils.py (1000 tokens)
├── mmdm/
├── conditioning/
├── cap4dcond.py (1300 tokens)
├── mesh2img.py (2.4k tokens)
├── mmdm.py (3.2k tokens)
├── modules/
├── module.py (400 tokens)
├── net/
├── attention.py (2.5k tokens)
├── mmdm_unet.py (800 tokens)
├── sampler.py (2.3k tokens)
├── utils.py (300 tokens)
├── configs/
├── avatar/
├── debug.yaml (300 tokens)
├── default.yaml (300 tokens)
├── high_quality.yaml (300 tokens)
├── low_quality.yaml (300 tokens)
├── medium_quality.yaml (300 tokens)
├── generation/
├── debug.yaml
├── default.yaml
├── high_quality.yaml
├── low_quality.yaml
├── medium_quality.yaml
├── mmdm/
├── cap4d_mmdm_final.yaml (900 tokens)
├── controlnet/
├── cldm/
├── ddim_hacked.py (3.3k tokens)
├── hack.py (700 tokens)
├── logger.py (900 tokens)
├── model.py (200 tokens)
├── ldm/
├── data/
├── __init__.py
├── util.py (100 tokens)
├── models/
├── autoencoder.py (1700 tokens)
├── diffusion/
├── __init__.py
├── ddim.py (3.6k tokens)
├── ddpm.py (17k tokens)
├── dpm_solver/
├── __init__.py
├── dpm_solver.py (13.2k tokens)
├── sampler.py (600 tokens)
├── plms.py (2.6k tokens)
├── sampling_util.py (200 tokens)
├── modules/
├── attention.py (2.4k tokens)
├── diffusionmodules/
├── __init__.py
├── model.py (6.9k tokens)
├── openaimodel.py (6.7k tokens)
├── upscaling.py (700 tokens)
├── util.py (2000 tokens)
├── distributions/
├── __init__.py
├── distributions.py (600 tokens)
├── ema.py (600 tokens)
├── encoders/
├── __init__.py
├── modules.py (1700 tokens)
├── util.py (1300 tokens)
├── share.py
├── data/
├── assets/
├── datasets/
├── gen_data.npz
├── flame/
├── blink_blendshape.npy
├── cap4d_avatar_template.obj (206.3k tokens)
├── cap4d_flame_template.obj (117.3k tokens)
├── deformable_verts.txt (4.7k tokens)
├── flowface_vertex_mask.npy
├── flowface_vertex_weights.npy
├── head_template_mesh.obj (112.7k tokens)
├── head_vertices.txt (4.6k tokens)
├── jaw_regressor.npy
├── weights/
├── mmdm/
├── config_dump.yaml (900 tokens)
├── examples/
├── input/
├── animation/
├── example_video.mp4
├── sequence_00/
├── fit.npz
├── orbit.npz
├── sequence_01/
├── fit.npz
├── orbit.npz
├── felix/
├── alignment.npz
├── bg/
├── cam0/
├── 0000.png
├── 0001.png
├── 0002.png
├── 0003.png
├── fit.npz
├── images/
├── cam0/
├── 00000001.png
├── 00000002.png
├── 00000003.png
├── 00000004.png
├── reference_images.json
├── visualization/
├── vis_cam0.mp4
├── lincoln/
├── alignment.npz
├── fit.npz
├── images/
├── cam0/
├── abraham_lincoln.png
├── reference_images.json
├── visualization/
├── vis_cam0.mp4
├── tesla/
├── alignment.npz
├── bg/
├── cam0/
├── 0000.png
├── fit.npz
├── images/
├── cam0/
├── tesla.png
├── reference_images.json
├── visualization/
├── vis_cam0.mp4
├── flowface/
├── flame/
├── flame.py (2.7k tokens)
├── io.py (300 tokens)
├── utils.py (700 tokens)
├── gaussianavatars/
├── LICENSE_GS.md (900 tokens)
├── animate.py (1500 tokens)
├── gaussian_renderer/
├── gsplat_renderer.py (600 tokens)
├── lpipsPyTorch/
├── __init__.py (100 tokens)
├── modules/
├── lpips.py (200 tokens)
├── networks.py (500 tokens)
├── utils.py (200 tokens)
├── scene/
├── cameras.py (300 tokens)
├── cap4d_gaussian_model.py (3.4k tokens)
├── dataset_readers.py (1700 tokens)
├── gaussian_model.py (5.6k tokens)
├── net/
├── positional_encoding.py (100 tokens)
├── unet.py (2.2k tokens)
├── scene.py (1100 tokens)
├── train.py (3.5k tokens)
├── utils/
├── camera_utils.py (300 tokens)
├── export_utils.py (1600 tokens)
├── general_utils.py (800 tokens)
├── graphics_utils.py (1000 tokens)
├── image_utils.py (200 tokens)
├── loss_utils.py (400 tokens)
├── mesh_utils.py (100 tokens)
├── sh_utils.py (900 tokens)
├── system_utils.py (200 tokens)
├── requirements.txt
├── scripts/
├── download_flame.sh (200 tokens)
├── download_mmdm_weights.sh
├── fixes/
├── fix_flame_pickle.py (400 tokens)
├── pipnet_cpu_nms.pyx (1100 tokens)
├── generate_avatar.sh (800 tokens)
├── generate_felix.sh (200 tokens)
├── generate_lincoln.sh (200 tokens)
├── generate_tesla.sh (200 tokens)
├── install_pixel3dmm.sh (600 tokens)
├── pixel3dmm/
├── convert_to_flowface.py (4.1k tokens)
├── l2cs_eye_tracker.py (1300 tokens)
├── robust_video_matting/
├── model/
├── decoder.py (1400 tokens)
├── deep_guided_filter.py (500 tokens)
├── fast_guided_filter.py (600 tokens)
├── lraspp.py (200 tokens)
├── mobilenetv3.py (600 tokens)
├── model.py (600 tokens)
├── test_pipeline.sh (200 tokens)
├── track_video_pixel3dmm.sh (500 tokens)
```
## /.gitignore
```gitignore path="/.gitignore"
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the enitre vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Cursor
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
```
## /README.md
# 🧢 CAP4D
Official repository for the paper
**CAP4D: Creating Animatable 4D Portrait Avatars with Morphable Multi-View Diffusion Models**, ***CVPR 2025 (Oral)***.
<a href="https://felixtaubner.github.io/" target="_blank">Felix Taubner</a><sup>1,2</sup>, <a href="https://ruihangzhang97.github.io/" target="_blank">Ruihang Zhang</a><sup>1</sup>, <a href="https://mathieutuli.com/" target="_blank">Mathieu Tuli</a><sup>3</sup>, <a href="https://davidlindell.com/" target="_blank">David B. Lindell</a><sup>1,2</sup>
<sup>1</sup>University of Toronto, <sup>2</sup>Vector Institute, <sup>3</sup>LG Electronics
<a href='https://arxiv.org/abs/2412.12093'><img src='https://img.shields.io/badge/arXiv-2301.02379-red'></a> <a href='https://felixtaubner.github.io/cap4d/'><img src='https://img.shields.io/badge/project page-CAP4D-Green'></a> <a href='#citation'><img src='https://img.shields.io/badge/cite-blue'></a>

TL;DR: CAP4D turns any number of reference images into an animatable avatar.
## ⚡️ Quick start guide
### 🛠️ 1. Create conda environment and install requirements
```bash
# 1. Clone repo
git clone https://github.com/felixtaubner/cap4d/
cd cap4d
# 2. Create conda environment for CAP4D:
conda create --name cap4d_env python=3.10
conda activate cap4d_env
# 3. Install requirements
pip install -r requirements.txt
# 4. Set python path
export PYTHONPATH=$(realpath "./"):$PYTHONPATH
```
Follow the [instructions](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md) and install Pytorch3D. Make sure to install with CUDA support. We recommend to install from source:
```bash
export FORCE_CUDA=1
pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
```
### 📦 2. Download FLAME and MMDM weights
Setup your FLAME account at the [FLAME website](https://flame.is.tue.mpg.de/index.html) and set the username
and password environment variables:
```bash
export FLAME_USERNAME=your_flame_user_name
export FLAME_PWD=your_flame_password
```
Download FLAME and MMDM weights using the provided scripts:
```bash
# 1. Download FLAME blendshapes
# set your flame username and password
bash scripts/download_flame.sh
# 2. Download CAP4D MMDM weights
bash scripts/download_mmdm_weights.sh
```
If the FLAME download script did not work, download FLAME2023 from the [FLAME website](https://flame.is.tue.mpg.de/index.html) and place `flame2023_no_jaw.pkl` in `data/assets/flame/`.
Then, fix the flame pkl file to be compatible with newer numpy versions:
```bash
python scripts/fixes/fix_flame_pickle.py --pickle_path data/assets/flame/flame2023_no_jaw.pkl
```
### ✅ 3. Check installation with a test run
Run the pipeline in debug settings to test the installation.
```bash
bash scripts/test_pipeline.sh
```
Check if a video is exported to `examples/debug_output/tesla/sequence_00/renders.mp4`.
If it appears to show a blurry cartoon Nicola Tesla, you're all set!
### 🎬 4. Inference
Run the provided scripts to generate avatars and animate them with a single script:
```bash
bash scripts/generate_felix.sh
bash scripts/generate_lincoln.sh
bash scripts/generate_tesla.sh
```
The output directories contain exported animations which you can view in real-time.
Open the [real-time viewer](https://felixtaubner.github.io/cap4d/viewer/) in your browser (powered by [Brush](https://github.com/ArthurBrussee/brush/)). Click `Load file` and
upload the exported animation found in `examples/output/{SUBJECT}/animation_{ID}/exported_animation.ply`.
## 🔧 Custom inference
See below for how to run your custom inference on your own reference images/videos and driving videos.
### ⚙️ 1. Run FLAME 3D face tracking
#### 1.1 FlowFace tracking
Coming soon! For now, only generations using the provided identities with precomputed [FlowFace](https://felixtaubner.github.io/flowface/) annotations are supported.
#### 1.2 Pixel3DMM tracking
Install [Pixel3DMM](https://github.com/SimonGiebenhain/pixel3dmm) using the provided script. Notice that this is prone to errors due to package version mismatches. Please report any errors as an issue!
```bash
export FLAME_USERNAME=your_flame_user_name
export FLAME_PWD=your_flame_password
export PIXEL3DMM_PATH=$(realpath "../PATH/TO/pixel3dmm") # set this to where you would like to clone the Pixel3DMM repo (absolute path)
export CAP4D_PATH=$(realpath "./") # set this to the cap4d directory (absolute path)
bash scripts/install_pixel3Dmm.sh
```
Run tracking and conversion on reference images/videos using the provided script. Note: If input is a directory of frames, it is assumed to be discontinous set of (monocular!) images. If input is a file, it will assume that it is a continous monocular video.
```bash
export PIXEL3DMM_PATH=$(realpath "../PATH/TO/pixel3dmm")
export CAP4D_PATH=$(realpath "./")
mkdir examples/output/custom/
# For more information on arguments
bash scripts/track_video_pixel3dmm.sh --help
# Process a directory of (reference) images
bash scripts/track_video_pixel3dmm.sh examples/input/felix/images/cam0/ examples/output/custom/reference_tracking/
# Optional: process a driving (or reference) video
bash scripts/track_video_pixel3dmm.sh examples/input/animation/example_video.mp4 examples/output/custom/driving_video_tracking/
```
Notice that results will be slightly worse than with FlowFace tracking, since the MMDM is trained with FlowFace.
### 🖼️ 2. Generate images using MMDM
```bash
# Generate images with single reference image
python cap4d/inference/generate_images.py --config_path configs/generation/default.yaml --reference_data_path examples/output/custom/reference_tracking/ --output_path examples/output/custom/mmdm/
```
Note: the generation script will use all visible CUDA devices. The more available devices, the faster it runs! This will take hours, and requires lots of RAM (ideally > 64 GB) to run smoothly.
### 👤 3. Fit Gaussian avatar
```bash
python gaussianavatars/train.py --config_path configs/avatar/default.yaml --source_paths examples/output/custom/mmdm/reference_images/ examples/output/custom/mmdm/generated_images/ --model_path examples/output/custom/avatar/ --interval 5000
```
### 🕺 4. Animate your avatar
Once the avatar is generated, it can be animated with the driving video computed in step 1 or the provided animations.
```bash
# Animate the avatar with provided animation files
python gaussianavatars/animate.py --model_path examples/output/custom/avatar/ --target_animation_path examples/input/animation/sequence_00/fit.npz --target_cam_trajectory_path examples/input/animation/sequence_00/orbit.npz --output_path examples/output/custom/animation_00/ --export_ply 1 --compress_ply 0
# Animate the avatar with driving video (computed using Pixel3DMM)
python gaussianavatars/animate.py --model_path examples/output/custom/avatar/ --target_animation_path examples/output/custom/driving_video_tracking/fit.npz --target_cam_trajectory_path examples/output/custom/driving_video_tracking/cam_static.npz --output_path examples/output/custom/animation_example/ --export_ply 1 --compress_ply 0
```
The `--target_animation_path` argument contains FLAME expressions and pose, while the (optional) `--target_cam_trajectory_path` argument contains the relative camera trajectory.
### ⚡️ 5. Full inference
We provide a convenient script to run full inference using your reference images and optionally a driving video.
```bash
export PIXEL3DMM_PATH=$(realpath "../PATH/TO/pixel3dmm")
export CAP4D_PATH=$(realpath "./")
# Generate avatar with custom input images/videos.
bash scripts/generate_avatar.sh --help
bash scripts/generate_avatar.sh {INPUT_VIDEO_PATH} {OUTPUT_PATH} [{QUALITY}] [{DRIVING_VIDEO_PATH}]
# Example generation with default quality generation with input images and driving video.
bash scripts/generate_avatar.sh examples/input/felix/images/cam0/ examples/output/felix_custom/ default examples/input/animation/example_video.mp4
```
### ✨ 6. View avatar in live viewer
Open the [real-time viewer](https://felixtaubner.github.io/cap4d/viewer/) in your browser (powered by [Brush](https://github.com/ArthurBrussee/brush/)). Click `Load file` and
upload the exported animation found in
`examples/output/custom/animation_00/exported_animation.ply` or
`examples/output/custom/animation_example/exported_animation.ply`.
## 📚 Related Resources
The MMDM code is based on [ControlNet](https://github.com/lllyasviel/ControlNet). The 4D Gaussian avatar code is based on [GaussianAvatars](https://github.com/ShenhanQian/GaussianAvatars). Special thanks to the authors for making their code public!
Related work:
- [CAT3D](https://cat3d.github.io/): Create Anything in 3D with Multi-View Diffusion Models
- [GaussianAvatars](https://shenhanqian.github.io/gaussian-avatars): Photorealistic Head Avatars with Rigged 3D Gaussians
- [FlowFace](https://felixtaubner.github.io/flowface/): 3D Face Tracking from 2D Video through Iterative Dense UV to Image Flow
- [StableDiffusion](https://github.com/Stability-AI/stablediffusion): High-Resolution Image Synthesis with Latent Diffusion Models
- [Pixel3DMM](https://github.com/SimonGiebenhain/pixel3dmm): Versatile Screen-Space Priors for Single-Image 3D Face Reconstruction
Awesome concurrent work:
- [Pippo](https://yashkant.github.io/pippo/): High-Resolution Multi-View Humans from a Single Image
- [Avat3r](https://tobias-kirschstein.github.io/avat3r/): Large Animatable Gaussian Reconstruction Model for High-fidelity 3D Head Avatars
## 📖 Citation
```tex
@inproceedings{taubner2025cap4d,
author = {Taubner, Felix and Zhang, Ruihang and Tuli, Mathieu and Lindell, David B.},
title = {{CAP4D}: Creating Animatable {4D} Portrait Avatars with Morphable Multi-View Diffusion Models},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2025},
pages = {5318-5330}
}
```
## Acknowledgement
This work was developed in collaboration with and with sponsorship from **LG Electronics**. We gratefully acknowledge their support and contributions throughout the course of this project.
## /assets/banner.gif
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/assets/banner.gif
## /cap4d/datasets/utils.py
```py path="/cap4d/datasets/utils.py"
import contextlib
from pathlib import Path
from typing import Dict
import numpy as np
import einops
import cv2
from decord import VideoReader
from scipy.spatial.transform import Rotation as R
from cap4d.flame.flame import CAP4DFlameSkinner, compute_flame
CROP_MARGIN = 0.2
@contextlib.contextmanager
def temp_seed(seed):
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
def crop_image(
img: np.ndarray,
crop_box: np.ndarray,
bg_value=0,
) -> np.ndarray:
"""
Crop image with the provided crop ranges. If crop box out of image range,
corresponding pixels will be padded black
"""
img_h = img.shape[0]
img_w = img.shape[1]
crop_h = crop_box[3] - crop_box[1]
crop_w = crop_box[2] - crop_box[0]
x_start = max(0, -crop_box[0])
x_end = max(0, crop_box[2] - img_w)
y_start = max(0, -crop_box[1])
y_end = max(0, crop_box[3] - img_h)
cropped_img = np.ones((crop_h, crop_w, *img.shape[2:]), dtype=img.dtype) * bg_value
cropped_img[y_start : crop_h - y_end, x_start : crop_w - x_end, ...] = img[
crop_box[1] + y_start : crop_box[3] - y_end,
crop_box[0] + x_start : crop_box[2] - x_end,
...,
]
return cropped_img
def rescale_image(
img: np.ndarray,
target_resolution: int,
):
interpolation_mode = cv2.INTER_LINEAR
if target_resolution < img.shape[0]:
interpolation_mode = cv2.INTER_AREA
img = cv2.resize(img, (target_resolution, target_resolution), interpolation=interpolation_mode)
return img
def apply_bg(
img: np.ndarray,
bg_weights: np.ndarray,
bg_color: np.ndarray = np.array([255, 255, 255]),
):
bg_weights = bg_weights / 255.
bg_img = bg_color[None, None]
img = bg_img * (1. - bg_weights) + img * bg_weights
return img
def verts_to_pytorch3d(
verts_2d: np.ndarray,
crop_box: np.ndarray,
):
"""
convert vertex 2D coordinates to pytorch3d screen space convention
"""
verts_2d[..., 0] = -((verts_2d[..., 0] - crop_box[..., 0]) / (crop_box[..., 2] - crop_box[..., 0]) * 2. - 1.)
verts_2d[..., 1] = -((verts_2d[..., 1] - crop_box[..., 1]) / (crop_box[..., 3] - crop_box[..., 1]) * 2. - 1.)
return verts_2d
def get_square_bbox(
bbox: np.ndarray,
border_margin: float = 0.1,
mode: str = "max", # min or max
):
"""
Square crops the image with a specified bounding box.
The image size will be squared, adjusted to the bounding box and min_border
width.
Parameters
----------
img_shape: tuple[int, int]
the shape of the image (h, w)
bbox: np.ndarray
the face bounding box
Returns
-------
crop_box: tuple[int, int, int, int]
the index ranges taken from the original image
x_min, y_min, x_max, y_max
"""
bbox = bbox.astype(int)
bbox_h = bbox[3] - bbox[1]
bbox_w = bbox[2] - bbox[0]
b_center = ((bbox[2] + bbox[0]) // 2, (bbox[3] + bbox[1]) // 2)
if mode == "max":
dim = int(max(bbox_h, bbox_w) // 2.0 * (1.0 + border_margin))
elif mode == "min":
dim = int(min(bbox_h, bbox_w) // 2.0 * (1.0 + border_margin))
return (
b_center[0] - dim,
b_center[1] - dim,
b_center[0] + dim,
b_center[1] + dim,
)
def get_bbox_from_verts(verts_2d, vert_mask):
head_verts = verts_2d[vert_mask]
head_bbox = [head_verts[..., 0].min(), head_verts[..., 1].min(), head_verts[..., 0].max(), head_verts[..., 1].max()]
crop_box = get_square_bbox(np.array(head_bbox), border_margin=CROP_MARGIN)
return np.array(crop_box)
def load_flame_verts_and_cam(
flame_skinner: CAP4DFlameSkinner,
flame_item: Dict[str, np.ndarray],
):
flame_out = compute_flame(flame_skinner, flame_item)
verts_2d = flame_out["verts_2d"][0, 0]
offsets_3d = flame_out["offsets_3d"][0]
intrinsics = np.eye(3)
intrinsics[0, 0] = flame_item["fx"][0, 0]
intrinsics[1, 1] = flame_item["fy"][0, 0]
intrinsics[0, 2] = flame_item["cx"][0, 0]
intrinsics[1, 2] = flame_item["cy"][0, 0]
extrinsics = flame_item["extr"][0]
return verts_2d, offsets_3d, intrinsics, extrinsics
def load_camera_rays(
crop_box,
intr,
extr,
target_resolution,
):
downscale_resolution = target_resolution
scale = downscale_resolution / (crop_box[2] - crop_box[0])
new_fx = intr[0, 0] * scale
new_fy = intr[1, 1] * scale
new_cx = (intr[0, 2] - crop_box[0]) * scale
new_cy = (intr[1, 2] - crop_box[1]) * scale
u, v = np.meshgrid(np.arange(downscale_resolution), np.arange(downscale_resolution)) # [H, w]
d = np.stack(((u - new_cx) / new_fx, (v - new_cy) / new_fy, np.ones_like(u)), axis=0)
d = d / (np.linalg.norm(d, axis=0, keepdims=True) + 1e-8)
h, w = d.shape[1:]
# project camera coordinates back to world
d = einops.rearrange(d, 'v h w -> v (h w)')
d = np.linalg.inv(extr[:3, :3]) @ d
d = einops.rearrange(d, 'v (h w) -> v h w', h=h)
return d # ray directions
def adjust_intrinsics_crop(fx, fy, cx, cy, bbox, target_resolution):
scale = target_resolution / (bbox[2] - bbox[0])
new_fx = fx * scale
new_fy = fy * scale
new_cx = (cx - bbox[0]) * scale
new_cy = (cy - bbox[1]) * scale
return new_fx, new_fy, new_cx, new_cy
def get_crop_mask(orig_resolution, target_resolution, crop_box):
crop_mask = np.ones((orig_resolution))
crop_mask = crop_image(crop_mask, crop_box, bg_value=0)
crop_mask = rescale_image(crop_mask, target_resolution)
return crop_mask
class FrameReader:
def __init__(self, video_path):
self.frame_list = sorted(list(Path(video_path).glob("*.*")))
def __len__(self):
return len(self.frame_list)
def __getitem__(self, index):
img = cv2.imread(str(self.frame_list[index]))[..., [2, 1, 0]]
return img
def load_frame(
video_path: Path, # path to .mp4 or dir containing frames
frame_id: np.ndarray,
):
if (video_path).is_dir():
video_reader = FrameReader(video_path)
else:
video_reader = VideoReader(str(video_path))
if frame_id >= len(video_reader):
print(f"WARNING: Frame {frame_id} out of bounds for video with length {len(video_reader)}")
frame_id = len(video_reader) - 1
frame_img = video_reader[frame_id]
if not isinstance(frame_img, np.ndarray):
frame_img = frame_img.asnumpy() # if the video reader is a decord reader
return frame_img
def pivot_camera_intrinsic(extrinsics, target, angles, distance_factor=1.):
"""
Rotates a camera around a target point.
Parameters:
- extrinsics: (4x4) numpy array, world_to_camera transformation matrix.
- target: (3,) numpy array, target coordinates to pivot around.
- angles: (3,) array-like, rotation angles (degrees) around X, Y, Z axes.
Returns:
- new_extrinsics: (4x4) numpy array, updated world_to_camera transformation.
"""
extrinsics = np.linalg.inv(extrinsics)
# Extract rotation and translation from extrinsics
R_c2w = extrinsics[:3, :3] # 3x3 rotation matrix
t_c2w = extrinsics[:3, 3] # 3x1 translation vector
# Compute offset vector from target to camera
v = (t_c2w - target) * distance_factor
# Compute rotation matrix for given angles
R_delta = R.from_euler('YX', angles, degrees=True).as_matrix() # 'yx'
# Apply intrinsic rotation to the camera's rotation (local frame)
new_R_c2w = R_c2w @ R_delta
# Rotate position offset in camera frame as well
new_v = R_c2w @ R_delta @ np.linalg.inv(R_c2w) @ v
new_t_c2w = target + new_v
# Construct new extrinsics
new_extrinsics = np.eye(4)
new_extrinsics[:3, :3] = new_R_c2w
new_extrinsics[:3, 3] = new_t_c2w
# import pdb; pdb.set_trace()
return np.linalg.inv(new_extrinsics)
def get_head_direction(rot):
rot_mat = R.from_rotvec(rot).as_matrix()
head_dir = -rot_mat[:3, 2] # -z is head direction
head_dir[1:] = -head_dir[1:] # p3d to opencv
return head_dir
def compute_yaw_pitch_to_face_direction(extrinsics, world_direction):
"""
Computes yaw and pitch angles (in degrees) needed to rotate the camera
so its forward vector aligns with the given world-space direction.
Parameters:
- extrinsics: (4x4) camera-to-world matrix
- world_direction: (3,) unit vector representing desired view direction in world space
Returns:
- yaw, pitch: angles in degrees
"""
# Normalize direction
world_direction = world_direction / np.linalg.norm(world_direction)
# Camera's current rotation (camera-to-world)
R_c2w = extrinsics[:3, :3]
# Convert world direction into camera's local frame
dir_local = R_c2w.T @ world_direction
dx, dy, dz = dir_local
# Handle potential divide-by-zero or arcsin domain errors
dz = np.clip(dz, -1e-6, 1) if dz == 0 else dz
dy = np.clip(dy, -1, 1)
# Yaw: rotation around Y (left-right)
yaw = np.degrees(np.arctan2(dx, dz))
# Pitch: rotation around X (up-down)
pitch = np.degrees(np.arcsin(-dy)) # negative because +Y is up
return yaw, pitch
```
## /cap4d/flame/flame.py
```py path="/cap4d/flame/flame.py"
from typing import Dict, Optional
import einops
import numpy as np
import torch
from flowface.flame.flame import FlameSkinner, FLAME_N_SHAPE, FLAME_N_EXPR
from flowface.flame.utils import batch_rodrigues, OPENCV2PYTORCH3D, transform_vertices, project_vertices
from cap4d.flame.mouth import FlameMouth
FLAME_PKL_PATH = "data/assets/flame/flame2023_no_jaw.pkl"
JAW_REGRESSOR_PATH = "data/assets/flame/jaw_regressor.npy"
BLINK_BLENDSHAPE_PATH = "data/assets/flame/blink_blendshape.npy"
# This module has NO trainable parameters
class CAP4DFlameSkinner(FlameSkinner):
def __init__(
self,
flame_pkl_path: str = FLAME_PKL_PATH,
n_shape_params: int = FLAME_N_SHAPE,
n_expr_params: int = FLAME_N_EXPR,
blink_blendshape_path: str = BLINK_BLENDSHAPE_PATH,
add_mouth: bool = False,
add_lower_jaw: bool = False,
jaw_regressor_path: str = JAW_REGRESSOR_PATH,
):
super().__init__(flame_pkl_path, n_shape_params, n_expr_params, blink_blendshape_path)
self.add_mouth = add_mouth
if add_mouth:
self.mouth = FlameMouth()
self.add_lower_jaw = add_lower_jaw
if add_lower_jaw:
self.lower_jaw = FlameMouth()
jaw_regressor = torch.tensor(np.load(jaw_regressor_path))
self.register_buffer("jaw_regressor", jaw_regressor)
def forward(
self,
flame_sequence: Dict,
return_offsets: bool = True,
return_transforms: bool = False,
) -> torch.Tensor:
"""
Compute 3D vertices given a flame sequence with shape parameters
flame_sequence (dictionary):
shape (N_shape)
and N_t timesteps of expression, pose (rot, tra), eye_rot (optional), jaw_rot (optional):
expr (N_t, N_exp)
rot (N_t, 3)
tra (N_t, 3)
eye_rot (N_t, 3)
jaw_rot (N_t, 3)
neck_rot (N_t, 3)
output: verts (N_t, V, 3)
"""
shape_offsets = self._get_shape_offsets(flame_sequence["shape"][None], None)
shape_verts = self._get_template_vertices(None) + shape_offsets
expr_offsets = self._get_expr_offsets(flame_sequence["expr"], None)
verts = shape_verts + expr_offsets # N_t, V, 3
# create rotation matrix for joint rotations, we apply base transform separately
rotations = torch.eye(3, device=verts.device)[None, None].repeat(verts.shape[0], 5, 1, 1)
if "neck_rot" in flame_sequence and flame_sequence["neck_rot"] is not None:
rotations[:, 0, ...] = batch_rodrigues(flame_sequence["neck_rot"])
if "jaw_rot" in flame_sequence and flame_sequence["jaw_rot"] is not None:
rotations[:, 2, ...] = batch_rodrigues(flame_sequence["jaw_rot"])
if "eye_rot" in flame_sequence and flame_sequence["eye_rot"] is not None:
eye_rot = batch_rodrigues(flame_sequence["eye_rot"])
rotations[:, 3, ...] = eye_rot
rotations[:, 4, ...] = eye_rot
verts, v_transforms = self._apply_joint_rotation(verts, rotations=rotations, vert_mask=None, return_transforms=True)
# compute offsets (including joint rotations)
offsets = verts - shape_verts
if self.add_mouth:
mouth_verts = self.mouth(shape_verts, self.joint_regressor)
mouth_verts = mouth_verts.repeat(verts.shape[0], 1, 1)
verts = torch.cat([verts, mouth_verts], dim=1)
offsets = torch.cat([offsets, torch.zeros_like(mouth_verts)], dim=1)
v_transforms = torch.cat([v_transforms, torch.zeros(mouth_verts.shape[0], mouth_verts.shape[1], 4, 4, device=v_transforms.device)], dim=1)
if self.add_lower_jaw:
jaw_rot = einops.einsum(flame_sequence["expr"], self.jaw_regressor, 'b exp, exp r -> b r')
neutral_jaw_verts = self.lower_jaw(shape_verts, self.joint_regressor, batch_rodrigues(jaw_rot * 0.))
jaw_verts = self.lower_jaw(shape_verts, self.joint_regressor, batch_rodrigues(jaw_rot))
verts = torch.cat([verts, jaw_verts], dim=1)
offsets = torch.cat([offsets, jaw_verts - neutral_jaw_verts], dim=1)
jaw_transforms = torch.zeros(jaw_verts.shape[0], 4, 4, device=v_transforms.device)
jaw_transforms[:, :3, :3] = batch_rodrigues(jaw_rot)
jaw_transforms[..., -1, -1] = 1.
jaw_transforms = jaw_transforms[:, None].repeat(1, jaw_verts.shape[1], 1, 1)
v_transforms = torch.cat([v_transforms, jaw_transforms], dim=1)
# apply base transform separately
base_rot = batch_rodrigues(flame_sequence["rot"])
base_tra = flame_sequence["tra"][..., None]
verts = (base_rot @ verts.permute(0, 2, 1) + base_tra).permute(0, 2, 1)
output = [verts]
if return_offsets:
output.append(offsets)
if return_transforms:
base_transform = torch.cat([base_rot, base_tra], dim=2)
base_transform = torch.cat([base_transform, torch.zeros_like(base_transform[:, :1, ...])], dim=1)
base_transform[..., -1, -1] = 1.
v_transforms = einops.einsum(base_transform, v_transforms, 'b i j, b N j k -> b N i k')
output.append(v_transforms)
return output
def compute_flame(
flame: CAP4DFlameSkinner,
fit_3d: Dict[str, np.ndarray],
):
flame_sequence = {
"shape": torch.tensor(fit_3d["shape"]).float(),
"expr": torch.tensor(fit_3d["expr"]).float(),
"rot": torch.tensor(fit_3d["rot"]).float(),
"tra": torch.tensor(fit_3d["tra"]).float(),
"eye_rot": torch.tensor(fit_3d["eye_rot"]).float(),
"jaw_rot": None,
"neck_rot": None,
}
if "neck_rot" in fit_3d:
flame_sequence["neck_rot"] = torch.tensor(fit_3d["neck_rot"]).float()
if "jaw_rot" in fit_3d:
flame_sequence["jaw_rot"] = torch.tensor(fit_3d["jaw_rot"]).float()
# compute FLAME vertices
verts_3d, offsets_3d = flame(
flame_sequence,
return_offsets=True,
) # [N_t V 3], [N_t 2 3]
fx, fy, cx, cy = [torch.tensor(fit_3d[key]).float() for key in ["fx", "fy", "cx", "cy"]] # [N_C, 1]
extr = torch.tensor(fit_3d["extr"]).float() # [N_c 3 3]
cam_parameters = {
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy,
"extr": extr,
}
# transform into OpenCV camera coordinate convention
verts_3d_cv = transform_vertices(OPENCV2PYTORCH3D[None].to(verts_3d.device), verts_3d) # [N_t V 3]
# project vertices to cameras
verts_2d = project_vertices(verts_3d_cv, cam_parameters) # [N_c N_t V 3]
return {
"verts_3d": verts_3d.cpu().numpy(),
"verts_3d_cv": verts_3d_cv.cpu().numpy(),
"verts_2d": verts_2d.cpu().numpy(),
"offsets_3d": offsets_3d.cpu().numpy(),
}
```
## /cap4d/flame/mouth.py
```py path="/cap4d/flame/mouth.py"
import torch
import torch.nn as nn
import numpy as np
import einops
def generate_uv_sphere(r=1.0, latitude_steps=30, longitude_steps=30):
# Creat half a sphere!
# Generate latitude and longitude
latitudes = torch.linspace(-np.pi / 2, np.pi / 2, latitude_steps)[:latitude_steps // 2]
longitudes = torch.linspace(0, 2 * np.pi, longitude_steps)
verts = []
# Create a meshgrid for latitudes and longitudes
for lat in latitudes:
for lon in longitudes:
verts.append([
r * torch.cos(lat) * torch.cos(lon),
r * torch.cos(lat) * torch.sin(lon),
r * torch.sin(lat),
])
verts = torch.tensor(verts)
# Generate triangle indices
indices = []
for i in range(latitude_steps // 2 - 1):
for j in range(longitude_steps):
# Current vertex
lat_1_lon_1 = i * longitude_steps + j
lat_1_lon_2 = i * longitude_steps + (j + 1) % longitude_steps
# Next row's vertices
lat_2_lon_1 = (i + 1) * longitude_steps + j
lat_2_lon_2 = (i + 1) * longitude_steps + (j + 1) % longitude_steps
if i < latitude_steps - 2:
indices.append([lat_1_lon_1, lat_2_lon_2, lat_2_lon_1])
if i > 0:
indices.append([lat_1_lon_1, lat_1_lon_2, lat_2_lon_2])
# Convert indices to a tensor
face_indices = torch.tensor(indices, dtype=torch.long)
return verts, face_indices
class FlameMouth(nn.Module):
def __init__(
self,
long_steps=20,
lat_steps=20,
lip_v_index=3533,
lip_offset=0.005,
):
super().__init__()
v_sphere, f_sphere = generate_uv_sphere(
r=1.,
latitude_steps=lat_steps,
longitude_steps=long_steps,
)
v_sphere[:, 1] = -v_sphere[:, 1] # flip axis
v_sphere[:, 2] = -v_sphere[:, 2] # flip axis to align in right direction
self.lip_v_index = lip_v_index
self.register_buffer("vertices", v_sphere)
self.register_buffer("faces", f_sphere)
self.lip_offset = lip_offset
def forward(
self,
neutral_verts,
joint_regressor,
jaw_rotation=None,
):
jaw_joint = einops.einsum(neutral_verts, joint_regressor[2], "b V xyz, V -> b xyz") # (B, 3)
lip_vert = neutral_verts[:, self.lip_v_index]
offset = lip_vert - jaw_joint
distance = offset.norm(dim=-1, keepdim=True)
direction = offset / distance
y = torch.zeros_like(direction, device=direction.device)
y[:, 1] = 1
# y[:, 0] = 1
new_x = torch.cross(y, direction, dim=-1)
new_x = new_x / new_x.norm(dim=-1, keepdim=True)
new_y = torch.cross(direction, new_x, dim=-1)
new_y = new_y / new_y.norm(dim=-1, keepdim=True)
new_z = direction
rot_mat = torch.stack([new_x, new_y, new_z], dim=-1)
v_sphere = self.vertices[None] * distance[..., None] * 0.25
# v_sphere = v_sphere + y * 0.5
v_sphere = (rot_mat @ v_sphere.permute(0, 2, 1)).permute(0, 2, 1)
center = jaw_joint + offset * 0.75 - self.lip_offset * direction
v_sphere = v_sphere + center
if jaw_rotation is not None:
v_offset = jaw_rotation @ (v_sphere - jaw_joint).permute(0, 2, 1)
v_sphere = jaw_joint + v_offset.permute(0, 2, 1)
return v_sphere
```
## /cap4d/inference/data/generation_data.py
```py path="/cap4d/inference/data/generation_data.py"
import numpy as np
from cap4d.datasets.utils import (
pivot_camera_intrinsic,
get_head_direction,
compute_yaw_pitch_to_face_direction,
)
from cap4d.inference.data.inference_data import CAP4DInferenceDataset
def elipsis_sample(yaw_limit, pitch_limit):
if yaw_limit == 0. or pitch_limit == 0.:
return 0., 0.
dist = 1.
while dist >= 1.:
yaw = np.random.uniform(-yaw_limit, yaw_limit)
pitch = np.random.uniform(-pitch_limit, pitch_limit)
dist = np.sqrt((yaw / yaw_limit) ** 2 + (pitch / pitch_limit) ** 2)
return yaw, pitch
class GenerationDataset(CAP4DInferenceDataset):
def __init__(
self,
generation_data_path,
reference_flame_item,
n_samples=840,
yaw_range=55,
pitch_range=20,
expr_factor=1.0,
resolution=512,
downsample_ratio=8,
):
super().__init__(resolution, downsample_ratio)
self.n_samples = n_samples
self.yaw_range = yaw_range
self.pitch_range = pitch_range
self.flame_dicts = self.init_flame_params(
generation_data_path,
reference_flame_item,
n_samples,
yaw_range,
pitch_range,
expr_factor,
)
def init_flame_params(
self,
generation_data_path,
reference_flame_item,
n_samples,
yaw_range,
pitch_range,
expr_factor,
):
gen_data = dict(np.load(generation_data_path))
ref_extr = reference_flame_item["extr"]
ref_shape = reference_flame_item["shape"]
ref_fx = reference_flame_item["fx"]
ref_fy = reference_flame_item["fy"]
ref_cx = reference_flame_item["cx"]
ref_cy = reference_flame_item["cy"]
ref_resolution = reference_flame_item["resolutions"]
ref_rot = reference_flame_item["rot"]
ref_tra = reference_flame_item["tra"]
ref_tra_cv = ref_tra.copy()
ref_tra_cv[:, 1:] = -ref_tra_cv[:, 1:] # p3d to opencv
ref_head_dir = get_head_direction(ref_rot) # in p3d
ref_head_dir[:, 1:] = -ref_head_dir[:, 1:] # p3d to opencv
center_yaw, center_pitch = compute_yaw_pitch_to_face_direction(
ref_extr[0],
ref_head_dir[0],
)
flame_list = []
assert n_samples <= len(gen_data["expr"]), "too many samples"
for expr, eye_rot in zip(gen_data["expr"][:n_samples], gen_data["eye_rot"][:n_samples]):
yaw, pitch = elipsis_sample(yaw_range, pitch_range)
yaw += center_yaw
pitch -= center_pitch # pitch is flipped for some reason
rotated_extr = pivot_camera_intrinsic(ref_extr[0], ref_tra_cv[0], [yaw, pitch])
flame_dict = {
"shape": ref_shape,
"expr": expr[None] * expr_factor,
"eye_rot": eye_rot[None] * expr_factor,
"rot": ref_rot,
"tra": ref_tra,
"extr": rotated_extr[None],
"resolutions": ref_resolution,
"fx": ref_fx,
"fy": ref_fy,
"cx": ref_cx,
"cy": ref_cy,
}
flame_list.append(flame_dict)
self.flame_list = flame_list
self.ref_extr = ref_extr[0]
```
## /cap4d/inference/data/inference_data.py
```py path="/cap4d/inference/data/inference_data.py"
import numpy as np
from torch.utils.data import Dataset
import einops
from cap4d.flame.flame import CAP4DFlameSkinner
from cap4d.datasets.utils import (
load_frame,
crop_image,
rescale_image,
apply_bg,
load_flame_verts_and_cam,
get_bbox_from_verts,
load_camera_rays,
verts_to_pytorch3d,
)
class CAP4DInferenceDataset(Dataset):
def __init__(
self,
resolution=512,
downsample_ratio=8,
):
self.resolution = resolution
self.latent_resolution = self.resolution // downsample_ratio
self.flame_skinner = CAP4DFlameSkinner(
add_mouth=True,
n_shape_params=150,
n_expr_params=65,
)
self.head_vertex_ids = np.genfromtxt("data/assets/flame/head_vertices.txt").astype(int)
self.flame_list = None
self.ref_extr = None
self.data_path = None
def __len__(self):
assert self.flame_list is not None, "self.flame_list not properly initialized"
return len(self.flame_list)
def __getitem__(self, idx):
flame_item = self.flame_list[idx]
verts_2d, offsets_3d, intrinsics, extrinsics = load_flame_verts_and_cam(
self.flame_skinner,
flame_item,
)
crop_box = get_bbox_from_verts(verts_2d, self.head_vertex_ids)
flame_item["crop_box"] = crop_box
if "img_dir_path" in flame_item:
# we have images available, load them
# load and crop image, including background
img_dir_path = flame_item["img_dir_path"]
timestep_id = flame_item["timestep_id"]
img = load_frame(img_dir_path, timestep_id)
del flame_item["img_dir_path"] # delete string from flame dict so that it can be collated
if "bg_dir_path" in flame_item:
bg = load_frame(flame_item["bg_dir_path"], timestep_id)
del flame_item["bg_dir_path"] # delete string from flame dict so that it can be collated
else:
print(f"WARNING: bg does not exist for image {img_dir_path}. Make sure the background is white.")
bg = np.ones_like(img) * 255
out_crop_mask = np.ones_like(img[..., [0]])
img = apply_bg(img, bg)
img = crop_image(img, crop_box, bg_value=255)
out_crop_mask = crop_image(out_crop_mask, crop_box, bg_value=0)
img = rescale_image(img, self.resolution)
img = ((img / 127.5) - 1.0).astype(np.float32)
out_crop_mask = rescale_image(out_crop_mask, self.latent_resolution)
is_ref = True
else:
# no image available means these images need to be generated
# set image to zero
img = np.zeros((self.resolution, self.resolution, 3), dtype=np.float32)
out_crop_mask = np.ones((self.latent_resolution, self.latent_resolution), dtype=np.float32)
is_ref = False
# load and transform ray map
ray_map = load_camera_rays(
crop_box,
intrinsics,
extrinsics,
self.latent_resolution,
)
assert self.ref_extr is not None, "reference extrinsics ref_extr not set"
# transform raymap to base extrinsics
ray_map_h = ray_map.shape[1]
ray_map = einops.rearrange(ray_map, 'v h w -> v (h w)')
ray_map = self.ref_extr[:3, :3] @ ray_map
ray_map = einops.rearrange(ray_map, 'v (h w) -> v h w', h=ray_map_h)
# reference mask is one for reference dataset
reference_mask = np.ones_like(out_crop_mask) * is_ref
# convert pixel space vertices to pytorch3d space [-1, 1]
verts_2d = verts_to_pytorch3d(verts_2d, np.array(crop_box))
cond_dict = {
"out_crop_mask": out_crop_mask[None],
"reference_mask": reference_mask[None],
"ray_map": ray_map[None],
"verts_2d": verts_2d[None],
"offsets_3d": offsets_3d[None],
} # [None] is for fake time dimension
out_dict = {
"jpg": img[None], # jpg names comes from controlnet implementation
"hint": cond_dict,
"flame_params": flame_item,
}
return out_dict
```
## /cap4d/inference/data/reference_data.py
```py path="/cap4d/inference/data/reference_data.py"
import numpy as np
import json
from pathlib import Path
from cap4d.inference.data.inference_data import CAP4DInferenceDataset
class ReferenceDataset(CAP4DInferenceDataset):
def __init__(
self,
data_path,
resolution=512,
downsample_ratio=8,
):
super().__init__(resolution, downsample_ratio)
self.load_flame_params(data_path)
def load_flame_params(
self,
data_path: Path,
):
flame_dict = dict(np.load(data_path / "fit.npz"))
with open(data_path / "reference_images.json") as f:
ref_json = json.load(f)
ref_list = []
for cam_name, timestep_id in ref_json:
cam_id = np.where(flame_dict["camera_order"] == cam_name)[0].item()
ref_list.append((cam_id, timestep_id)) # cam_id, timestep_id
flame_list = []
ref_extr = None
for cam_id, timestep_id in ref_list:
# select a single frame (camera, timestep) set from flame_dict
flame_item = {}
for key in flame_dict:
if key in ["expr", "rot", "tra", "eye_rot"]:
flame_item[key] = flame_dict[key][[timestep_id]]
elif key in ["fx", "fy", "cx", "cy", "extr", "resolutions"]:
flame_item[key] = flame_dict[key][[cam_id]]
elif key in ["shape"]:
flame_item[key] = flame_dict[key]
flame_item["timestep_id"] = timestep_id
cam_dir_path = flame_dict["camera_order"][cam_id]
flame_item["img_dir_path"] = data_path / "images" / cam_dir_path
bg_dir_path = data_path / "bg" / cam_dir_path
if bg_dir_path.exists():
flame_item["bg_dir_path"] = bg_dir_path
flame_list.append(flame_item)
if ref_extr is None:
ref_extr = flame_item["extr"]
self.ref_extr = ref_extr[0]
self.flame_list = flame_list
```
## /cap4d/inference/generate_images.py
```py path="/cap4d/inference/generate_images.py"
from pathlib import Path
import argparse
import copy
import shutil
import torch
from torch.utils.data import DataLoader
import gc
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from cap4d.mmdm.sampler import StochasticIOSampler
from cap4d.inference.data.reference_data import ReferenceDataset
from cap4d.inference.data.generation_data import GenerationDataset
from cap4d.inference.utils import (
get_condition_from_dataloader,
load_model,
save_visualization,
save_flame_params,
convert_and_save_latent_images,
find_number_of_generated_images,
)
@torch.no_grad()
def main(args):
ref_data_path = Path(args.reference_data_path)
gen_config_path = Path(args.config_path)
gen_config = OmegaConf.load(gen_config_path)
output_path = Path(args.output_path)
output_path.mkdir(exist_ok=True, parents=True)
output_ref_path = output_path / "reference_images"
output_ref_path.mkdir(exist_ok=True)
output_gen_path = output_path / "generated_images"
output_gen_path.mkdir(exist_ok=True)
seed_everything(gen_config["seed"])
shutil.copy(gen_config_path, output_path / "mmdm_config_dump.yaml")
# create reference dataloaders
print("Creating dataloaders")
refset = ReferenceDataset(ref_data_path, gen_config["resolution"])
ref_dataloader = DataLoader(refset, num_workers=1, batch_size=1, shuffle=False)
n_gen_samples = find_number_of_generated_images(
gen_config["generation_data"]["n_samples"],
len(ref_dataloader),
gen_config["R_max"],
gen_config["V"],
)
# create generation dataloaders
genset = GenerationDataset(
generation_data_path=gen_config["generation_data"]["data_path"],
reference_flame_item=refset.flame_list[0],
n_samples=n_gen_samples,
resolution=gen_config["resolution"],
yaw_range=gen_config["generation_data"]["yaw_range"],
pitch_range=gen_config["generation_data"]["pitch_range"],
expr_factor=gen_config["generation_data"]["expr_factor"],
)
gen_dataloader = DataLoader(genset, num_workers=1, batch_size=1, shuffle=False)
# load models
model = load_model(Path(gen_config["ckpt_path"]))
device_model_map = {}
first_rank_model = None
first_rank_device = None
if "cuda" in args.device:
for cuda_id in range(torch.cuda.device_count()):
dev_key = f"cuda:{cuda_id}"
device_model_map[dev_key] = copy.deepcopy(model).to(dev_key)
if first_rank_model is None:
first_rank_model = device_model_map[dev_key]
first_rank_device = dev_key
else:
device_model_map[args.device] = model.to(args.device)
first_rank_model = device_model_map[args.device]
first_rank_device = args.device
# model.cond_stage_model.device = "cpu"
print(f"Done loading model.")
# load all reference frames and create conditioning (and unconditional) images for each
print(f"Loading reference dataset from {ref_data_path}")
ref_data = get_condition_from_dataloader(
first_rank_model,
ref_dataloader,
args.device,
)
# load all generation frames and create conditioning images for each
print(f"Loading generation dataset from {gen_config['generation_data']['data_path']}")
gen_data = get_condition_from_dataloader(
first_rank_model,
gen_dataloader,
args.device,
)
if args.visualize_conditioning:
print("Saving visualization of conditioning images")
save_visualization(ref_data["cond_vis_frames"], output_ref_path)
save_visualization(gen_data["cond_vis_frames"], output_gen_path)
print("Saving flame parameters")
save_flame_params(ref_data["flame_params"], output_ref_path)
save_flame_params(gen_data["flame_params"], output_gen_path)
gc.collect()
for key in ref_data["cond_frames"]:
ref_data["cond_frames"][key] = torch.cat(ref_data["cond_frames"][key], dim=0)
ref_data["uncond_frames"][key] = torch.cat(ref_data["uncond_frames"][key], dim=0)
gen_data["cond_frames"][key] = torch.cat(gen_data["cond_frames"][key], dim=0)
gen_data["uncond_frames"][key] = torch.cat(gen_data["uncond_frames"][key], dim=0)
# Sample with Stochastic I/O Conditioning:
print(f"Generating images on {len(list(device_model_map))} devices with stochastic I/O.")
stochastic_io_sampler = StochasticIOSampler(device_model_map)
z_gen = stochastic_io_sampler.sample(
S=gen_config["n_ddim_steps"],
ref_cond=ref_data["cond_frames"],
ref_uncond=ref_data["uncond_frames"],
gen_cond=gen_data["cond_frames"],
gen_uncond=gen_data["uncond_frames"],
latent_shape=(4, gen_config["resolution"] // 8, gen_config["resolution"] // 8),
V=gen_config["V"],
R_max=gen_config["R_max"],
cfg_scale=gen_config["cfg_scale"],
)
print("Done generating.")
gc.collect()
z_gen = z_gen.cpu()
z_ref = ref_data["cond_frames"]["z_input"].cpu()
print(f"Saving reference images to {output_ref_path}/images")
convert_and_save_latent_images(z_ref, first_rank_model, first_rank_device, output_ref_path)
print(f"Saving generated images to {output_gen_path}/images")
convert_and_save_latent_images(z_gen, first_rank_model, first_rank_device, output_gen_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path",
type=str,
required=True,
help="path to generation config file",
)
parser.add_argument(
"--reference_data_path",
type=str,
required=True,
help="path to reference json file",
)
parser.add_argument(
"--output_path",
type=str,
required=True,
help="path to output",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="batch size",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="inference device",
)
parser.add_argument(
"--visualize_conditioning",
type=int,
default=1,
help="whether to save visualizations of conditioning images",
)
args = parser.parse_args()
main(args)
```
## /cap4d/inference/utils.py
```py path="/cap4d/inference/utils.py"
from collections import defaultdict
import os
import einops
from omegaconf import OmegaConf
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import cv2
from controlnet.ldm.util import instantiate_from_config
from cap4d.mmdm.mmdm import MMLDM
def to_device(batch, device="cuda"):
for key in batch:
if isinstance(batch[key], dict):
batch[key] = to_device(batch[key])
elif not isinstance(batch[key], list):
batch[key] = batch[key].to(device)
return batch
def log_cond(module, batch):
cond_model = module.cond_stage_model
cond_key = module.control_key
c_cond = cond_model(batch[cond_key], unconditional=False)
enc_vis = cond_model.get_vis(c_cond["pos_enc"])
for key in enc_vis:
vis = enc_vis[key]
b_ = vis.shape[0]
vis = einops.rearrange(vis, 'b t h w c -> (b t) c h w')
vis = F.interpolate(vis, scale_factor=8., mode="nearest")
vis = vis.clamp(-1., 1.)
enc_vis[key] = einops.rearrange(vis, '(b t) c h w -> (b t) h w c', b=b_)
return enc_vis
def load_model(ckpt_path):
list_of_files = list((ckpt_path / "checkpoints").glob("*.ckpt"))
latest_file = max(list_of_files, key=os.path.getctime)
print("Loading model using checkpoint", latest_file)
weight_path = latest_file
# load modified model
config_path = ckpt_path / "config_dump.yaml"
config = OmegaConf.load(config_path)
print(f'Loaded model config from [{config_path}]')
model: MMLDM = instantiate_from_config(config.model).cpu()
print("Loading state dict")
state_dict = torch.load(weight_path, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()
return model
def get_condition_from_dataloader(model, dataloader, device):
cond_frames = defaultdict(list)
uncond_frames = defaultdict(list)
cond_vis_frames = defaultdict(list)
flame_params = []
for frame_id, batch in enumerate(tqdm(dataloader)):
batch = to_device(batch, device=device)
# get conditioning from data batch
z, c = model.get_input(batch, model.first_stage_key, force_conditional=True)
# store conditioning for later
for key in c["c_concat"][0]:
c_cond = einops.rearrange(c["c_concat"][0][key], 'b t ... -> (b t) ...')
c_uncond = einops.rearrange(c["c_uncond"][0][key], 'b t ... -> (b t) ...')
cond_frames[key].append(c_cond.cpu())
uncond_frames[key].append(c_uncond.cpu())
# store conditioning visualization for later
cond_vis = log_cond(model, batch)
for key in cond_vis:
cond_vis_frames[key].append(cond_vis[key].cpu())
# store the flame and camera parameters for later
for b in range(batch["flame_params"]["fx"].shape[0]):
flame_dict = {}
for key in batch["flame_params"]:
flame_dict[key] = batch["flame_params"][key][b].cpu().numpy()
flame_params.append(flame_dict)
return {
"cond_frames": cond_frames,
"uncond_frames": uncond_frames,
"cond_vis_frames": cond_vis_frames,
"flame_params": flame_params,
}
def save_visualization(vis_frames, output_dir):
condition_base_dir = output_dir / "condition_vis"
condition_base_dir.mkdir(exist_ok=True)
for key in vis_frames:
for frame_id, vis_img in enumerate(vis_frames[key]):
out_dir = condition_base_dir / f"{key}"
out_dir.mkdir(exist_ok=True)
vis_img = vis_img[0]
cv2.imwrite(
str(out_dir / f"{frame_id:05d}.jpg"),
(((vis_img[..., [2, 1, 0]].cpu().numpy() + 1.) / 2.) * 255).astype(np.uint8),
)
def save_flame_params(flame_params, output_dir):
out_flame_dir = output_dir / "flame"
out_flame_dir.mkdir(exist_ok=True)
for frame_id, flame_item in enumerate(flame_params):
np.savez(out_flame_dir / f"{frame_id:05d}.npz", **flame_item)
def convert_and_save_latent_images(latents, model, device, output_dir):
out_img_dir = output_dir / "images"
out_img_dir.mkdir(exist_ok=True)
for i in range(latents.shape[0]):
# Convert latent to RGB
x_samples = model.decode_first_stage(latents[None, [i]].to(device))[0, 0]
img = ((x_samples + 1.) / 2.).clip(0., 1.)
img = img.permute(1, 2, 0).cpu().numpy() * 255.
out_img_path = out_img_dir / f"{i:05d}.png"
success = cv2.imwrite(str(out_img_path), img[..., [2, 1, 0]].astype(np.uint8))
assert success, f"failed to save image to {out_img_path}"
def find_number_of_generated_images(n_gen_default, n_ref, R_max, V):
R = min(n_ref, R_max)
missing_frames = n_gen_default % (V - R)
n_gen_new = n_gen_default + missing_frames
if missing_frames > 0:
print(f"WARNING: Default number of generated images {n_gen_default} incompatible with R={R}, adjusting the number to {n_gen_new}")
return n_gen_new
```
## /cap4d/mmdm/conditioning/cap4dcond.py
```py path="/cap4d/mmdm/conditioning/cap4dcond.py"
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from cap4d.mmdm.conditioning.mesh2img import PropRenderer
class PositionalEncoding(nn.Module):
def __init__(self, channels_per_dim):
"""
:param channels: The last dimension of the tensor you want to apply pos emb to.
"""
super().__init__()
assert channels_per_dim % 2 == 0
n_ch = channels_per_dim // 2
freqs = 2. ** torch.linspace(0., n_ch - 1, steps=n_ch)
self.register_buffer("freqs", freqs)
def forward(self, tensor):
"""
:param tensor: A 4d tensor of size (b, h, w, [x, y, z]), x, y, z should be [0, 1]
:return: Positional Encoding Matrix of size (b, h, w, ch)
"""
if len(tensor.shape) != 4:
raise RuntimeError("The input tensor has to be 4d!")
pos_xyz = tensor[..., None] * self.freqs[None, None, None, None]
pos_emb = torch.cat([torch.sin(pos_xyz), torch.cos(pos_xyz)], dim=-1)
pos_emb = einops.rearrange(pos_emb, "b h w c f -> b h w (c f)")
return pos_emb
class CAP4DConditioning(nn.Module):
def __init__(
self,
image_size=64,
positional_channels=42,
positional_multiplier=1.,
super_resolution=2,
use_ray_directions=True,
use_expr_deformation=True,
use_crop_mask=False,
std_expr_deformation=0.0104,
) -> None:
super().__init__()
self.image_size = image_size
assert super_resolution >=1 and super_resolution % 1 == 0
self.super_resolution = super_resolution
self.positional_channels = positional_channels
self.positional_multiplier = positional_multiplier
self.use_ray_directions = use_ray_directions
self.use_expr_deformation = use_expr_deformation
self.std_expr_deformation = std_expr_deformation
self.use_crop_mask = use_crop_mask
assert positional_channels % 3 == 0
self.pos_encoding = PositionalEncoding(positional_channels // 3)
self.renderer = PropRenderer()
def forward(self, batch, unconditional=True):
verts = batch["verts_2d"]
offsets = batch["offsets_3d"]
ref_mask = batch["reference_mask"][:, :, None]
B, T = verts.shape[:2]
z_input = None
if "z" in batch:
z_input = batch["z"]
img_size = self.image_size
if unconditional:
total_channels = self.positional_channels + 1 # positional and ref mask
if self.use_crop_mask:
total_channels += 1
if self.use_ray_directions:
total_channels += 3
if self.use_expr_deformation:
total_channels += 3
pose_pos_enc = torch.zeros((B, T, img_size, img_size, total_channels), device=verts.device)
if z_input is not None:
z_input = z_input * 0.
else:
with torch.no_grad():
verts = einops.rearrange(verts, 'b t n v -> (b t) n v')
offsets = einops.rearrange(offsets, 'b t n v -> (b t) n v')
offsets = offsets / self.std_expr_deformation # normalize offset magnitude
pose_map, mask = self.renderer.render(
verts,
(img_size * self.super_resolution, img_size * self.super_resolution),
prop=offsets if self.use_expr_deformation else None,
)
if self.use_expr_deformation:
# extract last three channels which are offsets
pose_map, offsets = pose_map.split([3, 3], dim=-1)
# Need to unnormalize from [-1, 1] to [0, resolution]
# pos_enc = self.pos_encoding((uv_img + 1.) / 2. * self.image_size * self.positional_multiplier)
pose_pos_enc = self.pos_encoding(pose_map * self.positional_multiplier)
if self.use_expr_deformation:
# append expression offset
pose_pos_enc = torch.cat([pose_pos_enc, offsets], dim=-1)
pose_pos_enc = pose_pos_enc * mask # mask values not rendered
# downscale pos_enc if we use super resolution
pose_pos_enc = einops.rearrange(pose_pos_enc, 'bt h w c -> bt c h w')
pose_pos_enc = F.interpolate(pose_pos_enc, (img_size, img_size), mode="area")
pose_pos_enc = einops.rearrange(pose_pos_enc, '(b t) c h w -> b t h w c', b=B)
if self.use_ray_directions:
# concat ray map
ray_map = batch["ray_map"]
ray_map = einops.rearrange(ray_map, 'b t c h w -> b t h w c')
pose_pos_enc = torch.cat([pose_pos_enc, ray_map], dim=-1)
# concat ref mask
ref_mask_reshape = einops.rearrange(ref_mask, 'b t c h w -> b t h w c')
pose_pos_enc = torch.cat([pose_pos_enc, ref_mask_reshape], dim=-1)
if self.use_crop_mask:
crop_mask = batch["out_crop_mask"][..., None]
pose_pos_enc = torch.cat([pose_pos_enc, crop_mask], dim=-1)
return {
"pos_enc": pose_pos_enc,
"z_input": z_input,
"ref_mask": ref_mask,
}
def get_vis(self, enc):
visualizations = {}
n_pos = self.positional_channels // 3
counter = 0
pos_enc = enc[..., 0:self.positional_channels]
# for i in [n_pos-1]:
for i in range(n_pos-2, n_pos):
visualizations[f"pose_map_{i}"] = pos_enc[..., [i, i + n_pos, i + n_pos * 2]]
counter = self.positional_channels
if self.use_expr_deformation:
visualizations["expr_disp"] = enc[..., counter:counter+3]
counter += 3
if self.use_ray_directions:
visualizations["ray_map"] = enc[..., counter:counter+3]
counter += 3
visualizations["ref_mask"] = enc[..., [counter] * 3]
counter += 1
if self.use_crop_mask:
visualizations["crop_mask"] = enc[..., [counter] * 3]
counter += 1
return visualizations
```
## /cap4d/mmdm/conditioning/mesh2img.py
```py path="/cap4d/mmdm/conditioning/mesh2img.py"
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch3d.ops.interp_face_attrs import interpolate_face_attributes
from pytorch3d.renderer import (
BlendParams,
PerspectiveCameras,
hard_rgb_blend,
rasterize_meshes,
)
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.structures.meshes import Meshes
from pytorch3d.io import load_obj
def create_camera_objects(
K: torch.Tensor, RT: torch.Tensor, resolution: Tuple[int, int]
) -> PerspectiveCameras:
"""
Creates pytorch3d camera objects from KRT matrices in the open3d convention.
Parameters
----------
K: torch.Tensor [3, 3]
Camera calibration matrix
RT: torch.Tensor [3, 4]
Transform matrix of the camera (cam to world) in the open3d format
resolution: tuple[int, int]
Resolution (height, width) of the camera
"""
R = RT[:, :3, :3]
tvec = RT[:, :3, 3]
focal_length = torch.stack([K[:, 0, 0], K[:, 1, 1]], dim=-1)
principal_point = K[:, :2, 2]
# Retype the image_size correctly and flip to width, height.
H, W = resolution
img_size = torch.tensor([[W, H]] * len(K), dtype=torch.int, device=K.device)
# Screen to NDC conversion:
# For non square images, we scale the points such that smallest side
# has range [-1, 1] and the largest side has range [-u, u], with u > 1.
# This convention is consistent with the PyTorch3D renderer, as well as
# the transformation function `get_ndc_to_screen_transform`.
scale = img_size.min(dim=1, keepdim=True)[0] / 2.0 # .to(RT)
scale = scale.expand(-1, 2)
c0 = img_size / 2.0
# Get the PyTorch3D focal length and principal point.
focal_pytorch3d = focal_length / scale
p0_pytorch3d = -(principal_point - c0) / scale
# For R, T we flip x, y axes (opencv screen space has an opposite
# orientation of screen axes).
# We also transpose R (opencv multiplies points from the opposite=left side).
R_pytorch3d = R.clone().permute(0, 2, 1)
T_pytorch3d = tvec.clone()
R_pytorch3d[:, :, :2] *= -1
T_pytorch3d[:, :2] *= -1
return PerspectiveCameras(
R=R_pytorch3d,
T=T_pytorch3d,
focal_length=focal_pytorch3d,
principal_point=p0_pytorch3d,
image_size=img_size,
device=K.device,
)
def create_camera_objects_pytorch3d(K, RT, resolution):
"""
Create pytorch3D camera objects from camera parameters
:param K:
:param RT:
:param resolution:
:return:
"""
R = RT[:, :, :3]
T = RT[:, :, 3]
H, W = resolution
img_size = torch.tensor([[H, W]] * len(K), dtype=torch.int, device=K.device)
f = torch.stack((K[:, 0, 0], K[:, 1, 1]), dim=-1)
principal_point = torch.cat([K[:, [0], -1], H - K[:, [1], -1]], dim=1)
cameras = PerspectiveCameras(
R=R,
T=T,
principal_point=principal_point,
focal_length=f,
device=K.device,
image_size=img_size,
in_ndc=False,
)
return cameras
def project_points(
lmks: torch.Tensor,
K: torch.Tensor,
RT: torch.Tensor,
resolution: Tuple[int, int] = None,
) -> torch.Tensor:
"""
Projects 3D points to 2D screen space using pytorch3d cameras
Parameters
----------
lmks: torch.Tensor [B, N, 3]
3D points to project
K, RT, resolution: see create_camera_objects() for definition
Returns
-------
lmks2d: torch.Tensor [B, N, 2]
2D reprojected points
"""
# create cameras
points = torch.cat(
[lmks, torch.ones((*lmks.shape[:-1], 1), device=lmks.device)], dim=-1
)
rt = torch.cat(
[RT, torch.ones((RT.shape[0], 1, 4), device=RT.device)], dim=-2
)
points_3d = (rt @ points.permute(0, 2, 1)).permute(0, 2, 1)
k = K[:, None, None, ...]
points_2d = torch.cat(
[
points_3d[..., [0]] / points_3d[..., [2]] * k[..., 0, 0]
+ k[..., 0, 2],
points_3d[..., [1]] / points_3d[..., [2]] * k[..., 1, 1]
+ k[..., 1, 2],
],
dim=-1,
)
return points_2d
class VertexShader(nn.Module):
def __init__(self) -> None:
super().__init__()
def _get_mesh_ndc(
self,
meshes: Meshes,
cameras: PerspectiveCameras,
) -> Meshes:
eps = None
verts_world = meshes.verts_padded()
verts_view = cameras.get_world_to_view_transform().transform_points(
verts_world, eps=eps
)
projection_trafo = cameras.get_projection_transform().compose(
cameras.get_ndc_camera_transform()
)
verts_ndc = projection_trafo.transform_points(verts_view, eps=eps)
verts_ndc[..., 2] = verts_view[..., 2]
meshes_ndc = meshes.update_padded(new_verts_padded=verts_ndc)
return meshes_ndc
def _get_fragments(
self, cameras, meshes_ndc, img_shape, blur_sigma
) -> Fragments:
znear = None
if cameras is not None:
znear = cameras.get_znear()
if isinstance(znear, torch.Tensor):
znear = znear.min().detach().item()
z_clip = None if znear is None else znear / 2
fragments = rasterize_meshes(
meshes_ndc,
image_size=img_shape,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blur_sigma,
faces_per_pixel=4 if blur_sigma > 0.0 else 1,
bin_size=None,
max_faces_per_bin=None,
clip_barycentric_coords=True,
perspective_correct=cameras is not None,
cull_backfaces=True,
z_clip_value=z_clip,
cull_to_frustum=False,
)
return Fragments(
pix_to_face=fragments[0],
zbuf=fragments[1],
bary_coords=fragments[2],
dists=fragments[3],
)
def _rasterize_property(self, property, fragments):
# rasterize vertex attribute over faces
# prop has to be not packed, [B, F, 3, D] -> [B * F, 3, D]
prop_packed = torch.cat(
[property[i] for i in range(property.shape[0])], dim=0
)
return interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, prop_packed
)
def _rasterize_vertices(
self, vertices: Dict[str, torch.Tensor], fragments: Fragments
):
rasterized_properties = {}
for key, prop in vertices.items():
if key == "positions":
continue
rasterized_properties[key] = self._rasterize_property(
prop, fragments
)
return rasterized_properties
def forward(
self,
vertices: Dict[str, torch.Tensor], # packed vertex properties!
faces,
intrinsics,
extrinsics,
img_shape,
blur_sigma,
return_meshes_and_cameras=False,
):
meshes = Meshes(verts=vertices["positions"], faces=faces)
cameras = None
if intrinsics is not None:
cameras = create_camera_objects(intrinsics, extrinsics, img_shape)
meshes = self._get_mesh_ndc(meshes, cameras)
fragments = self._get_fragments(cameras, meshes, img_shape, blur_sigma)
pixels = self._rasterize_vertices(vertices, fragments)
if return_meshes_and_cameras:
return pixels, fragments, meshes, cameras
else:
return pixels, fragments
class BasePixelShader(nn.Module):
def __init__(self) -> None:
super().__init__()
def sample_texture(self, pixels, texture):
assert "uv_coords" in pixels
pixel_uvs = pixels["uv_coords"]
N, H, W, K_faces, N_f = pixel_uvs.shape
# pixel_uvs: (N, H, W, K_faces, 3) -> (N, K_faces, H, W, 3) -> (N*K_faces, H, W, 3)
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(
N * K_faces, H, W, N_f
)
tex_stack = torch.cat(
[texture[[i]].expand(K_faces, -1, -1, -1) for i in range(N)]
)
tex = F.grid_sample(tex_stack, pixel_uvs[..., :2], align_corners=False)
return tex.reshape(N, K_faces, -1, H, W).permute(0, 3, 4, 1, 2)
def forward(self, fragments, pixels, textures):
raise NotImplementedError()
class TextureShader(BasePixelShader):
def __init__(self) -> None:
super().__init__()
def forward(self, fragments, pixels, texture):
features = self.sample_texture(pixels, texture)
blend_params = BlendParams(
sigma=0.,
gamma=1e-4,
background_color=[0] * texture.shape[1],
# N, H, W, K_faces, C
)
depth = fragments.zbuf[..., None]
depth[depth < 0.0] = 0.0
depth = depth.repeat(1, 1, 1, 1, 3)
img = hard_rgb_blend(features, fragments, blend_params)[..., :-1] # remove alpha channel
depth_img = hard_rgb_blend(depth, fragments, blend_params)
return img.permute(0, 3, 1, 2), depth_img[..., [0]].permute(0, 3, 1, 2)
def vertex_to_face_mask(vert_mask, faces):
face_mask = vert_mask[:, faces[0]].max(dim=-1)[0]
return face_mask
class PropRenderer(nn.Module):
"""
Rasterizes per-vertex properties given ndc meshes.
"""
def __init__(
self,
template_path="./data/assets/flame/cap4d_flame_template.obj",
head_vert_path="./data/assets/flame/head_vertices.txt",
n_mouth_verts=200,
prop_type="verts", # either uv or verts
) -> None:
super().__init__()
self.v_shader = VertexShader()
verts, faces, aux = load_obj(template_path)
self.register_buffer("faces", faces.verts_idx)
self.register_buffer("faces_uvs", faces.textures_idx)
# load old lower neck mask and set all new faces also to zero (because they are from the mouth)
vert_mask = torch.zeros(verts.shape[0]).bool()
head_verts = torch.tensor(np.genfromtxt(head_vert_path)).long()
vert_mask[head_verts] = 1
vert_mask[-n_mouth_verts:] = 1
# convert vert mask to face mask
face_mask = vert_mask[self.faces].max(dim=-1)[0]
self.register_buffer("face_mask", face_mask)
if prop_type == "verts":
self.register_buffer("props", verts)
# normalize:
self.props = self.props - self.props.mean(dim=-2, keepdim=True)
self.props = self.props / self.props.max()
elif prop_type == "uvs":
self.register_buffer("props", aux.verts_uvs)
self.props = self.props * 2. - 1.
self.props[..., 1] = -self.props[..., 1]
def render(
self,
vertices,
img_shape,
prop=None,
):
b = vertices.shape[0]
props_unpacked = self.props[self.faces][None].repeat(b, 1, 1, 1)
verts = {
"positions": vertices,
"prop": props_unpacked,
}
if prop is not None:
add_prop = prop[:, self.faces]
verts["add_prop"] = add_prop
pixels, fragments = self.v_shader(
verts,
self.faces[None].repeat(vertices.shape[0], 1, 1),
None,
None,
img_shape,
0.
)
img = pixels["prop"][..., 0, :]
if prop is not None:
img = torch.cat([img, pixels["add_prop"][..., 0, :]], dim=-1)
render_mask = fragments.pix_to_face != -1
face_mask = self.face_mask.repeat(b)
face_masked = face_mask[torch.clamp(fragments.pix_to_face, 0)]
render_mask = torch.logical_and(render_mask, face_masked)
return img, render_mask
```
## /cap4d/mmdm/mmdm.py
```py path="/cap4d/mmdm/mmdm.py"
import einops
import torch
import numpy as np
from functools import partial
from einops import rearrange, repeat
from torchvision.utils import make_grid
from controlnet.ldm.models.diffusion.ddpm import LatentDiffusion
from controlnet.ldm.util import exists, default
from controlnet.ldm.modules.diffusionmodules.util import make_beta_schedule
from controlnet.ldm.models.diffusion.ddim import DDIMSampler
from cap4d.mmdm.utils import shift_schedule, enforce_zero_terminal_snr
class MMLDM(LatentDiffusion):
"""
Class for morphable multi-view latent diffusion model
"""
def __init__(
self,
control_key,
only_mid_control,
n_frames,
*args,
cfg_probability=0.1,
shift_schedule=False,
sqrt_shift=False,
zero_snr_shift=True,
minus_one_shift=True,
negative_shift=False,
**kwargs
):
self.n_frames = n_frames
self.shift_schedule = shift_schedule
self.sqrt_shift = sqrt_shift
self.minus_one_shift = minus_one_shift
self.control_key = control_key
self.only_mid_control = only_mid_control
self.cfg_probability = cfg_probability
self.negative_shift = negative_shift
self.zero_snr_shift = zero_snr_shift
super().__init__(*args, **kwargs)
# @torch.no_grad()
def get_input(self, batch, k, bs=None, force_conditional=False, *args, **kwargs):
with torch.no_grad():
# From DDPM
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, 'b t h w c -> b t c h w')
x = x.to(memory_format=torch.contiguous_format) # .float() CONTIGUOUS
# From LatentDiffusion
if bs is not None:
x = x[:bs]
b_, t_ = x.shape[:2]
x_flat = einops.rearrange(x, 'b t c h w -> (b t) c h w')
encoder_posterior = self.encode_first_stage(x_flat)
z_flat = self.get_first_stage_encoding(encoder_posterior).detach()
z = einops.rearrange(z_flat, '(b t) c h w -> b t c h w', b=b_)
# Add gt z to control
batch[self.control_key]['z'] = z.detach()
c_uncond = self.get_unconditional_conditioning(batch[self.control_key])
if "mask" in batch:
loss_mask = batch["mask"]
else:
loss_mask = None
c_cond = self.get_learned_conditioning(batch[self.control_key])
if not force_conditional:
is_uncond = torch.rand(b_, device=x.device) < self.cfg_probability # do a mix with probability
is_cond = torch.logical_not(is_uncond)
control = {}
for key in c_cond:
control[key] = (
einops.einsum(c_uncond[key], is_uncond, 'b ..., b -> b ...') +
einops.einsum(c_cond[key], is_cond, 'b ..., b -> b ...')
)
else:
control = c_cond
# New stuff
assert isinstance(control, dict)
if bs is not None:
for key in control:
control[key] = control[key][:bs]
return z, dict(c_concat=[control], c_uncond=[c_uncond], mask=loss_mask)
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False):
b_, t_ = z.shape[:2]
z = einops.rearrange(z, 'b t c h w -> (b t) c h w')
z = super().decode_first_stage(z, predict_cids)
return einops.rearrange(z, '(b t) c h w -> b t c h w', b=b_)
def forward(self, x, c, *args, **kwargs):
t = torch.randint(0, self.num_timesteps, x.shape[:2], device=self.device).long()
assert c is not None
assert not self.shorten_cond_schedule
return self.p_losses(x, c, t, *args, **kwargs)
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
assert isinstance(cond, dict)
diffusion_model = self.model.diffusion_model
cond_txt = None # remove text conditioning
assert len(cond['c_concat']) == 1
control = cond['c_concat'][0]
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
return eps
def p_losses(self, x_start, cond, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
b_, t_ = t.shape[:2]
t_flat = einops.rearrange(t, 'b t -> (b t)')
noise_flat = einops.rearrange(noise, 'b t c h w -> (b t) c h w')
x_start_flat = einops.rearrange(x_start, 'b t c h w -> (b t) c h w')
x_noisy_flat = self.q_sample(x_start=x_start_flat, t=t_flat, noise=noise_flat)
x_noisy = einops.rearrange(x_noisy_flat, '(b t) c h w -> b t c h w', b=b_)
model_output = self.apply_model(x_noisy, t, cond)
loss_dict = {}
prefix = 'train' if self.training else 'val'
assert self.parameterization == 'eps'
target = noise
loss_simple = self.get_loss(model_output, target, mean=False)
# Mask loss by references
loss_simple = loss_simple.mean(dim=[2, 3, 4])
ref_mask = torch.logical_not(cond['c_concat'][0]['ref_mask'])
loss_simple_mean = (loss_simple * ref_mask).sum(dim=-1) / ref_mask.sum(dim=-1)
# Losses updated with time dimension
loss_dict.update({f'{prefix}/loss_simple': loss_simple_mean.mean()})
logvar_t = self.logvar[t] # .to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
loss = (loss * ref_mask).sum(dim=-1) / ref_mask.sum(dim=-1)
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
if self.learn_logvar:
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
loss_dict.update({'logvar': self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(2, 3, 4))
loss_vlb = (self.lvlb_weights[t] * loss_vlb * ref_mask).sum(dim=-1) / ref_mask.sum(dim=-1)
loss_vlb = loss_vlb.mean()
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
loss += (self.original_elbo_weight * loss_vlb)
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
# @torch.no_grad()
def get_learned_conditioning(self, c):
return self.cond_stage_model(c, unconditional=False)
@torch.no_grad()
def get_unconditional_conditioning(self, c):
return self.cond_stage_model(c, unconditional=True)
@torch.no_grad()
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
use_ddim = ddim_steps is not None
log = dict()
z, c = self.get_input(batch, self.first_stage_key, bs=N, force_conditional=True)
c_cat = c["c_concat"][0]
c_uncond = c["c_uncond"][0]
for key in c_cat:
c_cat[key] = c_cat[key][:N]
N = min(z.shape[0], N)
n_row = min(z.shape[0], n_row)
log["reconstruction"] = self.decode_first_stage(z)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t], device=self.device), '1 -> b', b=n_row)
t = t.long() # .to(self.device)
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [None]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=c_uncond,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
ddim_sampler = DDIMSampler(self)
# b, c, h, w = cond["c_concat"][0].shape
# shape = (self.channels, h // 8, w // 8)
shape = (self.n_frames, 4, self.image_size, self.image_size) # 64, 64)
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
return samples, intermediates
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.diffusion_model.parameters())
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list([param for param in self.cond_stage_model.parameters() if param.requires_grad])
for n, param in self.cond_stage_model.named_parameters():
if param.requires_grad:
print(n)
opt = torch.optim.AdamW(params, lr=lr)
return opt
def low_vram_shift(self, is_diffusing):
if is_diffusing:
self.model = self.model.cuda()
# self.control_model = self.control_model.cuda()
self.first_stage_model = self.first_stage_model.cpu()
self.cond_stage_model = self.cond_stage_model.cpu()
else:
self.model = self.model.cpu()
# self.control_model = self.control_model.cpu()
self.first_stage_model = self.first_stage_model.cuda()
self.cond_stage_model = self.cond_stage_model.cuda()
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
if self.zero_snr_shift:
print(f"Enforcing zero terminal SNR in noise schedule.")
betas = enforce_zero_terminal_snr(betas)
betas[betas > 0.99] = 0.99
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
if self.shift_schedule:
n_gen = self.n_frames
if self.minus_one_shift:
n_gen = n_gen - 1 # we are generating only n_total - 1 frames technically
shift_ratio = (64 ** 2) / (self.image_size ** 2 * n_gen)
if self.negative_shift:
shift_ratio = 1. / shift_ratio
if self.sqrt_shift:
shift_ratio = np.sqrt(shift_ratio)
new_alpha_cumprod, new_betas = shift_schedule(alphas_cumprod, shift_ratio=shift_ratio)
print(f"Shifted log psnr of noise schedule by {shift_ratio}.")
alphas = 1. - new_betas
betas = new_betas
alphas_cumprod = new_alpha_cumprod
print("Using non persistent schedule buffers.")
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas), persistent=False)
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod), persistent=False)
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev), persistent=False)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)), persistent=False)
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)), persistent=False)
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)), persistent=False)
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)), persistent=False)
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)), persistent=False)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance), persistent=False)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))), persistent=False)
self.register_buffer('posterior_mean_coef1', to_torch(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)), persistent=False)
self.register_buffer('posterior_mean_coef2', to_torch(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)), persistent=False)
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
elif self.parameterization == "v":
lvlb_weights = torch.ones_like(self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
else:
raise NotImplementedError("mu not supported")
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
# From LatentDiffusion
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
```
## /cap4d/mmdm/modules/module.py
```py path="/cap4d/mmdm/modules/module.py"
from pathlib import Path
import cv2
import numpy as np
import torch
import pytorch_lightning as pl
from controlnet.ldm.util import instantiate_from_config
def tensor_to_img(img):
img = ((img.permute(1, 2, 0) + 1.) / 2. * 255).clamp(0, 255).detach().cpu().numpy()
return img[..., [2, 1, 0]].astype(np.uint8)
class CAP4DModule(pl.LightningModule):
def __init__(
self,
model_config,
loss_config,
# callback_config,
ckpt_dir,
*args: pl.Any,
**kwargs: pl.Any
) -> None:
super().__init__(*args, **kwargs)
self.model = instantiate_from_config(model_config)
self.loss = instantiate_from_config(loss_config)
self.ckpt_dir = Path(ckpt_dir)
self.epoch_id = 0
def training_step(self, batch, batch_idx):
y = self.model(batch["hint"])
diff = (batch["jpg"] - y["image"]).abs().mean(dim=1, keepdims=True)
loss = (diff * y["mask"]).sum() / y["mask"].sum() # self.loss(batch["jpg"], y)
return loss
def validation_step(self, batch, batch_idx):
y = self.model(batch["hint"])
diff = (batch["jpg"] - y["image"]).abs().mean(dim=1, keepdims=True)
loss = (diff * y["mask"]).sum() / y["mask"].sum() # self.loss(batch["jpg"], y)
if batch_idx == 0:
epoch_dir = self.ckpt_dir / f"epoch_{self.epoch_id:04d}"
epoch_dir.mkdir(exist_ok=True)
print(self.ckpt_dir)
for i in range(y["image"].shape[0]):
cv2.imwrite(str(epoch_dir / f"{i:02d}_pred.png"), tensor_to_img(y["image"][i]))
cv2.imwrite(str(epoch_dir / f"{i:02d}_source.png"), tensor_to_img(batch["hint"]["source_img"][i]))
cv2.imwrite(str(epoch_dir / f"{i:02d}_target.png"), tensor_to_img(batch["jpg"][i]))
self.epoch_id += 1
self.log("loss", loss)
return loss
def configure_optimizers(self) -> pl.Any:
optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
return optimizer
```
## /cap4d/mmdm/net/attention.py
```py path="/cap4d/mmdm/net/attention.py"
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from controlnet.ldm.modules.diffusionmodules.util import checkpoint, GroupNorm32, LayerNorm32
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
# CrossAttn precision handling
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
# Whether to use flash attention
FIX_LEGACY_FAIL = os.environ.get("FIX_LEGACY_FAIL", False)
if FIX_LEGACY_FAIL:
print("================================")
print("Fixing legacy failed k and v layers")
print("================================")
_USE_FP16_ATTENTION = os.environ.get("USE_FP16_ATTENTION", False)
if _USE_FP16_ATTENTION:
print("================================")
print("Using fp16 attention")
print("================================")
_USE_FLASH = os.environ.get("USE_FLASH_ATTENTION", False)
if _USE_FLASH:
from flash_attn import flash_attn_func
print("================================")
print("Using flash attention")
print("================================")
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
# return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
return GroupNorm32(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
def legacy_attention(q, k, v, scale, mask=None):
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
return einsum('b i j, b j d -> b i d', sim, v)
class AttentionModule(nn.Module):
def __init__(
self,
query_dim,
heads=8,
dim_head=64,
dropout=0.,
mode="spatial", # ["spatial", "context", "temporal", "3d"]
context_dim=None,
num_timesteps=0,
):
super().__init__()
inner_dim = dim_head * heads
self.mode = mode
if mode == "context":
assert context_dim is not None
kv_dim = context_dim
elif mode == "spatial":
kv_dim = query_dim
elif mode == "temporal":
kv_dim = query_dim
assert num_timesteps > 0
elif mode == "3d":
kv_dim = query_dim
assert num_timesteps > 0
else:
assert False, f"ERROR: unrecognized mode {mode}"
self.scale = dim_head ** -0.5
self.heads = heads
self.num_timesteps = num_timesteps
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(kv_dim, inner_dim, bias=False)
self.to_v = nn.Linear(kv_dim, inner_dim, bias=False)
self.k_v_fixed = False
is_zero_module = mode == "temporal"
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim) if is_zero_module else zero_module(nn.Linear(inner_dim, query_dim)),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
t = self.num_timesteps
b = x.shape[0] # is batch_size * time_steps
q = self.to_q(x)
if self.mode == "context":
assert context is not None
# context attention
k = self.to_k(context)
v = self.to_v(context)
else:
if FIX_LEGACY_FAIL and not self.k_v_fixed:
with torch.no_grad():
self.to_v.weight.data = self.to_k.weight.data
self.k_v_fixed = True
# self attention
k = self.to_k(x)
v = self.to_v(x)
if _USE_FLASH or XFORMERS_IS_AVAILBLE: # XFORMER ATTENTION
if self.mode == "3d":
q, k, v = map(lambda yt: rearrange(yt, '(b t) n (h d) -> b (n t) h d', h=h, t=t), (q, k, v))
elif self.mode == "temporal":
q, k, v = map(lambda yt: rearrange(yt, '(b t) n (h d) -> (b n) t h d', h=h, t=t), (q, k, v))
elif self.mode == "context" or self.mode == "spatial":
q, k, v = map(lambda yt: rearrange(yt, 'b n (h d) -> b n h d', h=h), (q, k, v))
if _USE_FLASH:
dtype_before = q.dtype
out = flash_attn_func(q.half(), k.half(), v.half()).type(dtype_before)
else:
if _USE_FP16_ATTENTION:
dtype_before = q.dtype
out = xformers.ops.memory_efficient_attention(
q.half(), k.half(), v.half(), attn_bias=None, op=None,
).type(dtype_before)
else:
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=None,
)
if self.mode == "3d":
out = rearrange(out, 'b (n t) h d -> (b t) n (h d)', b=b//t, h=h, t=t)
elif self.mode == "temporal":
out = rearrange(out, '(b n) t h d -> (b t) n (h d)', b=b//t, h=h, t=t)
elif self.mode == "context" or self.mode == "spatial":
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
else: # NORMAL ATTENTION
if self.mode == "3d":
q, k, v = map(lambda yt: rearrange(yt, '(b t) n (h d) -> (b h) (n t) d', h=h, t=t), (q, k, v))
elif self.mode == "temporal":
q, k, v = map(lambda yt: rearrange(yt, '(b t) n (h d) -> (b h n) t d', h=h, t=t), (q, k, v))
elif self.mode == "context" or self.mode == "spatial":
q, k, v = map(lambda yt: rearrange(yt, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
if _USE_FP16_ATTENTION:
dtype_before = q.dtype
out = legacy_attention(q.half(), k.half(), v.half(), self.scale, mask=mask).type(dtype_before)
else:
out = legacy_attention(q, k, v, self.scale, mask=mask)
if self.mode == "3d":
out = rearrange(out, '(b h) (n t) d -> (b t) n (h d)', b=b//t, h=h, t=t)
elif self.mode == "temporal":
out = rearrange(out, '(b h n) t d -> (b t) n (h d)', b=b//t, h=h, t=t)
elif self.mode == "context" or self.mode == "spatial":
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.,
use_context=True,
context_dim=None,
gated_ff=True,
temporal_connection_type="none", # [3d, temporal, none]
num_timesteps=0,
):
super().__init__()
self.temporal_connection_type = temporal_connection_type
if temporal_connection_type != "none":
assert num_timesteps > 0
self.attn1 = AttentionModule(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
mode="spatial" if temporal_connection_type != "3d" else "3d",
num_timesteps=num_timesteps,
) # is a self-attention if not self.disable_self_attn
self.norm1 = LayerNorm32(dim) # nn.LayerNorm(dim)
self.use_context = use_context
if use_context:
self.attn2 = AttentionModule(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
mode="context",
) # is self-attn if context is none
self.norm2 = LayerNorm32(dim) # nn.LayerNorm(dim)
if temporal_connection_type == "temporal":
self.attn_t = AttentionModule(
query_dim=dim,
context_dim=None,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
num_timesteps=num_timesteps
)
self.norm_t = LayerNorm32(dim) # nn.LayerNorm(dim)
self.norm3 = LayerNorm32(dim) # nn.LayerNorm(dim)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
def forward(self, x, context=None):
return self._forward(x, context)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=None) + x
# context attention
if self.use_context:
assert context is not None # DISABLE CROSS ATTENTION IF NO CONTEXT
x = self.attn2(self.norm2(x), context=context) + x
# temporal attention
if self.temporal_connection_type == "temporal":
attn = self.attn_t(self.norm_t(x), context=None)
x = attn + x
# ff and normalization
x = self.ff(self.norm3(x)) + x
return x
class SpatioTemporalTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
dropout=0.,
use_context=True,
context_dim=None,
temporal_connection_type="none", # [3d, temporal, none]
num_timesteps=0,
):
super().__init__()
if use_context:
assert context_dim is not None
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
use_context=use_context,
context_dim=context_dim,
temporal_connection_type=temporal_connection_type,
num_timesteps=num_timesteps,
)]
)
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.has_context = True
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
assert not isinstance(context, list)
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context)
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
return x + x_in
```
## /cap4d/mmdm/net/mmdm_unet.py
```py path="/cap4d/mmdm/net/mmdm_unet.py"
import torch
import torch.nn as nn
import einops
from controlnet.ldm.modules.diffusionmodules.openaimodel import UNetModel
from controlnet.ldm.modules.diffusionmodules.util import (
zero_module,
timestep_embedding,
)
from cap4d.mmdm.net.attention import SpatioTemporalTransformer
class MMDMUnetModel(UNetModel):
def __init__(
self,
*args,
time_steps,
condition_channels=50,
model_channels=320,
image_size=32,
context_dim=1024,
temporal_mode="3d", # ["3d", "temporal"]
**kwargs,
):
assert temporal_mode in ["3d", "temporal"]
self.temporal_mode = temporal_mode
self.time_steps = time_steps
self.use_context = False
super().__init__(*args, model_channels=model_channels, image_size=image_size, context_dim=context_dim, **kwargs)
self.cond_linear = zero_module(nn.Linear(condition_channels, model_channels))
def create_attention_block(
self,
ch,
mult,
use_checkpoint,
num_heads,
dim_head,
transformer_depth,
context_dim,
disable_self_attn,
use_linear,
use_new_attention_order,
use_spatial_transformer,
):
if self.temporal_mode == "temporal":
temporal_connection_type = "temporal"
elif self.temporal_mode == "3d":
if mult >= 2:
temporal_connection_type = "3d"
else:
temporal_connection_type = "none"
return SpatioTemporalTransformer(
ch,
num_heads,
dim_head,
use_context=self.use_context,
context_dim=context_dim,
temporal_connection_type=temporal_connection_type,
num_timesteps=self.time_steps,
)
def forward(self, x, timesteps=None, context=None, control=None, **kwargs):
"""
x (b t h w c): input latent
control (b t h w c_cond): input conditioning signal
"""
z_input = control["z_input"]
ref_mask = control["ref_mask"]
# ground truth noise output
x_input = x - z_input
ref_mask_inv = torch.logical_not(ref_mask)
# replace with input latents where available
x = z_input * ref_mask + x * ref_mask_inv
# Disabling context cross attention for MMDM
assert context == None
# Flatten time dimension for processing
b_, t_ = x.shape[:2]
x = einops.rearrange(x, 'b t c h w -> (b t) c h w')
timesteps = einops.rearrange(timesteps, 'b t -> (b t)')
pos_enc = einops.rearrange(control["pos_enc"], 'b t h w c -> (b t) h w c')
pos_enc = pos_enc.type(self.dtype)
pos_embedding = self.cond_linear(pos_enc)
pos_embedding = pos_embedding.permute(0, 3, 1, 2)
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
t_emb = t_emb.type(self.dtype)
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
if pos_embedding is not None:
h += pos_embedding
pos_embedding = None
hs.append(h)
h = self.middle_block(h, emb, context)
for i, module in enumerate(self.output_blocks):
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = self.out(h)
h = h.type(x.dtype)
# Unflatten time dimension
h = einops.rearrange(h, '(b t) c h w -> b t c h w', b=b_)
# replace with input latents where available
h = x_input * ref_mask + h * ref_mask_inv
return h
```
## /cap4d/mmdm/sampler.py
```py path="/cap4d/mmdm/sampler.py"
from typing import Tuple, Dict
import torch
import numpy as np
from tqdm import tqdm
from controlnet.ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
)
class StochasticIOSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
if isinstance(model, dict):
# model distributed on different devices
self.device_model_map = model
else:
if torch.cuda.is_available():
self.device_model_map = {"cuda": model}
else:
self.device_model_map = {"cpu": model}
for key in self.device_model_map:
self.main_model = self.device_model_map[key]
self.ddpm_num_timesteps = self.main_model.num_timesteps
break
def register_buffer(self, name, attr):
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.main_model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32) # .to(self.main_model.device)
self.register_buffer('betas', to_torch(self.main_model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.main_model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.detach().cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.detach().cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.detach().cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.detach().cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.detach().cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.detach().cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(
self,
S: int,
ref_cond: Dict[str, torch.Tensor],
ref_uncond: Dict[str, torch.Tensor],
gen_cond: Dict[str, torch.Tensor],
gen_uncond: Dict[str, torch.Tensor],
latent_shape: Tuple[int, int, int],
V: int = 8,
R_max: int = 4,
cfg_scale: float = 1.,
eta: float = 0.,
verbose: bool = False,
):
"""
Generate images from reference images using Stochastic I/O conditioning.
Parameters:
S (int): Number of diffusion steps.
ref_cond (Dict[str, torch.Tensor]): Conditioning images used for reference (ref latents, pose maps, reference masks etc.).
ref_uncond (Dict[str, torch.Tensor]): Unconditional conditioning images used for reference (zeroed conditioning).
gen_cond (Dict[str, torch.Tensor]): Conditioning images used for reference (pose maps, reference masks etc.).
gen_uncond (Dict[str, torch.Tensor]): Unconditional conditioning images used for reference (pose maps, reference masks etc.).
latent_shape (Tuple[int]): Shape of the latent to be generated (B, C, H, W).
V (int): Number of views supported by the MMDM.
R_max (int, optional): Maximum number of reference images to use. Defaults to 4.
cfg_scale (float, optional): Classifier-free guidance scale. Higher values increase conditioning strength. Defaults to 1.0.
eta (float, optional): Noise scaling factor for DDIM sampling. 0 means deterministic sampling. Defaults to 0.
verbose (bool, optional): Whether to print detailed logs during sampling. Defaults to False.
Returns:
torch.Tensor: A tensor representing the generated sample(s) in latent space.
"""
mem_device = next(iter(gen_cond.items()))[1].device
n_devices = len(self.device_model_map)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
n_gen = next(iter(gen_cond.items()))[1].shape[0]
n_all_ref = next(iter(ref_cond.items()))[1].shape[0]
R = min(n_all_ref, R_max)
assert n_gen % (V - R) == 0, f"number of generated images ({n_gen}) has to be divisible by G ({V-R})" # has to be divisible for now
n_its = n_gen // (V - R)
# store all latents on CPU (to prevent using too much GPU memory)
all_x_T = torch.randn((n_gen, *latent_shape), device=mem_device)
all_e_t = torch.zeros_like(all_x_T)
timesteps = self.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running stochastic I/O sampling with {total_steps} timesteps, {R} reference images and {n_gen} generated images")
iterator = tqdm(time_range, desc='Stochastic I/O sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((1, V), step, device=mem_device, dtype=torch.long)
# reset e_t accumulator
all_e_t = all_e_t * 0.
# gather ref and gen batches
if R == 1:
ref_batches = np.zeros((n_its, R), dtype=np.int64)
else:
ref_batches = np.stack([
np.random.permutation(np.arange(n_all_ref))[:R] for _ in range(n_its)
], axis=0)
gen_batches = np.reshape(np.random.permutation(np.arange(n_gen)), (n_its, -1))
def dict_sample(in_dict, indices, device=None):
out_dict = {}
for key in in_dict:
if device is None:
out_dict[key] = in_dict[key][indices]
else:
out_dict[key] = in_dict[key][indices].to(device)
return out_dict
# Prepare input to GPUs
batch_indices = [] # [[b] for b in range(n_its)]
for l in range(int(np.ceil(n_its / n_devices))):
device_batch = []
for device_id in range(min(n_devices, n_its)):
if l * n_devices + device_id < n_its:
device_batch.append([l * n_devices + device_id])
batch_indices.append(device_batch)
# Go through all gen_batches and apply noise update
for dev_batches in batch_indices:
x_in_list = []
t_in_list = []
c_in_list = []
e_t_list = []
for dev_id, dev_batch in enumerate(dev_batches):
dev_key = list(self.device_model_map)[dev_id]
dev_device = self.device_model_map[dev_key].device
curr_ref_cond = dict_sample(ref_cond, ref_batches[dev_batch], device=dev_device)
curr_ref_uncond = dict_sample(ref_uncond, ref_batches[dev_batch], device=dev_device)
curr_gen_cond = dict_sample(gen_cond, gen_batches[dev_batch], device=dev_device)
curr_gen_uncond = dict_sample(gen_uncond, gen_batches[dev_batch], device=dev_device)
curr_x_T = all_x_T[gen_batches[dev_batch]].to(dev_device) # making batch_size = 1 this way
curr_cond = {}
curr_uncond = {}
c_in = {}
for key in curr_ref_cond:
curr_cond[key] = torch.cat([curr_ref_cond[key], curr_gen_cond[key]], dim=1)
curr_uncond[key] = torch.cat([curr_ref_uncond[key], curr_gen_uncond[key]], dim=1)
c_in[key] = torch.cat([curr_uncond[key], curr_cond[key]], dim=0) # stack them to run uncond and cond in one pass
t_in = torch.cat([ts] * 2, dim=0).to(dev_device)
c_in = dict(c_concat=[c_in])
x_in = torch.cat([curr_cond["z_input"][:, :R], curr_x_T], dim=1)
x_in = torch.cat([x_in] * 2, dim=0).to(dev_device)
x_in_list.append(x_in)
t_in_list.append(t_in)
c_in_list.append(c_in)
# Run model in parallel on all available devices
for dev_id, dev_batch in enumerate(dev_batches):
dev_key = list(self.device_model_map)[dev_id]
dev_device = self.device_model_map[dev_key].device
model_uncond, model_t = self.device_model_map[dev_key].apply_model(
x_in_list[dev_id],
t_in_list[dev_id],
c_in_list[dev_id],
).chunk(2)
model_output = model_uncond + cfg_scale * (model_t - model_uncond)
e_t = model_output[:, R:] # eps prediction mode, extract the generation samples starting at n_ref
e_t_list.append(e_t)
for dev_id, dev_batch in enumerate(dev_batches):
all_e_t[gen_batches[dev_batch]] += e_t_list[dev_id].to(mem_device)
alpha_t = self.ddim_alphas.float()[index]
sqrt_one_minus_alpha_t = self.ddim_sqrt_one_minus_alphas[index]
sigma_t = self.ddim_sigmas[index]
alpha_prev_t = torch.tensor(self.ddim_alphas_prev).float()[index]
alpha_prev_t = alpha_prev_t.double()
sqrt_one_minus_alpha_t = sqrt_one_minus_alpha_t.double()
alpha_t = alpha_t.double()
alpha_prev_t = alpha_prev_t.double()
e_t_factor = -alpha_prev_t.sqrt() * sqrt_one_minus_alpha_t / alpha_t.sqrt() + (1. - alpha_prev_t - sigma_t**2).sqrt()
x_t_factor = alpha_prev_t.sqrt() / alpha_t.sqrt()
e_t_factor = e_t_factor.float()
x_t_factor = x_t_factor.float()
all_x_T = all_x_T * x_t_factor + all_e_t * e_t_factor
return all_x_T
```
## /cap4d/mmdm/utils.py
```py path="/cap4d/mmdm/utils.py"
import numpy as np
def shift_schedule(alpha_cumprods, shift_ratio):
# shift_ratio = original_resolution (512) ** 2 / (new_resolution ** 2 * n_images)
sigma_cp = 1. - alpha_cumprods
snr = alpha_cumprods / sigma_cp
# log_snr_shifted = np.log(snr) - np.log(shift_ratio)
log_snr_shifted = np.log(snr) + np.log(shift_ratio)
alpha_shifted = np.exp(log_snr_shifted) / (1 + np.exp(log_snr_shifted))
betas_shifted = 1 - np.concatenate([[1], (alpha_shifted[1:] / alpha_shifted[:-1])])
return alpha_shifted, betas_shifted
# https://arxiv.org/pdf/2305.08891
def enforce_zero_terminal_snr(betas):
# Convert betas to alphas_bar_sqrt
alphas = 1 - betas
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = np.sqrt(alphas_bar)
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt ** 2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = np.concatenate([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
```
## /configs/avatar/debug.yaml
```yaml path="/configs/avatar/debug.yaml"
opt_params:
iterations: 1_000
sh_warmup_iterations: 100
# clip_probability: 0.2
lambda_scale: 1.
threshold_scale: 1.
lambda_xyz: 1e-3
threshold_xyz: 2.
metric_xyz: False
metric_scale: False
feature_lr: 0.0025
opacity_lr: 0.025
scaling_lr: 0.005 # (scaled up according to mean triangle scale) # 0.005 (original)
rotation_lr: 0.001
percent_dense: 0.01
lambda_dssim: 0.4
densification_interval: 200
densify_grad_threshold: 1e-6
opacity_reset_interval: 200
densify_until_iter: 700
densify_from_iter: 80
position_lr_init: 5e-3
position_lr_final: 5e-5
position_lr_delay_mult: 0.01
position_lr_max_steps: 1000
w_lpips: 0.1
lambda_lpips_end: 0.9
lpips_linear_start: 100
lpips_linear_end: 600
deform_net_w_decay: 2e-3
deform_net_lr_init: 1e-5
deform_net_lr_final: 1e-7
deform_net_lr_delay_mult: 0.01
deform_net_lr_max_steps: 1000
lambda_laplacian: 1.0
lambda_relative_deform: 0.4
lambda_relative_rot: 0.005
neck_lr_init: 1e-5
neck_lr_final: 1e-7
neck_lr_delay_mult: 0.01
neck_lr_max_steps: 1000
lambda_neck: 1.0
model_params:
n_unet_layers: 6
n_points_per_triangle: 2
use_lower_jaw: True
static_neck: False
gaussian_init_type: "scaled"
use_expr_mask: True
uv_resolution: 128
n_gaussians_init: 100_000
sh_degree: 3
```
## /configs/avatar/default.yaml
```yaml path="/configs/avatar/default.yaml"
opt_params:
iterations: 100_000
sh_warmup_iterations: 1_000
# clip_probability: 0.2
lambda_scale: 1.
threshold_scale: 1.
lambda_xyz: 1e-3
threshold_xyz: 2.
metric_xyz: False
metric_scale: False
feature_lr: 0.0025
opacity_lr: 0.025
scaling_lr: 0.005 # (scaled up according to mean triangle scale) # 0.005 (original)
rotation_lr: 0.001
percent_dense: 0.01
lambda_dssim: 0.4
densification_interval: 2000
densify_grad_threshold: 1e-6
opacity_reset_interval: 20_000
densify_until_iter: 70_000
densify_from_iter: 8_000
position_lr_init: 5e-3
position_lr_final: 5e-5
position_lr_delay_mult: 0.01
position_lr_max_steps: 100_000
w_lpips: 0.1
lambda_lpips_end: 0.9
lpips_linear_start: 10_000
lpips_linear_end: 60_000
deform_net_w_decay: 2e-3
deform_net_lr_init: 1e-5
deform_net_lr_final: 1e-7
deform_net_lr_delay_mult: 0.01
deform_net_lr_max_steps: 100_000
lambda_laplacian: 1.0
lambda_relative_deform: 0.4
lambda_relative_rot: 0.005
neck_lr_init: 1e-5
neck_lr_final: 1e-7
neck_lr_delay_mult: 0.01
neck_lr_max_steps: 100_000
lambda_neck: 1.0
model_params:
n_unet_layers: 6
n_points_per_triangle: 2
use_lower_jaw: True
static_neck: False
gaussian_init_type: "scaled"
use_expr_mask: True
uv_resolution: 128
n_gaussians_init: 100_000
sh_degree: 3
```
## /configs/avatar/high_quality.yaml
```yaml path="/configs/avatar/high_quality.yaml"
opt_params:
iterations: 100_000
sh_warmup_iterations: 1_000
# clip_probability: 0.2
lambda_scale: 1.
threshold_scale: 1.
lambda_xyz: 1e-3
threshold_xyz: 2.
metric_xyz: False
metric_scale: False
feature_lr: 0.0025
opacity_lr: 0.025
scaling_lr: 0.005 # (scaled up according to mean triangle scale) # 0.005 (original)
rotation_lr: 0.001
percent_dense: 0.01
lambda_dssim: 0.4
densification_interval: 2000
densify_grad_threshold: 1e-6
opacity_reset_interval: 20_000
densify_until_iter: 70_000
densify_from_iter: 8_000
position_lr_init: 5e-3
position_lr_final: 5e-5
position_lr_delay_mult: 0.01
position_lr_max_steps: 100_000
w_lpips: 0.1
lambda_lpips_end: 0.9
lpips_linear_start: 10_000
lpips_linear_end: 60_000
deform_net_w_decay: 2e-3
deform_net_lr_init: 1e-5
deform_net_lr_final: 1e-7
deform_net_lr_delay_mult: 0.01
deform_net_lr_max_steps: 100_000
lambda_laplacian: 1.0
lambda_relative_deform: 0.4
lambda_relative_rot: 0.005
neck_lr_init: 1e-5
neck_lr_final: 1e-7
neck_lr_delay_mult: 0.01
neck_lr_max_steps: 100_000
lambda_neck: 1.0
model_params:
n_unet_layers: 6
n_points_per_triangle: 2
use_lower_jaw: True
static_neck: False
gaussian_init_type: "scaled"
use_expr_mask: True
uv_resolution: 128
n_gaussians_init: 100_000
sh_degree: 3
```
## /configs/avatar/low_quality.yaml
```yaml path="/configs/avatar/low_quality.yaml"
opt_params:
iterations: 25_000
sh_warmup_iterations: 500
# clip_probability: 0.2
lambda_scale: 1.
threshold_scale: 1.
lambda_xyz: 1e-3
threshold_xyz: 2.
metric_xyz: False
metric_scale: False
feature_lr: 0.0025
opacity_lr: 0.025
scaling_lr: 0.005 # (scaled up according to mean triangle scale) # 0.005 (original)
rotation_lr: 0.001
percent_dense: 0.01
lambda_dssim: 0.4
densification_interval: 1000
densify_grad_threshold: 1e-6
opacity_reset_interval: 10_000
densify_until_iter: 24_000
densify_from_iter: 2_000
position_lr_init: 5e-3
position_lr_final: 5e-5
position_lr_delay_mult: 0.01
position_lr_max_steps: 25_000
w_lpips: 0.1
lambda_lpips_end: 0.9
lpips_linear_start: 4_000
lpips_linear_end: 24_000
deform_net_w_decay: 2e-3
deform_net_lr_init: 1e-5
deform_net_lr_final: 1e-7
deform_net_lr_delay_mult: 0.01
deform_net_lr_max_steps: 25_000
lambda_laplacian: 1.0
lambda_relative_deform: 0.4
lambda_relative_rot: 0.005
neck_lr_init: 1e-5
neck_lr_final: 1e-7
neck_lr_delay_mult: 0.01
neck_lr_max_steps: 25_000
lambda_neck: 1.0
model_params:
n_unet_layers: 6
n_points_per_triangle: 2
use_lower_jaw: True
static_neck: False
gaussian_init_type: "scaled"
use_expr_mask: True
uv_resolution: 128
n_gaussians_init: 25_000
sh_degree: 3
```
## /configs/avatar/medium_quality.yaml
```yaml path="/configs/avatar/medium_quality.yaml"
opt_params:
iterations: 50_000
sh_warmup_iterations: 1_000
# clip_probability: 0.2
lambda_scale: 1.
threshold_scale: 1.
lambda_xyz: 1e-3
threshold_xyz: 2.
metric_xyz: False
metric_scale: False
feature_lr: 0.0025
opacity_lr: 0.025
scaling_lr: 0.005 # (scaled up according to mean triangle scale) # 0.005 (original)
rotation_lr: 0.001
percent_dense: 0.01
lambda_dssim: 0.4
densification_interval: 2000
densify_grad_threshold: 1e-6
opacity_reset_interval: 10_000
densify_until_iter: 40_000
densify_from_iter: 4_000
position_lr_init: 5e-3
position_lr_final: 5e-5
position_lr_delay_mult: 0.01
position_lr_max_steps: 50_000
w_lpips: 0.1
lambda_lpips_end: 0.9
lpips_linear_start: 5_000
lpips_linear_end: 40_000
deform_net_w_decay: 2e-3
deform_net_lr_init: 1e-5
deform_net_lr_final: 1e-7
deform_net_lr_delay_mult: 0.01
deform_net_lr_max_steps: 50_000
lambda_laplacian: 1.0
lambda_relative_deform: 0.4
lambda_relative_rot: 0.005
neck_lr_init: 1e-5
neck_lr_final: 1e-7
neck_lr_delay_mult: 0.01
neck_lr_max_steps: 50_000
lambda_neck: 1.0
model_params:
n_unet_layers: 6
n_points_per_triangle: 2
use_lower_jaw: True
static_neck: False
gaussian_init_type: "scaled"
use_expr_mask: True
uv_resolution: 128
n_gaussians_init: 50_000
sh_degree: 3
```
## /configs/generation/debug.yaml
```yaml path="/configs/generation/debug.yaml"
n_ddim_steps: 10
cfg_scale: 2.0
resolution: 512
seed: 124
R_max: 4
V: 8
ckpt_path: ./data/weights/mmdm/
generation_data:
data_path: ./data/assets/datasets/gen_data.npz
yaw_range: 55
pitch_range: 20
expr_factor: 1.0
n_samples: 28
```
## /configs/generation/default.yaml
```yaml path="/configs/generation/default.yaml"
n_ddim_steps: 100
cfg_scale: 2.0
resolution: 512
seed: 124
R_max: 4
V: 8
ckpt_path: ./data/weights/mmdm/
generation_data:
data_path: ./data/assets/datasets/gen_data.npz
yaw_range: 55
pitch_range: 20
expr_factor: 1.0
n_samples: 840
```
## /configs/generation/high_quality.yaml
```yaml path="/configs/generation/high_quality.yaml"
n_ddim_steps: 100
cfg_scale: 2.0
resolution: 512
seed: 124
R_max: 4
V: 8
ckpt_path: ./data/weights/mmdm/
generation_data:
data_path: ./data/assets/datasets/gen_data.npz
yaw_range: 55
pitch_range: 20
expr_factor: 1.0
n_samples: 840
```
## /configs/generation/low_quality.yaml
```yaml path="/configs/generation/low_quality.yaml"
n_ddim_steps: 50
cfg_scale: 2.0
resolution: 512
seed: 124
R_max: 4
V: 8
ckpt_path: ./data/weights/mmdm/
generation_data:
data_path: ./data/assets/datasets/gen_data.npz
yaw_range: 55
pitch_range: 20
expr_factor: 1.0
n_samples: 210
```
## /configs/generation/medium_quality.yaml
```yaml path="/configs/generation/medium_quality.yaml"
n_ddim_steps: 50
cfg_scale: 2.0
resolution: 512
seed: 124
R_max: 4
V: 8
ckpt_path: ./data/weights/mmdm/
generation_data:
data_path: ./data/assets/datasets/gen_data.npz
yaw_range: 55
pitch_range: 20
expr_factor: 1.0
n_samples: 420
```
## /configs/mmdm/cap4d_mmdm_final.yaml
```yaml path="/configs/mmdm/cap4d_mmdm_final.yaml"
dataset_root: PATH_TO_DATASET
gpu_batch_size: 1
virtual_batch_size: 64
logger_freq: 400
learning_rate: 1e-4
n_steps: 100000
n_ref: 4
save_every_n_steps: 1000
init_path: "models/v2-1_512-ema-pruned.ckpt"
dataset:
target: cap4d.datasets.concat_dataset.ConcatDataset
params:
dataset_configs:
- target: cap4d.datasets.flow_face_dataset.FlowFaceDataset
params:
data_path: ${dataset_root}/mead-compressed/
split: train
source_frame: shuffle
resolution: 512
n_views: 8
max_n_references: ${n_ref}
add_mouth: True
adapter_config:
target: cap4d.datasets.adapters.nersemble_adapter.NersembleAdapter
params:
is_compressed: True
- target: cap4d.datasets.flow_face_dataset.FlowFaceDataset
params:
data_path: ${dataset_root}/vfhq-compressed/
split: train
source_frame: shuffle
resolution: 512
n_views: 8
max_n_references: ${n_ref}
add_mouth: True
adapter_config:
target: cap4d.datasets.adapters.vfhq_adapter.VFHQAdapter
params:
is_compressed: True
- target: cap4d.datasets.flow_face_dataset.FlowFaceDataset
params:
data_path: ${dataset_root}/ava-compressed/
split: train
source_frame: shuffle
resolution: 512
n_views: 8
max_n_references: ${n_ref}
add_mouth: True
adapter_config:
target: cap4d.datasets.adapters.ava_adapter.AvaAdapter
params:
is_compressed: True
- target: cap4d.datasets.flow_face_dataset.FlowFaceDataset
params:
data_path: ${dataset_root}/nersemble-compressed/
split: train
source_frame: shuffle
resolution: 512
n_views: 8
max_n_references: ${n_ref}
add_mouth: True
adapter_config:
target: cap4d.datasets.adapters.nersemble_adapter.NersembleAdapter
params:
is_compressed: True
model:
target: cap4d.mmdm.mmdm.MMLDM
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
n_frames: 8
first_stage_key: "jpg"
cond_stage_key: "txt"
control_key: "hint"
image_size: 64
channels: 4
cond_stage_trainable: False
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
only_mid_control: False
shift_schedule: True
zero_snr_shift: True
sqrt_shift: True
minus_one_shift: True
unet_config:
target: cap4d.mmdm.net.mmdm_unet.MMDMUnetModel
params:
# use_checkpoint: True
image_size: 64 # unused
time_steps: 8
temporal_mode: "3d"
in_channels: 4
out_channels: 4
model_channels: 320
condition_channels: 50 # 42 (pos map) + 3 (expr deform map) + 3 (ray map) + 1 (ref mask) + 1 (crop mask)
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
use_checkpoint: False
legacy: False
first_stage_config:
target: controlnet.ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 512
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: cap4d.mmdm.conditioning.cap4dcond.CAP4DConditioning
params:
image_size: 64
positional_channels: 42
positional_multiplier: 1.
super_resolution: 2
use_ray_directions: True
use_expr_deformation: True
use_crop_mask: True
```
## /controlnet/cldm/ddim_hacked.py
```py path="/controlnet/cldm/ddim_hacked.py"
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.detach().cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.detach().cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.detach().cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.detach().cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.detach().cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.detach().cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if False:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
# C, H, W = shape
size = (batch_size, *shape)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
b_t = shape[:-3] # B, T, C, H, W
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(b_t, step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b_t, device = x.shape[:-3], x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
model_t = self.model.apply_model(x, t, c)
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((*b_t, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((*b_t, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((*b_t, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((*b_t, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
num_reference_steps = timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec
```
## /controlnet/cldm/hack.py
```py path="/controlnet/cldm/hack.py"
import torch
import einops
import controlnet.ldm.modules.encoders.modules
import controlnet.ldm.modules.attention
from transformers import logging
from controlnet.ldm.modules.attention import default
def disable_verbosity():
logging.set_verbosity_error()
print('logging improved.')
return
def enable_sliced_attention():
ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
print('Enabled sliced_attention.')
return
def hack_everything(clip_skip=0):
disable_verbosity()
ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
print('Enabled clip hacks.')
return
# Written by Lvmin
def _hacked_clip_forward(self, text):
PAD = self.tokenizer.pad_token_id
EOS = self.tokenizer.eos_token_id
BOS = self.tokenizer.bos_token_id
def tokenize(t):
return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
def transformer_encode(t):
if self.clip_skip > 1:
rt = self.transformer(input_ids=t, output_hidden_states=True)
return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
else:
return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
def split(x):
return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
def pad(x, p, i):
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
raw_tokens_list = tokenize(text)
tokens_list = []
for raw_tokens in raw_tokens_list:
raw_tokens_123 = split(raw_tokens)
raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
tokens_list.append(raw_tokens_123)
tokens_list = torch.IntTensor(tokens_list).to(self.device)
feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
y = transformer_encode(feed)
z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
return z
# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
limit = k.shape[0]
att_step = 1
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
q_chunks.reverse()
k_chunks.reverse()
v_chunks.reverse()
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
del k, q, v
for i in range(0, limit, att_step):
q_buffer = q_chunks.pop()
k_buffer = k_chunks.pop()
v_buffer = v_chunks.pop()
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
del k_buffer, q_buffer
# attention, what we cannot get enough of, by chunks
sim_buffer = sim_buffer.softmax(dim=-1)
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
del v_buffer
sim[i:i + att_step, :, :] = sim_buffer
del sim_buffer
sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)
```
## /controlnet/cldm/logger.py
```py path="/controlnet/cldm/logger.py"
import os
import numpy as np
import torch
import torchvision
from pathlib import Path
from PIL import Image
from pytorch_lightning.callbacks import Callback
# from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.rank_zero import rank_zero_only
import einops
import torch.nn.functional as F
import gc
class ImageLogger(Callback):
def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
log_images_kwargs=None):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
# @rank_zero_only
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx, gpu_id=0):
root = os.path.join(save_dir, "image_log", split, f"e-{current_epoch:06d}")
rows = []
for k in images:
if len(images[k].shape) == 4:
# Single images
# TODO: Implement
continue
if len(images[k].shape) == 5:
# We have videos
b, t = images[k].shape[:2]
imgs = einops.rearrange(images[k], 'b t c h w -> (b t) c h w')
grid = torchvision.utils.make_grid(imgs, nrow=b * t)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.permute(1, 2, 0).numpy()
grid = (grid * 255).astype(np.uint8)
rows.append(grid)
filename = "gs-{:06}_b-{:06}_{:02d}_i.jpg".format(global_step, batch_idx, gpu_id)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(np.concatenate(rows, axis=0)).save(path)
def log_cond(self, pl_module, batch):
cond_model = pl_module.cond_stage_model
cond_key = pl_module.control_key
c_cond = cond_model(batch[cond_key], conditioned=True)
enc_vis = cond_model.get_vis(c_cond["pos_enc"])
for key in enc_vis:
vis = enc_vis[key]
b_ = vis.shape[0]
vis = einops.rearrange(vis, 'b t h w c -> (b t) c h w')
vis = F.interpolate(vis, scale_factor=8., mode="nearest")
enc_vis[key] = einops.rearrange(vis, '(b t) c h w -> b t c h w', b=b_)
return enc_vis
def log_img(self, pl_module, batch, batch_idx, split="train"):
check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
hasattr(pl_module, "log_images") and
callable(pl_module.log_images) and
self.max_images > 0):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
# draw vertices!
with torch.no_grad():
cond_vis = self.log_cond(pl_module, batch)
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
if not images[k].shape[-1] == 3:
images[k] = torch.clamp(images[k], -1., 1.)
for key in cond_vis:
images[key] = cond_vis[key].detach().cpu().clamp(-1., 1.)
self.log_local(
pl_module.logger.save_dir,
split,
images,
pl_module.global_step,
pl_module.current_epoch,
batch_idx,
gpu_id=pl_module.global_rank,
)
gc.collect()
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
return check_idx % self.batch_freq == 0
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self.disabled:
self.log_img(pl_module, batch, batch_idx, split="train")
```
## /controlnet/cldm/model.py
```py path="/controlnet/cldm/model.py"
import os
import torch
from omegaconf import OmegaConf
from controlnet.ldm.util import instantiate_from_config
def get_state_dict(d):
return d.get('state_dict', d)
def load_state_dict(ckpt_path, location='cpu'):
_, extension = os.path.splitext(ckpt_path)
if extension.lower() == ".safetensors":
import safetensors.torch
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
state_dict = get_state_dict(state_dict)
print(f'Loaded state_dict from [{ckpt_path}]')
return state_dict
def create_model(config_path):
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
print(f'Loaded model config from [{config_path}]')
return model
```
## /controlnet/ldm/data/__init__.py
```py path="/controlnet/ldm/data/__init__.py"
```
## /controlnet/ldm/data/util.py
```py path="/controlnet/ldm/data/util.py"
import torch
from ldm.modules.midas.api import load_midas_transform
class AddMiDaS(object):
def __init__(self, model_type):
super().__init__()
self.transform = load_midas_transform(model_type)
def pt2np(self, x):
x = ((x + 1.0) * .5).detach().cpu().numpy()
return x
def np2pt(self, x):
x = torch.from_numpy(x) * 2 - 1.
return x
def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point
x = self.pt2np(sample['jpg'])
x = self.transform({"image": x})["image"]
sample['midas_in'] = x
return sample
```
## /controlnet/ldm/models/autoencoder.py
```py path="/controlnet/ldm/models/autoencoder.py"
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from controlnet.ldm.modules.diffusionmodules.model import Encoder, Decoder
from controlnet.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from controlnet.ldm.util import instantiate_from_config
from controlnet.ldm.modules.ema import LitEma
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) # .float() CONTIGUOUS
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
```
## /controlnet/ldm/models/diffusion/__init__.py
```py path="/controlnet/ldm/models/diffusion/__init__.py"
```
## /controlnet/ldm/models/diffusion/ddim.py
```py path="/controlnet/ldm/models/diffusion/ddim.py"
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from controlnet.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.detach().cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.detach().cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.detach().cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.detach().cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.detach().cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.detach().cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
if not isinstance(ctmp, dict):
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
# C, H, W = shape
size = (batch_size, *shape)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
b_ = shape[:-3]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(b_, step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b_t, device = x.shape[:-3], x.device
# model_output = self.model.apply_model(x, t, c)
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
# if isinstance(c, dict):
# assert isinstance(unconditional_conditioning, dict)
# c_in = dict()
# for k in c:
# if isinstance(c[k], list):
# c_in[k] = [torch.cat([
# unconditional_conditioning[k][i],
# c[k][i]]) for i in range(len(c[k]))]
# else:
# c_in[k] = torch.cat([
# unconditional_conditioning[k],
# c[k]])
# elif isinstance(c, list):
# c_in = list()
# assert isinstance(unconditional_conditioning, list)
# for i in range(len(c)):
# c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
# else:
# c_in = torch.cat([unconditional_conditioning, c])
# HACK:
c_cond = c["c_concat"][0]
c_uncond = unconditional_conditioning
c_in = dict()
for key in c_cond:
c_in[key] = torch.cat([c_uncond[key], c_cond[key]], dim=0)
c_in = {"c_concat": [c_in]}
###
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((*b_t, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((*b_t, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((*b_t, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((*b_t, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec
```
## /controlnet/ldm/models/diffusion/dpm_solver/__init__.py
```py path="/controlnet/ldm/models/diffusion/dpm_solver/__init__.py"
from .sampler import DPMSolverSampler
```
## /controlnet/ldm/models/diffusion/dpm_solver/sampler.py
```py path="/controlnet/ldm/models/diffusion/dpm_solver/sampler.py"
"""SAMPLING ONLY."""
import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = {
"eps": "noise",
"v": "v"
}
class DPMSolverSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type=MODEL_TYPES[self.model.parameterization],
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
return x.to(device), None
```
## /controlnet/ldm/models/diffusion/plms.py
```py path="/controlnet/ldm/models/diffusion/plms.py"
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next,
dynamic_threshold=dynamic_threshold)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
```
## /controlnet/ldm/models/diffusion/sampling_util.py
```py path="/controlnet/ldm/models/diffusion/sampling_util.py"
import torch
import numpy as np
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)
```
## /controlnet/ldm/modules/attention.py
```py path="/controlnet/ldm/modules/attention.py"
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from controlnet.ldm.modules.diffusionmodules.util import checkpoint
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
# CrossAttn precision handling
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.")
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention
}
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):
super().__init__()
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
for d in range(depth)]
)
if not use_linear:
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
```
## /controlnet/ldm/modules/diffusionmodules/__init__.py
```py path="/controlnet/ldm/modules/diffusionmodules/__init__.py"
```
## /controlnet/ldm/modules/diffusionmodules/upscaling.py
```py path="/controlnet/ldm/modules/diffusionmodules/upscaling.py"
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
from ldm.util import default
class AbstractLowScaleModel(nn.Module):
# for concatenating a downsampled image to the latent representation
def __init__(self, noise_schedule_config=None):
super(AbstractLowScaleModel, self).__init__()
if noise_schedule_config is not None:
self.register_schedule(**noise_schedule_config)
def register_schedule(self, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def forward(self, x):
return x, None
def decode(self, x):
return x
class SimpleImageConcat(AbstractLowScaleModel):
# no noise level conditioning
def __init__(self):
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
self.max_noise_level = 0
def forward(self, x):
# fix to constant noise level
return x, torch.zeros(x.shape[0], device=x.device).long()
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level
def forward(self, x, noise_level=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level)
return z, noise_level
```
## /controlnet/ldm/modules/diffusionmodules/util.py
```py path="/controlnet/ldm/modules/diffusionmodules/util.py"
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import os
import math
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from controlnet.ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
return sigmas, alphas, alphas_prev
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()}
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
class LayerNorm32(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
```
## /controlnet/ldm/modules/distributions/__init__.py
```py path="/controlnet/ldm/modules/distributions/__init__.py"
```
## /controlnet/ldm/modules/distributions/distributions.py
```py path="/controlnet/ldm/modules/distributions/distributions.py"
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
```
## /controlnet/ldm/modules/ema.py
```py path="/controlnet/ldm/modules/ema.py"
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
else torch.tensor(-1, dtype=torch.int))
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def reset_num_updates(self):
del self.num_updates
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
```
## /controlnet/ldm/modules/encoders/__init__.py
```py path="/controlnet/ldm/modules/encoders/__init__.py"
```
## /controlnet/ldm/modules/encoders/modules.py
```py path="/controlnet/ldm/modules/encoders/modules.py"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
import open_clip
from controlnet.ldm.util import default, count_params
OPENAI_CLIP_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073])
OPENAI_CLIP_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711])
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.ucg_rate = ucg_rate
def forward(self, batch, key=None, disable_dropout=False):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
if self.ucg_rate > 0. and not disable_dropout:
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
c = c.long()
c = self.embedding(c)
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs,), device=device) * uc_class
uc = {self.key: uc}
return uc
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
"pooled",
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
def encode(self, text):
return self(text)
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
#"pooled",
"last",
"penultimate"
]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="last", embedding_type="text"):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
assert embedding_type in ["text", "visual"]
self.embedding_type = embedding_type
if embedding_type == "text":
del model.visual
elif embedding_type == "visual":
del model.token_embedding
del model.transformer
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def img_forward(self, img):
img = ((img + 1.) / 2. - OPENAI_CLIP_MEAN[None, :, None, None].to(img.device)) / OPENAI_CLIP_STD[None, :, None, None].to(img.device)
img = F.interpolate(img, (224, 224))
return self.model.visual(img)[:, None, :]
def text_forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, x):
if self.embedding_type == "text":
return self.text_forward(x)
elif self.embedding_type == "visual":
return self.img_forward(x)
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
clip_max_length=77, t5_max_length=77):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
def encode(self, text):
return self(text)
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]
```
## /controlnet/share.py
```py path="/controlnet/share.py"
import config
from controlnet.cldm.hack import disable_verbosity, enable_sliced_attention
disable_verbosity()
if config.save_memory:
enable_sliced_attention()
```
## /data/assets/datasets/gen_data.npz
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/data/assets/datasets/gen_data.npz
## /data/assets/flame/blink_blendshape.npy
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/data/assets/flame/blink_blendshape.npy
## /data/assets/flame/flowface_vertex_mask.npy
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/data/assets/flame/flowface_vertex_mask.npy
## /data/assets/flame/flowface_vertex_weights.npy
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/data/assets/flame/flowface_vertex_weights.npy
## /data/assets/flame/jaw_regressor.npy
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/data/assets/flame/jaw_regressor.npy
## /examples/input/animation/example_video.mp4
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/animation/example_video.mp4
## /examples/input/animation/sequence_00/fit.npz
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/animation/sequence_00/fit.npz
## /examples/input/animation/sequence_00/orbit.npz
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/animation/sequence_00/orbit.npz
## /examples/input/animation/sequence_01/fit.npz
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/animation/sequence_01/fit.npz
## /examples/input/animation/sequence_01/orbit.npz
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/animation/sequence_01/orbit.npz
## /examples/input/felix/alignment.npz
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/felix/alignment.npz
## /examples/input/felix/bg/cam0/0000.png
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/felix/bg/cam0/0000.png
## /examples/input/felix/bg/cam0/0001.png
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/felix/bg/cam0/0001.png
## /examples/input/felix/bg/cam0/0002.png
Binary file available at https://raw.githubusercontent.com/felixtaubner/cap4d/refs/heads/main/examples/input/felix/bg/cam0/0002.png
## /examples/input/lincoln/reference_images.json
```json path="/examples/input/lincoln/reference_images.json"
[
["cam0", 0]
]
```
The content has been capped at 50000 tokens. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.