Inception3D/TTT3R/main 220k tokens More Tools
```
├── .gitignore (100 tokens)
├── LICENSE (omitted)
├── README.md (800 tokens)
├── add_ckpt_path.py
├── datasets_preprocess/
   ├── long_prepare_bonn.py (600 tokens)
   ├── long_prepare_kitti.py (400 tokens)
   ├── long_prepare_scannet.py (600 tokens)
   ├── long_prepare_tum.py (800 tokens)
   ├── path_to_root.py (100 tokens)
├── demo.py (3.8k tokens)
├── eval/
   ├── eval.md (500 tokens)
   ├── mv_recon/
      ├── base.py (2k tokens)
      ├── criterion.py (3.3k tokens)
      ├── data.py (3.6k tokens)
      ├── dataset_utils/
         ├── __init__.py
         ├── corr.py (700 tokens)
         ├── cropping.py (900 tokens)
         ├── transforms.py (500 tokens)
      ├── launch.py (3.5k tokens)
      ├── run.sh (100 tokens)
      ├── utils.py (400 tokens)
   ├── relpose/
      ├── evo_utils.py (3.1k tokens)
      ├── launch.py (3.5k tokens)
      ├── metadata.py (2.1k tokens)
      ├── run_scannet.sh (200 tokens)
      ├── run_sintel.sh (100 tokens)
      ├── run_tum.sh (200 tokens)
      ├── utils.py (1800 tokens)
   ├── video_depth/
      ├── eval_depth.py (3.1k tokens)
      ├── launch.py (2.4k tokens)
      ├── metadata.py (1600 tokens)
      ├── run_bonn.sh (200 tokens)
      ├── run_kitti.sh (200 tokens)
      ├── run_sintel.sh (200 tokens)
      ├── tools.py (2.7k tokens)
      ├── utils.py (1400 tokens)
├── examples/
   ├── taylor.mp4
   ├── westlake.mp4
├── requirements.txt
├── src/
   ├── croco/
      ├── LICENSE (400 tokens)
      ├── NOTICE (100 tokens)
      ├── README.MD (1600 tokens)
      ├── assets/
         ├── Chateau1.png
         ├── Chateau2.png
         ├── arch.jpg
      ├── croco-stereo-flow-demo.ipynb (1000 tokens)
      ├── datasets/
         ├── __init__.py
         ├── crops/
            ├── README.MD (600 tokens)
            ├── extract_crops_from_images.py (1100 tokens)
         ├── habitat_sim/
            ├── README.MD (700 tokens)
            ├── __init__.py
            ├── generate_from_metadata.py (900 tokens)
            ├── generate_from_metadata_files.py (300 tokens)
            ├── generate_multiview_images.py (1700 tokens)
            ├── multiview_habitat_sim_generator.py (4k tokens)
            ├── pack_metadata_files.py (600 tokens)
            ├── paths.py (1200 tokens)
         ├── pairs_dataset.py (1100 tokens)
         ├── transforms.py (900 tokens)
      ├── interactive_demo.ipynb (2.1k tokens)
      ├── models/
         ├── blocks.py (2.5k tokens)
         ├── criterion.py (300 tokens)
         ├── croco.py (2.4k tokens)
         ├── croco_downstream.py (1000 tokens)
         ├── curope/
            ├── __init__.py
            ├── curope.cpp (500 tokens)
            ├── curope2d.py (300 tokens)
            ├── kernels.cu (800 tokens)
            ├── setup.py (200 tokens)
         ├── dpt_block.py (3k tokens)
         ├── head_downstream.py (600 tokens)
         ├── masking.py (100 tokens)
         ├── pos_embed.py (1400 tokens)
      ├── pretrain.py (2.5k tokens)
      ├── stereoflow/
         ├── README.MD (2.5k tokens)
         ├── augmentor.py (2.9k tokens)
         ├── criterion.py (2.6k tokens)
         ├── datasets_flow.py (6.7k tokens)
         ├── datasets_stereo.py (7.2k tokens)
         ├── download_model.sh (100 tokens)
         ├── engine.py (2.6k tokens)
         ├── test.py (2.1k tokens)
         ├── train.py (3k tokens)
      ├── utils/
         ├── misc.py (4k tokens)
   ├── dust3r/
      ├── __init__.py
      ├── blocks.py (3.5k tokens)
      ├── datasets/
         ├── __init__.py (500 tokens)
         ├── arkitscenes.py (1800 tokens)
         ├── arkitscenes_highres.py (1300 tokens)
         ├── base/
            ├── __init__.py
            ├── base_multiview_dataset.py (4.1k tokens)
            ├── batched_sampler.py (600 tokens)
            ├── easy_dataset.py (1200 tokens)
         ├── bedlam.py (2.9k tokens)
         ├── blendedmvs.py (2.4k tokens)
         ├── co3d.py (1500 tokens)
         ├── cop3d.py (800 tokens)
         ├── dl3dv.py (1100 tokens)
         ├── dynamic_replica.py (900 tokens)
         ├── eden.py (600 tokens)
         ├── hoi4d.py (600 tokens)
         ├── hypersim.py (1000 tokens)
         ├── irs.py (600 tokens)
         ├── mapfree.py (2.1k tokens)
         ├── megadepth.py (700 tokens)
         ├── mp3d.py (1000 tokens)
         ├── mvimgnet.py (1000 tokens)
         ├── mvs_synth.py (1000 tokens)
         ├── omniobject3d.py (1000 tokens)
         ├── pointodyssey.py (1200 tokens)
         ├── realestate10k.py (1000 tokens)
         ├── scannet.py (1000 tokens)
         ├── scannetpp.py (1400 tokens)
         ├── smartportraits.py (600 tokens)
         ├── spring.py (900 tokens)
         ├── synscapes.py (600 tokens)
         ├── tartanair.py (1100 tokens)
         ├── threedkb.py (700 tokens)
         ├── uasol.py (1000 tokens)
         ├── unreal4k.py (1000 tokens)
         ├── urbansyn.py (600 tokens)
         ├── utils/
            ├── __init__.py
            ├── corr.py (800 tokens)
            ├── cropping.py (900 tokens)
            ├── transforms.py (600 tokens)
         ├── vkitti2.py (1100 tokens)
         ├── waymo.py (1200 tokens)
         ├── wildrgbd.py (400 tokens)
      ├── heads/
         ├── __init__.py (300 tokens)
         ├── dpt_head.py (1800 tokens)
         ├── linear_head.py (2.5k tokens)
         ├── postprocess.py (900 tokens)
      ├── inference.py (2.6k tokens)
      ├── losses.py (8.7k tokens)
      ├── model.py (10.2k tokens)
      ├── patch_embed.py (600 tokens)
      ├── post_process.py (400 tokens)
      ├── utils/
         ├── __init__.py
         ├── camera.py (2.9k tokens)
         ├── device.py (600 tokens)
         ├── geometry.py (3.6k tokens)
         ├── image.py (1800 tokens)
         ├── misc.py (800 tokens)
         ├── parallel.py (500 tokens)
         ├── path_to_croco.py (100 tokens)
         ├── render.py (500 tokens)
      ├── viz.py (6.7k tokens)
   ├── train.py (6.3k tokens)
├── viser_utils.py (6.5k tokens)
```


## /.gitignore

```gitignore path="/.gitignore" 
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

# C extensions
*.so

# Distribution / packaging
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
.tox/
.coverage
.cache
nosetests.xml
coverage.xml

# Translations
*.mo

# Mr Developer
.mr.developer.cfg
.project
.pydevproject

# Rope
.ropeproject

# Django stuff:
*.log
*.pot

# Sphinx documentation
docs/_build/

# Ignore data and ckpts
*.pth
data
src/checkpoints

# Ignore results
tmp
eval_results

.vscode
```

## /README.md

<h2 align="center"> <a href="https://rover-xingyu.github.io/TTT3R">TTT3R: 3D Reconstruction as Test-Time Training</a>
</h2>

<h5 align="center">

