ByteDance-Seed/TraceAnything/main 27k tokens More Tools
```
├── LICENSE.txt (2.3k tokens)
├── LICENSE_MODEL.txt (100 tokens)
├── README.md (1800 tokens)
├── assets/
   ├── interactive_monkeys.gif
   ├── panel_screenshot.png
   ├── teaser.png
├── configs/
   ├── eval.yaml (400 tokens)
├── examples/
   ├── input/
      ├── cabinet/
         ├── 000.png
         ├── 001.png
         ├── 002.png
         ├── 003.png
         ├── 004.png
         ├── 005.png
         ├── 006.png
         ├── 007.png
         ├── 008.png
         ├── 009.png
         ├── 010.png
      ├── corgi/
         ├── 000.png
         ├── 001.png
      ├── dog/
         ├── 000.png
         ├── 001.png
      ├── elephant/
         ├── 000.png
         ├── 001.png
         ├── 002.png
         ├── 003.png
         ├── 004.png
         ├── 005.png
         ├── 006.png
         ├── 007.png
         ├── 008.png
         ├── 009.png
         ├── 010.png
         ├── 011.png
         ├── 012.png
         ├── 013.png
         ├── 014.png
         ├── 015.png
         ├── 016.png
         ├── 017.png
         ├── 018.png
         ├── 019.png
         ├── 020.png
         ├── 021.png
         ├── 022.png
         ├── 023.png
         ├── 024.png
         ├── 025.png
         ├── 026.png
         ├── 027.png
         ├── 028.png
         ├── 029.png
      ├── monkeys/
         ├── 000.png
         ├── 001.png
         ├── 002.png
         ├── 003.png
         ├── 004.png
         ├── 005.png
         ├── 006.png
         ├── 007.png
         ├── 008.png
         ├── 009.png
         ├── 010.png
         ├── 011.png
         ├── 012.png
         ├── 013.png
         ├── 014.png
         ├── 015.png
         ├── 016.png
         ├── 017.png
         ├── 018.png
         ├── 019.png
         ├── 020.png
         ├── 021.png
         ├── 022.png
         ├── 023.png
         ├── 024.png
         ├── 025.png
         ├── 026.png
         ├── 027.png
         ├── 028.png
         ├── 029.png
         ├── 030.png
      ├── parkour/
         ├── 000.png
         ├── 001.png
├── scripts/
   ├── infer.py (2k tokens)
   ├── user_mask.py (1700 tokens)
   ├── view.py (5.2k tokens)
├── trace_anything/
   ├── __init__.py (100 tokens)
   ├── heads.py (1500 tokens)
   ├── layers/
      ├── __init__.py (100 tokens)
      ├── blocks.py (3.3k tokens)
      ├── dpt_block.py (3.2k tokens)
      ├── patch_embed.py (800 tokens)
      ├── pos_embed.py (1500 tokens)
   ├── trace_anything.py (3.2k tokens)
```


## /LICENSE.txt

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


## /LICENSE_MODEL.txt

Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0)

The model weights, including all .pth, .pt, .bin, .safetensors, or other checkpoint files
released with this repository, are licensed under the Creative Commons
Attribution-NonCommercial 4.0 International License.

You may obtain a copy of this license at:
    https://creativecommons.org/licenses/by-nc/4.0/

You are free to share and adapt the model weights for non-commercial purposes,
provided that you give appropriate credit and indicate if changes were made.

Commercial use of the model weights is prohibited.

© 2025 Bytedance Ltd. and/or its affiliates.


## /README.md



# Trace Anything: Representing Any Video in 4D via Trajectory Fields
<p align="center">
  <a href="https://trace-anything.github.io/">
    <img src="https://img.shields.io/badge/Project%20Page-222222?logo=googlechrome&logoColor=white" alt="Project Page">
  </a>
  <a href="https://arxiv.org/abs/2510.13802">
    <img src="https://img.shields.io/badge/arXiv-b31b1b?logo=arxiv&logoColor=white" alt="arXiv">
  </a>
  <a href="https://youtu.be/J6y5l2E6qjA">
    <img src="https://img.shields.io/badge/YouTube-ea3323?logo=youtube&logoColor=white" alt="YouTube Video">
  </a>
  <a href="https://trace-anything.github.io/viser-client/interactive.html">
    <img src="https://img.shields.io/badge/🖐️ Interactive%20Results-2b7a78?logoColor=white" alt="Interactive Results">
  </a>
  <a href="https://huggingface.co/depth-anything/trace-anything">
    <img src="https://img.shields.io/badge/Model-f4b400?logo=huggingface&logoColor=black" alt="Hugging Face Model">
  </a>
</p>



<div align="center" class="is-size-5 publication-authors">
  <span class="author-block">
    <a href="https://xinhangliu.com/">Xinhang Liu</a><sup>1,2</sup>&nbsp;&nbsp;&nbsp;&nbsp;
  </span>
  <span class="author-block">
    <a href="https://henry123-boy.github.io/">Yuxi Xiao</a><sup>1,3</sup>&nbsp;&nbsp;&nbsp;&nbsp;
  </span>
  <span class="author-block">
    <a href="https://donydchen.github.io/">Donny Y. Chen</a><sup>1</sup>&nbsp;&nbsp;&nbsp;&nbsp;
  </span>
  <span class="author-block">
    <a href="https://scholar.google.com.sg/citations?user=Q8iay0gAAAAJ&hl=en">Jiashi Feng</a><sup>1</sup>&nbsp;&nbsp;&nbsp;&nbsp;
  </span>
  <br>
  <span class="author-block">
    <a href="https://yuwingtai.github.io/">Yu-Wing Tai</a><sup>4</sup>&nbsp;&nbsp;&nbsp;&nbsp;
  </span>
  <span class="author-block">
    <a href="https://cse.hkust.edu.hk/~cktang/bio.html">Chi-Keung Tang</a><sup>2</sup>&nbsp;&nbsp;&nbsp;&nbsp;
  </span>
  <span class="author-block">
    <a href="https://bingykang.github.io/">Bingyi Kang</a><sup>1</sup>&nbsp;&nbsp;&nbsp;&nbsp;
  </span>
</div>

<br>

<div align="center" class="is-size-5 publication-authors">
  <span class="author-block"><sup>1</sup>Bytedance Seed</span>&nbsp;&nbsp;&nbsp;&nbsp;
  <span class="author-block"><sup>2</sup>HKUST</span>&nbsp;&nbsp;&nbsp;&nbsp;
  <span class="author-block"><sup>3</sup>Zhejiang University</span>&nbsp;&nbsp;&nbsp;&nbsp;
  <span class="author-block"><sup>4</sup>Dartmouth College</span>
</div>


## Overview
We propose a 4D video representation, __trajectory field__, which maps each pixel across frames to a continuous, parametric 3D trajectory. With a single forward pass, the __Trace Anything__ model efficiently estimates such trajectory fields for any video, image pair, or unstructured image set.
This repository provides the official PyTorch implementation for running inference with the Trace Anything model and exploring trajectory fields in an interactive 3D viewer.

  ![Teaser](assets/teaser.png)




##  Setup

### Create and activate environment

```bash
# Clone the repository
git clone https://github.com/ByteDance-Seed/TraceAnything.git
cd TraceAnything

# Create and activate environment
conda create -n trace_anything python=3.10
conda activate trace_anything
```

### Requirements

* **Python** ≥ 3.10
* **PyTorch** (install according to your CUDA/CPU setup)
* **Dependencies**:

  ```bash
  pip install einops omegaconf pillow opencv-python viser imageio matplotlib torchvision
  ```

**Notes**

- **CUDA:** Tested with **CUDA 12.8**.  
- **GPU Memory:** The provided examples are tested to run on a **single GPU with ≥ 48 GB VRAM**.


### Model weights

