```
├── .gitignore (100 tokens)
├── LICENSE (omitted)
├── README.md (800 tokens)
├── add_ckpt_path.py
├── datasets_preprocess/
├── long_prepare_bonn.py (600 tokens)
├── long_prepare_kitti.py (400 tokens)
├── long_prepare_scannet.py (600 tokens)
├── long_prepare_tum.py (800 tokens)
├── path_to_root.py (100 tokens)
├── demo.py (3.8k tokens)
├── eval/
├── eval.md (500 tokens)
├── mv_recon/
├── base.py (2k tokens)
├── criterion.py (3.3k tokens)
├── data.py (3.6k tokens)
├── dataset_utils/
├── __init__.py
├── corr.py (700 tokens)
├── cropping.py (900 tokens)
├── transforms.py (500 tokens)
├── launch.py (3.5k tokens)
├── run.sh (100 tokens)
├── utils.py (400 tokens)
├── relpose/
├── evo_utils.py (3.1k tokens)
├── launch.py (3.5k tokens)
├── metadata.py (2.1k tokens)
├── run_scannet.sh (200 tokens)
├── run_sintel.sh (100 tokens)
├── run_tum.sh (200 tokens)
├── utils.py (1800 tokens)
├── video_depth/
├── eval_depth.py (3.1k tokens)
├── launch.py (2.4k tokens)
├── metadata.py (1600 tokens)
├── run_bonn.sh (200 tokens)
├── run_kitti.sh (200 tokens)
├── run_sintel.sh (200 tokens)
├── tools.py (2.7k tokens)
├── utils.py (1400 tokens)
├── examples/
├── taylor.mp4
├── westlake.mp4
├── requirements.txt
├── src/
├── croco/
├── LICENSE (400 tokens)
├── NOTICE (100 tokens)
├── README.MD (1600 tokens)
├── assets/
├── Chateau1.png
├── Chateau2.png
├── arch.jpg
├── croco-stereo-flow-demo.ipynb (1000 tokens)
├── datasets/
├── __init__.py
├── crops/
├── README.MD (600 tokens)
├── extract_crops_from_images.py (1100 tokens)
├── habitat_sim/
├── README.MD (700 tokens)
├── __init__.py
├── generate_from_metadata.py (900 tokens)
├── generate_from_metadata_files.py (300 tokens)
├── generate_multiview_images.py (1700 tokens)
├── multiview_habitat_sim_generator.py (4k tokens)
├── pack_metadata_files.py (600 tokens)
├── paths.py (1200 tokens)
├── pairs_dataset.py (1100 tokens)
├── transforms.py (900 tokens)
├── interactive_demo.ipynb (2.1k tokens)
├── models/
├── blocks.py (2.5k tokens)
├── criterion.py (300 tokens)
├── croco.py (2.4k tokens)
├── croco_downstream.py (1000 tokens)
├── curope/
├── __init__.py
├── curope.cpp (500 tokens)
├── curope2d.py (300 tokens)
├── kernels.cu (800 tokens)
├── setup.py (200 tokens)
├── dpt_block.py (3k tokens)
├── head_downstream.py (600 tokens)
├── masking.py (100 tokens)
├── pos_embed.py (1400 tokens)
├── pretrain.py (2.5k tokens)
├── stereoflow/
├── README.MD (2.5k tokens)
├── augmentor.py (2.9k tokens)
├── criterion.py (2.6k tokens)
├── datasets_flow.py (6.7k tokens)
├── datasets_stereo.py (7.2k tokens)
├── download_model.sh (100 tokens)
├── engine.py (2.6k tokens)
├── test.py (2.1k tokens)
├── train.py (3k tokens)
├── utils/
├── misc.py (4k tokens)
├── dust3r/
├── __init__.py
├── blocks.py (3.5k tokens)
├── datasets/
├── __init__.py (500 tokens)
├── arkitscenes.py (1800 tokens)
├── arkitscenes_highres.py (1300 tokens)
├── base/
├── __init__.py
├── base_multiview_dataset.py (4.1k tokens)
├── batched_sampler.py (600 tokens)
├── easy_dataset.py (1200 tokens)
├── bedlam.py (2.9k tokens)
├── blendedmvs.py (2.4k tokens)
├── co3d.py (1500 tokens)
├── cop3d.py (800 tokens)
├── dl3dv.py (1100 tokens)
├── dynamic_replica.py (900 tokens)
├── eden.py (600 tokens)
├── hoi4d.py (600 tokens)
├── hypersim.py (1000 tokens)
├── irs.py (600 tokens)
├── mapfree.py (2.1k tokens)
├── megadepth.py (700 tokens)
├── mp3d.py (1000 tokens)
├── mvimgnet.py (1000 tokens)
├── mvs_synth.py (1000 tokens)
├── omniobject3d.py (1000 tokens)
├── pointodyssey.py (1200 tokens)
├── realestate10k.py (1000 tokens)
├── scannet.py (1000 tokens)
├── scannetpp.py (1400 tokens)
├── smartportraits.py (600 tokens)
├── spring.py (900 tokens)
├── synscapes.py (600 tokens)
├── tartanair.py (1100 tokens)
├── threedkb.py (700 tokens)
├── uasol.py (1000 tokens)
├── unreal4k.py (1000 tokens)
├── urbansyn.py (600 tokens)
├── utils/
├── __init__.py
├── corr.py (800 tokens)
├── cropping.py (900 tokens)
├── transforms.py (600 tokens)
├── vkitti2.py (1100 tokens)
├── waymo.py (1200 tokens)
├── wildrgbd.py (400 tokens)
├── heads/
├── __init__.py (300 tokens)
├── dpt_head.py (1800 tokens)
├── linear_head.py (2.5k tokens)
├── postprocess.py (900 tokens)
├── inference.py (2.6k tokens)
├── losses.py (8.7k tokens)
├── model.py (10.2k tokens)
├── patch_embed.py (600 tokens)
├── post_process.py (400 tokens)
├── utils/
├── __init__.py
├── camera.py (2.9k tokens)
├── device.py (600 tokens)
├── geometry.py (3.6k tokens)
├── image.py (1800 tokens)
├── misc.py (800 tokens)
├── parallel.py (500 tokens)
├── path_to_croco.py (100 tokens)
├── render.py (500 tokens)
├── viz.py (6.7k tokens)
├── train.py (6.3k tokens)
├── viser_utils.py (6.5k tokens)
```
## /.gitignore
```gitignore path="/.gitignore"
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
.tox/
.coverage
.cache
nosetests.xml
coverage.xml
# Translations
*.mo
# Mr Developer
.mr.developer.cfg
.project
.pydevproject
# Rope
.ropeproject
# Django stuff:
*.log
*.pot
# Sphinx documentation
docs/_build/
# Ignore data and ckpts
*.pth
data
src/checkpoints
# Ignore results
tmp
eval_results
.vscode
```
## /README.md
<h2 align="center"> <a href="https://rover-xingyu.github.io/TTT3R">TTT3R: 3D Reconstruction as Test-Time Training</a>
</h2>
<h5 align="center">
[](https://arxiv.org/abs/2509.26645)
[](https://rover-xingyu.github.io/TTT3R)
[](https://x.com/RoverXingyu) [](https://bsky.app/profile/xingyu-chen.bsky.social)
[Xingyu Chen](https://rover-xingyu.github.io/),
[Yue Chen](https://fanegg.github.io/),
[Yuliang Xiu](https://xiuyuliang.cn/),
[Andreas Geiger](https://www.cvlibs.net/),
[Anpei Chen](https://apchenstu.github.io/)
</h5>
<div align="center">
TL;DR: A simple state update rule to enhance length generalization for CUT3R.
</div>
<br>
https://github.com/user-attachments/assets/b7583837-1f1e-43a4-b281-09f340ee5918
## Getting Started
### Installation
1. Clone TTT3R.
```bash
git clone https://github.com/Inception3D/TTT3R.git
cd TTT3R
```
2. Create the environment.
```bash
conda create -n ttt3r python=3.11 cmake=3.14.0
conda activate ttt3r
conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia # use the correct version of cuda for your system
pip install -r requirements.txt
# issues with pytorch dataloader, see https://github.com/pytorch/pytorch/issues/99625
conda install 'llvm-openmp<16'
# for evaluation
pip install evo
pip install open3d
```
3. Compile the cuda kernels for RoPE (as in CroCo v2).
```bash
cd src/croco/models/curope/
python setup.py build_ext --inplace
cd ../../../../
```
### Download Checkpoints
CUT3R provide checkpoints trained on 4-64 views: [`cut3r_512_dpt_4_64.pth`](https://drive.google.com/file/d/1Asz-ZB3FfpzZYwunhQvNPZEUA8XUNAYD/view?usp=drive_link).
To download the weights, run the following commands:
```bash
cd src
gdown --fuzzy https://drive.google.com/file/d/1Asz-ZB3FfpzZYwunhQvNPZEUA8XUNAYD/view?usp=drive_link
cd ..
```
### Inference Demo
To run the inference demo, you can use the following command:
```bash
# input can be a folder or a video
# the following script will run inference with TTT3R and visualize the output with viser on port 8080
CUDA_VISIBLE_DEVICES=6 python demo.py --model_path MODEL_PATH --size 512 \
--seq_path SEQ_PATH --output_dir OUT_DIR --port 8080 \
--model_update_type ttt3r --frame_interval 1 --reset_interval 100 \
--downsample_factor 1000 --vis_threshold 5.0
# Example:
CUDA_VISIBLE_DEVICES=6 python demo.py --model_path src/cut3r_512_dpt_4_64.pth --size 512 \
--seq_path examples/westlake.mp4 --output_dir tmp/taylor --port 8080 \
--model_update_type ttt3r --frame_interval 1 --reset_interval 100 \
--downsample_factor 100 --vis_threshold 6.0
CUDA_VISIBLE_DEVICES=6 python demo.py --model_path src/cut3r_512_dpt_4_64.pth --size 512 \
--seq_path examples/taylor.mp4 --output_dir tmp/taylor --port 8080 \
--model_update_type ttt3r --frame_interval 1 --reset_interval 50 \
--downsample_factor 100 --vis_threshold 10.0
```
Output results will be saved to `output_dir`.
### Evaluation
Please refer to the [eval.md](eval/eval.md) for more details.
## Acknowledgements
Our code is based on the following awesome repositories:
- [CUT3R](https://github.com/CUT3R/CUT3R)
- [Easi3R](https://github.com/Inception3D/Easi3R)
- [DUSt3R](https://github.com/naver/dust3r)
- [MonST3R](https://github.com/Junyi42/monst3r.git)
- [Spann3R](https://github.com/HengyiWang/spann3r.git)
- [Viser](https://github.com/nerfstudio-project/viser)
We thank the authors for releasing their code!
## Citation
If you find our work useful, please cite:
```bibtex
@article{chen2025ttt3r,
title={TTT3R: 3D Reconstruction as Test-Time Training},
author={Chen, Xingyu and Chen, Yue and Xiu, Yuliang and Geiger, Andreas and Chen, Anpei},
journal={arXiv preprint arXiv:2509.26645},
year={2025}
}
```
## /add_ckpt_path.py
```py path="/add_ckpt_path.py"
import sys
import os
import os.path as path
def add_path_to_dust3r(ckpt):
HERE_PATH = os.path.dirname(os.path.abspath(ckpt))
# workaround for sibling import
sys.path.insert(0, HERE_PATH)
```
## /datasets_preprocess/long_prepare_bonn.py
```py path="/datasets_preprocess/long_prepare_bonn.py"
import glob
import os
import shutil
import numpy as np
START_FRAME = 30 # inital frame
for TARGET_FRAMES in [50,100,150,200,250,300,350,400,450,500]:
END_FRAME = START_FRAME + TARGET_FRAMES # end frame
dirs = glob.glob("/home/xingyu/monst3r/data/bonn/rgbd_bonn_dataset/*/")
dirs = sorted(dirs)
# create new base directory
base_new_dir = "./data/long_bonn_s1/rgbd_bonn_dataset/"
os.makedirs(base_new_dir, exist_ok=True)
print(f"specified frame range: {START_FRAME} to {END_FRAME}, target frames: {TARGET_FRAMES}")
# extract frames
for dir in dirs:
# get original directory name
dir_name = os.path.basename(os.path.dirname(dir))
# build new directory path
new_base_dir = base_new_dir + dir_name + '/'
# pre-calculate the actual available frames for each modality
rgb_frames = glob.glob(dir + 'rgb/*.png')
rgb_frames = sorted(rgb_frames)
available_rgb_frames = len(rgb_frames)
depth_frames = glob.glob(dir + 'depth/*.png')
depth_frames = sorted(depth_frames)
available_depth_frames = len(depth_frames)
gt_path = dir + "groundtruth.txt"
gt = np.loadtxt(gt_path)
available_gt_frames = len(gt)
# calculate the actual frames for each modality in the specified range
actual_rgb_frames = min(available_rgb_frames - START_FRAME, TARGET_FRAMES)
actual_depth_frames = min(available_depth_frames - START_FRAME, TARGET_FRAMES)
actual_gt_frames = min(available_gt_frames - START_FRAME, TARGET_FRAMES)
# take the minimum value of the three modalities to ensure consistency
final_frame_count = min(actual_rgb_frames, actual_depth_frames, actual_gt_frames)
print(f" final unified frame count: {final_frame_count}")
# process RGB frames
rgb_frames = rgb_frames[START_FRAME:START_FRAME + final_frame_count]
new_dir = new_base_dir + f'rgb_{TARGET_FRAMES}/'
if os.path.exists(new_dir):
shutil.rmtree(new_dir)
for frame in rgb_frames:
os.makedirs(new_dir, exist_ok=True)
shutil.copy(frame, new_dir)
# process Depth frames
depth_frames = depth_frames[START_FRAME:START_FRAME + final_frame_count]
new_dir = new_base_dir + f'depth_{TARGET_FRAMES}/'
if os.path.exists(new_dir):
shutil.rmtree(new_dir)
for frame in depth_frames:
os.makedirs(new_dir, exist_ok=True)
shutil.copy(frame, new_dir)
# process Groundtruth
gt_final = gt[START_FRAME:START_FRAME + final_frame_count]
gt_file = new_base_dir + f'groundtruth_{TARGET_FRAMES}.txt'
if os.path.exists(gt_file):
os.remove(gt_file)
np.savetxt(gt_file, gt_final)
```
## /datasets_preprocess/long_prepare_kitti.py
```py path="/datasets_preprocess/long_prepare_kitti.py"
from PIL import Image
import numpy as np
def depth_read(filename):
# loads depth map D from png file
# and returns it as a numpy array,
# for details see readme.txt
depth_png = np.array(Image.open(filename), dtype=int)
# make sure we have a proper 16bit depth map here.. not 8bit!
assert(np.max(depth_png) > 255)
depth = depth_png.astype(np.float) / 256.
depth[depth_png == 0] = -1.
return depth
import glob
import os
import shutil
for TARGET_FRAMES in [50,100,150,200,250,300,350,400,450,500]:
depth_dirs = glob.glob("/home/xingyu/monst3r/data/kitti/val/*/proj_depth/groundtruth/image_02")
for dir in depth_dirs:
# new depth dir
new_depth_dir = f"./data/long_kitti_s1/depth_selection/val_selection_cropped/groundtruth_depth_gathered_{TARGET_FRAMES}/" + dir.split("/")[-4]+"_02"
# print(new_depth_dir)
new_image_dir = f"./data/long_kitti_s1/depth_selection/val_selection_cropped/image_gathered_{TARGET_FRAMES}/" + dir.split("/")[-4]+"_02"
os.makedirs(new_depth_dir, exist_ok=True)
os.makedirs(new_image_dir, exist_ok=True)
# get all depth files and calculate the actual frame count
all_depth_files = sorted(glob.glob(dir + "/*.png"))
actual_frames = min(len(all_depth_files), TARGET_FRAMES)
print(f"directory {dir.split('/')[-4]}: target frames {TARGET_FRAMES}, actual available frames {len(all_depth_files)}, actual processed frames {actual_frames}")
for depth_file in all_depth_files[:TARGET_FRAMES]:
new_path = new_depth_dir + "/" + depth_file.split("/")[-1]
shutil.copy(depth_file, new_path)
# get the path of the corresponding image
mid = "_".join(depth_file.split("/")[-5].split("_")[:3])
image_file = depth_file.replace('val', mid).replace('proj_depth/groundtruth/image_02', 'image_02/data')
print(image_file)
# check if the image file exists
if os.path.exists(image_file):
new_path = new_image_dir + "/" + image_file.split("/")[-1]
shutil.copy(image_file, new_path)
else:
print("Image file does not exist: ", image_file)
```
## /datasets_preprocess/long_prepare_scannet.py
```py path="/datasets_preprocess/long_prepare_scannet.py"
import glob
import os
import shutil
import numpy as np
# configurable parameters
for TARGET_FRAMES in [50,90,100,150,200,300,400,500,600,700,800,900,1000]:
SAMPLE_INTERVAL = 3 # sampling interval, take 1 frame every N frames original 3
seq_list = sorted(os.listdir("/home/share/Dataset/3D_scene/ScanNet/scans_test/"))
for seq in seq_list:
img_pathes = sorted(glob.glob(f"/home/share/Dataset/3D_scene/ScanNet/scans_test/{seq}/color/*.jpg"), key=lambda x: int(os.path.basename(x).split('.')[0]))
depth_pathes = sorted(glob.glob(f"/home/share/Dataset/3D_scene/ScanNet/scans_test/{seq}/depth/*.png"), key=lambda x: int(os.path.basename(x).split('.')[0]))
pose_pathes = sorted(glob.glob(f"/home/share/Dataset/3D_scene/ScanNet/scans_test/{seq}/pose/*.txt"), key=lambda x: int(os.path.basename(x).split('.')[0]))
# calculate the required original frame count
required_frames = TARGET_FRAMES * SAMPLE_INTERVAL
total_frames = min(len(img_pathes), len(depth_pathes), len(pose_pathes))
# if the original frame count is not enough, adjust the target frame count
actual_target_frames = min(TARGET_FRAMES, total_frames // SAMPLE_INTERVAL)
print(f"{seq}: original frame count {total_frames}, target frames {TARGET_FRAMES}, actual frames {actual_target_frames}")
# use target frame count to name the directory
new_color_dir = f"./data/long_scannet_s{SAMPLE_INTERVAL}/{seq}/color_{TARGET_FRAMES}"
new_depth_dir = f"./data/long_scannet_s{SAMPLE_INTERVAL}/{seq}/depth_{TARGET_FRAMES}"
# sample according to the target frame count
new_img_pathes = img_pathes[:actual_target_frames*SAMPLE_INTERVAL:SAMPLE_INTERVAL]
new_depth_pathes = depth_pathes[:actual_target_frames*SAMPLE_INTERVAL:SAMPLE_INTERVAL]
new_pose_pathes = pose_pathes[:actual_target_frames*SAMPLE_INTERVAL:SAMPLE_INTERVAL]
# if the target directory exists, delete it
if os.path.exists(new_color_dir):
shutil.rmtree(new_color_dir)
if os.path.exists(new_depth_dir):
shutil.rmtree(new_depth_dir)
os.makedirs(new_color_dir, exist_ok=True)
os.makedirs(new_depth_dir, exist_ok=True)
for i, (img_path, depth_path) in enumerate(zip(new_img_pathes, new_depth_pathes)):
shutil.copy(img_path, f"{new_color_dir}/frame_{i:04d}.jpg")
shutil.copy(depth_path, f"{new_depth_dir}/frame_{i:04d}.png")
# use target frame count to name the pose file
pose_new_path = f"./data/long_scannet_s{SAMPLE_INTERVAL}/{seq}/pose_{TARGET_FRAMES}.txt"
with open(pose_new_path, 'w') as f:
for i, pose_path in enumerate(new_pose_pathes):
with open(pose_path, 'r') as pose_file:
pose = np.loadtxt(pose_file)
pose = pose.reshape(-1)
f.write(f"{' '.join(map(str, pose))}\n")
```
## /datasets_preprocess/long_prepare_tum.py
```py path="/datasets_preprocess/long_prepare_tum.py"
import glob
import os
import shutil
import numpy as np
def read_file_list(filename):
"""
Reads a trajectory from a text file.
File format:
The file format is "stamp d1 d2 d3 ...", where stamp denotes the time stamp (to be matched)
and "d1 d2 d3.." is arbitary data (e.g., a 3D position and 3D orientation) associated to this timestamp.
Input:
filename -- File name
Output:
dict -- dictionary of (stamp,data) tuples
"""
file = open(filename)
data = file.read()
lines = data.replace(","," ").replace("\t"," ").split("\n")
list = [[v.strip() for v in line.split(" ") if v.strip()!=""] for line in lines if len(line)>0 and line[0]!="#"]
list = [(float(l[0]),l[1:]) for l in list if len(l)>1]
return dict(list)
def associate(first_list, second_list, offset, max_difference):
"""
Associate two dictionaries of (stamp, data). As the time stamps never match exactly, we aim
to find the closest match for every input tuple.
Input:
first_list -- first dictionary of (stamp, data) tuples
second_list -- second dictionary of (stamp, data) tuples
offset -- time offset between both dictionaries (e.g., to model the delay between the sensors)
max_difference -- search radius for candidate generation
Output:
matches -- list of matched tuples ((stamp1, data1), (stamp2, data2))
"""
# Convert keys to sets for efficient removal
first_keys = set(first_list.keys())
second_keys = set(second_list.keys())
potential_matches = [(abs(a - (b + offset)), a, b)
for a in first_keys
for b in second_keys
if abs(a - (b + offset)) < max_difference]
potential_matches.sort()
matches = []
for diff, a, b in potential_matches:
if a in first_keys and b in second_keys:
first_keys.remove(a)
second_keys.remove(b)
matches.append((a, b))
matches.sort()
return matches
# create new output directory
output_base_dir = "./data/long_tum_s1/"
SAMPLE_INTERVAL = 1 # sampling interval, take 1 frame every N frames original 3
os.makedirs(output_base_dir, exist_ok=True)
dirs = glob.glob("/home/xingyu/monst3r/data/tum/*/")
dirs = sorted(dirs)
# extract frames
# total_frames_list = []
for TARGET_FRAMES in [50, 100, 150, 200, 300, 400, 500, 600, 700, 800, 900, 1000]:
for dir in dirs:
frames = []
gt = []
first_file = dir + 'rgb.txt'
second_file = dir + 'groundtruth.txt'
first_list = read_file_list(first_file)
second_list = read_file_list(second_file)
matches = associate(first_list, second_list, 0.0, 0.02)
# for a,b in matches[:10]:
# print("%f %s %f %s"%(a," ".join(first_list[a]),b," ".join(second_list[b])))
for a,b in matches:
frames.append(dir + first_list[a][0])
gt.append([b]+second_list[b])
# sample 90 frames at the stride of 3
print(f"process {dir} with {len(frames)} frames")
frames = frames[::SAMPLE_INTERVAL][:TARGET_FRAMES]
# get original directory name as subdirectory name
dir_name = os.path.basename(os.path.dirname(dir))
new_dir = output_base_dir + dir_name + f'/rgb_{TARGET_FRAMES}/'
for frame in frames:
os.makedirs(new_dir, exist_ok=True)
shutil.copy(frame, new_dir)
# print(f'cp {frame} {new_dir}')
gt_90 = gt[::SAMPLE_INTERVAL][:TARGET_FRAMES]
gt_output_file = output_base_dir + dir_name + f'/groundtruth_{TARGET_FRAMES}.txt'
with open(gt_output_file, 'w') as f:
for pose in gt_90:
f.write(f"{' '.join(map(str, pose))}\n")
print(f"get {dir_name} with {len(frames)} frames")
```
## /datasets_preprocess/path_to_root.py
```py path="/datasets_preprocess/path_to_root.py"
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# DUSt3R repo root import
# --------------------------------------------------------
import sys
import os.path as path
HERE_PATH = path.normpath(path.dirname(__file__))
DUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../'))
# workaround for sibling import
sys.path.insert(0, DUST3R_REPO_PATH)
```
## /demo.py
```py path="/demo.py"
#!/usr/bin/env python3
"""
3D Point Cloud Inference and Visualization Script
This script performs inference using the ARCroco3DStereo model and visualizes the
resulting 3D point clouds with the PointCloudViewer. Use the command-line arguments
to adjust parameters such as the model checkpoint path, image sequence directory,
image size, device, etc.
Usage:
python demo.py [--model_path MODEL_PATH] [--seq_path SEQ_PATH] [--size IMG_SIZE]
[--device DEVICE] [--vis_threshold VIS_THRESHOLD] [--output_dir OUT_DIR]
Example:
python demo.py --model_path src/cut3r_512_dpt_4_64.pth \
--seq_path examples/001 --device cuda --size 512
"""
import os
import numpy as np
import torch
import time
import glob
import random
import cv2
import argparse
import tempfile
import shutil
from copy import deepcopy
from add_ckpt_path import add_path_to_dust3r
import imageio.v2 as iio
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from sklearn.decomposition import PCA
import datetime
from tqdm import tqdm
from skimage.filters import threshold_otsu, threshold_multiotsu
from einops import rearrange
# Set random seed for reproducibility.
random.seed(42)
framerate = 30
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Run 3D point cloud inference and visualization using ARCroco3DStereo."
)
parser.add_argument(
"--model_path",
type=str,
default="src/cut3r_512_dpt_4_64.pth",
help="Path to the pretrained model checkpoint.",
)
parser.add_argument(
"--seq_path",
type=str,
default="",
help="Path to the directory containing the image sequence.",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to run inference on (e.g., 'cuda' or 'cpu').",
)
parser.add_argument(
"--size",
type=int,
default="512",
help="Shape that input images will be rescaled to; if using 224+linear model, choose 224 otherwise 512",
)
parser.add_argument(
"--vis_threshold",
type=float,
default=1.5,
help="Visualization threshold for the point cloud viewer. Ranging from 1 to INF",
)
parser.add_argument(
"--output_dir",
type=str,
default="./demo_tmp",
help="value for tempfile.tempdir",
)
parser.add_argument(
"--port",
type=int,
default=7860,
help="port for the point cloud viewer",
)
parser.add_argument(
"--model_update_type",
type=str,
default="cut3r",
help="model update type: cut3r or ttt3r",
)
parser.add_argument(
"--frame_interval",
type=int,
default=1,
help="Frame interval for video processing (e.g., 1 means every frame, 2 means every other frame)",
)
parser.add_argument(
"--reset_interval",
type=int,
default=1000000,
help="Only used for demo, reset state for extremely long sequence, chunks are aligned via global camera poses",
)
parser.add_argument(
"--downsample_factor",
type=int,
default=1,
help="Downsample factor for the point cloud viewer",
)
return parser.parse_args()
def prepare_input(
img_paths, img_mask, size, raymaps=None, raymap_mask=None, revisit=1, update=True, reset_interval=10000
):
"""
Prepare input views for inference from a list of image paths.
Args:
img_paths (list): List of image file paths.
img_mask (list of bool): Flags indicating valid images.
size (int): Target image size.
raymaps (list, optional): List of ray maps.
raymap_mask (list, optional): Flags indicating valid ray maps.
revisit (int): How many times to revisit each view.
update (bool): Whether to update the state on revisits.
Returns:
list: A list of view dictionaries.
"""
# Import image loader (delayed import needed after adding ckpt path).
from src.dust3r.utils.image import load_images
images = load_images(img_paths, size=size)
views = []
if raymaps is None and raymap_mask is None:
# Only images are provided.
for i in range(len(images)):
view = {
"img": images[i]["img"],
"ray_map": torch.full(
(
images[i]["img"].shape[0],
6,
images[i]["img"].shape[-2],
images[i]["img"].shape[-1],
),
torch.nan,
),
"true_shape": torch.from_numpy(images[i]["true_shape"]),
"idx": i,
"instance": str(i),
"camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze(
0
),
"img_mask": torch.tensor(True).unsqueeze(0),
"ray_mask": torch.tensor(False).unsqueeze(0),
"update": torch.tensor(True).unsqueeze(0),
"reset": torch.tensor((i+1) % reset_interval == 0).unsqueeze(0),
}
views.append(view)
if (i+1) % reset_interval == 0:
overlap_view = deepcopy(view)
overlap_view["reset"] = torch.tensor(False).unsqueeze(0)
views.append(overlap_view)
else:
# Combine images and raymaps.
num_views = len(images) + len(raymaps)
assert len(img_mask) == len(raymap_mask) == num_views
assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps)
j = 0
k = 0
for i in range(num_views):
view = {
"img": (
images[j]["img"]
if img_mask[i]
else torch.full_like(images[0]["img"], torch.nan)
),
"ray_map": (
raymaps[k]
if raymap_mask[i]
else torch.full_like(raymaps[0], torch.nan)
),
"true_shape": (
torch.from_numpy(images[j]["true_shape"])
if img_mask[i]
else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]]))
),
"idx": i,
"instance": str(i),
"camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze(
0
),
"img_mask": torch.tensor(img_mask[i]).unsqueeze(0),
"ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0),
"update": torch.tensor(img_mask[i]).unsqueeze(0),
"reset": torch.tensor((i+1) % reset_interval == 0).unsqueeze(0),
}
if img_mask[i]:
j += 1
if raymap_mask[i]:
k += 1
views.append(view)
if (i+1) % reset_interval == 0:
overlap_view = deepcopy(view)
overlap_view["reset"] = torch.tensor(False).unsqueeze(0)
views.append(overlap_view)
assert j == len(images) and k == len(raymaps)
if revisit > 1:
new_views = []
for r in range(revisit):
for i, view in enumerate(views):
new_view = deepcopy(view)
new_view["idx"] = r * len(views) + i
new_view["instance"] = str(r * len(views) + i)
if r > 0 and not update:
new_view["update"] = torch.tensor(False).unsqueeze(0)
new_views.append(new_view)
return new_views
return views
def prepare_output(outputs, outdir, revisit=1, use_pose=True):
"""
Process inference outputs to generate point clouds and camera parameters for visualization.
Args:
outputs (dict): Inference outputs.
revisit (int): Number of revisits per view.
use_pose (bool): Whether to transform points using camera pose.
Returns:
tuple: (points, colors, confidence, camera parameters dictionary)
"""
from src.dust3r.utils.camera import pose_encoding_to_camera
from src.dust3r.post_process import estimate_focal_knowing_depth
from src.dust3r.utils.geometry import geotrf, matrix_cumprod
import roma
from viser_utils import convert_scene_output_to_glb
# Only keep the outputs corresponding to one full pass.
valid_length = len(outputs["pred"]) // revisit
outputs["pred"] = outputs["pred"][-valid_length:]
outputs["views"] = outputs["views"][-valid_length:]
# delet overlaps: reset_mask=True outputs["pred"] and outputs["views"]
reset_mask = torch.cat([view["reset"] for view in outputs["views"]], 0)
shifted_reset_mask = torch.cat([torch.tensor(False).unsqueeze(0), reset_mask[:-1]], dim=0)
outputs["pred"] = [
pred for pred, mask in zip(outputs["pred"], shifted_reset_mask) if not mask]
outputs["views"] = [
view for view, mask in zip(outputs["views"], shifted_reset_mask) if not mask]
reset_mask = reset_mask[~shifted_reset_mask]
pts3ds_self_ls = [output["pts3d_in_self_view"].cpu() for output in outputs["pred"]]
pts3ds_other = [output["pts3d_in_other_view"].cpu() for output in outputs["pred"]]
conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
conf_other = [output["conf"].cpu() for output in outputs["pred"]]
pts3ds_self = torch.cat(pts3ds_self_ls, 0)
# Recover camera poses.
pr_poses = [
pose_encoding_to_camera(pred["camera_pose"].clone()).cpu()
for pred in outputs["pred"]
]
if reset_mask.any():
pr_poses = torch.cat(pr_poses, 0)
identity = torch.eye(4, device=pr_poses.device)
reset_poses = torch.where(reset_mask.unsqueeze(-1).unsqueeze(-1), pr_poses, identity)
cumulative_bases = matrix_cumprod(reset_poses)
shifted_bases = torch.cat([identity.unsqueeze(0), cumulative_bases[:-1]], dim=0)
pr_poses = torch.einsum('bij,bjk->bik', shifted_bases, pr_poses)
# Convert sequence_scale list
pr_poses = list(pr_poses.unsqueeze(1).unbind(0))
R_c2w = torch.cat([pr_pose[:, :3, :3] for pr_pose in pr_poses], 0)
t_c2w = torch.cat([pr_pose[:, :3, 3] for pr_pose in pr_poses], 0)
if use_pose:
transformed_pts3ds_other = []
for pose, pself in zip(pr_poses, pts3ds_self):
transformed_pts3ds_other.append(geotrf(pose, pself.unsqueeze(0)))
pts3ds_other = transformed_pts3ds_other
conf_other = conf_self
# Estimate focal length based on depth.
B, H, W, _ = pts3ds_self.shape
pp = torch.tensor([W // 2, H // 2], device=pts3ds_self.device).float().repeat(B, 1)
focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld")
colors = [
0.5 * (output["img"].permute(0, 2, 3, 1) + 1.0) for output in outputs["views"]
]
cam_dict = {
"focal": focal.cpu().numpy(),
"pp": pp.cpu().numpy(),
"R": R_c2w.cpu().numpy(),
"t": t_c2w.cpu().numpy(),
}
pts3ds_self_tosave = pts3ds_self # B, H, W, 3
depths_tosave = pts3ds_self_tosave[..., 2]
pts3ds_other_tosave = torch.cat(pts3ds_other) # B, H, W, 3
conf_self_tosave = torch.cat(conf_self) # B, H, W
conf_other_tosave = torch.cat(conf_other) # B, H, W
colors_tosave = torch.cat(
[
0.5 * (output["img"].permute(0, 2, 3, 1).cpu() + 1.0)
for output in outputs["views"]
]
) # [B, H, W, 3]
cam2world_tosave = torch.cat(pr_poses) # B, 4, 4
intrinsics_tosave = (
torch.eye(3).unsqueeze(0).repeat(cam2world_tosave.shape[0], 1, 1)
) # B, 3, 3
intrinsics_tosave[:, 0, 0] = focal.detach().cpu()
intrinsics_tosave[:, 1, 1] = focal.detach().cpu()
intrinsics_tosave[:, 0, 2] = pp[:, 0]
intrinsics_tosave[:, 1, 2] = pp[:, 1]
if os.path.exists(os.path.join(outdir, "depth")):
shutil.rmtree(os.path.join(outdir, "depth"))
if os.path.exists(os.path.join(outdir, "conf")):
shutil.rmtree(os.path.join(outdir, "conf"))
if os.path.exists(os.path.join(outdir, "color")):
shutil.rmtree(os.path.join(outdir, "color"))
if os.path.exists(os.path.join(outdir, "camera")):
shutil.rmtree(os.path.join(outdir, "camera"))
os.makedirs(os.path.join(outdir, "depth"), exist_ok=True)
os.makedirs(os.path.join(outdir, "conf"), exist_ok=True)
os.makedirs(os.path.join(outdir, "color"), exist_ok=True)
os.makedirs(os.path.join(outdir, "camera"), exist_ok=True)
for f_id in range(len(pts3ds_self)):
depth = depths_tosave[f_id].cpu().numpy()
conf = conf_self_tosave[f_id].cpu().numpy()
color = colors_tosave[f_id].cpu().numpy()
c2w = cam2world_tosave[f_id].cpu().numpy()
intrins = intrinsics_tosave[f_id].cpu().numpy()
np.save(os.path.join(outdir, "depth", f"{f_id:06d}.npy"), depth)
np.save(os.path.join(outdir, "conf", f"{f_id:06d}.npy"), conf)
iio.imwrite(
os.path.join(outdir, "color", f"{f_id:06d}.png"),
(color * 255).astype(np.uint8),
)
np.savez(
os.path.join(outdir, "camera", f"{f_id:06d}.npz"),
pose=c2w,
intrinsics=intrins,
)
# # convert_scene_output_to_glb(outdir, (colors_tosave * 255).to(torch.uint8), pts3ds_other_tosave, conf_other_tosave > 1, focal, cam2world_tosave, as_pointcloud=True)
return pts3ds_other, colors, conf_other, cam_dict
def parse_seq_path(p, frame_interval=1):
global framerate
if os.path.isdir(p):
all_img_paths = sorted(glob.glob(f"{p}/*"))
img_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'}
img_paths = [path for path in all_img_paths
if os.path.splitext(path.lower())[1] in img_extensions]
if not img_paths:
raise ValueError(f"No image files found in directory {p}")
if frame_interval > 1:
img_paths = img_paths[::frame_interval]
print(f" - Image sequence: Total images: {len(all_img_paths)}, "
f"Frame interval: {frame_interval}, Images to process: {len(img_paths)}")
framerate = 30.0 / frame_interval
tmpdirname = None
else:
cap = cv2.VideoCapture(p)
if not cap.isOpened():
raise ValueError(f"Error opening video file {p}")
video_fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if video_fps == 0:
cap.release()
raise ValueError(f"Error: Video FPS is 0 for {p}")
framerate = video_fps / frame_interval
frame_indices = list(range(0, total_frames, frame_interval))
print(
f" - Video FPS: {video_fps}, Frame Interval: {frame_interval}, Total Frames to Read: {len(frame_indices)}, Processed Framerate: {framerate}"
)
img_paths = []
tmpdirname = tempfile.mkdtemp()
for i in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if not ret:
break
frame_path = os.path.join(tmpdirname, f"frame_{i}.jpg")
cv2.imwrite(frame_path, frame)
img_paths.append(frame_path)
cap.release()
return img_paths, tmpdirname
def run_inference(args):
"""
Execute the full inference and visualization pipeline.
Args:
args: Parsed command-line arguments.
"""
# Set up the computation device.
device = args.device
if device == "cuda" and not torch.cuda.is_available():
print("CUDA not available. Switching to CPU.")
device = "cpu"
# Add the checkpoint path (required for model imports in the dust3r package).
add_path_to_dust3r(args.model_path)
# Import model and inference functions after adding the ckpt path.
from src.dust3r.inference import inference, inference_recurrent, inference_recurrent_lighter
from src.dust3r.model import ARCroco3DStereo
from viser_utils import PointCloudViewer
# Prepare image file paths.
img_paths, tmpdirname = parse_seq_path(args.seq_path, args.frame_interval)
if not img_paths:
print(f"No images found in {args.seq_path}. Please verify the path.")
return
print(f"Found {len(img_paths)} images in {args.seq_path}.")
img_mask = [True] * len(img_paths)
# Prepare input views.
print("Preparing input views...")
views = prepare_input(
img_paths=img_paths,
img_mask=img_mask,
size=args.size,
revisit=1,
update=True,
reset_interval=args.reset_interval
)
if tmpdirname is not None:
shutil.rmtree(tmpdirname)
# Load and prepare the model.
print(f"Loading model from {args.model_path}...")
model = ARCroco3DStereo.from_pretrained(args.model_path).to(device)
model.config.model_update_type = args.model_update_type
model.eval()
# Run inference.
print("Running inference...")
start_time = time.time()
outputs, state_args = inference_recurrent_lighter(views, model, device)
total_time = time.time() - start_time
per_frame_time = total_time / len(views)
FPS_num = 1 / per_frame_time
print(
f"Inference completed in {total_time:.2f} seconds (average {per_frame_time:.2f} s per frame), FPS: {FPS_num:.2f}."
)
# Process outputs for visualization.
print("Preparing output for visualization...")
pts3ds_other, colors, conf, cam_dict = prepare_output(
outputs, args.output_dir, 1, True
)
# Convert tensors to numpy arrays for visualization.
pts3ds_to_vis = [p.cpu().numpy() for p in pts3ds_other]
colors_to_vis = [c.cpu().numpy() for c in colors]
edge_colors = [None] * len(pts3ds_to_vis)
# Create and run the point cloud viewer.
print("Launching point cloud viewer...")
viewer = PointCloudViewer(
model,
state_args,
pts3ds_to_vis,
colors_to_vis,
conf,
cam_dict,
device=device,
edge_color_list=edge_colors,
show_camera=True,
vis_threshold=args.vis_threshold,
size = args.size,
port = args.port,
downsample_factor=args.downsample_factor
)
viewer.run()
def main():
args = parse_args()
if not args.seq_path:
print(
"No inputs found! Please use our gradio demo if you would like to iteractively upload inputs."
)
return
else:
run_inference(args)
if __name__ == "__main__":
main()
```
## /eval/eval.md
# Evaluation
## Datasets
Please follow [MonST3R](https://github.com/Junyi42/monst3r/blob/main/data/evaluation_script.md) and [Spann3R](https://github.com/HengyiWang/spann3r/blob/main/docs/data_preprocess.md) to download **ScanNet**, **TUM-dynamics**,**Sintel**, **Bonn**, **KITTI** and **7scenes** datasets.
### ScanNet
To prepare the **ScanNet** dataset, execute:
```bash
python datasets_preprocess/long_prepare_scannet.py # You may need to change the path of the dataset
```
### TUM-dynamics
To prepare the **TUM-dynamics** dataset, execute:
```bash
python datasets_preprocess/long_prepare_tum.py # You may need to change the path of the dataset
```
### Bonn
To prepare the **Bonn** dataset, execute:
```bash
python datasets_preprocess/long_prepare_bonn.py # You may need to change the path of the dataset
```
### KITTI
To prepare the **KITTI** dataset, execute:
```bash
python datasets_preprocess/long_prepare_kitti.py # You may need to change the path of the dataset
```
# Evaluation Scripts
Results will be saved in `eval_results/*`.
### Camera Pose Estimation
```bash
CUDA_VISIBLE_DEVICES=6,7 bash eval/relpose/run_scannet.sh # You may need to change [--num_processes] to the number of your gpus and choose sequence length in datasets=('scannet_s3_1000')
CUDA_VISIBLE_DEVICES=6,7 bash eval/relpose/run_tum.sh # You may need to change [--num_processes] to the number of your gpus and choose sequence length in datasets=('tum_s1_1000')
CUDA_VISIBLE_DEVICES=6,7 bash eval/relpose/run_sintel.sh # You may need to change [--num_processes] to the number of your gpus
```
### Video Depth
```bash
CUDA_VISIBLE_DEVICES=5 bash eval/video_depth/run_kitti.sh # You may need to change [--num_processes] to the number of your gpus and choose sequence length in datasets=('kitti_s1_500')
CUDA_VISIBLE_DEVICES=5 bash eval/video_depth/run_bonn.sh # You may need to change [--num_processes] to the number of your gpus and choose sequence length in datasets=('bonn_s1_500')
CUDA_VISIBLE_DEVICES=5 bash eval/video_depth/run_sintel.sh # You may need to change [--num_processes] to the number of your gpus
```
### 3D Reconstruction
```bash
CUDA_VISIBLE_DEVICES=5 bash eval/mv_recon/run.sh # You may need to change [--num_processes] to the number of your gpus and hoose sequence length in max_frames
```
## /eval/mv_recon/base.py
```py path="/eval/mv_recon/base.py"
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# base class for implementing datasets
# --------------------------------------------------------
import PIL
import numpy as np
import torch
from eval.mv_recon.dataset_utils.transforms import ImgNorm
from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
import eval.mv_recon.dataset_utils.cropping as cropping
class BaseStereoViewDataset:
"""Define all basic options.
Usage:
class MyDataset (BaseStereoViewDataset):
def _get_views(self, idx, rng):
# overload here
views = []
views.append(dict(img=, ...))
return views
"""
def __init__(
self,
*, # only keyword arguments
split=None,
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
transform=ImgNorm,
aug_crop=False,
seed=None,
):
self.num_views = 2
self.split = split
self._set_resolutions(resolution)
self.transform = transform
if isinstance(transform, str):
transform = eval(transform)
self.aug_crop = aug_crop
self.seed = seed
def __len__(self):
return len(self.scenes)
def get_stats(self):
return f"{len(self)} pairs"
def __repr__(self):
resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
return (
f"""{type(self).__name__}({self.get_stats()},
{self.split=},
{self.seed=},
resolutions={resolutions_str},
{self.transform=})""".replace(
"self.", ""
)
.replace("\n", "")
.replace(" ", "")
)
def _get_views(self, idx, resolution, rng):
raise NotImplementedError()
def __getitem__(self, idx):
if isinstance(idx, tuple):
# the idx is specifying the aspect-ratio
idx, ar_idx = idx
else:
assert len(self._resolutions) == 1
ar_idx = 0
# set-up the rng
if self.seed: # reseed for each __getitem__
self._rng = np.random.default_rng(seed=self.seed + idx)
elif not hasattr(self, "_rng"):
seed = torch.initial_seed() # this is different for each dataloader process
self._rng = np.random.default_rng(seed=seed)
# over-loaded code
resolution = self._resolutions[
ar_idx
] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
views = self._get_views(idx, resolution, self._rng)
# check data-types
for v, view in enumerate(views):
assert (
"pts3d" not in view
), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
view["idx"] = v
# encode the image
width, height = view["img"].size
view["true_shape"] = np.int32((height, width))
view["img"] = self.transform(view["img"])
assert "camera_intrinsics" in view
if "camera_pose" not in view:
view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
else:
assert np.isfinite(
view["camera_pose"]
).all(), f"NaN in camera pose for view {view_name(view)}"
assert "pts3d" not in view
assert "valid_mask" not in view
assert np.isfinite(
view["depthmap"]
).all(), f"NaN in depthmap for view {view_name(view)}"
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
view["pts3d"] = pts3d
view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
# check all datatypes
for key, val in view.items():
res, err_msg = is_good_type(key, val)
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
K = view["camera_intrinsics"]
view["img_mask"] = True
view["ray_mask"] = False
view["ray_map"] = torch.full(
(6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan
)
view["update"] = True
view["reset"] = False
# last thing done!
for view in views:
# transpose to make sure all views are the same size
transpose_to_landscape(view)
# this allows to check whether the RNG is is the same state each time
view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
return views
def _set_resolutions(self, resolutions):
"""Set the resolution(s) of the dataset.
Params:
- resolutions: int or tuple or list of tuples
"""
assert resolutions is not None, "undefined resolution"
if not isinstance(resolutions, list):
resolutions = [resolutions]
self._resolutions = []
for resolution in resolutions:
if isinstance(resolution, int):
width = height = resolution
else:
width, height = resolution
assert isinstance(
width, int
), f"Bad type for {width=} {type(width)=}, should be int"
assert isinstance(
height, int
), f"Bad type for {height=} {type(height)=}, should be int"
assert width >= height
self._resolutions.append((width, height))
def _crop_resize_if_necessary(
self, image, depthmap, intrinsics, resolution, rng=None, info=None
):
"""This function:
- first downsizes the image with LANCZOS inteprolation,
which is better than bilinear interpolation in
"""
if not isinstance(image, PIL.Image.Image):
image = PIL.Image.fromarray(image)
# downscale with lanczos interpolation so that image.size == resolution
# cropping centered on the principal point
W, H = image.size
cx, cy = intrinsics[:2, 2].round().astype(int)
# calculate min distance to margin
min_margin_x = min(cx, W - cx)
min_margin_y = min(cy, H - cy)
assert min_margin_x > W / 5, f"Bad principal point in view={info}"
assert min_margin_y > H / 5, f"Bad principal point in view={info}"
## Center crop
# Crop on the principal point, make it always centered
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
l, t = cx - min_margin_x, cy - min_margin_y
r, b = cx + min_margin_x, cy + min_margin_y
crop_bbox = (l, t, r, b)
image, depthmap, intrinsics = cropping.crop_image_depthmap(
image, depthmap, intrinsics, crop_bbox
)
# # transpose the resolution if necessary
W, H = image.size # new size
assert resolution[0] >= resolution[1]
if H > 1.1 * W:
# image is portrait mode
resolution = resolution[::-1]
elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
# image is square, so we chose (portrait, landscape) randomly
if rng.integers(2):
resolution = resolution[::-1]
# high-quality Lanczos down-scaling
target_resolution = np.array(resolution)
# # if self.aug_crop > 1:
# # target_resolution += rng.integers(0, self.aug_crop)
# if resolution != (224, 224):
# halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8
# ## Recale with max factor, so one of width or height might be larger than target_resolution
# image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh))
# else:
image, depthmap, intrinsics = cropping.rescale_image_depthmap(
image, depthmap, intrinsics, target_resolution
)
# actual cropping (if necessary) with bilinear interpolation
# if resolution == (224, 224):
intrinsics2 = cropping.camera_matrix_of_crop(
intrinsics, image.size, resolution, offset_factor=0.5
)
crop_bbox = cropping.bbox_from_intrinsics_in_out(
intrinsics, intrinsics2, resolution
)
image, depthmap, intrinsics = cropping.crop_image_depthmap(
image, depthmap, intrinsics, crop_bbox
)
return image, depthmap, intrinsics
def is_good_type(key, v):
"""returns (is_good, err_msg)"""
if isinstance(v, (str, int, tuple)):
return True, None
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
return False, f"bad {v.dtype=}"
return True, None
def view_name(view, batch_index=None):
def sel(x):
return x[batch_index] if batch_index not in (None, slice(None)) else x
db = sel(view["dataset"])
label = sel(view["label"])
instance = sel(view["instance"])
return f"{db}/{label}/{instance}"
def transpose_to_landscape(view):
height, width = view["true_shape"]
if width < height:
# rectify portrait to landscape
assert view["img"].shape == (3, height, width)
view["img"] = view["img"].swapaxes(1, 2)
assert view["valid_mask"].shape == (height, width)
view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)
assert view["depthmap"].shape == (height, width)
view["depthmap"] = view["depthmap"].swapaxes(0, 1)
assert view["pts3d"].shape == (height, width, 3)
view["pts3d"] = view["pts3d"].swapaxes(0, 1)
# transpose x and y pixels
view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]
```
## /eval/mv_recon/criterion.py
```py path="/eval/mv_recon/criterion.py"
import torch
import torch.nn as nn
from copy import copy, deepcopy
from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
from dust3r.utils.geometry import inv, geotrf, depthmap_to_pts3d
from dust3r.utils.camera import pose_encoding_to_camera
class BaseCriterion(nn.Module):
def __init__(self, reduction="mean"):
super().__init__()
self.reduction = reduction
class Criterion(nn.Module):
def __init__(self, criterion=None):
super().__init__()
assert isinstance(
criterion, BaseCriterion
), f"{criterion} is not a proper criterion!"
self.criterion = copy(criterion)
def get_name(self):
return f"{type(self).__name__}({self.criterion})"
def with_reduction(self, mode="none"):
res = loss = deepcopy(self)
while loss is not None:
assert isinstance(loss, Criterion)
loss.criterion.reduction = mode # make it return the loss for each sample
loss = loss._loss2 # we assume loss is a Multiloss
return res
class MultiLoss(nn.Module):
"""Easily combinable losses (also keep track of individual loss values):
loss = MyLoss1() + 0.1*MyLoss2()
Usage:
Inherit from this class and override get_name() and compute_loss()
"""
def __init__(self):
super().__init__()
self._alpha = 1
self._loss2 = None
def compute_loss(self, *args, **kwargs):
raise NotImplementedError()
def get_name(self):
raise NotImplementedError()
def __mul__(self, alpha):
assert isinstance(alpha, (int, float))
res = copy(self)
res._alpha = alpha
return res
__rmul__ = __mul__ # same
def __add__(self, loss2):
assert isinstance(loss2, MultiLoss)
res = cur = copy(self)
while cur._loss2 is not None:
cur = cur._loss2
cur._loss2 = loss2
return res
def __repr__(self):
name = self.get_name()
if self._alpha != 1:
name = f"{self._alpha:g}*{name}"
if self._loss2:
name = f"{name} + {self._loss2}"
return name
def forward(self, *args, **kwargs):
loss = self.compute_loss(*args, **kwargs)
if isinstance(loss, tuple):
loss, details = loss
elif loss.ndim == 0:
details = {self.get_name(): float(loss)}
else:
details = {}
loss = loss * self._alpha
if self._loss2:
loss2, details2 = self._loss2(*args, **kwargs)
loss = loss + loss2
details |= details2
return loss, details
class LLoss(BaseCriterion):
"""L-norm loss"""
def forward(self, a, b):
assert (
a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3
), f"Bad shape = {a.shape}"
dist = self.distance(a, b)
if self.reduction == "none":
return dist
if self.reduction == "sum":
return dist.sum()
if self.reduction == "mean":
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
raise ValueError(f"bad {self.reduction=} mode")
def distance(self, a, b):
raise NotImplementedError()
class L21Loss(LLoss):
"""Euclidean distance between 3d points"""
def distance(self, a, b):
return torch.norm(a - b, dim=-1) # normalized L2 distance
L21 = L21Loss()
def get_pred_pts3d(gt, pred, use_pose=False):
if "depth" in pred and "pseudo_focal" in pred:
try:
pp = gt["camera_intrinsics"][..., :2, 2]
except KeyError:
pp = None
pts3d = depthmap_to_pts3d(**pred, pp=pp)
elif "pts3d" in pred:
# pts3d from my camera
pts3d = pred["pts3d"]
elif "pts3d_in_other_view" in pred:
# pts3d from the other camera, already transformed
assert use_pose is True
return pred["pts3d_in_other_view"] # return!
if use_pose:
camera_pose = pred.get("camera_pose")
pts3d = pred.get("pts3d_in_self_view")
assert camera_pose is not None
assert pts3d is not None
pts3d = geotrf(pose_encoding_to_camera(camera_pose), pts3d)
return pts3d
def Sum(losses, masks, conf=None):
loss, mask = losses[0], masks[0]
if loss.ndim > 0:
# we are actually returning the loss for every pixels
if conf is not None:
return losses, masks, conf
return losses, masks
else:
# we are returning the global loss
for loss2 in losses[1:]:
loss = loss + loss2
return loss
def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True):
assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3
assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3)
norm_mode, dis_mode = norm_mode.split("_")
nan_pts = []
nnzs = []
if norm_mode == "avg":
# gather all points together (joint normalization)
for i, pt in enumerate(pts):
nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3)
nan_pts.append(nan_pt)
nnzs.append(nnz)
if fix_first:
break
all_pts = torch.cat(nan_pts, dim=1)
# compute distance to origin
all_dis = all_pts.norm(dim=-1)
if dis_mode == "dis":
pass # do nothing
elif dis_mode == "log1p":
all_dis = torch.log1p(all_dis)
else:
raise ValueError(f"bad {dis_mode=}")
norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8)
else:
raise ValueError(f"Not implemented {norm_mode=}")
norm_factor = norm_factor.clip(min=1e-8)
while norm_factor.ndim < pts[0].ndim:
norm_factor.unsqueeze_(-1)
return norm_factor
def normalize_pointcloud_t(
pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False
):
if gt:
norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
res = []
for i, pt in enumerate(pts):
res.append(pt / norm_factor)
else:
# pts_l, pts_r = pts
# use pts_l and pts_r[-1] as pts to normalize
norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
res = []
for i in range(len(pts)):
res.append(pts[i] / norm_factor)
# res_r.append(pts_r[i] / norm_factor)
# res = [res_l, res_r]
return res, norm_factor
@torch.no_grad()
def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5):
# set invalid points to NaN
_zs = []
for i in range(len(zs)):
valid_mask = valid_masks[i] if valid_masks is not None else None
_z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1)
_zs.append(_z)
_zs = torch.cat(_zs, dim=-1)
# compute median depth overall (ignoring nans)
if quantile == 0.5:
shift_z = torch.nanmedian(_zs, dim=-1).values
else:
shift_z = torch.nanquantile(_zs, quantile, dim=-1)
return shift_z # (B,)
@torch.no_grad()
def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
# set invalid points to NaN
_pts = []
for i in range(len(pts)):
valid_mask = valid_masks[i] if valid_masks is not None else None
_pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3)
_pts.append(_pt)
_pts = torch.cat(_pts, dim=1)
# compute median center
_center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
if z_only:
_center[..., :2] = 0 # do not center X and Y
# compute median norm
_norm = ((_pts - _center) if center else _pts).norm(dim=-1)
scale = torch.nanmedian(_norm, dim=1).values
return _center[:, None, :, :], scale[:, None, None, None]
class Regr3D_t(Criterion, MultiLoss):
def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True):
super().__init__(criterion)
self.norm_mode = norm_mode
self.gt_scale = gt_scale
self.fix_first = fix_first
def get_all_pts3d_t(self, gts, preds, dist_clip=None):
# everything is normalized w.r.t. camera of view1
in_camera1 = inv(gts[0]["camera_pose"])
gt_pts = []
valids = []
pr_pts = []
for i, gt in enumerate(gts):
# in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3
gt_pts.append(geotrf(in_camera1, gt["pts3d"]))
valid = gt["valid_mask"].clone()
if dist_clip is not None:
# points that are too far-away == invalid
dis = gt["pts3d"].norm(dim=-1)
valid = valid & (dis <= dist_clip)
valids.append(valid)
pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True))
# if i != len(gts)-1:
# pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0)))
# if i != 0:
# pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0)))
# pr_pts = (pr_pts_l, pr_pts_r)
if self.norm_mode:
pr_pts, pr_factor = normalize_pointcloud_t(
pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False
)
else:
pr_factor = None
if self.norm_mode and not self.gt_scale:
gt_pts, gt_factor = normalize_pointcloud_t(
gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True
)
else:
gt_factor = None
return gt_pts, pr_pts, gt_factor, pr_factor, valids, {}
def compute_frame_loss(self, gts, preds, **kw):
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
self.get_all_pts3d_t(gts, preds, **kw)
)
pred_pts_l, pred_pts_r = pred_pts
loss_all = []
mask_all = []
conf_all = []
loss_left = 0
loss_right = 0
pred_conf_l = 0
pred_conf_r = 0
for i in range(len(gt_pts)):
# Left (Reference)
if i != len(gt_pts) - 1:
frame_loss = self.criterion(
pred_pts_l[i][masks[i]], gt_pts[i][masks[i]]
)
loss_all.append(frame_loss)
mask_all.append(masks[i])
conf_all.append(preds[i][0]["conf"])
# To compare target/reference loss
if i != 0:
loss_left += frame_loss.cpu().detach().numpy().mean()
pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean()
# Right (Target)
if i != 0:
frame_loss = self.criterion(
pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]]
)
loss_all.append(frame_loss)
mask_all.append(masks[i])
conf_all.append(preds[i - 1][1]["conf"])
# To compare target/reference loss
if i != len(gt_pts) - 1:
loss_right += frame_loss.cpu().detach().numpy().mean()
pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean()
if pr_factor is not None and gt_factor is not None:
filter_factor = pr_factor[pr_factor > gt_factor]
else:
filter_factor = []
if len(filter_factor) > 0:
factor_loss = (filter_factor - gt_factor).abs().mean()
else:
factor_loss = 0.0
self_name = type(self).__name__
details = {
self_name + "_pts3d_1": float(loss_all[0].mean()),
self_name + "_pts3d_2": float(loss_all[1].mean()),
self_name + "loss_left": float(loss_left),
self_name + "loss_right": float(loss_right),
self_name + "conf_left": float(pred_conf_l),
self_name + "conf_right": float(pred_conf_r),
}
return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss
class ConfLoss_t(MultiLoss):
"""Weighted regression by learned confidence.
Assuming the input pixel_loss is a pixel-level regression loss.
Principle:
high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
alpha: hyperparameter
"""
def __init__(self, pixel_loss, alpha=1):
super().__init__()
assert alpha > 0
self.alpha = alpha
self.pixel_loss = pixel_loss.with_reduction("none")
def get_name(self):
return f"ConfLoss({self.pixel_loss})"
def get_conf_log(self, x):
return x, torch.log(x)
def compute_frame_loss(self, gts, preds, **kw):
# compute per-pixel loss
(losses, masks, confs), details, loss_factor = (
self.pixel_loss.compute_frame_loss(gts, preds, **kw)
)
# weight by confidence
conf_losses = []
conf_sum = 0
for i in range(len(losses)):
conf, log_conf = self.get_conf_log(confs[i][masks[i]])
conf_sum += conf.mean()
conf_loss = losses[i] * conf - self.alpha * log_conf
conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
conf_losses.append(conf_loss)
conf_losses = torch.stack(conf_losses) * 2.0
conf_loss_mean = conf_losses.mean()
return (
conf_loss_mean,
dict(
conf_loss_1=float(conf_losses[0]),
conf_loss2=float(conf_losses[1]),
conf_mean=conf_sum / len(losses),
**details,
),
loss_factor,
)
class Regr3D_t_ShiftInv(Regr3D_t):
"""Same than Regr3D but invariant to depth shift."""
def get_all_pts3d_t(self, gts, preds):
# compute unnormalized points
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
super().get_all_pts3d_t(gts, preds)
)
# pred_pts_l, pred_pts_r = pred_pts
gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts]
pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts]
# pred_zs.append(pred_pts_r[-1][..., 2])
# compute median depth
gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None]
pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None]
# subtract the median depth
for i in range(len(gt_pts)):
gt_pts[i][..., 2] -= gt_shift_z
for i in range(len(pred_pts)):
# for j in range(len(pred_pts[i])):
pred_pts[i][..., 2] -= pred_shift_z
monitoring = dict(
monitoring,
gt_shift_z=gt_shift_z.mean().detach(),
pred_shift_z=pred_shift_z.mean().detach(),
)
return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
class Regr3D_t_ScaleInv(Regr3D_t):
"""Same than Regr3D but invariant to depth shift.
if gt_scale == True: enforce the prediction to take the same scale than GT
"""
def get_all_pts3d_t(self, gts, preds):
# compute depth-normalized points
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
super().get_all_pts3d_t(gts, preds)
)
# measure scene scale
# pred_pts_l, pred_pts_r = pred_pts
pred_pts_all = [
x.clone() for x in pred_pts
] # [pred_pt for pred_pt in pred_pts_l]
# pred_pts_all.append(pred_pts_r[-1])
_, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks)
_, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks)
# prevent predictions to be in a ridiculous range
pred_scale = pred_scale.clip(min=1e-3, max=1e3)
# subtract the median depth
if self.gt_scale:
for i in range(len(pred_pts)):
# for j in range(len(pred_pts[i])):
pred_pts[i] *= gt_scale / pred_scale
else:
for i in range(len(pred_pts)):
# for j in range(len(pred_pts[i])):
pred_pts[i] *= pred_scale / gt_scale
for i in range(len(gt_pts)):
gt_pts[i] *= gt_scale / pred_scale
monitoring = dict(
monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()
)
return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv):
# calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
pass
```
## /eval/mv_recon/data.py
```py path="/eval/mv_recon/data.py"
import os
import cv2
import json
import numpy as np
import os.path as osp
from collections import deque
import random
from eval.mv_recon.base import BaseStereoViewDataset
from dust3r.utils.image import imread_cv2
import eval.mv_recon.dataset_utils.cropping as cropping
def shuffle_deque(dq, seed=None):
# Set the random seed for reproducibility
if seed is not None:
random.seed(seed)
# Convert deque to list, shuffle, and convert back
shuffled_list = list(dq)
random.shuffle(shuffled_list)
return deque(shuffled_list)
class SevenScenes(BaseStereoViewDataset):
def __init__(
self,
num_seq=1,
num_frames=5,
min_thresh=10,
max_thresh=100,
test_id=None,
full_video=False,
tuple_list=None,
seq_id=None,
rebuttal=False,
shuffle_seed=-1,
kf_every=1,
max_frames=None,
*args,
ROOT,
**kwargs,
):
self.ROOT = ROOT
super().__init__(*args, **kwargs)
self.num_seq = num_seq
self.num_frames = num_frames
self.max_thresh = max_thresh
self.min_thresh = min_thresh
self.test_id = test_id
self.full_video = full_video
self.kf_every = kf_every
self.seq_id = seq_id
self.rebuttal = rebuttal
self.shuffle_seed = shuffle_seed
self.max_frames = max_frames
# load all scenes
self.load_all_tuples(tuple_list)
self.load_all_scenes(ROOT)
def __len__(self):
if self.tuple_list is not None:
return len(self.tuple_list)
return len(self.scene_list) * self.num_seq
def load_all_tuples(self, tuple_list):
if tuple_list is not None:
self.tuple_list = tuple_list
# with open(tuple_path) as f:
# self.tuple_list = f.read().splitlines()
else:
self.tuple_list = None
def load_all_scenes(self, base_dir):
if self.tuple_list is not None:
# Use pre-defined simplerecon scene_ids
self.scene_list = [
"stairs/seq-06",
"stairs/seq-02",
"pumpkin/seq-06",
"chess/seq-01",
"heads/seq-02",
"fire/seq-02",
"office/seq-03",
"pumpkin/seq-03",
"redkitchen/seq-07",
"chess/seq-02",
"office/seq-01",
"redkitchen/seq-01",
"fire/seq-01",
]
print(f"Found {len(self.scene_list)} sequences in split {self.split}")
return
scenes = os.listdir(base_dir)
file_split = {"train": "TrainSplit.txt", "test": "TestSplit.txt"}[self.split]
self.scene_list = []
for scene in scenes:
if self.test_id is not None and scene != self.test_id:
continue
# read file split
with open(osp.join(base_dir, scene, file_split)) as f:
seq_ids = f.read().splitlines()
for seq_id in seq_ids:
# seq is string, take the int part and make it 01, 02, 03
# seq_id = 'seq-{:2d}'.format(int(seq_id))
num_part = "".join(filter(str.isdigit, seq_id))
seq_id = f"seq-{num_part.zfill(2)}"
if self.seq_id is not None and seq_id != self.seq_id:
continue
self.scene_list.append(f"{scene}/{seq_id}")
print(f"Found {len(self.scene_list)} sequences in split {self.split}")
def _get_views(self, idx, resolution, rng):
if self.tuple_list is not None:
line = self.tuple_list[idx].split(" ")
scene_id = line[0]
img_idxs = line[1:]
else:
scene_id = self.scene_list[idx // self.num_seq]
seq_id = idx % self.num_seq
data_path = osp.join(self.ROOT, scene_id)
num_files = len([name for name in os.listdir(data_path) if "color" in name])
img_idxs = [f"{i:06d}" for i in range(num_files)]
img_idxs = img_idxs[:: self.kf_every]
if self.max_frames is not None:
img_idxs = img_idxs[:self.max_frames]
# Intrinsics used in SimpleRecon
fx, fy, cx, cy = 525, 525, 320, 240
intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
views = []
imgs_idxs = deque(img_idxs)
if self.shuffle_seed >= 0:
imgs_idxs = shuffle_deque(imgs_idxs)
while len(imgs_idxs) > 0:
im_idx = imgs_idxs.popleft()
impath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.color.png")
depthpath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.depth.proj.png")
posepath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.pose.txt")
rgb_image = imread_cv2(impath)
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
depthmap[depthmap == 65535] = 0
depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
depthmap[depthmap > 10] = 0
depthmap[depthmap < 1e-3] = 0
camera_pose = np.loadtxt(posepath).astype(np.float32)
if resolution != (224, 224) or self.rebuttal:
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
)
else:
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
)
W, H = rgb_image.size
cx = W // 2
cy = H // 2
l, t = cx - 112, cy - 112
r, b = cx + 112, cy + 112
crop_bbox = (l, t, r, b)
rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
rgb_image, depthmap, intrinsics, crop_bbox
)
views.append(
dict(
img=rgb_image,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset="7scenes",
label=osp.join(scene_id, im_idx),
instance=impath,
)
)
return views
class DTU(BaseStereoViewDataset):
def __init__(
self,
num_seq=49,
num_frames=5,
min_thresh=10,
max_thresh=30,
test_id=None,
full_video=False,
sample_pairs=False,
kf_every=1,
*args,
ROOT,
**kwargs,
):
self.ROOT = ROOT
super().__init__(*args, **kwargs)
self.num_seq = num_seq
self.num_frames = num_frames
self.max_thresh = max_thresh
self.min_thresh = min_thresh
self.test_id = test_id
self.full_video = full_video
self.kf_every = kf_every
self.sample_pairs = sample_pairs
# load all scenes
self.load_all_scenes(ROOT)
def __len__(self):
return len(self.scene_list) * self.num_seq
def load_all_scenes(self, base_dir):
if self.test_id is None:
self.scene_list = os.listdir(osp.join(base_dir))
print(f"Found {len(self.scene_list)} scenes in split {self.split}")
else:
if isinstance(self.test_id, list):
self.scene_list = self.test_id
else:
self.scene_list = [self.test_id]
print(f"Test_id: {self.test_id}")
def load_cam_mvsnet(self, file, interval_scale=1):
"""read camera txt file"""
cam = np.zeros((2, 4, 4))
words = file.read().split()
# read extrinsic
for i in range(0, 4):
for j in range(0, 4):
extrinsic_index = 4 * i + j + 1
cam[0][i][j] = words[extrinsic_index]
# read intrinsic
for i in range(0, 3):
for j in range(0, 3):
intrinsic_index = 3 * i + j + 18
cam[1][i][j] = words[intrinsic_index]
if len(words) == 29:
cam[1][3][0] = words[27]
cam[1][3][1] = float(words[28]) * interval_scale
cam[1][3][2] = 192
cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2]
elif len(words) == 30:
cam[1][3][0] = words[27]
cam[1][3][1] = float(words[28]) * interval_scale
cam[1][3][2] = words[29]
cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2]
elif len(words) == 31:
cam[1][3][0] = words[27]
cam[1][3][1] = float(words[28]) * interval_scale
cam[1][3][2] = words[29]
cam[1][3][3] = words[30]
else:
cam[1][3][0] = 0
cam[1][3][1] = 0
cam[1][3][2] = 0
cam[1][3][3] = 0
extrinsic = cam[0].astype(np.float32)
intrinsic = cam[1].astype(np.float32)
return intrinsic, extrinsic
def _get_views(self, idx, resolution, rng):
scene_id = self.scene_list[idx // self.num_seq]
seq_id = idx % self.num_seq
print("Scene ID:", scene_id)
image_path = osp.join(self.ROOT, scene_id, "images")
depth_path = osp.join(self.ROOT, scene_id, "depths")
mask_path = osp.join(self.ROOT, scene_id, "binary_masks")
cam_path = osp.join(self.ROOT, scene_id, "cams")
pairs_path = osp.join(self.ROOT, scene_id, "pair.txt")
if not self.full_video:
img_idxs = self.sample_pairs(pairs_path, seq_id)
else:
img_idxs = sorted(os.listdir(image_path))
img_idxs = img_idxs[:: self.kf_every]
views = []
imgs_idxs = deque(img_idxs)
while len(imgs_idxs) > 0:
im_idx = imgs_idxs.pop()
impath = osp.join(image_path, im_idx)
depthpath = osp.join(depth_path, im_idx.replace(".jpg", ".npy"))
campath = osp.join(cam_path, im_idx.replace(".jpg", "_cam.txt"))
maskpath = osp.join(mask_path, im_idx.replace(".jpg", ".png"))
rgb_image = imread_cv2(impath)
depthmap = np.load(depthpath)
depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0)
mask = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED) / 255.0
mask = mask.astype(np.float32)
mask[mask > 0.5] = 1.0
mask[mask < 0.5] = 0.0
mask = cv2.resize(
mask,
(depthmap.shape[1], depthmap.shape[0]),
interpolation=cv2.INTER_NEAREST,
)
kernel = np.ones((10, 10), np.uint8) # Define the erosion kernel
mask = cv2.erode(mask, kernel, iterations=1)
depthmap = depthmap * mask
cur_intrinsics, camera_pose = self.load_cam_mvsnet(open(campath, "r"))
intrinsics = cur_intrinsics[:3, :3]
camera_pose = np.linalg.inv(camera_pose)
if resolution != (224, 224):
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath
)
else:
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, (512, 384), rng=rng, info=impath
)
W, H = rgb_image.size
cx = W // 2
cy = H // 2
l, t = cx - 112, cy - 112
r, b = cx + 112, cy + 112
crop_bbox = (l, t, r, b)
rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
rgb_image, depthmap, intrinsics, crop_bbox
)
views.append(
dict(
img=rgb_image,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset="dtu",
label=osp.join(scene_id, im_idx),
instance=impath,
)
)
return views
class NRGBD(BaseStereoViewDataset):
def __init__(
self,
num_seq=1,
num_frames=5,
min_thresh=10,
max_thresh=100,
test_id=None,
full_video=False,
tuple_list=None,
seq_id=None,
rebuttal=False,
shuffle_seed=-1,
kf_every=1,
max_frames=None,
*args,
ROOT,
**kwargs,
):
self.ROOT = ROOT
super().__init__(*args, **kwargs)
self.num_seq = num_seq
self.num_frames = num_frames
self.max_thresh = max_thresh
self.min_thresh = min_thresh
self.test_id = test_id
self.full_video = full_video
self.kf_every = kf_every
self.seq_id = seq_id
self.rebuttal = rebuttal
self.shuffle_seed = shuffle_seed
self.max_frames = max_frames
# load all scenes
self.load_all_tuples(tuple_list)
self.load_all_scenes(ROOT)
def __len__(self):
if self.tuple_list is not None:
return len(self.tuple_list)
return len(self.scene_list) * self.num_seq
def load_all_tuples(self, tuple_list):
if tuple_list is not None:
self.tuple_list = tuple_list
# with open(tuple_path) as f:
# self.tuple_list = f.read().splitlines()
else:
self.tuple_list = None
def load_all_scenes(self, base_dir):
scenes = [
d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
]
if self.test_id is not None:
self.scene_list = [self.test_id]
else:
self.scene_list = scenes
print(f"Found {len(self.scene_list)} sequences in split {self.split}")
def load_poses(self, path):
file = open(path, "r")
lines = file.readlines()
file.close()
poses = []
valid = []
lines_per_matrix = 4
for i in range(0, len(lines), lines_per_matrix):
if "nan" in lines[i]:
valid.append(False)
poses.append(np.eye(4, 4, dtype=np.float32).tolist())
else:
valid.append(True)
pose_floats = [
[float(x) for x in line.split()]
for line in lines[i : i + lines_per_matrix]
]
poses.append(pose_floats)
return np.array(poses, dtype=np.float32), valid
def _get_views(self, idx, resolution, rng):
if self.tuple_list is not None:
line = self.tuple_list[idx].split(" ")
scene_id = line[0]
img_idxs = line[1:]
else:
scene_id = self.scene_list[idx // self.num_seq]
num_files = len(os.listdir(os.path.join(self.ROOT, scene_id, "images")))
img_idxs = [f"{i}" for i in range(num_files)]
img_idxs = img_idxs[:: min(self.kf_every, len(img_idxs) // 2)]
if self.max_frames is not None:
img_idxs = img_idxs[:self.max_frames]
fx, fy, cx, cy = 554.2562584220408, 554.2562584220408, 320, 240
intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
posepath = osp.join(self.ROOT, scene_id, f"poses.txt")
camera_poses, valids = self.load_poses(posepath)
imgs_idxs = deque(img_idxs)
if self.shuffle_seed >= 0:
imgs_idxs = shuffle_deque(imgs_idxs)
views = []
while len(imgs_idxs) > 0:
im_idx = imgs_idxs.popleft()
impath = osp.join(self.ROOT, scene_id, "images", f"img{im_idx}.png")
depthpath = osp.join(self.ROOT, scene_id, "depth", f"depth{im_idx}.png")
rgb_image = imread_cv2(impath)
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
depthmap[depthmap > 10] = 0
depthmap[depthmap < 1e-3] = 0
rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
camera_pose = camera_poses[int(im_idx)]
# gl to cv
camera_pose[:, 1:3] *= -1.0
if resolution != (224, 224) or self.rebuttal:
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
)
else:
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
)
W, H = rgb_image.size
cx = W // 2
cy = H // 2
l, t = cx - 112, cy - 112
r, b = cx + 112, cy + 112
crop_bbox = (l, t, r, b)
rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
rgb_image, depthmap, intrinsics, crop_bbox
)
views.append(
dict(
img=rgb_image,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset="nrgbd",
label=osp.join(scene_id, im_idx),
instance=impath,
)
)
return views
```
## /eval/mv_recon/dataset_utils/__init__.py
```py path="/eval/mv_recon/dataset_utils/__init__.py"
```
## /eval/mv_recon/dataset_utils/corr.py
```py path="/eval/mv_recon/dataset_utils/corr.py"
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
import numpy as np
from dust3r.utils.device import to_numpy
from dust3r.utils.geometry import inv, geotrf
def reproject_view(pts3d, view2):
shape = view2["pts3d"].shape[:2]
return reproject(
pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
)
def reproject(pts3d, K, world2cam, shape):
H, W, THREE = pts3d.shape
assert THREE == 3
with np.errstate(divide="ignore", invalid="ignore"):
pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
return (H, W), ravel_xy(pos, shape)
def ravel_xy(pos, shape):
H, W = shape
with np.errstate(invalid="ignore"):
qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
min=0, max=H - 1, out=qy
)
return quantized_pos
def unravel_xy(pos, shape):
return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
pos1 = is_reciprocal1.nonzero()[0]
pos2 = corres_1_to_2[pos1]
if ret_recip:
return is_reciprocal1, pos1, pos2
return pos1, pos2
def extract_correspondences_from_pts3d(
view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
):
view1, view2 = to_numpy((view1, view2))
shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)
is_reciprocal1, pos1, pos2 = reciprocal_1d(
corres1_to_2, corres2_to_1, ret_recip=True
)
is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))
if target_n_corres is None:
if ret_xy:
pos1 = unravel_xy(pos1, shape1)
pos2 = unravel_xy(pos2, shape2)
return pos1, pos2
available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
target_n_positives = int(target_n_corres * (1 - nneg))
n_positives = min(len(pos1), target_n_positives)
n_negatives = min(target_n_corres - n_positives, available_negatives)
if n_negatives + n_positives != target_n_corres:
n_positives = target_n_corres - n_negatives
assert n_positives <= len(pos1)
assert n_positives <= len(pos1)
assert n_positives <= len(pos2)
assert n_negatives <= (~is_reciprocal1).sum()
assert n_negatives <= (~is_reciprocal2).sum()
assert n_positives + n_negatives == target_n_corres
valid = np.ones(n_positives, dtype=bool)
if n_positives < len(pos1):
perm = rng.permutation(len(pos1))[:n_positives]
pos1 = pos1[perm]
pos2 = pos2[perm]
if n_negatives > 0:
def norm(p):
return p / p.sum()
pos1 = np.r_[
pos1,
rng.choice(
shape1[0] * shape1[1],
size=n_negatives,
replace=False,
p=norm(~is_reciprocal1),
),
]
pos2 = np.r_[
pos2,
rng.choice(
shape2[0] * shape2[1],
size=n_negatives,
replace=False,
p=norm(~is_reciprocal2),
),
]
valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
if ret_xy:
pos1 = unravel_xy(pos1, shape1)
pos2 = unravel_xy(pos2, shape2)
return pos1, pos2, valid
```
## /eval/mv_recon/dataset_utils/cropping.py
```py path="/eval/mv_recon/dataset_utils/cropping.py"
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
import PIL.Image
import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import cv2 # noqa
import numpy as np # noqa
from dust3r.utils.geometry import (
colmap_to_opencv_intrinsics,
opencv_to_colmap_intrinsics,
) # noqa
try:
lanczos = PIL.Image.Resampling.LANCZOS
bicubic = PIL.Image.Resampling.BICUBIC
except AttributeError:
lanczos = PIL.Image.LANCZOS
bicubic = PIL.Image.BICUBIC
class ImageList:
"""Convenience class to aply the same operation to a whole set of images."""
def __init__(self, images):
if not isinstance(images, (tuple, list, set)):
images = [images]
self.images = []
for image in images:
if not isinstance(image, PIL.Image.Image):
image = PIL.Image.fromarray(image)
self.images.append(image)
def __len__(self):
return len(self.images)
def to_pil(self):
return tuple(self.images) if len(self.images) > 1 else self.images[0]
@property
def size(self):
sizes = [im.size for im in self.images]
assert all(sizes[0] == s for s in sizes)
return sizes[0]
def resize(self, *args, **kwargs):
return ImageList(self._dispatch("resize", *args, **kwargs))
def crop(self, *args, **kwargs):
return ImageList(self._dispatch("crop", *args, **kwargs))
def _dispatch(self, func, *args, **kwargs):
return [getattr(im, func)(*args, **kwargs) for im in self.images]
def rescale_image_depthmap(
image, depthmap, camera_intrinsics, output_resolution, force=True
):
"""Jointly rescale a (image, depthmap)
so that (out_width, out_height) >= output_res
"""
image = ImageList(image)
input_resolution = np.array(image.size) # (W,H)
output_resolution = np.array(output_resolution)
if depthmap is not None:
assert tuple(depthmap.shape[:2]) == image.size[::-1]
assert output_resolution.shape == (2,)
scale_final = max(output_resolution / image.size) + 1e-8
if scale_final >= 1 and not force: # image is already smaller than what is asked
return (image.to_pil(), depthmap, camera_intrinsics)
output_resolution = np.floor(input_resolution * scale_final).astype(int)
image = image.resize(
output_resolution, resample=lanczos if scale_final < 1 else bicubic
)
if depthmap is not None:
depthmap = cv2.resize(
depthmap,
output_resolution,
fx=scale_final,
fy=scale_final,
interpolation=cv2.INTER_NEAREST,
)
camera_intrinsics = camera_matrix_of_crop(
camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
)
return image.to_pil(), depthmap, camera_intrinsics
def camera_matrix_of_crop(
input_camera_matrix,
input_resolution,
output_resolution,
scaling=1,
offset_factor=0.5,
offset=None,
):
margins = np.asarray(input_resolution) * scaling - output_resolution
assert np.all(margins >= 0.0)
if offset is None:
offset = offset_factor * margins
output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
output_camera_matrix_colmap[:2, :] *= scaling
output_camera_matrix_colmap[:2, 2] -= offset
output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
return output_camera_matrix
def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
"""
Return a crop of the input view.
"""
image = ImageList(image)
l, t, r, b = crop_bbox
image = image.crop((l, t, r, b))
depthmap = depthmap[t:b, l:r]
camera_intrinsics = camera_intrinsics.copy()
camera_intrinsics[0, 2] -= l
camera_intrinsics[1, 2] -= t
return image.to_pil(), depthmap, camera_intrinsics
def bbox_from_intrinsics_in_out(
input_camera_matrix, output_camera_matrix, output_resolution
):
out_width, out_height = output_resolution
l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
crop_bbox = (l, t, l + out_width, t + out_height)
return crop_bbox
```
## /eval/mv_recon/dataset_utils/transforms.py
```py path="/eval/mv_recon/dataset_utils/transforms.py"
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
import torchvision.transforms as tvf
from dust3r.utils.image import ImgNorm
ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
if isinstance(value, (int, float)):
if value < 0:
raise ValueError(f"If is a single number, it must be non negative.")
value = [center - float(value), center + float(value)]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
value = [float(value[0]), float(value[1])]
else:
raise TypeError(f"should be a single number or a list/tuple with length 2.")
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"values should be between {bound}, but got {value}.")
if value[0] == value[1] == center:
return None
else:
return tuple(value)
import torch
import torchvision.transforms.functional as F
def SeqColorJitter():
"""
Return a color jitter transform with same random parameters
"""
brightness = _check_input(0.5)
contrast = _check_input(0.5)
saturation = _check_input(0.5)
hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
fn_idx = torch.randperm(4)
brightness_factor = (
None
if brightness is None
else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
)
contrast_factor = (
None
if contrast is None
else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
)
saturation_factor = (
None
if saturation is None
else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
)
hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
def _color_jitter(img):
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
img = F.adjust_brightness(img, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
img = F.adjust_contrast(img, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
img = F.adjust_saturation(img, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img = F.adjust_hue(img, hue_factor)
return ImgNorm(img)
return _color_jitter
```
## /eval/mv_recon/launch.py
```py path="/eval/mv_recon/launch.py"
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import time
import torch
import argparse
import numpy as np
import open3d as o3d
import os.path as osp
from torch.utils.data import DataLoader
from add_ckpt_path import add_path_to_dust3r
from accelerate import Accelerator
from torch.utils.data._utils.collate import default_collate
import tempfile
from tqdm import tqdm
def get_args_parser():
parser = argparse.ArgumentParser("3D Reconstruction evaluation", add_help=False)
parser.add_argument(
"--weights",
type=str,
default="",
help="ckpt name",
)
parser.add_argument("--device", type=str, default="cuda:0", help="device")
parser.add_argument("--model_name", type=str, default="")
parser.add_argument(
"--conf_thresh", type=float, default=0.0, help="confidence threshold"
)
parser.add_argument(
"--output_dir",
type=str,
default="",
help="value for outdir",
)
parser.add_argument("--size", type=int, default=512)
parser.add_argument("--revisit", type=int, default=1, help="revisit times")
parser.add_argument("--freeze", action="store_true")
parser.add_argument("--max_frames", type=int, default=None, help="max frames limit")
parser.add_argument("--model_update_type", type=str, default="cut3r", help="model update type")
parser.add_argument("--voxel_size", type=float, default=0.0, help="voxel size for voxel grid downsampling, 0 means no downsampling")
return parser
def main(args):
add_path_to_dust3r(args.weights)
from eval.mv_recon.data import SevenScenes, NRGBD
from eval.mv_recon.utils import accuracy, completion
if args.size == 512:
resolution = (512, 384)
elif args.size == 224:
resolution = 224
else:
raise NotImplementedError
datasets_all = {
"7scenes": SevenScenes(
split="test",
ROOT="/home/share/Dataset/3D_scene/7scenes/", # "./data/7scenes",
resolution=resolution,
num_seq=1,
full_video=True,
kf_every=2,
max_frames=args.max_frames,
)
}
# ====== print the number of views for each scene ======
print("\n=== number of views for each scene ===")
for name_data, dataset in datasets_all.items():
print(f"\n{name_data} dataset:")
for scene_id in dataset.scene_list:
if name_data == "NRGBD":
# NRGBD dataset file structure
data_path = osp.join(dataset.ROOT, scene_id, "images")
num_files = len([name for name in os.listdir(data_path) if name.endswith('.png')])
view_count = len([f"{i}" for i in range(num_files)][::dataset.kf_every])
else:
# SevenScenes dataset file structure
data_path = osp.join(dataset.ROOT, scene_id)
num_files = len([name for name in os.listdir(data_path) if "color" in name])
view_count = len([f"{i:06d}" for i in range(num_files)][::dataset.kf_every])
# consider max_frames limit
if dataset.max_frames is not None:
actual_view_count = min(view_count, dataset.max_frames)
print(f" {scene_id}: {actual_view_count} views (original: {view_count}, limit: {dataset.max_frames})")
else:
print(f" {scene_id}: {view_count} views")
print("================================\n")
# ====== print end ======
accelerator = Accelerator()
device = accelerator.device
model_name = args.model_name
# if model_name == "ours" or model_name == "cut3r":
from dust3r.model import ARCroco3DStereo
from eval.mv_recon.criterion import Regr3D_t_ScaleShiftInv, L21
from dust3r.utils.geometry import geotrf
from copy import deepcopy
model = ARCroco3DStereo.from_pretrained(args.weights).to(device)
model.config.model_update_type = args.model_update_type
model.eval()
# else:
# raise NotImplementedError
os.makedirs(args.output_dir, exist_ok=True)
criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True)
with torch.no_grad():
for name_data, dataset in datasets_all.items():
save_path = osp.join(args.output_dir, name_data)
os.makedirs(save_path, exist_ok=True)
log_file = osp.join(save_path, f"logs_{accelerator.process_index}.txt")
acc_all = 0
acc_all_med = 0
comp_all = 0
comp_all_med = 0
nc1_all = 0
nc1_all_med = 0
nc2_all = 0
nc2_all_med = 0
fps_all = []
time_all = []
with accelerator.split_between_processes(list(range(len(dataset)))) as idxs:
for data_idx in tqdm(idxs):
batch = default_collate([dataset[data_idx]])
ignore_keys = set(
[
"depthmap",
"dataset",
"label",
"instance",
"idx",
"true_shape",
"rng",
]
)
for view in batch:
for name in view.keys(): # pseudo_focal
if name in ignore_keys:
continue
if isinstance(view[name], tuple) or isinstance(
view[name], list
):
view[name] = [
x.to(device, non_blocking=True) for x in view[name]
]
else:
view[name] = view[name].to(device, non_blocking=True)
# if model_name == "ours" or model_name == "cut3r":
revisit = args.revisit
update = not args.freeze
if revisit > 1:
# repeat input for 'revisit' times
new_views = []
for r in range(revisit):
for i in range(len(batch)):
new_view = deepcopy(batch[i])
new_view["idx"] = [
(r * len(batch) + i)
for _ in range(len(batch[i]["idx"]))
]
new_view["instance"] = [
str(r * len(batch) + i)
for _ in range(len(batch[i]["instance"]))
]
if r > 0:
if not update:
new_view["update"] = torch.zeros_like(
batch[i]["update"]
).bool()
new_views.append(new_view)
batch = new_views
with torch.cuda.amp.autocast(enabled=False):
start = time.time()
output = model(batch)
# preds, batch = model.forward_recurrent_light(batch)
end = time.time()
preds, batch = output.ress, output.views
valid_length = len(preds) // revisit
preds = preds[-valid_length:]
batch = batch[-valid_length:]
fps = len(batch) / (end - start)
print(
f"Finished reconstruction for {name_data} {data_idx+1}/{len(dataset)}, FPS: {fps:.2f}"
)
# continue
fps_all.append(fps)
time_all.append(end - start)
# Evaluation
print(f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}")
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
criterion.get_all_pts3d_t(batch, preds)
)
pred_scale, gt_scale, pred_shift_z, gt_shift_z = (
monitoring["pred_scale"],
monitoring["gt_scale"],
monitoring["pred_shift_z"],
monitoring["gt_shift_z"],
)
in_camera1 = None
pts_all = []
pts_gt_all = []
images_all = []
masks_all = []
conf_all = []
for j, view in enumerate(batch):
if in_camera1 is None:
in_camera1 = view["camera_pose"][0].cpu()
image = view["img"].permute(0, 2, 3, 1).cpu().numpy()[0]
mask = view["valid_mask"].cpu().numpy()[0]
# pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
pts = pred_pts[j].cpu().numpy()[0]
conf = preds[j]["conf"].cpu().data.numpy()[0]
# mask = mask & (conf > 1.8)
pts_gt = gt_pts[j].detach().cpu().numpy()[0]
H, W = image.shape[:2]
cx = W // 2
cy = H // 2
l, t = cx - 112, cy - 112
r, b = cx + 112, cy + 112
image = image[t:b, l:r]
mask = mask[t:b, l:r]
pts = pts[t:b, l:r]
pts_gt = pts_gt[t:b, l:r]
#### Align predicted 3D points to the ground truth
pts[..., -1] += gt_shift_z.cpu().numpy().item()
pts = geotrf(in_camera1, pts)
pts_gt[..., -1] += gt_shift_z.cpu().numpy().item()
pts_gt = geotrf(in_camera1, pts_gt)
images_all.append((image[None, ...] + 1.0) / 2.0)
pts_all.append(pts[None, ...])
pts_gt_all.append(pts_gt[None, ...])
masks_all.append(mask[None, ...])
conf_all.append(conf[None, ...])
images_all = np.concatenate(images_all, axis=0)
pts_all = np.concatenate(pts_all, axis=0)
pts_gt_all = np.concatenate(pts_gt_all, axis=0)
masks_all = np.concatenate(masks_all, axis=0)
scene_id = view["label"][0].rsplit("/", 1)[0]
save_params = {}
save_params["images_all"] = images_all
save_params["pts_all"] = pts_all
save_params["pts_gt_all"] = pts_gt_all
save_params["masks_all"] = masks_all
np.save(
os.path.join(save_path, f"{scene_id.replace('/', '_')}.npy"),
save_params,
)
if "DTU" in name_data:
threshold = 100
else:
threshold = 0.1
pts_all_masked = pts_all[masks_all > 0]
pts_gt_all_masked = pts_gt_all[masks_all > 0]
images_all_masked = images_all[masks_all > 0]
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pts_all_masked.reshape(-1, 3))
pcd.colors = o3d.utility.Vector3dVector(images_all_masked.reshape(-1, 3))
pcd_gt = o3d.geometry.PointCloud()
pcd_gt.points = o3d.utility.Vector3dVector(pts_gt_all_masked.reshape(-1, 3))
pcd_gt.colors = o3d.utility.Vector3dVector(images_all_masked.reshape(-1, 3))
# ====== voxel grid downsampling ======
if args.voxel_size > 0:
pcd = pcd.voxel_down_sample(voxel_size=args.voxel_size)
pcd_gt = pcd_gt.voxel_down_sample(voxel_size=args.voxel_size)
# ===========================
o3d.io.write_point_cloud(
os.path.join(
save_path, f"{scene_id.replace('/', '_')}-mask.ply"
),
pcd,
)
o3d.io.write_point_cloud(
os.path.join(save_path, f"{scene_id.replace('/', '_')}-gt.ply"),
pcd_gt,
)
trans_init = np.eye(4)
reg_p2p = o3d.pipelines.registration.registration_icp(
pcd,
pcd_gt,
threshold,
trans_init,
o3d.pipelines.registration.TransformationEstimationPointToPoint(),
)
transformation = reg_p2p.transformation
pcd = pcd.transform(transformation)
pcd.estimate_normals()
pcd_gt.estimate_normals()
gt_normal = np.asarray(pcd_gt.normals)
pred_normal = np.asarray(pcd.normals)
acc, acc_med, nc1, nc1_med = accuracy(
pcd_gt.points, pcd.points, gt_normal, pred_normal
)
comp, comp_med, nc2, nc2_med = completion(
pcd_gt.points, pcd.points, gt_normal, pred_normal
)
print(
f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}"
)
print(
f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}",
file=open(log_file, "a"),
)
acc_all += acc
comp_all += comp
nc1_all += nc1
nc2_all += nc2
acc_all_med += acc_med
comp_all_med += comp_med
nc1_all_med += nc1_med
nc2_all_med += nc2_med
# release cuda memory
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
# Get depth from pcd and run TSDFusion
if accelerator.is_main_process:
to_write = ""
# Copy the error log from each process to the main error log
for i in range(8):
if not os.path.exists(osp.join(save_path, f"logs_{i}.txt")):
break
with open(osp.join(save_path, f"logs_{i}.txt"), "r") as f_sub:
to_write += f_sub.read()
with open(osp.join(save_path, f"logs_all.txt"), "w") as f:
log_data = to_write
metrics = defaultdict(list)
for line in log_data.strip().split("\n"):
match = regex.match(line)
if match:
data = match.groupdict()
# Exclude 'scene_id' from metrics as it's an identifier
for key, value in data.items():
if key != "scene_id":
metrics[key].append(float(value))
metrics["nc"].append(
(float(data["nc1"]) + float(data["nc2"])) / 2
)
metrics["nc_med"].append(
(float(data["nc1_med"]) + float(data["nc2_med"])) / 2
)
mean_metrics = {
metric: sum(values) / len(values)
for metric, values in metrics.items()
}
c_name = "mean"
print_str = f"{c_name.ljust(20)}: "
for m_name in mean_metrics:
print_num = np.mean(mean_metrics[m_name])
print_str = print_str + f"{m_name}: {print_num:.3f} | "
print_str = print_str + "\n"
f.write(to_write + print_str)
from collections import defaultdict
import re
pattern = r"""
Idx:\s*(?P<scene_id>[^,]+),\s*
Acc:\s*(?P<acc>[^,]+),\s*
Comp:\s*(?P<comp>[^,]+),\s*
NC1:\s*(?P<nc1>[^,]+),\s*
NC2:\s*(?P<nc2>[^,]+)\s*-\s*
Acc_med:\s*(?P<acc_med>[^,]+),\s*
Compc_med:\s*(?P<comp_med>[^,]+),\s*
NC1c_med:\s*(?P<nc1_med>[^,]+),\s*
NC2c_med:\s*(?P<nc2_med>[^,]+)
"""
regex = re.compile(pattern, re.VERBOSE)
if __name__ == "__main__":
parser = get_args_parser()
args = parser.parse_args()
main(args)
```
## /eval/mv_recon/run.sh
```sh path="/eval/mv_recon/run.sh"
#!/bin/bash
set -e
workdir='.'
model_names=('ttt3r') # ttt3r cut3r
ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"
for model_name in "${model_names[@]}"; do
# for max_frames in 50 100 150 200 250 300 350 400
for max_frames in 200
do
output_dir="${workdir}/eval_results/video_recon/7scenes_${max_frames}/${model_name}"
echo "$output_dir"
NCCL_TIMEOUT=360000 accelerate launch --num_processes 1 --main_process_port 29502 eval/mv_recon/launch.py \
--weights "$model_weights" \
--output_dir "$output_dir" \
--model_name "$model_name" \
--model_update_type "$model_name" \
--max_frames $max_frames \
done
done
```
## /eval/mv_recon/utils.py
```py path="/eval/mv_recon/utils.py"
import numpy as np
from scipy.spatial import cKDTree as KDTree
import torch
def completion_ratio(gt_points, rec_points, dist_th=0.05):
gen_points_kd_tree = KDTree(rec_points)
distances, _ = gen_points_kd_tree.query(gt_points)
comp_ratio = np.mean((distances < dist_th).astype(np.float32))
return comp_ratio
def accuracy(gt_points, rec_points, gt_normals=None, rec_normals=None):
gt_points_kd_tree = KDTree(gt_points)
distances, idx = gt_points_kd_tree.query(rec_points, workers=-1)
acc = np.mean(distances)
acc_median = np.median(distances)
if gt_normals is not None and rec_normals is not None:
normal_dot = np.sum(gt_normals[idx] * rec_normals, axis=-1)
normal_dot = np.abs(normal_dot)
return acc, acc_median, np.mean(normal_dot), np.median(normal_dot)
return acc, acc_median
def completion(gt_points, rec_points, gt_normals=None, rec_normals=None):
gt_points_kd_tree = KDTree(rec_points)
distances, idx = gt_points_kd_tree.query(gt_points, workers=-1)
comp = np.mean(distances)
comp_median = np.median(distances)
if gt_normals is not None and rec_normals is not None:
normal_dot = np.sum(gt_normals * rec_normals[idx], axis=-1)
normal_dot = np.abs(normal_dot)
return comp, comp_median, np.mean(normal_dot), np.median(normal_dot)
return comp, comp_median
def compute_iou(pred_vox, target_vox):
# Get voxel indices
v_pred_indices = [voxel.grid_index for voxel in pred_vox.get_voxels()]
v_target_indices = [voxel.grid_index for voxel in target_vox.get_voxels()]
# Convert to sets for set operations
v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred_indices)
v_target_filled = set(tuple(np.round(x, 4)) for x in v_target_indices)
# Compute intersection and union
intersection = v_pred_filled & v_target_filled
union = v_pred_filled | v_target_filled
# Compute IoU
iou = len(intersection) / len(union)
return iou
```
## /eval/relpose/evo_utils.py
```py path="/eval/relpose/evo_utils.py"
import os
import re
from copy import deepcopy
from pathlib import Path
import evo.main_ape as main_ape
import evo.main_rpe as main_rpe
import matplotlib.pyplot as plt
import numpy as np
from evo.core import sync
from evo.core.metrics import PoseRelation, Unit
from evo.core.trajectory import PosePath3D, PoseTrajectory3D
from evo.tools import file_interface, plot
from scipy.spatial.transform import Rotation
from evo.core import metrics
import json
def sintel_cam_read(filename):
"""Read camera data, return (M,N) tuple.
M is the intrinsic matrix, N is the extrinsic matrix, so that
x = M*N*X,
with x being a point in homogeneous image pixel coordinates, X being a
point in homogeneous world coordinates.
"""
TAG_FLOAT = 202021.25
f = open(filename, "rb")
check = np.fromfile(f, dtype=np.float32, count=1)[0]
assert (
check == TAG_FLOAT
), " cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? ".format(
TAG_FLOAT, check
)
M = np.fromfile(f, dtype="float64", count=9).reshape((3, 3))
N = np.fromfile(f, dtype="float64", count=12).reshape((3, 4))
return M, N
def load_replica_traj(gt_file):
traj_w_c = np.loadtxt(gt_file)
assert traj_w_c.shape[1] == 12 or traj_w_c.shape[1] == 16
poses = [
np.array(
[
[r[0], r[1], r[2], r[3]],
[r[4], r[5], r[6], r[7]],
[r[8], r[9], r[10], r[11]],
[0, 0, 0, 1],
]
)
for r in traj_w_c
]
pose_path = PosePath3D(poses_se3=poses)
timestamps_mat = np.arange(traj_w_c.shape[0]).astype(float)
traj = PoseTrajectory3D(poses_se3=pose_path.poses_se3, timestamps=timestamps_mat)
xyz = traj.positions_xyz
# shift -1 column -> w in back column
# quat = np.roll(traj.orientations_quat_wxyz, -1, axis=1)
# uncomment this line if the quaternion is in scalar-first format
quat = traj.orientations_quat_wxyz
traj_tum = np.column_stack((xyz, quat))
return (traj_tum, timestamps_mat)
def load_sintel_traj(gt_file): # './data/sintel/training/camdata_left/alley_2'
# Refer to ParticleSfM
gt_pose_lists = sorted(os.listdir(gt_file))
gt_pose_lists = [
os.path.join(gt_file, x) for x in gt_pose_lists if x.endswith(".cam")
]
tstamps = [float(x.split("/")[-1][:-4].split("_")[-1]) for x in gt_pose_lists]
gt_poses = [
sintel_cam_read(f)[1] for f in gt_pose_lists
] # [1] means get the extrinsic
xyzs, wxyzs = [], []
tum_gt_poses = []
for gt_pose in gt_poses:
gt_pose = np.concatenate([gt_pose, np.array([[0, 0, 0, 1]])], 0)
gt_pose_inv = np.linalg.inv(gt_pose) # world2cam -> cam2world
xyz = gt_pose_inv[:3, -1]
xyzs.append(xyz)
R = Rotation.from_matrix(gt_pose_inv[:3, :3])
xyzw = R.as_quat() # scalar-last for scipy
wxyz = np.array([xyzw[-1], xyzw[0], xyzw[1], xyzw[2]])
wxyzs.append(wxyz)
tum_gt_pose = np.concatenate([xyz, wxyz], 0) # TODO: check if this is correct
tum_gt_poses.append(tum_gt_pose)
tum_gt_poses = np.stack(tum_gt_poses, 0)
tum_gt_poses[:, :3] = tum_gt_poses[:, :3] - np.mean(
tum_gt_poses[:, :3], 0, keepdims=True
)
tt = np.expand_dims(np.stack(tstamps, 0), -1)
return tum_gt_poses, tt
def load_iphone_traj(gt_file):
# Refer to load_sintel_traj
# read all JSON format camera parameter files
gt_pose_lists = sorted(os.listdir(gt_file))
gt_pose_lists = [os.path.join(gt_file, x) for x in gt_pose_lists if x.endswith(".json")]
xyzs, wxyzs = [], []
tum_gt_poses = []
for pose_file in gt_pose_lists:
with open(pose_file, 'r') as f:
camera_data = json.load(f)
gt_pose = np.array(camera_data['w2c'])
gt_pose_inv = np.linalg.inv(gt_pose) # world2cam -> cam2world
xyz = gt_pose_inv[:3, -1]
xyzs.append(xyz)
R = Rotation.from_matrix(gt_pose_inv[:3, :3])
xyzw = R.as_quat() # scalar-last for scipy
wxyz = np.array([xyzw[-1], xyzw[0], xyzw[1], xyzw[2]])
wxyzs.append(wxyz)
tum_gt_pose = np.concatenate([xyz, wxyz], 0)
tum_gt_poses.append(tum_gt_pose)
tum_gt_poses = np.stack(tum_gt_poses, 0)
tum_gt_poses[:, :3] = tum_gt_poses[:, :3] - np.mean(
tum_gt_poses[:, :3], 0, keepdims=True
)
# use array index as timestamps
tt = np.expand_dims(np.arange(tum_gt_poses.shape[0]).astype(float), -1)
return tum_gt_poses, tt
def load_traj(gt_traj_file, traj_format="sintel", skip=0, stride=1, num_frames=None):
"""Read trajectory format. Return in TUM-RGBD format.
Returns:
traj_tum (N, 7): camera to world poses in (x,y,z,qx,qy,qz,qw)
timestamps_mat (N, 1): timestamps
"""
if traj_format == "replica":
traj_tum, timestamps_mat = load_replica_traj(gt_traj_file)
elif traj_format == "sintel":
traj_tum, timestamps_mat = load_sintel_traj(gt_traj_file)
elif traj_format in ["tum", "tartanair"]:
traj = file_interface.read_tum_trajectory_file(gt_traj_file)
xyz = traj.positions_xyz
quat = traj.orientations_quat_wxyz
timestamps_mat = traj.timestamps
traj_tum = np.column_stack((xyz, quat))
elif traj_format in ["iphone"]:
traj_tum, timestamps_mat = load_iphone_traj(gt_traj_file)
else:
raise NotImplementedError
traj_tum = traj_tum[skip::stride]
timestamps_mat = timestamps_mat[skip::stride]
if num_frames is not None:
traj_tum = traj_tum[:num_frames]
timestamps_mat = timestamps_mat[:num_frames]
return traj_tum, timestamps_mat
def update_timestamps(gt_file, traj_format, skip=0, stride=1):
"""Update timestamps given a"""
if traj_format == "tum":
traj_t_map_file = gt_file.replace("groundtruth.txt", "rgb.txt")
timestamps = load_timestamps(traj_t_map_file, traj_format)
return timestamps[skip::stride]
elif traj_format == "tartanair":
traj_t_map_file = gt_file.replace("gt_pose.txt", "times.txt")
timestamps = load_timestamps(traj_t_map_file, traj_format)
return timestamps[skip::stride]
def load_timestamps(time_file, traj_format="replica"):
if traj_format in ["tum", "tartanair"]:
with open(time_file, "r+") as f:
lines = f.readlines()
timestamps_mat = [
float(x.split(" ")[0]) for x in lines if not x.startswith("#")
]
return timestamps_mat
def make_traj(args) -> PoseTrajectory3D:
if isinstance(args, tuple) or isinstance(args, list):
traj, tstamps = args
return PoseTrajectory3D(
positions_xyz=traj[:, :3],
orientations_quat_wxyz=traj[:, 3:],
timestamps=tstamps,
)
assert isinstance(args, PoseTrajectory3D), type(args)
return deepcopy(args)
def eval_metrics(pred_traj, gt_traj=None, seq="", filename="", sample_stride=1):
if sample_stride > 1:
pred_traj[0] = pred_traj[0][::sample_stride]
pred_traj[1] = pred_traj[1][::sample_stride]
if gt_traj is not None:
updated_gt_traj = []
updated_gt_traj.append(gt_traj[0][::sample_stride])
updated_gt_traj.append(gt_traj[1][::sample_stride])
gt_traj = updated_gt_traj
pred_traj = make_traj(pred_traj)
if gt_traj is not None:
gt_traj = make_traj(gt_traj)
if pred_traj.timestamps.shape[0] == gt_traj.timestamps.shape[0]:
pred_traj.timestamps = gt_traj.timestamps
else:
print(pred_traj.timestamps.shape[0], gt_traj.timestamps.shape[0])
gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj)
# ATE
traj_ref = gt_traj
traj_est = pred_traj
ate_result = main_ape.ape(
traj_ref,
traj_est,
est_name="traj",
pose_relation=PoseRelation.translation_part,
align=True,
correct_scale=True,
)
ate = ate_result.stats["rmse"]
# print(ate_result.np_arrays['error_array'])
# exit()
# RPE rotation and translation
delta_list = [1]
rpe_rots, rpe_transs = [], []
for delta in delta_list:
rpe_rots_result = main_rpe.rpe(
traj_ref,
traj_est,
est_name="traj",
pose_relation=PoseRelation.rotation_angle_deg,
align=True,
correct_scale=True,
delta=delta,
delta_unit=Unit.frames,
rel_delta_tol=0.01,
all_pairs=True,
)
rot = rpe_rots_result.stats["rmse"]
rpe_rots.append(rot)
for delta in delta_list:
rpe_transs_result = main_rpe.rpe(
traj_ref,
traj_est,
est_name="traj",
pose_relation=PoseRelation.translation_part,
align=True,
correct_scale=True,
delta=delta,
delta_unit=Unit.frames,
rel_delta_tol=0.01,
all_pairs=True,
)
trans = rpe_transs_result.stats["rmse"]
rpe_transs.append(trans)
rpe_trans, rpe_rot = np.mean(rpe_transs), np.mean(rpe_rots)
with open(filename, "w+") as f:
f.write(f"Seq: {seq} \n\n")
f.write(f"{ate_result}")
f.write(f"{rpe_rots_result}")
f.write(f"{rpe_transs_result}")
print(f"Save results to {filename}")
return ate, rpe_trans, rpe_rot
def eval_metrics_first_pose_align_last_pose(
pred_traj, gt_traj=None, seq="", filename="", figpath="", sample_stride=1
):
if sample_stride > 1:
pred_traj[0] = pred_traj[0][::sample_stride]
pred_traj[1] = pred_traj[1][::sample_stride]
if gt_traj is not None:
gt_traj = [gt_traj[0][::sample_stride], gt_traj[1][::sample_stride]]
pred_traj = make_traj(pred_traj)
if gt_traj is not None:
gt_traj = make_traj(gt_traj)
if pred_traj.timestamps.shape[0] == gt_traj.timestamps.shape[0]:
pred_traj.timestamps = gt_traj.timestamps
else:
print(
"Different number of poses:",
pred_traj.timestamps.shape[0],
gt_traj.timestamps.shape[0],
)
gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj)
if gt_traj is not None and pred_traj is not None:
if len(gt_traj.poses_se3) > 0 and len(pred_traj.poses_se3) > 0:
first_gt_pose = gt_traj.poses_se3[0]
first_pred_pose = pred_traj.poses_se3[0]
# T = (first_gt_pose) * inv(first_pred_pose)
T = first_gt_pose @ np.linalg.inv(first_pred_pose)
# Apply T to every predicted pose
aligned_pred_poses = []
for pose in pred_traj.poses_se3:
aligned_pred_poses.append(T @ pose)
aligned_pred_traj = PoseTrajectory3D(
poses_se3=aligned_pred_poses,
timestamps=np.array(pred_traj.timestamps),
# optionally copy other fields if your make_traj object has them
)
pred_traj = aligned_pred_traj # .poses_se3 = aligned_pred_poses
plot_trajectory(
pred_traj,
gt_traj,
title=seq,
filename=figpath,
align=False,
correct_scale=False,
)
if gt_traj is not None and len(gt_traj.poses_se3) > 0:
gt_traj = PoseTrajectory3D(
poses_se3=[gt_traj.poses_se3[-1]], timestamps=[gt_traj.timestamps[-1]]
)
if pred_traj is not None and len(pred_traj.poses_se3) > 0:
pred_traj = PoseTrajectory3D(
poses_se3=[pred_traj.poses_se3[-1]], timestamps=[pred_traj.timestamps[-1]]
)
ate_result = main_ape.ape(
gt_traj,
pred_traj,
est_name="traj",
pose_relation=PoseRelation.translation_part,
align=False, # <-- important
correct_scale=False, # <-- important
)
ate = ate_result.stats["rmse"]
with open(filename, "w+") as f:
f.write(f"Seq: {seq}\n\n")
f.write(f"{ate_result}")
print(f"Save results to {filename}")
return ate
def best_plotmode(traj):
_, i1, i2 = np.argsort(np.var(traj.positions_xyz, axis=0))
plot_axes = "xyz"[i2] + "xyz"[i1]
return getattr(plot.PlotMode, plot_axes)
def plot_trajectory(
pred_traj, gt_traj=None, title="", filename="", align=True, correct_scale=True
):
pred_traj = make_traj(pred_traj)
if gt_traj is not None:
gt_traj = make_traj(gt_traj)
if pred_traj.timestamps.shape[0] == gt_traj.timestamps.shape[0]:
pred_traj.timestamps = gt_traj.timestamps
else:
print("WARNING", pred_traj.timestamps.shape[0], gt_traj.timestamps.shape[0])
gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj)
if align:
pred_traj.align(gt_traj, correct_scale=correct_scale)
plot_collection = plot.PlotCollection("PlotCol")
fig = plt.figure(figsize=(8, 8))
plot_mode = best_plotmode(gt_traj if (gt_traj is not None) else pred_traj)
ax = plot.prepare_axis(fig, plot_mode)
ax.set_title(title)
if gt_traj is not None:
plot.traj(ax, plot_mode, gt_traj, "--", "gray", "Ground Truth")
plot.traj(ax, plot_mode, pred_traj, "-", "blue", "Predicted")
plot_collection.add_figure("traj_error", fig)
plot_collection.export(filename, confirm_overwrite=False)
plt.close(fig=fig)
print(f"Saved trajectory to {filename.replace('.png','')}_traj_error.png")
def save_trajectory_tum_format(traj, filename):
traj = make_traj(traj)
tostr = lambda a: " ".join(map(str, a))
with Path(filename).open("w") as f:
for i in range(traj.num_poses):
f.write(
f"{traj.timestamps[i]} {tostr(traj.positions_xyz[i])} {tostr(traj.orientations_quat_wxyz[i][[0,1,2,3]])}\n"
)
print(f"Saved trajectory to {filename}")
def extract_metrics(file_path):
with open(file_path, "r") as file:
content = file.read()
# Extract metrics using regex
ate_match = re.search(
r"APE w.r.t. translation part \(m\).*?rmse\s+([0-9.]+)", content, re.DOTALL
)
rpe_trans_match = re.search(
r"RPE w.r.t. translation part \(m\).*?rmse\s+([0-9.]+)", content, re.DOTALL
)
rpe_rot_match = re.search(
r"RPE w.r.t. rotation angle in degrees \(deg\).*?rmse\s+([0-9.]+)",
content,
re.DOTALL,
)
ate = float(ate_match.group(1)) if ate_match else 0.0
rpe_trans = float(rpe_trans_match.group(1)) if rpe_trans_match else 0.0
rpe_rot = float(rpe_rot_match.group(1)) if rpe_rot_match else 0.0
return ate, rpe_trans, rpe_rot
def process_directory(directory):
results = []
for root, _, files in os.walk(directory):
if files is not None:
files = sorted(files)
for file in files:
if file.endswith("_metric.txt"):
file_path = os.path.join(root, file)
seq_name = file.replace("_eval_metric.txt", "")
ate, rpe_trans, rpe_rot = extract_metrics(file_path)
results.append((seq_name, ate, rpe_trans, rpe_rot))
return results
def calculate_averages(results):
total_ate = sum(r[1] for r in results)
total_rpe_trans = sum(r[2] for r in results)
total_rpe_rot = sum(r[3] for r in results)
count = len(results)
if count == 0:
return 0.0, 0.0, 0.0
avg_ate = total_ate / count
avg_rpe_trans = total_rpe_trans / count
avg_rpe_rot = total_rpe_rot / count
return avg_ate, avg_rpe_trans, avg_rpe_rot
```
## /eval/relpose/launch.py
```py path="/eval/relpose/launch.py"
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import math
import cv2
import numpy as np
import torch
import argparse
from copy import deepcopy
from eval.relpose.metadata import dataset_metadata
from eval.relpose.utils import *
from accelerate import PartialState
from add_ckpt_path import add_path_to_dust3r
from tqdm import tqdm
import time
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--weights",
type=str,
help="path to the model weights",
default="",
)
parser.add_argument("--device", type=str, default="cuda", help="pytorch device")
parser.add_argument(
"--output_dir",
type=str,
default="",
help="value for outdir",
)
parser.add_argument(
"--no_crop", type=bool, default=True, help="whether to crop input data"
)
parser.add_argument(
"--eval_dataset",
type=str,
default="sintel",
choices=list(dataset_metadata.keys()),
)
parser.add_argument("--size", type=int, default="224")
parser.add_argument(
"--model_update_type",
type=str,
default="cut3r",
help="model type for state update strategy: cut3r or ttt3r",
)
parser.add_argument(
"--pose_eval_stride", default=1, type=int, help="stride for pose evaluation"
)
parser.add_argument("--shuffle", action="store_true", default=False)
parser.add_argument(
"--full_seq",
action="store_true",
default=False,
help="use full sequence for pose evaluation",
)
parser.add_argument(
"--seq_list",
nargs="+",
default=None,
help="list of sequences for pose evaluation",
)
parser.add_argument("--revisit", type=int, default=1)
parser.add_argument("--freeze_state", action="store_true", default=False)
parser.add_argument("--solve_pose", action="store_true", default=False)
return parser
def eval_pose_estimation(args, model, save_dir=None):
metadata = dataset_metadata.get(args.eval_dataset)
img_path = metadata["img_path"]
mask_path = metadata["mask_path"]
ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist(
args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path
)
return ate_mean, rpe_trans_mean, rpe_rot_mean
def eval_pose_estimation_dist(args, model, img_path, save_dir=None, mask_path=None):
from dust3r.inference import inference, inference_recurrent, inference_recurrent_lighter
metadata = dataset_metadata.get(args.eval_dataset)
anno_path = metadata.get("anno_path", None)
seq_list = args.seq_list
if seq_list is None:
if metadata.get("full_seq", False):
args.full_seq = True
else:
seq_list = metadata.get("seq_list", [])
if args.full_seq:
seq_list = os.listdir(img_path)
seq_list = [
seq for seq in seq_list if os.path.isdir(os.path.join(img_path, seq))
]
seq_list = sorted(seq_list)
if save_dir is None:
save_dir = args.output_dir
distributed_state = PartialState()
model.to(distributed_state.device)
device = distributed_state.device
with distributed_state.split_between_processes(seq_list) as seqs:
ate_list = []
rpe_trans_list = []
rpe_rot_list = []
load_img_size = args.size
error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt" # Unique log file per process
bug = False
for seq in tqdm(seqs):
try:
dir_path = metadata["dir_path_func"](img_path, seq)
# Handle skip_condition
skip_condition = metadata.get("skip_condition", None)
if skip_condition is not None and skip_condition(save_dir, seq):
continue
mask_path_seq_func = metadata.get(
"mask_path_seq_func", lambda mask_path, seq: None
)
mask_path_seq = mask_path_seq_func(mask_path, seq)
filelist = [
os.path.join(dir_path, name) for name in os.listdir(dir_path)
]
filelist.sort()
filelist = filelist[:: args.pose_eval_stride]
views = prepare_input(
filelist,
[True for _ in filelist],
size=load_img_size,
crop=not args.no_crop,
revisit=args.revisit,
update=not args.freeze_state,
)
start = time.time()
outputs, _ = inference_recurrent_lighter(views, model, device)
end = time.time()
fps = len(filelist) / (end - start)
print(f"Finished pose estimation for {args.eval_dataset} {seq: <16}, FPS: {fps:.2f}")
(
colors,
pts3ds_self,
pts3ds_other,
conf_self,
conf_other,
cam_dict,
pr_poses,
) = prepare_output(
outputs, revisit=args.revisit, solve_pose=args.solve_pose
)
pred_traj = get_tum_poses(pr_poses)
os.makedirs(f"{save_dir}/{seq}", exist_ok=True)
save_tum_poses(pr_poses, f"{save_dir}/{seq}/pred_traj.txt")
save_focals(cam_dict, f"{save_dir}/{seq}/pred_focal.txt")
save_intrinsics(cam_dict, f"{save_dir}/{seq}/pred_intrinsics.txt")
# save_depth_maps(pts3ds_self,f'{save_dir}/{seq}', conf_self=conf_self)
# save_conf_maps(conf_self,f'{save_dir}/{seq}')
# save_rgb_imgs(colors,f'{save_dir}/{seq}')
gt_traj_file = metadata["gt_traj_func"](img_path, anno_path, seq)
traj_format = metadata.get("traj_format", None)
if args.eval_dataset == "sintel":
gt_traj = load_traj(
gt_traj_file=gt_traj_file, stride=args.pose_eval_stride
)
elif traj_format is not None:
gt_traj = load_traj(
gt_traj_file=gt_traj_file,
traj_format=traj_format,
stride=args.pose_eval_stride,
)
else:
gt_traj = None
if gt_traj is not None:
ate, rpe_trans, rpe_rot = eval_metrics(
pred_traj,
gt_traj,
seq=seq,
filename=f"{save_dir}/{seq}_eval_metric.txt",
)
plot_trajectory(
pred_traj, gt_traj, title=seq, filename=f"{save_dir}/{seq}.png"
)
else:
ate, rpe_trans, rpe_rot = 0, 0, 0
bug = True
ate_list.append(ate)
rpe_trans_list.append(rpe_trans)
rpe_rot_list.append(rpe_rot)
# Write to error log after each sequence
with open(error_log_path, "a") as f:
f.write(
f"{args.eval_dataset}-{seq: <16} | ATE: {ate:.5f}, RPE trans: {rpe_trans:.5f}, RPE rot: {rpe_rot:.5f}\n"
)
f.write(f"{ate:.5f}\n")
f.write(f"{rpe_trans:.5f}\n")
f.write(f"{rpe_rot:.5f}\n")
except Exception as e:
if "out of memory" in str(e):
# Handle OOM
torch.cuda.empty_cache() # Clear the CUDA memory
with open(error_log_path, "a") as f:
f.write(
f"OOM error in sequence {seq}, skipping this sequence.\n"
)
print(f"OOM error in sequence {seq}, skipping...")
elif "Degenerate covariance rank" in str(
e
) or "Eigenvalues did not converge" in str(e):
# Handle Degenerate covariance rank exception and Eigenvalues did not converge exception
with open(error_log_path, "a") as f:
f.write(f"Exception in sequence {seq}: {str(e)}\n")
print(f"Traj evaluation error in sequence {seq}, skipping.")
else:
raise e # Rethrow if it's not an expected exception
distributed_state.wait_for_everyone()
results = process_directory(save_dir)
avg_ate, avg_rpe_trans, avg_rpe_rot = calculate_averages(results)
# Write the averages to the error log (only on the main process)
if distributed_state.is_main_process:
with open(f"{save_dir}/_error_log.txt", "a") as f:
# Copy the error log from each process to the main error log
for i in range(distributed_state.num_processes):
if not os.path.exists(f"{save_dir}/_error_log_{i}.txt"):
break
with open(f"{save_dir}/_error_log_{i}.txt", "r") as f_sub:
f.write(f_sub.read())
f.write(
f"Average ATE: {avg_ate:.5f}, Average RPE trans: {avg_rpe_trans:.5f}, Average RPE rot: {avg_rpe_rot:.5f}\n"
)
return avg_ate, avg_rpe_trans, avg_rpe_rot
if __name__ == "__main__":
args = get_args_parser()
args = args.parse_args()
add_path_to_dust3r(args.weights)
from dust3r.utils.image import load_images_for_eval as load_images
from dust3r.post_process import estimate_focal_knowing_depth
from dust3r.model import ARCroco3DStereo
from dust3r.utils.camera import pose_encoding_to_camera
from dust3r.utils.geometry import weighted_procrustes, geotrf
args.full_seq = False
args.no_crop = False
def recover_cam_params(pts3ds_self, pts3ds_other, conf_self, conf_other):
B, H, W, _ = pts3ds_self.shape
pp = (
torch.tensor([W // 2, H // 2], device=pts3ds_self.device)
.float()
.repeat(B, 1)
.reshape(B, 1, 2)
)
focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld")
pts3ds_self = pts3ds_self.reshape(B, -1, 3)
pts3ds_other = pts3ds_other.reshape(B, -1, 3)
conf_self = conf_self.reshape(B, -1)
conf_other = conf_other.reshape(B, -1)
# weighted procrustes
c2w = weighted_procrustes(
pts3ds_self,
pts3ds_other,
torch.log(conf_self) * torch.log(conf_other),
use_weights=True,
return_T=True,
)
return c2w, focal, pp.reshape(B, 2)
def prepare_input(
img_paths,
img_mask,
size,
raymaps=None,
raymap_mask=None,
revisit=1,
update=True,
crop=True,
):
images = load_images(img_paths, size=size, crop=crop, verbose=False)
views = []
if raymaps is None and raymap_mask is None:
num_views = len(images)
for i in range(num_views):
view = {
"img": images[i]["img"],
"ray_map": torch.full(
(
images[i]["img"].shape[0],
6,
images[i]["img"].shape[-2],
images[i]["img"].shape[-1],
),
torch.nan,
),
"true_shape": torch.from_numpy(images[i]["true_shape"]),
"idx": i,
"instance": str(i),
"camera_pose": torch.from_numpy(
np.eye(4).astype(np.float32)
).unsqueeze(0),
"img_mask": torch.tensor(True).unsqueeze(0),
"ray_mask": torch.tensor(False).unsqueeze(0),
"update": torch.tensor(True).unsqueeze(0),
"reset": torch.tensor(False).unsqueeze(0),
}
views.append(view)
else:
num_views = len(images) + len(raymaps)
assert len(img_mask) == len(raymap_mask) == num_views
assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps)
j = 0
k = 0
for i in range(num_views):
view = {
"img": (
images[j]["img"]
if img_mask[i]
else torch.full_like(images[0]["img"], torch.nan)
),
"ray_map": (
raymaps[k]
if raymap_mask[i]
else torch.full_like(raymaps[0], torch.nan)
),
"true_shape": (
torch.from_numpy(images[j]["true_shape"])
if img_mask[i]
else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]]))
),
"idx": i,
"instance": str(i),
"camera_pose": torch.from_numpy(
np.eye(4).astype(np.float32)
).unsqueeze(0),
"img_mask": torch.tensor(img_mask[i]).unsqueeze(0),
"ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0),
"update": torch.tensor(img_mask[i]).unsqueeze(0),
"reset": torch.tensor(False).unsqueeze(0),
}
if img_mask[i]:
j += 1
if raymap_mask[i]:
k += 1
views.append(view)
assert j == len(images) and k == len(raymaps)
if revisit > 1:
# repeat input for 'revisit' times
new_views = []
for r in range(revisit):
for i in range(len(views)):
new_view = deepcopy(views[i])
new_view["idx"] = r * len(views) + i
new_view["instance"] = str(r * len(views) + i)
if r > 0:
if not update:
new_view["update"] = torch.tensor(False).unsqueeze(0)
new_views.append(new_view)
return new_views
return views
def prepare_output(outputs, revisit=1, solve_pose=False):
valid_length = len(outputs["pred"]) // revisit
outputs["pred"] = outputs["pred"][-valid_length:]
outputs["views"] = outputs["views"][-valid_length:]
if solve_pose:
pts3ds_self = [
output["pts3d_in_self_view"].cpu() for output in outputs["pred"]
]
pts3ds_other = [
output["pts3d_in_other_view"].cpu() for output in outputs["pred"]
]
conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
conf_other = [output["conf"].cpu() for output in outputs["pred"]]
pr_poses, focal, pp = recover_cam_params(
torch.cat(pts3ds_self, 0),
torch.cat(pts3ds_other, 0),
torch.cat(conf_self, 0),
torch.cat(conf_other, 0),
)
pts3ds_self = torch.cat(pts3ds_self, 0)
else:
pts3ds_self = [
output["pts3d_in_self_view"].cpu() for output in outputs["pred"]
]
pts3ds_other = [
output["pts3d_in_other_view"].cpu() for output in outputs["pred"]
]
conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
conf_other = [output["conf"].cpu() for output in outputs["pred"]]
pts3ds_self = torch.cat(pts3ds_self, 0)
pr_poses = [
pose_encoding_to_camera(pred["camera_pose"].clone()).cpu()
for pred in outputs["pred"]
]
pr_poses = torch.cat(pr_poses, 0)
B, H, W, _ = pts3ds_self.shape
pp = (
torch.tensor([W // 2, H // 2], device=pts3ds_self.device)
.float()
.repeat(B, 1)
.reshape(B, 2)
)
focal = estimate_focal_knowing_depth(
pts3ds_self, pp, focal_mode="weiszfeld"
)
colors = [0.5 * (output["rgb"][0] + 1.0) for output in outputs["pred"]]
cam_dict = {
"focal": focal.cpu().numpy(),
"pp": pp.cpu().numpy(),
}
return (
colors,
pts3ds_self,
pts3ds_other,
conf_self,
conf_other,
cam_dict,
pr_poses,
)
model = ARCroco3DStereo.from_pretrained(args.weights)
# set model type
model.config.model_update_type = args.model_update_type
eval_pose_estimation(args, model, save_dir=args.output_dir)
```
## /eval/relpose/metadata.py
```py path="/eval/relpose/metadata.py"
import os
import glob
from tqdm import tqdm
# Define the merged dataset metadata dictionary
dataset_metadata = {
"davis": {
"img_path": "data/davis/DAVIS/JPEGImages/480p",
"mask_path": "data/davis/DAVIS/masked_images/480p",
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
"gt_traj_func": lambda img_path, anno_path, seq: None,
"traj_format": None,
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq),
"skip_condition": None,
"process_func": None, # Not used in mono depth estimation
},
"kitti": {
"img_path": "data/kitti/depth_selection/val_selection_cropped/image_gathered", # Default path
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
"gt_traj_func": lambda img_path, anno_path, seq: None,
"traj_format": None,
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": lambda args, img_path: process_kitti(args, img_path),
},
"bonn": {
"img_path": "data/bonn/rgbd_bonn_dataset",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(
img_path, f"rgbd_bonn_{seq}", "rgb_110"
),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt"
),
"traj_format": "tum",
"seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"],
"full_seq": False,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": lambda args, img_path: process_bonn(args, img_path),
},
"nyu": {
"img_path": "data/nyu-v2/val/nyu_images",
"mask_path": None,
"process_func": lambda args, img_path: process_nyu(args, img_path),
},
"scannet": {
"img_path": "data/scannetv2",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, seq, "pose_90.txt"
),
"traj_format": "replica",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
"process_func": lambda args, img_path: process_scannet(args, img_path),
},
"scannet-257": {
"img_path": "data/scannetv2_3_257",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, seq, "pose_90.txt"
),
"traj_format": "replica",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
"process_func": lambda args, img_path: process_scannet(args, img_path),
},
"scannet-129": {
"img_path": "data/scannetv2_3_129",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, seq, "pose_90.txt"
),
"traj_format": "replica",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
"process_func": lambda args, img_path: process_scannet(args, img_path),
},
"scannet-65": {
"img_path": "data/scannetv2_3_65",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, seq, "pose_90.txt"
),
"traj_format": "replica",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
"process_func": lambda args, img_path: process_scannet(args, img_path),
},
"scannet-33": {
"img_path": "data/scannetv2_3_33",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, seq, "pose_90.txt"
),
"traj_format": "replica",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
"process_func": lambda args, img_path: process_scannet(args, img_path),
},
"tum": {
"img_path": "data/tum",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, seq, "groundtruth_90.txt"
),
"traj_format": "tum",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": None,
},
"sintel": {
"img_path": "data/sintel/training/final",
"anno_path": "data/sintel/training/camdata_left",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq),
"traj_format": None,
"seq_list": [
"alley_2",
"ambush_4",
"ambush_5",
"ambush_6",
"cave_2",
"cave_4",
"market_2",
"market_5",
"market_6",
"shaman_3",
"sleeping_1",
"sleeping_2",
"temple_2",
"temple_3",
],
"full_seq": False,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": lambda args, img_path: process_sintel(args, img_path),
},
}
scannet_numbers = [50, 90, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 950, 1000]
scannet_configs = {
f"scannet_s3_{num}": {
"img_path": "data/long_scannet_s3",
"mask_path": None,
"dir_path_func": lambda img_path, seq, num=num: os.path.join(img_path, seq, f"color_{num}"),
"gt_traj_func": lambda img_path, anno_path, seq, num=num: os.path.join(
img_path, seq, f"pose_{num}.txt"
),
"traj_format": "replica",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": lambda args, img_path: process_scannet(args, img_path),
}
for num in scannet_numbers
}
# then update dataset_metadata
dataset_metadata.update(scannet_configs)
tum_numbers = [50, 100, 150, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
tum_configs = {
f"tum_s1_{num}": {
"img_path": "data/long_tum_s1",
"mask_path": None,
"dir_path_func": lambda img_path, seq, num=num: os.path.join(img_path, seq, f"rgb_{num}"),
"gt_traj_func": lambda img_path, anno_path, seq, num=num: os.path.join(
img_path, seq, f"groundtruth_{num}.txt"
),
"traj_format": "tum",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": None,
}
for num in tum_numbers
}
dataset_metadata.update(tum_configs)
# Define processing functions for each dataset
def process_kitti(args, img_path):
for dir in tqdm(sorted(glob.glob(f"{img_path}/*"))):
filelist = sorted(glob.glob(f"{dir}/*.png"))
save_dir = f"{args.output_dir}/{os.path.basename(dir)}"
yield filelist, save_dir
def process_bonn(args, img_path):
if args.full_seq:
for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))):
filelist = sorted(glob.glob(f"{dir}/rgb/*.png"))
save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}"
yield filelist, save_dir
else:
seq_list = (
["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"]
if args.seq_list is None
else args.seq_list
)
for seq in tqdm(seq_list):
filelist = sorted(glob.glob(f"{img_path}/rgbd_bonn_{seq}/rgb_110/*.png"))
save_dir = f"{args.output_dir}/{seq}"
yield filelist, save_dir
def process_nyu(args, img_path):
filelist = sorted(glob.glob(f"{img_path}/*.png"))
save_dir = f"{args.output_dir}"
yield filelist, save_dir
def process_scannet(args, img_path):
seq_list = sorted(glob.glob(f"{img_path}/*"))
for seq in tqdm(seq_list):
filelist = sorted(glob.glob(f"{seq}/color_90/*.jpg"))
save_dir = f"{args.output_dir}/{os.path.basename(seq)}"
yield filelist, save_dir
def process_sintel(args, img_path):
if args.full_seq:
for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))):
filelist = sorted(glob.glob(f"{dir}/*.png"))
save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}"
yield filelist, save_dir
else:
seq_list = [
"alley_2",
"ambush_4",
"ambush_5",
"ambush_6",
"cave_2",
"cave_4",
"market_2",
"market_5",
"market_6",
"shaman_3",
"sleeping_1",
"sleeping_2",
"temple_2",
"temple_3",
]
for seq in tqdm(seq_list):
filelist = sorted(glob.glob(f"{img_path}/{seq}/*.png"))
save_dir = f"{args.output_dir}/{seq}"
yield filelist, save_dir
```
## /eval/relpose/run_scannet.sh
```sh path="/eval/relpose/run_scannet.sh"
#!/bin/bash
set -e
workdir='.'
model_names=('ttt3r') # ttt3r cut3r
ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"
# datasets=('scannet_s3_50' 'scannet_s3_90' 'scannet_s3_100' 'scannet_s3_150' 'scannet_s3_200' 'scannet_s3_300' 'scannet_s3_400' 'scannet_s3_500'
# 'scannet_s3_600' 'scannet_s3_700' 'scannet_s3_800' 'scannet_s3_900' 'scannet_s3_1000')
datasets=('scannet_s3_1000')
for model_name in "${model_names[@]}"; do
for data in "${datasets[@]}"; do
output_dir="${workdir}/eval_results/relpose/${data}/${model_name}"
echo "$output_dir"
accelerate launch --num_processes 2 --main_process_port 29550 eval/relpose/launch.py \
--weights "$model_weights" \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--size 512 \
--model_update_type "$model_name"
done
done
```
## /eval/relpose/run_sintel.sh
```sh path="/eval/relpose/run_sintel.sh"
#!/bin/bash
set -e
workdir='.'
model_names=('ttt3r') # ttt3r cut3r
ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"
datasets=('sintel')
for model_name in "${model_names[@]}"; do
for data in "${datasets[@]}"; do
output_dir="${workdir}/eval_results/s1/relpose/${data}/${model_name}"
echo "$output_dir"
accelerate launch --num_processes 2 --main_process_port 29550 eval/relpose/launch.py \
--weights "$model_weights" \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--size 512 \
--model_update_type "$model_name"
done
done
```
## /eval/relpose/run_tum.sh
```sh path="/eval/relpose/run_tum.sh"
#!/bin/bash
set -e
workdir='.'
model_names=('ttt3r') # ttt3r cut3r
ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"
# datasets=('tum_s1_50' 'tum_s1_100' 'tum_s1_150' 'tum_s1_200' 'tum_s1_300' 'tum_s1_400' 'tum_s1_500' 'tum_s1_600' 'tum_s1_700' 'tum_s1_800' 'tum_s1_900' 'tum_s1_1000')
datasets=('tum_s1_1000')
for model_name in "${model_names[@]}"; do
for data in "${datasets[@]}"; do
output_dir="${workdir}/eval_results/relpose/${data}/${model_name}"
echo "$output_dir"
accelerate launch --num_processes 2 --main_process_port 29551 eval/relpose/launch.py \
--weights "$model_weights" \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--size 512 \
--model_update_type "$model_name"
done
done
```
## /eval/relpose/utils.py
```py path="/eval/relpose/utils.py"
from copy import deepcopy
import cv2
import numpy as np
import torch
import torch.nn as nn
import roma
from copy import deepcopy
import tqdm
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from scipy.spatial.transform import Rotation
from eval.relpose.evo_utils import *
from PIL import Image
import imageio.v2 as iio
from matplotlib.figure import Figure
# from checkpoints.dust3r.viz import colorize_np, colorize
def todevice(batch, device, callback=None, non_blocking=False):
"""Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
batch: list, tuple, dict of tensors or other things
device: pytorch device or 'numpy'
callback: function that would be called on every sub-elements.
"""
if callback:
batch = callback(batch)
if isinstance(batch, dict):
return {k: todevice(v, device) for k, v in batch.items()}
if isinstance(batch, (tuple, list)):
return type(batch)(todevice(x, device) for x in batch)
x = batch
if device == "numpy":
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
elif x is not None:
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if torch.is_tensor(x):
x = x.to(device, non_blocking=non_blocking)
return x
to_device = todevice # alias
def to_numpy(x):
return todevice(x, "numpy")
def c2w_to_tumpose(c2w):
"""
Convert a camera-to-world matrix to a tuple of translation and rotation
input: c2w: 4x4 matrix
output: tuple of translation and rotation (x y z qw qx qy qz)
"""
# convert input to numpy
c2w = to_numpy(c2w)
xyz = c2w[:3, -1]
rot = Rotation.from_matrix(c2w[:3, :3])
qx, qy, qz, qw = rot.as_quat()
tum_pose = np.concatenate([xyz, [qw, qx, qy, qz]])
return tum_pose
def get_tum_poses(poses):
"""
poses: list of 4x4 arrays
"""
tt = np.arange(len(poses)).astype(float)
tum_poses = [c2w_to_tumpose(p) for p in poses]
tum_poses = np.stack(tum_poses, 0)
return [tum_poses, tt]
def save_tum_poses(poses, path):
traj = get_tum_poses(poses)
save_trajectory_tum_format(traj, path)
return traj[0] # return the poses
def save_focals(cam_dict, path):
# convert focal to txt
focals = cam_dict["focal"]
np.savetxt(path, focals, fmt="%.6f")
return focals
def save_intrinsics(cam_dict, path):
K_raw = np.eye(3)[None].repeat(len(cam_dict["focal"]), axis=0)
K_raw[:, 0, 0] = cam_dict["focal"]
K_raw[:, 1, 1] = cam_dict["focal"]
K_raw[:, :2, 2] = cam_dict["pp"]
K = K_raw.reshape(-1, 9)
np.savetxt(path, K, fmt="%.6f")
return K_raw
def save_conf_maps(conf, path):
for i, c in enumerate(conf):
np.save(f"{path}/conf_{i}.npy", c.detach().cpu().numpy())
return conf
def save_rgb_imgs(colors, path):
imgs = colors
for i, img in enumerate(imgs):
# convert from rgb to bgr
iio.imwrite(
f"{path}/frame_{i:04d}.png", (img.cpu().numpy() * 255).astype(np.uint8)
)
return imgs
def save_depth_maps(pts3ds_self, path, conf_self=None):
depth_maps = torch.stack([pts3d_self[..., -1] for pts3d_self in pts3ds_self], 0)
min_depth = depth_maps.min() # float(torch.quantile(out, 0.01))
max_depth = depth_maps.max() # float(torch.quantile(out, 0.99))
colored_depth = colorize(
depth_maps,
cmap_name="Spectral_r",
range=(min_depth, max_depth),
append_cbar=True,
)
images = []
if conf_self is not None:
conf_selfs = torch.concat(conf_self, 0)
min_conf = torch.log(conf_selfs.min()) # float(torch.quantile(out, 0.01))
max_conf = torch.log(conf_selfs.max()) # float(torch.quantile(out, 0.99))
colored_conf = colorize(
torch.log(conf_selfs),
cmap_name="jet",
range=(min_conf, max_conf),
append_cbar=True,
)
for i, depth_map in enumerate(colored_depth):
# Apply color map to depth map
img_path = f"{path}/depth_frame_{(i):04d}.png"
if conf_self is None:
to_save = (depth_map * 255).detach().cpu().numpy().astype(np.uint8)
else:
to_save = torch.cat([depth_map, colored_conf[i]], dim=1)
to_save = (to_save * 255).detach().cpu().numpy().astype(np.uint8)
iio.imwrite(img_path, to_save)
images.append(Image.open(img_path))
np.save(f"{path}/frame_{(i):04d}.npy", depth_maps[i].detach().cpu().numpy())
images[0].save(
f"{path}/_depth_maps.gif",
save_all=True,
append_images=images[1:],
duration=100,
loop=0,
)
return depth_maps
def get_vertical_colorbar(h, vmin, vmax, cmap_name="jet", label=None, cbar_precision=2):
"""
:param w: pixels
:param h: pixels
:param vmin: min value
:param vmax: max value
:param cmap_name:
:param label
:return:
"""
fig = Figure(figsize=(2, 8), dpi=100)
fig.subplots_adjust(right=1.5)
canvas = FigureCanvasAgg(fig)
# Do some plotting.
ax = fig.add_subplot(111)
cmap = cm.get_cmap(cmap_name)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
tick_cnt = 6
tick_loc = np.linspace(vmin, vmax, tick_cnt)
cb1 = mpl.colorbar.ColorbarBase(
ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical"
)
tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
if cbar_precision == 0:
tick_label = [x[:-2] for x in tick_label]
cb1.set_ticklabels(tick_label)
cb1.ax.tick_params(labelsize=18, rotation=0)
if label is not None:
cb1.set_label(label)
# fig.tight_layout()
canvas.draw()
s, (width, height) = canvas.print_to_buffer()
im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
im = im[:, :, :3].astype(np.float32) / 255.0
if h != im.shape[0]:
w = int(im.shape[1] / im.shape[0] * h)
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
return im
def colorize_np(
x,
cmap_name="jet",
mask=None,
range=None,
append_cbar=False,
cbar_in_image=False,
cbar_precision=2,
):
"""
turn a grayscale image into a color image
:param x: input grayscale, [H, W]
:param cmap_name: the colorization method
:param mask: the mask image, [H, W]
:param range: the range for scaling, automatic if None, [min, max]
:param append_cbar: if append the color bar
:param cbar_in_image: put the color bar inside the image to keep the output image the same size as the input image
:return: colorized image, [H, W]
"""
if range is not None:
vmin, vmax = range
elif mask is not None:
# vmin, vmax = np.percentile(x[mask], (2, 100))
vmin = np.min(x[mask][np.nonzero(x[mask])])
vmax = np.max(x[mask])
# vmin = vmin - np.abs(vmin) * 0.01
x[np.logical_not(mask)] = vmin
# print(vmin, vmax)
else:
vmin, vmax = np.percentile(x, (1, 100))
vmax += 1e-6
x = np.clip(x, vmin, vmax)
x = (x - vmin) / (vmax - vmin)
# x = np.clip(x, 0., 1.)
cmap = cm.get_cmap(cmap_name)
x_new = cmap(x)[:, :, :3]
if mask is not None:
mask = np.float32(mask[:, :, np.newaxis])
x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)
cbar = get_vertical_colorbar(
h=x.shape[0],
vmin=vmin,
vmax=vmax,
cmap_name=cmap_name,
cbar_precision=cbar_precision,
)
if append_cbar:
if cbar_in_image:
x_new[:, -cbar.shape[1] :, :] = cbar
else:
x_new = np.concatenate(
(x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
)
return x_new
else:
return x_new
# tensor
def colorize(
x, cmap_name="jet", mask=None, range=None, append_cbar=False, cbar_in_image=False
):
"""
turn a grayscale image into a color image
:param x: torch.Tensor, grayscale image, [H, W] or [B, H, W]
:param mask: torch.Tensor or None, mask image, [H, W] or [B, H, W] or None
"""
device = x.device
x = x.cpu().numpy()
if mask is not None:
mask = mask.cpu().numpy() > 0.99
kernel = np.ones((3, 3), np.uint8)
if x.ndim == 2:
x = x[None]
if mask is not None:
mask = mask[None]
out = []
for x_ in x:
if mask is not None:
mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image)
out.append(torch.from_numpy(x_).to(device).float())
out = torch.stack(out).squeeze(0)
return out
```
## /eval/video_depth/eval_depth.py
```py path="/eval/video_depth/eval_depth.py"
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from eval.video_depth.tools import depth_evaluation, group_by_directory
import numpy as np
import cv2
from tqdm import tqdm
import glob
from PIL import Image
import argparse
import json
from eval.video_depth.metadata import dataset_metadata
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--output_dir",
type=str,
default="",
help="value for outdir",
)
parser.add_argument(
"--eval_dataset", type=str, default="nyu", choices=list(dataset_metadata.keys())
)
parser.add_argument(
"--align",
type=str,
default="scale&shift",
choices=["scale&shift", "scale", "metric"],
)
return parser
def main(args):
if args.eval_dataset == "sintel":
TAG_FLOAT = 202021.25
def depth_read(filename):
"""Read depth data from file, return as numpy array."""
f = open(filename, "rb")
check = np.fromfile(f, dtype=np.float32, count=1)[0]
assert (
check == TAG_FLOAT
), " depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? ".format(
TAG_FLOAT, check
)
width = np.fromfile(f, dtype=np.int32, count=1)[0]
height = np.fromfile(f, dtype=np.int32, count=1)[0]
size = width * height
assert (
width > 0 and height > 0 and size > 1 and size < 100000000
), " depth_read:: Wrong input size (width = {0}, height = {1}).".format(
width, height
)
depth = np.fromfile(f, dtype=np.float32, count=-1).reshape((height, width))
return depth
pred_pathes = glob.glob(
f"{args.output_dir}/*/frame_*.npy"
) # TODO: update the path to your prediction
pred_pathes = sorted(pred_pathes)
if len(pred_pathes) > 643:
full = True
else:
full = False
if full:
depth_pathes = glob.glob(f"./data/sintel/training/depth/*/*.dpt")
depth_pathes = sorted(depth_pathes)
else:
seq_list = [
"alley_2",
"ambush_4",
"ambush_5",
"ambush_6",
"cave_2",
"cave_4",
"market_2",
"market_5",
"market_6",
"shaman_3",
"sleeping_1",
"sleeping_2",
"temple_2",
"temple_3",
]
depth_pathes_folder = [
f"./data/sintel/training/depth/{seq}" for seq in seq_list
]
depth_pathes = []
for depth_pathes_folder_i in depth_pathes_folder:
depth_pathes += glob.glob(depth_pathes_folder_i + "/*.dpt")
depth_pathes = sorted(depth_pathes)
def get_video_results():
grouped_pred_depth = group_by_directory(pred_pathes)
grouped_gt_depth = group_by_directory(depth_pathes)
gathered_depth_metrics = []
for key in tqdm(grouped_pred_depth.keys()):
pd_pathes = grouped_pred_depth[key]
gt_pathes = grouped_gt_depth[key.replace("_pred_depth", "")]
gt_depth = np.stack(
[depth_read(gt_path) for gt_path in gt_pathes], axis=0
)
pr_depth = np.stack(
[
cv2.resize(
np.load(pd_path),
(gt_depth.shape[2], gt_depth.shape[1]),
interpolation=cv2.INTER_CUBIC,
)
for pd_path in pd_pathes
],
axis=0,
)
# for depth eval, set align_with_lad2=False to use median alignment; set align_with_lad2=True to use scale&shift alignment
if args.align == "scale&shift":
depth_results, error_map, depth_predict, depth_gt = (
depth_evaluation(
pr_depth,
gt_depth,
max_depth=70,
align_with_lad2=True,
use_gpu=True,
post_clip_max=70,
)
)
elif args.align == "scale":
depth_results, error_map, depth_predict, depth_gt = (
depth_evaluation(
pr_depth,
gt_depth,
max_depth=70,
align_with_scale=True,
use_gpu=True,
post_clip_max=70,
)
)
elif args.align == "metric":
depth_results, error_map, depth_predict, depth_gt = (
depth_evaluation(
pr_depth,
gt_depth,
max_depth=70,
metric_scale=True,
use_gpu=True,
post_clip_max=70,
)
)
gathered_depth_metrics.append(depth_results)
depth_log_path = f"{args.output_dir}/result_{args.align}.json"
average_metrics = {
key: np.average(
[metrics[key] for metrics in gathered_depth_metrics],
weights=[
metrics["valid_pixels"] for metrics in gathered_depth_metrics
],
)
for key in gathered_depth_metrics[0].keys()
if key != "valid_pixels"
}
print("Average depth evaluation metrics:", average_metrics)
with open(depth_log_path, "w") as f:
f.write(json.dumps(average_metrics))
get_video_results()
elif args.eval_dataset.startswith("bonn"):
def depth_read(filename):
# loads depth map D from png file
# and returns it as a numpy array
depth_png = np.asarray(Image.open(filename))
# make sure we have a proper 16bit depth map here.. not 8bit!
assert np.max(depth_png) > 255
depth = depth_png.astype(np.float64) / 5000.0
depth[depth_png == 0] = -1.0
return depth
seq_list = ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"]
# extract number from dataset name, e.g. bonn_400 -> 400
if "_" in args.eval_dataset:
bonn_number = args.eval_dataset.split("_")[-1]
else:
bonn_number = "110" # default value
img_pathes_folder = [
f"./data/long_bonn_s1/rgbd_bonn_dataset/rgbd_bonn_{seq}/rgb_{bonn_number}/*.png"
for seq in seq_list
]
img_pathes = []
for img_pathes_folder_i in img_pathes_folder:
img_pathes += glob.glob(img_pathes_folder_i)
img_pathes = sorted(img_pathes)
depth_pathes_folder = [
f"./data/long_bonn_s1/rgbd_bonn_dataset/rgbd_bonn_{seq}/depth_{bonn_number}/*.png"
for seq in seq_list
]
depth_pathes = []
for depth_pathes_folder_i in depth_pathes_folder:
depth_pathes += glob.glob(depth_pathes_folder_i)
depth_pathes = sorted(depth_pathes)
pred_pathes = glob.glob(
f"{args.output_dir}/*/frame*.npy"
) # TODO: update the path to your prediction
pred_pathes = sorted(pred_pathes)
def get_video_results():
grouped_pred_depth = group_by_directory(pred_pathes)
grouped_gt_depth = group_by_directory(depth_pathes, idx=-2)
gathered_depth_metrics = []
for key in tqdm(grouped_gt_depth.keys()):
pd_pathes = grouped_pred_depth[key[10:]]
gt_pathes = grouped_gt_depth[key]
gt_depth = np.stack(
[depth_read(gt_path) for gt_path in gt_pathes], axis=0
)
pr_depth = np.stack(
[
cv2.resize(
np.load(pd_path),
(gt_depth.shape[2], gt_depth.shape[1]),
interpolation=cv2.INTER_CUBIC,
)
for pd_path in pd_pathes
],
axis=0,
)
# for depth eval, set align_with_lad2=False to use median alignment; set align_with_lad2=True to use scale&shift alignment
if args.align == "scale&shift":
depth_results, error_map, depth_predict, depth_gt = (
depth_evaluation(
pr_depth,
gt_depth,
max_depth=70,
align_with_lad2=True,
use_gpu=True,
)
)
elif args.align == "scale":
depth_results, error_map, depth_predict, depth_gt = (
depth_evaluation(
pr_depth,
gt_depth,
max_depth=70,
align_with_scale=True,
use_gpu=True,
)
)
elif args.align == "metric":
depth_results, error_map, depth_predict, depth_gt = (
depth_evaluation(
pr_depth,
gt_depth,
max_depth=70,
metric_scale=True,
use_gpu=True,
)
)
gathered_depth_metrics.append(depth_results)
# seq_len = gt_depth.shape[0]
# error_map = error_map.reshape(seq_len, -1, error_map.shape[-1]).cpu()
# error_map_colored = colorize(error_map, range=(error_map.min(), error_map.max()), append_cbar=True)
# ImageSequenceClip([x for x in (error_map_colored.numpy()*255).astype(np.uint8)], fps=10).write_videofile(f'{args.output_dir}/errormap_{key}_{args.align}.mp4', fps=10)
depth_log_path = f"{args.output_dir}/result_{args.align}.json"
average_metrics = {
key: np.average(
[metrics[key] for metrics in gathered_depth_metrics],
weights=[
metrics["valid_pixels"] for metrics in gathered_depth_metrics
],
)
for key in gathered_depth_metrics[0].keys()
if key != "valid_pixels"
}
print("Average depth evaluation metrics:", average_metrics)
with open(depth_log_path, "w") as f:
f.write(json.dumps(average_metrics))
get_video_results()
elif args.eval_dataset.startswith("kitti"):
def depth_read(filename):
# loads depth map D from png file
# and returns it as a numpy array,
# for details see readme.txt
img_pil = Image.open(filename)
depth_png = np.array(img_pil, dtype=int)
# make sure we have a proper 16bit depth map here.. not 8bit!
assert np.max(depth_png) > 255
depth = depth_png.astype(float) / 256.0
depth[depth_png == 0] = -1.0
return depth
# extract number from dataset name, e.g. kitti_100 -> 100
if "_" in args.eval_dataset:
kitti_number = args.eval_dataset.split("_")[-1]
else:
kitti_number = "110" # default value
depth_pathes = glob.glob(
f"./data/long_kitti_s1/depth_selection/val_selection_cropped/groundtruth_depth_gathered_{kitti_number}/*/*.png"
)
depth_pathes = sorted(depth_pathes)
pred_pathes = glob.glob(
f"{args.output_dir}/*/frame_*.npy"
) # TODO: update the path to your prediction
pred_pathes = sorted(pred_pathes)
def get_video_results():
grouped_pred_depth = group_by_directory(pred_pathes)
grouped_gt_depth = group_by_directory(depth_pathes)
gathered_depth_metrics = []
for key in tqdm(grouped_pred_depth.keys()):
pd_pathes = grouped_pred_depth[key]
gt_pathes = grouped_gt_depth[key]
gt_depth = np.stack(
[depth_read(gt_path) for gt_path in gt_pathes], axis=0
)
pr_depth = np.stack(
[
cv2.resize(
np.load(pd_path),
(gt_depth.shape[2], gt_depth.shape[1]),
interpolation=cv2.INTER_CUBIC,
)
for pd_path in pd_pathes
],
axis=0,
)
# for depth eval, set align_with_lad2=False to use median alignment; set align_with_lad2=True to use scale&shift alignment
if args.align == "scale&shift":
depth_results, error_map, depth_predict, depth_gt = (
depth_evaluation(
pr_depth,
gt_depth,
max_depth=None,
align_with_lad2=True,
use_gpu=True,
)
)
elif args.align == "scale":
depth_results, error_map, depth_predict, depth_gt = (
depth_evaluation(
pr_depth,
gt_depth,
max_depth=None,
align_with_scale=True,
use_gpu=True,
)
)
elif args.align == "metric":
depth_results, error_map, depth_predict, depth_gt = (
depth_evaluation(
pr_depth,
gt_depth,
max_depth=None,
metric_scale=True,
use_gpu=True,
)
)
gathered_depth_metrics.append(depth_results)
depth_log_path = f"{args.output_dir}/result_{args.align}.json"
average_metrics = {
key: np.average(
[metrics[key] for metrics in gathered_depth_metrics],
weights=[
metrics["valid_pixels"] for metrics in gathered_depth_metrics
],
)
for key in gathered_depth_metrics[0].keys()
if key != "valid_pixels"
}
print("Average depth evaluation metrics:", average_metrics)
with open(depth_log_path, "w") as f:
f.write(json.dumps(average_metrics))
get_video_results()
if __name__ == "__main__":
args = get_args_parser()
args = args.parse_args()
main(args)
```
## /eval/video_depth/launch.py
```py path="/eval/video_depth/launch.py"
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import math
import cv2
import numpy as np
import torch
import argparse
from copy import deepcopy
from eval.video_depth.metadata import dataset_metadata
from eval.video_depth.utils import save_depth_maps
from accelerate import PartialState
from add_ckpt_path import add_path_to_dust3r
import time
from tqdm import tqdm
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--weights",
type=str,
help="path to the model weights",
default="",
)
parser.add_argument("--device", type=str, default="cuda", help="pytorch device")
parser.add_argument(
"--output_dir",
type=str,
default="",
help="value for outdir",
)
parser.add_argument(
"--no_crop", type=bool, default=True, help="whether to crop input data"
)
parser.add_argument(
"--eval_dataset",
type=str,
default="sintel",
choices=list(dataset_metadata.keys()),
)
parser.add_argument("--size", type=int, default="224")
parser.add_argument(
"--model_update_type",
type=str,
default="cut3r",
help="model type for state update strategy: cut3r or ttt3r",
)
parser.add_argument(
"--pose_eval_stride", default=1, type=int, help="stride for pose evaluation"
)
parser.add_argument(
"--full_seq",
action="store_true",
default=False,
help="use full sequence for pose evaluation",
)
parser.add_argument(
"--seq_list",
nargs="+",
default=None,
help="list of sequences for pose evaluation",
)
return parser
def eval_pose_estimation(args, model, save_dir=None):
metadata = dataset_metadata.get(args.eval_dataset)
img_path = metadata["img_path"]
mask_path = metadata["mask_path"]
ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist(
args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path
)
return ate_mean, rpe_trans_mean, rpe_rot_mean
def eval_pose_estimation_dist(args, model, img_path, save_dir=None, mask_path=None):
from dust3r.inference import inference, inference_recurrent, inference_recurrent_lighter
metadata = dataset_metadata.get(args.eval_dataset)
anno_path = metadata.get("anno_path", None)
seq_list = args.seq_list
if seq_list is None:
if metadata.get("full_seq", False):
args.full_seq = True
else:
seq_list = metadata.get("seq_list", [])
if args.full_seq:
seq_list = os.listdir(img_path)
seq_list = [
seq for seq in seq_list if os.path.isdir(os.path.join(img_path, seq))
]
seq_list = sorted(seq_list)
if save_dir is None:
save_dir = args.output_dir
distributed_state = PartialState()
model.to(distributed_state.device)
device = distributed_state.device
with distributed_state.split_between_processes(seq_list) as seqs:
ate_list = []
rpe_trans_list = []
rpe_rot_list = []
load_img_size = args.size
assert load_img_size == 512
error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt" # Unique log file per process
bug = False
for seq in tqdm(seqs):
try:
dir_path = metadata["dir_path_func"](img_path, seq)
# Handle skip_condition
skip_condition = metadata.get("skip_condition", None)
if skip_condition is not None and skip_condition(save_dir, seq):
continue
mask_path_seq_func = metadata.get(
"mask_path_seq_func", lambda mask_path, seq: None
)
mask_path_seq = mask_path_seq_func(mask_path, seq)
filelist = [
os.path.join(dir_path, name) for name in os.listdir(dir_path)
]
filelist.sort()
filelist = filelist[:: args.pose_eval_stride]
views = prepare_input(
filelist,
[True for _ in filelist],
size=load_img_size,
crop=not args.no_crop,
)
start = time.time()
outputs, _ = inference_recurrent_lighter(views, model, device)
end = time.time()
fps = len(filelist) / (end - start)
(
colors,
pts3ds_self,
pts3ds_other,
conf_self,
conf_other,
cam_dict,
pr_poses,
) = prepare_output(outputs)
os.makedirs(f"{save_dir}/{seq}", exist_ok=True)
save_depth_maps(pts3ds_self, f"{save_dir}/{seq}", conf_self=conf_self)
except Exception as e:
if "out of memory" in str(e):
# Handle OOM
torch.cuda.empty_cache() # Clear the CUDA memory
with open(error_log_path, "a") as f:
f.write(
f"OOM error in sequence {seq}, skipping this sequence.\n"
)
print(f"OOM error in sequence {seq}, skipping...")
elif "Degenerate covariance rank" in str(
e
) or "Eigenvalues did not converge" in str(e):
# Handle Degenerate covariance rank exception and Eigenvalues did not converge exception
with open(error_log_path, "a") as f:
f.write(f"Exception in sequence {seq}: {str(e)}\n")
print(f"Traj evaluation error in sequence {seq}, skipping.")
else:
raise e # Rethrow if it's not an expected exception
return None, None, None
if __name__ == "__main__":
args = get_args_parser()
args = args.parse_args()
add_path_to_dust3r(args.weights)
from dust3r.utils.image import load_images_for_eval as load_images
from dust3r.post_process import estimate_focal_knowing_depth
from dust3r.model import ARCroco3DStereo
from dust3r.utils.camera import pose_encoding_to_camera
if args.eval_dataset == "sintel":
args.full_seq = True
else:
args.full_seq = False
args.no_crop = True
def prepare_input(
img_paths,
img_mask,
size,
raymaps=None,
raymap_mask=None,
revisit=1,
update=True,
crop=True,
):
images = load_images(img_paths, size=size, crop=crop)
views = []
if raymaps is None and raymap_mask is None:
num_views = len(images)
for i in range(num_views):
view = {
"img": images[i]["img"],
"ray_map": torch.full(
(
images[i]["img"].shape[0],
6,
images[i]["img"].shape[-2],
images[i]["img"].shape[-1],
),
torch.nan,
),
"true_shape": torch.from_numpy(images[i]["true_shape"]),
"idx": i,
"instance": str(i),
"camera_pose": torch.from_numpy(
np.eye(4).astype(np.float32)
).unsqueeze(0),
"img_mask": torch.tensor(True).unsqueeze(0),
"ray_mask": torch.tensor(False).unsqueeze(0),
"update": torch.tensor(True).unsqueeze(0),
"reset": torch.tensor(False).unsqueeze(0),
}
views.append(view)
else:
num_views = len(images) + len(raymaps)
assert len(img_mask) == len(raymap_mask) == num_views
assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps)
j = 0
k = 0
for i in range(num_views):
view = {
"img": (
images[j]["img"]
if img_mask[i]
else torch.full_like(images[0]["img"], torch.nan)
),
"ray_map": (
raymaps[k]
if raymap_mask[i]
else torch.full_like(raymaps[0], torch.nan)
),
"true_shape": (
torch.from_numpy(images[j]["true_shape"])
if img_mask[i]
else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]]))
),
"idx": i,
"instance": str(i),
"camera_pose": torch.from_numpy(
np.eye(4).astype(np.float32)
).unsqueeze(0),
"img_mask": torch.tensor(img_mask[i]).unsqueeze(0),
"ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0),
"update": torch.tensor(img_mask[i]).unsqueeze(0),
"reset": torch.tensor(False).unsqueeze(0),
}
if img_mask[i]:
j += 1
if raymap_mask[i]:
k += 1
views.append(view)
assert j == len(images) and k == len(raymaps)
if revisit > 1:
# repeat input for 'revisit' times
new_views = []
for r in range(revisit):
for i in range(len(views)):
new_view = deepcopy(views[i])
new_view["idx"] = r * len(views) + i
new_view["instance"] = str(r * len(views) + i)
if r > 0:
if not update:
new_view["update"] = torch.tensor(False).unsqueeze(0)
new_views.append(new_view)
return new_views
return views
def prepare_output(outputs, revisit=1):
valid_length = len(outputs["pred"]) // revisit
outputs["pred"] = outputs["pred"][-valid_length:]
outputs["views"] = outputs["views"][-valid_length:]
pts3ds_self = [output["pts3d_in_self_view"].cpu() for output in outputs["pred"]]
pts3ds_other = [
output["pts3d_in_other_view"].cpu() for output in outputs["pred"]
]
conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
conf_other = [output["conf"].cpu() for output in outputs["pred"]]
pts3ds_self = torch.cat(pts3ds_self, 0)
pr_poses = [
pose_encoding_to_camera(pred["camera_pose"].clone()).cpu()
for pred in outputs["pred"]
]
pr_poses = torch.cat(pr_poses, 0)
B, H, W, _ = pts3ds_self.shape
pp = (
torch.tensor([W // 2, H // 2], device=pts3ds_self.device)
.float()
.repeat(B, 1)
.reshape(B, 2)
)
focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld")
colors = [0.5 * (output["rgb"][0] + 1.0) for output in outputs["pred"]]
cam_dict = {
"focal": focal.cpu().numpy(),
"pp": pp.cpu().numpy(),
}
return (
colors,
pts3ds_self,
pts3ds_other,
conf_self,
conf_other,
cam_dict,
pr_poses,
)
model = ARCroco3DStereo.from_pretrained(args.weights)
# set model type
model.config.model_update_type = args.model_update_type
eval_pose_estimation(args, model, save_dir=args.output_dir)
```
## /eval/video_depth/metadata.py
```py path="/eval/video_depth/metadata.py"
import os
import glob
from tqdm import tqdm
# Define the merged dataset metadata dictionary
dataset_metadata = {
"davis": {
"img_path": "data/davis/DAVIS/JPEGImages/480p",
"mask_path": "data/davis/DAVIS/masked_images/480p",
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
"gt_traj_func": lambda img_path, anno_path, seq: None,
"traj_format": None,
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq),
"skip_condition": None,
"process_func": None, # Not used in mono depth estimation
},
"kitti": {
"img_path": "data/kitti/depth_selection/val_selection_cropped/image_gathered", # Default path
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
"gt_traj_func": lambda img_path, anno_path, seq: None,
"traj_format": None,
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": lambda args, img_path: process_kitti(args, img_path),
},
"bonn": {
"img_path": "data/bonn/rgbd_bonn_dataset",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(
img_path, f"rgbd_bonn_{seq}", "rgb_110"
),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt"
),
"traj_format": "tum",
"seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"],
"full_seq": False,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": lambda args, img_path: process_bonn(args, img_path),
},
"nyu": {
"img_path": "data/nyu-v2/val/nyu_images",
"mask_path": None,
"process_func": lambda args, img_path: process_nyu(args, img_path),
},
"scannet": {
"img_path": "data/scannetv2",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, seq, "pose_90.txt"
),
"traj_format": "replica",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
"process_func": lambda args, img_path: process_scannet(args, img_path),
},
"tum": {
"img_path": "data/tum",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, seq, "groundtruth_90.txt"
),
"traj_format": "tum",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": None,
},
"sintel": {
"img_path": "data/sintel/training/final",
"anno_path": "data/sintel/training/camdata_left",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq),
"traj_format": None,
"seq_list": [
"alley_2",
"ambush_4",
"ambush_5",
"ambush_6",
"cave_2",
"cave_4",
"market_2",
"market_5",
"market_6",
"shaman_3",
"sleeping_1",
"sleeping_2",
"temple_2",
"temple_3",
],
"full_seq": False,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": lambda args, img_path: process_sintel(args, img_path),
},
}
kitti_numbers = [50, 100, 110, 150, 200, 250, 300, 350, 400, 450, 500]
kitti_configs = {
f"kitti_s1_{num}": {
"img_path": f"data/long_kitti_s1/depth_selection/val_selection_cropped/image_gathered_{num}", # Default path
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
"gt_traj_func": lambda img_path, anno_path, seq: None,
"traj_format": None,
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": lambda args, img_path: process_kitti(args, img_path),
}
for num in kitti_numbers
}
dataset_metadata.update(kitti_configs)
bonn_numbers = [50, 100, 110, 150, 200, 250, 300, 350, 400, 450, 500]
bonn_configs = {
f"bonn_s1_{num}": {
"img_path": "data/long_bonn_s1/rgbd_bonn_dataset",
"mask_path": None,
"dir_path_func": lambda img_path, seq, num=num: os.path.join(
img_path, f"rgbd_bonn_{seq}", f"rgb_{num}"
),
"gt_traj_func": lambda img_path, anno_path, seq, num=num: os.path.join(
img_path, f"rgbd_bonn_{seq}", f"groundtruth_{num}.txt"
),
"traj_format": "tum",
"seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"],
"full_seq": False,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": lambda args, img_path: process_bonn(args, img_path),
}
for num in bonn_numbers
}
dataset_metadata.update(bonn_configs)
# Define processing functions for each dataset
def process_kitti(args, img_path):
for dir in tqdm(sorted(glob.glob(f"{img_path}/*"))):
filelist = sorted(glob.glob(f"{dir}/*.png"))
save_dir = f"{args.output_dir}/{os.path.basename(dir)}"
yield filelist, save_dir
def process_bonn(args, img_path):
if args.full_seq:
for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))):
filelist = sorted(glob.glob(f"{dir}/rgb/*.png"))
save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}"
yield filelist, save_dir
else:
seq_list = (
["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"]
if args.seq_list is None
else args.seq_list
)
for seq in tqdm(seq_list):
filelist = sorted(glob.glob(f"{img_path}/rgbd_bonn_{seq}/rgb_110/*.png"))
save_dir = f"{args.output_dir}/{seq}"
yield filelist, save_dir
def process_nyu(args, img_path):
filelist = sorted(glob.glob(f"{img_path}/*.png"))
save_dir = f"{args.output_dir}"
yield filelist, save_dir
def process_scannet(args, img_path):
seq_list = sorted(glob.glob(f"{img_path}/*"))
for seq in tqdm(seq_list):
filelist = sorted(glob.glob(f"{seq}/color_90/*.jpg"))
save_dir = f"{args.output_dir}/{os.path.basename(seq)}"
yield filelist, save_dir
def process_sintel(args, img_path):
if args.full_seq:
for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))):
filelist = sorted(glob.glob(f"{dir}/*.png"))
save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}"
yield filelist, save_dir
else:
seq_list = [
"alley_2",
"ambush_4",
"ambush_5",
"ambush_6",
"cave_2",
"cave_4",
"market_2",
"market_5",
"market_6",
"shaman_3",
"sleeping_1",
"sleeping_2",
"temple_2",
"temple_3",
]
for seq in tqdm(seq_list):
filelist = sorted(glob.glob(f"{img_path}/{seq}/*.png"))
save_dir = f"{args.output_dir}/{seq}"
yield filelist, save_dir
```
## /eval/video_depth/run_bonn.sh
```sh path="/eval/video_depth/run_bonn.sh"
#!/bin/bash
set -e
workdir='.'
model_names=('ttt3r') # ttt3r cut3r
ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"
# datasets=('bonn_s1_50' 'bonn_s1_100' 'bonn_s1_110' 'bonn_s1_150' 'bonn_s1_200' 'bonn_s1_250' 'bonn_s1_300' 'bonn_s1_350' 'bonn_s1_400' 'bonn_s1_450' 'bonn_s1_500')
datasets=('bonn_s1_500')
for model_name in "${model_names[@]}"; do
for data in "${datasets[@]}"; do
output_dir="${workdir}/eval_results/video_depth/${data}/${model_name}"
echo "$output_dir"
accelerate launch --num_processes 1 --main_process_port 29556 eval/video_depth/launch.py \
--weights "$model_weights" \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--size 512 \
--model_update_type "$model_name"
# scale&shift scale metric
python eval/video_depth/eval_depth.py \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--align "metric"
python eval/video_depth/eval_depth.py \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--align "scale"
python eval/video_depth/eval_depth.py \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--align "scale&shift"
done
done
```
## /eval/video_depth/run_kitti.sh
```sh path="/eval/video_depth/run_kitti.sh"
#!/bin/bash
set -e
workdir='.'
model_names=('ttt3r') # ttt3r cut3r
ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"
# datasets=('kitti_s1_50' 'kitti_s1_100' 'kitti_s1_110' 'kitti_s1_150' 'kitti_s1_200' 'kitti_s1_250' 'kitti_s1_300' 'kitti_s1_350' 'kitti_s1_400' 'kitti_s1_450' 'kitti_s1_500')
datasets=('kitti_s1_500')
for model_name in "${model_names[@]}"; do
for data in "${datasets[@]}"; do
output_dir="${workdir}/eval_results/video_depth/${data}/${model_name}"
echo "$output_dir"
accelerate launch --num_processes 1 --main_process_port 29555 eval/video_depth/launch.py \
--weights "$model_weights" \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--size 512 \
--model_update_type "$model_name"
# scale&shift scale metric
python eval/video_depth/eval_depth.py \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--align "metric"
python eval/video_depth/eval_depth.py \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--align "scale"
python eval/video_depth/eval_depth.py \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--align "scale&shift"
done
done
```
## /eval/video_depth/run_sintel.sh
```sh path="/eval/video_depth/run_sintel.sh"
#!/bin/bash
set -e
workdir='.'
model_names=('ttt3r') # ttt3r cut3r
ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"
datasets=('sintel')
for model_name in "${model_names[@]}"; do
for data in "${datasets[@]}"; do
output_dir="${workdir}/eval_results/video_depth/${data}/${model_name}"
echo "$output_dir"
accelerate launch --num_processes 1 --main_process_port 29555 eval/video_depth/launch.py \
--weights "$model_weights" \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--size 512 \
--model_update_type "$model_name"
# scale&shift scale metric
python eval/video_depth/eval_depth.py \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--align "metric"
python eval/video_depth/eval_depth.py \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--align "scale"
python eval/video_depth/eval_depth.py \
--output_dir "$output_dir" \
--eval_dataset "$data" \
--align "scale&shift"
done
done
```
## /eval/video_depth/tools.py
```py path="/eval/video_depth/tools.py"
import torch
import numpy as np
import cv2
import glob
import argparse
from pathlib import Path
from tqdm import tqdm
from copy import deepcopy
from scipy.optimize import minimize
import os
from collections import defaultdict
def group_by_directory(pathes, idx=-1):
"""
Groups the file paths based on the second-to-last directory in their paths.
Parameters:
- pathes (list): List of file paths.
Returns:
- dict: A dictionary where keys are the second-to-last directory names and values are lists of file paths.
"""
grouped_pathes = defaultdict(list)
for path in pathes:
# Extract the second-to-last directory
dir_name = os.path.dirname(path).split("/")[idx]
grouped_pathes[dir_name].append(path)
return grouped_pathes
def depth2disparity(depth, return_mask=False):
if isinstance(depth, torch.Tensor):
disparity = torch.zeros_like(depth)
elif isinstance(depth, np.ndarray):
disparity = np.zeros_like(depth)
non_negtive_mask = depth > 0
disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
if return_mask:
return disparity, non_negtive_mask
else:
return disparity
def absolute_error_loss(params, predicted_depth, ground_truth_depth):
s, t = params
predicted_aligned = s * predicted_depth + t
abs_error = np.abs(predicted_aligned - ground_truth_depth)
return np.sum(abs_error)
def absolute_value_scaling(predicted_depth, ground_truth_depth, s=1, t=0):
predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1)
ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1)
initial_params = [s, t] # s = 1, t = 0
result = minimize(
absolute_error_loss,
initial_params,
args=(predicted_depth_np, ground_truth_depth_np),
)
s, t = result.x
return s, t
def absolute_value_scaling2(
predicted_depth,
ground_truth_depth,
s_init=1.0,
t_init=0.0,
lr=1e-4,
max_iters=1000,
tol=1e-6,
):
# Initialize s and t as torch tensors with requires_grad=True
s = torch.tensor(
[s_init],
requires_grad=True,
device=predicted_depth.device,
dtype=predicted_depth.dtype,
)
t = torch.tensor(
[t_init],
requires_grad=True,
device=predicted_depth.device,
dtype=predicted_depth.dtype,
)
optimizer = torch.optim.Adam([s, t], lr=lr)
prev_loss = None
for i in range(max_iters):
optimizer.zero_grad()
# Compute predicted aligned depth
predicted_aligned = s * predicted_depth + t
# Compute absolute error
abs_error = torch.abs(predicted_aligned - ground_truth_depth)
# Compute loss
loss = torch.sum(abs_error)
# Backpropagate
loss.backward()
# Update parameters
optimizer.step()
# Check convergence
if prev_loss is not None and torch.abs(prev_loss - loss) < tol:
break
prev_loss = loss.item()
return s.detach().item(), t.detach().item()
def depth_evaluation(
predicted_depth_original,
ground_truth_depth_original,
max_depth=80,
custom_mask=None,
post_clip_min=None,
post_clip_max=None,
pre_clip_min=None,
pre_clip_max=None,
align_with_lstsq=False,
align_with_lad=False,
align_with_lad2=False,
metric_scale=False,
lr=1e-4,
max_iters=1000,
use_gpu=False,
align_with_scale=False,
disp_input=False,
):
"""
Evaluate the depth map using various metrics and return a depth error parity map, with an option for least squares alignment.
Args:
predicted_depth (numpy.ndarray or torch.Tensor): The predicted depth map.
ground_truth_depth (numpy.ndarray or torch.Tensor): The ground truth depth map.
max_depth (float): The maximum depth value to consider. Default is 80 meters.
align_with_lstsq (bool): If True, perform least squares alignment of the predicted depth with ground truth.
Returns:
dict: A dictionary containing the evaluation metrics.
torch.Tensor: The depth error parity map.
"""
if isinstance(predicted_depth_original, np.ndarray):
predicted_depth_original = torch.from_numpy(predicted_depth_original)
if isinstance(ground_truth_depth_original, np.ndarray):
ground_truth_depth_original = torch.from_numpy(ground_truth_depth_original)
if custom_mask is not None and isinstance(custom_mask, np.ndarray):
custom_mask = torch.from_numpy(custom_mask)
# if the dimension is 3, flatten to 2d along the batch dimension
if predicted_depth_original.dim() == 3:
_, h, w = predicted_depth_original.shape
predicted_depth_original = predicted_depth_original.view(-1, w)
ground_truth_depth_original = ground_truth_depth_original.view(-1, w)
if custom_mask is not None:
custom_mask = custom_mask.view(-1, w)
# put to device
if use_gpu:
predicted_depth_original = predicted_depth_original.cuda()
ground_truth_depth_original = ground_truth_depth_original.cuda()
# Filter out depths greater than max_depth
if max_depth is not None:
mask = (ground_truth_depth_original > 0) & (
ground_truth_depth_original < max_depth
)
else:
mask = ground_truth_depth_original > 0
predicted_depth = predicted_depth_original[mask]
ground_truth_depth = ground_truth_depth_original[mask]
# Clip the depth values
if pre_clip_min is not None:
predicted_depth = torch.clamp(predicted_depth, min=pre_clip_min)
if pre_clip_max is not None:
predicted_depth = torch.clamp(predicted_depth, max=pre_clip_max)
if disp_input: # align the pred to gt in the disparity space
real_gt = ground_truth_depth.clone()
ground_truth_depth = 1 / (ground_truth_depth + 1e-8)
# various alignment methods
if metric_scale:
predicted_depth = predicted_depth
elif align_with_lstsq:
# Convert to numpy for lstsq
predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1, 1)
ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1, 1)
# Add a column of ones for the shift term
A = np.hstack([predicted_depth_np, np.ones_like(predicted_depth_np)])
# Solve for scale (s) and shift (t) using least squares
result = np.linalg.lstsq(A, ground_truth_depth_np, rcond=None)
s, t = result[0][0], result[0][1]
# convert to torch tensor
s = torch.tensor(s, device=predicted_depth_original.device)
t = torch.tensor(t, device=predicted_depth_original.device)
# Apply scale and shift
predicted_depth = s * predicted_depth + t
elif align_with_lad:
s, t = absolute_value_scaling(
predicted_depth,
ground_truth_depth,
s=torch.median(ground_truth_depth) / torch.median(predicted_depth),
)
predicted_depth = s * predicted_depth + t
elif align_with_lad2:
s_init = (
torch.median(ground_truth_depth) / torch.median(predicted_depth)
).item()
s, t = absolute_value_scaling2(
predicted_depth,
ground_truth_depth,
s_init=s_init,
lr=lr,
max_iters=max_iters,
)
predicted_depth = s * predicted_depth + t
elif align_with_scale:
# Compute initial scale factor 's' using the closed-form solution (L2 norm)
dot_pred_gt = torch.nanmean(ground_truth_depth)
dot_pred_pred = torch.nanmean(predicted_depth)
s = dot_pred_gt / dot_pred_pred
# Iterative reweighted least squares using the Weiszfeld method
for _ in range(10):
# Compute residuals between scaled predictions and ground truth
residuals = s * predicted_depth - ground_truth_depth
abs_residuals = (
residuals.abs() + 1e-8
) # Add small constant to avoid division by zero
# Compute weights inversely proportional to the residuals
weights = 1.0 / abs_residuals
# Update 's' using weighted sums
weighted_dot_pred_gt = torch.sum(
weights * predicted_depth * ground_truth_depth
)
weighted_dot_pred_pred = torch.sum(weights * predicted_depth**2)
s = weighted_dot_pred_gt / weighted_dot_pred_pred
# Optionally clip 's' to prevent extreme scaling
s = s.clamp(min=1e-3)
# Detach 's' if you want to stop gradients from flowing through it
s = s.detach()
# Apply the scale factor to the predicted depth
predicted_depth = s * predicted_depth
else:
# Align the predicted depth with the ground truth using median scaling
scale_factor = torch.median(ground_truth_depth) / torch.median(predicted_depth)
predicted_depth *= scale_factor
if disp_input:
# convert back to depth
ground_truth_depth = real_gt
predicted_depth = depth2disparity(predicted_depth)
# Clip the predicted depth values
if post_clip_min is not None:
predicted_depth = torch.clamp(predicted_depth, min=post_clip_min)
if post_clip_max is not None:
predicted_depth = torch.clamp(predicted_depth, max=post_clip_max)
if custom_mask is not None:
assert custom_mask.shape == ground_truth_depth_original.shape
mask_within_mask = custom_mask.cpu()[mask]
predicted_depth = predicted_depth[mask_within_mask]
ground_truth_depth = ground_truth_depth[mask_within_mask]
# Calculate the metrics
abs_rel = torch.mean(
torch.abs(predicted_depth - ground_truth_depth) / ground_truth_depth
).item()
sq_rel = torch.mean(
((predicted_depth - ground_truth_depth) ** 2) / ground_truth_depth
).item()
# Correct RMSE calculation
rmse = torch.sqrt(torch.mean((predicted_depth - ground_truth_depth) ** 2)).item()
# Clip the depth values to avoid log(0)
predicted_depth = torch.clamp(predicted_depth, min=1e-5)
log_rmse = torch.sqrt(
torch.mean((torch.log(predicted_depth) - torch.log(ground_truth_depth)) ** 2)
).item()
# Calculate the accuracy thresholds
max_ratio = torch.maximum(
predicted_depth / ground_truth_depth, ground_truth_depth / predicted_depth
)
threshold_0 = torch.mean((max_ratio < 1.0).float()).item()
threshold_1 = torch.mean((max_ratio < 1.25).float()).item()
threshold_2 = torch.mean((max_ratio < 1.25**2).float()).item()
threshold_3 = torch.mean((max_ratio < 1.25**3).float()).item()
# Compute the depth error parity map
if metric_scale:
predicted_depth_original = predicted_depth_original
if disp_input:
predicted_depth_original = depth2disparity(predicted_depth_original)
depth_error_parity_map = (
torch.abs(predicted_depth_original - ground_truth_depth_original)
/ ground_truth_depth_original
)
elif align_with_lstsq or align_with_lad or align_with_lad2:
predicted_depth_original = predicted_depth_original * s + t
if disp_input:
predicted_depth_original = depth2disparity(predicted_depth_original)
depth_error_parity_map = (
torch.abs(predicted_depth_original - ground_truth_depth_original)
/ ground_truth_depth_original
)
elif align_with_scale:
predicted_depth_original = predicted_depth_original * s
if disp_input:
predicted_depth_original = depth2disparity(predicted_depth_original)
depth_error_parity_map = (
torch.abs(predicted_depth_original - ground_truth_depth_original)
/ ground_truth_depth_original
)
else:
predicted_depth_original = predicted_depth_original * scale_factor
if disp_input:
predicted_depth_original = depth2disparity(predicted_depth_original)
depth_error_parity_map = (
torch.abs(predicted_depth_original - ground_truth_depth_original)
/ ground_truth_depth_original
)
# Reshape the depth_error_parity_map back to the original image size
depth_error_parity_map_full = torch.zeros_like(ground_truth_depth_original)
depth_error_parity_map_full = torch.where(
mask, depth_error_parity_map, depth_error_parity_map_full
)
predict_depth_map_full = predicted_depth_original
gt_depth_map_full = torch.zeros_like(ground_truth_depth_original)
gt_depth_map_full = torch.where(
mask, ground_truth_depth_original, gt_depth_map_full
)
num_valid_pixels = (
torch.sum(mask).item()
if custom_mask is None
else torch.sum(mask_within_mask).item()
)
if num_valid_pixels == 0:
(
abs_rel,
sq_rel,
rmse,
log_rmse,
threshold_0,
threshold_1,
threshold_2,
threshold_3,
) = (0, 0, 0, 0, 0, 0, 0, 0)
results = {
"Abs Rel": abs_rel,
"Sq Rel": sq_rel,
"RMSE": rmse,
"Log RMSE": log_rmse,
"δ < 1.": threshold_0,
"δ < 1.25": threshold_1,
"δ < 1.25^2": threshold_2,
"δ < 1.25^3": threshold_3,
"valid_pixels": num_valid_pixels,
}
return (
results,
depth_error_parity_map_full,
predict_depth_map_full,
gt_depth_map_full,
)
```
## /eval/video_depth/utils.py
```py path="/eval/video_depth/utils.py"
from copy import deepcopy
import cv2
import numpy as np
import torch
import torch.nn as nn
import roma
from copy import deepcopy
import tqdm
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from scipy.spatial.transform import Rotation
from PIL import Image
import imageio.v2 as iio
from matplotlib.figure import Figure
def save_focals(cam_dict, path):
# convert focal to txt
focals = cam_dict["focal"]
np.savetxt(path, focals, fmt="%.6f")
return focals
def save_intrinsics(cam_dict, path):
K_raw = np.eye(3)[None].repeat(len(cam_dict["focal"]), axis=0)
K_raw[:, 0, 0] = cam_dict["focal"]
K_raw[:, 1, 1] = cam_dict["focal"]
K_raw[:, :2, 2] = cam_dict["pp"]
K = K_raw.reshape(-1, 9)
np.savetxt(path, K, fmt="%.6f")
return K_raw
def save_conf_maps(conf, path):
for i, c in enumerate(conf):
np.save(f"{path}/conf_{i}.npy", c.detach().cpu().numpy())
return conf
def save_rgb_imgs(colors, path):
imgs = colors
for i, img in enumerate(imgs):
# convert from rgb to bgr
iio.imwrite(
f"{path}/frame_{i:04d}.jpg", (img.cpu().numpy() * 255).astype(np.uint8)
)
return imgs
def save_depth_maps(pts3ds_self, path, conf_self=None):
depth_maps = torch.stack([pts3d_self[..., -1] for pts3d_self in pts3ds_self], 0)
min_depth = depth_maps.min() # float(torch.quantile(out, 0.01))
max_depth = depth_maps.max() # float(torch.quantile(out, 0.99))
colored_depth = colorize(
depth_maps,
cmap_name="Spectral_r",
range=(min_depth, max_depth),
append_cbar=True,
)
images = []
if conf_self is not None:
conf_selfs = torch.concat(conf_self, 0)
min_conf = torch.log(conf_selfs.min()) # float(torch.quantile(out, 0.01))
max_conf = torch.log(conf_selfs.max()) # float(torch.quantile(out, 0.99))
colored_conf = colorize(
torch.log(conf_selfs),
cmap_name="jet",
range=(min_conf, max_conf),
append_cbar=True,
)
for i, depth_map in enumerate(colored_depth):
# Apply color map to depth map
img_path = f"{path}/frame_{(i):04d}.png"
if conf_self is None:
to_save = (depth_map * 255).detach().cpu().numpy().astype(np.uint8)
else:
to_save = torch.cat([depth_map, colored_conf[i]], dim=1)
to_save = (to_save * 255).detach().cpu().numpy().astype(np.uint8)
iio.imwrite(img_path, to_save)
images.append(Image.open(img_path))
np.save(f"{path}/frame_{(i):04d}.npy", depth_maps[i].detach().cpu().numpy())
# comment this as it may fail sometimes
# images[0].save(f'{path}/_depth_maps.gif', save_all=True, append_images=images[1:], duration=100, loop=0)
return depth_maps
def get_vertical_colorbar(h, vmin, vmax, cmap_name="jet", label=None, cbar_precision=2):
"""
:param w: pixels
:param h: pixels
:param vmin: min value
:param vmax: max value
:param cmap_name:
:param label
:return:
"""
fig = Figure(figsize=(2, 8), dpi=100)
fig.subplots_adjust(right=1.5)
canvas = FigureCanvasAgg(fig)
# Do some plotting.
ax = fig.add_subplot(111)
cmap = cm.get_cmap(cmap_name)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
tick_cnt = 6
tick_loc = np.linspace(vmin, vmax, tick_cnt)
cb1 = mpl.colorbar.ColorbarBase(
ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical"
)
tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
if cbar_precision == 0:
tick_label = [x[:-2] for x in tick_label]
cb1.set_ticklabels(tick_label)
cb1.ax.tick_params(labelsize=18, rotation=0)
if label is not None:
cb1.set_label(label)
# fig.tight_layout()
canvas.draw()
s, (width, height) = canvas.print_to_buffer()
im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
im = im[:, :, :3].astype(np.float32) / 255.0
if h != im.shape[0]:
w = int(im.shape[1] / im.shape[0] * h)
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
return im
def colorize_np(
x,
cmap_name="jet",
mask=None,
range=None,
append_cbar=False,
cbar_in_image=False,
cbar_precision=2,
):
"""
turn a grayscale image into a color image
:param x: input grayscale, [H, W]
:param cmap_name: the colorization method
:param mask: the mask image, [H, W]
:param range: the range for scaling, automatic if None, [min, max]
:param append_cbar: if append the color bar
:param cbar_in_image: put the color bar inside the image to keep the output image the same size as the input image
:return: colorized image, [H, W]
"""
if range is not None:
vmin, vmax = range
elif mask is not None:
# vmin, vmax = np.percentile(x[mask], (2, 100))
vmin = np.min(x[mask][np.nonzero(x[mask])])
vmax = np.max(x[mask])
# vmin = vmin - np.abs(vmin) * 0.01
x[np.logical_not(mask)] = vmin
# print(vmin, vmax)
else:
vmin, vmax = np.percentile(x, (1, 100))
vmax += 1e-6
x = np.clip(x, vmin, vmax)
x = (x - vmin) / (vmax - vmin)
# x = np.clip(x, 0., 1.)
cmap = cm.get_cmap(cmap_name)
x_new = cmap(x)[:, :, :3]
if mask is not None:
mask = np.float32(mask[:, :, np.newaxis])
x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)
cbar = get_vertical_colorbar(
h=x.shape[0],
vmin=vmin,
vmax=vmax,
cmap_name=cmap_name,
cbar_precision=cbar_precision,
)
if append_cbar:
if cbar_in_image:
x_new[:, -cbar.shape[1] :, :] = cbar
else:
x_new = np.concatenate(
(x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
)
return x_new
else:
return x_new
# tensor
def colorize(
x, cmap_name="jet", mask=None, range=None, append_cbar=False, cbar_in_image=False
):
"""
turn a grayscale image into a color image
:param x: torch.Tensor, grayscale image, [H, W] or [B, H, W]
:param mask: torch.Tensor or None, mask image, [H, W] or [B, H, W] or None
"""
device = x.device
x = x.cpu().numpy()
if mask is not None:
mask = mask.cpu().numpy() > 0.99
kernel = np.ones((3, 3), np.uint8)
if x.ndim == 2:
x = x[None]
if mask is not None:
mask = mask[None]
out = []
for x_ in x:
if mask is not None:
mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image)
out.append(torch.from_numpy(x_).to(device).float())
out = torch.stack(out).squeeze(0)
return out
```
## /examples/taylor.mp4
Binary file available at https://raw.githubusercontent.com/Inception3D/TTT3R/refs/heads/main/examples/taylor.mp4
## /examples/westlake.mp4
Binary file available at https://raw.githubusercontent.com/Inception3D/TTT3R/refs/heads/main/examples/westlake.mp4
## /requirements.txt
numpy==1.26.4
torch
torchvision
roma
gradio
matplotlib
tqdm
opencv-python
scipy
einops
trimesh
tensorboard
pyglet<2
huggingface-hub[torch]>=0.22
viser
gradio
lpips
hydra-core
pillow==10.3.0
h5py
accelerate
transformers
scikit-learn
## /src/croco/LICENSE
``` path="/src/croco/LICENSE"
CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
A summary of the CC BY-NC-SA 4.0 license is located here:
https://creativecommons.org/licenses/by-nc-sa/4.0/
The CC BY-NC-SA 4.0 license is located here:
https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py
***************************
NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py
This software is being redistributed in a modifiled form. The original form is available here:
https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
This software in this file incorporates parts of the following software available here:
Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py
available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE
MoCo v3: https://github.com/facebookresearch/moco-v3
available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE
DeiT: https://github.com/facebookresearch/deit
available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE
ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
https://github.com/facebookresearch/mae/blob/main/LICENSE
Attribution-NonCommercial 4.0 International
***************************
NOTICE WITH RESPECT TO THE FILE: models/blocks.py
This software is being redistributed in a modifiled form. The original form is available here:
https://github.com/rwightman/pytorch-image-models
ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
```
## /src/croco/NOTICE
``` path="/src/croco/NOTICE"
CroCo
Copyright 2022-present NAVER Corp.
This project contains subcomponents with separate copyright notices and license terms.
Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
====
facebookresearch/mae
https://github.com/facebookresearch/mae
Attribution-NonCommercial 4.0 International
====
rwightman/pytorch-image-models
https://github.com/rwightman/pytorch-image-models
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
```
## /src/croco/README.MD
# CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow
[[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)]
This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2:

```bibtex
@inproceedings{croco,
title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}},
author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}},
booktitle={{NeurIPS}},
year={2022}
}
@inproceedings{croco_v2,
title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}},
author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me},
booktitle={ICCV},
year={2023}
}
```
## License
The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information.
Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License.
Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license.
## Preparation
1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version.
```bash
conda create -n croco python=3.7 cmake=3.14.0
conda activate croco
conda install habitat-sim headless -c conda-forge -c aihabitat
conda install pytorch torchvision -c pytorch
conda install notebook ipykernel matplotlib
conda install ipywidgets widgetsnbextension
conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation
```
2. Compile cuda kernels for RoPE
CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels.
```bash
cd models/curope/
python setup.py build_ext --inplace
cd ../../
```
This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only.
You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation.
In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded.
3. Download pre-trained model
We provide several pre-trained models:
| modelname | pre-training data | pos. embed. | Encoder | Decoder |
|------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------|
| [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth) | Habitat | cosine | ViT-B | Small |
| [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real | RoPE | ViT-B | Small |
| [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth) | Habitat + real | RoPE | ViT-B | Base |
| [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real | RoPE | ViT-L | Base |
To download a specific model, i.e., the first one (`CroCo.pth`)
```bash
mkdir -p pretrained_models/
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/
```
## Reconstruction example
Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`)
```bash
python demo.py
```
## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator
First download the test scene from Habitat:
```bash
python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/
```
Then, run the Notebook demo `interactive_demo.ipynb`.
In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo.

## Pre-training
### CroCo
To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command:
```
torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/
```
Our CroCo pre-training was launched on a single server with 4 GPUs.
It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training.
Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
The first run can take a few minutes to start, to parse all available pre-training pairs.
### CroCo v2
For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD).
Then, run the following command for the largest model (ViT-L encoder, Base decoder):
```
torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/
```
Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases.
The largest model should take around 12 days on A100.
Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
## Stereo matching and Optical flow downstream tasks
For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD).
## /src/croco/assets/Chateau1.png
Binary file available at https://raw.githubusercontent.com/Inception3D/TTT3R/refs/heads/main/src/croco/assets/Chateau1.png
## /src/croco/assets/Chateau2.png
Binary file available at https://raw.githubusercontent.com/Inception3D/TTT3R/refs/heads/main/src/croco/assets/Chateau2.png
## /src/croco/assets/arch.jpg
Binary file available at https://raw.githubusercontent.com/Inception3D/TTT3R/refs/heads/main/src/croco/assets/arch.jpg
## /src/croco/croco-stereo-flow-demo.ipynb
```ipynb path="/src/croco/croco-stereo-flow-demo.ipynb"
{
"cells": [
{
"cell_type": "markdown",
"id": "9bca0f41",
"metadata": {},
"source": [
"# Simple inference example with CroCo-Stereo or CroCo-Flow"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "80653ef7",
"metadata": {},
"outputs": [],
"source": [
"# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
"# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
]
},
{
"cell_type": "markdown",
"id": "4f033862",
"metadata": {},
"source": [
"First download the model(s) of your choice by running\n",
"\`\`\`\n",
"bash stereoflow/download_model.sh crocostereo.pth\n",
"bash stereoflow/download_model.sh crocoflow.pth\n",
"\`\`\`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1fb2e392",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
"device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
"import matplotlib.pylab as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0e25d77",
"metadata": {},
"outputs": [],
"source": [
"from stereoflow.test import _load_model_and_criterion\n",
"from stereoflow.engine import tiled_pred\n",
"from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n",
"from stereoflow.datasets_flow import flowToColor\n",
"tile_overlap=0.7 # recommended value, higher value can be slightly better but slower"
]
},
{
"cell_type": "markdown",
"id": "86a921f5",
"metadata": {},
"source": [
"### CroCo-Stereo example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64e483cb",
"metadata": {},
"outputs": [],
"source": [
"image1 = np.asarray(Image.open('<path_to_left_image>'))\n",
"image2 = np.asarray(Image.open('<path_to_right_image>'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0d04303",
"metadata": {},
"outputs": [],
"source": [
"model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47dc14b5",
"metadata": {},
"outputs": [],
"source": [
"im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
"im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
"with torch.inference_mode():\n",
" pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
"pred = pred.squeeze(0).squeeze(0).cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "583b9f16",
"metadata": {},
"outputs": [],
"source": [
"plt.imshow(vis_disparity(pred))\n",
"plt.axis('off')"
]
},
{
"cell_type": "markdown",
"id": "d2df5d70",
"metadata": {},
"source": [
"### CroCo-Flow example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ee257a7",
"metadata": {},
"outputs": [],
"source": [
"image1 = np.asarray(Image.open('<path_to_first_image>'))\n",
"image2 = np.asarray(Image.open('<path_to_second_image>'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5edccf0",
"metadata": {},
"outputs": [],
"source": [
"model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b19692c3",
"metadata": {},
"outputs": [],
"source": [
"im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
"im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
"with torch.inference_mode():\n",
" pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
"pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "26f79db3",
"metadata": {},
"outputs": [],
"source": [
"plt.imshow(flowToColor(pred))\n",
"plt.axis('off')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
```
## /src/croco/datasets/__init__.py
```py path="/src/croco/datasets/__init__.py"
```
## /src/croco/datasets/habitat_sim/__init__.py
```py path="/src/croco/datasets/habitat_sim/__init__.py"
```
## /src/croco/datasets/habitat_sim/generate_from_metadata_files.py
```py path="/src/croco/datasets/habitat_sim/generate_from_metadata_files.py"
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
"""
Script generating commandlines to generate image pairs from metadata files.
"""
import os
import glob
from tqdm import tqdm
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", required=True)
parser.add_argument("--output_dir", required=True)
parser.add_argument(
"--prefix",
default="",
help="Commanline prefix, useful e.g. to setup environment.",
)
args = parser.parse_args()
input_metadata_filenames = glob.iglob(
f"{args.input_dir}/**/metadata.json", recursive=True
)
for metadata_filename in tqdm(input_metadata_filenames):
output_dir = os.path.join(
args.output_dir,
os.path.relpath(os.path.dirname(metadata_filename), args.input_dir),
)
# Do not process the scene if the metadata file already exists
if os.path.exists(os.path.join(output_dir, "metadata.json")):
continue
commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
print(commandline)
```
## /src/croco/models/curope/__init__.py
```py path="/src/croco/models/curope/__init__.py"
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
from .curope2d import cuRoPE2D
```
## /src/croco/stereoflow/download_model.sh
```sh path="/src/croco/stereoflow/download_model.sh"
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
model=$1
outfile="stereoflow_models/${model}"
if [[ ! -f $outfile ]]
then
mkdir -p stereoflow_models/;
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/StereoFlow_models/$1 -P stereoflow_models/;
else
echo "Model ${model} already downloaded in ${outfile}."
fi
```
## /src/dust3r/__init__.py
```py path="/src/dust3r/__init__.py"
```
## /src/dust3r/datasets/base/__init__.py
```py path="/src/dust3r/datasets/base/__init__.py"
```
The content has been capped at 50000 tokens. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.