[![arXiv](https://img.shields.io/badge/Arxiv-2509.26645-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2509.26645) 
[![Home Page](https://img.shields.io/badge/Project-Website-33728E.svg)](https://rover-xingyu.github.io/TTT3R) 
[![X](https://img.shields.io/badge/@Xingyu%20Chen-black?logo=X)](https://x.com/RoverXingyu)  [![Bluesky](https://img.shields.io/badge/@Xingyu%20Chen-white?logo=Bluesky)](https://bsky.app/profile/xingyu-chen.bsky.social)


[Xingyu Chen](https://rover-xingyu.github.io/),
[Yue Chen](https://fanegg.github.io/),
[Yuliang Xiu](https://xiuyuliang.cn/),
[Andreas Geiger](https://www.cvlibs.net/),
[Anpei Chen](https://apchenstu.github.io/)
</h5>

<div align="center">
TL;DR: A simple state update rule to enhance length generalization for CUT3R.
</div>
<br>

https://github.com/user-attachments/assets/b7583837-1f1e-43a4-b281-09f340ee5918

## Getting Started

### Installation

1. Clone TTT3R.
```bash
git clone https://github.com/Inception3D/TTT3R.git
cd TTT3R
```

2. Create the environment.
```bash
conda create -n ttt3r python=3.11 cmake=3.14.0
conda activate ttt3r
conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia  # use the correct version of cuda for your system
pip install -r requirements.txt
# issues with pytorch dataloader, see https://github.com/pytorch/pytorch/issues/99625
conda install 'llvm-openmp<16'
# for evaluation
pip install evo
pip install open3d
```

3. Compile the cuda kernels for RoPE (as in CroCo v2).
```bash
cd src/croco/models/curope/
python setup.py build_ext --inplace
cd ../../../../
```

### Download Checkpoints

CUT3R provide checkpoints trained on 4-64 views: [`cut3r_512_dpt_4_64.pth`](https://drive.google.com/file/d/1Asz-ZB3FfpzZYwunhQvNPZEUA8XUNAYD/view?usp=drive_link).

To download the weights, run the following commands:
```bash
cd src
gdown --fuzzy https://drive.google.com/file/d/1Asz-ZB3FfpzZYwunhQvNPZEUA8XUNAYD/view?usp=drive_link
cd ..
```

### Inference Demo

To run the inference demo, you can use the following command:
```bash
# input can be a folder or a video
# the following script will run inference with TTT3R and visualize the output with viser on port 8080
CUDA_VISIBLE_DEVICES=6 python demo.py --model_path MODEL_PATH --size 512 \
    --seq_path SEQ_PATH --output_dir OUT_DIR --port 8080 \
    --model_update_type ttt3r --frame_interval 1 --reset_interval 100 \
    --downsample_factor 1000 --vis_threshold 5.0

# Example:
CUDA_VISIBLE_DEVICES=6 python demo.py --model_path src/cut3r_512_dpt_4_64.pth --size 512 \
    --seq_path examples/westlake.mp4 --output_dir tmp/taylor --port 8080 \
    --model_update_type ttt3r --frame_interval 1 --reset_interval 100 \
    --downsample_factor 100 --vis_threshold 6.0

CUDA_VISIBLE_DEVICES=6 python demo.py --model_path src/cut3r_512_dpt_4_64.pth --size 512 \
    --seq_path examples/taylor.mp4 --output_dir tmp/taylor --port 8080 \
    --model_update_type ttt3r --frame_interval 1 --reset_interval 50 \
    --downsample_factor 100 --vis_threshold 10.0 
```
Output results will be saved to `output_dir`.


### Evaluation
Please refer to the [eval.md](eval/eval.md) for more details.

## Acknowledgements
Our code is based on the following awesome repositories:

- [CUT3R](https://github.com/CUT3R/CUT3R)
- [Easi3R](https://github.com/Inception3D/Easi3R)
- [DUSt3R](https://github.com/naver/dust3r)
- [MonST3R](https://github.com/Junyi42/monst3r.git)
- [Spann3R](https://github.com/HengyiWang/spann3r.git)
- [Viser](https://github.com/nerfstudio-project/viser)

We thank the authors for releasing their code!

## Citation

If you find our work useful, please cite:

```bibtex
@article{chen2025ttt3r,
    title={TTT3R: 3D Reconstruction as Test-Time Training},
    author={Chen, Xingyu and Chen, Yue and Xiu, Yuliang and Geiger, Andreas and Chen, Anpei},
    journal={arXiv preprint arXiv:2509.26645},
    year={2025}
    }
```


## /add_ckpt_path.py

```py path="/add_ckpt_path.py" 
import sys
import os
import os.path as path


def add_path_to_dust3r(ckpt):
    HERE_PATH = os.path.dirname(os.path.abspath(ckpt))
    # workaround for sibling import
    sys.path.insert(0, HERE_PATH)

```

## /datasets_preprocess/long_prepare_bonn.py

```py path="/datasets_preprocess/long_prepare_bonn.py" 
import glob
import os
import shutil
import numpy as np

START_FRAME = 30  # inital frame
for TARGET_FRAMES in [50,100,150,200,250,300,350,400,450,500]: 
    END_FRAME = START_FRAME + TARGET_FRAMES   # end frame

    dirs = glob.glob("/home/xingyu/monst3r/data/bonn/rgbd_bonn_dataset/*/")
    dirs = sorted(dirs)

    # create new base directory
    base_new_dir = "./data/long_bonn_s1/rgbd_bonn_dataset/"
    os.makedirs(base_new_dir, exist_ok=True)

    print(f"specified frame range: {START_FRAME} to {END_FRAME}, target frames: {TARGET_FRAMES}")

    # extract frames
    for dir in dirs:
        # get original directory name
        dir_name = os.path.basename(os.path.dirname(dir))
        # build new directory path
        new_base_dir = base_new_dir + dir_name + '/'
        
        # pre-calculate the actual available frames for each modality
        rgb_frames = glob.glob(dir + 'rgb/*.png')
        rgb_frames = sorted(rgb_frames)
        available_rgb_frames = len(rgb_frames)
        
        depth_frames = glob.glob(dir + 'depth/*.png')
        depth_frames = sorted(depth_frames)
        available_depth_frames = len(depth_frames)
        
        gt_path = dir + "groundtruth.txt"
        gt = np.loadtxt(gt_path)
        available_gt_frames = len(gt)
        
        # calculate the actual frames for each modality in the specified range
        actual_rgb_frames = min(available_rgb_frames - START_FRAME, TARGET_FRAMES)
        actual_depth_frames = min(available_depth_frames - START_FRAME, TARGET_FRAMES)
        actual_gt_frames = min(available_gt_frames - START_FRAME, TARGET_FRAMES)
        
        # take the minimum value of the three modalities to ensure consistency
        final_frame_count = min(actual_rgb_frames, actual_depth_frames, actual_gt_frames)
        print(f"  final unified frame count: {final_frame_count}")
        
        # process RGB frames
        rgb_frames = rgb_frames[START_FRAME:START_FRAME + final_frame_count]
        new_dir = new_base_dir + f'rgb_{TARGET_FRAMES}/'
        if os.path.exists(new_dir):
            shutil.rmtree(new_dir)
        for frame in rgb_frames:
            os.makedirs(new_dir, exist_ok=True)
            shutil.copy(frame, new_dir)

        # process Depth frames
        depth_frames = depth_frames[START_FRAME:START_FRAME + final_frame_count]
        new_dir = new_base_dir + f'depth_{TARGET_FRAMES}/'
        if os.path.exists(new_dir):
            shutil.rmtree(new_dir)
        for frame in depth_frames:
            os.makedirs(new_dir, exist_ok=True)
            shutil.copy(frame, new_dir)

        # process Groundtruth
        gt_final = gt[START_FRAME:START_FRAME + final_frame_count]
        gt_file = new_base_dir + f'groundtruth_{TARGET_FRAMES}.txt'
        if os.path.exists(gt_file):
            os.remove(gt_file)
        np.savetxt(gt_file, gt_final)
```

## /datasets_preprocess/long_prepare_kitti.py

```py path="/datasets_preprocess/long_prepare_kitti.py" 
from PIL import Image
import numpy as np


def depth_read(filename):
    # loads depth map D from png file
    # and returns it as a numpy array,
    # for details see readme.txt

    depth_png = np.array(Image.open(filename), dtype=int)
    # make sure we have a proper 16bit depth map here.. not 8bit!
    assert(np.max(depth_png) > 255)

    depth = depth_png.astype(np.float) / 256.
    depth[depth_png == 0] = -1.
    return depth


import glob
import os
import shutil

for TARGET_FRAMES in [50,100,150,200,250,300,350,400,450,500]:
    depth_dirs = glob.glob("/home/xingyu/monst3r/data/kitti/val/*/proj_depth/groundtruth/image_02")
    for dir in depth_dirs:
        # new depth dir
        new_depth_dir = f"./data/long_kitti_s1/depth_selection/val_selection_cropped/groundtruth_depth_gathered_{TARGET_FRAMES}/" + dir.split("/")[-4]+"_02"
        # print(new_depth_dir)
        new_image_dir = f"./data/long_kitti_s1/depth_selection/val_selection_cropped/image_gathered_{TARGET_FRAMES}/" + dir.split("/")[-4]+"_02"
        os.makedirs(new_depth_dir, exist_ok=True)
        os.makedirs(new_image_dir, exist_ok=True)
        
        # get all depth files and calculate the actual frame count
        all_depth_files = sorted(glob.glob(dir + "/*.png"))
        actual_frames = min(len(all_depth_files), TARGET_FRAMES)
        print(f"directory {dir.split('/')[-4]}: target frames {TARGET_FRAMES}, actual available frames {len(all_depth_files)}, actual processed frames {actual_frames}")

        for depth_file in all_depth_files[:TARGET_FRAMES]:
            new_path = new_depth_dir + "/" + depth_file.split("/")[-1]
            shutil.copy(depth_file, new_path)
            # get the path of the corresponding image
            mid = "_".join(depth_file.split("/")[-5].split("_")[:3])
            image_file = depth_file.replace('val', mid).replace('proj_depth/groundtruth/image_02', 'image_02/data')
            print(image_file)
            # check if the image file exists
            if os.path.exists(image_file):
                new_path = new_image_dir + "/" + image_file.split("/")[-1]
                shutil.copy(image_file, new_path)
            else:
                print("Image file does not exist: ", image_file)
```

## /datasets_preprocess/long_prepare_scannet.py

```py path="/datasets_preprocess/long_prepare_scannet.py" 
import glob
import os
import shutil
import numpy as np

# configurable parameters
for TARGET_FRAMES in [50,90,100,150,200,300,400,500,600,700,800,900,1000]:

    SAMPLE_INTERVAL = 3  # sampling interval, take 1 frame every N frames original 3

    seq_list = sorted(os.listdir("/home/share/Dataset/3D_scene/ScanNet/scans_test/"))

    for seq in seq_list:
        img_pathes = sorted(glob.glob(f"/home/share/Dataset/3D_scene/ScanNet/scans_test/{seq}/color/*.jpg"), key=lambda x: int(os.path.basename(x).split('.')[0]))
        depth_pathes = sorted(glob.glob(f"/home/share/Dataset/3D_scene/ScanNet/scans_test/{seq}/depth/*.png"), key=lambda x: int(os.path.basename(x).split('.')[0]))
        pose_pathes = sorted(glob.glob(f"/home/share/Dataset/3D_scene/ScanNet/scans_test/{seq}/pose/*.txt"), key=lambda x: int(os.path.basename(x).split('.')[0]))
        
        # calculate the required original frame count
        required_frames = TARGET_FRAMES * SAMPLE_INTERVAL
        total_frames = min(len(img_pathes), len(depth_pathes), len(pose_pathes))
        
        # if the original frame count is not enough, adjust the target frame count
        actual_target_frames = min(TARGET_FRAMES, total_frames // SAMPLE_INTERVAL)
        
        print(f"{seq}: original frame count {total_frames}, target frames {TARGET_FRAMES}, actual frames {actual_target_frames}")

        # use target frame count to name the directory
        new_color_dir = f"./data/long_scannet_s{SAMPLE_INTERVAL}/{seq}/color_{TARGET_FRAMES}"
        new_depth_dir = f"./data/long_scannet_s{SAMPLE_INTERVAL}/{seq}/depth_{TARGET_FRAMES}"

        # sample according to the target frame count
        new_img_pathes = img_pathes[:actual_target_frames*SAMPLE_INTERVAL:SAMPLE_INTERVAL]
        new_depth_pathes = depth_pathes[:actual_target_frames*SAMPLE_INTERVAL:SAMPLE_INTERVAL]
        new_pose_pathes = pose_pathes[:actual_target_frames*SAMPLE_INTERVAL:SAMPLE_INTERVAL]

        # if the target directory exists, delete it
        if os.path.exists(new_color_dir):
            shutil.rmtree(new_color_dir)
        if os.path.exists(new_depth_dir):
            shutil.rmtree(new_depth_dir)
        
        os.makedirs(new_color_dir, exist_ok=True)
        os.makedirs(new_depth_dir, exist_ok=True)

        for i, (img_path, depth_path) in enumerate(zip(new_img_pathes, new_depth_pathes)):
            shutil.copy(img_path, f"{new_color_dir}/frame_{i:04d}.jpg")
            shutil.copy(depth_path, f"{new_depth_dir}/frame_{i:04d}.png")

        # use target frame count to name the pose file
        pose_new_path = f"./data/long_scannet_s{SAMPLE_INTERVAL}/{seq}/pose_{TARGET_FRAMES}.txt"
        with open(pose_new_path, 'w') as f:
            for i, pose_path in enumerate(new_pose_pathes):
                with open(pose_path, 'r') as pose_file:
                    pose = np.loadtxt(pose_file)
                    pose = pose.reshape(-1)
                    f.write(f"{' '.join(map(str, pose))}\n")

```

## /datasets_preprocess/long_prepare_tum.py

```py path="/datasets_preprocess/long_prepare_tum.py" 
import glob
import os
import shutil
import numpy as np

def read_file_list(filename):
    """
    Reads a trajectory from a text file. 
    
    File format:
    The file format is "stamp d1 d2 d3 ...", where stamp denotes the time stamp (to be matched)
    and "d1 d2 d3.." is arbitary data (e.g., a 3D position and 3D orientation) associated to this timestamp. 
    
    Input:
    filename -- File name
    
    Output:
    dict -- dictionary of (stamp,data) tuples
    
    """
    file = open(filename)
    data = file.read()
    lines = data.replace(","," ").replace("\t"," ").split("\n") 
    list = [[v.strip() for v in line.split(" ") if v.strip()!=""] for line in lines if len(line)>0 and line[0]!="#"]
    list = [(float(l[0]),l[1:]) for l in list if len(l)>1]
    return dict(list)

def associate(first_list, second_list, offset, max_difference):
    """
    Associate two dictionaries of (stamp, data). As the time stamps never match exactly, we aim 
    to find the closest match for every input tuple.
    
    Input:
    first_list -- first dictionary of (stamp, data) tuples
    second_list -- second dictionary of (stamp, data) tuples
    offset -- time offset between both dictionaries (e.g., to model the delay between the sensors)
    max_difference -- search radius for candidate generation

    Output:
    matches -- list of matched tuples ((stamp1, data1), (stamp2, data2))
    """
    # Convert keys to sets for efficient removal
    first_keys = set(first_list.keys())
    second_keys = set(second_list.keys())
    
    potential_matches = [(abs(a - (b + offset)), a, b) 
                         for a in first_keys 
                         for b in second_keys 
                         if abs(a - (b + offset)) < max_difference]
    potential_matches.sort()
    matches = []
    for diff, a, b in potential_matches:
        if a in first_keys and b in second_keys:
            first_keys.remove(a)
            second_keys.remove(b)
            matches.append((a, b))
    
    matches.sort()
    return matches

# create new output directory
output_base_dir = "./data/long_tum_s1/"
SAMPLE_INTERVAL = 1 # sampling interval, take 1 frame every N frames original 3
os.makedirs(output_base_dir, exist_ok=True)

dirs = glob.glob("/home/xingyu/monst3r/data/tum/*/")
dirs = sorted(dirs)
# extract frames
# total_frames_list = []

for TARGET_FRAMES in [50, 100, 150, 200, 300, 400, 500, 600, 700, 800, 900, 1000]:
    for dir in dirs:
        frames = []
        gt = []
        first_file = dir + 'rgb.txt'
        second_file = dir + 'groundtruth.txt'

        first_list = read_file_list(first_file)
        second_list = read_file_list(second_file)
        matches = associate(first_list, second_list, 0.0, 0.02)

        # for a,b in matches[:10]:
        #     print("%f %s %f %s"%(a," ".join(first_list[a]),b," ".join(second_list[b])))
        for a,b in matches:
            frames.append(dir + first_list[a][0])
            gt.append([b]+second_list[b])
        
        # sample 90 frames at the stride of 3
        print(f"process {dir} with {len(frames)} frames")
        frames = frames[::SAMPLE_INTERVAL][:TARGET_FRAMES]

        
        # get original directory name as subdirectory name
        dir_name = os.path.basename(os.path.dirname(dir))
        new_dir = output_base_dir + dir_name + f'/rgb_{TARGET_FRAMES}/'

        for frame in frames:
            os.makedirs(new_dir, exist_ok=True)
            shutil.copy(frame, new_dir)
            # print(f'cp {frame} {new_dir}')

        gt_90 = gt[::SAMPLE_INTERVAL][:TARGET_FRAMES]
        gt_output_file = output_base_dir + dir_name + f'/groundtruth_{TARGET_FRAMES}.txt'
        with open(gt_output_file, 'w') as f:
            for pose in gt_90:
                f.write(f"{' '.join(map(str, pose))}\n")

        print(f"get {dir_name} with {len(frames)} frames")

```

## /datasets_preprocess/path_to_root.py

```py path="/datasets_preprocess/path_to_root.py" 
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# DUSt3R repo root import
# --------------------------------------------------------

import sys
import os.path as path
HERE_PATH = path.normpath(path.dirname(__file__))
DUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../'))
# workaround for sibling import
sys.path.insert(0, DUST3R_REPO_PATH)

```

## /demo.py

```py path="/demo.py" 
#!/usr/bin/env python3
"""
3D Point Cloud Inference and Visualization Script

This script performs inference using the ARCroco3DStereo model and visualizes the
resulting 3D point clouds with the PointCloudViewer. Use the command-line arguments
to adjust parameters such as the model checkpoint path, image sequence directory,
image size, device, etc.

Usage:
    python demo.py [--model_path MODEL_PATH] [--seq_path SEQ_PATH] [--size IMG_SIZE]
                            [--device DEVICE] [--vis_threshold VIS_THRESHOLD] [--output_dir OUT_DIR]

Example:
    python demo.py --model_path src/cut3r_512_dpt_4_64.pth \
        --seq_path examples/001 --device cuda --size 512
"""

import os
import numpy as np
import torch
import time
import glob
import random
import cv2
import argparse
import tempfile
import shutil
from copy import deepcopy
from add_ckpt_path import add_path_to_dust3r
import imageio.v2 as iio
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from sklearn.decomposition import PCA
import datetime
from tqdm import tqdm
from skimage.filters import threshold_otsu, threshold_multiotsu
from einops import rearrange

# Set random seed for reproducibility.
random.seed(42)

framerate = 30

def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Run 3D point cloud inference and visualization using ARCroco3DStereo."
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default="src/cut3r_512_dpt_4_64.pth",
        help="Path to the pretrained model checkpoint.",
    )
    parser.add_argument(
        "--seq_path",
        type=str,
        default="",
        help="Path to the directory containing the image sequence.",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device to run inference on (e.g., 'cuda' or 'cpu').",
    )
    parser.add_argument(
        "--size",
        type=int,
        default="512",
        help="Shape that input images will be rescaled to; if using 224+linear model, choose 224 otherwise 512",
    )
    parser.add_argument(
        "--vis_threshold",
        type=float,
        default=1.5,
        help="Visualization threshold for the point cloud viewer. Ranging from 1 to INF",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./demo_tmp",
        help="value for tempfile.tempdir",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=7860,
        help="port for the point cloud viewer",
    )
    parser.add_argument(
        "--model_update_type",
        type=str,
        default="cut3r",
        help="model update type: cut3r or ttt3r",
    )
    parser.add_argument(
        "--frame_interval",
        type=int,
        default=1,
        help="Frame interval for video processing (e.g., 1 means every frame, 2 means every other frame)",
    )
    parser.add_argument(
        "--reset_interval",
        type=int,
        default=1000000,
        help="Only used for demo, reset state for extremely long sequence, chunks are aligned via global camera poses",
    )
    parser.add_argument(
        "--downsample_factor",
        type=int,
        default=1,
        help="Downsample factor for the point cloud viewer",
    )
    return parser.parse_args()


def prepare_input(
    img_paths, img_mask, size, raymaps=None, raymap_mask=None, revisit=1, update=True, reset_interval=10000
):
    """
    Prepare input views for inference from a list of image paths.

    Args:
        img_paths (list): List of image file paths.
        img_mask (list of bool): Flags indicating valid images.
        size (int): Target image size.
        raymaps (list, optional): List of ray maps.
        raymap_mask (list, optional): Flags indicating valid ray maps.
        revisit (int): How many times to revisit each view.
        update (bool): Whether to update the state on revisits.

    Returns:
        list: A list of view dictionaries.
    """
    # Import image loader (delayed import needed after adding ckpt path).
    from src.dust3r.utils.image import load_images

    images = load_images(img_paths, size=size)
    views = []

    if raymaps is None and raymap_mask is None:
        # Only images are provided.
        for i in range(len(images)):
            view = {
                "img": images[i]["img"],
                "ray_map": torch.full(
                    (
                        images[i]["img"].shape[0],
                        6,
                        images[i]["img"].shape[-2],
                        images[i]["img"].shape[-1],
                    ),
                    torch.nan,
                ),
                "true_shape": torch.from_numpy(images[i]["true_shape"]),
                "idx": i,
                "instance": str(i),
                "camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze(
                    0
                ),
                "img_mask": torch.tensor(True).unsqueeze(0),
                "ray_mask": torch.tensor(False).unsqueeze(0),
                "update": torch.tensor(True).unsqueeze(0),
                "reset": torch.tensor((i+1) % reset_interval == 0).unsqueeze(0),
            }
            views.append(view)
            if (i+1) % reset_interval == 0:
                overlap_view = deepcopy(view)
                overlap_view["reset"] = torch.tensor(False).unsqueeze(0)
                views.append(overlap_view)
    else:
        # Combine images and raymaps.
        num_views = len(images) + len(raymaps)
        assert len(img_mask) == len(raymap_mask) == num_views
        assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps)

        j = 0
        k = 0
        for i in range(num_views):
            view = {
                "img": (
                    images[j]["img"]
                    if img_mask[i]
                    else torch.full_like(images[0]["img"], torch.nan)
                ),
                "ray_map": (
                    raymaps[k]
                    if raymap_mask[i]
                    else torch.full_like(raymaps[0], torch.nan)
                ),
                "true_shape": (
                    torch.from_numpy(images[j]["true_shape"])
                    if img_mask[i]
                    else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]]))
                ),
                "idx": i,
                "instance": str(i),
                "camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze(
                    0
                ),
                "img_mask": torch.tensor(img_mask[i]).unsqueeze(0),
                "ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0),
                "update": torch.tensor(img_mask[i]).unsqueeze(0),
                "reset": torch.tensor((i+1) % reset_interval == 0).unsqueeze(0),
            }
            if img_mask[i]:
                j += 1
            if raymap_mask[i]:
                k += 1
            views.append(view)
            if (i+1) % reset_interval == 0:
                overlap_view = deepcopy(view)
                overlap_view["reset"] = torch.tensor(False).unsqueeze(0)
                views.append(overlap_view)
        assert j == len(images) and k == len(raymaps)

    if revisit > 1:
        new_views = []
        for r in range(revisit):
            for i, view in enumerate(views):
                new_view = deepcopy(view)
                new_view["idx"] = r * len(views) + i
                new_view["instance"] = str(r * len(views) + i)
                if r > 0 and not update:
                    new_view["update"] = torch.tensor(False).unsqueeze(0)
                new_views.append(new_view)
        return new_views

    return views


def prepare_output(outputs, outdir, revisit=1, use_pose=True):
    """
    Process inference outputs to generate point clouds and camera parameters for visualization.

    Args:
        outputs (dict): Inference outputs.
        revisit (int): Number of revisits per view.
        use_pose (bool): Whether to transform points using camera pose.

    Returns:
        tuple: (points, colors, confidence, camera parameters dictionary)
    """
    from src.dust3r.utils.camera import pose_encoding_to_camera
    from src.dust3r.post_process import estimate_focal_knowing_depth
    from src.dust3r.utils.geometry import geotrf, matrix_cumprod
    import roma
    from viser_utils import convert_scene_output_to_glb


    # Only keep the outputs corresponding to one full pass.
    valid_length = len(outputs["pred"]) // revisit
    outputs["pred"] = outputs["pred"][-valid_length:]
    outputs["views"] = outputs["views"][-valid_length:]

    # delet overlaps: reset_mask=True outputs["pred"] and outputs["views"]
    reset_mask = torch.cat([view["reset"] for view in outputs["views"]], 0)
    shifted_reset_mask = torch.cat([torch.tensor(False).unsqueeze(0), reset_mask[:-1]], dim=0)

    outputs["pred"] = [
        pred for pred, mask in zip(outputs["pred"], shifted_reset_mask) if not mask]
    outputs["views"] = [
        view for view, mask in zip(outputs["views"], shifted_reset_mask) if not mask]
    reset_mask = reset_mask[~shifted_reset_mask]

    pts3ds_self_ls = [output["pts3d_in_self_view"].cpu() for output in outputs["pred"]]
    pts3ds_other = [output["pts3d_in_other_view"].cpu() for output in outputs["pred"]]
    conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
    conf_other = [output["conf"].cpu() for output in outputs["pred"]]
    pts3ds_self = torch.cat(pts3ds_self_ls, 0)

    # Recover camera poses.
    pr_poses = [
        pose_encoding_to_camera(pred["camera_pose"].clone()).cpu()
        for pred in outputs["pred"]
    ]

    if reset_mask.any():
        pr_poses = torch.cat(pr_poses, 0)
        identity = torch.eye(4, device=pr_poses.device)
        reset_poses = torch.where(reset_mask.unsqueeze(-1).unsqueeze(-1), pr_poses, identity)
        cumulative_bases = matrix_cumprod(reset_poses)
        shifted_bases = torch.cat([identity.unsqueeze(0), cumulative_bases[:-1]], dim=0)
        pr_poses = torch.einsum('bij,bjk->bik', shifted_bases, pr_poses)
        # Convert sequence_scale list
        pr_poses = list(pr_poses.unsqueeze(1).unbind(0))

    R_c2w = torch.cat([pr_pose[:, :3, :3] for pr_pose in pr_poses], 0)
    t_c2w = torch.cat([pr_pose[:, :3, 3] for pr_pose in pr_poses], 0)

    if use_pose:
        transformed_pts3ds_other = []
        for pose, pself in zip(pr_poses, pts3ds_self):
            transformed_pts3ds_other.append(geotrf(pose, pself.unsqueeze(0)))
        pts3ds_other = transformed_pts3ds_other
        conf_other = conf_self

    # Estimate focal length based on depth.
    B, H, W, _ = pts3ds_self.shape
    pp = torch.tensor([W // 2, H // 2], device=pts3ds_self.device).float().repeat(B, 1)
    focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld")

    colors = [
        0.5 * (output["img"].permute(0, 2, 3, 1) + 1.0) for output in outputs["views"]
    ]

    cam_dict = {
        "focal": focal.cpu().numpy(),
        "pp": pp.cpu().numpy(),
        "R": R_c2w.cpu().numpy(),
        "t": t_c2w.cpu().numpy(),
    }

    pts3ds_self_tosave = pts3ds_self  # B, H, W, 3
    depths_tosave = pts3ds_self_tosave[..., 2]
    pts3ds_other_tosave = torch.cat(pts3ds_other)  # B, H, W, 3
    conf_self_tosave = torch.cat(conf_self)  # B, H, W
    conf_other_tosave = torch.cat(conf_other)  # B, H, W
    colors_tosave = torch.cat(
        [
            0.5 * (output["img"].permute(0, 2, 3, 1).cpu() + 1.0)
            for output in outputs["views"]
        ]
    )  # [B, H, W, 3]
    cam2world_tosave = torch.cat(pr_poses)  # B, 4, 4
    intrinsics_tosave = (
        torch.eye(3).unsqueeze(0).repeat(cam2world_tosave.shape[0], 1, 1)
    )  # B, 3, 3
    intrinsics_tosave[:, 0, 0] = focal.detach().cpu()
    intrinsics_tosave[:, 1, 1] = focal.detach().cpu()
    intrinsics_tosave[:, 0, 2] = pp[:, 0]
    intrinsics_tosave[:, 1, 2] = pp[:, 1]

    if os.path.exists(os.path.join(outdir, "depth")):
        shutil.rmtree(os.path.join(outdir, "depth"))
    if os.path.exists(os.path.join(outdir, "conf")):
        shutil.rmtree(os.path.join(outdir, "conf"))
    if os.path.exists(os.path.join(outdir, "color")):
        shutil.rmtree(os.path.join(outdir, "color"))
    if os.path.exists(os.path.join(outdir, "camera")):
        shutil.rmtree(os.path.join(outdir, "camera"))
    os.makedirs(os.path.join(outdir, "depth"), exist_ok=True)
    os.makedirs(os.path.join(outdir, "conf"), exist_ok=True)
    os.makedirs(os.path.join(outdir, "color"), exist_ok=True)
    os.makedirs(os.path.join(outdir, "camera"), exist_ok=True)
    for f_id in range(len(pts3ds_self)):
        depth = depths_tosave[f_id].cpu().numpy()
        conf = conf_self_tosave[f_id].cpu().numpy()
        color = colors_tosave[f_id].cpu().numpy()
        c2w = cam2world_tosave[f_id].cpu().numpy()
        intrins = intrinsics_tosave[f_id].cpu().numpy()
        np.save(os.path.join(outdir, "depth", f"{f_id:06d}.npy"), depth)
        np.save(os.path.join(outdir, "conf", f"{f_id:06d}.npy"), conf)
        iio.imwrite(
            os.path.join(outdir, "color", f"{f_id:06d}.png"),
            (color * 255).astype(np.uint8),
        )
        np.savez(
            os.path.join(outdir, "camera", f"{f_id:06d}.npz"),
            pose=c2w,
            intrinsics=intrins,
        )

    # # convert_scene_output_to_glb(outdir, (colors_tosave * 255).to(torch.uint8), pts3ds_other_tosave, conf_other_tosave > 1, focal, cam2world_tosave, as_pointcloud=True)
    return pts3ds_other, colors, conf_other, cam_dict

def parse_seq_path(p, frame_interval=1):
    global framerate
    
    if os.path.isdir(p):
        all_img_paths = sorted(glob.glob(f"{p}/*"))
        img_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'}
        img_paths = [path for path in all_img_paths 
                    if os.path.splitext(path.lower())[1] in img_extensions]
        
        if not img_paths:
            raise ValueError(f"No image files found in directory {p}")
        
        if frame_interval > 1:
            img_paths = img_paths[::frame_interval]
            print(f" - Image sequence: Total images: {len(all_img_paths)}, "
                  f"Frame interval: {frame_interval}, Images to process: {len(img_paths)}")
        
        framerate = 30.0 / frame_interval
        
        tmpdirname = None
    else:
        cap = cv2.VideoCapture(p)
        if not cap.isOpened():
            raise ValueError(f"Error opening video file {p}")
        video_fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if video_fps == 0:
            cap.release()
            raise ValueError(f"Error: Video FPS is 0 for {p}")
        
        framerate = video_fps / frame_interval
        
        frame_indices = list(range(0, total_frames, frame_interval))
        print(
            f" - Video FPS: {video_fps}, Frame Interval: {frame_interval}, Total Frames to Read: {len(frame_indices)}, Processed Framerate: {framerate}"
        )
        img_paths = []
        tmpdirname = tempfile.mkdtemp()
        for i in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = cap.read()
            if not ret:
                break
            frame_path = os.path.join(tmpdirname, f"frame_{i}.jpg")
            cv2.imwrite(frame_path, frame)
            img_paths.append(frame_path)
        cap.release()
    return img_paths, tmpdirname


def run_inference(args):
    """
    Execute the full inference and visualization pipeline.

    Args:
        args: Parsed command-line arguments.
    """
    # Set up the computation device.
    device = args.device
    if device == "cuda" and not torch.cuda.is_available():
        print("CUDA not available. Switching to CPU.")
        device = "cpu"

    # Add the checkpoint path (required for model imports in the dust3r package).
    add_path_to_dust3r(args.model_path)

    # Import model and inference functions after adding the ckpt path.
    from src.dust3r.inference import inference, inference_recurrent, inference_recurrent_lighter
    from src.dust3r.model import ARCroco3DStereo
    from viser_utils import PointCloudViewer

    # Prepare image file paths.
    img_paths, tmpdirname = parse_seq_path(args.seq_path, args.frame_interval)
    if not img_paths:
        print(f"No images found in {args.seq_path}. Please verify the path.")
        return

    print(f"Found {len(img_paths)} images in {args.seq_path}.")
    img_mask = [True] * len(img_paths)

    # Prepare input views.
    print("Preparing input views...")
    views = prepare_input(
        img_paths=img_paths,
        img_mask=img_mask,
        size=args.size,
        revisit=1,
        update=True,
        reset_interval=args.reset_interval
    )
    if tmpdirname is not None:
        shutil.rmtree(tmpdirname)

    # Load and prepare the model.
    print(f"Loading model from {args.model_path}...")
    model = ARCroco3DStereo.from_pretrained(args.model_path).to(device)
    model.config.model_update_type = args.model_update_type

    model.eval()

    # Run inference.
    print("Running inference...")
    start_time = time.time()
    outputs, state_args = inference_recurrent_lighter(views, model, device)

    total_time = time.time() - start_time
    per_frame_time = total_time / len(views)
    FPS_num = 1 / per_frame_time
    print(
        f"Inference completed in {total_time:.2f} seconds (average {per_frame_time:.2f} s per frame), FPS: {FPS_num:.2f}."
    )

    # Process outputs for visualization.
    print("Preparing output for visualization...")
    pts3ds_other, colors, conf, cam_dict = prepare_output(
        outputs, args.output_dir, 1, True
    )

    # Convert tensors to numpy arrays for visualization.
    pts3ds_to_vis = [p.cpu().numpy() for p in pts3ds_other]
    colors_to_vis = [c.cpu().numpy() for c in colors]
    edge_colors = [None] * len(pts3ds_to_vis)

    # Create and run the point cloud viewer.
    print("Launching point cloud viewer...")
    viewer = PointCloudViewer(
        model,
        state_args,
        pts3ds_to_vis,
        colors_to_vis,
        conf,
        cam_dict,
        device=device,
        edge_color_list=edge_colors,
        show_camera=True,
        vis_threshold=args.vis_threshold,
        size = args.size,
        port = args.port,
        downsample_factor=args.downsample_factor
    )
    viewer.run()


def main():
    args = parse_args()
    if not args.seq_path:
        print(
            "No inputs found! Please use our gradio demo if you would like to iteractively upload inputs."
        )
        return
    else:
        run_inference(args)


if __name__ == "__main__":
    main()

```

## /eval/eval.md

# Evaluation

## Datasets
Please follow [MonST3R](https://github.com/Junyi42/monst3r/blob/main/data/evaluation_script.md) and [Spann3R](https://github.com/HengyiWang/spann3r/blob/main/docs/data_preprocess.md) to download **ScanNet**, **TUM-dynamics**,**Sintel**, **Bonn**, **KITTI**   and **7scenes** datasets.

### ScanNet
To prepare the **ScanNet** dataset, execute:
```bash
python datasets_preprocess/long_prepare_scannet.py # You may need to change the path of the dataset
```

### TUM-dynamics
To prepare the **TUM-dynamics** dataset, execute:
```bash
python datasets_preprocess/long_prepare_tum.py # You may need to change the path of the dataset
```

### Bonn
To prepare the **Bonn** dataset, execute:
```bash
python datasets_preprocess/long_prepare_bonn.py # You may need to change the path of the dataset
```

### KITTI
To prepare the **KITTI** dataset, execute:
```bash
python datasets_preprocess/long_prepare_kitti.py # You may need to change the path of the dataset
```

# Evaluation Scripts

Results will be saved in `eval_results/*`.

### Camera Pose Estimation

```bash
CUDA_VISIBLE_DEVICES=6,7 bash eval/relpose/run_scannet.sh # You may need to change [--num_processes] to the number of your gpus and choose sequence length in datasets=('scannet_s3_1000')
CUDA_VISIBLE_DEVICES=6,7 bash eval/relpose/run_tum.sh # You may need to change [--num_processes] to the number of your gpus and choose sequence length in datasets=('tum_s1_1000')
CUDA_VISIBLE_DEVICES=6,7 bash eval/relpose/run_sintel.sh # You may need to change [--num_processes] to the number of your gpus
```

### Video Depth

```bash
CUDA_VISIBLE_DEVICES=5 bash eval/video_depth/run_kitti.sh # You may need to change [--num_processes] to the number of your gpus and choose sequence length in datasets=('kitti_s1_500')
CUDA_VISIBLE_DEVICES=5 bash eval/video_depth/run_bonn.sh # You may need to change [--num_processes] to the number of your gpus and choose sequence length in datasets=('bonn_s1_500')
CUDA_VISIBLE_DEVICES=5 bash eval/video_depth/run_sintel.sh # You may need to change [--num_processes] to the number of your gpus
```



### 3D Reconstruction

```bash
CUDA_VISIBLE_DEVICES=5 bash eval/mv_recon/run.sh # You may need to change [--num_processes] to the number of your gpus and hoose sequence length in max_frames
```



## /eval/mv_recon/base.py

```py path="/eval/mv_recon/base.py" 
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# base class for implementing datasets
# --------------------------------------------------------
import PIL
import numpy as np
import torch

from eval.mv_recon.dataset_utils.transforms import ImgNorm
from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
import eval.mv_recon.dataset_utils.cropping as cropping


class BaseStereoViewDataset:
    """Define all basic options.

    Usage:
        class MyDataset (BaseStereoViewDataset):
            def _get_views(self, idx, rng):
                # overload here
                views = []
                views.append(dict(img=, ...))
                return views
    """

    def __init__(
        self,
        *,  # only keyword arguments
        split=None,
        resolution=None,  # square_size or (width, height) or list of [(width,height), ...]
        transform=ImgNorm,
        aug_crop=False,
        seed=None,
    ):
        self.num_views = 2
        self.split = split
        self._set_resolutions(resolution)

        self.transform = transform
        if isinstance(transform, str):
            transform = eval(transform)

        self.aug_crop = aug_crop
        self.seed = seed

    def __len__(self):
        return len(self.scenes)

    def get_stats(self):
        return f"{len(self)} pairs"

    def __repr__(self):
        resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
        return (
            f"""{type(self).__name__}({self.get_stats()},
            {self.split=},
            {self.seed=},
            resolutions={resolutions_str},
            {self.transform=})""".replace(
                "self.", ""
            )
            .replace("\n", "")
            .replace("   ", "")
        )

    def _get_views(self, idx, resolution, rng):
        raise NotImplementedError()

    def __getitem__(self, idx):
        if isinstance(idx, tuple):
            # the idx is specifying the aspect-ratio
            idx, ar_idx = idx
        else:
            assert len(self._resolutions) == 1
            ar_idx = 0

        # set-up the rng
        if self.seed:  # reseed for each __getitem__
            self._rng = np.random.default_rng(seed=self.seed + idx)
        elif not hasattr(self, "_rng"):
            seed = torch.initial_seed()  # this is different for each dataloader process
            self._rng = np.random.default_rng(seed=seed)

        # over-loaded code
        resolution = self._resolutions[
            ar_idx
        ]  # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
        views = self._get_views(idx, resolution, self._rng)

        # check data-types
        for v, view in enumerate(views):
            assert (
                "pts3d" not in view
            ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
            view["idx"] = v

            # encode the image
            width, height = view["img"].size
            view["true_shape"] = np.int32((height, width))
            view["img"] = self.transform(view["img"])

            assert "camera_intrinsics" in view
            if "camera_pose" not in view:
                view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
            else:
                assert np.isfinite(
                    view["camera_pose"]
                ).all(), f"NaN in camera pose for view {view_name(view)}"
            assert "pts3d" not in view
            assert "valid_mask" not in view
            assert np.isfinite(
                view["depthmap"]
            ).all(), f"NaN in depthmap for view {view_name(view)}"
            pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)

            view["pts3d"] = pts3d
            view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)

            # check all datatypes
            for key, val in view.items():
                res, err_msg = is_good_type(key, val)
                assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
            K = view["camera_intrinsics"]
            view["img_mask"] = True
            view["ray_mask"] = False
            view["ray_map"] = torch.full(
                (6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan
            )
            view["update"] = True
            view["reset"] = False

        # last thing done!
        for view in views:
            # transpose to make sure all views are the same size
            transpose_to_landscape(view)
            # this allows to check whether the RNG is is the same state each time
            view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
        return views

    def _set_resolutions(self, resolutions):
        """Set the resolution(s) of the dataset.
        Params:
            - resolutions: int or tuple or list of tuples
        """
        assert resolutions is not None, "undefined resolution"

        if not isinstance(resolutions, list):
            resolutions = [resolutions]

        self._resolutions = []
        for resolution in resolutions:
            if isinstance(resolution, int):
                width = height = resolution
            else:
                width, height = resolution
            assert isinstance(
                width, int
            ), f"Bad type for {width=} {type(width)=}, should be int"
            assert isinstance(
                height, int
            ), f"Bad type for {height=} {type(height)=}, should be int"
            assert width >= height
            self._resolutions.append((width, height))

    def _crop_resize_if_necessary(
        self, image, depthmap, intrinsics, resolution, rng=None, info=None
    ):
        """This function:
        - first downsizes the image with LANCZOS inteprolation,
          which is better than bilinear interpolation in
        """
        if not isinstance(image, PIL.Image.Image):
            image = PIL.Image.fromarray(image)

        # downscale with lanczos interpolation so that image.size == resolution
        # cropping centered on the principal point
        W, H = image.size
        cx, cy = intrinsics[:2, 2].round().astype(int)

        # calculate min distance to margin
        min_margin_x = min(cx, W - cx)
        min_margin_y = min(cy, H - cy)
        assert min_margin_x > W / 5, f"Bad principal point in view={info}"
        assert min_margin_y > H / 5, f"Bad principal point in view={info}"

        ## Center crop
        # Crop on the principal point, make it always centered
        # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
        l, t = cx - min_margin_x, cy - min_margin_y
        r, b = cx + min_margin_x, cy + min_margin_y
        crop_bbox = (l, t, r, b)

        image, depthmap, intrinsics = cropping.crop_image_depthmap(
            image, depthmap, intrinsics, crop_bbox
        )

        # # transpose the resolution if necessary
        W, H = image.size  # new size
        assert resolution[0] >= resolution[1]
        if H > 1.1 * W:
            # image is portrait mode
            resolution = resolution[::-1]
        elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
            # image is square, so we chose (portrait, landscape) randomly
            if rng.integers(2):
                resolution = resolution[::-1]

        # high-quality Lanczos down-scaling
        target_resolution = np.array(resolution)
        # # if self.aug_crop > 1:
        # #     target_resolution += rng.integers(0, self.aug_crop)
        # if resolution != (224, 224):
        #     halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8
        #     ## Recale with max factor, so  one of width or height might be larger than target_resolution
        #     image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh))
        # else:
        image, depthmap, intrinsics = cropping.rescale_image_depthmap(
            image, depthmap, intrinsics, target_resolution
        )
        # actual cropping (if necessary) with bilinear interpolation
        # if resolution == (224, 224):
        intrinsics2 = cropping.camera_matrix_of_crop(
            intrinsics, image.size, resolution, offset_factor=0.5
        )
        crop_bbox = cropping.bbox_from_intrinsics_in_out(
            intrinsics, intrinsics2, resolution
        )
        image, depthmap, intrinsics = cropping.crop_image_depthmap(
            image, depthmap, intrinsics, crop_bbox
        )
        return image, depthmap, intrinsics


def is_good_type(key, v):
    """returns (is_good, err_msg)"""
    if isinstance(v, (str, int, tuple)):
        return True, None
    if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
        return False, f"bad {v.dtype=}"
    return True, None


def view_name(view, batch_index=None):
    def sel(x):
        return x[batch_index] if batch_index not in (None, slice(None)) else x

    db = sel(view["dataset"])
    label = sel(view["label"])
    instance = sel(view["instance"])
    return f"{db}/{label}/{instance}"


def transpose_to_landscape(view):
    height, width = view["true_shape"]

    if width < height:
        # rectify portrait to landscape
        assert view["img"].shape == (3, height, width)
        view["img"] = view["img"].swapaxes(1, 2)

        assert view["valid_mask"].shape == (height, width)
        view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)

        assert view["depthmap"].shape == (height, width)
        view["depthmap"] = view["depthmap"].swapaxes(0, 1)

        assert view["pts3d"].shape == (height, width, 3)
        view["pts3d"] = view["pts3d"].swapaxes(0, 1)

        # transpose x and y pixels
        view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]

```

## /eval/mv_recon/criterion.py

```py path="/eval/mv_recon/criterion.py" 
import torch
import torch.nn as nn
from copy import copy, deepcopy
from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
from dust3r.utils.geometry import inv, geotrf, depthmap_to_pts3d
from dust3r.utils.camera import pose_encoding_to_camera


class BaseCriterion(nn.Module):
    def __init__(self, reduction="mean"):
        super().__init__()
        self.reduction = reduction


class Criterion(nn.Module):
    def __init__(self, criterion=None):
        super().__init__()
        assert isinstance(
            criterion, BaseCriterion
        ), f"{criterion} is not a proper criterion!"
        self.criterion = copy(criterion)

    def get_name(self):
        return f"{type(self).__name__}({self.criterion})"

    def with_reduction(self, mode="none"):
        res = loss = deepcopy(self)
        while loss is not None:
            assert isinstance(loss, Criterion)
            loss.criterion.reduction = mode  # make it return the loss for each sample
            loss = loss._loss2  # we assume loss is a Multiloss
        return res


class MultiLoss(nn.Module):
    """Easily combinable losses (also keep track of individual loss values):
        loss = MyLoss1() + 0.1*MyLoss2()
    Usage:
        Inherit from this class and override get_name() and compute_loss()
    """

    def __init__(self):
        super().__init__()
        self._alpha = 1
        self._loss2 = None

    def compute_loss(self, *args, **kwargs):
        raise NotImplementedError()

    def get_name(self):
        raise NotImplementedError()

    def __mul__(self, alpha):
        assert isinstance(alpha, (int, float))
        res = copy(self)
        res._alpha = alpha
        return res

    __rmul__ = __mul__  # same

    def __add__(self, loss2):
        assert isinstance(loss2, MultiLoss)
        res = cur = copy(self)

        while cur._loss2 is not None:
            cur = cur._loss2
        cur._loss2 = loss2
        return res

    def __repr__(self):
        name = self.get_name()
        if self._alpha != 1:
            name = f"{self._alpha:g}*{name}"
        if self._loss2:
            name = f"{name} + {self._loss2}"
        return name

    def forward(self, *args, **kwargs):
        loss = self.compute_loss(*args, **kwargs)
        if isinstance(loss, tuple):
            loss, details = loss
        elif loss.ndim == 0:
            details = {self.get_name(): float(loss)}
        else:
            details = {}
        loss = loss * self._alpha

        if self._loss2:
            loss2, details2 = self._loss2(*args, **kwargs)
            loss = loss + loss2
            details |= details2

        return loss, details


class LLoss(BaseCriterion):
    """L-norm loss"""

    def forward(self, a, b):
        assert (
            a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3
        ), f"Bad shape = {a.shape}"
        dist = self.distance(a, b)

        if self.reduction == "none":
            return dist
        if self.reduction == "sum":
            return dist.sum()
        if self.reduction == "mean":
            return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
        raise ValueError(f"bad {self.reduction=} mode")

    def distance(self, a, b):
        raise NotImplementedError()


class L21Loss(LLoss):
    """Euclidean distance between 3d points"""

    def distance(self, a, b):
        return torch.norm(a - b, dim=-1)  # normalized L2 distance


L21 = L21Loss()


def get_pred_pts3d(gt, pred, use_pose=False):
    if "depth" in pred and "pseudo_focal" in pred:
        try:
            pp = gt["camera_intrinsics"][..., :2, 2]
        except KeyError:
            pp = None
        pts3d = depthmap_to_pts3d(**pred, pp=pp)

    elif "pts3d" in pred:
        # pts3d from my camera
        pts3d = pred["pts3d"]

    elif "pts3d_in_other_view" in pred:
        # pts3d from the other camera, already transformed
        assert use_pose is True
        return pred["pts3d_in_other_view"]  # return!

    if use_pose:
        camera_pose = pred.get("camera_pose")
        pts3d = pred.get("pts3d_in_self_view")
        assert camera_pose is not None
        assert pts3d is not None
        pts3d = geotrf(pose_encoding_to_camera(camera_pose), pts3d)

    return pts3d


def Sum(losses, masks, conf=None):
    loss, mask = losses[0], masks[0]
    if loss.ndim > 0:
        # we are actually returning the loss for every pixels
        if conf is not None:
            return losses, masks, conf
        return losses, masks
    else:
        # we are returning the global loss
        for loss2 in losses[1:]:
            loss = loss + loss2
        return loss


def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True):
    assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3
    assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3)
    norm_mode, dis_mode = norm_mode.split("_")

    nan_pts = []
    nnzs = []

    if norm_mode == "avg":
        # gather all points together (joint normalization)

        for i, pt in enumerate(pts):
            nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3)
            nan_pts.append(nan_pt)
            nnzs.append(nnz)

            if fix_first:
                break
        all_pts = torch.cat(nan_pts, dim=1)

        # compute distance to origin
        all_dis = all_pts.norm(dim=-1)
        if dis_mode == "dis":
            pass  # do nothing
        elif dis_mode == "log1p":
            all_dis = torch.log1p(all_dis)
        else:
            raise ValueError(f"bad {dis_mode=}")

        norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8)
    else:
        raise ValueError(f"Not implemented {norm_mode=}")

    norm_factor = norm_factor.clip(min=1e-8)
    while norm_factor.ndim < pts[0].ndim:
        norm_factor.unsqueeze_(-1)

    return norm_factor


def normalize_pointcloud_t(
    pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False
):
    if gt:
        norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
        res = []

        for i, pt in enumerate(pts):
            res.append(pt / norm_factor)

    else:
        # pts_l, pts_r = pts
        # use pts_l and pts_r[-1] as pts to normalize
        norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)

        res = []

        for i in range(len(pts)):
            res.append(pts[i] / norm_factor)
            # res_r.append(pts_r[i] / norm_factor)

        # res = [res_l, res_r]

    return res, norm_factor


@torch.no_grad()
def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5):
    # set invalid points to NaN
    _zs = []
    for i in range(len(zs)):
        valid_mask = valid_masks[i] if valid_masks is not None else None
        _z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1)
        _zs.append(_z)

    _zs = torch.cat(_zs, dim=-1)

    # compute median depth overall (ignoring nans)
    if quantile == 0.5:
        shift_z = torch.nanmedian(_zs, dim=-1).values
    else:
        shift_z = torch.nanquantile(_zs, quantile, dim=-1)
    return shift_z  # (B,)


@torch.no_grad()
def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
    # set invalid points to NaN

    _pts = []
    for i in range(len(pts)):
        valid_mask = valid_masks[i] if valid_masks is not None else None
        _pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3)
        _pts.append(_pt)

    _pts = torch.cat(_pts, dim=1)

    # compute median center
    _center = torch.nanmedian(_pts, dim=1, keepdim=True).values  # (B,1,3)
    if z_only:
        _center[..., :2] = 0  # do not center X and Y

    # compute median norm
    _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
    scale = torch.nanmedian(_norm, dim=1).values
    return _center[:, None, :, :], scale[:, None, None, None]


class Regr3D_t(Criterion, MultiLoss):
    def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True):
        super().__init__(criterion)
        self.norm_mode = norm_mode
        self.gt_scale = gt_scale
        self.fix_first = fix_first

    def get_all_pts3d_t(self, gts, preds, dist_clip=None):
        # everything is normalized w.r.t. camera of view1
        in_camera1 = inv(gts[0]["camera_pose"])

        gt_pts = []
        valids = []
        pr_pts = []

        for i, gt in enumerate(gts):
            # in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3
            gt_pts.append(geotrf(in_camera1, gt["pts3d"]))

            valid = gt["valid_mask"].clone()

            if dist_clip is not None:
                # points that are too far-away == invalid
                dis = gt["pts3d"].norm(dim=-1)
                valid = valid & (dis <= dist_clip)

            valids.append(valid)
            pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True))
            # if i != len(gts)-1:
            #     pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0)))

            # if i != 0:
            #     pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0)))

        # pr_pts = (pr_pts_l, pr_pts_r)

        if self.norm_mode:
            pr_pts, pr_factor = normalize_pointcloud_t(
                pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False
            )
        else:
            pr_factor = None

        if self.norm_mode and not self.gt_scale:
            gt_pts, gt_factor = normalize_pointcloud_t(
                gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True
            )
        else:
            gt_factor = None

        return gt_pts, pr_pts, gt_factor, pr_factor, valids, {}

    def compute_frame_loss(self, gts, preds, **kw):
        gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
            self.get_all_pts3d_t(gts, preds, **kw)
        )

        pred_pts_l, pred_pts_r = pred_pts

        loss_all = []
        mask_all = []
        conf_all = []

        loss_left = 0
        loss_right = 0
        pred_conf_l = 0
        pred_conf_r = 0

        for i in range(len(gt_pts)):

            # Left (Reference)
            if i != len(gt_pts) - 1:
                frame_loss = self.criterion(
                    pred_pts_l[i][masks[i]], gt_pts[i][masks[i]]
                )

                loss_all.append(frame_loss)
                mask_all.append(masks[i])
                conf_all.append(preds[i][0]["conf"])

                # To compare target/reference loss
                if i != 0:
                    loss_left += frame_loss.cpu().detach().numpy().mean()
                    pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean()

            # Right (Target)
            if i != 0:
                frame_loss = self.criterion(
                    pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]]
                )

                loss_all.append(frame_loss)
                mask_all.append(masks[i])
                conf_all.append(preds[i - 1][1]["conf"])

                # To compare target/reference loss
                if i != len(gt_pts) - 1:
                    loss_right += frame_loss.cpu().detach().numpy().mean()
                    pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean()

        if pr_factor is not None and gt_factor is not None:
            filter_factor = pr_factor[pr_factor > gt_factor]
        else:
            filter_factor = []

        if len(filter_factor) > 0:
            factor_loss = (filter_factor - gt_factor).abs().mean()
        else:
            factor_loss = 0.0

        self_name = type(self).__name__
        details = {
            self_name + "_pts3d_1": float(loss_all[0].mean()),
            self_name + "_pts3d_2": float(loss_all[1].mean()),
            self_name + "loss_left": float(loss_left),
            self_name + "loss_right": float(loss_right),
            self_name + "conf_left": float(pred_conf_l),
            self_name + "conf_right": float(pred_conf_r),
        }

        return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss


class ConfLoss_t(MultiLoss):
    """Weighted regression by learned confidence.
        Assuming the input pixel_loss is a pixel-level regression loss.

    Principle:
        high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
        low  confidence means low  conf = 10  ==> conf_loss = x * 10 - alpha*log(10)

        alpha: hyperparameter
    """

    def __init__(self, pixel_loss, alpha=1):
        super().__init__()
        assert alpha > 0
        self.alpha = alpha
        self.pixel_loss = pixel_loss.with_reduction("none")

    def get_name(self):
        return f"ConfLoss({self.pixel_loss})"

    def get_conf_log(self, x):
        return x, torch.log(x)

    def compute_frame_loss(self, gts, preds, **kw):
        # compute per-pixel loss
        (losses, masks, confs), details, loss_factor = (
            self.pixel_loss.compute_frame_loss(gts, preds, **kw)
        )

        # weight by confidence
        conf_losses = []
        conf_sum = 0
        for i in range(len(losses)):
            conf, log_conf = self.get_conf_log(confs[i][masks[i]])
            conf_sum += conf.mean()
            conf_loss = losses[i] * conf - self.alpha * log_conf
            conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
            conf_losses.append(conf_loss)

        conf_losses = torch.stack(conf_losses) * 2.0
        conf_loss_mean = conf_losses.mean()

        return (
            conf_loss_mean,
            dict(
                conf_loss_1=float(conf_losses[0]),
                conf_loss2=float(conf_losses[1]),
                conf_mean=conf_sum / len(losses),
                **details,
            ),
            loss_factor,
        )


class Regr3D_t_ShiftInv(Regr3D_t):
    """Same than Regr3D but invariant to depth shift."""

    def get_all_pts3d_t(self, gts, preds):
        # compute unnormalized points
        gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
            super().get_all_pts3d_t(gts, preds)
        )

        # pred_pts_l, pred_pts_r = pred_pts
        gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts]

        pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts]
        # pred_zs.append(pred_pts_r[-1][..., 2])

        # compute median depth
        gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None]
        pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None]

        # subtract the median depth
        for i in range(len(gt_pts)):
            gt_pts[i][..., 2] -= gt_shift_z

        for i in range(len(pred_pts)):
            # for j in range(len(pred_pts[i])):
            pred_pts[i][..., 2] -= pred_shift_z

        monitoring = dict(
            monitoring,
            gt_shift_z=gt_shift_z.mean().detach(),
            pred_shift_z=pred_shift_z.mean().detach(),
        )
        return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring


class Regr3D_t_ScaleInv(Regr3D_t):
    """Same than Regr3D but invariant to depth shift.
    if gt_scale == True: enforce the prediction to take the same scale than GT
    """

    def get_all_pts3d_t(self, gts, preds):
        # compute depth-normalized points
        gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
            super().get_all_pts3d_t(gts, preds)
        )

        # measure scene scale

        # pred_pts_l, pred_pts_r = pred_pts

        pred_pts_all = [
            x.clone() for x in pred_pts
        ]  # [pred_pt for pred_pt in pred_pts_l]
        # pred_pts_all.append(pred_pts_r[-1])

        _, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks)
        _, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks)

        # prevent predictions to be in a ridiculous range
        pred_scale = pred_scale.clip(min=1e-3, max=1e3)

        # subtract the median depth
        if self.gt_scale:
            for i in range(len(pred_pts)):
                # for j in range(len(pred_pts[i])):
                pred_pts[i] *= gt_scale / pred_scale

        else:
            for i in range(len(pred_pts)):
                # for j in range(len(pred_pts[i])):
                pred_pts[i] *= pred_scale / gt_scale

            for i in range(len(gt_pts)):
                gt_pts[i] *= gt_scale / pred_scale

        monitoring = dict(
            monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()
        )

        return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring


class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv):
    # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
    pass

```

## /eval/mv_recon/data.py

```py path="/eval/mv_recon/data.py" 
import os
import cv2
import json
import numpy as np
import os.path as osp
from collections import deque
import random
from eval.mv_recon.base import BaseStereoViewDataset
from dust3r.utils.image import imread_cv2
import eval.mv_recon.dataset_utils.cropping as cropping


def shuffle_deque(dq, seed=None):
    # Set the random seed for reproducibility
    if seed is not None:
        random.seed(seed)

    # Convert deque to list, shuffle, and convert back
    shuffled_list = list(dq)
    random.shuffle(shuffled_list)
    return deque(shuffled_list)


class SevenScenes(BaseStereoViewDataset):
    def __init__(
        self,
        num_seq=1,
        num_frames=5,
        min_thresh=10,
        max_thresh=100,
        test_id=None,
        full_video=False,
        tuple_list=None,
        seq_id=None,
        rebuttal=False,
        shuffle_seed=-1,
        kf_every=1,
        max_frames=None,
        *args,
        ROOT,
        **kwargs,
    ):
        self.ROOT = ROOT
        super().__init__(*args, **kwargs)
        self.num_seq = num_seq
        self.num_frames = num_frames
        self.max_thresh = max_thresh
        self.min_thresh = min_thresh
        self.test_id = test_id
        self.full_video = full_video
        self.kf_every = kf_every
        self.seq_id = seq_id
        self.rebuttal = rebuttal
        self.shuffle_seed = shuffle_seed
        self.max_frames = max_frames

        # load all scenes
        self.load_all_tuples(tuple_list)
        self.load_all_scenes(ROOT)

    def __len__(self):
        if self.tuple_list is not None:
            return len(self.tuple_list)
        return len(self.scene_list) * self.num_seq

    def load_all_tuples(self, tuple_list):
        if tuple_list is not None:
            self.tuple_list = tuple_list
            # with open(tuple_path) as f:
            #     self.tuple_list = f.read().splitlines()

        else:
            self.tuple_list = None

    def load_all_scenes(self, base_dir):

        if self.tuple_list is not None:
            # Use pre-defined simplerecon scene_ids
            self.scene_list = [
                "stairs/seq-06",
                "stairs/seq-02",
                "pumpkin/seq-06",
                "chess/seq-01",
                "heads/seq-02",
                "fire/seq-02",
                "office/seq-03",
                "pumpkin/seq-03",
                "redkitchen/seq-07",
                "chess/seq-02",
                "office/seq-01",
                "redkitchen/seq-01",
                "fire/seq-01",
            ]
            print(f"Found {len(self.scene_list)} sequences in split {self.split}")
            return

        scenes = os.listdir(base_dir)

        file_split = {"train": "TrainSplit.txt", "test": "TestSplit.txt"}[self.split]

        self.scene_list = []
        for scene in scenes:
            if self.test_id is not None and scene != self.test_id:
                continue
            # read file split
            with open(osp.join(base_dir, scene, file_split)) as f:
                seq_ids = f.read().splitlines()

                for seq_id in seq_ids:
                    # seq is string, take the int part and make it 01, 02, 03
                    # seq_id = 'seq-{:2d}'.format(int(seq_id))
                    num_part = "".join(filter(str.isdigit, seq_id))
                    seq_id = f"seq-{num_part.zfill(2)}"
                    if self.seq_id is not None and seq_id != self.seq_id:
                        continue
                    self.scene_list.append(f"{scene}/{seq_id}")

        print(f"Found {len(self.scene_list)} sequences in split {self.split}")

    def _get_views(self, idx, resolution, rng):

        if self.tuple_list is not None:
            line = self.tuple_list[idx].split(" ")
            scene_id = line[0]
            img_idxs = line[1:]

        else:
            scene_id = self.scene_list[idx // self.num_seq]
            seq_id = idx % self.num_seq

            data_path = osp.join(self.ROOT, scene_id)
            num_files = len([name for name in os.listdir(data_path) if "color" in name])
            img_idxs = [f"{i:06d}" for i in range(num_files)]
            img_idxs = img_idxs[:: self.kf_every]
            
            if self.max_frames is not None:
                img_idxs = img_idxs[:self.max_frames]

        # Intrinsics used in SimpleRecon
        fx, fy, cx, cy = 525, 525, 320, 240
        intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)

        views = []
        imgs_idxs = deque(img_idxs)
        if self.shuffle_seed >= 0:
            imgs_idxs = shuffle_deque(imgs_idxs)

        while len(imgs_idxs) > 0:
            im_idx = imgs_idxs.popleft()
            impath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.color.png")
            depthpath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.depth.proj.png")
            posepath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.pose.txt")

            rgb_image = imread_cv2(impath)
            depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
            rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))

            depthmap[depthmap == 65535] = 0
            depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
            depthmap[depthmap > 10] = 0
            depthmap[depthmap < 1e-3] = 0

            camera_pose = np.loadtxt(posepath).astype(np.float32)

            if resolution != (224, 224) or self.rebuttal:
                rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
                    rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
                )
            else:
                rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
                    rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
                )
                W, H = rgb_image.size
                cx = W // 2
                cy = H // 2
                l, t = cx - 112, cy - 112
                r, b = cx + 112, cy + 112
                crop_bbox = (l, t, r, b)
                rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
                    rgb_image, depthmap, intrinsics, crop_bbox
                )

            views.append(
                dict(
                    img=rgb_image,
                    depthmap=depthmap,
                    camera_pose=camera_pose,
                    camera_intrinsics=intrinsics,
                    dataset="7scenes",
                    label=osp.join(scene_id, im_idx),
                    instance=impath,
                )
            )
        return views


class DTU(BaseStereoViewDataset):
    def __init__(
        self,
        num_seq=49,
        num_frames=5,
        min_thresh=10,
        max_thresh=30,
        test_id=None,
        full_video=False,
        sample_pairs=False,
        kf_every=1,
        *args,
        ROOT,
        **kwargs,
    ):
        self.ROOT = ROOT
        super().__init__(*args, **kwargs)

        self.num_seq = num_seq
        self.num_frames = num_frames
        self.max_thresh = max_thresh
        self.min_thresh = min_thresh
        self.test_id = test_id
        self.full_video = full_video
        self.kf_every = kf_every
        self.sample_pairs = sample_pairs

        # load all scenes
        self.load_all_scenes(ROOT)

    def __len__(self):
        return len(self.scene_list) * self.num_seq

    def load_all_scenes(self, base_dir):

        if self.test_id is None:
            self.scene_list = os.listdir(osp.join(base_dir))
            print(f"Found {len(self.scene_list)} scenes in split {self.split}")

        else:
            if isinstance(self.test_id, list):
                self.scene_list = self.test_id
            else:
                self.scene_list = [self.test_id]

            print(f"Test_id: {self.test_id}")

    def load_cam_mvsnet(self, file, interval_scale=1):
        """read camera txt file"""
        cam = np.zeros((2, 4, 4))
        words = file.read().split()
        # read extrinsic
        for i in range(0, 4):
            for j in range(0, 4):
                extrinsic_index = 4 * i + j + 1
                cam[0][i][j] = words[extrinsic_index]

        # read intrinsic
        for i in range(0, 3):
            for j in range(0, 3):
                intrinsic_index = 3 * i + j + 18
                cam[1][i][j] = words[intrinsic_index]

        if len(words) == 29:
            cam[1][3][0] = words[27]
            cam[1][3][1] = float(words[28]) * interval_scale
            cam[1][3][2] = 192
            cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2]
        elif len(words) == 30:
            cam[1][3][0] = words[27]
            cam[1][3][1] = float(words[28]) * interval_scale
            cam[1][3][2] = words[29]
            cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2]
        elif len(words) == 31:
            cam[1][3][0] = words[27]
            cam[1][3][1] = float(words[28]) * interval_scale
            cam[1][3][2] = words[29]
            cam[1][3][3] = words[30]
        else:
            cam[1][3][0] = 0
            cam[1][3][1] = 0
            cam[1][3][2] = 0
            cam[1][3][3] = 0

        extrinsic = cam[0].astype(np.float32)
        intrinsic = cam[1].astype(np.float32)

        return intrinsic, extrinsic

    def _get_views(self, idx, resolution, rng):
        scene_id = self.scene_list[idx // self.num_seq]
        seq_id = idx % self.num_seq

        print("Scene ID:", scene_id)

        image_path = osp.join(self.ROOT, scene_id, "images")
        depth_path = osp.join(self.ROOT, scene_id, "depths")
        mask_path = osp.join(self.ROOT, scene_id, "binary_masks")
        cam_path = osp.join(self.ROOT, scene_id, "cams")
        pairs_path = osp.join(self.ROOT, scene_id, "pair.txt")

        if not self.full_video:
            img_idxs = self.sample_pairs(pairs_path, seq_id)
        else:
            img_idxs = sorted(os.listdir(image_path))
            img_idxs = img_idxs[:: self.kf_every]

        views = []
        imgs_idxs = deque(img_idxs)

        while len(imgs_idxs) > 0:
            im_idx = imgs_idxs.pop()
            impath = osp.join(image_path, im_idx)
            depthpath = osp.join(depth_path, im_idx.replace(".jpg", ".npy"))
            campath = osp.join(cam_path, im_idx.replace(".jpg", "_cam.txt"))
            maskpath = osp.join(mask_path, im_idx.replace(".jpg", ".png"))

            rgb_image = imread_cv2(impath)
            depthmap = np.load(depthpath)
            depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0)

            mask = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED) / 255.0
            mask = mask.astype(np.float32)

            mask[mask > 0.5] = 1.0
            mask[mask < 0.5] = 0.0

            mask = cv2.resize(
                mask,
                (depthmap.shape[1], depthmap.shape[0]),
                interpolation=cv2.INTER_NEAREST,
            )
            kernel = np.ones((10, 10), np.uint8)  # Define the erosion kernel
            mask = cv2.erode(mask, kernel, iterations=1)
            depthmap = depthmap * mask

            cur_intrinsics, camera_pose = self.load_cam_mvsnet(open(campath, "r"))
            intrinsics = cur_intrinsics[:3, :3]
            camera_pose = np.linalg.inv(camera_pose)

            if resolution != (224, 224):
                rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
                    rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath
                )
            else:
                rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
                    rgb_image, depthmap, intrinsics, (512, 384), rng=rng, info=impath
                )
                W, H = rgb_image.size
                cx = W // 2
                cy = H // 2
                l, t = cx - 112, cy - 112
                r, b = cx + 112, cy + 112
                crop_bbox = (l, t, r, b)
                rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
                    rgb_image, depthmap, intrinsics, crop_bbox
                )

            views.append(
                dict(
                    img=rgb_image,
                    depthmap=depthmap,
                    camera_pose=camera_pose,
                    camera_intrinsics=intrinsics,
                    dataset="dtu",
                    label=osp.join(scene_id, im_idx),
                    instance=impath,
                )
            )

        return views


class NRGBD(BaseStereoViewDataset):
    def __init__(
        self,
        num_seq=1,
        num_frames=5,
        min_thresh=10,
        max_thresh=100,
        test_id=None,
        full_video=False,
        tuple_list=None,
        seq_id=None,
        rebuttal=False,
        shuffle_seed=-1,
        kf_every=1,
        max_frames=None,
        *args,
        ROOT,
        **kwargs,
    ):

        self.ROOT = ROOT
        super().__init__(*args, **kwargs)
        self.num_seq = num_seq
        self.num_frames = num_frames
        self.max_thresh = max_thresh
        self.min_thresh = min_thresh
        self.test_id = test_id
        self.full_video = full_video
        self.kf_every = kf_every
        self.seq_id = seq_id
        self.rebuttal = rebuttal
        self.shuffle_seed = shuffle_seed
        self.max_frames = max_frames

        # load all scenes
        self.load_all_tuples(tuple_list)
        self.load_all_scenes(ROOT)

    def __len__(self):
        if self.tuple_list is not None:
            return len(self.tuple_list)
        return len(self.scene_list) * self.num_seq

    def load_all_tuples(self, tuple_list):
        if tuple_list is not None:
            self.tuple_list = tuple_list
            # with open(tuple_path) as f:
            #     self.tuple_list = f.read().splitlines()

        else:
            self.tuple_list = None

    def load_all_scenes(self, base_dir):

        scenes = [
            d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
        ]

        if self.test_id is not None:
            self.scene_list = [self.test_id]

        else:
            self.scene_list = scenes

        print(f"Found {len(self.scene_list)} sequences in split {self.split}")

    def load_poses(self, path):
        file = open(path, "r")
        lines = file.readlines()
        file.close()
        poses = []
        valid = []
        lines_per_matrix = 4
        for i in range(0, len(lines), lines_per_matrix):
            if "nan" in lines[i]:
                valid.append(False)
                poses.append(np.eye(4, 4, dtype=np.float32).tolist())
            else:
                valid.append(True)
                pose_floats = [
                    [float(x) for x in line.split()]
                    for line in lines[i : i + lines_per_matrix]
                ]
                poses.append(pose_floats)

        return np.array(poses, dtype=np.float32), valid

    def _get_views(self, idx, resolution, rng):

        if self.tuple_list is not None:
            line = self.tuple_list[idx].split(" ")
            scene_id = line[0]
            img_idxs = line[1:]

        else:
            scene_id = self.scene_list[idx // self.num_seq]

            num_files = len(os.listdir(os.path.join(self.ROOT, scene_id, "images")))
            img_idxs = [f"{i}" for i in range(num_files)]
            img_idxs = img_idxs[:: min(self.kf_every, len(img_idxs) // 2)]
            if self.max_frames is not None:
                img_idxs = img_idxs[:self.max_frames]

        fx, fy, cx, cy = 554.2562584220408, 554.2562584220408, 320, 240
        intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)

        posepath = osp.join(self.ROOT, scene_id, f"poses.txt")
        camera_poses, valids = self.load_poses(posepath)

        imgs_idxs = deque(img_idxs)
        if self.shuffle_seed >= 0:
            imgs_idxs = shuffle_deque(imgs_idxs)
        views = []

        while len(imgs_idxs) > 0:
            im_idx = imgs_idxs.popleft()

            impath = osp.join(self.ROOT, scene_id, "images", f"img{im_idx}.png")
            depthpath = osp.join(self.ROOT, scene_id, "depth", f"depth{im_idx}.png")

            rgb_image = imread_cv2(impath)
            depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
            depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
            depthmap[depthmap > 10] = 0
            depthmap[depthmap < 1e-3] = 0

            rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))

            camera_pose = camera_poses[int(im_idx)]
            # gl to cv
            camera_pose[:, 1:3] *= -1.0
            if resolution != (224, 224) or self.rebuttal:
                rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
                    rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
                )
            else:
                rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
                    rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
                )
                W, H = rgb_image.size
                cx = W // 2
                cy = H // 2
                l, t = cx - 112, cy - 112
                r, b = cx + 112, cy + 112
                crop_bbox = (l, t, r, b)
                rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
                    rgb_image, depthmap, intrinsics, crop_bbox
                )

            views.append(
                dict(
                    img=rgb_image,
                    depthmap=depthmap,
                    camera_pose=camera_pose,
                    camera_intrinsics=intrinsics,
                    dataset="nrgbd",
                    label=osp.join(scene_id, im_idx),
                    instance=impath,
                )
            )

        return views

```

## /eval/mv_recon/dataset_utils/__init__.py

```py path="/eval/mv_recon/dataset_utils/__init__.py" 


```

## /eval/mv_recon/dataset_utils/corr.py

```py path="/eval/mv_recon/dataset_utils/corr.py" 
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------

import numpy as np
from dust3r.utils.device import to_numpy
from dust3r.utils.geometry import inv, geotrf


def reproject_view(pts3d, view2):
    shape = view2["pts3d"].shape[:2]
    return reproject(
        pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
    )


def reproject(pts3d, K, world2cam, shape):
    H, W, THREE = pts3d.shape
    assert THREE == 3

    with np.errstate(divide="ignore", invalid="ignore"):
        pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)

    return (H, W), ravel_xy(pos, shape)


def ravel_xy(pos, shape):
    H, W = shape
    with np.errstate(invalid="ignore"):
        qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
    quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
        min=0, max=H - 1, out=qy
    )
    return quantized_pos


def unravel_xy(pos, shape):

    return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()


def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
    is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
    pos1 = is_reciprocal1.nonzero()[0]
    pos2 = corres_1_to_2[pos1]
    if ret_recip:
        return is_reciprocal1, pos1, pos2
    return pos1, pos2


def extract_correspondences_from_pts3d(
    view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
):
    view1, view2 = to_numpy((view1, view2))

    shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
    shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)

    is_reciprocal1, pos1, pos2 = reciprocal_1d(
        corres1_to_2, corres2_to_1, ret_recip=True
    )
    is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))

    if target_n_corres is None:
        if ret_xy:
            pos1 = unravel_xy(pos1, shape1)
            pos2 = unravel_xy(pos2, shape2)
        return pos1, pos2

    available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
    target_n_positives = int(target_n_corres * (1 - nneg))
    n_positives = min(len(pos1), target_n_positives)
    n_negatives = min(target_n_corres - n_positives, available_negatives)

    if n_negatives + n_positives != target_n_corres:

        n_positives = target_n_corres - n_negatives
        assert n_positives <= len(pos1)

    assert n_positives <= len(pos1)
    assert n_positives <= len(pos2)
    assert n_negatives <= (~is_reciprocal1).sum()
    assert n_negatives <= (~is_reciprocal2).sum()
    assert n_positives + n_negatives == target_n_corres

    valid = np.ones(n_positives, dtype=bool)
    if n_positives < len(pos1):

        perm = rng.permutation(len(pos1))[:n_positives]
        pos1 = pos1[perm]
        pos2 = pos2[perm]

    if n_negatives > 0:

        def norm(p):
            return p / p.sum()

        pos1 = np.r_[
            pos1,
            rng.choice(
                shape1[0] * shape1[1],
                size=n_negatives,
                replace=False,
                p=norm(~is_reciprocal1),
            ),
        ]
        pos2 = np.r_[
            pos2,
            rng.choice(
                shape2[0] * shape2[1],
                size=n_negatives,
                replace=False,
                p=norm(~is_reciprocal2),
            ),
        ]
        valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]

    if ret_xy:
        pos1 = unravel_xy(pos1, shape1)
        pos2 = unravel_xy(pos2, shape2)
    return pos1, pos2, valid

```

## /eval/mv_recon/dataset_utils/cropping.py

```py path="/eval/mv_recon/dataset_utils/cropping.py" 
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------

import PIL.Image
import os

os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import cv2  # noqa
import numpy as np  # noqa
from dust3r.utils.geometry import (
    colmap_to_opencv_intrinsics,
    opencv_to_colmap_intrinsics,
)  # noqa

try:
    lanczos = PIL.Image.Resampling.LANCZOS
    bicubic = PIL.Image.Resampling.BICUBIC
except AttributeError:
    lanczos = PIL.Image.LANCZOS
    bicubic = PIL.Image.BICUBIC


class ImageList:
    """Convenience class to aply the same operation to a whole set of images."""

    def __init__(self, images):
        if not isinstance(images, (tuple, list, set)):
            images = [images]
        self.images = []
        for image in images:
            if not isinstance(image, PIL.Image.Image):
                image = PIL.Image.fromarray(image)
            self.images.append(image)

    def __len__(self):
        return len(self.images)

    def to_pil(self):
        return tuple(self.images) if len(self.images) > 1 else self.images[0]

    @property
    def size(self):
        sizes = [im.size for im in self.images]
        assert all(sizes[0] == s for s in sizes)
        return sizes[0]

    def resize(self, *args, **kwargs):
        return ImageList(self._dispatch("resize", *args, **kwargs))

    def crop(self, *args, **kwargs):
        return ImageList(self._dispatch("crop", *args, **kwargs))

    def _dispatch(self, func, *args, **kwargs):
        return [getattr(im, func)(*args, **kwargs) for im in self.images]


def rescale_image_depthmap(
    image, depthmap, camera_intrinsics, output_resolution, force=True
):
    """Jointly rescale a (image, depthmap)
    so that (out_width, out_height) >= output_res
    """
    image = ImageList(image)
    input_resolution = np.array(image.size)  # (W,H)
    output_resolution = np.array(output_resolution)
    if depthmap is not None:

        assert tuple(depthmap.shape[:2]) == image.size[::-1]

    assert output_resolution.shape == (2,)
    scale_final = max(output_resolution / image.size) + 1e-8
    if scale_final >= 1 and not force:  # image is already smaller than what is asked
        return (image.to_pil(), depthmap, camera_intrinsics)
    output_resolution = np.floor(input_resolution * scale_final).astype(int)

    image = image.resize(
        output_resolution, resample=lanczos if scale_final < 1 else bicubic
    )
    if depthmap is not None:
        depthmap = cv2.resize(
            depthmap,
            output_resolution,
            fx=scale_final,
            fy=scale_final,
            interpolation=cv2.INTER_NEAREST,
        )

    camera_intrinsics = camera_matrix_of_crop(
        camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
    )

    return image.to_pil(), depthmap, camera_intrinsics


def camera_matrix_of_crop(
    input_camera_matrix,
    input_resolution,
    output_resolution,
    scaling=1,
    offset_factor=0.5,
    offset=None,
):

    margins = np.asarray(input_resolution) * scaling - output_resolution
    assert np.all(margins >= 0.0)
    if offset is None:
        offset = offset_factor * margins

    output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
    output_camera_matrix_colmap[:2, :] *= scaling
    output_camera_matrix_colmap[:2, 2] -= offset
    output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)

    return output_camera_matrix


def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
    """
    Return a crop of the input view.
    """
    image = ImageList(image)
    l, t, r, b = crop_bbox

    image = image.crop((l, t, r, b))
    depthmap = depthmap[t:b, l:r]

    camera_intrinsics = camera_intrinsics.copy()
    camera_intrinsics[0, 2] -= l
    camera_intrinsics[1, 2] -= t

    return image.to_pil(), depthmap, camera_intrinsics


def bbox_from_intrinsics_in_out(
    input_camera_matrix, output_camera_matrix, output_resolution
):
    out_width, out_height = output_resolution
    l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
    crop_bbox = (l, t, l + out_width, t + out_height)
    return crop_bbox

```

## /eval/mv_recon/dataset_utils/transforms.py

```py path="/eval/mv_recon/dataset_utils/transforms.py" 
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------

import torchvision.transforms as tvf
from dust3r.utils.image import ImgNorm


ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])


def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
    if isinstance(value, (int, float)):
        if value < 0:
            raise ValueError(f"If  is a single number, it must be non negative.")
        value = [center - float(value), center + float(value)]
        if clip_first_on_zero:
            value[0] = max(value[0], 0.0)
    elif isinstance(value, (tuple, list)) and len(value) == 2:
        value = [float(value[0]), float(value[1])]
    else:
        raise TypeError(f"should be a single number or a list/tuple with length 2.")

    if not bound[0] <= value[0] <= value[1] <= bound[1]:
        raise ValueError(f"values should be between {bound}, but got {value}.")

    if value[0] == value[1] == center:
        return None
    else:
        return tuple(value)


import torch
import torchvision.transforms.functional as F


def SeqColorJitter():
    """
    Return a color jitter transform with same random parameters
    """
    brightness = _check_input(0.5)
    contrast = _check_input(0.5)
    saturation = _check_input(0.5)
    hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)

    fn_idx = torch.randperm(4)
    brightness_factor = (
        None
        if brightness is None
        else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
    )
    contrast_factor = (
        None
        if contrast is None
        else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
    )
    saturation_factor = (
        None
        if saturation is None
        else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
    )
    hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))

    def _color_jitter(img):
        for fn_id in fn_idx:
            if fn_id == 0 and brightness_factor is not None:
                img = F.adjust_brightness(img, brightness_factor)
            elif fn_id == 1 and contrast_factor is not None:
                img = F.adjust_contrast(img, contrast_factor)
            elif fn_id == 2 and saturation_factor is not None:
                img = F.adjust_saturation(img, saturation_factor)
            elif fn_id == 3 and hue_factor is not None:
                img = F.adjust_hue(img, hue_factor)
        return ImgNorm(img)

    return _color_jitter

```

## /eval/mv_recon/launch.py

```py path="/eval/mv_recon/launch.py" 
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import time
import torch
import argparse
import numpy as np
import open3d as o3d
import os.path as osp
from torch.utils.data import DataLoader
from add_ckpt_path import add_path_to_dust3r
from accelerate import Accelerator
from torch.utils.data._utils.collate import default_collate
import tempfile
from tqdm import tqdm


def get_args_parser():
    parser = argparse.ArgumentParser("3D Reconstruction evaluation", add_help=False)
    parser.add_argument(
        "--weights",
        type=str,
        default="",
        help="ckpt name",
    )
    parser.add_argument("--device", type=str, default="cuda:0", help="device")
    parser.add_argument("--model_name", type=str, default="")
    parser.add_argument(
        "--conf_thresh", type=float, default=0.0, help="confidence threshold"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="",
        help="value for outdir",
    )
    parser.add_argument("--size", type=int, default=512)
    parser.add_argument("--revisit", type=int, default=1, help="revisit times")
    parser.add_argument("--freeze", action="store_true")
    parser.add_argument("--max_frames", type=int, default=None, help="max frames limit")
    parser.add_argument("--model_update_type", type=str, default="cut3r", help="model update type")
    parser.add_argument("--voxel_size", type=float, default=0.0, help="voxel size for voxel grid downsampling, 0 means no downsampling")
    return parser


def main(args):
    add_path_to_dust3r(args.weights)
    from eval.mv_recon.data import SevenScenes, NRGBD
    from eval.mv_recon.utils import accuracy, completion

    if args.size == 512:
        resolution = (512, 384)
    elif args.size == 224:
        resolution = 224
    else:
        raise NotImplementedError
    datasets_all = {
        "7scenes": SevenScenes(
            split="test",
            ROOT="/home/share/Dataset/3D_scene/7scenes/", # "./data/7scenes",
            resolution=resolution,
            num_seq=1,
            full_video=True,
            kf_every=2,
            max_frames=args.max_frames,
        )
    }

    # ====== print the number of views for each scene ======
    print("\n=== number of views for each scene ===")
    for name_data, dataset in datasets_all.items():
        print(f"\n{name_data} dataset:")
        for scene_id in dataset.scene_list:
            if name_data == "NRGBD":
                # NRGBD dataset file structure
                data_path = osp.join(dataset.ROOT, scene_id, "images")
                num_files = len([name for name in os.listdir(data_path) if name.endswith('.png')])
                view_count = len([f"{i}" for i in range(num_files)][::dataset.kf_every])
            else:
                # SevenScenes dataset file structure
                data_path = osp.join(dataset.ROOT, scene_id)
                num_files = len([name for name in os.listdir(data_path) if "color" in name])
                view_count = len([f"{i:06d}" for i in range(num_files)][::dataset.kf_every])
            
            # consider max_frames limit
            if dataset.max_frames is not None:
                actual_view_count = min(view_count, dataset.max_frames)
                print(f"  {scene_id}: {actual_view_count} views (original: {view_count}, limit: {dataset.max_frames})")
            else:
                print(f"  {scene_id}: {view_count} views")
    print("================================\n")
    # ====== print end ======

    accelerator = Accelerator()
    device = accelerator.device
    model_name = args.model_name
    # if model_name == "ours" or model_name == "cut3r":
    from dust3r.model import ARCroco3DStereo
    from eval.mv_recon.criterion import Regr3D_t_ScaleShiftInv, L21
    from dust3r.utils.geometry import geotrf
    from copy import deepcopy

    model = ARCroco3DStereo.from_pretrained(args.weights).to(device)
    model.config.model_update_type = args.model_update_type

    model.eval()
    # else:
    #     raise NotImplementedError
    os.makedirs(args.output_dir, exist_ok=True)

    criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True)

    with torch.no_grad():
        for name_data, dataset in datasets_all.items():
            save_path = osp.join(args.output_dir, name_data)
            os.makedirs(save_path, exist_ok=True)
            log_file = osp.join(save_path, f"logs_{accelerator.process_index}.txt")

            acc_all = 0
            acc_all_med = 0
            comp_all = 0
            comp_all_med = 0
            nc1_all = 0
            nc1_all_med = 0
            nc2_all = 0
            nc2_all_med = 0

            fps_all = []
            time_all = []

            with accelerator.split_between_processes(list(range(len(dataset)))) as idxs:
                for data_idx in tqdm(idxs):
                    batch = default_collate([dataset[data_idx]])
                    ignore_keys = set(
                        [
                            "depthmap",
                            "dataset",
                            "label",
                            "instance",
                            "idx",
                            "true_shape",
                            "rng",
                        ]
                    )
                    for view in batch:
                        for name in view.keys():  # pseudo_focal
                            if name in ignore_keys:
                                continue
                            if isinstance(view[name], tuple) or isinstance(
                                view[name], list
                            ):
                                view[name] = [
                                    x.to(device, non_blocking=True) for x in view[name]
                                ]
                            else:
                                view[name] = view[name].to(device, non_blocking=True)

                    # if model_name == "ours" or model_name == "cut3r":
                    revisit = args.revisit
                    update = not args.freeze
                    if revisit > 1:
                        # repeat input for 'revisit' times
                        new_views = []
                        for r in range(revisit):
                            for i in range(len(batch)):
                                new_view = deepcopy(batch[i])
                                new_view["idx"] = [
                                    (r * len(batch) + i)
                                    for _ in range(len(batch[i]["idx"]))
                                ]
                                new_view["instance"] = [
                                    str(r * len(batch) + i)
                                    for _ in range(len(batch[i]["instance"]))
                                ]
                                if r > 0:
                                    if not update:
                                        new_view["update"] = torch.zeros_like(
                                            batch[i]["update"]
                                        ).bool()
                                new_views.append(new_view)
                        batch = new_views
                    with torch.cuda.amp.autocast(enabled=False):
                        start = time.time()
                        output = model(batch)
                        # preds, batch = model.forward_recurrent_light(batch)
                        end = time.time()
                        preds, batch = output.ress, output.views
                    valid_length = len(preds) // revisit
                    preds = preds[-valid_length:]
                    batch = batch[-valid_length:]
                    fps = len(batch) / (end - start)
                    print(
                        f"Finished reconstruction for {name_data} {data_idx+1}/{len(dataset)}, FPS: {fps:.2f}"
                    )
                    # continue
                    fps_all.append(fps)
                    time_all.append(end - start)

                    # Evaluation
                    print(f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}")
                    gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
                        criterion.get_all_pts3d_t(batch, preds)
                    )
                    pred_scale, gt_scale, pred_shift_z, gt_shift_z = (
                        monitoring["pred_scale"],
                        monitoring["gt_scale"],
                        monitoring["pred_shift_z"],
                        monitoring["gt_shift_z"],
                    )

                    in_camera1 = None
                    pts_all = []
                    pts_gt_all = []
                    images_all = []
                    masks_all = []
                    conf_all = []

                    for j, view in enumerate(batch):
                        if in_camera1 is None:
                            in_camera1 = view["camera_pose"][0].cpu()

                        image = view["img"].permute(0, 2, 3, 1).cpu().numpy()[0]
                        mask = view["valid_mask"].cpu().numpy()[0]

                        # pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
                        pts = pred_pts[j].cpu().numpy()[0]
                        conf = preds[j]["conf"].cpu().data.numpy()[0]
                        # mask = mask & (conf > 1.8)

                        pts_gt = gt_pts[j].detach().cpu().numpy()[0]

                        H, W = image.shape[:2]
                        cx = W // 2
                        cy = H // 2
                        l, t = cx - 112, cy - 112
                        r, b = cx + 112, cy + 112
                        image = image[t:b, l:r]
                        mask = mask[t:b, l:r]
                        pts = pts[t:b, l:r]
                        pts_gt = pts_gt[t:b, l:r]

                        #### Align predicted 3D points to the ground truth
                        pts[..., -1] += gt_shift_z.cpu().numpy().item()
                        pts = geotrf(in_camera1, pts)

                        pts_gt[..., -1] += gt_shift_z.cpu().numpy().item()
                        pts_gt = geotrf(in_camera1, pts_gt)

                        images_all.append((image[None, ...] + 1.0) / 2.0)
                        pts_all.append(pts[None, ...])
                        pts_gt_all.append(pts_gt[None, ...])
                        masks_all.append(mask[None, ...])
                        conf_all.append(conf[None, ...])

                    images_all = np.concatenate(images_all, axis=0)
                    pts_all = np.concatenate(pts_all, axis=0)
                    pts_gt_all = np.concatenate(pts_gt_all, axis=0)
                    masks_all = np.concatenate(masks_all, axis=0)

                    scene_id = view["label"][0].rsplit("/", 1)[0]

                    save_params = {}

                    save_params["images_all"] = images_all
                    save_params["pts_all"] = pts_all
                    save_params["pts_gt_all"] = pts_gt_all
                    save_params["masks_all"] = masks_all

                    np.save(
                        os.path.join(save_path, f"{scene_id.replace('/', '_')}.npy"),
                        save_params,
                    )

                    if "DTU" in name_data:
                        threshold = 100
                    else:
                        threshold = 0.1

                    pts_all_masked = pts_all[masks_all > 0]
                    pts_gt_all_masked = pts_gt_all[masks_all > 0]
                    images_all_masked = images_all[masks_all > 0]

                    pcd = o3d.geometry.PointCloud()
                    pcd.points = o3d.utility.Vector3dVector(pts_all_masked.reshape(-1, 3))
                    pcd.colors = o3d.utility.Vector3dVector(images_all_masked.reshape(-1, 3))
                    pcd_gt = o3d.geometry.PointCloud()
                    pcd_gt.points = o3d.utility.Vector3dVector(pts_gt_all_masked.reshape(-1, 3))
                    pcd_gt.colors = o3d.utility.Vector3dVector(images_all_masked.reshape(-1, 3))

                    # ====== voxel grid downsampling ======
                    if args.voxel_size > 0:
                        pcd = pcd.voxel_down_sample(voxel_size=args.voxel_size)
                        pcd_gt = pcd_gt.voxel_down_sample(voxel_size=args.voxel_size)
                    # ===========================

                    o3d.io.write_point_cloud(
                        os.path.join(
                            save_path, f"{scene_id.replace('/', '_')}-mask.ply"
                        ),
                        pcd,
                    )

                    o3d.io.write_point_cloud(
                        os.path.join(save_path, f"{scene_id.replace('/', '_')}-gt.ply"),
                        pcd_gt,
                    )

                    trans_init = np.eye(4)

                    reg_p2p = o3d.pipelines.registration.registration_icp(
                        pcd,
                        pcd_gt,
                        threshold,
                        trans_init,
                        o3d.pipelines.registration.TransformationEstimationPointToPoint(),
                    )

                    transformation = reg_p2p.transformation

                    pcd = pcd.transform(transformation)
                    pcd.estimate_normals()
                    pcd_gt.estimate_normals()

                    gt_normal = np.asarray(pcd_gt.normals)
                    pred_normal = np.asarray(pcd.normals)

                    acc, acc_med, nc1, nc1_med = accuracy(
                        pcd_gt.points, pcd.points, gt_normal, pred_normal
                    )
                    comp, comp_med, nc2, nc2_med = completion(
                        pcd_gt.points, pcd.points, gt_normal, pred_normal
                    )
                    print(
                        f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}"
                    )
                    print(
                        f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}",
                        file=open(log_file, "a"),
                    )

                    acc_all += acc
                    comp_all += comp
                    nc1_all += nc1
                    nc2_all += nc2

                    acc_all_med += acc_med
                    comp_all_med += comp_med
                    nc1_all_med += nc1_med
                    nc2_all_med += nc2_med

                    # release cuda memory
                    torch.cuda.empty_cache()

            accelerator.wait_for_everyone()
            # Get depth from pcd and run TSDFusion
            if accelerator.is_main_process:
                to_write = ""
                # Copy the error log from each process to the main error log
                for i in range(8):
                    if not os.path.exists(osp.join(save_path, f"logs_{i}.txt")):
                        break
                    with open(osp.join(save_path, f"logs_{i}.txt"), "r") as f_sub:
                        to_write += f_sub.read()

                with open(osp.join(save_path, f"logs_all.txt"), "w") as f:
                    log_data = to_write
                    metrics = defaultdict(list)
                    for line in log_data.strip().split("\n"):
                        match = regex.match(line)
                        if match:
                            data = match.groupdict()
                            # Exclude 'scene_id' from metrics as it's an identifier
                            for key, value in data.items():
                                if key != "scene_id":
                                    metrics[key].append(float(value))
                            metrics["nc"].append(
                                (float(data["nc1"]) + float(data["nc2"])) / 2
                            )
                            metrics["nc_med"].append(
                                (float(data["nc1_med"]) + float(data["nc2_med"])) / 2
                            )
                    mean_metrics = {
                        metric: sum(values) / len(values)
                        for metric, values in metrics.items()
                    }

                    c_name = "mean"
                    print_str = f"{c_name.ljust(20)}: "
                    for m_name in mean_metrics:
                        print_num = np.mean(mean_metrics[m_name])
                        print_str = print_str + f"{m_name}: {print_num:.3f} | "
                    print_str = print_str + "\n"
                    f.write(to_write + print_str)


from collections import defaultdict
import re

pattern = r"""
    Idx:\s*(?P<scene_id>[^,]+),\s*
    Acc:\s*(?P<acc>[^,]+),\s*
    Comp:\s*(?P<comp>[^,]+),\s*
    NC1:\s*(?P<nc1>[^,]+),\s*
    NC2:\s*(?P<nc2>[^,]+)\s*-\s*
    Acc_med:\s*(?P<acc_med>[^,]+),\s*
    Compc_med:\s*(?P<comp_med>[^,]+),\s*
    NC1c_med:\s*(?P<nc1_med>[^,]+),\s*
    NC2c_med:\s*(?P<nc2_med>[^,]+)
"""

regex = re.compile(pattern, re.VERBOSE)


if __name__ == "__main__":
    parser = get_args_parser()
    args = parser.parse_args()

    main(args)

```

## /eval/mv_recon/run.sh

```sh path="/eval/mv_recon/run.sh" 
#!/bin/bash

set -e

workdir='.'
model_names=('ttt3r') # ttt3r cut3r
ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"

for model_name in "${model_names[@]}"; do

# for max_frames in 50 100 150 200 250 300 350 400
for max_frames in 200

do
    output_dir="${workdir}/eval_results/video_recon/7scenes_${max_frames}/${model_name}"
    echo "$output_dir"
    NCCL_TIMEOUT=360000 accelerate launch --num_processes 1 --main_process_port 29502 eval/mv_recon/launch.py \
        --weights "$model_weights" \
        --output_dir "$output_dir" \
        --model_name "$model_name" \
        --model_update_type "$model_name" \
        --max_frames $max_frames \

done
done

```

## /eval/mv_recon/utils.py

```py path="/eval/mv_recon/utils.py" 
import numpy as np
from scipy.spatial import cKDTree as KDTree
import torch


def completion_ratio(gt_points, rec_points, dist_th=0.05):
    gen_points_kd_tree = KDTree(rec_points)
    distances, _ = gen_points_kd_tree.query(gt_points)
    comp_ratio = np.mean((distances < dist_th).astype(np.float32))
    return comp_ratio


def accuracy(gt_points, rec_points, gt_normals=None, rec_normals=None):
    gt_points_kd_tree = KDTree(gt_points)
    distances, idx = gt_points_kd_tree.query(rec_points, workers=-1)
    acc = np.mean(distances)

    acc_median = np.median(distances)

    if gt_normals is not None and rec_normals is not None:
        normal_dot = np.sum(gt_normals[idx] * rec_normals, axis=-1)
        normal_dot = np.abs(normal_dot)

        return acc, acc_median, np.mean(normal_dot), np.median(normal_dot)

    return acc, acc_median


def completion(gt_points, rec_points, gt_normals=None, rec_normals=None):
    gt_points_kd_tree = KDTree(rec_points)
    distances, idx = gt_points_kd_tree.query(gt_points, workers=-1)
    comp = np.mean(distances)
    comp_median = np.median(distances)

    if gt_normals is not None and rec_normals is not None:
        normal_dot = np.sum(gt_normals * rec_normals[idx], axis=-1)
        normal_dot = np.abs(normal_dot)

        return comp, comp_median, np.mean(normal_dot), np.median(normal_dot)

    return comp, comp_median


def compute_iou(pred_vox, target_vox):
    # Get voxel indices
    v_pred_indices = [voxel.grid_index for voxel in pred_vox.get_voxels()]
    v_target_indices = [voxel.grid_index for voxel in target_vox.get_voxels()]

    # Convert to sets for set operations
    v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred_indices)
    v_target_filled = set(tuple(np.round(x, 4)) for x in v_target_indices)

    # Compute intersection and union
    intersection = v_pred_filled & v_target_filled
    union = v_pred_filled | v_target_filled

    # Compute IoU
    iou = len(intersection) / len(union)
    return iou

```

## /eval/relpose/evo_utils.py

```py path="/eval/relpose/evo_utils.py" 
import os
import re
from copy import deepcopy
from pathlib import Path

import evo.main_ape as main_ape
import evo.main_rpe as main_rpe
import matplotlib.pyplot as plt
import numpy as np
from evo.core import sync
from evo.core.metrics import PoseRelation, Unit
from evo.core.trajectory import PosePath3D, PoseTrajectory3D
from evo.tools import file_interface, plot
from scipy.spatial.transform import Rotation
from evo.core import metrics
import json


def sintel_cam_read(filename):
    """Read camera data, return (M,N) tuple.

    M is the intrinsic matrix, N is the extrinsic matrix, so that

    x = M*N*X,
    with x being a point in homogeneous image pixel coordinates, X being a
    point in homogeneous world coordinates.
    """
    TAG_FLOAT = 202021.25

    f = open(filename, "rb")
    check = np.fromfile(f, dtype=np.float32, count=1)[0]
    assert (
        check == TAG_FLOAT
    ), " cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? ".format(
        TAG_FLOAT, check
    )
    M = np.fromfile(f, dtype="float64", count=9).reshape((3, 3))
    N = np.fromfile(f, dtype="float64", count=12).reshape((3, 4))
    return M, N


def load_replica_traj(gt_file):
    traj_w_c = np.loadtxt(gt_file)
    assert traj_w_c.shape[1] == 12 or traj_w_c.shape[1] == 16
    poses = [
        np.array(
            [
                [r[0], r[1], r[2], r[3]],
                [r[4], r[5], r[6], r[7]],
                [r[8], r[9], r[10], r[11]],
                [0, 0, 0, 1],
            ]
        )
        for r in traj_w_c
    ]

    pose_path = PosePath3D(poses_se3=poses)
    timestamps_mat = np.arange(traj_w_c.shape[0]).astype(float)

    traj = PoseTrajectory3D(poses_se3=pose_path.poses_se3, timestamps=timestamps_mat)
    xyz = traj.positions_xyz
    # shift -1 column -> w in back column
    # quat = np.roll(traj.orientations_quat_wxyz, -1, axis=1)
    # uncomment this line if the quaternion is in scalar-first format
    quat = traj.orientations_quat_wxyz

    traj_tum = np.column_stack((xyz, quat))
    return (traj_tum, timestamps_mat)


def load_sintel_traj(gt_file):  # './data/sintel/training/camdata_left/alley_2'
    # Refer to ParticleSfM
    gt_pose_lists = sorted(os.listdir(gt_file))
    gt_pose_lists = [
        os.path.join(gt_file, x) for x in gt_pose_lists if x.endswith(".cam")
    ]
    tstamps = [float(x.split("/")[-1][:-4].split("_")[-1]) for x in gt_pose_lists]
    gt_poses = [
        sintel_cam_read(f)[1] for f in gt_pose_lists
    ]  # [1] means get the extrinsic
    xyzs, wxyzs = [], []
    tum_gt_poses = []
    for gt_pose in gt_poses:
        gt_pose = np.concatenate([gt_pose, np.array([[0, 0, 0, 1]])], 0)
        gt_pose_inv = np.linalg.inv(gt_pose)  # world2cam -> cam2world
        xyz = gt_pose_inv[:3, -1]
        xyzs.append(xyz)
        R = Rotation.from_matrix(gt_pose_inv[:3, :3])
        xyzw = R.as_quat()  # scalar-last for scipy
        wxyz = np.array([xyzw[-1], xyzw[0], xyzw[1], xyzw[2]])
        wxyzs.append(wxyz)
        tum_gt_pose = np.concatenate([xyz, wxyz], 0)  # TODO: check if this is correct
        tum_gt_poses.append(tum_gt_pose)

    tum_gt_poses = np.stack(tum_gt_poses, 0)
    tum_gt_poses[:, :3] = tum_gt_poses[:, :3] - np.mean(
        tum_gt_poses[:, :3], 0, keepdims=True
    )
    tt = np.expand_dims(np.stack(tstamps, 0), -1)
    return tum_gt_poses, tt

def load_iphone_traj(gt_file): 
    # Refer to load_sintel_traj
    # read all JSON format camera parameter files
    gt_pose_lists = sorted(os.listdir(gt_file))
    gt_pose_lists = [os.path.join(gt_file, x) for x in gt_pose_lists if x.endswith(".json")]
    
    xyzs, wxyzs = [], []
    tum_gt_poses = []
    for pose_file in gt_pose_lists:
        with open(pose_file, 'r') as f:
            camera_data = json.load(f)
        
        gt_pose = np.array(camera_data['w2c'])
        gt_pose_inv = np.linalg.inv(gt_pose)  # world2cam -> cam2world
        
        xyz = gt_pose_inv[:3, -1]
        xyzs.append(xyz)
        
        R = Rotation.from_matrix(gt_pose_inv[:3, :3])
        xyzw = R.as_quat()  # scalar-last for scipy
        wxyz = np.array([xyzw[-1], xyzw[0], xyzw[1], xyzw[2]])
        wxyzs.append(wxyz)
        
        tum_gt_pose = np.concatenate([xyz, wxyz], 0)
        tum_gt_poses.append(tum_gt_pose)

    tum_gt_poses = np.stack(tum_gt_poses, 0)
    tum_gt_poses[:, :3] = tum_gt_poses[:, :3] - np.mean(
        tum_gt_poses[:, :3], 0, keepdims=True
    )
    
    # use array index as timestamps
    tt = np.expand_dims(np.arange(tum_gt_poses.shape[0]).astype(float), -1)
    return tum_gt_poses, tt

def load_traj(gt_traj_file, traj_format="sintel", skip=0, stride=1, num_frames=None):
    """Read trajectory format. Return in TUM-RGBD format.
    Returns:
        traj_tum (N, 7): camera to world poses in (x,y,z,qx,qy,qz,qw)
        timestamps_mat (N, 1): timestamps
    """
    if traj_format == "replica":
        traj_tum, timestamps_mat = load_replica_traj(gt_traj_file)
    elif traj_format == "sintel":
        traj_tum, timestamps_mat = load_sintel_traj(gt_traj_file)
    elif traj_format in ["tum", "tartanair"]:
        traj = file_interface.read_tum_trajectory_file(gt_traj_file)
        xyz = traj.positions_xyz
        quat = traj.orientations_quat_wxyz
        timestamps_mat = traj.timestamps
        traj_tum = np.column_stack((xyz, quat))
    elif traj_format in ["iphone"]:
        traj_tum, timestamps_mat = load_iphone_traj(gt_traj_file)
    else:
        raise NotImplementedError

    traj_tum = traj_tum[skip::stride]
    timestamps_mat = timestamps_mat[skip::stride]
    if num_frames is not None:
        traj_tum = traj_tum[:num_frames]
        timestamps_mat = timestamps_mat[:num_frames]
    return traj_tum, timestamps_mat


def update_timestamps(gt_file, traj_format, skip=0, stride=1):
    """Update timestamps given a"""
    if traj_format == "tum":
        traj_t_map_file = gt_file.replace("groundtruth.txt", "rgb.txt")
        timestamps = load_timestamps(traj_t_map_file, traj_format)
        return timestamps[skip::stride]
    elif traj_format == "tartanair":
        traj_t_map_file = gt_file.replace("gt_pose.txt", "times.txt")
        timestamps = load_timestamps(traj_t_map_file, traj_format)
        return timestamps[skip::stride]


def load_timestamps(time_file, traj_format="replica"):
    if traj_format in ["tum", "tartanair"]:
        with open(time_file, "r+") as f:
            lines = f.readlines()
        timestamps_mat = [
            float(x.split(" ")[0]) for x in lines if not x.startswith("#")
        ]
        return timestamps_mat


def make_traj(args) -> PoseTrajectory3D:
    if isinstance(args, tuple) or isinstance(args, list):
        traj, tstamps = args
        return PoseTrajectory3D(
            positions_xyz=traj[:, :3],
            orientations_quat_wxyz=traj[:, 3:],
            timestamps=tstamps,
        )
    assert isinstance(args, PoseTrajectory3D), type(args)
    return deepcopy(args)


def eval_metrics(pred_traj, gt_traj=None, seq="", filename="", sample_stride=1):

    if sample_stride > 1:
        pred_traj[0] = pred_traj[0][::sample_stride]
        pred_traj[1] = pred_traj[1][::sample_stride]
        if gt_traj is not None:
            updated_gt_traj = []
            updated_gt_traj.append(gt_traj[0][::sample_stride])
            updated_gt_traj.append(gt_traj[1][::sample_stride])
            gt_traj = updated_gt_traj

    pred_traj = make_traj(pred_traj)

    if gt_traj is not None:
        gt_traj = make_traj(gt_traj)

        if pred_traj.timestamps.shape[0] == gt_traj.timestamps.shape[0]:
            pred_traj.timestamps = gt_traj.timestamps
        else:
            print(pred_traj.timestamps.shape[0], gt_traj.timestamps.shape[0])

        gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj)

    # ATE
    traj_ref = gt_traj
    traj_est = pred_traj

    ate_result = main_ape.ape(
        traj_ref,
        traj_est,
        est_name="traj",
        pose_relation=PoseRelation.translation_part,
        align=True,
        correct_scale=True,
    )

    ate = ate_result.stats["rmse"]
    # print(ate_result.np_arrays['error_array'])
    # exit()

    # RPE rotation and translation
    delta_list = [1]
    rpe_rots, rpe_transs = [], []
    for delta in delta_list:
        rpe_rots_result = main_rpe.rpe(
            traj_ref,
            traj_est,
            est_name="traj",
            pose_relation=PoseRelation.rotation_angle_deg,
            align=True,
            correct_scale=True,
            delta=delta,
            delta_unit=Unit.frames,
            rel_delta_tol=0.01,
            all_pairs=True,
        )

        rot = rpe_rots_result.stats["rmse"]
        rpe_rots.append(rot)

    for delta in delta_list:
        rpe_transs_result = main_rpe.rpe(
            traj_ref,
            traj_est,
            est_name="traj",
            pose_relation=PoseRelation.translation_part,
            align=True,
            correct_scale=True,
            delta=delta,
            delta_unit=Unit.frames,
            rel_delta_tol=0.01,
            all_pairs=True,
        )

        trans = rpe_transs_result.stats["rmse"]
        rpe_transs.append(trans)

    rpe_trans, rpe_rot = np.mean(rpe_transs), np.mean(rpe_rots)
    with open(filename, "w+") as f:
        f.write(f"Seq: {seq} \n\n")
        f.write(f"{ate_result}")
        f.write(f"{rpe_rots_result}")
        f.write(f"{rpe_transs_result}")

    print(f"Save results to {filename}")
    return ate, rpe_trans, rpe_rot


def eval_metrics_first_pose_align_last_pose(
    pred_traj, gt_traj=None, seq="", filename="", figpath="", sample_stride=1
):
    if sample_stride > 1:
        pred_traj[0] = pred_traj[0][::sample_stride]
        pred_traj[1] = pred_traj[1][::sample_stride]
        if gt_traj is not None:
            gt_traj = [gt_traj[0][::sample_stride], gt_traj[1][::sample_stride]]
    pred_traj = make_traj(pred_traj)
    if gt_traj is not None:
        gt_traj = make_traj(gt_traj)

        if pred_traj.timestamps.shape[0] == gt_traj.timestamps.shape[0]:
            pred_traj.timestamps = gt_traj.timestamps
        else:
            print(
                "Different number of poses:",
                pred_traj.timestamps.shape[0],
                gt_traj.timestamps.shape[0],
            )

        gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj)

    if gt_traj is not None and pred_traj is not None:
        if len(gt_traj.poses_se3) > 0 and len(pred_traj.poses_se3) > 0:
            first_gt_pose = gt_traj.poses_se3[0]
            first_pred_pose = pred_traj.poses_se3[0]
            # T = (first_gt_pose) * inv(first_pred_pose)
            T = first_gt_pose @ np.linalg.inv(first_pred_pose)

            # Apply T to every predicted pose
            aligned_pred_poses = []
            for pose in pred_traj.poses_se3:
                aligned_pred_poses.append(T @ pose)
            aligned_pred_traj = PoseTrajectory3D(
                poses_se3=aligned_pred_poses,
                timestamps=np.array(pred_traj.timestamps),
                # optionally copy other fields if your make_traj object has them
            )
            pred_traj = aligned_pred_traj  # .poses_se3 = aligned_pred_poses
        plot_trajectory(
            pred_traj,
            gt_traj,
            title=seq,
            filename=figpath,
            align=False,
            correct_scale=False,
        )

    if gt_traj is not None and len(gt_traj.poses_se3) > 0:
        gt_traj = PoseTrajectory3D(
            poses_se3=[gt_traj.poses_se3[-1]], timestamps=[gt_traj.timestamps[-1]]
        )
    if pred_traj is not None and len(pred_traj.poses_se3) > 0:
        pred_traj = PoseTrajectory3D(
            poses_se3=[pred_traj.poses_se3[-1]], timestamps=[pred_traj.timestamps[-1]]
        )

    ate_result = main_ape.ape(
        gt_traj,
        pred_traj,
        est_name="traj",
        pose_relation=PoseRelation.translation_part,
        align=False,  # <-- important
        correct_scale=False,  # <-- important
    )
    ate = ate_result.stats["rmse"]
    with open(filename, "w+") as f:
        f.write(f"Seq: {seq}\n\n")
        f.write(f"{ate_result}")

    print(f"Save results to {filename}")

    return ate


def best_plotmode(traj):
    _, i1, i2 = np.argsort(np.var(traj.positions_xyz, axis=0))
    plot_axes = "xyz"[i2] + "xyz"[i1]
    return getattr(plot.PlotMode, plot_axes)


def plot_trajectory(
    pred_traj, gt_traj=None, title="", filename="", align=True, correct_scale=True
):
    pred_traj = make_traj(pred_traj)

    if gt_traj is not None:
        gt_traj = make_traj(gt_traj)
        if pred_traj.timestamps.shape[0] == gt_traj.timestamps.shape[0]:
            pred_traj.timestamps = gt_traj.timestamps
        else:
            print("WARNING", pred_traj.timestamps.shape[0], gt_traj.timestamps.shape[0])

        gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj)

        if align:
            pred_traj.align(gt_traj, correct_scale=correct_scale)

    plot_collection = plot.PlotCollection("PlotCol")
    fig = plt.figure(figsize=(8, 8))
    plot_mode = best_plotmode(gt_traj if (gt_traj is not None) else pred_traj)
    ax = plot.prepare_axis(fig, plot_mode)
    ax.set_title(title)
    if gt_traj is not None:
        plot.traj(ax, plot_mode, gt_traj, "--", "gray", "Ground Truth")
    plot.traj(ax, plot_mode, pred_traj, "-", "blue", "Predicted")
    plot_collection.add_figure("traj_error", fig)
    plot_collection.export(filename, confirm_overwrite=False)
    plt.close(fig=fig)
    print(f"Saved trajectory to {filename.replace('.png','')}_traj_error.png")


def save_trajectory_tum_format(traj, filename):
    traj = make_traj(traj)
    tostr = lambda a: " ".join(map(str, a))
    with Path(filename).open("w") as f:
        for i in range(traj.num_poses):
            f.write(
                f"{traj.timestamps[i]} {tostr(traj.positions_xyz[i])} {tostr(traj.orientations_quat_wxyz[i][[0,1,2,3]])}\n"
            )
    print(f"Saved trajectory to {filename}")


def extract_metrics(file_path):
    with open(file_path, "r") as file:
        content = file.read()

    # Extract metrics using regex
    ate_match = re.search(
        r"APE w.r.t. translation part \(m\).*?rmse\s+([0-9.]+)", content, re.DOTALL
    )
    rpe_trans_match = re.search(
        r"RPE w.r.t. translation part \(m\).*?rmse\s+([0-9.]+)", content, re.DOTALL
    )
    rpe_rot_match = re.search(
        r"RPE w.r.t. rotation angle in degrees \(deg\).*?rmse\s+([0-9.]+)",
        content,
        re.DOTALL,
    )

    ate = float(ate_match.group(1)) if ate_match else 0.0
    rpe_trans = float(rpe_trans_match.group(1)) if rpe_trans_match else 0.0
    rpe_rot = float(rpe_rot_match.group(1)) if rpe_rot_match else 0.0

    return ate, rpe_trans, rpe_rot


def process_directory(directory):
    results = []
    for root, _, files in os.walk(directory):
        if files is not None:
            files = sorted(files)
        for file in files:
            if file.endswith("_metric.txt"):
                file_path = os.path.join(root, file)
                seq_name = file.replace("_eval_metric.txt", "")
                ate, rpe_trans, rpe_rot = extract_metrics(file_path)
                results.append((seq_name, ate, rpe_trans, rpe_rot))

    return results


def calculate_averages(results):
    total_ate = sum(r[1] for r in results)
    total_rpe_trans = sum(r[2] for r in results)
    total_rpe_rot = sum(r[3] for r in results)
    count = len(results)

    if count == 0:
        return 0.0, 0.0, 0.0

    avg_ate = total_ate / count
    avg_rpe_trans = total_rpe_trans / count
    avg_rpe_rot = total_rpe_rot / count

    return avg_ate, avg_rpe_trans, avg_rpe_rot

```

## /eval/relpose/launch.py

```py path="/eval/relpose/launch.py" 
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import math
import cv2
import numpy as np
import torch
import argparse

from copy import deepcopy
from eval.relpose.metadata import dataset_metadata
from eval.relpose.utils import *

from accelerate import PartialState
from add_ckpt_path import add_path_to_dust3r

from tqdm import tqdm
import time

def get_args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--weights",
        type=str,
        help="path to the model weights",
        default="",
    )

    parser.add_argument("--device", type=str, default="cuda", help="pytorch device")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="",
        help="value for outdir",
    )
    parser.add_argument(
        "--no_crop", type=bool, default=True, help="whether to crop input data"
    )

    parser.add_argument(
        "--eval_dataset",
        type=str,
        default="sintel",
        choices=list(dataset_metadata.keys()),
    )
    parser.add_argument("--size", type=int, default="224")

    parser.add_argument(
        "--model_update_type",
        type=str,
        default="cut3r",
        help="model type for state update strategy: cut3r or ttt3r",
    )

    parser.add_argument(
        "--pose_eval_stride", default=1, type=int, help="stride for pose evaluation"
    )
    parser.add_argument("--shuffle", action="store_true", default=False)
    parser.add_argument(
        "--full_seq",
        action="store_true",
        default=False,
        help="use full sequence for pose evaluation",
    )
    parser.add_argument(
        "--seq_list",
        nargs="+",
        default=None,
        help="list of sequences for pose evaluation",
    )

    parser.add_argument("--revisit", type=int, default=1)
    parser.add_argument("--freeze_state", action="store_true", default=False)
    parser.add_argument("--solve_pose", action="store_true", default=False)
    return parser


def eval_pose_estimation(args, model, save_dir=None):
    metadata = dataset_metadata.get(args.eval_dataset)
    img_path = metadata["img_path"]
    mask_path = metadata["mask_path"]

    ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist(
        args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path
    )
    return ate_mean, rpe_trans_mean, rpe_rot_mean


def eval_pose_estimation_dist(args, model, img_path, save_dir=None, mask_path=None):
    from dust3r.inference import inference, inference_recurrent, inference_recurrent_lighter

    metadata = dataset_metadata.get(args.eval_dataset)
    anno_path = metadata.get("anno_path", None)

    seq_list = args.seq_list
    if seq_list is None:
        if metadata.get("full_seq", False):
            args.full_seq = True
        else:
            seq_list = metadata.get("seq_list", [])
        if args.full_seq:
            seq_list = os.listdir(img_path)
            seq_list = [
                seq for seq in seq_list if os.path.isdir(os.path.join(img_path, seq))
            ]
        seq_list = sorted(seq_list)

    if save_dir is None:
        save_dir = args.output_dir

    distributed_state = PartialState()
    model.to(distributed_state.device)
    device = distributed_state.device

    with distributed_state.split_between_processes(seq_list) as seqs:
        ate_list = []
        rpe_trans_list = []
        rpe_rot_list = []
        load_img_size = args.size
        error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt"  # Unique log file per process
        bug = False
        for seq in tqdm(seqs):
            try:
                dir_path = metadata["dir_path_func"](img_path, seq)

                # Handle skip_condition
                skip_condition = metadata.get("skip_condition", None)
                if skip_condition is not None and skip_condition(save_dir, seq):
                    continue

                mask_path_seq_func = metadata.get(
                    "mask_path_seq_func", lambda mask_path, seq: None
                )
                mask_path_seq = mask_path_seq_func(mask_path, seq)

                filelist = [
                    os.path.join(dir_path, name) for name in os.listdir(dir_path)
                ]
                filelist.sort()
                filelist = filelist[:: args.pose_eval_stride]

                views = prepare_input(
                    filelist,
                    [True for _ in filelist],
                    size=load_img_size,
                    crop=not args.no_crop,
                    revisit=args.revisit,
                    update=not args.freeze_state,
                )

                start = time.time()
                outputs, _ = inference_recurrent_lighter(views, model, device)
                end = time.time()
                fps = len(filelist) / (end - start)
                print(f"Finished pose estimation for {args.eval_dataset} {seq: <16}, FPS: {fps:.2f}")

                (
                    colors,
                    pts3ds_self,
                    pts3ds_other,
                    conf_self,
                    conf_other,
                    cam_dict,
                    pr_poses,
                ) = prepare_output(
                    outputs, revisit=args.revisit, solve_pose=args.solve_pose
                )

                pred_traj = get_tum_poses(pr_poses)
                os.makedirs(f"{save_dir}/{seq}", exist_ok=True)
                save_tum_poses(pr_poses, f"{save_dir}/{seq}/pred_traj.txt")
                save_focals(cam_dict, f"{save_dir}/{seq}/pred_focal.txt")
                save_intrinsics(cam_dict, f"{save_dir}/{seq}/pred_intrinsics.txt")
                # save_depth_maps(pts3ds_self,f'{save_dir}/{seq}', conf_self=conf_self)
                # save_conf_maps(conf_self,f'{save_dir}/{seq}')
                # save_rgb_imgs(colors,f'{save_dir}/{seq}')

                gt_traj_file = metadata["gt_traj_func"](img_path, anno_path, seq)
                traj_format = metadata.get("traj_format", None)

                if args.eval_dataset == "sintel":
                    gt_traj = load_traj(
                        gt_traj_file=gt_traj_file, stride=args.pose_eval_stride
                    )
                elif traj_format is not None:
                    gt_traj = load_traj(
                        gt_traj_file=gt_traj_file,
                        traj_format=traj_format,
                        stride=args.pose_eval_stride,
                    )
                else:
                    gt_traj = None

                if gt_traj is not None:
                    ate, rpe_trans, rpe_rot = eval_metrics(
                        pred_traj,
                        gt_traj,
                        seq=seq,
                        filename=f"{save_dir}/{seq}_eval_metric.txt",
                    )
                    plot_trajectory(
                        pred_traj, gt_traj, title=seq, filename=f"{save_dir}/{seq}.png"
                    )
                else:
                    ate, rpe_trans, rpe_rot = 0, 0, 0
                    bug = True

                ate_list.append(ate)
                rpe_trans_list.append(rpe_trans)
                rpe_rot_list.append(rpe_rot)

                # Write to error log after each sequence
                with open(error_log_path, "a") as f:
                    f.write(
                        f"{args.eval_dataset}-{seq: <16} | ATE: {ate:.5f}, RPE trans: {rpe_trans:.5f}, RPE rot: {rpe_rot:.5f}\n"
                    )
                    f.write(f"{ate:.5f}\n")
                    f.write(f"{rpe_trans:.5f}\n")
                    f.write(f"{rpe_rot:.5f}\n")

            except Exception as e:
                if "out of memory" in str(e):
                    # Handle OOM
                    torch.cuda.empty_cache()  # Clear the CUDA memory
                    with open(error_log_path, "a") as f:
                        f.write(
                            f"OOM error in sequence {seq}, skipping this sequence.\n"
                        )
                    print(f"OOM error in sequence {seq}, skipping...")
                elif "Degenerate covariance rank" in str(
                    e
                ) or "Eigenvalues did not converge" in str(e):
                    # Handle Degenerate covariance rank exception and Eigenvalues did not converge exception
                    with open(error_log_path, "a") as f:
                        f.write(f"Exception in sequence {seq}: {str(e)}\n")
                    print(f"Traj evaluation error in sequence {seq}, skipping.")
                else:
                    raise e  # Rethrow if it's not an expected exception

    distributed_state.wait_for_everyone()

    results = process_directory(save_dir)
    avg_ate, avg_rpe_trans, avg_rpe_rot = calculate_averages(results)

    # Write the averages to the error log (only on the main process)
    if distributed_state.is_main_process:
        with open(f"{save_dir}/_error_log.txt", "a") as f:
            # Copy the error log from each process to the main error log
            for i in range(distributed_state.num_processes):
                if not os.path.exists(f"{save_dir}/_error_log_{i}.txt"):
                    break
                with open(f"{save_dir}/_error_log_{i}.txt", "r") as f_sub:
                    f.write(f_sub.read())
            f.write(
                f"Average ATE: {avg_ate:.5f}, Average RPE trans: {avg_rpe_trans:.5f}, Average RPE rot: {avg_rpe_rot:.5f}\n"
            )

    return avg_ate, avg_rpe_trans, avg_rpe_rot


if __name__ == "__main__":
    args = get_args_parser()
    args = args.parse_args()
    add_path_to_dust3r(args.weights)
    from dust3r.utils.image import load_images_for_eval as load_images
    from dust3r.post_process import estimate_focal_knowing_depth
    from dust3r.model import ARCroco3DStereo
    from dust3r.utils.camera import pose_encoding_to_camera
    from dust3r.utils.geometry import weighted_procrustes, geotrf

    args.full_seq = False
    args.no_crop = False

    def recover_cam_params(pts3ds_self, pts3ds_other, conf_self, conf_other):
        B, H, W, _ = pts3ds_self.shape
        pp = (
            torch.tensor([W // 2, H // 2], device=pts3ds_self.device)
            .float()
            .repeat(B, 1)
            .reshape(B, 1, 2)
        )
        focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld")

        pts3ds_self = pts3ds_self.reshape(B, -1, 3)
        pts3ds_other = pts3ds_other.reshape(B, -1, 3)
        conf_self = conf_self.reshape(B, -1)
        conf_other = conf_other.reshape(B, -1)
        # weighted procrustes
        c2w = weighted_procrustes(
            pts3ds_self,
            pts3ds_other,
            torch.log(conf_self) * torch.log(conf_other),
            use_weights=True,
            return_T=True,
        )
        return c2w, focal, pp.reshape(B, 2)

    def prepare_input(
        img_paths,
        img_mask,
        size,
        raymaps=None,
        raymap_mask=None,
        revisit=1,
        update=True,
        crop=True,
    ):
        images = load_images(img_paths, size=size, crop=crop, verbose=False)
        views = []
        if raymaps is None and raymap_mask is None:
            num_views = len(images)

            for i in range(num_views):
                view = {
                    "img": images[i]["img"],
                    "ray_map": torch.full(
                        (
                            images[i]["img"].shape[0],
                            6,
                            images[i]["img"].shape[-2],
                            images[i]["img"].shape[-1],
                        ),
                        torch.nan,
                    ),
                    "true_shape": torch.from_numpy(images[i]["true_shape"]),
                    "idx": i,
                    "instance": str(i),
                    "camera_pose": torch.from_numpy(
                        np.eye(4).astype(np.float32)
                    ).unsqueeze(0),
                    "img_mask": torch.tensor(True).unsqueeze(0),
                    "ray_mask": torch.tensor(False).unsqueeze(0),
                    "update": torch.tensor(True).unsqueeze(0),
                    "reset": torch.tensor(False).unsqueeze(0),
                }
                views.append(view)
        else:

            num_views = len(images) + len(raymaps)
            assert len(img_mask) == len(raymap_mask) == num_views
            assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps)

            j = 0
            k = 0
            for i in range(num_views):
                view = {
                    "img": (
                        images[j]["img"]
                        if img_mask[i]
                        else torch.full_like(images[0]["img"], torch.nan)
                    ),
                    "ray_map": (
                        raymaps[k]
                        if raymap_mask[i]
                        else torch.full_like(raymaps[0], torch.nan)
                    ),
                    "true_shape": (
                        torch.from_numpy(images[j]["true_shape"])
                        if img_mask[i]
                        else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]]))
                    ),
                    "idx": i,
                    "instance": str(i),
                    "camera_pose": torch.from_numpy(
                        np.eye(4).astype(np.float32)
                    ).unsqueeze(0),
                    "img_mask": torch.tensor(img_mask[i]).unsqueeze(0),
                    "ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0),
                    "update": torch.tensor(img_mask[i]).unsqueeze(0),
                    "reset": torch.tensor(False).unsqueeze(0),
                }
                if img_mask[i]:
                    j += 1
                if raymap_mask[i]:
                    k += 1
                views.append(view)
            assert j == len(images) and k == len(raymaps)

        if revisit > 1:
            # repeat input for 'revisit' times
            new_views = []
            for r in range(revisit):
                for i in range(len(views)):
                    new_view = deepcopy(views[i])
                    new_view["idx"] = r * len(views) + i
                    new_view["instance"] = str(r * len(views) + i)
                    if r > 0:
                        if not update:
                            new_view["update"] = torch.tensor(False).unsqueeze(0)
                    new_views.append(new_view)
            return new_views
        return views

    def prepare_output(outputs, revisit=1, solve_pose=False):
        valid_length = len(outputs["pred"]) // revisit
        outputs["pred"] = outputs["pred"][-valid_length:]
        outputs["views"] = outputs["views"][-valid_length:]

        if solve_pose:
            pts3ds_self = [
                output["pts3d_in_self_view"].cpu() for output in outputs["pred"]
            ]
            pts3ds_other = [
                output["pts3d_in_other_view"].cpu() for output in outputs["pred"]
            ]
            conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
            conf_other = [output["conf"].cpu() for output in outputs["pred"]]
            pr_poses, focal, pp = recover_cam_params(
                torch.cat(pts3ds_self, 0),
                torch.cat(pts3ds_other, 0),
                torch.cat(conf_self, 0),
                torch.cat(conf_other, 0),
            )
            pts3ds_self = torch.cat(pts3ds_self, 0)
        else:

            pts3ds_self = [
                output["pts3d_in_self_view"].cpu() for output in outputs["pred"]
            ]
            pts3ds_other = [
                output["pts3d_in_other_view"].cpu() for output in outputs["pred"]
            ]
            conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
            conf_other = [output["conf"].cpu() for output in outputs["pred"]]
            pts3ds_self = torch.cat(pts3ds_self, 0)
            pr_poses = [
                pose_encoding_to_camera(pred["camera_pose"].clone()).cpu()
                for pred in outputs["pred"]
            ]
            pr_poses = torch.cat(pr_poses, 0)

            B, H, W, _ = pts3ds_self.shape
            pp = (
                torch.tensor([W // 2, H // 2], device=pts3ds_self.device)
                .float()
                .repeat(B, 1)
                .reshape(B, 2)
            )
            focal = estimate_focal_knowing_depth(
                pts3ds_self, pp, focal_mode="weiszfeld"
            )

        colors = [0.5 * (output["rgb"][0] + 1.0) for output in outputs["pred"]]
        cam_dict = {
            "focal": focal.cpu().numpy(),
            "pp": pp.cpu().numpy(),
        }
        return (
            colors,
            pts3ds_self,
            pts3ds_other,
            conf_self,
            conf_other,
            cam_dict,
            pr_poses,
        )

    model = ARCroco3DStereo.from_pretrained(args.weights)
    
    # set model type
    model.config.model_update_type = args.model_update_type

    eval_pose_estimation(args, model, save_dir=args.output_dir)

```

## /eval/relpose/metadata.py

```py path="/eval/relpose/metadata.py" 
import os
import glob
from tqdm import tqdm

# Define the merged dataset metadata dictionary
dataset_metadata = {
    "davis": {
        "img_path": "data/davis/DAVIS/JPEGImages/480p",
        "mask_path": "data/davis/DAVIS/masked_images/480p",
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
        "gt_traj_func": lambda img_path, anno_path, seq: None,
        "traj_format": None,
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq),
        "skip_condition": None,
        "process_func": None,  # Not used in mono depth estimation
    },
    "kitti": {
        "img_path": "data/kitti/depth_selection/val_selection_cropped/image_gathered",  # Default path
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
        "gt_traj_func": lambda img_path, anno_path, seq: None,
        "traj_format": None,
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,
        "process_func": lambda args, img_path: process_kitti(args, img_path),
    },
    "bonn": {
        "img_path": "data/bonn/rgbd_bonn_dataset",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(
            img_path, f"rgbd_bonn_{seq}", "rgb_110"
        ),
        "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
            img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt"
        ),
        "traj_format": "tum",
        "seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"],
        "full_seq": False,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,
        "process_func": lambda args, img_path: process_bonn(args, img_path),
    },
    "nyu": {
        "img_path": "data/nyu-v2/val/nyu_images",
        "mask_path": None,
        "process_func": lambda args, img_path: process_nyu(args, img_path),
    },
    "scannet": {
        "img_path": "data/scannetv2",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
        "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
            img_path, seq, "pose_90.txt"
        ),
        "traj_format": "replica",
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,  # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
        "process_func": lambda args, img_path: process_scannet(args, img_path),
    },
    "scannet-257": {
        "img_path": "data/scannetv2_3_257",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
        "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
            img_path, seq, "pose_90.txt"
        ),
        "traj_format": "replica",
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,  # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
        "process_func": lambda args, img_path: process_scannet(args, img_path),
    },
    "scannet-129": {
        "img_path": "data/scannetv2_3_129",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
        "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
            img_path, seq, "pose_90.txt"
        ),
        "traj_format": "replica",
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,  # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
        "process_func": lambda args, img_path: process_scannet(args, img_path),
    },
    "scannet-65": {
        "img_path": "data/scannetv2_3_65",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
        "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
            img_path, seq, "pose_90.txt"
        ),
        "traj_format": "replica",
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,  # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
        "process_func": lambda args, img_path: process_scannet(args, img_path),
    },
    "scannet-33": {
        "img_path": "data/scannetv2_3_33",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
        "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
            img_path, seq, "pose_90.txt"
        ),
        "traj_format": "replica",
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,  # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
        "process_func": lambda args, img_path: process_scannet(args, img_path),
    },
    "tum": {
        "img_path": "data/tum",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"),
        "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
            img_path, seq, "groundtruth_90.txt"
        ),
        "traj_format": "tum",
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,
        "process_func": None,
    },
    "sintel": {
        "img_path": "data/sintel/training/final",
        "anno_path": "data/sintel/training/camdata_left",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
        "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq),
        "traj_format": None,
        "seq_list": [
            "alley_2",
            "ambush_4",
            "ambush_5",
            "ambush_6",
            "cave_2",
            "cave_4",
            "market_2",
            "market_5",
            "market_6",
            "shaman_3",
            "sleeping_1",
            "sleeping_2",
            "temple_2",
            "temple_3",
        ],
        "full_seq": False,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,
        "process_func": lambda args, img_path: process_sintel(args, img_path),
    },
}



scannet_numbers = [50, 90, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 950, 1000]
scannet_configs = {
    f"scannet_s3_{num}": {
        "img_path": "data/long_scannet_s3",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq, num=num: os.path.join(img_path, seq, f"color_{num}"),
        "gt_traj_func": lambda img_path, anno_path, seq, num=num: os.path.join(
            img_path, seq, f"pose_{num}.txt"
        ),
        "traj_format": "replica",
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,
        "process_func": lambda args, img_path: process_scannet(args, img_path),
    }
    for num in scannet_numbers
}
# then update dataset_metadata
dataset_metadata.update(scannet_configs)

tum_numbers = [50, 100, 150, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
tum_configs = {
    f"tum_s1_{num}": {
        "img_path": "data/long_tum_s1",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq, num=num: os.path.join(img_path, seq, f"rgb_{num}"),
        "gt_traj_func": lambda img_path, anno_path, seq, num=num: os.path.join(
            img_path, seq, f"groundtruth_{num}.txt"
        ),
        "traj_format": "tum",
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,
        "process_func": None,
        }
    for num in tum_numbers
}
dataset_metadata.update(tum_configs)


# Define processing functions for each dataset
def process_kitti(args, img_path):
    for dir in tqdm(sorted(glob.glob(f"{img_path}/*"))):
        filelist = sorted(glob.glob(f"{dir}/*.png"))
        save_dir = f"{args.output_dir}/{os.path.basename(dir)}"
        yield filelist, save_dir


def process_bonn(args, img_path):
    if args.full_seq:
        for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))):
            filelist = sorted(glob.glob(f"{dir}/rgb/*.png"))
            save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}"
            yield filelist, save_dir
    else:
        seq_list = (
            ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"]
            if args.seq_list is None
            else args.seq_list
        )
        for seq in tqdm(seq_list):
            filelist = sorted(glob.glob(f"{img_path}/rgbd_bonn_{seq}/rgb_110/*.png"))
            save_dir = f"{args.output_dir}/{seq}"
            yield filelist, save_dir


def process_nyu(args, img_path):
    filelist = sorted(glob.glob(f"{img_path}/*.png"))
    save_dir = f"{args.output_dir}"
    yield filelist, save_dir


def process_scannet(args, img_path):
    seq_list = sorted(glob.glob(f"{img_path}/*"))
    for seq in tqdm(seq_list):
        filelist = sorted(glob.glob(f"{seq}/color_90/*.jpg"))
        save_dir = f"{args.output_dir}/{os.path.basename(seq)}"
        yield filelist, save_dir


def process_sintel(args, img_path):
    if args.full_seq:
        for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))):
            filelist = sorted(glob.glob(f"{dir}/*.png"))
            save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}"
            yield filelist, save_dir
    else:
        seq_list = [
            "alley_2",
            "ambush_4",
            "ambush_5",
            "ambush_6",
            "cave_2",
            "cave_4",
            "market_2",
            "market_5",
            "market_6",
            "shaman_3",
            "sleeping_1",
            "sleeping_2",
            "temple_2",
            "temple_3",
        ]
        for seq in tqdm(seq_list):
            filelist = sorted(glob.glob(f"{img_path}/{seq}/*.png"))
            save_dir = f"{args.output_dir}/{seq}"
            yield filelist, save_dir

```

## /eval/relpose/run_scannet.sh

```sh path="/eval/relpose/run_scannet.sh" 
#!/bin/bash

set -e

workdir='.'
model_names=('ttt3r') # ttt3r cut3r

ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"

# datasets=('scannet_s3_50' 'scannet_s3_90' 'scannet_s3_100' 'scannet_s3_150' 'scannet_s3_200' 'scannet_s3_300' 'scannet_s3_400' 'scannet_s3_500'
#             'scannet_s3_600' 'scannet_s3_700' 'scannet_s3_800' 'scannet_s3_900' 'scannet_s3_1000')
datasets=('scannet_s3_1000')

for model_name in "${model_names[@]}"; do
for data in "${datasets[@]}"; do
    output_dir="${workdir}/eval_results/relpose/${data}/${model_name}"
    echo "$output_dir"
    accelerate launch --num_processes 2 --main_process_port 29550 eval/relpose/launch.py \
        --weights "$model_weights" \
        --output_dir "$output_dir" \
        --eval_dataset "$data" \
        --size 512 \
        --model_update_type "$model_name"
done
done



```

## /eval/relpose/run_sintel.sh

```sh path="/eval/relpose/run_sintel.sh" 
#!/bin/bash

set -e

workdir='.'
model_names=('ttt3r') # ttt3r cut3r

ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"


datasets=('sintel')


for model_name in "${model_names[@]}"; do
for data in "${datasets[@]}"; do
    output_dir="${workdir}/eval_results/s1/relpose/${data}/${model_name}"
    echo "$output_dir"
    accelerate launch --num_processes 2 --main_process_port 29550 eval/relpose/launch.py \
        --weights "$model_weights" \
        --output_dir "$output_dir" \
        --eval_dataset "$data" \
        --size 512 \
        --model_update_type "$model_name"
done
done



```

## /eval/relpose/run_tum.sh

```sh path="/eval/relpose/run_tum.sh" 
#!/bin/bash

set -e

workdir='.'
model_names=('ttt3r') # ttt3r cut3r

ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"


# datasets=('tum_s1_50' 'tum_s1_100' 'tum_s1_150' 'tum_s1_200' 'tum_s1_300' 'tum_s1_400' 'tum_s1_500' 'tum_s1_600' 'tum_s1_700' 'tum_s1_800' 'tum_s1_900' 'tum_s1_1000')
datasets=('tum_s1_1000')

for model_name in "${model_names[@]}"; do
for data in "${datasets[@]}"; do
    output_dir="${workdir}/eval_results/relpose/${data}/${model_name}"
    echo "$output_dir"
    accelerate launch --num_processes 2 --main_process_port 29551 eval/relpose/launch.py \
        --weights "$model_weights" \
        --output_dir "$output_dir" \
        --eval_dataset "$data" \
        --size 512 \
        --model_update_type "$model_name"
done
done



```

## /eval/relpose/utils.py

```py path="/eval/relpose/utils.py" 
from copy import deepcopy
import cv2

import numpy as np
import torch
import torch.nn as nn
import roma
from copy import deepcopy
import tqdm
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from scipy.spatial.transform import Rotation
from eval.relpose.evo_utils import *
from PIL import Image
import imageio.v2 as iio
from matplotlib.figure import Figure

# from checkpoints.dust3r.viz import colorize_np, colorize


def todevice(batch, device, callback=None, non_blocking=False):
    """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).

    batch: list, tuple, dict of tensors or other things
    device: pytorch device or 'numpy'
    callback: function that would be called on every sub-elements.
    """
    if callback:
        batch = callback(batch)

    if isinstance(batch, dict):
        return {k: todevice(v, device) for k, v in batch.items()}

    if isinstance(batch, (tuple, list)):
        return type(batch)(todevice(x, device) for x in batch)

    x = batch
    if device == "numpy":
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().numpy()
    elif x is not None:
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x)
        if torch.is_tensor(x):
            x = x.to(device, non_blocking=non_blocking)
    return x


to_device = todevice  # alias


def to_numpy(x):
    return todevice(x, "numpy")


def c2w_to_tumpose(c2w):
    """
    Convert a camera-to-world matrix to a tuple of translation and rotation

    input: c2w: 4x4 matrix
    output: tuple of translation and rotation (x y z qw qx qy qz)
    """
    # convert input to numpy
    c2w = to_numpy(c2w)
    xyz = c2w[:3, -1]
    rot = Rotation.from_matrix(c2w[:3, :3])
    qx, qy, qz, qw = rot.as_quat()
    tum_pose = np.concatenate([xyz, [qw, qx, qy, qz]])
    return tum_pose


def get_tum_poses(poses):
    """
    poses: list of 4x4 arrays
    """
    tt = np.arange(len(poses)).astype(float)
    tum_poses = [c2w_to_tumpose(p) for p in poses]
    tum_poses = np.stack(tum_poses, 0)
    return [tum_poses, tt]


def save_tum_poses(poses, path):
    traj = get_tum_poses(poses)
    save_trajectory_tum_format(traj, path)
    return traj[0]  # return the poses


def save_focals(cam_dict, path):
    # convert focal to txt
    focals = cam_dict["focal"]
    np.savetxt(path, focals, fmt="%.6f")
    return focals


def save_intrinsics(cam_dict, path):
    K_raw = np.eye(3)[None].repeat(len(cam_dict["focal"]), axis=0)
    K_raw[:, 0, 0] = cam_dict["focal"]
    K_raw[:, 1, 1] = cam_dict["focal"]
    K_raw[:, :2, 2] = cam_dict["pp"]
    K = K_raw.reshape(-1, 9)
    np.savetxt(path, K, fmt="%.6f")
    return K_raw


def save_conf_maps(conf, path):
    for i, c in enumerate(conf):
        np.save(f"{path}/conf_{i}.npy", c.detach().cpu().numpy())
    return conf


def save_rgb_imgs(colors, path):
    imgs = colors
    for i, img in enumerate(imgs):
        # convert from rgb to bgr
        iio.imwrite(
            f"{path}/frame_{i:04d}.png", (img.cpu().numpy() * 255).astype(np.uint8)
        )
    return imgs


def save_depth_maps(pts3ds_self, path, conf_self=None):
    depth_maps = torch.stack([pts3d_self[..., -1] for pts3d_self in pts3ds_self], 0)
    min_depth = depth_maps.min()  # float(torch.quantile(out, 0.01))
    max_depth = depth_maps.max()  # float(torch.quantile(out, 0.99))
    colored_depth = colorize(
        depth_maps,
        cmap_name="Spectral_r",
        range=(min_depth, max_depth),
        append_cbar=True,
    )
    images = []

    if conf_self is not None:
        conf_selfs = torch.concat(conf_self, 0)
        min_conf = torch.log(conf_selfs.min())  # float(torch.quantile(out, 0.01))
        max_conf = torch.log(conf_selfs.max())  # float(torch.quantile(out, 0.99))
        colored_conf = colorize(
            torch.log(conf_selfs),
            cmap_name="jet",
            range=(min_conf, max_conf),
            append_cbar=True,
        )

    for i, depth_map in enumerate(colored_depth):
        # Apply color map to depth map
        img_path = f"{path}/depth_frame_{(i):04d}.png"
        if conf_self is None:
            to_save = (depth_map * 255).detach().cpu().numpy().astype(np.uint8)
        else:
            to_save = torch.cat([depth_map, colored_conf[i]], dim=1)
            to_save = (to_save * 255).detach().cpu().numpy().astype(np.uint8)
        iio.imwrite(img_path, to_save)
        images.append(Image.open(img_path))
        np.save(f"{path}/frame_{(i):04d}.npy", depth_maps[i].detach().cpu().numpy())

    images[0].save(
        f"{path}/_depth_maps.gif",
        save_all=True,
        append_images=images[1:],
        duration=100,
        loop=0,
    )

    return depth_maps


def get_vertical_colorbar(h, vmin, vmax, cmap_name="jet", label=None, cbar_precision=2):
    """
    :param w: pixels
    :param h: pixels
    :param vmin: min value
    :param vmax: max value
    :param cmap_name:
    :param label
    :return:
    """
    fig = Figure(figsize=(2, 8), dpi=100)
    fig.subplots_adjust(right=1.5)
    canvas = FigureCanvasAgg(fig)

    # Do some plotting.
    ax = fig.add_subplot(111)
    cmap = cm.get_cmap(cmap_name)
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

    tick_cnt = 6
    tick_loc = np.linspace(vmin, vmax, tick_cnt)
    cb1 = mpl.colorbar.ColorbarBase(
        ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical"
    )

    tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
    if cbar_precision == 0:
        tick_label = [x[:-2] for x in tick_label]

    cb1.set_ticklabels(tick_label)

    cb1.ax.tick_params(labelsize=18, rotation=0)
    if label is not None:
        cb1.set_label(label)

    # fig.tight_layout()

    canvas.draw()
    s, (width, height) = canvas.print_to_buffer()

    im = np.frombuffer(s, np.uint8).reshape((height, width, 4))

    im = im[:, :, :3].astype(np.float32) / 255.0
    if h != im.shape[0]:
        w = int(im.shape[1] / im.shape[0] * h)
        im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)

    return im


def colorize_np(
    x,
    cmap_name="jet",
    mask=None,
    range=None,
    append_cbar=False,
    cbar_in_image=False,
    cbar_precision=2,
):
    """
    turn a grayscale image into a color image
    :param x: input grayscale, [H, W]
    :param cmap_name: the colorization method
    :param mask: the mask image, [H, W]
    :param range: the range for scaling, automatic if None, [min, max]
    :param append_cbar: if append the color bar
    :param cbar_in_image: put the color bar inside the image to keep the output image the same size as the input image
    :return: colorized image, [H, W]
    """
    if range is not None:
        vmin, vmax = range
    elif mask is not None:
        # vmin, vmax = np.percentile(x[mask], (2, 100))
        vmin = np.min(x[mask][np.nonzero(x[mask])])
        vmax = np.max(x[mask])
        # vmin = vmin - np.abs(vmin) * 0.01
        x[np.logical_not(mask)] = vmin
        # print(vmin, vmax)
    else:
        vmin, vmax = np.percentile(x, (1, 100))
        vmax += 1e-6

    x = np.clip(x, vmin, vmax)
    x = (x - vmin) / (vmax - vmin)
    # x = np.clip(x, 0., 1.)

    cmap = cm.get_cmap(cmap_name)
    x_new = cmap(x)[:, :, :3]

    if mask is not None:
        mask = np.float32(mask[:, :, np.newaxis])
        x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)

    cbar = get_vertical_colorbar(
        h=x.shape[0],
        vmin=vmin,
        vmax=vmax,
        cmap_name=cmap_name,
        cbar_precision=cbar_precision,
    )

    if append_cbar:
        if cbar_in_image:
            x_new[:, -cbar.shape[1] :, :] = cbar
        else:
            x_new = np.concatenate(
                (x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
            )
        return x_new
    else:
        return x_new


# tensor
def colorize(
    x, cmap_name="jet", mask=None, range=None, append_cbar=False, cbar_in_image=False
):
    """
    turn a grayscale image into a color image
    :param x: torch.Tensor, grayscale image, [H, W] or [B, H, W]
    :param mask: torch.Tensor or None, mask image, [H, W] or [B, H, W] or None
    """

    device = x.device
    x = x.cpu().numpy()
    if mask is not None:
        mask = mask.cpu().numpy() > 0.99
        kernel = np.ones((3, 3), np.uint8)

    if x.ndim == 2:
        x = x[None]
        if mask is not None:
            mask = mask[None]

    out = []
    for x_ in x:
        if mask is not None:
            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)

        x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image)
        out.append(torch.from_numpy(x_).to(device).float())
    out = torch.stack(out).squeeze(0)
    return out

```

## /eval/video_depth/eval_depth.py

```py path="/eval/video_depth/eval_depth.py" 
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from eval.video_depth.tools import depth_evaluation, group_by_directory
import numpy as np
import cv2
from tqdm import tqdm
import glob
from PIL import Image
import argparse
import json
from eval.video_depth.metadata import dataset_metadata


def get_args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--output_dir",
        type=str,
        default="",
        help="value for outdir",
    )
    parser.add_argument(
        "--eval_dataset", type=str, default="nyu", choices=list(dataset_metadata.keys())
    )
    parser.add_argument(
        "--align",
        type=str,
        default="scale&shift",
        choices=["scale&shift", "scale", "metric"],
    )
    return parser


def main(args):
    if args.eval_dataset == "sintel":
        TAG_FLOAT = 202021.25

        def depth_read(filename):
            """Read depth data from file, return as numpy array."""
            f = open(filename, "rb")
            check = np.fromfile(f, dtype=np.float32, count=1)[0]
            assert (
                check == TAG_FLOAT
            ), " depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? ".format(
                TAG_FLOAT, check
            )
            width = np.fromfile(f, dtype=np.int32, count=1)[0]
            height = np.fromfile(f, dtype=np.int32, count=1)[0]
            size = width * height
            assert (
                width > 0 and height > 0 and size > 1 and size < 100000000
            ), " depth_read:: Wrong input size (width = {0}, height = {1}).".format(
                width, height
            )
            depth = np.fromfile(f, dtype=np.float32, count=-1).reshape((height, width))
            return depth

        pred_pathes = glob.glob(
            f"{args.output_dir}/*/frame_*.npy"
        )  # TODO: update the path to your prediction
        pred_pathes = sorted(pred_pathes)

        if len(pred_pathes) > 643:
            full = True
        else:
            full = False

        if full:
            depth_pathes = glob.glob(f"./data/sintel/training/depth/*/*.dpt")
            depth_pathes = sorted(depth_pathes)
        else:
            seq_list = [
                "alley_2",
                "ambush_4",
                "ambush_5",
                "ambush_6",
                "cave_2",
                "cave_4",
                "market_2",
                "market_5",
                "market_6",
                "shaman_3",
                "sleeping_1",
                "sleeping_2",
                "temple_2",
                "temple_3",
            ]
            depth_pathes_folder = [
                f"./data/sintel/training/depth/{seq}" for seq in seq_list
            ]
            depth_pathes = []
            for depth_pathes_folder_i in depth_pathes_folder:
                depth_pathes += glob.glob(depth_pathes_folder_i + "/*.dpt")
            depth_pathes = sorted(depth_pathes)

        def get_video_results():
            grouped_pred_depth = group_by_directory(pred_pathes)

            grouped_gt_depth = group_by_directory(depth_pathes)
            gathered_depth_metrics = []

            for key in tqdm(grouped_pred_depth.keys()):
                pd_pathes = grouped_pred_depth[key]
                gt_pathes = grouped_gt_depth[key.replace("_pred_depth", "")]

                gt_depth = np.stack(
                    [depth_read(gt_path) for gt_path in gt_pathes], axis=0
                )
                pr_depth = np.stack(
                    [
                        cv2.resize(
                            np.load(pd_path),
                            (gt_depth.shape[2], gt_depth.shape[1]),
                            interpolation=cv2.INTER_CUBIC,
                        )
                        for pd_path in pd_pathes
                    ],
                    axis=0,
                )
                # for depth eval, set align_with_lad2=False to use median alignment; set align_with_lad2=True to use scale&shift alignment
                if args.align == "scale&shift":
                    depth_results, error_map, depth_predict, depth_gt = (
                        depth_evaluation(
                            pr_depth,
                            gt_depth,
                            max_depth=70,
                            align_with_lad2=True,
                            use_gpu=True,
                            post_clip_max=70,
                        )
                    )
                elif args.align == "scale":
                    depth_results, error_map, depth_predict, depth_gt = (
                        depth_evaluation(
                            pr_depth,
                            gt_depth,
                            max_depth=70,
                            align_with_scale=True,
                            use_gpu=True,
                            post_clip_max=70,
                        )
                    )
                elif args.align == "metric":
                    depth_results, error_map, depth_predict, depth_gt = (
                        depth_evaluation(
                            pr_depth,
                            gt_depth,
                            max_depth=70,
                            metric_scale=True,
                            use_gpu=True,
                            post_clip_max=70,
                        )
                    )
                gathered_depth_metrics.append(depth_results)

            depth_log_path = f"{args.output_dir}/result_{args.align}.json"
            average_metrics = {
                key: np.average(
                    [metrics[key] for metrics in gathered_depth_metrics],
                    weights=[
                        metrics["valid_pixels"] for metrics in gathered_depth_metrics
                    ],
                )
                for key in gathered_depth_metrics[0].keys()
                if key != "valid_pixels"
            }
            print("Average depth evaluation metrics:", average_metrics)
            with open(depth_log_path, "w") as f:
                f.write(json.dumps(average_metrics))

        get_video_results()
    elif args.eval_dataset.startswith("bonn"):

        def depth_read(filename):
            # loads depth map D from png file
            # and returns it as a numpy array
            depth_png = np.asarray(Image.open(filename))
            # make sure we have a proper 16bit depth map here.. not 8bit!
            assert np.max(depth_png) > 255
            depth = depth_png.astype(np.float64) / 5000.0
            depth[depth_png == 0] = -1.0
            return depth

        seq_list = ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"]

        # extract number from dataset name, e.g. bonn_400 -> 400
        if "_" in args.eval_dataset:
            bonn_number = args.eval_dataset.split("_")[-1]
        else:
            bonn_number = "110"  # default value

        img_pathes_folder = [
            f"./data/long_bonn_s1/rgbd_bonn_dataset/rgbd_bonn_{seq}/rgb_{bonn_number}/*.png"
            for seq in seq_list
        ]
        img_pathes = []
        for img_pathes_folder_i in img_pathes_folder:
            img_pathes += glob.glob(img_pathes_folder_i)
        img_pathes = sorted(img_pathes)
        depth_pathes_folder = [
            f"./data/long_bonn_s1/rgbd_bonn_dataset/rgbd_bonn_{seq}/depth_{bonn_number}/*.png"
            for seq in seq_list
        ]
        depth_pathes = []
        for depth_pathes_folder_i in depth_pathes_folder:
            depth_pathes += glob.glob(depth_pathes_folder_i)
        depth_pathes = sorted(depth_pathes)
        pred_pathes = glob.glob(
            f"{args.output_dir}/*/frame*.npy"
        )  # TODO: update the path to your prediction
        pred_pathes = sorted(pred_pathes)

        def get_video_results():
            grouped_pred_depth = group_by_directory(pred_pathes)
            grouped_gt_depth = group_by_directory(depth_pathes, idx=-2)
            gathered_depth_metrics = []
            for key in tqdm(grouped_gt_depth.keys()):
                pd_pathes = grouped_pred_depth[key[10:]]
                gt_pathes = grouped_gt_depth[key]
                gt_depth = np.stack(
                    [depth_read(gt_path) for gt_path in gt_pathes], axis=0
                )
                pr_depth = np.stack(
                    [
                        cv2.resize(
                            np.load(pd_path),
                            (gt_depth.shape[2], gt_depth.shape[1]),
                            interpolation=cv2.INTER_CUBIC,
                        )
                        for pd_path in pd_pathes
                    ],
                    axis=0,
                )
                # for depth eval, set align_with_lad2=False to use median alignment; set align_with_lad2=True to use scale&shift alignment
                if args.align == "scale&shift":
                    depth_results, error_map, depth_predict, depth_gt = (
                        depth_evaluation(
                            pr_depth,
                            gt_depth,
                            max_depth=70,
                            align_with_lad2=True,
                            use_gpu=True,
                        )
                    )
                elif args.align == "scale":
                    depth_results, error_map, depth_predict, depth_gt = (
                        depth_evaluation(
                            pr_depth,
                            gt_depth,
                            max_depth=70,
                            align_with_scale=True,
                            use_gpu=True,
                        )
                    )
                elif args.align == "metric":
                    depth_results, error_map, depth_predict, depth_gt = (
                        depth_evaluation(
                            pr_depth,
                            gt_depth,
                            max_depth=70,
                            metric_scale=True,
                            use_gpu=True,
                        )
                    )
                gathered_depth_metrics.append(depth_results)

                # seq_len = gt_depth.shape[0]
                # error_map = error_map.reshape(seq_len, -1, error_map.shape[-1]).cpu()
                # error_map_colored = colorize(error_map, range=(error_map.min(), error_map.max()), append_cbar=True)
                # ImageSequenceClip([x for x in (error_map_colored.numpy()*255).astype(np.uint8)], fps=10).write_videofile(f'{args.output_dir}/errormap_{key}_{args.align}.mp4', fps=10)

            depth_log_path = f"{args.output_dir}/result_{args.align}.json"
            average_metrics = {
                key: np.average(
                    [metrics[key] for metrics in gathered_depth_metrics],
                    weights=[
                        metrics["valid_pixels"] for metrics in gathered_depth_metrics
                    ],
                )
                for key in gathered_depth_metrics[0].keys()
                if key != "valid_pixels"
            }
            print("Average depth evaluation metrics:", average_metrics)
            with open(depth_log_path, "w") as f:
                f.write(json.dumps(average_metrics))

        get_video_results()
    elif args.eval_dataset.startswith("kitti"):

        def depth_read(filename):
            # loads depth map D from png file
            # and returns it as a numpy array,
            # for details see readme.txt
            img_pil = Image.open(filename)
            depth_png = np.array(img_pil, dtype=int)
            # make sure we have a proper 16bit depth map here.. not 8bit!
            assert np.max(depth_png) > 255

            depth = depth_png.astype(float) / 256.0
            depth[depth_png == 0] = -1.0
            return depth

        # extract number from dataset name, e.g. kitti_100 -> 100
        if "_" in args.eval_dataset:
            kitti_number = args.eval_dataset.split("_")[-1]
        else:
            kitti_number = "110"  # default value
        
        depth_pathes = glob.glob(
            f"./data/long_kitti_s1/depth_selection/val_selection_cropped/groundtruth_depth_gathered_{kitti_number}/*/*.png"
        )
        depth_pathes = sorted(depth_pathes)
        pred_pathes = glob.glob(
            f"{args.output_dir}/*/frame_*.npy"
        )  # TODO: update the path to your prediction
        pred_pathes = sorted(pred_pathes)

        def get_video_results():
            grouped_pred_depth = group_by_directory(pred_pathes)
            grouped_gt_depth = group_by_directory(depth_pathes)
            gathered_depth_metrics = []
            for key in tqdm(grouped_pred_depth.keys()):
                pd_pathes = grouped_pred_depth[key]
                gt_pathes = grouped_gt_depth[key]
                gt_depth = np.stack(
                    [depth_read(gt_path) for gt_path in gt_pathes], axis=0
                )
                pr_depth = np.stack(
                    [
                        cv2.resize(
                            np.load(pd_path),
                            (gt_depth.shape[2], gt_depth.shape[1]),
                            interpolation=cv2.INTER_CUBIC,
                        )
                        for pd_path in pd_pathes
                    ],
                    axis=0,
                )

                # for depth eval, set align_with_lad2=False to use median alignment; set align_with_lad2=True to use scale&shift alignment
                if args.align == "scale&shift":
                    depth_results, error_map, depth_predict, depth_gt = (
                        depth_evaluation(
                            pr_depth,
                            gt_depth,
                            max_depth=None,
                            align_with_lad2=True,
                            use_gpu=True,
                        )
                    )
                elif args.align == "scale":
                    depth_results, error_map, depth_predict, depth_gt = (
                        depth_evaluation(
                            pr_depth,
                            gt_depth,
                            max_depth=None,
                            align_with_scale=True,
                            use_gpu=True,
                        )
                    )
                elif args.align == "metric":
                    depth_results, error_map, depth_predict, depth_gt = (
                        depth_evaluation(
                            pr_depth,
                            gt_depth,
                            max_depth=None,
                            metric_scale=True,
                            use_gpu=True,
                        )
                    )
                gathered_depth_metrics.append(depth_results)

            depth_log_path = f"{args.output_dir}/result_{args.align}.json"
            average_metrics = {
                key: np.average(
                    [metrics[key] for metrics in gathered_depth_metrics],
                    weights=[
                        metrics["valid_pixels"] for metrics in gathered_depth_metrics
                    ],
                )
                for key in gathered_depth_metrics[0].keys()
                if key != "valid_pixels"
            }
            print("Average depth evaluation metrics:", average_metrics)
            with open(depth_log_path, "w") as f:
                f.write(json.dumps(average_metrics))

        get_video_results()


if __name__ == "__main__":
    args = get_args_parser()
    args = args.parse_args()
    main(args)

```

## /eval/video_depth/launch.py

```py path="/eval/video_depth/launch.py" 
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import math
import cv2
import numpy as np
import torch
import argparse

from copy import deepcopy
from eval.video_depth.metadata import dataset_metadata
from eval.video_depth.utils import save_depth_maps
from accelerate import PartialState
from add_ckpt_path import add_path_to_dust3r
import time
from tqdm import tqdm


def get_args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--weights",
        type=str,
        help="path to the model weights",
        default="",
    )

    parser.add_argument("--device", type=str, default="cuda", help="pytorch device")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="",
        help="value for outdir",
    )
    parser.add_argument(
        "--no_crop", type=bool, default=True, help="whether to crop input data"
    )

    parser.add_argument(
        "--eval_dataset",
        type=str,
        default="sintel",
        choices=list(dataset_metadata.keys()),
    )
    parser.add_argument("--size", type=int, default="224")

    parser.add_argument(
        "--model_update_type",
        type=str,
        default="cut3r",
        help="model type for state update strategy: cut3r or ttt3r",
    )


    parser.add_argument(
        "--pose_eval_stride", default=1, type=int, help="stride for pose evaluation"
    )
    parser.add_argument(
        "--full_seq",
        action="store_true",
        default=False,
        help="use full sequence for pose evaluation",
    )
    parser.add_argument(
        "--seq_list",
        nargs="+",
        default=None,
        help="list of sequences for pose evaluation",
    )
    return parser


def eval_pose_estimation(args, model, save_dir=None):
    metadata = dataset_metadata.get(args.eval_dataset)
    img_path = metadata["img_path"]
    mask_path = metadata["mask_path"]

    ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist(
        args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path
    )
    return ate_mean, rpe_trans_mean, rpe_rot_mean


def eval_pose_estimation_dist(args, model, img_path, save_dir=None, mask_path=None):
    from dust3r.inference import inference, inference_recurrent, inference_recurrent_lighter

    metadata = dataset_metadata.get(args.eval_dataset)
    anno_path = metadata.get("anno_path", None)

    seq_list = args.seq_list
    if seq_list is None:
        if metadata.get("full_seq", False):
            args.full_seq = True
        else:
            seq_list = metadata.get("seq_list", [])
        if args.full_seq:
            seq_list = os.listdir(img_path)
            seq_list = [
                seq for seq in seq_list if os.path.isdir(os.path.join(img_path, seq))
            ]
        seq_list = sorted(seq_list)

    if save_dir is None:
        save_dir = args.output_dir

    distributed_state = PartialState()
    model.to(distributed_state.device)
    device = distributed_state.device

    with distributed_state.split_between_processes(seq_list) as seqs:
        ate_list = []
        rpe_trans_list = []
        rpe_rot_list = []
        load_img_size = args.size
        assert load_img_size == 512
        error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt"  # Unique log file per process
        bug = False
        for seq in tqdm(seqs):
            try:
                dir_path = metadata["dir_path_func"](img_path, seq)

                # Handle skip_condition
                skip_condition = metadata.get("skip_condition", None)
                if skip_condition is not None and skip_condition(save_dir, seq):
                    continue

                mask_path_seq_func = metadata.get(
                    "mask_path_seq_func", lambda mask_path, seq: None
                )
                mask_path_seq = mask_path_seq_func(mask_path, seq)

                filelist = [
                    os.path.join(dir_path, name) for name in os.listdir(dir_path)
                ]
                filelist.sort()
                filelist = filelist[:: args.pose_eval_stride]

                views = prepare_input(
                    filelist,
                    [True for _ in filelist],
                    size=load_img_size,
                    crop=not args.no_crop,
                )
                start = time.time()
                outputs, _ = inference_recurrent_lighter(views, model, device)
                end = time.time()
                fps = len(filelist) / (end - start)

                (
                    colors,
                    pts3ds_self,
                    pts3ds_other,
                    conf_self,
                    conf_other,
                    cam_dict,
                    pr_poses,
                ) = prepare_output(outputs)

                os.makedirs(f"{save_dir}/{seq}", exist_ok=True)
                save_depth_maps(pts3ds_self, f"{save_dir}/{seq}", conf_self=conf_self)

            except Exception as e:
                if "out of memory" in str(e):
                    # Handle OOM
                    torch.cuda.empty_cache()  # Clear the CUDA memory
                    with open(error_log_path, "a") as f:
                        f.write(
                            f"OOM error in sequence {seq}, skipping this sequence.\n"
                        )
                    print(f"OOM error in sequence {seq}, skipping...")
                elif "Degenerate covariance rank" in str(
                    e
                ) or "Eigenvalues did not converge" in str(e):
                    # Handle Degenerate covariance rank exception and Eigenvalues did not converge exception
                    with open(error_log_path, "a") as f:
                        f.write(f"Exception in sequence {seq}: {str(e)}\n")
                    print(f"Traj evaluation error in sequence {seq}, skipping.")
                else:
                    raise e  # Rethrow if it's not an expected exception
    return None, None, None


if __name__ == "__main__":
    args = get_args_parser()
    args = args.parse_args()
    add_path_to_dust3r(args.weights)
    from dust3r.utils.image import load_images_for_eval as load_images
    from dust3r.post_process import estimate_focal_knowing_depth
    from dust3r.model import ARCroco3DStereo
    from dust3r.utils.camera import pose_encoding_to_camera

    if args.eval_dataset == "sintel":
        args.full_seq = True
    else:
        args.full_seq = False
    args.no_crop = True

    def prepare_input(
        img_paths,
        img_mask,
        size,
        raymaps=None,
        raymap_mask=None,
        revisit=1,
        update=True,
        crop=True,
    ):
        images = load_images(img_paths, size=size, crop=crop)
        views = []
        if raymaps is None and raymap_mask is None:
            num_views = len(images)

            for i in range(num_views):
                view = {
                    "img": images[i]["img"],
                    "ray_map": torch.full(
                        (
                            images[i]["img"].shape[0],
                            6,
                            images[i]["img"].shape[-2],
                            images[i]["img"].shape[-1],
                        ),
                        torch.nan,
                    ),
                    "true_shape": torch.from_numpy(images[i]["true_shape"]),
                    "idx": i,
                    "instance": str(i),
                    "camera_pose": torch.from_numpy(
                        np.eye(4).astype(np.float32)
                    ).unsqueeze(0),
                    "img_mask": torch.tensor(True).unsqueeze(0),
                    "ray_mask": torch.tensor(False).unsqueeze(0),
                    "update": torch.tensor(True).unsqueeze(0),
                    "reset": torch.tensor(False).unsqueeze(0),
                }
                views.append(view)
        else:

            num_views = len(images) + len(raymaps)
            assert len(img_mask) == len(raymap_mask) == num_views
            assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps)

            j = 0
            k = 0
            for i in range(num_views):
                view = {
                    "img": (
                        images[j]["img"]
                        if img_mask[i]
                        else torch.full_like(images[0]["img"], torch.nan)
                    ),
                    "ray_map": (
                        raymaps[k]
                        if raymap_mask[i]
                        else torch.full_like(raymaps[0], torch.nan)
                    ),
                    "true_shape": (
                        torch.from_numpy(images[j]["true_shape"])
                        if img_mask[i]
                        else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]]))
                    ),
                    "idx": i,
                    "instance": str(i),
                    "camera_pose": torch.from_numpy(
                        np.eye(4).astype(np.float32)
                    ).unsqueeze(0),
                    "img_mask": torch.tensor(img_mask[i]).unsqueeze(0),
                    "ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0),
                    "update": torch.tensor(img_mask[i]).unsqueeze(0),
                    "reset": torch.tensor(False).unsqueeze(0),
                }
                if img_mask[i]:
                    j += 1
                if raymap_mask[i]:
                    k += 1
                views.append(view)
            assert j == len(images) and k == len(raymaps)

        if revisit > 1:
            # repeat input for 'revisit' times
            new_views = []
            for r in range(revisit):
                for i in range(len(views)):
                    new_view = deepcopy(views[i])
                    new_view["idx"] = r * len(views) + i
                    new_view["instance"] = str(r * len(views) + i)
                    if r > 0:
                        if not update:
                            new_view["update"] = torch.tensor(False).unsqueeze(0)
                    new_views.append(new_view)
            return new_views
        return views

    def prepare_output(outputs, revisit=1):
        valid_length = len(outputs["pred"]) // revisit
        outputs["pred"] = outputs["pred"][-valid_length:]
        outputs["views"] = outputs["views"][-valid_length:]

        pts3ds_self = [output["pts3d_in_self_view"].cpu() for output in outputs["pred"]]
        pts3ds_other = [
            output["pts3d_in_other_view"].cpu() for output in outputs["pred"]
        ]
        conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
        conf_other = [output["conf"].cpu() for output in outputs["pred"]]
        pts3ds_self = torch.cat(pts3ds_self, 0)
        pr_poses = [
            pose_encoding_to_camera(pred["camera_pose"].clone()).cpu()
            for pred in outputs["pred"]
        ]
        pr_poses = torch.cat(pr_poses, 0)

        B, H, W, _ = pts3ds_self.shape
        pp = (
            torch.tensor([W // 2, H // 2], device=pts3ds_self.device)
            .float()
            .repeat(B, 1)
            .reshape(B, 2)
        )
        focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld")

        colors = [0.5 * (output["rgb"][0] + 1.0) for output in outputs["pred"]]
        cam_dict = {
            "focal": focal.cpu().numpy(),
            "pp": pp.cpu().numpy(),
        }
        return (
            colors,
            pts3ds_self,
            pts3ds_other,
            conf_self,
            conf_other,
            cam_dict,
            pr_poses,
        )

    model = ARCroco3DStereo.from_pretrained(args.weights)

    # set model type
    model.config.model_update_type = args.model_update_type

    eval_pose_estimation(args, model, save_dir=args.output_dir)

```

## /eval/video_depth/metadata.py

```py path="/eval/video_depth/metadata.py" 
import os
import glob
from tqdm import tqdm

# Define the merged dataset metadata dictionary
dataset_metadata = {
    "davis": {
        "img_path": "data/davis/DAVIS/JPEGImages/480p",
        "mask_path": "data/davis/DAVIS/masked_images/480p",
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
        "gt_traj_func": lambda img_path, anno_path, seq: None,
        "traj_format": None,
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq),
        "skip_condition": None,
        "process_func": None,  # Not used in mono depth estimation
    },
    "kitti": {
        "img_path": "data/kitti/depth_selection/val_selection_cropped/image_gathered",  # Default path
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
        "gt_traj_func": lambda img_path, anno_path, seq: None,
        "traj_format": None,
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,
        "process_func": lambda args, img_path: process_kitti(args, img_path),
    },
    "bonn": {
        "img_path": "data/bonn/rgbd_bonn_dataset",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(
            img_path, f"rgbd_bonn_{seq}", "rgb_110"
        ),
        "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
            img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt"
        ),
        "traj_format": "tum",
        "seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"],
        "full_seq": False,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,
        "process_func": lambda args, img_path: process_bonn(args, img_path),
    },
    "nyu": {
        "img_path": "data/nyu-v2/val/nyu_images",
        "mask_path": None,
        "process_func": lambda args, img_path: process_nyu(args, img_path),
    },
    "scannet": {
        "img_path": "data/scannetv2",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
        "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
            img_path, seq, "pose_90.txt"
        ),
        "traj_format": "replica",
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,  # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
        "process_func": lambda args, img_path: process_scannet(args, img_path),
    },
    "tum": {
        "img_path": "data/tum",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"),
        "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
            img_path, seq, "groundtruth_90.txt"
        ),
        "traj_format": "tum",
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,
        "process_func": None,
    },
    "sintel": {
        "img_path": "data/sintel/training/final",
        "anno_path": "data/sintel/training/camdata_left",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
        "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq),
        "traj_format": None,
        "seq_list": [
            "alley_2",
            "ambush_4",
            "ambush_5",
            "ambush_6",
            "cave_2",
            "cave_4",
            "market_2",
            "market_5",
            "market_6",
            "shaman_3",
            "sleeping_1",
            "sleeping_2",
            "temple_2",
            "temple_3",
        ],
        "full_seq": False,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,
        "process_func": lambda args, img_path: process_sintel(args, img_path),
    },
}

kitti_numbers = [50, 100, 110, 150, 200, 250, 300, 350, 400, 450, 500]
kitti_configs = {
    f"kitti_s1_{num}": {
        "img_path": f"data/long_kitti_s1/depth_selection/val_selection_cropped/image_gathered_{num}",  # Default path
        "mask_path": None,
        "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
        "gt_traj_func": lambda img_path, anno_path, seq: None,
        "traj_format": None,
        "seq_list": None,
        "full_seq": True,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,
        "process_func": lambda args, img_path: process_kitti(args, img_path),
    }
    for num in kitti_numbers
}
dataset_metadata.update(kitti_configs)


bonn_numbers = [50, 100, 110, 150, 200, 250, 300, 350, 400, 450, 500]
bonn_configs = {
    f"bonn_s1_{num}": {
        "img_path": "data/long_bonn_s1/rgbd_bonn_dataset",
        "mask_path": None,
        "dir_path_func": lambda img_path, seq, num=num: os.path.join(
            img_path, f"rgbd_bonn_{seq}", f"rgb_{num}"
        ),
        "gt_traj_func": lambda img_path, anno_path, seq, num=num: os.path.join(
            img_path, f"rgbd_bonn_{seq}", f"groundtruth_{num}.txt"
        ),
        "traj_format": "tum",
        "seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"],
        "full_seq": False,
        "mask_path_seq_func": lambda mask_path, seq: None,
        "skip_condition": None,
        "process_func": lambda args, img_path: process_bonn(args, img_path),
    }
    for num in bonn_numbers
}
dataset_metadata.update(bonn_configs)

# Define processing functions for each dataset
def process_kitti(args, img_path):
    for dir in tqdm(sorted(glob.glob(f"{img_path}/*"))):
        filelist = sorted(glob.glob(f"{dir}/*.png"))
        save_dir = f"{args.output_dir}/{os.path.basename(dir)}"
        yield filelist, save_dir


def process_bonn(args, img_path):
    if args.full_seq:
        for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))):
            filelist = sorted(glob.glob(f"{dir}/rgb/*.png"))
            save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}"
            yield filelist, save_dir
    else:
        seq_list = (
            ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"]
            if args.seq_list is None
            else args.seq_list
        )
        for seq in tqdm(seq_list):
            filelist = sorted(glob.glob(f"{img_path}/rgbd_bonn_{seq}/rgb_110/*.png"))
            save_dir = f"{args.output_dir}/{seq}"
            yield filelist, save_dir


def process_nyu(args, img_path):
    filelist = sorted(glob.glob(f"{img_path}/*.png"))
    save_dir = f"{args.output_dir}"
    yield filelist, save_dir


def process_scannet(args, img_path):
    seq_list = sorted(glob.glob(f"{img_path}/*"))
    for seq in tqdm(seq_list):
        filelist = sorted(glob.glob(f"{seq}/color_90/*.jpg"))
        save_dir = f"{args.output_dir}/{os.path.basename(seq)}"
        yield filelist, save_dir


def process_sintel(args, img_path):
    if args.full_seq:
        for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))):
            filelist = sorted(glob.glob(f"{dir}/*.png"))
            save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}"
            yield filelist, save_dir
    else:
        seq_list = [
            "alley_2",
            "ambush_4",
            "ambush_5",
            "ambush_6",
            "cave_2",
            "cave_4",
            "market_2",
            "market_5",
            "market_6",
            "shaman_3",
            "sleeping_1",
            "sleeping_2",
            "temple_2",
            "temple_3",
        ]
        for seq in tqdm(seq_list):
            filelist = sorted(glob.glob(f"{img_path}/{seq}/*.png"))
            save_dir = f"{args.output_dir}/{seq}"
            yield filelist, save_dir

```

## /eval/video_depth/run_bonn.sh

```sh path="/eval/video_depth/run_bonn.sh" 
#!/bin/bash

set -e

workdir='.'
model_names=('ttt3r') # ttt3r cut3r
ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"
# datasets=('bonn_s1_50' 'bonn_s1_100' 'bonn_s1_110' 'bonn_s1_150' 'bonn_s1_200' 'bonn_s1_250' 'bonn_s1_300' 'bonn_s1_350' 'bonn_s1_400' 'bonn_s1_450' 'bonn_s1_500')
datasets=('bonn_s1_500')


for model_name in "${model_names[@]}"; do
for data in "${datasets[@]}"; do
    output_dir="${workdir}/eval_results/video_depth/${data}/${model_name}"
    echo "$output_dir"

    accelerate launch --num_processes 1 --main_process_port 29556 eval/video_depth/launch.py \
        --weights "$model_weights" \
        --output_dir "$output_dir" \
        --eval_dataset "$data" \
        --size 512 \
        --model_update_type "$model_name"

    # scale&shift scale metric
    python eval/video_depth/eval_depth.py \
    --output_dir "$output_dir" \
    --eval_dataset "$data" \
    --align "metric"

    python eval/video_depth/eval_depth.py \
    --output_dir "$output_dir" \
    --eval_dataset "$data" \
    --align "scale"

    python eval/video_depth/eval_depth.py \
    --output_dir "$output_dir" \
    --eval_dataset "$data" \
    --align "scale&shift"
done
done

```

## /eval/video_depth/run_kitti.sh

```sh path="/eval/video_depth/run_kitti.sh" 
#!/bin/bash

set -e

workdir='.'
model_names=('ttt3r') # ttt3r cut3r
ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"
# datasets=('kitti_s1_50' 'kitti_s1_100' 'kitti_s1_110' 'kitti_s1_150' 'kitti_s1_200' 'kitti_s1_250' 'kitti_s1_300' 'kitti_s1_350' 'kitti_s1_400' 'kitti_s1_450' 'kitti_s1_500')
datasets=('kitti_s1_500')


for model_name in "${model_names[@]}"; do
for data in "${datasets[@]}"; do
    output_dir="${workdir}/eval_results/video_depth/${data}/${model_name}"
    echo "$output_dir"

    accelerate launch --num_processes 1 --main_process_port 29555 eval/video_depth/launch.py \
        --weights "$model_weights" \
        --output_dir "$output_dir" \
        --eval_dataset "$data" \
        --size 512 \
        --model_update_type "$model_name"

    # scale&shift scale metric
    python eval/video_depth/eval_depth.py \
    --output_dir "$output_dir" \
    --eval_dataset "$data" \
    --align "metric"

    python eval/video_depth/eval_depth.py \
    --output_dir "$output_dir" \
    --eval_dataset "$data" \
    --align "scale"

    python eval/video_depth/eval_depth.py \
    --output_dir "$output_dir" \
    --eval_dataset "$data" \
    --align "scale&shift"
done
done

```

## /eval/video_depth/run_sintel.sh

```sh path="/eval/video_depth/run_sintel.sh" 
#!/bin/bash

set -e

workdir='.'
model_names=('ttt3r') # ttt3r cut3r

ckpt_name='cut3r_512_dpt_4_64'
model_weights="${workdir}/src/${ckpt_name}.pth"
datasets=('sintel')

for model_name in "${model_names[@]}"; do
for data in "${datasets[@]}"; do
    output_dir="${workdir}/eval_results/video_depth/${data}/${model_name}"
    echo "$output_dir"

    accelerate launch --num_processes 1 --main_process_port 29555 eval/video_depth/launch.py \
        --weights "$model_weights" \
        --output_dir "$output_dir" \
        --eval_dataset "$data" \
        --size 512 \
        --model_update_type "$model_name"

    # scale&shift scale metric
    python eval/video_depth/eval_depth.py \
    --output_dir "$output_dir" \
    --eval_dataset "$data" \
    --align "metric"

    python eval/video_depth/eval_depth.py \
    --output_dir "$output_dir" \
    --eval_dataset "$data" \
    --align "scale"

    python eval/video_depth/eval_depth.py \
    --output_dir "$output_dir" \
    --eval_dataset "$data" \
    --align "scale&shift"
done
done

```

## /eval/video_depth/tools.py

```py path="/eval/video_depth/tools.py" 
import torch
import numpy as np
import cv2
import glob
import argparse
from pathlib import Path
from tqdm import tqdm
from copy import deepcopy
from scipy.optimize import minimize
import os
from collections import defaultdict


def group_by_directory(pathes, idx=-1):
    """
    Groups the file paths based on the second-to-last directory in their paths.

    Parameters:
    - pathes (list): List of file paths.

    Returns:
    - dict: A dictionary where keys are the second-to-last directory names and values are lists of file paths.
    """
    grouped_pathes = defaultdict(list)

    for path in pathes:
        # Extract the second-to-last directory
        dir_name = os.path.dirname(path).split("/")[idx]
        grouped_pathes[dir_name].append(path)

    return grouped_pathes


def depth2disparity(depth, return_mask=False):
    if isinstance(depth, torch.Tensor):
        disparity = torch.zeros_like(depth)
    elif isinstance(depth, np.ndarray):
        disparity = np.zeros_like(depth)
    non_negtive_mask = depth > 0
    disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
    if return_mask:
        return disparity, non_negtive_mask
    else:
        return disparity


def absolute_error_loss(params, predicted_depth, ground_truth_depth):
    s, t = params

    predicted_aligned = s * predicted_depth + t

    abs_error = np.abs(predicted_aligned - ground_truth_depth)
    return np.sum(abs_error)


def absolute_value_scaling(predicted_depth, ground_truth_depth, s=1, t=0):
    predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1)
    ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1)

    initial_params = [s, t]  # s = 1, t = 0

    result = minimize(
        absolute_error_loss,
        initial_params,
        args=(predicted_depth_np, ground_truth_depth_np),
    )

    s, t = result.x
    return s, t


def absolute_value_scaling2(
    predicted_depth,
    ground_truth_depth,
    s_init=1.0,
    t_init=0.0,
    lr=1e-4,
    max_iters=1000,
    tol=1e-6,
):
    # Initialize s and t as torch tensors with requires_grad=True
    s = torch.tensor(
        [s_init],
        requires_grad=True,
        device=predicted_depth.device,
        dtype=predicted_depth.dtype,
    )
    t = torch.tensor(
        [t_init],
        requires_grad=True,
        device=predicted_depth.device,
        dtype=predicted_depth.dtype,
    )

    optimizer = torch.optim.Adam([s, t], lr=lr)

    prev_loss = None

    for i in range(max_iters):
        optimizer.zero_grad()

        # Compute predicted aligned depth
        predicted_aligned = s * predicted_depth + t

        # Compute absolute error
        abs_error = torch.abs(predicted_aligned - ground_truth_depth)

        # Compute loss
        loss = torch.sum(abs_error)

        # Backpropagate
        loss.backward()

        # Update parameters
        optimizer.step()

        # Check convergence
        if prev_loss is not None and torch.abs(prev_loss - loss) < tol:
            break

        prev_loss = loss.item()

    return s.detach().item(), t.detach().item()


def depth_evaluation(
    predicted_depth_original,
    ground_truth_depth_original,
    max_depth=80,
    custom_mask=None,
    post_clip_min=None,
    post_clip_max=None,
    pre_clip_min=None,
    pre_clip_max=None,
    align_with_lstsq=False,
    align_with_lad=False,
    align_with_lad2=False,
    metric_scale=False,
    lr=1e-4,
    max_iters=1000,
    use_gpu=False,
    align_with_scale=False,
    disp_input=False,
):
    """
    Evaluate the depth map using various metrics and return a depth error parity map, with an option for least squares alignment.

    Args:
        predicted_depth (numpy.ndarray or torch.Tensor): The predicted depth map.
        ground_truth_depth (numpy.ndarray or torch.Tensor): The ground truth depth map.
        max_depth (float): The maximum depth value to consider. Default is 80 meters.
        align_with_lstsq (bool): If True, perform least squares alignment of the predicted depth with ground truth.

    Returns:
        dict: A dictionary containing the evaluation metrics.
        torch.Tensor: The depth error parity map.
    """
    if isinstance(predicted_depth_original, np.ndarray):
        predicted_depth_original = torch.from_numpy(predicted_depth_original)
    if isinstance(ground_truth_depth_original, np.ndarray):
        ground_truth_depth_original = torch.from_numpy(ground_truth_depth_original)
    if custom_mask is not None and isinstance(custom_mask, np.ndarray):
        custom_mask = torch.from_numpy(custom_mask)

    # if the dimension is 3, flatten to 2d along the batch dimension
    if predicted_depth_original.dim() == 3:
        _, h, w = predicted_depth_original.shape
        predicted_depth_original = predicted_depth_original.view(-1, w)
        ground_truth_depth_original = ground_truth_depth_original.view(-1, w)
        if custom_mask is not None:
            custom_mask = custom_mask.view(-1, w)

    # put to device
    if use_gpu:
        predicted_depth_original = predicted_depth_original.cuda()
        ground_truth_depth_original = ground_truth_depth_original.cuda()

    # Filter out depths greater than max_depth
    if max_depth is not None:
        mask = (ground_truth_depth_original > 0) & (
            ground_truth_depth_original < max_depth
        )
    else:
        mask = ground_truth_depth_original > 0
    predicted_depth = predicted_depth_original[mask]
    ground_truth_depth = ground_truth_depth_original[mask]

    # Clip the depth values
    if pre_clip_min is not None:
        predicted_depth = torch.clamp(predicted_depth, min=pre_clip_min)
    if pre_clip_max is not None:
        predicted_depth = torch.clamp(predicted_depth, max=pre_clip_max)

    if disp_input:  # align the pred to gt in the disparity space
        real_gt = ground_truth_depth.clone()
        ground_truth_depth = 1 / (ground_truth_depth + 1e-8)

    # various alignment methods
    if metric_scale:
        predicted_depth = predicted_depth
    elif align_with_lstsq:
        # Convert to numpy for lstsq
        predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1, 1)
        ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1, 1)

        # Add a column of ones for the shift term
        A = np.hstack([predicted_depth_np, np.ones_like(predicted_depth_np)])

        # Solve for scale (s) and shift (t) using least squares
        result = np.linalg.lstsq(A, ground_truth_depth_np, rcond=None)
        s, t = result[0][0], result[0][1]

        # convert to torch tensor
        s = torch.tensor(s, device=predicted_depth_original.device)
        t = torch.tensor(t, device=predicted_depth_original.device)

        # Apply scale and shift
        predicted_depth = s * predicted_depth + t
    elif align_with_lad:
        s, t = absolute_value_scaling(
            predicted_depth,
            ground_truth_depth,
            s=torch.median(ground_truth_depth) / torch.median(predicted_depth),
        )
        predicted_depth = s * predicted_depth + t
    elif align_with_lad2:
        s_init = (
            torch.median(ground_truth_depth) / torch.median(predicted_depth)
        ).item()
        s, t = absolute_value_scaling2(
            predicted_depth,
            ground_truth_depth,
            s_init=s_init,
            lr=lr,
            max_iters=max_iters,
        )
        predicted_depth = s * predicted_depth + t
    elif align_with_scale:
        # Compute initial scale factor 's' using the closed-form solution (L2 norm)
        dot_pred_gt = torch.nanmean(ground_truth_depth)
        dot_pred_pred = torch.nanmean(predicted_depth)
        s = dot_pred_gt / dot_pred_pred

        # Iterative reweighted least squares using the Weiszfeld method
        for _ in range(10):
            # Compute residuals between scaled predictions and ground truth
            residuals = s * predicted_depth - ground_truth_depth
            abs_residuals = (
                residuals.abs() + 1e-8
            )  # Add small constant to avoid division by zero

            # Compute weights inversely proportional to the residuals
            weights = 1.0 / abs_residuals

            # Update 's' using weighted sums
            weighted_dot_pred_gt = torch.sum(
                weights * predicted_depth * ground_truth_depth
            )
            weighted_dot_pred_pred = torch.sum(weights * predicted_depth**2)
            s = weighted_dot_pred_gt / weighted_dot_pred_pred

        # Optionally clip 's' to prevent extreme scaling
        s = s.clamp(min=1e-3)

        # Detach 's' if you want to stop gradients from flowing through it
        s = s.detach()

        # Apply the scale factor to the predicted depth
        predicted_depth = s * predicted_depth

    else:
        # Align the predicted depth with the ground truth using median scaling
        scale_factor = torch.median(ground_truth_depth) / torch.median(predicted_depth)
        predicted_depth *= scale_factor

    if disp_input:
        # convert back to depth
        ground_truth_depth = real_gt
        predicted_depth = depth2disparity(predicted_depth)

    # Clip the predicted depth values
    if post_clip_min is not None:
        predicted_depth = torch.clamp(predicted_depth, min=post_clip_min)
    if post_clip_max is not None:
        predicted_depth = torch.clamp(predicted_depth, max=post_clip_max)

    if custom_mask is not None:
        assert custom_mask.shape == ground_truth_depth_original.shape
        mask_within_mask = custom_mask.cpu()[mask]
        predicted_depth = predicted_depth[mask_within_mask]
        ground_truth_depth = ground_truth_depth[mask_within_mask]

    # Calculate the metrics
    abs_rel = torch.mean(
        torch.abs(predicted_depth - ground_truth_depth) / ground_truth_depth
    ).item()
    sq_rel = torch.mean(
        ((predicted_depth - ground_truth_depth) ** 2) / ground_truth_depth
    ).item()

    # Correct RMSE calculation
    rmse = torch.sqrt(torch.mean((predicted_depth - ground_truth_depth) ** 2)).item()

    # Clip the depth values to avoid log(0)
    predicted_depth = torch.clamp(predicted_depth, min=1e-5)
    log_rmse = torch.sqrt(
        torch.mean((torch.log(predicted_depth) - torch.log(ground_truth_depth)) ** 2)
    ).item()

    # Calculate the accuracy thresholds
    max_ratio = torch.maximum(
        predicted_depth / ground_truth_depth, ground_truth_depth / predicted_depth
    )
    threshold_0 = torch.mean((max_ratio < 1.0).float()).item()
    threshold_1 = torch.mean((max_ratio < 1.25).float()).item()
    threshold_2 = torch.mean((max_ratio < 1.25**2).float()).item()
    threshold_3 = torch.mean((max_ratio < 1.25**3).float()).item()

    # Compute the depth error parity map
    if metric_scale:
        predicted_depth_original = predicted_depth_original
        if disp_input:
            predicted_depth_original = depth2disparity(predicted_depth_original)
        depth_error_parity_map = (
            torch.abs(predicted_depth_original - ground_truth_depth_original)
            / ground_truth_depth_original
        )
    elif align_with_lstsq or align_with_lad or align_with_lad2:
        predicted_depth_original = predicted_depth_original * s + t
        if disp_input:
            predicted_depth_original = depth2disparity(predicted_depth_original)
        depth_error_parity_map = (
            torch.abs(predicted_depth_original - ground_truth_depth_original)
            / ground_truth_depth_original
        )
    elif align_with_scale:
        predicted_depth_original = predicted_depth_original * s
        if disp_input:
            predicted_depth_original = depth2disparity(predicted_depth_original)
        depth_error_parity_map = (
            torch.abs(predicted_depth_original - ground_truth_depth_original)
            / ground_truth_depth_original
        )
    else:
        predicted_depth_original = predicted_depth_original * scale_factor
        if disp_input:
            predicted_depth_original = depth2disparity(predicted_depth_original)
        depth_error_parity_map = (
            torch.abs(predicted_depth_original - ground_truth_depth_original)
            / ground_truth_depth_original
        )

    # Reshape the depth_error_parity_map back to the original image size
    depth_error_parity_map_full = torch.zeros_like(ground_truth_depth_original)
    depth_error_parity_map_full = torch.where(
        mask, depth_error_parity_map, depth_error_parity_map_full
    )

    predict_depth_map_full = predicted_depth_original
    gt_depth_map_full = torch.zeros_like(ground_truth_depth_original)
    gt_depth_map_full = torch.where(
        mask, ground_truth_depth_original, gt_depth_map_full
    )

    num_valid_pixels = (
        torch.sum(mask).item()
        if custom_mask is None
        else torch.sum(mask_within_mask).item()
    )
    if num_valid_pixels == 0:
        (
            abs_rel,
            sq_rel,
            rmse,
            log_rmse,
            threshold_0,
            threshold_1,
            threshold_2,
            threshold_3,
        ) = (0, 0, 0, 0, 0, 0, 0, 0)

    results = {
        "Abs Rel": abs_rel,
        "Sq Rel": sq_rel,
        "RMSE": rmse,
        "Log RMSE": log_rmse,
        "δ < 1.": threshold_0,
        "δ < 1.25": threshold_1,
        "δ < 1.25^2": threshold_2,
        "δ < 1.25^3": threshold_3,
        "valid_pixels": num_valid_pixels,
    }

    return (
        results,
        depth_error_parity_map_full,
        predict_depth_map_full,
        gt_depth_map_full,
    )

```

## /eval/video_depth/utils.py

```py path="/eval/video_depth/utils.py" 
from copy import deepcopy
import cv2

import numpy as np
import torch
import torch.nn as nn
import roma
from copy import deepcopy
import tqdm
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from scipy.spatial.transform import Rotation
from PIL import Image
import imageio.v2 as iio
from matplotlib.figure import Figure


def save_focals(cam_dict, path):
    # convert focal to txt
    focals = cam_dict["focal"]
    np.savetxt(path, focals, fmt="%.6f")
    return focals


def save_intrinsics(cam_dict, path):
    K_raw = np.eye(3)[None].repeat(len(cam_dict["focal"]), axis=0)
    K_raw[:, 0, 0] = cam_dict["focal"]
    K_raw[:, 1, 1] = cam_dict["focal"]
    K_raw[:, :2, 2] = cam_dict["pp"]
    K = K_raw.reshape(-1, 9)
    np.savetxt(path, K, fmt="%.6f")
    return K_raw


def save_conf_maps(conf, path):
    for i, c in enumerate(conf):
        np.save(f"{path}/conf_{i}.npy", c.detach().cpu().numpy())
    return conf


def save_rgb_imgs(colors, path):
    imgs = colors
    for i, img in enumerate(imgs):
        # convert from rgb to bgr
        iio.imwrite(
            f"{path}/frame_{i:04d}.jpg", (img.cpu().numpy() * 255).astype(np.uint8)
        )
    return imgs


def save_depth_maps(pts3ds_self, path, conf_self=None):
    depth_maps = torch.stack([pts3d_self[..., -1] for pts3d_self in pts3ds_self], 0)
    min_depth = depth_maps.min()  # float(torch.quantile(out, 0.01))
    max_depth = depth_maps.max()  # float(torch.quantile(out, 0.99))
    colored_depth = colorize(
        depth_maps,
        cmap_name="Spectral_r",
        range=(min_depth, max_depth),
        append_cbar=True,
    )
    images = []

    if conf_self is not None:
        conf_selfs = torch.concat(conf_self, 0)
        min_conf = torch.log(conf_selfs.min())  # float(torch.quantile(out, 0.01))
        max_conf = torch.log(conf_selfs.max())  # float(torch.quantile(out, 0.99))
        colored_conf = colorize(
            torch.log(conf_selfs),
            cmap_name="jet",
            range=(min_conf, max_conf),
            append_cbar=True,
        )

    for i, depth_map in enumerate(colored_depth):
        # Apply color map to depth map
        img_path = f"{path}/frame_{(i):04d}.png"
        if conf_self is None:
            to_save = (depth_map * 255).detach().cpu().numpy().astype(np.uint8)
        else:
            to_save = torch.cat([depth_map, colored_conf[i]], dim=1)
            to_save = (to_save * 255).detach().cpu().numpy().astype(np.uint8)
        iio.imwrite(img_path, to_save)
        images.append(Image.open(img_path))
        np.save(f"{path}/frame_{(i):04d}.npy", depth_maps[i].detach().cpu().numpy())

    # comment this as it may fail sometimes
    # images[0].save(f'{path}/_depth_maps.gif', save_all=True, append_images=images[1:], duration=100, loop=0)

    return depth_maps


def get_vertical_colorbar(h, vmin, vmax, cmap_name="jet", label=None, cbar_precision=2):
    """
    :param w: pixels
    :param h: pixels
    :param vmin: min value
    :param vmax: max value
    :param cmap_name:
    :param label
    :return:
    """
    fig = Figure(figsize=(2, 8), dpi=100)
    fig.subplots_adjust(right=1.5)
    canvas = FigureCanvasAgg(fig)

    # Do some plotting.
    ax = fig.add_subplot(111)
    cmap = cm.get_cmap(cmap_name)
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

    tick_cnt = 6
    tick_loc = np.linspace(vmin, vmax, tick_cnt)
    cb1 = mpl.colorbar.ColorbarBase(
        ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical"
    )

    tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
    if cbar_precision == 0:
        tick_label = [x[:-2] for x in tick_label]

    cb1.set_ticklabels(tick_label)

    cb1.ax.tick_params(labelsize=18, rotation=0)
    if label is not None:
        cb1.set_label(label)

    # fig.tight_layout()

    canvas.draw()
    s, (width, height) = canvas.print_to_buffer()

    im = np.frombuffer(s, np.uint8).reshape((height, width, 4))

    im = im[:, :, :3].astype(np.float32) / 255.0
    if h != im.shape[0]:
        w = int(im.shape[1] / im.shape[0] * h)
        im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)

    return im


def colorize_np(
    x,
    cmap_name="jet",
    mask=None,
    range=None,
    append_cbar=False,
    cbar_in_image=False,
    cbar_precision=2,
):
    """
    turn a grayscale image into a color image
    :param x: input grayscale, [H, W]
    :param cmap_name: the colorization method
    :param mask: the mask image, [H, W]
    :param range: the range for scaling, automatic if None, [min, max]
    :param append_cbar: if append the color bar
    :param cbar_in_image: put the color bar inside the image to keep the output image the same size as the input image
    :return: colorized image, [H, W]
    """
    if range is not None:
        vmin, vmax = range
    elif mask is not None:
        # vmin, vmax = np.percentile(x[mask], (2, 100))
        vmin = np.min(x[mask][np.nonzero(x[mask])])
        vmax = np.max(x[mask])
        # vmin = vmin - np.abs(vmin) * 0.01
        x[np.logical_not(mask)] = vmin
        # print(vmin, vmax)
    else:
        vmin, vmax = np.percentile(x, (1, 100))
        vmax += 1e-6

    x = np.clip(x, vmin, vmax)
    x = (x - vmin) / (vmax - vmin)
    # x = np.clip(x, 0., 1.)

    cmap = cm.get_cmap(cmap_name)
    x_new = cmap(x)[:, :, :3]

    if mask is not None:
        mask = np.float32(mask[:, :, np.newaxis])
        x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)

    cbar = get_vertical_colorbar(
        h=x.shape[0],
        vmin=vmin,
        vmax=vmax,
        cmap_name=cmap_name,
        cbar_precision=cbar_precision,
    )

    if append_cbar:
        if cbar_in_image:
            x_new[:, -cbar.shape[1] :, :] = cbar
        else:
            x_new = np.concatenate(
                (x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
            )
        return x_new
    else:
        return x_new


# tensor
def colorize(
    x, cmap_name="jet", mask=None, range=None, append_cbar=False, cbar_in_image=False
):
    """
    turn a grayscale image into a color image
    :param x: torch.Tensor, grayscale image, [H, W] or [B, H, W]
    :param mask: torch.Tensor or None, mask image, [H, W] or [B, H, W] or None
    """

    device = x.device
    x = x.cpu().numpy()
    if mask is not None:
        mask = mask.cpu().numpy() > 0.99
        kernel = np.ones((3, 3), np.uint8)

    if x.ndim == 2:
        x = x[None]
        if mask is not None:
            mask = mask[None]

    out = []
    for x_ in x:
        if mask is not None:
            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)

        x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image)
        out.append(torch.from_numpy(x_).to(device).float())
    out = torch.stack(out).squeeze(0)
    return out

```

## /examples/taylor.mp4

Binary file available at https://raw.githubusercontent.com/Inception3D/TTT3R/refs/heads/main/examples/taylor.mp4

## /examples/westlake.mp4

Binary file available at https://raw.githubusercontent.com/Inception3D/TTT3R/refs/heads/main/examples/westlake.mp4

## /requirements.txt

numpy==1.26.4
torch
torchvision
roma
gradio
matplotlib
tqdm
opencv-python
scipy
einops
trimesh
tensorboard
pyglet<2
huggingface-hub[torch]>=0.22
viser
gradio
lpips
hydra-core
pillow==10.3.0
h5py
accelerate
transformers
scikit-learn

## /src/croco/LICENSE

``` path="/src/croco/LICENSE" 
CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.

A summary of the CC BY-NC-SA 4.0 license is located here:
	https://creativecommons.org/licenses/by-nc-sa/4.0/

The CC BY-NC-SA 4.0 license is located here:
	https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
	
	
SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py

***************************

NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py

This software is being redistributed in a modifiled form. The original form is available here:

https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py

This software in this file incorporates parts of the following software available here:

Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py
available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE

MoCo v3: https://github.com/facebookresearch/moco-v3
available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE

DeiT: https://github.com/facebookresearch/deit
available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE


ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:

https://github.com/facebookresearch/mae/blob/main/LICENSE

Attribution-NonCommercial 4.0 International

***************************

NOTICE WITH RESPECT TO THE FILE: models/blocks.py

This software is being redistributed in a modifiled form. The original form is available here:

https://github.com/rwightman/pytorch-image-models

ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:

https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE

Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
```

## /src/croco/NOTICE

``` path="/src/croco/NOTICE" 
CroCo
Copyright 2022-present NAVER Corp.

This project contains subcomponents with separate copyright notices and license terms. 
Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.

====

facebookresearch/mae
https://github.com/facebookresearch/mae

Attribution-NonCommercial 4.0 International

====

rwightman/pytorch-image-models
https://github.com/rwightman/pytorch-image-models

Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
```

## /src/croco/README.MD

# CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow

[[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)]

This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2:

![image](assets/arch.jpg)

```bibtex
@inproceedings{croco,
  title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}},
  author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}},
  booktitle={{NeurIPS}},
  year={2022}
}

@inproceedings{croco_v2,
  title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}},
  author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me}, 
  booktitle={ICCV},
  year={2023}
}
```

## License

The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information.
Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License.
Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license.

## Preparation

1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version.

```bash
conda create -n croco python=3.7 cmake=3.14.0
conda activate croco
conda install habitat-sim headless -c conda-forge -c aihabitat
conda install pytorch torchvision -c pytorch
conda install notebook ipykernel matplotlib
conda install ipywidgets widgetsnbextension
conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation

```

2. Compile cuda kernels for RoPE

CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels.
```bash
cd models/curope/
python setup.py build_ext --inplace
cd ../../
```

This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only.
You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation.

In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded.

3. Download pre-trained model

We provide several pre-trained models:

| modelname                                                                                                                          | pre-training data | pos. embed. | Encoder | Decoder |
|------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------|
| [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth)                                                 | Habitat           | cosine      | ViT-B   | Small   |
| [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real    | RoPE        | ViT-B   | Small   |
| [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth)   | Habitat + real    | RoPE        | ViT-B   | Base    |
| [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real    | RoPE        | ViT-L   | Base    |

To download a specific model, i.e., the first one (`CroCo.pth`)
```bash
mkdir -p pretrained_models/
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/
```

## Reconstruction example

Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`)
```bash
python demo.py
```

## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator

First download the test scene from Habitat:
```bash
python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/
```

Then, run the Notebook demo `interactive_demo.ipynb`.

In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo.
![croco_interactive_demo](https://user-images.githubusercontent.com/1822210/200516576-7937bc6a-55f8-49ed-8618-3ddf89433ea4.jpg)

## Pre-training 

### CroCo 

To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command:
```
torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/
```

Our CroCo pre-training was launched on a single server with 4 GPUs.
It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training.
Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
The first run can take a few minutes to start, to parse all available pre-training pairs.

### CroCo v2 

For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD).
Then, run the following command for the largest model (ViT-L encoder, Base decoder):
```
torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/
```

Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases.
The largest model should take around 12 days on A100.
Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.

## Stereo matching and Optical flow downstream tasks

For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD).


## /src/croco/assets/Chateau1.png

Binary file available at https://raw.githubusercontent.com/Inception3D/TTT3R/refs/heads/main/src/croco/assets/Chateau1.png

## /src/croco/assets/Chateau2.png

Binary file available at https://raw.githubusercontent.com/Inception3D/TTT3R/refs/heads/main/src/croco/assets/Chateau2.png

## /src/croco/assets/arch.jpg

Binary file available at https://raw.githubusercontent.com/Inception3D/TTT3R/refs/heads/main/src/croco/assets/arch.jpg

## /src/croco/croco-stereo-flow-demo.ipynb

```ipynb path="/src/croco/croco-stereo-flow-demo.ipynb" 
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9bca0f41",
   "metadata": {},
   "source": [
    "# Simple inference example with CroCo-Stereo or CroCo-Flow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80653ef7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
    "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f033862",
   "metadata": {},
   "source": [
    "First download the model(s) of your choice by running\n",
    "\`\`\`\n",
    "bash stereoflow/download_model.sh crocostereo.pth\n",
    "bash stereoflow/download_model.sh crocoflow.pth\n",
    "\`\`\`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fb2e392",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
    "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
    "import matplotlib.pylab as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0e25d77",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stereoflow.test import _load_model_and_criterion\n",
    "from stereoflow.engine import tiled_pred\n",
    "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n",
    "from stereoflow.datasets_flow import flowToColor\n",
    "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86a921f5",
   "metadata": {},
   "source": [
    "### CroCo-Stereo example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64e483cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "image1 = np.asarray(Image.open('<path_to_left_image>'))\n",
    "image2 = np.asarray(Image.open('<path_to_right_image>'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0d04303",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47dc14b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
    "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
    "with torch.inference_mode():\n",
    "    pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
    "pred = pred.squeeze(0).squeeze(0).cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "583b9f16",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(vis_disparity(pred))\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2df5d70",
   "metadata": {},
   "source": [
    "### CroCo-Flow example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ee257a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "image1 = np.asarray(Image.open('<path_to_first_image>'))\n",
    "image2 = np.asarray(Image.open('<path_to_second_image>'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5edccf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b19692c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
    "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
    "with torch.inference_mode():\n",
    "    pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
    "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26f79db3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(flowToColor(pred))\n",
    "plt.axis('off')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}

```

## /src/croco/datasets/__init__.py

```py path="/src/croco/datasets/__init__.py" 

```

## /src/croco/datasets/habitat_sim/__init__.py

```py path="/src/croco/datasets/habitat_sim/__init__.py" 

```

## /src/croco/datasets/habitat_sim/generate_from_metadata_files.py

```py path="/src/croco/datasets/habitat_sim/generate_from_metadata_files.py" 
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).

"""
Script generating commandlines to generate image pairs from metadata files.
"""
import os
import glob
from tqdm import tqdm
import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", required=True)
    parser.add_argument("--output_dir", required=True)
    parser.add_argument(
        "--prefix",
        default="",
        help="Commanline prefix, useful e.g. to setup environment.",
    )
    args = parser.parse_args()

    input_metadata_filenames = glob.iglob(
        f"{args.input_dir}/**/metadata.json", recursive=True
    )

    for metadata_filename in tqdm(input_metadata_filenames):
        output_dir = os.path.join(
            args.output_dir,
            os.path.relpath(os.path.dirname(metadata_filename), args.input_dir),
        )
        # Do not process the scene if the metadata file already exists
        if os.path.exists(os.path.join(output_dir, "metadata.json")):
            continue
        commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
        print(commandline)

```

## /src/croco/models/curope/__init__.py

```py path="/src/croco/models/curope/__init__.py" 
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).

from .curope2d import cuRoPE2D

```

## /src/croco/stereoflow/download_model.sh

```sh path="/src/croco/stereoflow/download_model.sh" 
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).

model=$1
outfile="stereoflow_models/${model}"
if [[ ! -f $outfile ]]
then
	mkdir -p stereoflow_models/;
	wget https://download.europe.naverlabs.com/ComputerVision/CroCo/StereoFlow_models/$1 -P stereoflow_models/;
else
	echo "Model ${model} already downloaded in ${outfile}."
fi
```

## /src/dust3r/__init__.py

```py path="/src/dust3r/__init__.py" 

```

## /src/dust3r/datasets/base/__init__.py

```py path="/src/dust3r/datasets/base/__init__.py" 

```


The content has been capped at 50000 tokens. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.
Copied!