Download the pretrained **[model](https://huggingface.co/depth-anything/trace-anything/resolve/main/trace_anything.pt?download=true)** and place it at:

```text
checkpoints/trace_anything.pt
```



##  Inference

We provide example input videos and image pairs under `examples/input`.
Each subdirectory corresponds to a scene:

```
examples/
  input/
    scene_name_1/
      ...
    scene_name_2/
      ...
```

The inference script loads images from these scene folders and produces outputs.

---

### Notes

* Images must satisfy `W ≥ H`. (Portrait images are automatically transposed.)
* Images are resized so that the long side = **512**, then cropped to the nearest multiple of 16 (a model requirement).
* If the number of views exceeds 40, the script automatically downsamples.
* (Advanced) The script assumes input images are ordered in time (e.g., video frames or paired images). Support for unstructured, unordered inputs will be released in the future.

---

### Running inference

Run the model over all scenes:

```bash
python scripts/infer.py
```

#### Default arguments

You can override these paths with flags:

* `--config     configs/eval.yaml`
* `--ckpt       checkpoints/trace_anything.pt`
* `--input_dir  examples/input`
* `--output_dir examples/output`

#### Example

```bash
python scripts/infer.py \
  --input_dir examples/input \
  --output_dir examples/output \
  --ckpt checkpoints/trace_anything.pt
```

Results are saved to:

```text
<output_dir>/<scene>/output.pt
```

---

### What’s inside `output.pt`?

* `preds[i]['ctrl_pts3d']` — 3D control points, shape `[K, H, W, 3]`
* `preds[i]['ctrl_conf']` — confidence maps, shape `[K, H, W]`
* `preds[i]['fg_mask']` — binary mask `[H, W]`, computed via Otsu thresholding on control-point variance.
  (Mask images are also saved under `<output_dir>/<scene>/masks`.)
* `preds[i]['time']` — predicted scalar time ∈ `[0, 1)`.

  > Even though the true timestamp is implicit from known sequence order, the network’s timestamp head still estimates it.
* `views[i]['img']` — normalized input image tensor ∈ `[-1, 1]`



## Optional: User-Guided Masks with SAM2

If you prefer **user-guided SAM2 masks** instead of the automatic masks computed from Trace Anything outputs (for visualization), we provide a helper script [`scripts/user_mask.py`](scripts/user_mask.py). This script lets you interactively select points on the first frame of a scene to produce per-frame foreground masks. 

Install [SAM2](https://github.com/facebookresearch/sam2) and download its checkpoint. Then run with:

```bash
python scripts/user_mask.py --scene <output_scene_dir> \
  --sam2_cfg configs/sam2.1/sam2.1_hiera_l.yaml \
  --sam2_ckpt <path_to_sam2_ckpt>
```

This will saves masks to:

  ```
  <scene>/masks/{i:03d}_user.png
  ```
It also updates `<scene>/output.pt` with:

  ```python
  preds[i]["fg_mask_user"]
  ```

When visualizing, `fg_mask_user` will automatically be preferred over `fg_mask` if available.




##  Interactive Visualization 🚀
Our visualizer lets you explore the trajectory field interactively:

![Interactive trajectory field demo](./assets/interactive_monkeys.gif)


Fire up the interactive 3D viewer and dive your trajectory fields:

```bash
python scripts/view.py --output examples/output/<scene>/output.pt
```

### Useful flags

* `--port 8020` — set viewer port
* `--t_step 0.025` — timeline step (smaller = more fine-grained curve evaluation)
* `--ds 2` — downsample all data by `::2` for extra speed


### Remote use (SSH port-forwarding)

  ```bash
  ssh -N -L 8020:localhost:8020 <user>@<server>
  # Then open http://localhost:8020 locally
  ```

### Trajectory panel
  Input a frame number, or simply type `"mid"` / `"last"`.
  Then hit **Build / Refresh** to construct trajectories, and toggle **Show trajectories** to view them.

  ![Trajectories Panel](assets/panel_screenshot.png)

### Play around! 🎉

  * Pump up or shrink point size
  * Filter out noisy background / foreground points by confidence
  * Drag to swivel the viewpoint
  * Slide through time and watch the trajectories evolve

## Acknowledgements
We sincerely thank the authors of the open-source repositories [DUSt3R](https://github.com/naver/dust3r), [Fast3R](https://github.com/facebookresearch/fast3r), [VGGT](https://github.com/facebookresearch/vggt), [MonST3R](https://github.com/Junyi42/monst3r), [Easi3R](https://github.com/Inception3D/Easi3R), [St4RTrack](https://github.com/HavenFeng/St4RTrack?tab=readme-ov-file), [POMATO](https://github.com/wyddmw/POMATO?tab=readme-ov-file), [SpaTrackerV2](https://github.com/henry123-boy/SpaTrackerV2) and [Viser](https://github.com/nerfstudio-project/viser) for their inspiring and high-quality work that greatly contributed to this project.


## License

- **Code**: Licensed under the [Apache 2.0 License](http://www.apache.org/licenses/LICENSE-2.0).  
- **Model weights**: Licensed under the [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/). These weights are provided for research and non-commercial use only.

## Citation

If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:

```bibtex
@misc{liu2025traceanythingrepresentingvideo,
      title={Trace Anything: Representing Any Video in 4D via Trajectory Fields}, 
      author={Xinhang Liu and Yuxi Xiao and Donny Y. Chen and Jiashi Feng and Yu-Wing Tai and Chi-Keung Tang and Bingyi Kang},
      year={2025},
      eprint={2510.13802},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2510.13802}, 
}
```


## /assets/interactive_monkeys.gif

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/assets/interactive_monkeys.gif

## /assets/panel_screenshot.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/assets/panel_screenshot.png

## /assets/teaser.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/assets/teaser.png

## /configs/eval.yaml

```yaml path="/configs/eval.yaml" 
# trace_anything/configs/eval.yaml

hydra:
  run:
    dir: .        
  job:
    chdir: false
  output_subdir: null
  job_logging:
    disable_existing_loggers: true
    root:
      handlers: []
    handlers: {}
  verbose: false

ckpt_path: checkpoints/trace_anything.ckpt

model:
  _target_: fast3r.models.multiview_dust3r_module.MultiViewDUSt3RLitModule

  net:
    _target_: fast3r.models.fast3r.Fast3R
    encoder_args:
      encoder_type: croco
      img_size: 512
      patch_size: 16
      patch_embed_cls: ManyAR_PatchEmbed
      embed_dim: 1024
      num_heads: 16
      depth: 24
      mlp_ratio: 4
      pos_embed: RoPE100
      attn_implementation: flash_attention
    decoder_args:
      decoder_type: fast3r
      random_image_idx_embedding: false
      enc_embed_dim: 1024
      embed_dim: 1024
      num_heads: 16
      depth: 24
      mlp_ratio: 4.0
      qkv_bias: true
      drop: 0.0
      attn_drop: 0.0
      attn_implementation: flash_attention
    head_args:
      head_type: dpt
      output_mode: pts3d
      landscape_only: true
      depth_mode: [exp, -.inf, .inf]
      conf_mode: [exp, 1, .inf]
      patch_size: 16
      with_local_head: true
    freeze: none
    targeting_mechanism: bspline_conf   
    poly_degree: 10

  train_criterion:
    _target_: fast3r.dust3r.losses.ConfLossMultiviewV2
    pixel_loss:
      _target_: fast3r.dust3r.losses.Regr3DMultiviewV3
      criterion:
        _target_: fast3r.dust3r.losses.L21Loss
      norm_mode: avg_dis
    alpha: 0.2

  validation_criterion:
    _target_: fast3r.dust3r.losses.ConfLossMultiviewV2
    pixel_loss:
      _target_: fast3r.dust3r.losses.Regr3DMultiviewV3
      criterion:
        _target_: fast3r.dust3r.losses.L21Loss
      norm_mode: avg_dis
    alpha: 0.2

  optimizer:
    _target_: torch.optim.AdamW
    _partial_: true
    lr: 1.0e-4
    betas: [0.9, 0.95]
    weight_decay: 0.05

  scheduler:
    _target_: pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR
    _partial_: true
    warmup_epochs: 1
    max_epochs: 10
    eta_min: 1.0e-6

  compile: false

```

## /examples/input/cabinet/000.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/000.png

## /examples/input/cabinet/001.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/001.png

## /examples/input/cabinet/002.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/002.png

## /examples/input/cabinet/003.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/003.png

## /examples/input/cabinet/004.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/004.png

## /examples/input/cabinet/005.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/005.png

## /examples/input/cabinet/006.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/006.png

## /examples/input/cabinet/007.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/007.png

## /examples/input/cabinet/008.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/008.png

## /examples/input/cabinet/009.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/009.png

## /examples/input/cabinet/010.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/010.png

## /examples/input/corgi/000.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/corgi/000.png

## /examples/input/corgi/001.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/corgi/001.png

## /examples/input/dog/000.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/dog/000.png

## /examples/input/dog/001.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/dog/001.png

## /examples/input/elephant/000.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/000.png

## /examples/input/elephant/001.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/001.png

## /examples/input/elephant/002.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/002.png

## /examples/input/elephant/003.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/003.png

## /examples/input/elephant/004.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/004.png

## /examples/input/elephant/005.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/005.png

## /examples/input/elephant/006.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/006.png

## /examples/input/elephant/007.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/007.png

## /examples/input/elephant/008.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/008.png

## /examples/input/elephant/009.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/009.png

## /examples/input/elephant/010.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/010.png

## /examples/input/elephant/011.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/011.png

## /examples/input/elephant/012.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/012.png

## /examples/input/elephant/013.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/013.png

## /examples/input/elephant/014.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/014.png

## /examples/input/elephant/015.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/015.png

## /examples/input/elephant/016.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/016.png

## /examples/input/elephant/017.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/017.png

## /examples/input/elephant/018.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/018.png

## /examples/input/elephant/019.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/019.png

## /examples/input/elephant/020.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/020.png

## /examples/input/elephant/021.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/021.png

## /examples/input/elephant/022.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/022.png

## /examples/input/elephant/023.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/023.png

## /examples/input/elephant/024.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/024.png

## /examples/input/elephant/025.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/025.png

## /examples/input/elephant/026.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/026.png

## /examples/input/elephant/027.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/027.png

## /examples/input/elephant/028.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/028.png

## /examples/input/elephant/029.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/029.png

## /examples/input/monkeys/000.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/000.png

## /examples/input/monkeys/001.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/001.png

## /examples/input/monkeys/002.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/002.png

## /examples/input/monkeys/003.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/003.png

## /examples/input/monkeys/004.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/004.png

## /examples/input/monkeys/005.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/005.png

## /examples/input/monkeys/006.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/006.png

## /examples/input/monkeys/007.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/007.png

## /examples/input/monkeys/008.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/008.png

## /examples/input/monkeys/009.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/009.png

## /examples/input/monkeys/010.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/010.png

## /examples/input/monkeys/011.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/011.png

## /examples/input/monkeys/012.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/012.png

## /examples/input/monkeys/013.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/013.png

## /examples/input/monkeys/014.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/014.png

## /examples/input/monkeys/015.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/015.png

## /examples/input/monkeys/016.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/016.png

## /examples/input/monkeys/017.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/017.png

## /examples/input/monkeys/018.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/018.png

## /examples/input/monkeys/019.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/019.png

## /examples/input/monkeys/020.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/020.png

## /examples/input/monkeys/021.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/021.png

## /examples/input/monkeys/022.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/022.png

## /examples/input/monkeys/023.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/023.png

## /examples/input/monkeys/024.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/024.png

## /examples/input/monkeys/025.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/025.png

## /examples/input/monkeys/026.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/026.png

## /examples/input/monkeys/027.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/027.png

## /examples/input/monkeys/028.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/028.png

## /examples/input/monkeys/029.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/029.png

## /examples/input/monkeys/030.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/030.png

## /examples/input/parkour/000.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/parkour/000.png

## /examples/input/parkour/001.png

Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/parkour/001.png

## /scripts/infer.py

```py path="/scripts/infer.py" 
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# scripts/infer.py
"""
Run inference on all scenes and save:
  - <scene>/output.pt with {'preds','views'}
  - <scene>/masks/{i:03d}.png   (binary FG masks)
  - <scene>/images/{i:03d}.png  (RGB frames used for inference)

Masks are computed from ctrl-pt variance + smart Otsu.
"""

import sys, pathlib
sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))

import os
import cv2
import time
import argparse
from typing import List, Dict

import numpy as np
import torch
from PIL import Image
import torchvision.transforms as tvf
from omegaconf import OmegaConf

from trace_anything.trace_anything import TraceAnything


def _pretty(msg: str) -> None:
    print(f"[{time.strftime('%H:%M:%S')}] {msg}")


# allow ${python_eval: ...} in YAML if used
OmegaConf.register_new_resolver("python_eval", lambda code: eval(code))


# ---------------- image I/O ----------------
def _resize_long_side(pil: Image.Image, long: int = 512) -> Image.Image:
    w, h = pil.size
    if w >= h:
        return pil.resize((long, int(h * long / w)), Image.BILINEAR)
    else:
        return pil.resize((int(w * long / h), long), Image.BILINEAR)


def _load_images(input_dir: str, device: torch.device) -> List[Dict]:
    """Read images, rotate portrait->landscape, resize(long=512), crop to 16-multiple, normalize [-1,1]."""
    tfm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5,)*3, (0.5,)*3)])

    fnames = sorted(
        f for f in os.listdir(input_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg")) and "_vis" not in f
    )
    if not fnames:
        raise FileNotFoundError(f"No images in {input_dir}")

    views, target = [], None
    for i, f in enumerate(fnames):
        arr = cv2.imread(os.path.join(input_dir, f))
        if arr is None:
            raise FileNotFoundError(f"Failed to read {f}")
        pil = Image.fromarray(cv2.cvtColor(arr, cv2.COLOR_BGR2RGB))

        W0, H0 = pil.size
        if H0 > W0:  # portrait -> landscape
            pil = pil.transpose(Image.Transpose.ROTATE_90)

        pil = _resize_long_side(pil, 512)
        if target is None:
            H, W = pil.size[1], pil.size[0]
            target = (H - H % 16, W - W % 16)
            _pretty(f"📐 target size: {target[0]}x{target[1]} (16-multiple)")
        Ht, Wt = target
        pil = pil.crop((0, 0, Wt, Ht))

        tensor = tfm(pil).unsqueeze(0).to(device)  # [1,3,H,W]
        t = i / (len(fnames) - 1) if len(fnames) > 1 else 0.0
        views.append({"img": tensor, "time_step": t})
    return views


# ---------------- ckpt + model ----------------
def _get_state_dict(ckpt: dict) -> dict:
    """Accept either a pure state_dict or a Lightning .ckpt."""
    if isinstance(ckpt, dict) and "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
        return ckpt["state_dict"]
    return ckpt


def _load_cfg(cfg_path: str):
    if not os.path.isfile(cfg_path):
        raise FileNotFoundError(cfg_path)
    return OmegaConf.load(cfg_path)


def _to_dict(x):
    # OmegaConf -> plain dict
    return OmegaConf.to_container(x, resolve=True) if not isinstance(x, dict) else x


def _build_model_from_cfg(cfg, ckpt_path: str, device: torch.device) -> torch.nn.Module:
    if not os.path.isfile(ckpt_path):
        raise FileNotFoundError(ckpt_path)

    # net config
    net_cfg = cfg.get("model", {}).get("net", None) or cfg.get("net", None)
    if net_cfg is None:
        raise KeyError("expect cfg.model.net or cfg.net in YAML")

    model = TraceAnything(
        encoder_args=_to_dict(net_cfg["encoder_args"]),
        decoder_args=_to_dict(net_cfg["decoder_args"]),
        head_args=_to_dict(net_cfg["head_args"]),
        targeting_mechanism=net_cfg.get("targeting_mechanism", "bspline_conf"),
        poly_degree=net_cfg.get("poly_degree", 10),
        whether_local=False,
    )

    ckpt = torch.load(ckpt_path, map_location="cpu")
    sd = _get_state_dict(ckpt)

    if all(k.startswith("net.") for k in sd.keys()):
        sd = {k[4:]: v for k, v in sd.items()}

    model.load_state_dict(sd, strict=False)
    model.to(device).eval()
    return model


# ---------------- smart var threshold ----------------
def _otsu_threshold_from_hist(hist: np.ndarray, bin_edges: np.ndarray) -> float | None:
    total = hist.sum()
    if total <= 0:
        return None
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0
    w1 = np.cumsum(hist)
    w2 = total - w1
    sum_total = (hist * bin_centers).sum()
    sumB = np.cumsum(hist * bin_centers)
    valid = (w1 > 0) & (w2 > 0)
    if not np.any(valid):
        return None
    m1 = sumB[valid] / w1[valid]
    m2 = (sum_total - sumB[valid]) / w2[valid]
    between = w1[valid] * w2[valid] * (m1 - m2) ** 2
    idx = np.argmax(between)
    return float(bin_centers[valid][idx])


def _smart_var_threshold(var_map_t: torch.Tensor) -> float:
    """
    1) log-transform variance
    2) Otsu on histogram
    3) fallback to 65–80% mid-quantile midpoint
    Returns threshold in original variance domain.
    """
    var_np = var_map_t.detach().float().cpu().numpy()
    v = np.log(var_np + 1e-9)
    hist, bin_edges = np.histogram(v, bins=256)
    thr_log = _otsu_threshold_from_hist(hist, bin_edges)
    if thr_log is None or not np.isfinite(thr_log):
        q65 = float(np.quantile(var_np, 0.65))
        q80 = float(np.quantile(var_np, 0.80))
        return 0.5 * (q65 + q80)
    thr_var = float(np.exp(thr_log))
    q40 = float(np.quantile(var_np, 0.40))
    q95 = float(np.quantile(var_np, 0.95))
    return max(q40, min(q95, thr_var))


# ---------------- main loop ----------------
def run(args):
    base_in = args.input_dir
    base_out = args.output_dir

    if not os.path.isdir(base_in):
        raise FileNotFoundError(base_in)
    os.makedirs(base_out, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # config & model
    cfg = _load_cfg(args.config)
    _pretty("🔧 loading model …")
    model = _build_model_from_cfg(cfg, ckpt_path=args.ckpt, device=device)
    _pretty("✅ model ready")

    # iterate scenes
    for scene in sorted(os.listdir(base_in)):
        in_dir = os.path.join(base_in, scene)
        if not os.path.isdir(in_dir):
            continue
        out_dir = os.path.join(base_out, scene)
        masks_dir = os.path.join(out_dir, "masks")
        images_dir = os.path.join(out_dir, "images")
        os.makedirs(out_dir, exist_ok=True)
        os.makedirs(masks_dir, exist_ok=True)
        os.makedirs(images_dir, exist_ok=True)

        _pretty(f"\n📂 Scene: {scene}")
        _pretty("🖼️  loading images …")
        views = _load_images(in_dir, device=device)
        if len(views) > 40:
            stride = max(1, len(views) // 39)   # floor division
            views = views[::stride]
        _pretty(f"🧮 {len(views)} views loaded")

        _pretty("🚀 inference …")
        t0 = time.perf_counter()
        with torch.no_grad():
            preds = model.forward(views)
        dt = time.perf_counter() - t0
        ms_per_view = (dt / max(1, len(views))) * 1000.0
        _pretty(f"✅ done | {dt:.2f}s total | {ms_per_view:.1f} ms/view")

        # ---- compute + save FG masks and images ----
        _pretty("🧪 computing FG masks + saving frames …")
        for i, pred in enumerate(preds):
            # variance map over control points (K), mean over xyz -> [H,W]
            ctrl_pts3d = pred["ctrl_pts3d"]
            ctrl_pts3d_t = torch.from_numpy(ctrl_pts3d) if isinstance(ctrl_pts3d, np.ndarray) else ctrl_pts3d
            var_map = torch.var(ctrl_pts3d_t, dim=0, unbiased=False).mean(-1)  # [H,W]
            thr = _smart_var_threshold(var_map)
            fg_mask = (~(var_map <= thr)).detach().cpu().numpy().astype(bool)

            # save mask as binary PNG and stash in preds
            cv2.imwrite(os.path.join(masks_dir, f"{i:03d}.png"), (fg_mask.astype(np.uint8) * 255))
            pred["fg_mask"] = torch.from_numpy(fg_mask)  # CPU bool tensor

            # also save the RGB image we actually ran on
            img = views[i]["img"].detach().cpu().squeeze(0)  # [3,H,W] in [-1,1]
            img_np = (img.permute(1, 2, 0).numpy() + 1.0) * 127.5
            img_uint8 = np.clip(img_np, 0, 255).astype(np.uint8)
            cv2.imwrite(os.path.join(images_dir, f"{i:03d}.png"), cv2.cvtColor(img_uint8, cv2.COLOR_RGB2BGR))

            # trim heavy intermediates just in case
            pred.pop("track_pts3d", None)
            pred.pop("track_conf", None)

        # persist
        save_path = os.path.join(out_dir, "output.pt")
        torch.save({"preds": preds, "views": views}, save_path)
        _pretty(f"💾 saved: {save_path}")
        _pretty(f"🖼️  masks → {masks_dir} | images → {images_dir}")

def parse_args():
    p = argparse.ArgumentParser("TraceAnything inference")
    p.add_argument("--config", type=str, default="configs/eval.yaml",
                   help="Path to YAML config")
    p.add_argument("--ckpt", type=str, default="checkpoints/trace_anything.pt",
                   help="Path to the checkpoint")
    p.add_argument("--input_dir", type=str, default="./examples/input",
                   help="Directory containing scenes (each subfolder is a scene)")
    p.add_argument("--output_dir", type=str, default="./examples/output",
                   help="Directory to write scene outputs")
    return p.parse_args()

def main():
    args = parse_args()
    run(args)

if __name__ == "__main__":
    main()

```

## /scripts/user_mask.py

```py path="/scripts/user_mask.py" 
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/env python3
# scripts/user_mask.py
"""
Interactive points on frame 0 -> SAM2 video propagation -> user masks.

Saves per-frame masks to <scene>/masks/{i:03d}_user.png
and stores preds[i]["fg_mask_user"] in <scene>/output.pt.
"""

import sys, pathlib; sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
import os, argparse, shutil, tempfile, inspect
from typing import Dict, List, Optional
import numpy as np, torch, cv2

def dilate_mask(mask: np.ndarray, ksize: int = 3, iterations: int = 2) -> np.ndarray:
    kernel = np.ones((ksize, ksize), np.uint8)
    return cv2.dilate(mask.astype(np.uint8), kernel, iterations=iterations).astype(bool)


def _pretty(msg:str):
    import time; print(f"[{time.strftime('%H:%M:%S')}] {msg}")

def load_output(scene:str)->Dict:
    p=os.path.join(scene,"output.pt"); 
    if not os.path.isfile(p): raise FileNotFoundError(p)
    out=torch.load(p,map_location="cpu"); out["_root_dir"]=scene; return out

def tensor_to_rgb(img3hw:torch.Tensor)->np.ndarray:
    arr=(img3hw.detach().cpu().permute(1,2,0).numpy()+1.0)*127.5
    return np.clip(arr,0,255).astype(np.uint8)

def build_tmp_jpegs(frames_rgb: List[np.ndarray], scene_dir: str) -> str:
    tmp = os.path.join(scene_dir, "tmp_jpg")
    os.makedirs(tmp, exist_ok=True)
    for i, fr in enumerate(frames_rgb):
        fn = os.path.join(tmp, f"{i:06d}.jpg")
        cv2.imwrite(fn, cv2.cvtColor(fr, cv2.COLOR_RGB2BGR), [int(cv2.IMWRITE_JPEG_QUALITY), 95])
    return tmp

def ask_points(frame_bgr:np.ndarray, preview_dir:str)->np.ndarray:
    H,W=frame_bgr.shape[:2]
    while True:
        print("Enter points as: x y   (ENTER empty line to finish)")
        pts=[]
        while True:
            s=input(">>> ").strip()
            if s=="": break
            try:
                x,y=map(float,s.split()); pts.append([x,y,1.0])
            except: print("  ex: 320 180")
        if not pts:
            a=input("No points. [q]=quit, [r]=retry > ").strip().lower()
            if a.startswith("q"): raise SystemExit(0)
            else: continue
        pts=np.array(pts,dtype=np.float32)
        pts[:,0]=np.clip(pts[:,0],0,W-1); pts[:,1]=np.clip(pts[:,1],0,H-1)
        vis=frame_bgr.copy()
        for x,y,_ in pts:
            p=(int(round(x)),int(round(y)))
            cv2.circle(vis,p,4,(0,255,0),-1,cv2.LINE_AA); cv2.circle(vis,p,8,(0,255,0),1,cv2.LINE_AA)
        prev_path = os.path.join(preview_dir, "prompts_preview.png")
        cv2.imwrite(prev_path, vis)
        print(f"Preview: {os.path.abspath(prev_path)}")
        if input("Accept? [y]/n/q: ").strip().lower().startswith("y"): return pts
        if input("Retry or quit? [r]/q: ").strip().lower().startswith("q"): raise SystemExit(0)

from sam2.build_sam import build_sam2_video_predictor

def run_propagation(model_cfg:str, ckpt_path:str, jpg_dir:str, points_xy1:np.ndarray,
                    ann_frame_idx:int=0, ann_obj_id:int=1)->List[np.ndarray]:
    if not os.path.isfile(ckpt_path): raise FileNotFoundError(ckpt_path)
    predictor=build_sam2_video_predictor(model_cfg, ckpt_path)
    coords=points_xy1[:, :2][None,...].astype(np.float32); labels=points_xy1[:, 2][None,...].astype(np.int32)
    use_cuda=torch.cuda.is_available()
    autocast=torch.autocast("cuda",dtype=torch.bfloat16) if use_cuda else torch.cuda.amp.autocast(enabled=False)
    with torch.inference_mode(), autocast:
        state=predictor.init_state(jpg_dir)
        predictor.add_new_points_or_box(inference_state=state, frame_idx=ann_frame_idx,
                                        obj_id=ann_obj_id, points=coords, labels=labels)
        n=len([f for f in os.listdir(jpg_dir) if f.lower().endswith(".jpg")])
        masks:List[Optional[np.ndarray]]=[None]*n
        sig=inspect.signature(predictor.propagate_in_video)
        kwargs={"inference_state":state}
        if "start_frame_idx" in sig.parameters and "end_frame_idx" in sig.parameters:
            kwargs.update(start_frame_idx=ann_frame_idx,end_frame_idx=n-1)
        for yielded in predictor.propagate_in_video(**kwargs):
            if isinstance(yielded,tuple) and len(yielded)==3:
                fi,obj_ids,ms=yielded
            else:
                fi=int(yielded["frame_idx"]); obj_ids=yielded.get("object_ids") or yielded.get("obj_ids"); ms=yielded["masks"]
            pick=None
            for oid,m in zip(obj_ids,ms):
                if int(oid)==ann_obj_id: pick=m; break
            if pick is None:
                pick = masks[0]
            if isinstance(pick,torch.Tensor): pick=pick.detach().cpu().numpy()
            pick=np.asarray(pick); 
            if pick.ndim==3 and pick.shape[0]==1: pick=pick[0]
            masks[int(fi)]=(pick>0.5)
        if hasattr(predictor,"get_frame_masks"):
            for i in range(n):
                if masks[i] is None:
                    oids,ms=predictor.get_frame_masks(state,i)
                    pick=None
                    for oid,m in zip(oids,ms):
                        if int(oid)==ann_obj_id: pick=m; break
                    pick=pick or (ms[0] if ms else None)
                    if isinstance(pick,torch.Tensor): pick=pick.detach().cpu().numpy()
                    if pick is None: continue
                    pick=np.asarray(pick)
                    if pick.ndim==3 and pick.shape[0]==1: pick=pick[0]
                    masks[i]=(pick>0.5)
    H=W=None
    for m in masks:
        if m is not None: H,W=m.shape; break
    if H is None: raise RuntimeError("SAM2 produced no masks.")
    return [dilate_mask(m) if m is not None else np.zeros((H, W), bool) for m in masks]


def save_user_masks(scene:str, masks:List[np.ndarray], preds:List[Dict], views:List[Dict]):
    mdir=os.path.join(scene,"masks"); os.makedirs(mdir,exist_ok=True)
    for i,m in enumerate(masks):
        cv2.imwrite(os.path.join(mdir,f"{i:03d}_user.png"), (m.astype(np.uint8)*255))
        preds[i]["fg_mask_user"]=torch.from_numpy(m.astype(bool))
    torch.save({"preds":preds,"views":views}, os.path.join(scene,"output.pt"))

def parse_args():
    p=argparse.ArgumentParser("User mask via SAM2 video propagation")
    p.add_argument("--scene", type=str, default="./examples/output/breakdance",
                   help="Scene dir containing output.pt")
    p.add_argument("--sam2_cfg", type=str, default="configs/sam2.1/sam2.1_hiera_l.yaml",
                   help="SAM2 config (string or path; passed through)")
    p.add_argument("--sam2_ckpt", type=str, default="../sam2/checkpoints/sam2.1_hiera_large.pt",
                   help="Path to SAM2 checkpoint .pt")
    return p.parse_args()

def main():
    args=parse_args()
    out=load_output(args.scene); preds,views=out["preds"],out["views"]
    if not preds: return _pretty("No frames in output.pt")
    frames=[tensor_to_rgb(views[i]["img"].squeeze(0)) for i in range(len(views))]
    frame0_bgr=cv2.cvtColor(frames[0], cv2.COLOR_RGB2BGR)
    tmp_dir = build_tmp_jpegs(frames, args.scene)
    pts_xy1 = ask_points(frame0_bgr, tmp_dir)
    try:
        _pretty("Propagating with SAM2 …")
        masks = run_propagation(
            model_cfg=args.sam2_cfg,
            ckpt_path=args.sam2_ckpt,
            jpg_dir=tmp_dir,
            points_xy1=pts_xy1,
            ann_frame_idx=0,
            ann_obj_id=1,
        )
        save_user_masks(args.scene, masks, preds, views)
        _pretty(f"✅ Saved user masks to {os.path.join(args.scene, 'masks')} and updated output.pt")
        _pretty(f"🗂️  Kept JPEG frames and preview in: {tmp_dir}")
    finally:
        # shutil.rmtree(tmp,ignore_errors=True)
        _pretty("🧹 Not removing temp JPEG folder (preview included)")

if __name__=="__main__": main()

```

## /scripts/view.py

```py path="/scripts/view.py" 
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# scripts/view.py
# TraceAnything viewer (viser) – no mask recomputation.
# - Loads FG masks from (priority):
#     1) <scene>/masks/{i:03d}_user.png  or preds[i]["fg_mask_user"]
#     2) <scene>/masks/{i:03d}.png       or preds[i]["fg_mask"]
# - BG mask = ~FG
# - Saves per-frame images next to output.pt (does NOT recompute masks)
# - Initial conf filtering: drop bottom 10% (FG/BG)
# - Downsample everything with --ds (H,W stride)

import sys, pathlib
sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))

import os
import time
import threading
import argparse
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import cv2
import viser

# ---- tiny helpers ----
def to_numpy(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.asarray(x)

def as_float(x):
    a = np.asarray(x)
    return float(a.reshape(-1)[0])

def ensure_dir(p):
    os.makedirs(p, exist_ok=True)
    return p

# ---- tiny HSV colormap (no heavy deps) ----
import colorsys
def hsv_colormap(vals01: np.ndarray) -> np.ndarray:
    """vals01: [T] in [0,1] -> [T,3] RGB in [0,1]"""
    vals01 = np.clip(vals01, 0.0, 1.0)
    rgb = [colorsys.hsv_to_rgb(v, 1.0, 1.0) for v in vals01]
    return np.asarray(rgb, dtype=np.float32)

# --- repo fn ---
from trace_anything.trace_anything import evaluate_bspline_conf

# ----------------------- I/O -----------------------
def load_output_dict(path_or_dir: str) -> Dict:
    path = path_or_dir
    if os.path.isdir(path):
        path = os.path.join(path, "output.pt")
    if not os.path.isfile(path):
        raise FileNotFoundError(path)
    out = torch.load(path, map_location="cpu")
    out["_root_dir"] = os.path.dirname(path)
    return out

def save_images(frames: List[Dict], images_dir: str):
    print(f"[viewer] Saving images to: {images_dir}")
    ensure_dir(images_dir)
    for i, fr in enumerate(frames):
        cv2.imwrite(os.path.join(images_dir, f"{i:03d}.png"),
                    cv2.cvtColor(fr["img_rgb_uint8"], cv2.COLOR_RGB2BGR))

# ----------------- mask loading (NO recomputation) -----------------
_REPORTED_MASK_SRC: set[tuple[str, int]] = set()

def _load_fg_mask_for_index(root_dir: str, idx: int, pred: Dict) -> np.ndarray:
    """
    Priority:
      1) masks/{i:03d}_user.png  or pred['fg_mask_user']  (user mask)
      2) masks/{i:03d}.png       or pred['fg_mask']       (raw mask)
    Returns: mask_bool_hw
    """
    def _ret(mask_bool: np.ndarray, src: str) -> np.ndarray:
        key = (root_dir, idx)
        if key not in _REPORTED_MASK_SRC and os.environ.get("TA_SILENCE_MASK_SRC") != "1":
            print(f"[viewer] frame {idx:03d}: using {src}")
            _REPORTED_MASK_SRC.add(key)
        return mask_bool

    # 1) USER
    #   PNG then preds (so external edits override stale preds if any)
    p_png = os.path.join(root_dir, "masks", f"{idx:03d}_user.png")
    if os.path.isfile(p_png):
        arr = cv2.imread(p_png, cv2.IMREAD_GRAYSCALE)
        if arr is not None:
            # p_png_1 = os.path.join(root_dir, "masks", f"{idx:03d}_user_1.png")
            # if os.path.isfile(p_png_1):
            #     arr1 = cv2.imread(p_png_1, cv2.IMREAD_GRAYSCALE)
            #     if arr1 is not None and arr1.shape == arr.shape:
            #         arr = np.maximum(arr, arr1)
            return _ret((arr > 0), "user mask (png)")
    if "fg_mask_user" in pred and pred["fg_mask_user"] is not None:
        m = pred["fg_mask_user"]
        if isinstance(m, torch.Tensor): m = m.detach().cpu().numpy()
        return _ret(np.asarray(m).astype(bool), "user mask (preds)")

    # 2) RAW
    p_png = os.path.join(root_dir, "masks", f"{idx:03d}.png")
    if os.path.isfile(p_png):
        arr = cv2.imread(p_png, cv2.IMREAD_GRAYSCALE)
        if arr is not None:
            return _ret((arr > 0), "raw mask (png)")
    if "fg_mask" in pred and pred["fg_mask"] is not None:
        m = pred["fg_mask"]
        if isinstance(m, torch.Tensor): m = m.detach().cpu().numpy()
        return _ret(np.asarray(m).astype(bool), "raw mask (preds)")

    # --- legacy compatibility ---
    for key, fname, label in [
        ("fg_mask_user", f"{idx:03d}_fg_user.png", "user mask (legacy png)"),
        ("fg_mask",      f"{idx:03d}_fg_refined.png", "refined mask (legacy png)"),
        ("fg_mask_raw",  f"{idx:03d}_fg_raw.png", "raw mask (legacy png)"),
    ]:
        if key in pred and pred[key] is not None:
            m = pred[key]
            if isinstance(m, torch.Tensor): m = m.detach().cpu().numpy()
            return _ret(np.asarray(m).astype(bool), f"{label} (preds)")
        path = os.path.join(root_dir, "masks", fname)
        if os.path.isfile(path):
            arr = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if arr is not None:
                return _ret((arr > 0), label)

    raise FileNotFoundError(
        f"No FG mask for frame {idx}: looked for user/raw (png or preds)."
    )

# -------------- precompute tensors for viewer -------------
def build_precomputes(
    output: Dict,
    t_step: float,
    ds: int,
) -> Tuple[np.ndarray, List[Dict], List[np.ndarray], np.ndarray, np.ndarray]:
    preds = output["preds"]
    views = output["views"]
    n = len(preds)
    assert n == len(views)

    root = output.get("_root_dir", os.getcwd())

    # timeline
    t_vals = np.arange(0.0, 1.0 + 1e-6, t_step, dtype=np.float32)
    if t_vals[-1] >= 1.0:
        t_vals[-1] = 0.99
    T = len(t_vals)
    t_tensor = torch.from_numpy(t_vals)

    frames: List[Dict] = []
    fg_conf_pool_per_t: List[List[np.ndarray]] = [[] for _ in range(T)]
    bg_conf_pool: List[np.ndarray] = []

    stride = slice(None, None, ds)  # ::ds

    for i in range(n):
        pred = preds[i]
        view = views[i]

        # image ([-1,1]) -> uint8 RGB
        img = to_numpy(view["img"].squeeze().permute(1, 2, 0))
        img_uint8 = np.clip((img + 1.0) * 127.5, 0, 255).astype(np.uint8)
        img_uint8 = img_uint8[::ds, ::ds]  # downsample for saving/vis
        H, W = img_uint8.shape[:2]
        HW = H * W
        img_flat = (img_uint8.astype(np.float32) / 255.0).reshape(HW, 3)

        # load FG mask (prefer user, then raw)
        fg_mask = _load_fg_mask_for_index(root, i, pred)

        # match resolution to downsampled view
        # if mask at full-res and we downsampled, stride it; else resize with nearest
        if fg_mask.shape == (H * ds, W * ds) and ds > 1:
            fg_mask = fg_mask[::ds, ::ds]
        elif fg_mask.shape != (H, W):
            fg_mask = cv2.resize(
                (fg_mask.astype(np.uint8) * 255),
                (W, H),
                interpolation=cv2.INTER_NEAREST
            ) > 0

        bg_mask = ~fg_mask
        bg_mask_flat = bg_mask.reshape(-1)
        fg_mask_flat = fg_mask.reshape(-1)

        # control points/conf (K,H,W,[3]) at downsampled stride
        ctrl_pts3d = pred["ctrl_pts3d"][:, stride, stride, :]    # [K,H,W,3]
        ctrl_conf  = pred["ctrl_conf"][:, stride, stride]         # [K,H,W]

        # evaluate curve over T timesteps
        pts3d_t, conf_t = evaluate_bspline_conf(ctrl_pts3d, ctrl_conf, t_tensor)  # [T,H,W,3], [T,H,W]
        pts3d_t = to_numpy(pts3d_t).reshape(T, HW, 3)
        conf_t  = to_numpy(conf_t).reshape(T, HW)

        # FG per t (keep per-t list for later filtering)
        pts_fg_per_t  = [pts3d_t[t][fg_mask_flat] for t in range(T)]
        conf_fg_per_t = [conf_t[t][fg_mask_flat]  for t in range(T)]
        for t in range(T):
            if pts_fg_per_t[t].size > 0:
                fg_conf_pool_per_t[t].append(conf_fg_per_t[t])

        # BG static
        bg_pts = pts3d_t.mean(axis=0)[bg_mask_flat]
        bg_conf_mean = conf_t.mean(axis=0)[bg_mask_flat]
        bg_conf_pool.append(bg_conf_mean)

        frames.append(dict(
            img_rgb_uint8=img_uint8,
            img_rgb_float=img_flat,
            H=H, W=W, HW=HW,
            bg_mask_flat=bg_mask_flat,
            fg_mask_flat=fg_mask_flat,
            pts_fg_per_t=pts_fg_per_t,
            conf_fg_per_t=conf_fg_per_t,
            bg_pts=bg_pts,
            bg_conf_mean=bg_conf_mean,
        ))

    # pools for percentiles
    fg_conf_all_t: List[np.ndarray] = []
    for t in range(T):
        if len(fg_conf_pool_per_t[t]) == 0:
            fg_conf_all_t.append(np.empty((0,), dtype=np.float32))
        else:
            fg_conf_all_t.append(np.concatenate(fg_conf_pool_per_t[t], axis=0).astype(np.float32))

    if len(bg_conf_pool):
        bg_conf_all_flat = np.concatenate(bg_conf_pool, axis=0).astype(np.float32)
    else:
        bg_conf_all_flat = np.empty((0,), dtype=np.float32)

    # frame times (fallback to views' time_step if missing)
    def _get_time(i):
        ti = preds[i].get("time", None)
        if ti is None:
            ti = views[i].get("time_step", float(i / max(1, n - 1)))
        return as_float(ti)
    times = np.array([_get_time(i) for i in range(n)], dtype=np.float64)

    return t_vals, frames, fg_conf_all_t, bg_conf_all_flat, times

def choose_nearest_frame_indices(frame_times: np.ndarray, t_vals: np.ndarray) -> np.ndarray:
    return np.array([int(np.argmin(np.abs(frame_times - tv))) for tv in t_vals], dtype=np.int64)

# ---------------- nodes & state ----------------
class ViewerState:
    def __init__(self):
        self.lock = threading.Lock()
        self.is_updating = False
        self.status_label = None
        self.slider_fg = None
        self.slider_bg = None
        self.slider_time = None
        self.point_size = None
        self.bg_point_size = None
        self.fg_nodes = []  # len T
        self.bg_nodes = []  # len N
        # trajectories
        self.show_traj = None
        self.traj_width = None
        self.traj_frames_text = None
        self.traj_build_btn = None
        self.traj_nodes = []
        self.playing = False
        self.play_btn = None
        self.pause_btn = None
        self.fps_slider = None
        self.loop_checkbox = None

def build_bg_nodes(server: viser.ViserServer, frames: List[Dict], init_percentile: float,
                   bg_conf_all_flat: np.ndarray, state: ViewerState):
    if bg_conf_all_flat.size == 0:
        return
    thr_val = np.percentile(bg_conf_all_flat, init_percentile)
    for i, fr in enumerate(frames):
        keep = fr["bg_conf_mean"] >= thr_val
        pts  = fr["bg_pts"][keep]
        cols = fr["img_rgb_float"][fr["bg_mask_flat"]][keep]
        node = server.scene.add_point_cloud(
            name=f"/bg/frame{i}",
            points=pts,
            colors=cols,
            point_size=state.bg_point_size.value if state.bg_point_size else 0.0002,
            point_shape="rounded",
            visible=True,
        )
        state.bg_nodes.append(node)
        print(f"[viewer] BG frame={i:02d}: add {pts.shape[0]} pts")

def build_fg_nodes(server: viser.ViserServer, frames: List[Dict], nearest_idx: np.ndarray, t_vals: np.ndarray,
                   fg_conf_all_t: List[np.ndarray], init_percentile: float, state: ViewerState):
    print("\n[viewer] Building FG nodes per timeline step …")
    T = len(t_vals)
    for t in range(T):
        fi = int(nearest_idx[t])
        conf_all = fg_conf_all_t[t]
        thr_t = np.percentile(conf_all, init_percentile) if conf_all.size > 0 else np.inf
        conf = frames[fi]["conf_fg_per_t"][t]
        pts  = frames[fi]["pts_fg_per_t"][t]
        keep = conf >= thr_t
        pts_k = pts[keep]
        cols_k = frames[fi]["img_rgb_float"][frames[fi]["fg_mask_flat"]][keep]
        node = server.scene.add_point_cloud(
            name=f"/fg/t{t:02d}",
            points=pts_k,
            colors=cols_k,
            point_size=state.point_size.value if state.point_size else 0.0002,
            point_shape="rounded",
            visible=(t == 0),
        )
        state.fg_nodes.append(node)
        print(f"[viewer] t={t:02d}: add FG node with {pts_k.shape[0]} pts")

def update_conf_filtering(server: viser.ViserServer, state: ViewerState, frames: List[Dict], nearest_idx: np.ndarray,
                          t_vals: np.ndarray, fg_conf_all_t: List[np.ndarray], bg_conf_all_flat: np.ndarray,
                          fg_percentile: float, bg_percentile: float):
    with state.lock:
        if state.is_updating:
            return
        state.is_updating = True
    try:
        if state.status_label:
            state.status_label.value = "⚙️ Filtering… please wait"
        if state.slider_fg: state.slider_fg.disabled = True
        if state.slider_bg: state.slider_bg.disabled = True

        # BG
        thr_bg = np.percentile(bg_conf_all_flat, bg_percentile) if bg_conf_all_flat.size > 0 else np.inf
        print(f"[filter] BG: percentile={bg_percentile:.1f} thr={thr_bg:.6f}")
        for i, node in enumerate(state.bg_nodes):
            conf = frames[i]["bg_conf_mean"]
            keep = conf >= thr_bg
            pts  = frames[i]["bg_pts"][keep]
            cols = frames[i]["img_rgb_float"][frames[i]["bg_mask_flat"]][keep]
            node.points = pts
            node.colors = cols
            print(f"  - frame {i:02d}: keep {pts.shape[0]} pts")
        server.flush()

        # FG
        print(f"[filter] FG: percentile={fg_percentile:.1f}")
        T = len(t_vals)
        for t in range(T):
            fi = int(nearest_idx[t])
            conf_all = fg_conf_all_t[t]
            thr_t = np.percentile(conf_all, fg_percentile) if conf_all.size > 0 else np.inf
            conf = frames[fi]["conf_fg_per_t"][t]
            pts  = frames[fi]["pts_fg_per_t"][t]
            keep = conf >= thr_t
            pts_k = pts[keep]
            cols_k = frames[fi]["img_rgb_float"][frames[fi]["fg_mask_flat"]][keep]
            node = state.fg_nodes[t]
            node.points = pts_k
            node.colors = cols_k
            print(f"  - t {t:02d}: frame {fi:02d}, keep {pts_k.shape[0]} pts")
            if (t % 3) == 0:
                server.flush()
        server.flush()
    finally:
        if state.status_label:
            state.status_label.value = ""
        if state.slider_fg: state.slider_fg.disabled = False
        if state.slider_bg: state.slider_bg.disabled = False
        with state.lock:
            state.is_updating = False

def _update_traj_visibility(state: "ViewerState", server: viser.ViserServer, tidx: int, on: bool):
    if not state.traj_nodes:
        return
    with server.atomic():
        for t, nodes in enumerate(state.traj_nodes):
            vis = on and (t <= tidx)
            for nd in nodes:
                nd.visible = vis
    server.flush()

# ---------------- trajectories ----------------
def build_traj_nodes(
    server: viser.ViserServer,
    output: Dict,
    frames: List[Dict],
    traj_frames: List[int],
    t_vals: np.ndarray,
    max_points: int,
    state: "ViewerState",
):
    # remove old
    for lst in state.traj_nodes:
        for nd in lst:
            nd.remove()
    state.traj_nodes = [[] for _ in range(len(t_vals) - 1)]

    T = len(t_vals)
    vals01 = (np.arange(T - 1, dtype=np.float32)) / max(1, T - 2)
    seg_colors = hsv_colormap(vals01)  # [T-1, 3]

    print("[traj] building …")
    for fi in traj_frames:
        if fi < 0 or fi >= len(frames):
            continue
        fr = frames[fi]
        fg_mask_flat = fr["fg_mask_flat"]
        if not np.any(fg_mask_flat):
            continue

        fg_idx = np.flatnonzero(fg_mask_flat)
        if max_points > 0 and fg_idx.size > max_points:
            sel = np.random.default_rng(42).choice(fg_idx, size=max_points, replace=False)
            inv = {p: j for j, p in enumerate(fg_idx)}
            sel_fg = np.array([inv[p] for p in sel], dtype=np.int64)
        else:
            sel = fg_idx
            sel_fg = np.arange(fg_idx.size, dtype=np.int64)

        arr = np.stack([fr["pts_fg_per_t"][t] for t in range(T)], axis=0)
        if sel_fg.size == 0:
            continue
        arr = arr[:, sel_fg, :]  # [T,N_sel,3]

        for t in range(T - 1):
            p0 = arr[t]
            p1 = arr[t + 1]
            if p0.size == 0:
                continue
            segs = np.stack([p0, p1], axis=1)  # [N,2,3]
            col = np.repeat(seg_colors[t][None, :], segs.shape[0], axis=0)  # [N,3]
            node = server.scene.add_line_segments(
                name=f"/traj/frame{fi}/t{t:02d}",
                points=segs,
                colors=np.repeat(col[:, None, :], 2, axis=1),
                line_width=state.traj_width.value if state.traj_width else 0.075,
                visible=False,
            )
            state.traj_nodes[t].append(node)
    print("[traj] done.")

# ----------------- main viewer -----------------
def serve_view(output: Dict, port: int = 8020, t_step: float = 0.1, ds: int = 2):
    server = viser.ViserServer(port=port)
    server.gui.set_panel_label("TraceAnything Viewer")
    server.gui.configure_theme(control_layout="floating", control_width="medium", show_logo=False)

    # restore camera & scene setup
    server.scene.set_up_direction((0.0, -1.0, 0.0))
    server.scene.world_axes.visible = False

    @server.on_client_connect
    def _on_connect(client: viser.ClientHandle):
        with client.atomic():
            client.camera.position = (-0.00141163, -0.01910395, -0.06794288)
            client.camera.look_at  = (-0.00352821, -0.01143425,  0.01549390)
        client.flush()

    root = output.get("_root_dir", os.getcwd())
    images_dir = ensure_dir(os.path.join(root, "images"))
    masks_dir  = ensure_dir(os.path.join(root, "masks"))

    t_vals, frames, fg_conf_all_t, bg_conf_all_flat, times = build_precomputes(output, t_step, ds)
    nearest_idx = choose_nearest_frame_indices(times, t_vals)

    # save images only (masks are assumed precomputed)
    save_images(frames, images_dir)
    print(f"[viewer] Using precomputed FG masks in: {masks_dir} (or preds['fg_mask*'])")

    state = ViewerState()
    with server.gui.add_folder("Point Size", expand_by_default=True):
        state.point_size = server.gui.add_slider("FG Point Size", min=1e-5, max=1e-3, step=1e-4, initial_value=0.0002)
        state.bg_point_size = server.gui.add_slider("BG Point Size", min=1e-5, max=1e-3, step=1e-4, initial_value=0.0002)

    with server.gui.add_folder("Confidence Filtering", expand_by_default=True):
        state.slider_fg = server.gui.add_slider("FG percentile (drop bottom %)", min=0, max=100, step=1, initial_value=10)
        state.slider_bg = server.gui.add_slider("BG percentile (drop bottom %)", min=0, max=100, step=1, initial_value=10)

    with server.gui.add_folder("Playback", expand_by_default=True):
        state.slider_time = server.gui.add_slider("Time", min=0.0, max=1.0, step=t_step, initial_value=0.0)
        state.play_btn = server.gui.add_button("▶ Play")
        state.pause_btn = server.gui.add_button("⏸ Pause")
        state.fps_slider = server.gui.add_slider("FPS", min=1, max=60, step=1, initial_value=10)
        state.loop_checkbox = server.gui.add_checkbox("Loop", True)

    # ---- Trajectories panel ----
    with server.gui.add_folder("Trajectories", expand_by_default=True):
        state.show_traj = server.gui.add_checkbox("Show trajectories", False)
        state.traj_width = server.gui.add_slider("Line width", min=0.01, max=0.2, step=0.005, initial_value=0.075)
        state.traj_frames_text = server.gui.add_text("Frames (e.g. 0,mid,last)", initial_value="0,mid,last")
        state.traj_build_btn = server.gui.add_button("Build / Refresh")

    state.status_label = server.gui.add_markdown("")

    # build nodes
    build_bg_nodes(server, frames, init_percentile=state.slider_bg.value, bg_conf_all_flat=bg_conf_all_flat, state=state)
    build_fg_nodes(server, frames, nearest_idx, t_vals, fg_conf_all_t, init_percentile=state.slider_fg.value, state=state)
    print("\n[viewer] Ready. Open the printed URL and play with sliders!\n")

    # --- playback loop thread ---
    def _playback_loop():
        while True:
            if state.playing:
                try:
                    fps = max(1, int(state.fps_slider.value)) if state.fps_slider else 10
                    dt = 1.0 / float(fps)
                    tv = float(state.slider_time.value)
                    tv_next = tv + t_step
                    if tv_next > 1.0 - 1e-6:
                        if state.loop_checkbox and state.loop_checkbox.value:
                            tv_next = 0.0
                        else:
                            tv_next = 1.0 - 1e-6
                            state.playing = False
                    state.slider_time.value = tv_next
                except Exception:
                    pass
            time.sleep(dt if state.playing else 0.05)

    threading.Thread(target=_playback_loop, daemon=True).start()

    # callbacks
    @state.slider_time.on_update
    def _(_):
        tv = state.slider_time.value
        tidx = int(round(tv / t_step))
        tidx = max(0, min(tidx, len(t_vals) - 1))
        with server.atomic():
            for t in range(len(state.fg_nodes)):
                state.fg_nodes[t].visible = (t == tidx)
        server.flush()
        # NEW: only show trajectories up to current t
        _update_traj_visibility(state, server, tidx, on=(state.show_traj and state.show_traj.value))

    @state.play_btn.on_click
    def _(_):
        state.playing = True

    @state.pause_btn.on_click
    def _(_):
        state.playing = False

    def _rebuild_filter(_):
        update_conf_filtering(
            server=server,
            state=state,
            frames=frames,
            nearest_idx=nearest_idx,
            t_vals=t_vals,
            fg_conf_all_t=fg_conf_all_t,
            bg_conf_all_flat=bg_conf_all_flat,
            fg_percentile=state.slider_fg.value,
            bg_percentile=state.slider_bg.value,
        )

    @state.slider_fg.on_update
    def _(_):
        _rebuild_filter(_)

    @state.slider_bg.on_update
    def _(_):
        _rebuild_filter(_)

    @state.point_size.on_update
    def _(_):
        with server.atomic():
            for n in state.fg_nodes:
                n.point_size = state.point_size.value
        server.flush()

    @state.bg_point_size.on_update
    def _(_):
        with server.atomic():
            for n in state.bg_nodes:
                n.point_size = state.bg_point_size.value
        server.flush()

    # --- trajectories build/refresh and live controls ---
    def _parse_traj_frames():
        txt = (state.traj_frames_text.value or "").strip()
        if not txt:
            return []
        tokens = [t.strip() for t in txt.split(",") if t.strip()]
        n = len(frames)
        out_idx = []
        for tk in tokens:
            if tk == "mid":
                out_idx.append(n // 2)
            elif tk == "last":
                out_idx.append(n - 1)
            else:
                try:
                    out_idx.append(int(tk))
                except Exception:
                    pass
        out_idx = [i for i in sorted(set(out_idx)) if 0 <= i < n]
        return out_idx

    @state.traj_build_btn.on_click
    def _(_):
        sel = _parse_traj_frames()
        if not sel:
            return
        if state.status_label:
            state.status_label.value = "🧵 Building trajectories…"
        build_traj_nodes(
            server=server,
            output=output,
            frames=frames,
            traj_frames=sel,
            t_vals=t_vals,
            max_points=10000,
            state=state,
        )
        tidx = int(round(state.slider_time.value / t_step))
        tidx = max(0, min(tidx, len(t_vals) - 1))
        _update_traj_visibility(state, server, tidx, on=(state.show_traj and state.show_traj.value))
        if state.status_label:
            state.status_label.value = ""
        if state.show_traj.value:
            tidx = int(round(state.slider_time.value / t_step))
            tidx = max(0, min(tidx, len(t_vals) - 1))
            with server.atomic():
                for t, nodes in enumerate(state.traj_nodes):
                    vis = (t <= tidx)
                    for nd in nodes:
                        nd.visible = vis
            server.flush()

    @state.show_traj.on_update
    def _(_):
        tidx = int(round(state.slider_time.value / t_step))
        tidx = max(0, min(tidx, len(t_vals) - 1))
        _update_traj_visibility(state, server, tidx, on=state.show_traj.value)

    @state.traj_width.on_update
    def _(_):
        w = state.traj_width.value
        with server.atomic():
            for nodes in state.traj_nodes:
                for nd in nodes:
                    nd.line_width = w
        server.flush()

    return server

def parse_args():
    p = argparse.ArgumentParser("TraceAnything viewer")
    p.add_argument("--output", type=str, default="./examples/output/elephant/output.pt",
                   help="Path to output.pt or parent directory.")
    p.add_argument("--port", type=int, default=8020)
    p.add_argument("--t_step", type=float, default=0.025)
    p.add_argument("--ds", type=int, default=1, help="downsample stride for H,W (>=1)")
    return p.parse_args()

def main():
    args = parse_args()
    out = load_output_dict(args.output)
    server = serve_view(out, port=args.port, t_step=args.t_step, ds=max(1, args.ds))
    try:
        while True:
            time.sleep(3600)
    except KeyboardInterrupt:
        pass

if __name__ == "__main__":
    main()

```

## /trace_anything/__init__.py

```py path="/trace_anything/__init__.py" 
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
```

## /trace_anything/heads.py

```py path="/trace_anything/heads.py" 
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# trace_anything/heads.py
from typing import List
import torch
import torch.nn as nn
from einops import rearrange
from .layers.dpt_block import DPTOutputAdapter  # 本地化


def _resolve_hw(img_info):
    """Normalize `img_info` into the form (int(H), int(W)).
    Allowed input formats:
      - torch.Tensor of shape (..., 2) 
        (e.g. [num_views, B, 2] or [B, 2])
      - (h, w) tuple/list, where elements can be int or Tensor
    Requirement: H/W must be consistent across the whole batch; 
    otherwise, raise an error 
    (current inference path does not support mixed-size batches).
    """
    if isinstance(img_info, torch.Tensor):
        assert img_info.shape[-1] == 2, f"img_info last dim must be 2, got {img_info.shape}"
        h0 = img_info.reshape(-1, 2)[0, 0].item()
        w0 = img_info.reshape(-1, 2)[0, 1].item()
        if (img_info[..., 0] != img_info[..., 0].reshape(-1)[0]).any() or \
           (img_info[..., 1] != img_info[..., 1].reshape(-1)[0]).any():
            raise AssertionError(f"Mixed H/W in batch not supported: {tuple(img_info.shape)}")
        return int(h0), int(w0)

    if isinstance(img_info, (list, tuple)) and len(img_info) == 2:
        h, w = img_info
        if isinstance(h, torch.Tensor): h = int(h.reshape(-1)[0].item())
        if isinstance(w, torch.Tensor): w = int(w.reshape(-1)[0].item())
        return int(h), int(w)

    raise TypeError(f"Unexpected img_info type: {type(img_info)}")


# ---------- postprocess ----------
def reg_dense_depth(xyz, mode):
    mode, vmin, vmax = mode
    no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
    assert no_bounds
    if mode == "linear":
        return xyz if no_bounds else xyz.clip(min=vmin, max=vmax)
    d = xyz.norm(dim=-1, keepdim=True)
    xyz = xyz / d.clip(min=1e-8)
    if mode == "square":
        return xyz * d.square()
    if mode == "exp":
        return xyz * torch.expm1(d)
    raise ValueError(f"bad {mode=}")

def reg_dense_conf(x, mode):
    mode, vmin, vmax = mode
    if mode == "exp":
        return vmin + x.exp().clip(max=vmax - vmin)
    if mode == "sigmoid":
        return (vmax - vmin) * torch.sigmoid(x) + vmin
    raise ValueError(f"bad {mode=}")

def postprocess(out, depth_mode, conf_mode):
    fmap = out.permute(0, 2, 3, 1)  # B,H,W,C
    res = dict(pts3d=reg_dense_depth(fmap[..., 0:3], mode=depth_mode))
    if conf_mode is not None:
        res["conf"] = reg_dense_conf(fmap[..., 3], mode=conf_mode)
    return res

def postprocess_multi_point(out, depth_mode, conf_mode):
    fmap = out.permute(0, 2, 3, 1)  # B,H,W,C
    B, H, W, C = fmap.shape
    n_point = C // 4
    pts_3d = fmap[..., :3*n_point].view(B, H, W, 3, n_point).permute(4, 0, 1, 2, 3)  # [K,B,H,W,3]
    conf   = fmap[..., 3*n_point:].view(B, H, W, 1, n_point).squeeze(3).permute(3, 0, 1, 2)  # [K,B,H,W]
    res = dict(pts3d=reg_dense_depth(pts_3d, mode=depth_mode))
    res["conf"] = reg_dense_conf(conf, mode=conf_mode)
    return res

class DPTOutputAdapterFix(DPTOutputAdapter):
    def init(self, dim_tokens_enc=768):
        super().init(dim_tokens_enc)
        del self.act_1_postprocess
        del self.act_2_postprocess
        del self.act_3_postprocess
        del self.act_4_postprocess
    
    def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
        assert (
            self.dim_tokens_enc is not None
        ), "Need to call init(dim_tokens_enc) function first"
        image_size = self.image_size if image_size is None else image_size
        H, W = image_size
        H, W = int(H), int(W) 

        # Number of patches in height and width
        N_H = H // (self.stride_level * self.P_H)
        N_W = W // (self.stride_level * self.P_W)
        layers = [encoder_tokens[h] for h in self.hooks]          # 4 x [B,N,C]
        layers = [self.adapt_tokens(l) for l in layers]           # 4 x [B,N,C]
        layers = [rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers]
        layers = [self.act_postprocess[i](l) for i, l in enumerate(layers)]
        layers = [self.scratch.layer_rn[i](l) for i, l in enumerate(layers)]
        p4 = self.scratch.refinenet4(layers[3])[:, :, : layers[2].shape[2], : layers[2].shape[3]]
        p3 = self.scratch.refinenet3(p4, layers[2])
        p2 = self.scratch.refinenet2(p3, layers[1])
        p1 = self.scratch.refinenet1(p2, layers[0])

        max_chunk = 1 if self.training else 50
        outs = []
        for ch in torch.split(p1, max_chunk, dim=0):
            outs.append(self.head(ch))
        return torch.cat(outs, dim=0)  # [B,C,H,W]

# ---------- Heads ----------
class ScalarHead(nn.Module):
    def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
                 output_width_ratio=1, num_channels=1, **kwargs):
        super().__init__()
        assert n_cls_token == 0
        dpt_args = dict(output_width_ratio=output_width_ratio, num_channels=num_channels, **kwargs)
        if hooks_idx is not None:
            dpt_args.update(hooks=hooks_idx)
        self.dpt = DPTOutputAdapterFix(**dpt_args)
        dpt_init = {} if dim_tokens is None else {"dim_tokens_enc": dim_tokens}
        self.dpt.init(**dpt_init)
        self.scalar_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # B,C,1,1  (C==1)
            nn.Flatten(),             # B,1
            nn.Linear(1, 1),
            nn.Sigmoid(),
        )

    def forward(self, x_list, img_info):
        H, W = _resolve_hw(img_info)                     
        out = self.dpt(x_list, image_size=(H, W))       # [B,1,H,W]
        return self.scalar_head(out)                    # [B,1]

class PixelHead(nn.Module):
    """Output per-pixel (3D point + confidence), supports multiple points (K)."""
    def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
                 output_width_ratio=1, num_channels=1,
                 postprocess=postprocess_multi_point,
                 depth_mode=None, conf_mode=None, **kwargs):
        super().__init__()
        assert n_cls_token == 0
        self.postprocess = postprocess
        self.depth_mode = depth_mode
        self.conf_mode  = conf_mode
        dpt_args = dict(output_width_ratio=output_width_ratio, num_channels=num_channels, **kwargs)
        if hooks_idx is not None:
            dpt_args.update(hooks=hooks_idx)
        self.dpt = DPTOutputAdapterFix(**dpt_args)
        dpt_init = {} if dim_tokens is None else {"dim_tokens_enc": dim_tokens}
        self.dpt.init(**dpt_init)

    def forward(self, x_list, img_info):
        H, W = _resolve_hw(img_info)                    
        out = self.dpt(x_list, image_size=(H, W))        # [B,C,H,W]
        return self.postprocess(out, self.depth_mode, self.conf_mode)
```

## /trace_anything/layers/__init__.py

```py path="/trace_anything/layers/__init__.py" 
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
```

## /trace_anything/layers/blocks.py

```py path="/trace_anything/layers/blocks.py" 
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).


# --------------------------------------------------------
# Main encoder/decoder blocks
# --------------------------------------------------------
# References:
# timm
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py


import collections.abc
from itertools import repeat
import math

import torch
import torch.nn as nn
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return x
        return tuple(repeat(x, n))

    return parse


to_2tuple = _ntuple(2)


def drop_path(
    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (
        x.ndim - 1
    )  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""

    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f"drop_prob={round(self.drop_prob,3):0.3f}"


class Mlp(nn.Module):
    """MLP as used in Vision Transformer, MLP-Mixer and related networks"""

    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        bias=True,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class Attention(nn.Module):
    def __init__(
        self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0,
        attn_mask=None, is_causal=False, attn_implementation="pytorch_naive",
        attn_bias_for_inference_enabled=False,
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5
        # use attention biasing to accommodate for longer sequences than during training
        self.attn_bias_for_inference_enabled = attn_bias_for_inference_enabled
        gamma = 1.0
        train_seqlen = 20
        inference_seqlen = 137
        self.attn_bias_scale = head_dim**-0.5 * (gamma * math.log(inference_seqlen) / math.log(train_seqlen))**0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.dropout_p = attn_drop
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.rope = rope
        self.attn_mask = attn_mask
        self.is_causal = is_causal
        self.attn_implementation = attn_implementation

    def forward(self, x, xpos):
        B, N, C = x.shape

        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .transpose(1, 3)
        )
        q, k, v = [qkv[:, :, i] for i in range(3)]
        # q,k,v = qkv.unbind(2)  # make torchscript happy (cannot use tensor as tuple)

        if self.rope is not None:
            with torch.autocast(device_type=next(self.parameters()).device.type, dtype=torch.float32):  # FIXME: for some reason Lightning didn't pick up torch.cuda.amp.custom_fwd when using bf16-true
                q = self.rope(q, xpos) if xpos is not None else q
                k = self.rope(k, xpos) if xpos is not None else k

        if not self.training and self.attn_bias_for_inference_enabled:
            scale = self.attn_bias_scale
        else:
            scale = self.scale

        # Important: For the fusion Transformer, we forward through the attention with bfloat16 precision
        # If you are not using this block for the fusion Transformer, you should double check the precision of the input and output
        if self.attn_implementation == "pytorch_naive":
            assert self.attn_mask is None, "attn_mask not supported for pytorch_naive implementation of scaled dot product attention"
            assert self.is_causal is False, "is_causal not supported for pytorch_naive implementation of scaled dot product attention"
            dtype = k.dtype
            with torch.autocast("cuda", dtype=torch.bfloat16):
                x = (q @ k.transpose(-2, -1)) * scale
                x = x.softmax(dim=-1)
                x = self.attn_drop(x)
            if dtype == torch.float32:  # if input was FP32, cast back to FP32
                x = x.to(torch.float32)
            x = (x @ v).transpose(1, 2).reshape(B, N, C)
            x = self.proj(x)
            x = self.proj_drop(x)  
        elif self.attn_implementation == "flash_attention":
            with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
                dtype = k.dtype
                with torch.autocast("cuda", dtype=torch.bfloat16):
                    x = scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout_p, is_causal=self.is_causal, scale=scale)
                if dtype == torch.float32:  # if input was FP32, cast back to FP32
                    x = x.to(torch.float32)
                x = x.transpose(1, 2).reshape(B, N, C)
                x = self.proj(x)
                x = self.proj_drop(x)
        elif self.attn_implementation == "pytorch_auto":
            with torch.nn.attention.sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION,]):
                dtype = k.dtype
                with torch.autocast("cuda", dtype=torch.bfloat16):
                    x = scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout_p, is_causal=self.is_causal, scale=scale)
                if dtype == torch.float32:  # if input was FP32, cast back to FP32
                    x = x.to(torch.float32)
                x = x.transpose(1, 2).reshape(B, N, C)
                x = self.proj(x)
                x = self.proj_drop(x)
        else:
            raise ValueError(f"Unknown attn_implementation: {self.attn_implementation}")

        return x


class Block(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        rope=None,
        attn_implementation="pytorch_naive",
        attn_bias_for_inference_enabled=False,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            rope=rope,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
            attn_implementation=attn_implementation,
            attn_bias_for_inference_enabled=attn_bias_for_inference_enabled,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x, xpos):
        x = x + self.drop_path(self.attn(self.norm1(x), xpos))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class CrossAttention(nn.Module):
    def __init__(
        self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, attn_mask=None, is_causal=False, attn_implementation="pytorch_naive"
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        self.projq = nn.Linear(dim, dim, bias=qkv_bias)
        self.projk = nn.Linear(dim, dim, bias=qkv_bias)
        self.projv = nn.Linear(dim, dim, bias=qkv_bias)
        self.dropout_p = attn_drop
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.rope = rope

        self.attn_mask = attn_mask
        self.is_causal = is_causal
        self.attn_implementation = attn_implementation

    def forward(self, query, key, value, qpos, kpos):
        B, Nq, C = query.shape
        Nk = key.shape[1]
        Nv = value.shape[1]

        q = (
            self.projq(query)
            .reshape(B, Nq, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )
        k = (
            self.projk(key)
            .reshape(B, Nk, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )
        v = (
            self.projv(value)
            .reshape(B, Nv, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )

        if self.rope is not None:
            with torch.autocast(device_type=next(self.parameters()).device.type, dtype=torch.float32):  # FIXME: for some reason Lightning didn't pick up torch.cuda.amp.custom_fwd when using bf16-true
                q = self.rope(q, qpos) if qpos is not None else q
                k = self.rope(k, kpos) if kpos is not None else k

        if self.attn_implementation == "pytorch_naive":
            assert self.attn_mask is None, "attn_mask not supported for pytorch_naive implementation of scaled dot product attention"
            assert self.is_causal is False, "is_causal not supported for pytorch_naive implementation of scaled dot product attention"
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)

            x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
            x = self.proj(x)
            x = self.proj_drop(x)
        elif self.attn_implementation == "flash_attention":
            with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
                # cast to BF16 to use flash_attention
                q = q.to(torch.bfloat16)
                k = k.to(torch.bfloat16)
                v = v.to(torch.bfloat16)
                x = scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout_p, is_causal=self.is_causal, scale=self.scale)
                # cast back to FP32
                x = x.to(torch.float32)
                x = x.transpose(1, 2).reshape(B, Nq, C)
                x = self.proj(x)
                x = self.proj_drop(x)
        else:
            raise ValueError(f"Unknown attn_implementation: {self.attn_implementation}")

        return x


class DecoderBlock(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        norm_mem=True,
        rope=None,
        attn_implementation="pytorch_naive",
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            rope=rope,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
            attn_implementation=attn_implementation,
        )
        self.cross_attn = CrossAttention(
            dim,
            rope=rope,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
            attn_implementation=attn_implementation,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.norm3 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()

    def forward(self, x, y, xpos, ypos):
        x = x + self.drop_path(self.attn(self.norm1(x), xpos))
        y_ = self.norm_y(y)
        x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
        x = x + self.drop_path(self.mlp(self.norm3(x)))
        return x, y


# patch embedding
class PositionGetter(object):
    """return positions of patches"""

    def __init__(self):
        self.cache_positions = {}

    def __call__(self, b, h, w, device):
        if not (h, w) in self.cache_positions:
            x = torch.arange(w, device=device)
            y = torch.arange(h, device=device)
            self.cache_positions[h, w] = torch.cartesian_prod(y, x)  # (h, w, 2)
        pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone()
        return pos


class PatchEmbed(nn.Module):
    """just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""

    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        norm_layer=None,
        flatten=True,
    ):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

        self.position_getter = PositionGetter()

    def forward(self, x):
        B, C, H, W = x.shape
        torch._assert(
            H == self.img_size[0],
            f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
        )
        torch._assert(
            W == self.img_size[1],
            f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
        )
        x = self.proj(x)
        pos = self.position_getter(B, x.size(2), x.size(3), x.device)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2).contiguous()  # BCHW -> BNC
        x = self.norm(x)
        return x, pos

    def _init_weights(self):
        w = self.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

```

## /trace_anything/layers/dpt_block.py

```py path="/trace_anything/layers/dpt_block.py" 
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).

# --------------------------------------------------------
# DPT head for ViTs
# --------------------------------------------------------
# References:
# https://github.com/isl-org/DPT
# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py

from typing import Iterable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat


def pair(t):
    return t if isinstance(t, tuple) else (t, t)


def make_scratch(in_shape, out_shape, groups=1, expand=False):
    scratch = nn.Module()

    out_shape1 = out_shape
    out_shape2 = out_shape
    out_shape3 = out_shape
    out_shape4 = out_shape
    if expand == True:
        out_shape1 = out_shape
        out_shape2 = out_shape * 2
        out_shape3 = out_shape * 4
        out_shape4 = out_shape * 8

    scratch.layer1_rn = nn.Conv2d(
        in_shape[0],
        out_shape1,
        kernel_size=3,
        stride=1,
        padding=1,
        bias=False,
        groups=groups,
    )
    scratch.layer2_rn = nn.Conv2d(
        in_shape[1],
        out_shape2,
        kernel_size=3,
        stride=1,
        padding=1,
        bias=False,
        groups=groups,
    )
    scratch.layer3_rn = nn.Conv2d(
        in_shape[2],
        out_shape3,
        kernel_size=3,
        stride=1,
        padding=1,
        bias=False,
        groups=groups,
    )
    scratch.layer4_rn = nn.Conv2d(
        in_shape[3],
        out_shape4,
        kernel_size=3,
        stride=1,
        padding=1,
        bias=False,
        groups=groups,
    )

    scratch.layer_rn = nn.ModuleList(
        [
            scratch.layer1_rn,
            scratch.layer2_rn,
            scratch.layer3_rn,
            scratch.layer4_rn,
        ]
    )

    return scratch


class ResidualConvUnit_custom(nn.Module):
    """Residual convolution module."""

    def __init__(self, features, activation, bn):
        """Init.
        Args:
            features (int): number of features
        """
        super().__init__()

        self.bn = bn

        self.groups = 1

        self.conv1 = nn.Conv2d(
            features,
            features,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=not self.bn,
            groups=self.groups,
        )

        self.conv2 = nn.Conv2d(
            features,
            features,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=not self.bn,
            groups=self.groups,
        )

        if self.bn == True:
            self.bn1 = nn.BatchNorm2d(features)
            self.bn2 = nn.BatchNorm2d(features)

        self.activation = activation

        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x):
        """Forward pass.
        Args:
            x (tensor): input
        Returns:
            tensor: output
        """

        out = self.activation(x)
        out = self.conv1(out)
        if self.bn == True:
            out = self.bn1(out)

        out = self.activation(out)
        out = self.conv2(out)
        if self.bn == True:
            out = self.bn2(out)

        if self.groups > 1:
            out = self.conv_merge(out)

        return self.skip_add.add(out, x)


class FeatureFusionBlock_custom(nn.Module):
    """Feature fusion block."""

    def __init__(
        self,
        features,
        activation,
        deconv=False,
        bn=False,
        expand=False,
        align_corners=True,
        width_ratio=1,
    ):
        """Init.
        Args:
            features (int): number of features
        """
        super(FeatureFusionBlock_custom, self).__init__()
        self.width_ratio = width_ratio

        self.deconv = deconv
        self.align_corners = align_corners

        self.groups = 1

        self.expand = expand
        out_features = features
        if self.expand == True:
            out_features = features // 2

        self.out_conv = nn.Conv2d(
            features,
            out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
            groups=1,
        )

        self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
        self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)

        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, *xs, max_chunk_size=100):
        """Forward pass.
        Returns:
            tensor: output
        """
        output = xs[0]

        if len(xs) == 2:
            res = self.resConfUnit1(xs[1])
            if self.width_ratio != 1:
                res = F.interpolate(
                    res, size=(output.shape[2], output.shape[3]), mode="bilinear"
                )

            output = self.skip_add.add(output, res)
            # output += res

        output = self.resConfUnit2(output)

        if self.width_ratio != 1:
            # and output.shape[3] < self.width_ratio * output.shape[2]
            # size=(image.shape[])
            if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
                shape = 3 * output.shape[3]
            else:
                shape = int(self.width_ratio * 2 * output.shape[2])
            output = F.interpolate(
                output, size=(2 * output.shape[2], shape), mode="bilinear"
            )
        else:
            # Split input into chunks to avoid memory issues with large batches

            chunks = torch.split(output, max_chunk_size, dim=0)
            outputs = []

            for chunk in chunks:
                out_chunk = nn.functional.interpolate(
                    chunk,
                    scale_factor=2,
                    mode="bilinear",
                    align_corners=self.align_corners,
                )
                outputs.append(out_chunk)

            # Concatenate outputs along the batch dimension
            output = torch.cat(outputs, dim=0)

        output = self.out_conv(output)
        return output


def make_fusion_block(features, use_bn, width_ratio=1):
    return FeatureFusionBlock_custom(
        features,
        nn.ReLU(False),
        deconv=False,
        bn=use_bn,
        expand=False,
        align_corners=True,
        width_ratio=width_ratio,
    )


class Interpolate(nn.Module):
    """Interpolation module."""

    def __init__(self, scale_factor, mode, align_corners=False):
        """Init.
        Args:
            scale_factor (float): scaling
            mode (str): interpolation mode
        """
        super(Interpolate, self).__init__()

        self.interp = nn.functional.interpolate
        self.scale_factor = scale_factor
        self.mode = mode
        self.align_corners = align_corners

    def forward(self, x):
        """Forward pass.
        Args:
            x (tensor): input
        Returns:
            tensor: interpolated data
        """

        x = self.interp(
            x,
            scale_factor=self.scale_factor,
            mode=self.mode,
            align_corners=self.align_corners,
        )

        return x


class DPTOutputAdapter(nn.Module):
    """DPT output adapter.

    :param num_cahnnels: Number of output channels
    :param stride_level: tride level compared to the full-sized image.
        E.g. 4 for 1/4th the size of the image.
    :param patch_size_full: Int or tuple of the patch size over the full image size.
        Patch size for smaller inputs will be computed accordingly.
    :param hooks: Index of intermediate layers
    :param layer_dims: Dimension of intermediate layers
    :param feature_dim: Feature dimension
    :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
    :param use_bn: If set to True, activates batch norm
    :param dim_tokens_enc:  Dimension of tokens coming from encoder
    """

    def __init__(
        self,
        num_channels: int = 1,
        stride_level: int = 1,
        patch_size: Union[int, Tuple[int, int]] = 16,
        main_tasks: Iterable[str] = ("rgb",),
        hooks: List[int] = [2, 5, 8, 11],
        layer_dims: List[int] = [96, 192, 384, 768],
        feature_dim: int = 256,
        last_dim: int = 32,
        use_bn: bool = False,
        dim_tokens_enc: Optional[int] = None,
        head_type: str = "regression",
        output_width_ratio=1,
        **kwargs
    ):
        super().__init__()
        self.num_channels = num_channels
        self.stride_level = stride_level
        self.patch_size = pair(patch_size)
        self.main_tasks = main_tasks
        self.hooks = hooks
        self.layer_dims = layer_dims
        self.feature_dim = feature_dim
        self.dim_tokens_enc = (
            dim_tokens_enc * len(self.main_tasks)
            if dim_tokens_enc is not None
            else None
        )
        self.head_type = head_type

        # Actual patch height and width, taking into account stride of input
        self.P_H = max(1, self.patch_size[0] // stride_level)
        self.P_W = max(1, self.patch_size[1] // stride_level)

        self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)

        self.scratch.refinenet1 = make_fusion_block(
            feature_dim, use_bn, output_width_ratio
        )
        self.scratch.refinenet2 = make_fusion_block(
            feature_dim, use_bn, output_width_ratio
        )
        self.scratch.refinenet3 = make_fusion_block(
            feature_dim, use_bn, output_width_ratio
        )
        self.scratch.refinenet4 = make_fusion_block(
            feature_dim, use_bn, output_width_ratio
        )

        if self.head_type == "regression":
            # The "DPTDepthModel" head
            self.head = nn.Sequential(
                nn.Conv2d(
                    feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1
                ),
                # the act_postprocess layers upsample each patch by 8 in total,
                # so self.patch_size / 8 calculates how much more we need to upsample
                # to get to the full image size (remember that num_patches = image_size / patch_size)
                Interpolate(scale_factor=self.patch_size[0] / 8, mode="bilinear", align_corners=True),
                nn.Conv2d(
                    feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1
                ),
                nn.ReLU(True),
                nn.Conv2d(
                    last_dim, self.num_channels, kernel_size=1, stride=1, padding=0
                ),
            )
        elif self.head_type == "semseg":
            # The "DPTSegmentationModel" head
            self.head = nn.Sequential(
                nn.Conv2d(
                    feature_dim, feature_dim, kernel_size=3, padding=1, bias=False
                ),
                nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
                nn.ReLU(True),
                nn.Dropout(0.1, False),
                nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
                Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
            )
        else:
            raise ValueError('DPT head_type must be "regression" or "semseg".')

        if self.dim_tokens_enc is not None:
            self.init(dim_tokens_enc=dim_tokens_enc)

    def init(self, dim_tokens_enc=768):
        """
        Initialize parts of decoder that are dependent on dimension of encoder tokens.
        Should be called when setting up MultiMAE.

        :param dim_tokens_enc: Dimension of tokens coming from encoder
        """
        # print(dim_tokens_enc)

        # Set up activation postprocessing layers
        if isinstance(dim_tokens_enc, int):
            dim_tokens_enc = 4 * [dim_tokens_enc]

        self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]

        self.act_1_postprocess = nn.Sequential(
            nn.Conv2d(
                in_channels=self.dim_tokens_enc[0],
                out_channels=self.layer_dims[0],
                kernel_size=1,
                stride=1,
                padding=0,
            ),
            nn.ConvTranspose2d(
                in_channels=self.layer_dims[0],
                out_channels=self.layer_dims[0],
                kernel_size=4,
                stride=4,
                padding=0,
                bias=True,
                dilation=1,
                groups=1,
            ),
        )

        self.act_2_postprocess = nn.Sequential(
            nn.Conv2d(
                in_channels=self.dim_tokens_enc[1],
                out_channels=self.layer_dims[1],
                kernel_size=1,
                stride=1,
                padding=0,
            ),
            nn.ConvTranspose2d(
                in_channels=self.layer_dims[1],
                out_channels=self.layer_dims[1],
                kernel_size=2,
                stride=2,
                padding=0,
                bias=True,
                dilation=1,
                groups=1,
            ),
        )

        self.act_3_postprocess = nn.Sequential(
            nn.Conv2d(
                in_channels=self.dim_tokens_enc[2],
                out_channels=self.layer_dims[2],
                kernel_size=1,
                stride=1,
                padding=0,
            )
        )

        self.act_4_postprocess = nn.Sequential(
            nn.Conv2d(
                in_channels=self.dim_tokens_enc[3],
                out_channels=self.layer_dims[3],
                kernel_size=1,
                stride=1,
                padding=0,
            ),
            nn.Conv2d(
                in_channels=self.layer_dims[3],
                out_channels=self.layer_dims[3],
                kernel_size=3,
                stride=2,
                padding=1,
            ),
        )

        self.act_postprocess = nn.ModuleList(
            [
                self.act_1_postprocess,
                self.act_2_postprocess,
                self.act_3_postprocess,
                self.act_4_postprocess,
            ]
        )

    def adapt_tokens(self, encoder_tokens):
        # Adapt tokens
        x = []
        x.append(encoder_tokens[:, :])
        x = torch.cat(x, dim=-1)
        return x

    def forward(self, encoder_tokens: List[torch.Tensor], image_size):
        # input_info: Dict):
        assert (
            self.dim_tokens_enc is not None
        ), "Need to call init(dim_tokens_enc) function first"
        H, W = image_size

        # Number of patches in height and width
        N_H = H // (self.stride_level * self.P_H)
        N_W = W // (self.stride_level * self.P_W)

        # Hook decoder onto 4 layers from specified ViT layers
        layers = [encoder_tokens[hook] for hook in self.hooks]

        # Extract only task-relevant tokens and ignore global tokens.
        layers = [self.adapt_tokens(l) for l in layers]

        # Reshape tokens to spatial representation
        layers = [
            rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers
        ]

        layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
        # Project layers to chosen feature dim
        layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]

        # Fuse layers using refinement stages
        path_4 = self.scratch.refinenet4(layers[3])
        path_3 = self.scratch.refinenet3(path_4, layers[2])
        path_2 = self.scratch.refinenet2(path_3, layers[1])
        path_1 = self.scratch.refinenet1(path_2, layers[0])

        # Output head
        out = self.head(path_1)

        return out

```

## /trace_anything/layers/patch_embed.py

```py path="/trace_anything/layers/patch_embed.py" 
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# PatchEmbed implementation for DUST3R,
# in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio
# --------------------------------------------------------
import torch
from trace_anything.layers.blocks import PatchEmbed


def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):
    assert patch_embed_cls in ["PatchEmbedDust3R", "ManyAR_PatchEmbed"]
    patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim)
    return patch_embed


class PatchEmbedDust3R(PatchEmbed):
    def forward(self, x, **kw):
        B, C, H, W = x.shape
        assert (
            H % self.patch_size[0] == 0
        ), f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
        assert (
            W % self.patch_size[1] == 0
        ), f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
        x = self.proj(x)
        pos = self.position_getter(B, x.size(2), x.size(3), x.device)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x, pos


class ManyAR_PatchEmbed(PatchEmbed):
    """Handle images with non-square aspect ratio.
    All images in the same batch have the same aspect ratio.
    true_shape = [(height, width) ...] indicates the actual shape of each image.
    """

    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        norm_layer=None,
        flatten=True,
    ):
        self.embed_dim = embed_dim
        super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten)

    def forward(self, img, true_shape):
        B, C, H, W = img.shape
        assert W >= H, f"img should be in landscape mode, but got {W=} {H=}"
        assert (
            H % self.patch_size[0] == 0
        ), f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
        assert (
            W % self.patch_size[1] == 0
        ), f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
        assert true_shape.shape == (
            B,
            2,
        ), f"true_shape has the wrong shape={true_shape.shape}"

        # size expressed in tokens
        W //= self.patch_size[0]
        H //= self.patch_size[1]
        n_tokens = H * W

        height, width = true_shape.T
        is_landscape = width >= height
        is_portrait = ~is_landscape

        # linear projection, transposed if necessary
        if is_landscape.any():
            new_landscape_content = self.proj(img[is_landscape])
            new_landscape_content = new_landscape_content.permute(0, 2, 3, 1).flatten(1, 2)
        if is_portrait.any():
            new_protrait_content = self.proj(img[is_portrait].swapaxes(-1, -2))
            new_protrait_content = new_protrait_content.permute(0, 2, 3, 1).flatten(1, 2)

        # allocate space for result and set the content
        x = img.new_empty((B, n_tokens, self.embed_dim), dtype=next(self.named_parameters())[1].dtype)  # dynamically set dtype based on the current precision
        if is_landscape.any():
            x[is_landscape] = new_landscape_content
        if is_portrait.any():
            x[is_portrait] = new_protrait_content

        # allocate space for result and set the content
        pos = img.new_empty((B, n_tokens, 2), dtype=torch.int64)
        if is_landscape.any():
            pos[is_landscape] = self.position_getter(1, H, W, pos.device).expand(is_landscape.sum(), -1, -1)
        if is_portrait.any():
            pos[is_portrait] = self.position_getter(1, W, H, pos.device).expand(is_portrait.sum(), -1, -1)

        x = self.norm(x)
        return x, pos

```

## /trace_anything/layers/pos_embed.py

```py path="/trace_anything/layers/pos_embed.py" 
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).


# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------


import numpy as np
import torch


# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if n_cls_token > 0:
        pos_embed = np.concatenate(
            [np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0
        )
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
    if "pos_embed" in checkpoint_model:
        pos_embed_checkpoint = checkpoint_model["pos_embed"]
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches**0.5)
        # class_token and dist_token are kept unchanged
        if orig_size != new_size:
            print(
                "Position interpolate from %dx%d to %dx%d"
                % (orig_size, orig_size, new_size, new_size)
            )
            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
            # only the position tokens are interpolated
            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
            pos_tokens = pos_tokens.reshape(
                -1, orig_size, orig_size, embedding_size
            ).permute(0, 3, 1, 2)
            pos_tokens = torch.nn.functional.interpolate(
                pos_tokens,
                size=(new_size, new_size),
                mode="bicubic",
                align_corners=False,
            )
            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
            checkpoint_model["pos_embed"] = new_pos_embed


# ----------------------------------------------------------
# RoPE2D: RoPE implementation in 2D
# ----------------------------------------------------------

try:
    from fast3r.croco.models.curope import cuRoPE2D

    RoPE2D = cuRoPE2D
except ImportError:
    print(
        "Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead"
    )

    class RoPE2D(torch.nn.Module):
        def __init__(self, freq=100.0, F0=1.0):
            super().__init__()
            self.base = freq
            self.F0 = F0
            self.cache = {}

        def get_cos_sin(self, D, seq_len, device, dtype):
            if (D, seq_len, device, dtype) not in self.cache:
                inv_freq = 1.0 / (
                    self.base ** (torch.arange(0, D, 2).float().to(device) / D)
                )
                t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
                freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
                freqs = torch.cat((freqs, freqs), dim=-1)
                cos = freqs.cos()  # (Seq, Dim)
                sin = freqs.sin()
                self.cache[D, seq_len, device, dtype] = (cos, sin)
            return self.cache[D, seq_len, device, dtype]

        @staticmethod
        def rotate_half(x):
            x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
            return torch.cat((-x2, x1), dim=-1)

        def apply_rope1d(self, tokens, pos1d, cos, sin):
            assert pos1d.ndim == 2
            cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
            sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
            return (tokens * cos) + (self.rotate_half(tokens) * sin)

        def forward(self, tokens, positions):
            """
            input:
                * tokens: batch_size x nheads x ntokens x dim
                * positions: batch_size x ntokens x 2 (y and x position of each token)
            output:
                * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
            """
            assert (
                tokens.size(3) % 2 == 0
            ), "number of dimensions should be a multiple of two"
            D = tokens.size(3) // 2
            assert positions.ndim == 3 and positions.shape[-1] == 2  # Batch, Seq, 2
            cos, sin = self.get_cos_sin(
                D, int(positions.max()) + 1, tokens.device, tokens.dtype
            )
            # split features into two along the feature dimension, and apply rope1d on each half
            y, x = tokens.chunk(2, dim=-1)
            y = self.apply_rope1d(y, positions[:, :, 0], cos, sin)
            x = self.apply_rope1d(x, positions[:, :, 1], cos, sin)
            tokens = torch.cat((y, x), dim=-1)
            return tokens

```

## /trace_anything/trace_anything.py

```py path="/trace_anything/trace_anything.py" 
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# trace_anything/trace_anything.py
"""
Minimal inference-only model for this repo.

Assumptions (pruned by asserts):
- encoder_type == 'croco'
- decoder_type == 'transformer'
- head_type == 'dpt'
- targeting_mechanism == 'bspline_conf'
- optional: whether_local (bool)
"""

import math
import time
from copy import deepcopy
from functools import partial
from typing import Dict, List

import torch
import torch.nn as nn
from einops import rearrange
import numpy as np

from .layers.blocks import Block, PositionGetter
from .layers.pos_embed import RoPE2D, get_1d_sincos_pos_embed_from_grid
from .layers.patch_embed import get_patch_embed
from .heads import PixelHead, ScalarHead

from contextlib import contextmanager
import time



# ======== B-spline  ========
PRECOMPUTED_KNOTS = {
    4:  torch.tensor([0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32),
    7:  torch.tensor([0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32),
    10: torch.tensor([0.0, 0.0, 0.0, 0.0, 1/3, 1/3, 1/3, 2/3, 2/3, 2/3, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32),
}

def _precompute_knot_differences(n_ctrl_pts, degree, knots):
    denom1 = torch.zeros(n_ctrl_pts, degree + 1, device=knots.device)
    denom2 = torch.zeros(n_ctrl_pts, degree + 1, device=knots.device)
    for k in range(degree + 1):
        for i in range(n_ctrl_pts):
            denom1[i, k] = knots[i + k] - knots[i] if i + k < len(knots) else 0.0
            denom2[i, k] = knots[i + k + 1] - knots[i + 1] if i + k + 1 < len(knots) else 1.0
    return denom1, denom2

PRECOMPUTED_DENOMS = {n: _precompute_knot_differences(n, 3, PRECOMPUTED_KNOTS[n]) for n in [4, 7, 10]}

def _compute_bspline_basis(n_ctrl_pts, degree, t_values, knots, denom1, denom2):
    N = t_values.size(0)
    basis = torch.zeros(N, n_ctrl_pts, degree + 1, device=t_values.device)
    t = t_values
    basis_k0 = torch.zeros(N, n_ctrl_pts, device=t.device)
    for i in range(n_ctrl_pts):
        if i == n_ctrl_pts - 1:
            basis_k0[:, i] = ((knots[i] <= t) & (t <= knots[i + 1])).float()
        else:
            basis_k0[:, i] = ((knots[i] <= t) & (t < knots[i + 1])).float()
    basis[:, :, 0] = basis_k0
    for k in range(1, degree + 1):
        basis_k = torch.zeros(N, n_ctrl_pts, device=t.device)
        for i in range(n_ctrl_pts):
            term1 = ((t - knots[i]) / denom1[i, k]) * basis[:, i, k-1] if denom1[i, k] > 0 else 0.0
            term2 = ((knots[i + k + 1] - t) / denom2[i, k]) * basis[:, i + 1, k-1] if (denom2[i, k] > 0 and i + 1 < n_ctrl_pts) else 0.0
            basis_k[:, i] = term1 + term2
        basis[:, :, k] = basis_k
    return basis[:, :, degree]

def evaluate_bspline_conf(ctrl_pts3d, ctrl_conf, t_values):
    """ctrl_pts3d:[N_ctrl,H,W,3], ctrl_conf:[N_ctrl,H,W], t_values:[T] -> (T,H,W,3),(T,H,W)"""
    n_ctrl_pts, H, W, _ = ctrl_pts3d.shape
    assert n_ctrl_pts in (4, 7, 10), f"unsupported n_ctrl_pts={n_ctrl_pts}"
    degree = 3
    knot_vector = PRECOMPUTED_KNOTS[n_ctrl_pts].to(ctrl_pts3d.device)
    denom1, denom2 = [d.to(ctrl_pts3d.device) for d in PRECOMPUTED_DENOMS[n_ctrl_pts]]
    ctrl_pts3d = ctrl_pts3d.permute(0, 3, 1, 2)         # [N,3,H,W]
    ctrl_conf  = ctrl_conf.unsqueeze(-1).permute(0, 3, 1, 2)  # [N,1,H,W]
    basis = _compute_bspline_basis(n_ctrl_pts, degree, t_values, knot_vector, denom1, denom2)  # [T,N]
    basis = basis.view(-1, n_ctrl_pts, 1, 1, 1)               # [T,N,1,1,1]
    pts3d_t = torch.sum(basis * ctrl_pts3d.unsqueeze(0), dim=1).permute(0, 2, 3, 1)  # [T,H,W,3]
    conf_t  = torch.sum(basis * ctrl_conf.unsqueeze(0),  dim=1).squeeze(1)           # [T,H,W]
    return pts3d_t, conf_t

# ======== Encoders(仅 CroCo) ========
class CroCoEncoder(nn.Module):
    def __init__(
        self,
        img_size=512, patch_size=16, patch_embed_cls="ManyAR_PatchEmbed",
        embed_dim=768, num_heads=12, depth=12, mlp_ratio=4,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        pos_embed="RoPE100", attn_implementation="pytorch_naive",
    ):
        super().__init__()
        assert pos_embed.startswith("RoPE"), f"pos_embed must start with RoPE*, got {pos_embed}"
        self.patch_embed = get_patch_embed(patch_embed_cls, img_size, patch_size, embed_dim)
        freq = float(pos_embed[len("RoPE"):])
        self.rope = RoPE2D(freq=freq)
        self.enc_blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                  qkv_bias=True, norm_layer=nn.LayerNorm, rope=self.rope,
                  attn_implementation=attn_implementation)
            for _ in range(depth)
        ])
        self.enc_norm = norm_layer(embed_dim)

    def forward(self, image, true_shape):
        x, pos = self.patch_embed(image, true_shape=true_shape)
        for blk in self.enc_blocks:
            x = blk(x, pos)
        x = self.enc_norm(x)
        return x, pos


# ======== Decoder ========
class TraceDecoder(nn.Module):
    def __init__(
        self,
        random_image_idx_embedding: bool,
        enc_embed_dim: int,
        embed_dim: int = 768,
        num_heads: int = 12,
        depth: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        attn_implementation: str = "pytorch_naive",
        attn_bias_for_inference_enabled=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.decoder_embed = nn.Linear(enc_embed_dim, embed_dim, bias=True)
        self.dec_blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
                  norm_layer=nn.LayerNorm, attn_implementation=attn_implementation,
                  attn_bias_for_inference_enabled=attn_bias_for_inference_enabled)
            for _ in range(depth)
        ])
        self.random_image_idx_embedding = random_image_idx_embedding
        self.register_buffer(
            "image_idx_emb",
            torch.from_numpy(get_1d_sincos_pos_embed_from_grid(embed_dim, np.arange(1000))).float(),
            persistent=False,
        )
        self.dec_norm = norm_layer(embed_dim)

    def _get_random_image_pos(self, encoded_feats, batch_size, num_views, max_image_idx, device):
        image_ids = torch.zeros(batch_size, num_views, dtype=torch.long)
        image_ids[:, 0] = 0
        per_forward_pass_seed = torch.randint(0, 2 ** 32, (1,)).item()
        per_rank_generator = torch.Generator().manual_seed(per_forward_pass_seed)
        for b in range(batch_size):
            random_ids = torch.randperm(max_image_idx, generator=per_rank_generator)[:num_views - 1] + 1
            image_ids[b, 1:] = random_ids
        image_ids = image_ids.to(device)
        image_pos_list = []
        for i in range(num_views):
            num_patches = encoded_feats[i].shape[1]
            pos_for_view = self.image_idx_emb[image_ids[:, i]].unsqueeze(1).repeat(1, num_patches, 1)
            image_pos_list.append(pos_for_view)
        return torch.cat(image_pos_list, dim=1)

    def forward(self, encoded_feats: List[torch.Tensor], positions: List[torch.Tensor],
                image_ids: torch.Tensor, image_timesteps: torch.Tensor):
        x = torch.cat(encoded_feats, dim=1)
        pos = torch.cat(positions, dim=1)
        outputs = [x]
        x = self.decoder_embed(x)

        if self.random_image_idx_embedding:
            image_pos = self._get_random_image_pos(
                encoded_feats=encoded_feats,
                batch_size=encoded_feats[0].shape[0],
                num_views=len(encoded_feats),
                max_image_idx=self.image_idx_emb.shape[0] - 1,
                device=x.device,
            )
        else:
            num_embeddings = self.image_idx_emb.shape[0]
            indices = (image_timesteps * (num_embeddings - 1)).long().view(-1)
            image_pos = torch.index_select(self.image_idx_emb, dim=0, index=indices)
            image_pos = image_pos.view(1, image_timesteps.shape[1], self.image_idx_emb.shape[1])

        x += image_pos
        for blk in self.dec_blocks:
            x = blk(x, pos)
            outputs.append(x)

        x = self.dec_norm(x)
        outputs[-1] = x
        return outputs, image_pos


# ======== main model ========
class TraceAnything(nn.Module):
    def __init__(self, *, encoder_args: Dict, decoder_args: Dict, head_args: Dict,
                 targeting_mechanism: str = "bspline_conf",
                 poly_degree: int = 10, whether_local: bool = False):
        super().__init__()

        assert targeting_mechanism == "bspline_conf", f"Only bspline_conf is supported now, got {targeting_mechanism}"
        assert encoder_args.get("encoder_type") == "croco"
        assert decoder_args.get("decoder_type", "transformer") in ("transformer", "fast3r")
        assert head_args.get("head_type") == "dpt"

        self.targeting_mechanism = targeting_mechanism
        self.poly_degree = int(poly_degree)
        self.whether_local = bool(whether_local or head_args.get("with_local_head", False))

        # build encoder / decoder
        enc_args = deepcopy(encoder_args); enc_args.pop("encoder_type", None)
        self.encoder = CroCoEncoder(**enc_args)

        dec_args = deepcopy(decoder_args); dec_args.pop("decoder_type", None)
        self.decoder = TraceDecoder(**dec_args)

        # build heads
        feature_dim = 256
        last_dim = feature_dim // 2
        ed = encoder_args["embed_dim"]; dd = decoder_args["embed_dim"]
        hooks = [0, decoder_args["depth"] * 2 // 4, decoder_args["depth"] * 3 // 4, decoder_args["depth"]]
        self.ds_head_time = ScalarHead(
            num_channels=1, feature_dim=feature_dim, last_dim=last_dim,
            hooks_idx=hooks, dim_tokens=[ed, dd, dd, dd], head_type="regression",
            patch_size=head_args["patch_size"],
        )
        out_nchan = (3 + bool(head_args["conf_mode"])) * self.poly_degree
        self.ds_head_track = PixelHead(
            num_channels=out_nchan, feature_dim=feature_dim, last_dim=last_dim,
            hooks_idx=hooks, dim_tokens=[ed, dd, dd, dd], head_type="regression",
            patch_size=head_args["patch_size"],
            depth_mode=head_args["depth_mode"], conf_mode=head_args["conf_mode"],
        )
        if self.whether_local:
            self.ds_head_local = PixelHead(
                num_channels=out_nchan, feature_dim=feature_dim, last_dim=last_dim,
                hooks_idx=hooks, dim_tokens=[ed, dd, dd, dd], head_type="regression",
                patch_size=head_args["patch_size"],
                depth_mode=head_args["depth_mode"], conf_mode=head_args["conf_mode"],
            )

        self.time_head  = self.ds_head_time
        self.track_head = self.ds_head_track
        if self.whether_local:
            self.local_head = self.ds_head_local

        self.max_parallel_views_for_head = 25

    def _encode_images(self, views, chunk_size=400):
        B = views[0]["img"].shape[0]
        same_shape = all(v["img"].shape == views[0]["img"].shape for v in views)
        if same_shape:
            imgs = torch.cat([v["img"] for v in views], dim=0)
            true_shapes = torch.cat([v.get("true_shape", torch.tensor(v["img"].shape[-2:])[None].repeat(B, 1)) for v in views], dim=0)
            feats_chunks, pos_chunks = [], []
            for s in range(0, imgs.shape[0], chunk_size):
                e = min(s + chunk_size, imgs.shape[0])
                f, p = self.encoder(imgs[s:e], true_shapes[s:e])
                feats_chunks.append(f); pos_chunks.append(p)
            feats = torch.cat(feats_chunks, dim=0); pos = torch.cat(pos_chunks, dim=0)
            encoded_feats = torch.split(feats, B, dim=0)
            positions    = torch.split(pos,   B, dim=0)
            shapes       = torch.split(true_shapes, B, dim=0)
        else:
            encoded_feats, positions, shapes = [], [], []
            for v in views:
                img = v["img"]
                true_shape = v.get("true_shape", torch.tensor(img.shape[-2:])[None].repeat(B, 1))
                f, p = self.encoder(img, true_shape)
                encoded_feats.append(f); positions.append(p); shapes.append(true_shape)
        return encoded_feats, positions, shapes

    @torch.no_grad()
    def forward(self, views, profiling: bool = False):
        # 1) encode
        encoded_feats, positions, shapes = self._encode_images(views)

        # 2) build time embedding
        num_images = len(views)
        B, P, D = encoded_feats[0].shape
        image_ids, image_times = [], []
        for i, ef in enumerate(encoded_feats):
            num_patches = ef.shape[1]
            image_ids.extend([i] * num_patches)
            image_times.extend([views[i]["time_step"]] * num_patches)
        image_ids  = torch.tensor(image_ids * B,  device=encoded_feats[0].device).reshape(B, -1)
        image_times = torch.tensor(image_times * B, device=encoded_feats[0].device).reshape(B, -1)

        # 3) decode
        dec_output, _ = self.decoder(encoded_feats, positions, image_ids, image_times)

        # 4) gather outputs per view
        P_patches = P
        gathered_outputs_list = []
        for layer_output in dec_output:
            layer_output = rearrange(layer_output, 'B (n P) D -> (n B) P D', n=num_images, P=P_patches)
            gathered_outputs_list.append(layer_output)

        # 5) heads
        time_step = self.time_head(gathered_outputs_list, torch.stack(shapes))          # [N, 1] per view
        track_tmp = self.track_head(gathered_outputs_list, torch.stack(shapes))        # dict with pts3d/conf after postprocess
        ctrl_pts3d, ctrl_conf = track_tmp['pts3d'], track_tmp['conf']                  # [K,N,H,W,3], [K,N,H,W]

        if self.whether_local:
            local_tmp = self.local_head(gathered_outputs_list, torch.stack(shapes))
            ctrl_pts3d_local, ctrl_conf_local = local_tmp['pts3d'], local_tmp['conf']

        # 6) evaluate bspline track over all reference times
        results = [{} for _ in range(num_images)]
        t_values = torch.stack([time_step[i].squeeze() for i in range(num_images)])    # [N]
        for img_id in range(num_images):
            pts3d_t, conf_t = evaluate_bspline_conf(ctrl_pts3d[:, img_id], ctrl_conf[:, img_id], t_values.detach())
            res = {
                'time': time_step[img_id],
                'ctrl_pts3d': ctrl_pts3d[:, img_id],
                'ctrl_conf':  ctrl_conf[:,  img_id],
                'track_pts3d': [pts3d_t[[t]] for t in range(num_images)],
                'track_conf':  [conf_t[[t]]  for t in range(num_images)],
            }
            if self.whether_local:
                pts3d_t_l, conf_t_l = evaluate_bspline_conf(ctrl_pts3d_local[:, img_id], ctrl_conf_local[:, img_id], t_values.detach())
                res.update({
                    'ctrl_pts3d_local': ctrl_pts3d_local[:, img_id],
                    'ctrl_conf_local':  ctrl_conf_local[:,  img_id],
                    'track_pts3d_local': [pts3d_t_l[[t]] for t in range(num_images)],
                    'track_conf_local':  [conf_t_l[[t]]  for t in range(num_images)],
                })
            results[img_id] = res

        return results


    

```


The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.
Copied!