```
├── LICENSE.txt (2.3k tokens)
├── LICENSE_MODEL.txt (100 tokens)
├── README.md (1800 tokens)
├── assets/
├── interactive_monkeys.gif
├── panel_screenshot.png
├── teaser.png
├── configs/
├── eval.yaml (400 tokens)
├── examples/
├── input/
├── cabinet/
├── 000.png
├── 001.png
├── 002.png
├── 003.png
├── 004.png
├── 005.png
├── 006.png
├── 007.png
├── 008.png
├── 009.png
├── 010.png
├── corgi/
├── 000.png
├── 001.png
├── dog/
├── 000.png
├── 001.png
├── elephant/
├── 000.png
├── 001.png
├── 002.png
├── 003.png
├── 004.png
├── 005.png
├── 006.png
├── 007.png
├── 008.png
├── 009.png
├── 010.png
├── 011.png
├── 012.png
├── 013.png
├── 014.png
├── 015.png
├── 016.png
├── 017.png
├── 018.png
├── 019.png
├── 020.png
├── 021.png
├── 022.png
├── 023.png
├── 024.png
├── 025.png
├── 026.png
├── 027.png
├── 028.png
├── 029.png
├── monkeys/
├── 000.png
├── 001.png
├── 002.png
├── 003.png
├── 004.png
├── 005.png
├── 006.png
├── 007.png
├── 008.png
├── 009.png
├── 010.png
├── 011.png
├── 012.png
├── 013.png
├── 014.png
├── 015.png
├── 016.png
├── 017.png
├── 018.png
├── 019.png
├── 020.png
├── 021.png
├── 022.png
├── 023.png
├── 024.png
├── 025.png
├── 026.png
├── 027.png
├── 028.png
├── 029.png
├── 030.png
├── parkour/
├── 000.png
├── 001.png
├── scripts/
├── infer.py (2k tokens)
├── user_mask.py (1700 tokens)
├── view.py (5.2k tokens)
├── trace_anything/
├── __init__.py (100 tokens)
├── heads.py (1500 tokens)
├── layers/
├── __init__.py (100 tokens)
├── blocks.py (3.3k tokens)
├── dpt_block.py (3.2k tokens)
├── patch_embed.py (800 tokens)
├── pos_embed.py (1500 tokens)
├── trace_anything.py (3.2k tokens)
```
## /LICENSE.txt
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
## /LICENSE_MODEL.txt
Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0)
The model weights, including all .pth, .pt, .bin, .safetensors, or other checkpoint files
released with this repository, are licensed under the Creative Commons
Attribution-NonCommercial 4.0 International License.
You may obtain a copy of this license at:
https://creativecommons.org/licenses/by-nc/4.0/
You are free to share and adapt the model weights for non-commercial purposes,
provided that you give appropriate credit and indicate if changes were made.
Commercial use of the model weights is prohibited.
© 2025 Bytedance Ltd. and/or its affiliates.
## /README.md
# Trace Anything: Representing Any Video in 4D via Trajectory Fields
<p align="center">
<a href="https://trace-anything.github.io/">
<img src="https://img.shields.io/badge/Project%20Page-222222?logo=googlechrome&logoColor=white" alt="Project Page">
</a>
<a href="https://arxiv.org/abs/2510.13802">
<img src="https://img.shields.io/badge/arXiv-b31b1b?logo=arxiv&logoColor=white" alt="arXiv">
</a>
<a href="https://youtu.be/J6y5l2E6qjA">
<img src="https://img.shields.io/badge/YouTube-ea3323?logo=youtube&logoColor=white" alt="YouTube Video">
</a>
<a href="https://trace-anything.github.io/viser-client/interactive.html">
<img src="https://img.shields.io/badge/🖐️ Interactive%20Results-2b7a78?logoColor=white" alt="Interactive Results">
</a>
<a href="https://huggingface.co/depth-anything/trace-anything">
<img src="https://img.shields.io/badge/Model-f4b400?logo=huggingface&logoColor=black" alt="Hugging Face Model">
</a>
</p>
<div align="center" class="is-size-5 publication-authors">
<span class="author-block">
<a href="https://xinhangliu.com/">Xinhang Liu</a><sup>1,2</sup>
</span>
<span class="author-block">
<a href="https://henry123-boy.github.io/">Yuxi Xiao</a><sup>1,3</sup>
</span>
<span class="author-block">
<a href="https://donydchen.github.io/">Donny Y. Chen</a><sup>1</sup>
</span>
<span class="author-block">
<a href="https://scholar.google.com.sg/citations?user=Q8iay0gAAAAJ&hl=en">Jiashi Feng</a><sup>1</sup>
</span>
<br>
<span class="author-block">
<a href="https://yuwingtai.github.io/">Yu-Wing Tai</a><sup>4</sup>
</span>
<span class="author-block">
<a href="https://cse.hkust.edu.hk/~cktang/bio.html">Chi-Keung Tang</a><sup>2</sup>
</span>
<span class="author-block">
<a href="https://bingykang.github.io/">Bingyi Kang</a><sup>1</sup>
</span>
</div>
<br>
<div align="center" class="is-size-5 publication-authors">
<span class="author-block"><sup>1</sup>Bytedance Seed</span>
<span class="author-block"><sup>2</sup>HKUST</span>
<span class="author-block"><sup>3</sup>Zhejiang University</span>
<span class="author-block"><sup>4</sup>Dartmouth College</span>
</div>
## Overview
We propose a 4D video representation, __trajectory field__, which maps each pixel across frames to a continuous, parametric 3D trajectory. With a single forward pass, the __Trace Anything__ model efficiently estimates such trajectory fields for any video, image pair, or unstructured image set.
This repository provides the official PyTorch implementation for running inference with the Trace Anything model and exploring trajectory fields in an interactive 3D viewer.

## Setup
### Create and activate environment
```bash
# Clone the repository
git clone https://github.com/ByteDance-Seed/TraceAnything.git
cd TraceAnything
# Create and activate environment
conda create -n trace_anything python=3.10
conda activate trace_anything
```
### Requirements
* **Python** ≥ 3.10
* **PyTorch** (install according to your CUDA/CPU setup)
* **Dependencies**:
```bash
pip install einops omegaconf pillow opencv-python viser imageio matplotlib torchvision
```
**Notes**
- **CUDA:** Tested with **CUDA 12.8**.
- **GPU Memory:** The provided examples are tested to run on a **single GPU with ≥ 48 GB VRAM**.
### Model weights
Download the pretrained **[model](https://huggingface.co/depth-anything/trace-anything/resolve/main/trace_anything.pt?download=true)** and place it at:
```text
checkpoints/trace_anything.pt
```
## Inference
We provide example input videos and image pairs under `examples/input`.
Each subdirectory corresponds to a scene:
```
examples/
input/
scene_name_1/
...
scene_name_2/
...
```
The inference script loads images from these scene folders and produces outputs.
---
### Notes
* Images must satisfy `W ≥ H`. (Portrait images are automatically transposed.)
* Images are resized so that the long side = **512**, then cropped to the nearest multiple of 16 (a model requirement).
* If the number of views exceeds 40, the script automatically downsamples.
* (Advanced) The script assumes input images are ordered in time (e.g., video frames or paired images). Support for unstructured, unordered inputs will be released in the future.
---
### Running inference
Run the model over all scenes:
```bash
python scripts/infer.py
```
#### Default arguments
You can override these paths with flags:
* `--config configs/eval.yaml`
* `--ckpt checkpoints/trace_anything.pt`
* `--input_dir examples/input`
* `--output_dir examples/output`
#### Example
```bash
python scripts/infer.py \
--input_dir examples/input \
--output_dir examples/output \
--ckpt checkpoints/trace_anything.pt
```
Results are saved to:
```text
<output_dir>/<scene>/output.pt
```
---
### What’s inside `output.pt`?
* `preds[i]['ctrl_pts3d']` — 3D control points, shape `[K, H, W, 3]`
* `preds[i]['ctrl_conf']` — confidence maps, shape `[K, H, W]`
* `preds[i]['fg_mask']` — binary mask `[H, W]`, computed via Otsu thresholding on control-point variance.
(Mask images are also saved under `<output_dir>/<scene>/masks`.)
* `preds[i]['time']` — predicted scalar time ∈ `[0, 1)`.
> Even though the true timestamp is implicit from known sequence order, the network’s timestamp head still estimates it.
* `views[i]['img']` — normalized input image tensor ∈ `[-1, 1]`
## Optional: User-Guided Masks with SAM2
If you prefer **user-guided SAM2 masks** instead of the automatic masks computed from Trace Anything outputs (for visualization), we provide a helper script [`scripts/user_mask.py`](scripts/user_mask.py). This script lets you interactively select points on the first frame of a scene to produce per-frame foreground masks.
Install [SAM2](https://github.com/facebookresearch/sam2) and download its checkpoint. Then run with:
```bash
python scripts/user_mask.py --scene <output_scene_dir> \
--sam2_cfg configs/sam2.1/sam2.1_hiera_l.yaml \
--sam2_ckpt <path_to_sam2_ckpt>
```
This will saves masks to:
```
<scene>/masks/{i:03d}_user.png
```
It also updates `<scene>/output.pt` with:
```python
preds[i]["fg_mask_user"]
```
When visualizing, `fg_mask_user` will automatically be preferred over `fg_mask` if available.
## Interactive Visualization 🚀
Our visualizer lets you explore the trajectory field interactively:

Fire up the interactive 3D viewer and dive your trajectory fields:
```bash
python scripts/view.py --output examples/output/<scene>/output.pt
```
### Useful flags
* `--port 8020` — set viewer port
* `--t_step 0.025` — timeline step (smaller = more fine-grained curve evaluation)
* `--ds 2` — downsample all data by `::2` for extra speed
### Remote use (SSH port-forwarding)
```bash
ssh -N -L 8020:localhost:8020 <user>@<server>
# Then open http://localhost:8020 locally
```
### Trajectory panel
Input a frame number, or simply type `"mid"` / `"last"`.
Then hit **Build / Refresh** to construct trajectories, and toggle **Show trajectories** to view them.

### Play around! 🎉
* Pump up or shrink point size
* Filter out noisy background / foreground points by confidence
* Drag to swivel the viewpoint
* Slide through time and watch the trajectories evolve
## Acknowledgements
We sincerely thank the authors of the open-source repositories [DUSt3R](https://github.com/naver/dust3r), [Fast3R](https://github.com/facebookresearch/fast3r), [VGGT](https://github.com/facebookresearch/vggt), [MonST3R](https://github.com/Junyi42/monst3r), [Easi3R](https://github.com/Inception3D/Easi3R), [St4RTrack](https://github.com/HavenFeng/St4RTrack?tab=readme-ov-file), [POMATO](https://github.com/wyddmw/POMATO?tab=readme-ov-file), [SpaTrackerV2](https://github.com/henry123-boy/SpaTrackerV2) and [Viser](https://github.com/nerfstudio-project/viser) for their inspiring and high-quality work that greatly contributed to this project.
## License
- **Code**: Licensed under the [Apache 2.0 License](http://www.apache.org/licenses/LICENSE-2.0).
- **Model weights**: Licensed under the [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/). These weights are provided for research and non-commercial use only.
## Citation
If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:
```bibtex
@misc{liu2025traceanythingrepresentingvideo,
title={Trace Anything: Representing Any Video in 4D via Trajectory Fields},
author={Xinhang Liu and Yuxi Xiao and Donny Y. Chen and Jiashi Feng and Yu-Wing Tai and Chi-Keung Tang and Bingyi Kang},
year={2025},
eprint={2510.13802},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2510.13802},
}
```
## /assets/interactive_monkeys.gif
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/assets/interactive_monkeys.gif
## /assets/panel_screenshot.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/assets/panel_screenshot.png
## /assets/teaser.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/assets/teaser.png
## /configs/eval.yaml
```yaml path="/configs/eval.yaml"
# trace_anything/configs/eval.yaml
hydra:
run:
dir: .
job:
chdir: false
output_subdir: null
job_logging:
disable_existing_loggers: true
root:
handlers: []
handlers: {}
verbose: false
ckpt_path: checkpoints/trace_anything.ckpt
model:
_target_: fast3r.models.multiview_dust3r_module.MultiViewDUSt3RLitModule
net:
_target_: fast3r.models.fast3r.Fast3R
encoder_args:
encoder_type: croco
img_size: 512
patch_size: 16
patch_embed_cls: ManyAR_PatchEmbed
embed_dim: 1024
num_heads: 16
depth: 24
mlp_ratio: 4
pos_embed: RoPE100
attn_implementation: flash_attention
decoder_args:
decoder_type: fast3r
random_image_idx_embedding: false
enc_embed_dim: 1024
embed_dim: 1024
num_heads: 16
depth: 24
mlp_ratio: 4.0
qkv_bias: true
drop: 0.0
attn_drop: 0.0
attn_implementation: flash_attention
head_args:
head_type: dpt
output_mode: pts3d
landscape_only: true
depth_mode: [exp, -.inf, .inf]
conf_mode: [exp, 1, .inf]
patch_size: 16
with_local_head: true
freeze: none
targeting_mechanism: bspline_conf
poly_degree: 10
train_criterion:
_target_: fast3r.dust3r.losses.ConfLossMultiviewV2
pixel_loss:
_target_: fast3r.dust3r.losses.Regr3DMultiviewV3
criterion:
_target_: fast3r.dust3r.losses.L21Loss
norm_mode: avg_dis
alpha: 0.2
validation_criterion:
_target_: fast3r.dust3r.losses.ConfLossMultiviewV2
pixel_loss:
_target_: fast3r.dust3r.losses.Regr3DMultiviewV3
criterion:
_target_: fast3r.dust3r.losses.L21Loss
norm_mode: avg_dis
alpha: 0.2
optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 1.0e-4
betas: [0.9, 0.95]
weight_decay: 0.05
scheduler:
_target_: pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR
_partial_: true
warmup_epochs: 1
max_epochs: 10
eta_min: 1.0e-6
compile: false
```
## /examples/input/cabinet/000.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/000.png
## /examples/input/cabinet/001.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/001.png
## /examples/input/cabinet/002.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/002.png
## /examples/input/cabinet/003.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/003.png
## /examples/input/cabinet/004.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/004.png
## /examples/input/cabinet/005.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/005.png
## /examples/input/cabinet/006.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/006.png
## /examples/input/cabinet/007.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/007.png
## /examples/input/cabinet/008.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/008.png
## /examples/input/cabinet/009.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/009.png
## /examples/input/cabinet/010.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/cabinet/010.png
## /examples/input/corgi/000.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/corgi/000.png
## /examples/input/corgi/001.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/corgi/001.png
## /examples/input/dog/000.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/dog/000.png
## /examples/input/dog/001.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/dog/001.png
## /examples/input/elephant/000.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/000.png
## /examples/input/elephant/001.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/001.png
## /examples/input/elephant/002.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/002.png
## /examples/input/elephant/003.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/003.png
## /examples/input/elephant/004.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/004.png
## /examples/input/elephant/005.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/005.png
## /examples/input/elephant/006.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/006.png
## /examples/input/elephant/007.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/007.png
## /examples/input/elephant/008.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/008.png
## /examples/input/elephant/009.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/009.png
## /examples/input/elephant/010.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/010.png
## /examples/input/elephant/011.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/011.png
## /examples/input/elephant/012.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/012.png
## /examples/input/elephant/013.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/013.png
## /examples/input/elephant/014.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/014.png
## /examples/input/elephant/015.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/015.png
## /examples/input/elephant/016.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/016.png
## /examples/input/elephant/017.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/017.png
## /examples/input/elephant/018.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/018.png
## /examples/input/elephant/019.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/019.png
## /examples/input/elephant/020.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/020.png
## /examples/input/elephant/021.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/021.png
## /examples/input/elephant/022.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/022.png
## /examples/input/elephant/023.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/023.png
## /examples/input/elephant/024.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/024.png
## /examples/input/elephant/025.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/025.png
## /examples/input/elephant/026.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/026.png
## /examples/input/elephant/027.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/027.png
## /examples/input/elephant/028.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/028.png
## /examples/input/elephant/029.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/elephant/029.png
## /examples/input/monkeys/000.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/000.png
## /examples/input/monkeys/001.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/001.png
## /examples/input/monkeys/002.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/002.png
## /examples/input/monkeys/003.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/003.png
## /examples/input/monkeys/004.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/004.png
## /examples/input/monkeys/005.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/005.png
## /examples/input/monkeys/006.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/006.png
## /examples/input/monkeys/007.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/007.png
## /examples/input/monkeys/008.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/008.png
## /examples/input/monkeys/009.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/009.png
## /examples/input/monkeys/010.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/010.png
## /examples/input/monkeys/011.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/011.png
## /examples/input/monkeys/012.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/012.png
## /examples/input/monkeys/013.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/013.png
## /examples/input/monkeys/014.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/014.png
## /examples/input/monkeys/015.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/015.png
## /examples/input/monkeys/016.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/016.png
## /examples/input/monkeys/017.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/017.png
## /examples/input/monkeys/018.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/018.png
## /examples/input/monkeys/019.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/019.png
## /examples/input/monkeys/020.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/020.png
## /examples/input/monkeys/021.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/021.png
## /examples/input/monkeys/022.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/022.png
## /examples/input/monkeys/023.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/023.png
## /examples/input/monkeys/024.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/024.png
## /examples/input/monkeys/025.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/025.png
## /examples/input/monkeys/026.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/026.png
## /examples/input/monkeys/027.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/027.png
## /examples/input/monkeys/028.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/028.png
## /examples/input/monkeys/029.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/029.png
## /examples/input/monkeys/030.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/monkeys/030.png
## /examples/input/parkour/000.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/parkour/000.png
## /examples/input/parkour/001.png
Binary file available at https://raw.githubusercontent.com/ByteDance-Seed/TraceAnything/refs/heads/main/examples/input/parkour/001.png
## /scripts/infer.py
```py path="/scripts/infer.py"
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# scripts/infer.py
"""
Run inference on all scenes and save:
- <scene>/output.pt with {'preds','views'}
- <scene>/masks/{i:03d}.png (binary FG masks)
- <scene>/images/{i:03d}.png (RGB frames used for inference)
Masks are computed from ctrl-pt variance + smart Otsu.
"""
import sys, pathlib
sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
import os
import cv2
import time
import argparse
from typing import List, Dict
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as tvf
from omegaconf import OmegaConf
from trace_anything.trace_anything import TraceAnything
def _pretty(msg: str) -> None:
print(f"[{time.strftime('%H:%M:%S')}] {msg}")
# allow ${python_eval: ...} in YAML if used
OmegaConf.register_new_resolver("python_eval", lambda code: eval(code))
# ---------------- image I/O ----------------
def _resize_long_side(pil: Image.Image, long: int = 512) -> Image.Image:
w, h = pil.size
if w >= h:
return pil.resize((long, int(h * long / w)), Image.BILINEAR)
else:
return pil.resize((int(w * long / h), long), Image.BILINEAR)
def _load_images(input_dir: str, device: torch.device) -> List[Dict]:
"""Read images, rotate portrait->landscape, resize(long=512), crop to 16-multiple, normalize [-1,1]."""
tfm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5,)*3, (0.5,)*3)])
fnames = sorted(
f for f in os.listdir(input_dir)
if f.lower().endswith((".png", ".jpg", ".jpeg")) and "_vis" not in f
)
if not fnames:
raise FileNotFoundError(f"No images in {input_dir}")
views, target = [], None
for i, f in enumerate(fnames):
arr = cv2.imread(os.path.join(input_dir, f))
if arr is None:
raise FileNotFoundError(f"Failed to read {f}")
pil = Image.fromarray(cv2.cvtColor(arr, cv2.COLOR_BGR2RGB))
W0, H0 = pil.size
if H0 > W0: # portrait -> landscape
pil = pil.transpose(Image.Transpose.ROTATE_90)
pil = _resize_long_side(pil, 512)
if target is None:
H, W = pil.size[1], pil.size[0]
target = (H - H % 16, W - W % 16)
_pretty(f"📐 target size: {target[0]}x{target[1]} (16-multiple)")
Ht, Wt = target
pil = pil.crop((0, 0, Wt, Ht))
tensor = tfm(pil).unsqueeze(0).to(device) # [1,3,H,W]
t = i / (len(fnames) - 1) if len(fnames) > 1 else 0.0
views.append({"img": tensor, "time_step": t})
return views
# ---------------- ckpt + model ----------------
def _get_state_dict(ckpt: dict) -> dict:
"""Accept either a pure state_dict or a Lightning .ckpt."""
if isinstance(ckpt, dict) and "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
return ckpt["state_dict"]
return ckpt
def _load_cfg(cfg_path: str):
if not os.path.isfile(cfg_path):
raise FileNotFoundError(cfg_path)
return OmegaConf.load(cfg_path)
def _to_dict(x):
# OmegaConf -> plain dict
return OmegaConf.to_container(x, resolve=True) if not isinstance(x, dict) else x
def _build_model_from_cfg(cfg, ckpt_path: str, device: torch.device) -> torch.nn.Module:
if not os.path.isfile(ckpt_path):
raise FileNotFoundError(ckpt_path)
# net config
net_cfg = cfg.get("model", {}).get("net", None) or cfg.get("net", None)
if net_cfg is None:
raise KeyError("expect cfg.model.net or cfg.net in YAML")
model = TraceAnything(
encoder_args=_to_dict(net_cfg["encoder_args"]),
decoder_args=_to_dict(net_cfg["decoder_args"]),
head_args=_to_dict(net_cfg["head_args"]),
targeting_mechanism=net_cfg.get("targeting_mechanism", "bspline_conf"),
poly_degree=net_cfg.get("poly_degree", 10),
whether_local=False,
)
ckpt = torch.load(ckpt_path, map_location="cpu")
sd = _get_state_dict(ckpt)
if all(k.startswith("net.") for k in sd.keys()):
sd = {k[4:]: v for k, v in sd.items()}
model.load_state_dict(sd, strict=False)
model.to(device).eval()
return model
# ---------------- smart var threshold ----------------
def _otsu_threshold_from_hist(hist: np.ndarray, bin_edges: np.ndarray) -> float | None:
total = hist.sum()
if total <= 0:
return None
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0
w1 = np.cumsum(hist)
w2 = total - w1
sum_total = (hist * bin_centers).sum()
sumB = np.cumsum(hist * bin_centers)
valid = (w1 > 0) & (w2 > 0)
if not np.any(valid):
return None
m1 = sumB[valid] / w1[valid]
m2 = (sum_total - sumB[valid]) / w2[valid]
between = w1[valid] * w2[valid] * (m1 - m2) ** 2
idx = np.argmax(between)
return float(bin_centers[valid][idx])
def _smart_var_threshold(var_map_t: torch.Tensor) -> float:
"""
1) log-transform variance
2) Otsu on histogram
3) fallback to 65–80% mid-quantile midpoint
Returns threshold in original variance domain.
"""
var_np = var_map_t.detach().float().cpu().numpy()
v = np.log(var_np + 1e-9)
hist, bin_edges = np.histogram(v, bins=256)
thr_log = _otsu_threshold_from_hist(hist, bin_edges)
if thr_log is None or not np.isfinite(thr_log):
q65 = float(np.quantile(var_np, 0.65))
q80 = float(np.quantile(var_np, 0.80))
return 0.5 * (q65 + q80)
thr_var = float(np.exp(thr_log))
q40 = float(np.quantile(var_np, 0.40))
q95 = float(np.quantile(var_np, 0.95))
return max(q40, min(q95, thr_var))
# ---------------- main loop ----------------
def run(args):
base_in = args.input_dir
base_out = args.output_dir
if not os.path.isdir(base_in):
raise FileNotFoundError(base_in)
os.makedirs(base_out, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# config & model
cfg = _load_cfg(args.config)
_pretty("🔧 loading model …")
model = _build_model_from_cfg(cfg, ckpt_path=args.ckpt, device=device)
_pretty("✅ model ready")
# iterate scenes
for scene in sorted(os.listdir(base_in)):
in_dir = os.path.join(base_in, scene)
if not os.path.isdir(in_dir):
continue
out_dir = os.path.join(base_out, scene)
masks_dir = os.path.join(out_dir, "masks")
images_dir = os.path.join(out_dir, "images")
os.makedirs(out_dir, exist_ok=True)
os.makedirs(masks_dir, exist_ok=True)
os.makedirs(images_dir, exist_ok=True)
_pretty(f"\n📂 Scene: {scene}")
_pretty("🖼️ loading images …")
views = _load_images(in_dir, device=device)
if len(views) > 40:
stride = max(1, len(views) // 39) # floor division
views = views[::stride]
_pretty(f"🧮 {len(views)} views loaded")
_pretty("🚀 inference …")
t0 = time.perf_counter()
with torch.no_grad():
preds = model.forward(views)
dt = time.perf_counter() - t0
ms_per_view = (dt / max(1, len(views))) * 1000.0
_pretty(f"✅ done | {dt:.2f}s total | {ms_per_view:.1f} ms/view")
# ---- compute + save FG masks and images ----
_pretty("🧪 computing FG masks + saving frames …")
for i, pred in enumerate(preds):
# variance map over control points (K), mean over xyz -> [H,W]
ctrl_pts3d = pred["ctrl_pts3d"]
ctrl_pts3d_t = torch.from_numpy(ctrl_pts3d) if isinstance(ctrl_pts3d, np.ndarray) else ctrl_pts3d
var_map = torch.var(ctrl_pts3d_t, dim=0, unbiased=False).mean(-1) # [H,W]
thr = _smart_var_threshold(var_map)
fg_mask = (~(var_map <= thr)).detach().cpu().numpy().astype(bool)
# save mask as binary PNG and stash in preds
cv2.imwrite(os.path.join(masks_dir, f"{i:03d}.png"), (fg_mask.astype(np.uint8) * 255))
pred["fg_mask"] = torch.from_numpy(fg_mask) # CPU bool tensor
# also save the RGB image we actually ran on
img = views[i]["img"].detach().cpu().squeeze(0) # [3,H,W] in [-1,1]
img_np = (img.permute(1, 2, 0).numpy() + 1.0) * 127.5
img_uint8 = np.clip(img_np, 0, 255).astype(np.uint8)
cv2.imwrite(os.path.join(images_dir, f"{i:03d}.png"), cv2.cvtColor(img_uint8, cv2.COLOR_RGB2BGR))
# trim heavy intermediates just in case
pred.pop("track_pts3d", None)
pred.pop("track_conf", None)
# persist
save_path = os.path.join(out_dir, "output.pt")
torch.save({"preds": preds, "views": views}, save_path)
_pretty(f"💾 saved: {save_path}")
_pretty(f"🖼️ masks → {masks_dir} | images → {images_dir}")
def parse_args():
p = argparse.ArgumentParser("TraceAnything inference")
p.add_argument("--config", type=str, default="configs/eval.yaml",
help="Path to YAML config")
p.add_argument("--ckpt", type=str, default="checkpoints/trace_anything.pt",
help="Path to the checkpoint")
p.add_argument("--input_dir", type=str, default="./examples/input",
help="Directory containing scenes (each subfolder is a scene)")
p.add_argument("--output_dir", type=str, default="./examples/output",
help="Directory to write scene outputs")
return p.parse_args()
def main():
args = parse_args()
run(args)
if __name__ == "__main__":
main()
```
## /scripts/user_mask.py
```py path="/scripts/user_mask.py"
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python3
# scripts/user_mask.py
"""
Interactive points on frame 0 -> SAM2 video propagation -> user masks.
Saves per-frame masks to <scene>/masks/{i:03d}_user.png
and stores preds[i]["fg_mask_user"] in <scene>/output.pt.
"""
import sys, pathlib; sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
import os, argparse, shutil, tempfile, inspect
from typing import Dict, List, Optional
import numpy as np, torch, cv2
def dilate_mask(mask: np.ndarray, ksize: int = 3, iterations: int = 2) -> np.ndarray:
kernel = np.ones((ksize, ksize), np.uint8)
return cv2.dilate(mask.astype(np.uint8), kernel, iterations=iterations).astype(bool)
def _pretty(msg:str):
import time; print(f"[{time.strftime('%H:%M:%S')}] {msg}")
def load_output(scene:str)->Dict:
p=os.path.join(scene,"output.pt");
if not os.path.isfile(p): raise FileNotFoundError(p)
out=torch.load(p,map_location="cpu"); out["_root_dir"]=scene; return out
def tensor_to_rgb(img3hw:torch.Tensor)->np.ndarray:
arr=(img3hw.detach().cpu().permute(1,2,0).numpy()+1.0)*127.5
return np.clip(arr,0,255).astype(np.uint8)
def build_tmp_jpegs(frames_rgb: List[np.ndarray], scene_dir: str) -> str:
tmp = os.path.join(scene_dir, "tmp_jpg")
os.makedirs(tmp, exist_ok=True)
for i, fr in enumerate(frames_rgb):
fn = os.path.join(tmp, f"{i:06d}.jpg")
cv2.imwrite(fn, cv2.cvtColor(fr, cv2.COLOR_RGB2BGR), [int(cv2.IMWRITE_JPEG_QUALITY), 95])
return tmp
def ask_points(frame_bgr:np.ndarray, preview_dir:str)->np.ndarray:
H,W=frame_bgr.shape[:2]
while True:
print("Enter points as: x y (ENTER empty line to finish)")
pts=[]
while True:
s=input(">>> ").strip()
if s=="": break
try:
x,y=map(float,s.split()); pts.append([x,y,1.0])
except: print(" ex: 320 180")
if not pts:
a=input("No points. [q]=quit, [r]=retry > ").strip().lower()
if a.startswith("q"): raise SystemExit(0)
else: continue
pts=np.array(pts,dtype=np.float32)
pts[:,0]=np.clip(pts[:,0],0,W-1); pts[:,1]=np.clip(pts[:,1],0,H-1)
vis=frame_bgr.copy()
for x,y,_ in pts:
p=(int(round(x)),int(round(y)))
cv2.circle(vis,p,4,(0,255,0),-1,cv2.LINE_AA); cv2.circle(vis,p,8,(0,255,0),1,cv2.LINE_AA)
prev_path = os.path.join(preview_dir, "prompts_preview.png")
cv2.imwrite(prev_path, vis)
print(f"Preview: {os.path.abspath(prev_path)}")
if input("Accept? [y]/n/q: ").strip().lower().startswith("y"): return pts
if input("Retry or quit? [r]/q: ").strip().lower().startswith("q"): raise SystemExit(0)
from sam2.build_sam import build_sam2_video_predictor
def run_propagation(model_cfg:str, ckpt_path:str, jpg_dir:str, points_xy1:np.ndarray,
ann_frame_idx:int=0, ann_obj_id:int=1)->List[np.ndarray]:
if not os.path.isfile(ckpt_path): raise FileNotFoundError(ckpt_path)
predictor=build_sam2_video_predictor(model_cfg, ckpt_path)
coords=points_xy1[:, :2][None,...].astype(np.float32); labels=points_xy1[:, 2][None,...].astype(np.int32)
use_cuda=torch.cuda.is_available()
autocast=torch.autocast("cuda",dtype=torch.bfloat16) if use_cuda else torch.cuda.amp.autocast(enabled=False)
with torch.inference_mode(), autocast:
state=predictor.init_state(jpg_dir)
predictor.add_new_points_or_box(inference_state=state, frame_idx=ann_frame_idx,
obj_id=ann_obj_id, points=coords, labels=labels)
n=len([f for f in os.listdir(jpg_dir) if f.lower().endswith(".jpg")])
masks:List[Optional[np.ndarray]]=[None]*n
sig=inspect.signature(predictor.propagate_in_video)
kwargs={"inference_state":state}
if "start_frame_idx" in sig.parameters and "end_frame_idx" in sig.parameters:
kwargs.update(start_frame_idx=ann_frame_idx,end_frame_idx=n-1)
for yielded in predictor.propagate_in_video(**kwargs):
if isinstance(yielded,tuple) and len(yielded)==3:
fi,obj_ids,ms=yielded
else:
fi=int(yielded["frame_idx"]); obj_ids=yielded.get("object_ids") or yielded.get("obj_ids"); ms=yielded["masks"]
pick=None
for oid,m in zip(obj_ids,ms):
if int(oid)==ann_obj_id: pick=m; break
if pick is None:
pick = masks[0]
if isinstance(pick,torch.Tensor): pick=pick.detach().cpu().numpy()
pick=np.asarray(pick);
if pick.ndim==3 and pick.shape[0]==1: pick=pick[0]
masks[int(fi)]=(pick>0.5)
if hasattr(predictor,"get_frame_masks"):
for i in range(n):
if masks[i] is None:
oids,ms=predictor.get_frame_masks(state,i)
pick=None
for oid,m in zip(oids,ms):
if int(oid)==ann_obj_id: pick=m; break
pick=pick or (ms[0] if ms else None)
if isinstance(pick,torch.Tensor): pick=pick.detach().cpu().numpy()
if pick is None: continue
pick=np.asarray(pick)
if pick.ndim==3 and pick.shape[0]==1: pick=pick[0]
masks[i]=(pick>0.5)
H=W=None
for m in masks:
if m is not None: H,W=m.shape; break
if H is None: raise RuntimeError("SAM2 produced no masks.")
return [dilate_mask(m) if m is not None else np.zeros((H, W), bool) for m in masks]
def save_user_masks(scene:str, masks:List[np.ndarray], preds:List[Dict], views:List[Dict]):
mdir=os.path.join(scene,"masks"); os.makedirs(mdir,exist_ok=True)
for i,m in enumerate(masks):
cv2.imwrite(os.path.join(mdir,f"{i:03d}_user.png"), (m.astype(np.uint8)*255))
preds[i]["fg_mask_user"]=torch.from_numpy(m.astype(bool))
torch.save({"preds":preds,"views":views}, os.path.join(scene,"output.pt"))
def parse_args():
p=argparse.ArgumentParser("User mask via SAM2 video propagation")
p.add_argument("--scene", type=str, default="./examples/output/breakdance",
help="Scene dir containing output.pt")
p.add_argument("--sam2_cfg", type=str, default="configs/sam2.1/sam2.1_hiera_l.yaml",
help="SAM2 config (string or path; passed through)")
p.add_argument("--sam2_ckpt", type=str, default="../sam2/checkpoints/sam2.1_hiera_large.pt",
help="Path to SAM2 checkpoint .pt")
return p.parse_args()
def main():
args=parse_args()
out=load_output(args.scene); preds,views=out["preds"],out["views"]
if not preds: return _pretty("No frames in output.pt")
frames=[tensor_to_rgb(views[i]["img"].squeeze(0)) for i in range(len(views))]
frame0_bgr=cv2.cvtColor(frames[0], cv2.COLOR_RGB2BGR)
tmp_dir = build_tmp_jpegs(frames, args.scene)
pts_xy1 = ask_points(frame0_bgr, tmp_dir)
try:
_pretty("Propagating with SAM2 …")
masks = run_propagation(
model_cfg=args.sam2_cfg,
ckpt_path=args.sam2_ckpt,
jpg_dir=tmp_dir,
points_xy1=pts_xy1,
ann_frame_idx=0,
ann_obj_id=1,
)
save_user_masks(args.scene, masks, preds, views)
_pretty(f"✅ Saved user masks to {os.path.join(args.scene, 'masks')} and updated output.pt")
_pretty(f"🗂️ Kept JPEG frames and preview in: {tmp_dir}")
finally:
# shutil.rmtree(tmp,ignore_errors=True)
_pretty("🧹 Not removing temp JPEG folder (preview included)")
if __name__=="__main__": main()
```
## /scripts/view.py
```py path="/scripts/view.py"
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# scripts/view.py
# TraceAnything viewer (viser) – no mask recomputation.
# - Loads FG masks from (priority):
# 1) <scene>/masks/{i:03d}_user.png or preds[i]["fg_mask_user"]
# 2) <scene>/masks/{i:03d}.png or preds[i]["fg_mask"]
# - BG mask = ~FG
# - Saves per-frame images next to output.pt (does NOT recompute masks)
# - Initial conf filtering: drop bottom 10% (FG/BG)
# - Downsample everything with --ds (H,W stride)
import sys, pathlib
sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
import os
import time
import threading
import argparse
from typing import Dict, List, Tuple, Optional
import numpy as np
import torch
import cv2
import viser
# ---- tiny helpers ----
def to_numpy(x):
if isinstance(x, torch.Tensor):
return x.detach().cpu().numpy()
return np.asarray(x)
def as_float(x):
a = np.asarray(x)
return float(a.reshape(-1)[0])
def ensure_dir(p):
os.makedirs(p, exist_ok=True)
return p
# ---- tiny HSV colormap (no heavy deps) ----
import colorsys
def hsv_colormap(vals01: np.ndarray) -> np.ndarray:
"""vals01: [T] in [0,1] -> [T,3] RGB in [0,1]"""
vals01 = np.clip(vals01, 0.0, 1.0)
rgb = [colorsys.hsv_to_rgb(v, 1.0, 1.0) for v in vals01]
return np.asarray(rgb, dtype=np.float32)
# --- repo fn ---
from trace_anything.trace_anything import evaluate_bspline_conf
# ----------------------- I/O -----------------------
def load_output_dict(path_or_dir: str) -> Dict:
path = path_or_dir
if os.path.isdir(path):
path = os.path.join(path, "output.pt")
if not os.path.isfile(path):
raise FileNotFoundError(path)
out = torch.load(path, map_location="cpu")
out["_root_dir"] = os.path.dirname(path)
return out
def save_images(frames: List[Dict], images_dir: str):
print(f"[viewer] Saving images to: {images_dir}")
ensure_dir(images_dir)
for i, fr in enumerate(frames):
cv2.imwrite(os.path.join(images_dir, f"{i:03d}.png"),
cv2.cvtColor(fr["img_rgb_uint8"], cv2.COLOR_RGB2BGR))
# ----------------- mask loading (NO recomputation) -----------------
_REPORTED_MASK_SRC: set[tuple[str, int]] = set()
def _load_fg_mask_for_index(root_dir: str, idx: int, pred: Dict) -> np.ndarray:
"""
Priority:
1) masks/{i:03d}_user.png or pred['fg_mask_user'] (user mask)
2) masks/{i:03d}.png or pred['fg_mask'] (raw mask)
Returns: mask_bool_hw
"""
def _ret(mask_bool: np.ndarray, src: str) -> np.ndarray:
key = (root_dir, idx)
if key not in _REPORTED_MASK_SRC and os.environ.get("TA_SILENCE_MASK_SRC") != "1":
print(f"[viewer] frame {idx:03d}: using {src}")
_REPORTED_MASK_SRC.add(key)
return mask_bool
# 1) USER
# PNG then preds (so external edits override stale preds if any)
p_png = os.path.join(root_dir, "masks", f"{idx:03d}_user.png")
if os.path.isfile(p_png):
arr = cv2.imread(p_png, cv2.IMREAD_GRAYSCALE)
if arr is not None:
# p_png_1 = os.path.join(root_dir, "masks", f"{idx:03d}_user_1.png")
# if os.path.isfile(p_png_1):
# arr1 = cv2.imread(p_png_1, cv2.IMREAD_GRAYSCALE)
# if arr1 is not None and arr1.shape == arr.shape:
# arr = np.maximum(arr, arr1)
return _ret((arr > 0), "user mask (png)")
if "fg_mask_user" in pred and pred["fg_mask_user"] is not None:
m = pred["fg_mask_user"]
if isinstance(m, torch.Tensor): m = m.detach().cpu().numpy()
return _ret(np.asarray(m).astype(bool), "user mask (preds)")
# 2) RAW
p_png = os.path.join(root_dir, "masks", f"{idx:03d}.png")
if os.path.isfile(p_png):
arr = cv2.imread(p_png, cv2.IMREAD_GRAYSCALE)
if arr is not None:
return _ret((arr > 0), "raw mask (png)")
if "fg_mask" in pred and pred["fg_mask"] is not None:
m = pred["fg_mask"]
if isinstance(m, torch.Tensor): m = m.detach().cpu().numpy()
return _ret(np.asarray(m).astype(bool), "raw mask (preds)")
# --- legacy compatibility ---
for key, fname, label in [
("fg_mask_user", f"{idx:03d}_fg_user.png", "user mask (legacy png)"),
("fg_mask", f"{idx:03d}_fg_refined.png", "refined mask (legacy png)"),
("fg_mask_raw", f"{idx:03d}_fg_raw.png", "raw mask (legacy png)"),
]:
if key in pred and pred[key] is not None:
m = pred[key]
if isinstance(m, torch.Tensor): m = m.detach().cpu().numpy()
return _ret(np.asarray(m).astype(bool), f"{label} (preds)")
path = os.path.join(root_dir, "masks", fname)
if os.path.isfile(path):
arr = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
if arr is not None:
return _ret((arr > 0), label)
raise FileNotFoundError(
f"No FG mask for frame {idx}: looked for user/raw (png or preds)."
)
# -------------- precompute tensors for viewer -------------
def build_precomputes(
output: Dict,
t_step: float,
ds: int,
) -> Tuple[np.ndarray, List[Dict], List[np.ndarray], np.ndarray, np.ndarray]:
preds = output["preds"]
views = output["views"]
n = len(preds)
assert n == len(views)
root = output.get("_root_dir", os.getcwd())
# timeline
t_vals = np.arange(0.0, 1.0 + 1e-6, t_step, dtype=np.float32)
if t_vals[-1] >= 1.0:
t_vals[-1] = 0.99
T = len(t_vals)
t_tensor = torch.from_numpy(t_vals)
frames: List[Dict] = []
fg_conf_pool_per_t: List[List[np.ndarray]] = [[] for _ in range(T)]
bg_conf_pool: List[np.ndarray] = []
stride = slice(None, None, ds) # ::ds
for i in range(n):
pred = preds[i]
view = views[i]
# image ([-1,1]) -> uint8 RGB
img = to_numpy(view["img"].squeeze().permute(1, 2, 0))
img_uint8 = np.clip((img + 1.0) * 127.5, 0, 255).astype(np.uint8)
img_uint8 = img_uint8[::ds, ::ds] # downsample for saving/vis
H, W = img_uint8.shape[:2]
HW = H * W
img_flat = (img_uint8.astype(np.float32) / 255.0).reshape(HW, 3)
# load FG mask (prefer user, then raw)
fg_mask = _load_fg_mask_for_index(root, i, pred)
# match resolution to downsampled view
# if mask at full-res and we downsampled, stride it; else resize with nearest
if fg_mask.shape == (H * ds, W * ds) and ds > 1:
fg_mask = fg_mask[::ds, ::ds]
elif fg_mask.shape != (H, W):
fg_mask = cv2.resize(
(fg_mask.astype(np.uint8) * 255),
(W, H),
interpolation=cv2.INTER_NEAREST
) > 0
bg_mask = ~fg_mask
bg_mask_flat = bg_mask.reshape(-1)
fg_mask_flat = fg_mask.reshape(-1)
# control points/conf (K,H,W,[3]) at downsampled stride
ctrl_pts3d = pred["ctrl_pts3d"][:, stride, stride, :] # [K,H,W,3]
ctrl_conf = pred["ctrl_conf"][:, stride, stride] # [K,H,W]
# evaluate curve over T timesteps
pts3d_t, conf_t = evaluate_bspline_conf(ctrl_pts3d, ctrl_conf, t_tensor) # [T,H,W,3], [T,H,W]
pts3d_t = to_numpy(pts3d_t).reshape(T, HW, 3)
conf_t = to_numpy(conf_t).reshape(T, HW)
# FG per t (keep per-t list for later filtering)
pts_fg_per_t = [pts3d_t[t][fg_mask_flat] for t in range(T)]
conf_fg_per_t = [conf_t[t][fg_mask_flat] for t in range(T)]
for t in range(T):
if pts_fg_per_t[t].size > 0:
fg_conf_pool_per_t[t].append(conf_fg_per_t[t])
# BG static
bg_pts = pts3d_t.mean(axis=0)[bg_mask_flat]
bg_conf_mean = conf_t.mean(axis=0)[bg_mask_flat]
bg_conf_pool.append(bg_conf_mean)
frames.append(dict(
img_rgb_uint8=img_uint8,
img_rgb_float=img_flat,
H=H, W=W, HW=HW,
bg_mask_flat=bg_mask_flat,
fg_mask_flat=fg_mask_flat,
pts_fg_per_t=pts_fg_per_t,
conf_fg_per_t=conf_fg_per_t,
bg_pts=bg_pts,
bg_conf_mean=bg_conf_mean,
))
# pools for percentiles
fg_conf_all_t: List[np.ndarray] = []
for t in range(T):
if len(fg_conf_pool_per_t[t]) == 0:
fg_conf_all_t.append(np.empty((0,), dtype=np.float32))
else:
fg_conf_all_t.append(np.concatenate(fg_conf_pool_per_t[t], axis=0).astype(np.float32))
if len(bg_conf_pool):
bg_conf_all_flat = np.concatenate(bg_conf_pool, axis=0).astype(np.float32)
else:
bg_conf_all_flat = np.empty((0,), dtype=np.float32)
# frame times (fallback to views' time_step if missing)
def _get_time(i):
ti = preds[i].get("time", None)
if ti is None:
ti = views[i].get("time_step", float(i / max(1, n - 1)))
return as_float(ti)
times = np.array([_get_time(i) for i in range(n)], dtype=np.float64)
return t_vals, frames, fg_conf_all_t, bg_conf_all_flat, times
def choose_nearest_frame_indices(frame_times: np.ndarray, t_vals: np.ndarray) -> np.ndarray:
return np.array([int(np.argmin(np.abs(frame_times - tv))) for tv in t_vals], dtype=np.int64)
# ---------------- nodes & state ----------------
class ViewerState:
def __init__(self):
self.lock = threading.Lock()
self.is_updating = False
self.status_label = None
self.slider_fg = None
self.slider_bg = None
self.slider_time = None
self.point_size = None
self.bg_point_size = None
self.fg_nodes = [] # len T
self.bg_nodes = [] # len N
# trajectories
self.show_traj = None
self.traj_width = None
self.traj_frames_text = None
self.traj_build_btn = None
self.traj_nodes = []
self.playing = False
self.play_btn = None
self.pause_btn = None
self.fps_slider = None
self.loop_checkbox = None
def build_bg_nodes(server: viser.ViserServer, frames: List[Dict], init_percentile: float,
bg_conf_all_flat: np.ndarray, state: ViewerState):
if bg_conf_all_flat.size == 0:
return
thr_val = np.percentile(bg_conf_all_flat, init_percentile)
for i, fr in enumerate(frames):
keep = fr["bg_conf_mean"] >= thr_val
pts = fr["bg_pts"][keep]
cols = fr["img_rgb_float"][fr["bg_mask_flat"]][keep]
node = server.scene.add_point_cloud(
name=f"/bg/frame{i}",
points=pts,
colors=cols,
point_size=state.bg_point_size.value if state.bg_point_size else 0.0002,
point_shape="rounded",
visible=True,
)
state.bg_nodes.append(node)
print(f"[viewer] BG frame={i:02d}: add {pts.shape[0]} pts")
def build_fg_nodes(server: viser.ViserServer, frames: List[Dict], nearest_idx: np.ndarray, t_vals: np.ndarray,
fg_conf_all_t: List[np.ndarray], init_percentile: float, state: ViewerState):
print("\n[viewer] Building FG nodes per timeline step …")
T = len(t_vals)
for t in range(T):
fi = int(nearest_idx[t])
conf_all = fg_conf_all_t[t]
thr_t = np.percentile(conf_all, init_percentile) if conf_all.size > 0 else np.inf
conf = frames[fi]["conf_fg_per_t"][t]
pts = frames[fi]["pts_fg_per_t"][t]
keep = conf >= thr_t
pts_k = pts[keep]
cols_k = frames[fi]["img_rgb_float"][frames[fi]["fg_mask_flat"]][keep]
node = server.scene.add_point_cloud(
name=f"/fg/t{t:02d}",
points=pts_k,
colors=cols_k,
point_size=state.point_size.value if state.point_size else 0.0002,
point_shape="rounded",
visible=(t == 0),
)
state.fg_nodes.append(node)
print(f"[viewer] t={t:02d}: add FG node with {pts_k.shape[0]} pts")
def update_conf_filtering(server: viser.ViserServer, state: ViewerState, frames: List[Dict], nearest_idx: np.ndarray,
t_vals: np.ndarray, fg_conf_all_t: List[np.ndarray], bg_conf_all_flat: np.ndarray,
fg_percentile: float, bg_percentile: float):
with state.lock:
if state.is_updating:
return
state.is_updating = True
try:
if state.status_label:
state.status_label.value = "⚙️ Filtering… please wait"
if state.slider_fg: state.slider_fg.disabled = True
if state.slider_bg: state.slider_bg.disabled = True
# BG
thr_bg = np.percentile(bg_conf_all_flat, bg_percentile) if bg_conf_all_flat.size > 0 else np.inf
print(f"[filter] BG: percentile={bg_percentile:.1f} thr={thr_bg:.6f}")
for i, node in enumerate(state.bg_nodes):
conf = frames[i]["bg_conf_mean"]
keep = conf >= thr_bg
pts = frames[i]["bg_pts"][keep]
cols = frames[i]["img_rgb_float"][frames[i]["bg_mask_flat"]][keep]
node.points = pts
node.colors = cols
print(f" - frame {i:02d}: keep {pts.shape[0]} pts")
server.flush()
# FG
print(f"[filter] FG: percentile={fg_percentile:.1f}")
T = len(t_vals)
for t in range(T):
fi = int(nearest_idx[t])
conf_all = fg_conf_all_t[t]
thr_t = np.percentile(conf_all, fg_percentile) if conf_all.size > 0 else np.inf
conf = frames[fi]["conf_fg_per_t"][t]
pts = frames[fi]["pts_fg_per_t"][t]
keep = conf >= thr_t
pts_k = pts[keep]
cols_k = frames[fi]["img_rgb_float"][frames[fi]["fg_mask_flat"]][keep]
node = state.fg_nodes[t]
node.points = pts_k
node.colors = cols_k
print(f" - t {t:02d}: frame {fi:02d}, keep {pts_k.shape[0]} pts")
if (t % 3) == 0:
server.flush()
server.flush()
finally:
if state.status_label:
state.status_label.value = ""
if state.slider_fg: state.slider_fg.disabled = False
if state.slider_bg: state.slider_bg.disabled = False
with state.lock:
state.is_updating = False
def _update_traj_visibility(state: "ViewerState", server: viser.ViserServer, tidx: int, on: bool):
if not state.traj_nodes:
return
with server.atomic():
for t, nodes in enumerate(state.traj_nodes):
vis = on and (t <= tidx)
for nd in nodes:
nd.visible = vis
server.flush()
# ---------------- trajectories ----------------
def build_traj_nodes(
server: viser.ViserServer,
output: Dict,
frames: List[Dict],
traj_frames: List[int],
t_vals: np.ndarray,
max_points: int,
state: "ViewerState",
):
# remove old
for lst in state.traj_nodes:
for nd in lst:
nd.remove()
state.traj_nodes = [[] for _ in range(len(t_vals) - 1)]
T = len(t_vals)
vals01 = (np.arange(T - 1, dtype=np.float32)) / max(1, T - 2)
seg_colors = hsv_colormap(vals01) # [T-1, 3]
print("[traj] building …")
for fi in traj_frames:
if fi < 0 or fi >= len(frames):
continue
fr = frames[fi]
fg_mask_flat = fr["fg_mask_flat"]
if not np.any(fg_mask_flat):
continue
fg_idx = np.flatnonzero(fg_mask_flat)
if max_points > 0 and fg_idx.size > max_points:
sel = np.random.default_rng(42).choice(fg_idx, size=max_points, replace=False)
inv = {p: j for j, p in enumerate(fg_idx)}
sel_fg = np.array([inv[p] for p in sel], dtype=np.int64)
else:
sel = fg_idx
sel_fg = np.arange(fg_idx.size, dtype=np.int64)
arr = np.stack([fr["pts_fg_per_t"][t] for t in range(T)], axis=0)
if sel_fg.size == 0:
continue
arr = arr[:, sel_fg, :] # [T,N_sel,3]
for t in range(T - 1):
p0 = arr[t]
p1 = arr[t + 1]
if p0.size == 0:
continue
segs = np.stack([p0, p1], axis=1) # [N,2,3]
col = np.repeat(seg_colors[t][None, :], segs.shape[0], axis=0) # [N,3]
node = server.scene.add_line_segments(
name=f"/traj/frame{fi}/t{t:02d}",
points=segs,
colors=np.repeat(col[:, None, :], 2, axis=1),
line_width=state.traj_width.value if state.traj_width else 0.075,
visible=False,
)
state.traj_nodes[t].append(node)
print("[traj] done.")
# ----------------- main viewer -----------------
def serve_view(output: Dict, port: int = 8020, t_step: float = 0.1, ds: int = 2):
server = viser.ViserServer(port=port)
server.gui.set_panel_label("TraceAnything Viewer")
server.gui.configure_theme(control_layout="floating", control_width="medium", show_logo=False)
# restore camera & scene setup
server.scene.set_up_direction((0.0, -1.0, 0.0))
server.scene.world_axes.visible = False
@server.on_client_connect
def _on_connect(client: viser.ClientHandle):
with client.atomic():
client.camera.position = (-0.00141163, -0.01910395, -0.06794288)
client.camera.look_at = (-0.00352821, -0.01143425, 0.01549390)
client.flush()
root = output.get("_root_dir", os.getcwd())
images_dir = ensure_dir(os.path.join(root, "images"))
masks_dir = ensure_dir(os.path.join(root, "masks"))
t_vals, frames, fg_conf_all_t, bg_conf_all_flat, times = build_precomputes(output, t_step, ds)
nearest_idx = choose_nearest_frame_indices(times, t_vals)
# save images only (masks are assumed precomputed)
save_images(frames, images_dir)
print(f"[viewer] Using precomputed FG masks in: {masks_dir} (or preds['fg_mask*'])")
state = ViewerState()
with server.gui.add_folder("Point Size", expand_by_default=True):
state.point_size = server.gui.add_slider("FG Point Size", min=1e-5, max=1e-3, step=1e-4, initial_value=0.0002)
state.bg_point_size = server.gui.add_slider("BG Point Size", min=1e-5, max=1e-3, step=1e-4, initial_value=0.0002)
with server.gui.add_folder("Confidence Filtering", expand_by_default=True):
state.slider_fg = server.gui.add_slider("FG percentile (drop bottom %)", min=0, max=100, step=1, initial_value=10)
state.slider_bg = server.gui.add_slider("BG percentile (drop bottom %)", min=0, max=100, step=1, initial_value=10)
with server.gui.add_folder("Playback", expand_by_default=True):
state.slider_time = server.gui.add_slider("Time", min=0.0, max=1.0, step=t_step, initial_value=0.0)
state.play_btn = server.gui.add_button("▶ Play")
state.pause_btn = server.gui.add_button("⏸ Pause")
state.fps_slider = server.gui.add_slider("FPS", min=1, max=60, step=1, initial_value=10)
state.loop_checkbox = server.gui.add_checkbox("Loop", True)
# ---- Trajectories panel ----
with server.gui.add_folder("Trajectories", expand_by_default=True):
state.show_traj = server.gui.add_checkbox("Show trajectories", False)
state.traj_width = server.gui.add_slider("Line width", min=0.01, max=0.2, step=0.005, initial_value=0.075)
state.traj_frames_text = server.gui.add_text("Frames (e.g. 0,mid,last)", initial_value="0,mid,last")
state.traj_build_btn = server.gui.add_button("Build / Refresh")
state.status_label = server.gui.add_markdown("")
# build nodes
build_bg_nodes(server, frames, init_percentile=state.slider_bg.value, bg_conf_all_flat=bg_conf_all_flat, state=state)
build_fg_nodes(server, frames, nearest_idx, t_vals, fg_conf_all_t, init_percentile=state.slider_fg.value, state=state)
print("\n[viewer] Ready. Open the printed URL and play with sliders!\n")
# --- playback loop thread ---
def _playback_loop():
while True:
if state.playing:
try:
fps = max(1, int(state.fps_slider.value)) if state.fps_slider else 10
dt = 1.0 / float(fps)
tv = float(state.slider_time.value)
tv_next = tv + t_step
if tv_next > 1.0 - 1e-6:
if state.loop_checkbox and state.loop_checkbox.value:
tv_next = 0.0
else:
tv_next = 1.0 - 1e-6
state.playing = False
state.slider_time.value = tv_next
except Exception:
pass
time.sleep(dt if state.playing else 0.05)
threading.Thread(target=_playback_loop, daemon=True).start()
# callbacks
@state.slider_time.on_update
def _(_):
tv = state.slider_time.value
tidx = int(round(tv / t_step))
tidx = max(0, min(tidx, len(t_vals) - 1))
with server.atomic():
for t in range(len(state.fg_nodes)):
state.fg_nodes[t].visible = (t == tidx)
server.flush()
# NEW: only show trajectories up to current t
_update_traj_visibility(state, server, tidx, on=(state.show_traj and state.show_traj.value))
@state.play_btn.on_click
def _(_):
state.playing = True
@state.pause_btn.on_click
def _(_):
state.playing = False
def _rebuild_filter(_):
update_conf_filtering(
server=server,
state=state,
frames=frames,
nearest_idx=nearest_idx,
t_vals=t_vals,
fg_conf_all_t=fg_conf_all_t,
bg_conf_all_flat=bg_conf_all_flat,
fg_percentile=state.slider_fg.value,
bg_percentile=state.slider_bg.value,
)
@state.slider_fg.on_update
def _(_):
_rebuild_filter(_)
@state.slider_bg.on_update
def _(_):
_rebuild_filter(_)
@state.point_size.on_update
def _(_):
with server.atomic():
for n in state.fg_nodes:
n.point_size = state.point_size.value
server.flush()
@state.bg_point_size.on_update
def _(_):
with server.atomic():
for n in state.bg_nodes:
n.point_size = state.bg_point_size.value
server.flush()
# --- trajectories build/refresh and live controls ---
def _parse_traj_frames():
txt = (state.traj_frames_text.value or "").strip()
if not txt:
return []
tokens = [t.strip() for t in txt.split(",") if t.strip()]
n = len(frames)
out_idx = []
for tk in tokens:
if tk == "mid":
out_idx.append(n // 2)
elif tk == "last":
out_idx.append(n - 1)
else:
try:
out_idx.append(int(tk))
except Exception:
pass
out_idx = [i for i in sorted(set(out_idx)) if 0 <= i < n]
return out_idx
@state.traj_build_btn.on_click
def _(_):
sel = _parse_traj_frames()
if not sel:
return
if state.status_label:
state.status_label.value = "🧵 Building trajectories…"
build_traj_nodes(
server=server,
output=output,
frames=frames,
traj_frames=sel,
t_vals=t_vals,
max_points=10000,
state=state,
)
tidx = int(round(state.slider_time.value / t_step))
tidx = max(0, min(tidx, len(t_vals) - 1))
_update_traj_visibility(state, server, tidx, on=(state.show_traj and state.show_traj.value))
if state.status_label:
state.status_label.value = ""
if state.show_traj.value:
tidx = int(round(state.slider_time.value / t_step))
tidx = max(0, min(tidx, len(t_vals) - 1))
with server.atomic():
for t, nodes in enumerate(state.traj_nodes):
vis = (t <= tidx)
for nd in nodes:
nd.visible = vis
server.flush()
@state.show_traj.on_update
def _(_):
tidx = int(round(state.slider_time.value / t_step))
tidx = max(0, min(tidx, len(t_vals) - 1))
_update_traj_visibility(state, server, tidx, on=state.show_traj.value)
@state.traj_width.on_update
def _(_):
w = state.traj_width.value
with server.atomic():
for nodes in state.traj_nodes:
for nd in nodes:
nd.line_width = w
server.flush()
return server
def parse_args():
p = argparse.ArgumentParser("TraceAnything viewer")
p.add_argument("--output", type=str, default="./examples/output/elephant/output.pt",
help="Path to output.pt or parent directory.")
p.add_argument("--port", type=int, default=8020)
p.add_argument("--t_step", type=float, default=0.025)
p.add_argument("--ds", type=int, default=1, help="downsample stride for H,W (>=1)")
return p.parse_args()
def main():
args = parse_args()
out = load_output_dict(args.output)
server = serve_view(out, port=args.port, t_step=args.t_step, ds=max(1, args.ds))
try:
while True:
time.sleep(3600)
except KeyboardInterrupt:
pass
if __name__ == "__main__":
main()
```
## /trace_anything/__init__.py
```py path="/trace_anything/__init__.py"
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
```
## /trace_anything/heads.py
```py path="/trace_anything/heads.py"
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# trace_anything/heads.py
from typing import List
import torch
import torch.nn as nn
from einops import rearrange
from .layers.dpt_block import DPTOutputAdapter # 本地化
def _resolve_hw(img_info):
"""Normalize `img_info` into the form (int(H), int(W)).
Allowed input formats:
- torch.Tensor of shape (..., 2)
(e.g. [num_views, B, 2] or [B, 2])
- (h, w) tuple/list, where elements can be int or Tensor
Requirement: H/W must be consistent across the whole batch;
otherwise, raise an error
(current inference path does not support mixed-size batches).
"""
if isinstance(img_info, torch.Tensor):
assert img_info.shape[-1] == 2, f"img_info last dim must be 2, got {img_info.shape}"
h0 = img_info.reshape(-1, 2)[0, 0].item()
w0 = img_info.reshape(-1, 2)[0, 1].item()
if (img_info[..., 0] != img_info[..., 0].reshape(-1)[0]).any() or \
(img_info[..., 1] != img_info[..., 1].reshape(-1)[0]).any():
raise AssertionError(f"Mixed H/W in batch not supported: {tuple(img_info.shape)}")
return int(h0), int(w0)
if isinstance(img_info, (list, tuple)) and len(img_info) == 2:
h, w = img_info
if isinstance(h, torch.Tensor): h = int(h.reshape(-1)[0].item())
if isinstance(w, torch.Tensor): w = int(w.reshape(-1)[0].item())
return int(h), int(w)
raise TypeError(f"Unexpected img_info type: {type(img_info)}")
# ---------- postprocess ----------
def reg_dense_depth(xyz, mode):
mode, vmin, vmax = mode
no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
assert no_bounds
if mode == "linear":
return xyz if no_bounds else xyz.clip(min=vmin, max=vmax)
d = xyz.norm(dim=-1, keepdim=True)
xyz = xyz / d.clip(min=1e-8)
if mode == "square":
return xyz * d.square()
if mode == "exp":
return xyz * torch.expm1(d)
raise ValueError(f"bad {mode=}")
def reg_dense_conf(x, mode):
mode, vmin, vmax = mode
if mode == "exp":
return vmin + x.exp().clip(max=vmax - vmin)
if mode == "sigmoid":
return (vmax - vmin) * torch.sigmoid(x) + vmin
raise ValueError(f"bad {mode=}")
def postprocess(out, depth_mode, conf_mode):
fmap = out.permute(0, 2, 3, 1) # B,H,W,C
res = dict(pts3d=reg_dense_depth(fmap[..., 0:3], mode=depth_mode))
if conf_mode is not None:
res["conf"] = reg_dense_conf(fmap[..., 3], mode=conf_mode)
return res
def postprocess_multi_point(out, depth_mode, conf_mode):
fmap = out.permute(0, 2, 3, 1) # B,H,W,C
B, H, W, C = fmap.shape
n_point = C // 4
pts_3d = fmap[..., :3*n_point].view(B, H, W, 3, n_point).permute(4, 0, 1, 2, 3) # [K,B,H,W,3]
conf = fmap[..., 3*n_point:].view(B, H, W, 1, n_point).squeeze(3).permute(3, 0, 1, 2) # [K,B,H,W]
res = dict(pts3d=reg_dense_depth(pts_3d, mode=depth_mode))
res["conf"] = reg_dense_conf(conf, mode=conf_mode)
return res
class DPTOutputAdapterFix(DPTOutputAdapter):
def init(self, dim_tokens_enc=768):
super().init(dim_tokens_enc)
del self.act_1_postprocess
del self.act_2_postprocess
del self.act_3_postprocess
del self.act_4_postprocess
def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
assert (
self.dim_tokens_enc is not None
), "Need to call init(dim_tokens_enc) function first"
image_size = self.image_size if image_size is None else image_size
H, W = image_size
H, W = int(H), int(W)
# Number of patches in height and width
N_H = H // (self.stride_level * self.P_H)
N_W = W // (self.stride_level * self.P_W)
layers = [encoder_tokens[h] for h in self.hooks] # 4 x [B,N,C]
layers = [self.adapt_tokens(l) for l in layers] # 4 x [B,N,C]
layers = [rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers]
layers = [self.act_postprocess[i](l) for i, l in enumerate(layers)]
layers = [self.scratch.layer_rn[i](l) for i, l in enumerate(layers)]
p4 = self.scratch.refinenet4(layers[3])[:, :, : layers[2].shape[2], : layers[2].shape[3]]
p3 = self.scratch.refinenet3(p4, layers[2])
p2 = self.scratch.refinenet2(p3, layers[1])
p1 = self.scratch.refinenet1(p2, layers[0])
max_chunk = 1 if self.training else 50
outs = []
for ch in torch.split(p1, max_chunk, dim=0):
outs.append(self.head(ch))
return torch.cat(outs, dim=0) # [B,C,H,W]
# ---------- Heads ----------
class ScalarHead(nn.Module):
def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
output_width_ratio=1, num_channels=1, **kwargs):
super().__init__()
assert n_cls_token == 0
dpt_args = dict(output_width_ratio=output_width_ratio, num_channels=num_channels, **kwargs)
if hooks_idx is not None:
dpt_args.update(hooks=hooks_idx)
self.dpt = DPTOutputAdapterFix(**dpt_args)
dpt_init = {} if dim_tokens is None else {"dim_tokens_enc": dim_tokens}
self.dpt.init(**dpt_init)
self.scalar_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1), # B,C,1,1 (C==1)
nn.Flatten(), # B,1
nn.Linear(1, 1),
nn.Sigmoid(),
)
def forward(self, x_list, img_info):
H, W = _resolve_hw(img_info)
out = self.dpt(x_list, image_size=(H, W)) # [B,1,H,W]
return self.scalar_head(out) # [B,1]
class PixelHead(nn.Module):
"""Output per-pixel (3D point + confidence), supports multiple points (K)."""
def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
output_width_ratio=1, num_channels=1,
postprocess=postprocess_multi_point,
depth_mode=None, conf_mode=None, **kwargs):
super().__init__()
assert n_cls_token == 0
self.postprocess = postprocess
self.depth_mode = depth_mode
self.conf_mode = conf_mode
dpt_args = dict(output_width_ratio=output_width_ratio, num_channels=num_channels, **kwargs)
if hooks_idx is not None:
dpt_args.update(hooks=hooks_idx)
self.dpt = DPTOutputAdapterFix(**dpt_args)
dpt_init = {} if dim_tokens is None else {"dim_tokens_enc": dim_tokens}
self.dpt.init(**dpt_init)
def forward(self, x_list, img_info):
H, W = _resolve_hw(img_info)
out = self.dpt(x_list, image_size=(H, W)) # [B,C,H,W]
return self.postprocess(out, self.depth_mode, self.conf_mode)
```
## /trace_anything/layers/__init__.py
```py path="/trace_anything/layers/__init__.py"
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
```
## /trace_anything/layers/blocks.py
```py path="/trace_anything/layers/blocks.py"
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# Main encoder/decoder blocks
# --------------------------------------------------------
# References:
# timm
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
import collections.abc
from itertools import repeat
import math
import torch
import torch.nn as nn
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0,
attn_mask=None, is_causal=False, attn_implementation="pytorch_naive",
attn_bias_for_inference_enabled=False,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
# use attention biasing to accommodate for longer sequences than during training
self.attn_bias_for_inference_enabled = attn_bias_for_inference_enabled
gamma = 1.0
train_seqlen = 20
inference_seqlen = 137
self.attn_bias_scale = head_dim**-0.5 * (gamma * math.log(inference_seqlen) / math.log(train_seqlen))**0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.dropout_p = attn_drop
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = rope
self.attn_mask = attn_mask
self.is_causal = is_causal
self.attn_implementation = attn_implementation
def forward(self, x, xpos):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.transpose(1, 3)
)
q, k, v = [qkv[:, :, i] for i in range(3)]
# q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
if self.rope is not None:
with torch.autocast(device_type=next(self.parameters()).device.type, dtype=torch.float32): # FIXME: for some reason Lightning didn't pick up torch.cuda.amp.custom_fwd when using bf16-true
q = self.rope(q, xpos) if xpos is not None else q
k = self.rope(k, xpos) if xpos is not None else k
if not self.training and self.attn_bias_for_inference_enabled:
scale = self.attn_bias_scale
else:
scale = self.scale
# Important: For the fusion Transformer, we forward through the attention with bfloat16 precision
# If you are not using this block for the fusion Transformer, you should double check the precision of the input and output
if self.attn_implementation == "pytorch_naive":
assert self.attn_mask is None, "attn_mask not supported for pytorch_naive implementation of scaled dot product attention"
assert self.is_causal is False, "is_causal not supported for pytorch_naive implementation of scaled dot product attention"
dtype = k.dtype
with torch.autocast("cuda", dtype=torch.bfloat16):
x = (q @ k.transpose(-2, -1)) * scale
x = x.softmax(dim=-1)
x = self.attn_drop(x)
if dtype == torch.float32: # if input was FP32, cast back to FP32
x = x.to(torch.float32)
x = (x @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
elif self.attn_implementation == "flash_attention":
with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
dtype = k.dtype
with torch.autocast("cuda", dtype=torch.bfloat16):
x = scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout_p, is_causal=self.is_causal, scale=scale)
if dtype == torch.float32: # if input was FP32, cast back to FP32
x = x.to(torch.float32)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
elif self.attn_implementation == "pytorch_auto":
with torch.nn.attention.sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION,]):
dtype = k.dtype
with torch.autocast("cuda", dtype=torch.bfloat16):
x = scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout_p, is_causal=self.is_causal, scale=scale)
if dtype == torch.float32: # if input was FP32, cast back to FP32
x = x.to(torch.float32)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
else:
raise ValueError(f"Unknown attn_implementation: {self.attn_implementation}")
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
rope=None,
attn_implementation="pytorch_naive",
attn_bias_for_inference_enabled=False,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
rope=rope,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
attn_implementation=attn_implementation,
attn_bias_for_inference_enabled=attn_bias_for_inference_enabled,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x, xpos):
x = x + self.drop_path(self.attn(self.norm1(x), xpos))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class CrossAttention(nn.Module):
def __init__(
self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, attn_mask=None, is_causal=False, attn_implementation="pytorch_naive"
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.projq = nn.Linear(dim, dim, bias=qkv_bias)
self.projk = nn.Linear(dim, dim, bias=qkv_bias)
self.projv = nn.Linear(dim, dim, bias=qkv_bias)
self.dropout_p = attn_drop
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = rope
self.attn_mask = attn_mask
self.is_causal = is_causal
self.attn_implementation = attn_implementation
def forward(self, query, key, value, qpos, kpos):
B, Nq, C = query.shape
Nk = key.shape[1]
Nv = value.shape[1]
q = (
self.projq(query)
.reshape(B, Nq, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
k = (
self.projk(key)
.reshape(B, Nk, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
v = (
self.projv(value)
.reshape(B, Nv, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)
if self.rope is not None:
with torch.autocast(device_type=next(self.parameters()).device.type, dtype=torch.float32): # FIXME: for some reason Lightning didn't pick up torch.cuda.amp.custom_fwd when using bf16-true
q = self.rope(q, qpos) if qpos is not None else q
k = self.rope(k, kpos) if kpos is not None else k
if self.attn_implementation == "pytorch_naive":
assert self.attn_mask is None, "attn_mask not supported for pytorch_naive implementation of scaled dot product attention"
assert self.is_causal is False, "is_causal not supported for pytorch_naive implementation of scaled dot product attention"
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
x = self.proj(x)
x = self.proj_drop(x)
elif self.attn_implementation == "flash_attention":
with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
# cast to BF16 to use flash_attention
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)
x = scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout_p, is_causal=self.is_causal, scale=self.scale)
# cast back to FP32
x = x.to(torch.float32)
x = x.transpose(1, 2).reshape(B, Nq, C)
x = self.proj(x)
x = self.proj_drop(x)
else:
raise ValueError(f"Unknown attn_implementation: {self.attn_implementation}")
return x
class DecoderBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
norm_mem=True,
rope=None,
attn_implementation="pytorch_naive",
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
rope=rope,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
attn_implementation=attn_implementation,
)
self.cross_attn = CrossAttention(
dim,
rope=rope,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
attn_implementation=attn_implementation,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm3 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
def forward(self, x, y, xpos, ypos):
x = x + self.drop_path(self.attn(self.norm1(x), xpos))
y_ = self.norm_y(y)
x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
x = x + self.drop_path(self.mlp(self.norm3(x)))
return x, y
# patch embedding
class PositionGetter(object):
"""return positions of patches"""
def __init__(self):
self.cache_positions = {}
def __call__(self, b, h, w, device):
if not (h, w) in self.cache_positions:
x = torch.arange(w, device=device)
y = torch.arange(h, device=device)
self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2)
pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone()
return pos
class PatchEmbed(nn.Module):
"""just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
self.position_getter = PositionGetter()
def forward(self, x):
B, C, H, W = x.shape
torch._assert(
H == self.img_size[0],
f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
)
torch._assert(
W == self.img_size[1],
f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
)
x = self.proj(x)
pos = self.position_getter(B, x.size(2), x.size(3), x.device)
if self.flatten:
x = x.flatten(2).transpose(1, 2).contiguous() # BCHW -> BNC
x = self.norm(x)
return x, pos
def _init_weights(self):
w = self.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
```
## /trace_anything/layers/dpt_block.py
```py path="/trace_anything/layers/dpt_block.py"
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# DPT head for ViTs
# --------------------------------------------------------
# References:
# https://github.com/isl-org/DPT
# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py
from typing import Iterable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand == True:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0],
out_shape1,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1],
out_shape2,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2],
out_shape3,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3],
out_shape4,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer_rn = nn.ModuleList(
[
scratch.layer1_rn,
scratch.layer2_rn,
scratch.layer3_rn,
scratch.layer4_rn,
]
)
return scratch
class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module."""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups = 1
self.conv1 = nn.Conv2d(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=not self.bn,
groups=self.groups,
)
self.conv2 = nn.Conv2d(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=not self.bn,
groups=self.groups,
)
if self.bn == True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn == True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn == True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block."""
def __init__(
self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True,
width_ratio=1,
):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.width_ratio = width_ratio
self.deconv = deconv
self.align_corners = align_corners
self.groups = 1
self.expand = expand
out_features = features
if self.expand == True:
out_features = features // 2
self.out_conv = nn.Conv2d(
features,
out_features,
kernel_size=1,
stride=1,
padding=0,
bias=True,
groups=1,
)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs, max_chunk_size=100):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
if self.width_ratio != 1:
res = F.interpolate(
res, size=(output.shape[2], output.shape[3]), mode="bilinear"
)
output = self.skip_add.add(output, res)
# output += res
output = self.resConfUnit2(output)
if self.width_ratio != 1:
# and output.shape[3] < self.width_ratio * output.shape[2]
# size=(image.shape[])
if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
shape = 3 * output.shape[3]
else:
shape = int(self.width_ratio * 2 * output.shape[2])
output = F.interpolate(
output, size=(2 * output.shape[2], shape), mode="bilinear"
)
else:
# Split input into chunks to avoid memory issues with large batches
chunks = torch.split(output, max_chunk_size, dim=0)
outputs = []
for chunk in chunks:
out_chunk = nn.functional.interpolate(
chunk,
scale_factor=2,
mode="bilinear",
align_corners=self.align_corners,
)
outputs.append(out_chunk)
# Concatenate outputs along the batch dimension
output = torch.cat(outputs, dim=0)
output = self.out_conv(output)
return output
def make_fusion_block(features, use_bn, width_ratio=1):
return FeatureFusionBlock_custom(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
width_ratio=width_ratio,
)
class Interpolate(nn.Module):
"""Interpolation module."""
def __init__(self, scale_factor, mode, align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(
x,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
)
return x
class DPTOutputAdapter(nn.Module):
"""DPT output adapter.
:param num_cahnnels: Number of output channels
:param stride_level: tride level compared to the full-sized image.
E.g. 4 for 1/4th the size of the image.
:param patch_size_full: Int or tuple of the patch size over the full image size.
Patch size for smaller inputs will be computed accordingly.
:param hooks: Index of intermediate layers
:param layer_dims: Dimension of intermediate layers
:param feature_dim: Feature dimension
:param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
:param use_bn: If set to True, activates batch norm
:param dim_tokens_enc: Dimension of tokens coming from encoder
"""
def __init__(
self,
num_channels: int = 1,
stride_level: int = 1,
patch_size: Union[int, Tuple[int, int]] = 16,
main_tasks: Iterable[str] = ("rgb",),
hooks: List[int] = [2, 5, 8, 11],
layer_dims: List[int] = [96, 192, 384, 768],
feature_dim: int = 256,
last_dim: int = 32,
use_bn: bool = False,
dim_tokens_enc: Optional[int] = None,
head_type: str = "regression",
output_width_ratio=1,
**kwargs
):
super().__init__()
self.num_channels = num_channels
self.stride_level = stride_level
self.patch_size = pair(patch_size)
self.main_tasks = main_tasks
self.hooks = hooks
self.layer_dims = layer_dims
self.feature_dim = feature_dim
self.dim_tokens_enc = (
dim_tokens_enc * len(self.main_tasks)
if dim_tokens_enc is not None
else None
)
self.head_type = head_type
# Actual patch height and width, taking into account stride of input
self.P_H = max(1, self.patch_size[0] // stride_level)
self.P_W = max(1, self.patch_size[1] // stride_level)
self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
self.scratch.refinenet1 = make_fusion_block(
feature_dim, use_bn, output_width_ratio
)
self.scratch.refinenet2 = make_fusion_block(
feature_dim, use_bn, output_width_ratio
)
self.scratch.refinenet3 = make_fusion_block(
feature_dim, use_bn, output_width_ratio
)
self.scratch.refinenet4 = make_fusion_block(
feature_dim, use_bn, output_width_ratio
)
if self.head_type == "regression":
# The "DPTDepthModel" head
self.head = nn.Sequential(
nn.Conv2d(
feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1
),
# the act_postprocess layers upsample each patch by 8 in total,
# so self.patch_size / 8 calculates how much more we need to upsample
# to get to the full image size (remember that num_patches = image_size / patch_size)
Interpolate(scale_factor=self.patch_size[0] / 8, mode="bilinear", align_corners=True),
nn.Conv2d(
feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1
),
nn.ReLU(True),
nn.Conv2d(
last_dim, self.num_channels, kernel_size=1, stride=1, padding=0
),
)
elif self.head_type == "semseg":
# The "DPTSegmentationModel" head
self.head = nn.Sequential(
nn.Conv2d(
feature_dim, feature_dim, kernel_size=3, padding=1, bias=False
),
nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
nn.ReLU(True),
nn.Dropout(0.1, False),
nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
)
else:
raise ValueError('DPT head_type must be "regression" or "semseg".')
if self.dim_tokens_enc is not None:
self.init(dim_tokens_enc=dim_tokens_enc)
def init(self, dim_tokens_enc=768):
"""
Initialize parts of decoder that are dependent on dimension of encoder tokens.
Should be called when setting up MultiMAE.
:param dim_tokens_enc: Dimension of tokens coming from encoder
"""
# print(dim_tokens_enc)
# Set up activation postprocessing layers
if isinstance(dim_tokens_enc, int):
dim_tokens_enc = 4 * [dim_tokens_enc]
self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
self.act_1_postprocess = nn.Sequential(
nn.Conv2d(
in_channels=self.dim_tokens_enc[0],
out_channels=self.layer_dims[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=self.layer_dims[0],
out_channels=self.layer_dims[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
self.act_2_postprocess = nn.Sequential(
nn.Conv2d(
in_channels=self.dim_tokens_enc[1],
out_channels=self.layer_dims[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=self.layer_dims[1],
out_channels=self.layer_dims[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
self.act_3_postprocess = nn.Sequential(
nn.Conv2d(
in_channels=self.dim_tokens_enc[2],
out_channels=self.layer_dims[2],
kernel_size=1,
stride=1,
padding=0,
)
)
self.act_4_postprocess = nn.Sequential(
nn.Conv2d(
in_channels=self.dim_tokens_enc[3],
out_channels=self.layer_dims[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=self.layer_dims[3],
out_channels=self.layer_dims[3],
kernel_size=3,
stride=2,
padding=1,
),
)
self.act_postprocess = nn.ModuleList(
[
self.act_1_postprocess,
self.act_2_postprocess,
self.act_3_postprocess,
self.act_4_postprocess,
]
)
def adapt_tokens(self, encoder_tokens):
# Adapt tokens
x = []
x.append(encoder_tokens[:, :])
x = torch.cat(x, dim=-1)
return x
def forward(self, encoder_tokens: List[torch.Tensor], image_size):
# input_info: Dict):
assert (
self.dim_tokens_enc is not None
), "Need to call init(dim_tokens_enc) function first"
H, W = image_size
# Number of patches in height and width
N_H = H // (self.stride_level * self.P_H)
N_W = W // (self.stride_level * self.P_W)
# Hook decoder onto 4 layers from specified ViT layers
layers = [encoder_tokens[hook] for hook in self.hooks]
# Extract only task-relevant tokens and ignore global tokens.
layers = [self.adapt_tokens(l) for l in layers]
# Reshape tokens to spatial representation
layers = [
rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers
]
layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
# Project layers to chosen feature dim
layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
# Fuse layers using refinement stages
path_4 = self.scratch.refinenet4(layers[3])
path_3 = self.scratch.refinenet3(path_4, layers[2])
path_2 = self.scratch.refinenet2(path_3, layers[1])
path_1 = self.scratch.refinenet1(path_2, layers[0])
# Output head
out = self.head(path_1)
return out
```
## /trace_anything/layers/patch_embed.py
```py path="/trace_anything/layers/patch_embed.py"
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# PatchEmbed implementation for DUST3R,
# in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio
# --------------------------------------------------------
import torch
from trace_anything.layers.blocks import PatchEmbed
def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):
assert patch_embed_cls in ["PatchEmbedDust3R", "ManyAR_PatchEmbed"]
patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim)
return patch_embed
class PatchEmbedDust3R(PatchEmbed):
def forward(self, x, **kw):
B, C, H, W = x.shape
assert (
H % self.patch_size[0] == 0
), f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
assert (
W % self.patch_size[1] == 0
), f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
x = self.proj(x)
pos = self.position_getter(B, x.size(2), x.size(3), x.device)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x, pos
class ManyAR_PatchEmbed(PatchEmbed):
"""Handle images with non-square aspect ratio.
All images in the same batch have the same aspect ratio.
true_shape = [(height, width) ...] indicates the actual shape of each image.
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
):
self.embed_dim = embed_dim
super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten)
def forward(self, img, true_shape):
B, C, H, W = img.shape
assert W >= H, f"img should be in landscape mode, but got {W=} {H=}"
assert (
H % self.patch_size[0] == 0
), f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
assert (
W % self.patch_size[1] == 0
), f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
assert true_shape.shape == (
B,
2,
), f"true_shape has the wrong shape={true_shape.shape}"
# size expressed in tokens
W //= self.patch_size[0]
H //= self.patch_size[1]
n_tokens = H * W
height, width = true_shape.T
is_landscape = width >= height
is_portrait = ~is_landscape
# linear projection, transposed if necessary
if is_landscape.any():
new_landscape_content = self.proj(img[is_landscape])
new_landscape_content = new_landscape_content.permute(0, 2, 3, 1).flatten(1, 2)
if is_portrait.any():
new_protrait_content = self.proj(img[is_portrait].swapaxes(-1, -2))
new_protrait_content = new_protrait_content.permute(0, 2, 3, 1).flatten(1, 2)
# allocate space for result and set the content
x = img.new_empty((B, n_tokens, self.embed_dim), dtype=next(self.named_parameters())[1].dtype) # dynamically set dtype based on the current precision
if is_landscape.any():
x[is_landscape] = new_landscape_content
if is_portrait.any():
x[is_portrait] = new_protrait_content
# allocate space for result and set the content
pos = img.new_empty((B, n_tokens, 2), dtype=torch.int64)
if is_landscape.any():
pos[is_landscape] = self.position_getter(1, H, W, pos.device).expand(is_landscape.sum(), -1, -1)
if is_portrait.any():
pos[is_portrait] = self.position_getter(1, W, H, pos.device).expand(is_portrait.sum(), -1, -1)
x = self.norm(x)
return x, pos
```
## /trace_anything/layers/pos_embed.py
```py path="/trace_anything/layers/pos_embed.py"
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------
import numpy as np
import torch
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if n_cls_token > 0:
pos_embed = np.concatenate(
[np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size, orig_size, new_size, new_size)
)
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(
-1, orig_size, orig_size, embedding_size
).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode="bicubic",
align_corners=False,
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
# ----------------------------------------------------------
# RoPE2D: RoPE implementation in 2D
# ----------------------------------------------------------
try:
from fast3r.croco.models.curope import cuRoPE2D
RoPE2D = cuRoPE2D
except ImportError:
print(
"Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead"
)
class RoPE2D(torch.nn.Module):
def __init__(self, freq=100.0, F0=1.0):
super().__init__()
self.base = freq
self.F0 = F0
self.cache = {}
def get_cos_sin(self, D, seq_len, device, dtype):
if (D, seq_len, device, dtype) not in self.cache:
inv_freq = 1.0 / (
self.base ** (torch.arange(0, D, 2).float().to(device) / D)
)
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
freqs = torch.cat((freqs, freqs), dim=-1)
cos = freqs.cos() # (Seq, Dim)
sin = freqs.sin()
self.cache[D, seq_len, device, dtype] = (cos, sin)
return self.cache[D, seq_len, device, dtype]
@staticmethod
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rope1d(self, tokens, pos1d, cos, sin):
assert pos1d.ndim == 2
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
return (tokens * cos) + (self.rotate_half(tokens) * sin)
def forward(self, tokens, positions):
"""
input:
* tokens: batch_size x nheads x ntokens x dim
* positions: batch_size x ntokens x 2 (y and x position of each token)
output:
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
"""
assert (
tokens.size(3) % 2 == 0
), "number of dimensions should be a multiple of two"
D = tokens.size(3) // 2
assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2
cos, sin = self.get_cos_sin(
D, int(positions.max()) + 1, tokens.device, tokens.dtype
)
# split features into two along the feature dimension, and apply rope1d on each half
y, x = tokens.chunk(2, dim=-1)
y = self.apply_rope1d(y, positions[:, :, 0], cos, sin)
x = self.apply_rope1d(x, positions[:, :, 1], cos, sin)
tokens = torch.cat((y, x), dim=-1)
return tokens
```
## /trace_anything/trace_anything.py
```py path="/trace_anything/trace_anything.py"
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# trace_anything/trace_anything.py
"""
Minimal inference-only model for this repo.
Assumptions (pruned by asserts):
- encoder_type == 'croco'
- decoder_type == 'transformer'
- head_type == 'dpt'
- targeting_mechanism == 'bspline_conf'
- optional: whether_local (bool)
"""
import math
import time
from copy import deepcopy
from functools import partial
from typing import Dict, List
import torch
import torch.nn as nn
from einops import rearrange
import numpy as np
from .layers.blocks import Block, PositionGetter
from .layers.pos_embed import RoPE2D, get_1d_sincos_pos_embed_from_grid
from .layers.patch_embed import get_patch_embed
from .heads import PixelHead, ScalarHead
from contextlib import contextmanager
import time
# ======== B-spline ========
PRECOMPUTED_KNOTS = {
4: torch.tensor([0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32),
7: torch.tensor([0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32),
10: torch.tensor([0.0, 0.0, 0.0, 0.0, 1/3, 1/3, 1/3, 2/3, 2/3, 2/3, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32),
}
def _precompute_knot_differences(n_ctrl_pts, degree, knots):
denom1 = torch.zeros(n_ctrl_pts, degree + 1, device=knots.device)
denom2 = torch.zeros(n_ctrl_pts, degree + 1, device=knots.device)
for k in range(degree + 1):
for i in range(n_ctrl_pts):
denom1[i, k] = knots[i + k] - knots[i] if i + k < len(knots) else 0.0
denom2[i, k] = knots[i + k + 1] - knots[i + 1] if i + k + 1 < len(knots) else 1.0
return denom1, denom2
PRECOMPUTED_DENOMS = {n: _precompute_knot_differences(n, 3, PRECOMPUTED_KNOTS[n]) for n in [4, 7, 10]}
def _compute_bspline_basis(n_ctrl_pts, degree, t_values, knots, denom1, denom2):
N = t_values.size(0)
basis = torch.zeros(N, n_ctrl_pts, degree + 1, device=t_values.device)
t = t_values
basis_k0 = torch.zeros(N, n_ctrl_pts, device=t.device)
for i in range(n_ctrl_pts):
if i == n_ctrl_pts - 1:
basis_k0[:, i] = ((knots[i] <= t) & (t <= knots[i + 1])).float()
else:
basis_k0[:, i] = ((knots[i] <= t) & (t < knots[i + 1])).float()
basis[:, :, 0] = basis_k0
for k in range(1, degree + 1):
basis_k = torch.zeros(N, n_ctrl_pts, device=t.device)
for i in range(n_ctrl_pts):
term1 = ((t - knots[i]) / denom1[i, k]) * basis[:, i, k-1] if denom1[i, k] > 0 else 0.0
term2 = ((knots[i + k + 1] - t) / denom2[i, k]) * basis[:, i + 1, k-1] if (denom2[i, k] > 0 and i + 1 < n_ctrl_pts) else 0.0
basis_k[:, i] = term1 + term2
basis[:, :, k] = basis_k
return basis[:, :, degree]
def evaluate_bspline_conf(ctrl_pts3d, ctrl_conf, t_values):
"""ctrl_pts3d:[N_ctrl,H,W,3], ctrl_conf:[N_ctrl,H,W], t_values:[T] -> (T,H,W,3),(T,H,W)"""
n_ctrl_pts, H, W, _ = ctrl_pts3d.shape
assert n_ctrl_pts in (4, 7, 10), f"unsupported n_ctrl_pts={n_ctrl_pts}"
degree = 3
knot_vector = PRECOMPUTED_KNOTS[n_ctrl_pts].to(ctrl_pts3d.device)
denom1, denom2 = [d.to(ctrl_pts3d.device) for d in PRECOMPUTED_DENOMS[n_ctrl_pts]]
ctrl_pts3d = ctrl_pts3d.permute(0, 3, 1, 2) # [N,3,H,W]
ctrl_conf = ctrl_conf.unsqueeze(-1).permute(0, 3, 1, 2) # [N,1,H,W]
basis = _compute_bspline_basis(n_ctrl_pts, degree, t_values, knot_vector, denom1, denom2) # [T,N]
basis = basis.view(-1, n_ctrl_pts, 1, 1, 1) # [T,N,1,1,1]
pts3d_t = torch.sum(basis * ctrl_pts3d.unsqueeze(0), dim=1).permute(0, 2, 3, 1) # [T,H,W,3]
conf_t = torch.sum(basis * ctrl_conf.unsqueeze(0), dim=1).squeeze(1) # [T,H,W]
return pts3d_t, conf_t
# ======== Encoders(仅 CroCo) ========
class CroCoEncoder(nn.Module):
def __init__(
self,
img_size=512, patch_size=16, patch_embed_cls="ManyAR_PatchEmbed",
embed_dim=768, num_heads=12, depth=12, mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
pos_embed="RoPE100", attn_implementation="pytorch_naive",
):
super().__init__()
assert pos_embed.startswith("RoPE"), f"pos_embed must start with RoPE*, got {pos_embed}"
self.patch_embed = get_patch_embed(patch_embed_cls, img_size, patch_size, embed_dim)
freq = float(pos_embed[len("RoPE"):])
self.rope = RoPE2D(freq=freq)
self.enc_blocks = nn.ModuleList([
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
qkv_bias=True, norm_layer=nn.LayerNorm, rope=self.rope,
attn_implementation=attn_implementation)
for _ in range(depth)
])
self.enc_norm = norm_layer(embed_dim)
def forward(self, image, true_shape):
x, pos = self.patch_embed(image, true_shape=true_shape)
for blk in self.enc_blocks:
x = blk(x, pos)
x = self.enc_norm(x)
return x, pos
# ======== Decoder ========
class TraceDecoder(nn.Module):
def __init__(
self,
random_image_idx_embedding: bool,
enc_embed_dim: int,
embed_dim: int = 768,
num_heads: int = 12,
depth: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
attn_implementation: str = "pytorch_naive",
attn_bias_for_inference_enabled=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.decoder_embed = nn.Linear(enc_embed_dim, embed_dim, bias=True)
self.dec_blocks = nn.ModuleList([
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
norm_layer=nn.LayerNorm, attn_implementation=attn_implementation,
attn_bias_for_inference_enabled=attn_bias_for_inference_enabled)
for _ in range(depth)
])
self.random_image_idx_embedding = random_image_idx_embedding
self.register_buffer(
"image_idx_emb",
torch.from_numpy(get_1d_sincos_pos_embed_from_grid(embed_dim, np.arange(1000))).float(),
persistent=False,
)
self.dec_norm = norm_layer(embed_dim)
def _get_random_image_pos(self, encoded_feats, batch_size, num_views, max_image_idx, device):
image_ids = torch.zeros(batch_size, num_views, dtype=torch.long)
image_ids[:, 0] = 0
per_forward_pass_seed = torch.randint(0, 2 ** 32, (1,)).item()
per_rank_generator = torch.Generator().manual_seed(per_forward_pass_seed)
for b in range(batch_size):
random_ids = torch.randperm(max_image_idx, generator=per_rank_generator)[:num_views - 1] + 1
image_ids[b, 1:] = random_ids
image_ids = image_ids.to(device)
image_pos_list = []
for i in range(num_views):
num_patches = encoded_feats[i].shape[1]
pos_for_view = self.image_idx_emb[image_ids[:, i]].unsqueeze(1).repeat(1, num_patches, 1)
image_pos_list.append(pos_for_view)
return torch.cat(image_pos_list, dim=1)
def forward(self, encoded_feats: List[torch.Tensor], positions: List[torch.Tensor],
image_ids: torch.Tensor, image_timesteps: torch.Tensor):
x = torch.cat(encoded_feats, dim=1)
pos = torch.cat(positions, dim=1)
outputs = [x]
x = self.decoder_embed(x)
if self.random_image_idx_embedding:
image_pos = self._get_random_image_pos(
encoded_feats=encoded_feats,
batch_size=encoded_feats[0].shape[0],
num_views=len(encoded_feats),
max_image_idx=self.image_idx_emb.shape[0] - 1,
device=x.device,
)
else:
num_embeddings = self.image_idx_emb.shape[0]
indices = (image_timesteps * (num_embeddings - 1)).long().view(-1)
image_pos = torch.index_select(self.image_idx_emb, dim=0, index=indices)
image_pos = image_pos.view(1, image_timesteps.shape[1], self.image_idx_emb.shape[1])
x += image_pos
for blk in self.dec_blocks:
x = blk(x, pos)
outputs.append(x)
x = self.dec_norm(x)
outputs[-1] = x
return outputs, image_pos
# ======== main model ========
class TraceAnything(nn.Module):
def __init__(self, *, encoder_args: Dict, decoder_args: Dict, head_args: Dict,
targeting_mechanism: str = "bspline_conf",
poly_degree: int = 10, whether_local: bool = False):
super().__init__()
assert targeting_mechanism == "bspline_conf", f"Only bspline_conf is supported now, got {targeting_mechanism}"
assert encoder_args.get("encoder_type") == "croco"
assert decoder_args.get("decoder_type", "transformer") in ("transformer", "fast3r")
assert head_args.get("head_type") == "dpt"
self.targeting_mechanism = targeting_mechanism
self.poly_degree = int(poly_degree)
self.whether_local = bool(whether_local or head_args.get("with_local_head", False))
# build encoder / decoder
enc_args = deepcopy(encoder_args); enc_args.pop("encoder_type", None)
self.encoder = CroCoEncoder(**enc_args)
dec_args = deepcopy(decoder_args); dec_args.pop("decoder_type", None)
self.decoder = TraceDecoder(**dec_args)
# build heads
feature_dim = 256
last_dim = feature_dim // 2
ed = encoder_args["embed_dim"]; dd = decoder_args["embed_dim"]
hooks = [0, decoder_args["depth"] * 2 // 4, decoder_args["depth"] * 3 // 4, decoder_args["depth"]]
self.ds_head_time = ScalarHead(
num_channels=1, feature_dim=feature_dim, last_dim=last_dim,
hooks_idx=hooks, dim_tokens=[ed, dd, dd, dd], head_type="regression",
patch_size=head_args["patch_size"],
)
out_nchan = (3 + bool(head_args["conf_mode"])) * self.poly_degree
self.ds_head_track = PixelHead(
num_channels=out_nchan, feature_dim=feature_dim, last_dim=last_dim,
hooks_idx=hooks, dim_tokens=[ed, dd, dd, dd], head_type="regression",
patch_size=head_args["patch_size"],
depth_mode=head_args["depth_mode"], conf_mode=head_args["conf_mode"],
)
if self.whether_local:
self.ds_head_local = PixelHead(
num_channels=out_nchan, feature_dim=feature_dim, last_dim=last_dim,
hooks_idx=hooks, dim_tokens=[ed, dd, dd, dd], head_type="regression",
patch_size=head_args["patch_size"],
depth_mode=head_args["depth_mode"], conf_mode=head_args["conf_mode"],
)
self.time_head = self.ds_head_time
self.track_head = self.ds_head_track
if self.whether_local:
self.local_head = self.ds_head_local
self.max_parallel_views_for_head = 25
def _encode_images(self, views, chunk_size=400):
B = views[0]["img"].shape[0]
same_shape = all(v["img"].shape == views[0]["img"].shape for v in views)
if same_shape:
imgs = torch.cat([v["img"] for v in views], dim=0)
true_shapes = torch.cat([v.get("true_shape", torch.tensor(v["img"].shape[-2:])[None].repeat(B, 1)) for v in views], dim=0)
feats_chunks, pos_chunks = [], []
for s in range(0, imgs.shape[0], chunk_size):
e = min(s + chunk_size, imgs.shape[0])
f, p = self.encoder(imgs[s:e], true_shapes[s:e])
feats_chunks.append(f); pos_chunks.append(p)
feats = torch.cat(feats_chunks, dim=0); pos = torch.cat(pos_chunks, dim=0)
encoded_feats = torch.split(feats, B, dim=0)
positions = torch.split(pos, B, dim=0)
shapes = torch.split(true_shapes, B, dim=0)
else:
encoded_feats, positions, shapes = [], [], []
for v in views:
img = v["img"]
true_shape = v.get("true_shape", torch.tensor(img.shape[-2:])[None].repeat(B, 1))
f, p = self.encoder(img, true_shape)
encoded_feats.append(f); positions.append(p); shapes.append(true_shape)
return encoded_feats, positions, shapes
@torch.no_grad()
def forward(self, views, profiling: bool = False):
# 1) encode
encoded_feats, positions, shapes = self._encode_images(views)
# 2) build time embedding
num_images = len(views)
B, P, D = encoded_feats[0].shape
image_ids, image_times = [], []
for i, ef in enumerate(encoded_feats):
num_patches = ef.shape[1]
image_ids.extend([i] * num_patches)
image_times.extend([views[i]["time_step"]] * num_patches)
image_ids = torch.tensor(image_ids * B, device=encoded_feats[0].device).reshape(B, -1)
image_times = torch.tensor(image_times * B, device=encoded_feats[0].device).reshape(B, -1)
# 3) decode
dec_output, _ = self.decoder(encoded_feats, positions, image_ids, image_times)
# 4) gather outputs per view
P_patches = P
gathered_outputs_list = []
for layer_output in dec_output:
layer_output = rearrange(layer_output, 'B (n P) D -> (n B) P D', n=num_images, P=P_patches)
gathered_outputs_list.append(layer_output)
# 5) heads
time_step = self.time_head(gathered_outputs_list, torch.stack(shapes)) # [N, 1] per view
track_tmp = self.track_head(gathered_outputs_list, torch.stack(shapes)) # dict with pts3d/conf after postprocess
ctrl_pts3d, ctrl_conf = track_tmp['pts3d'], track_tmp['conf'] # [K,N,H,W,3], [K,N,H,W]
if self.whether_local:
local_tmp = self.local_head(gathered_outputs_list, torch.stack(shapes))
ctrl_pts3d_local, ctrl_conf_local = local_tmp['pts3d'], local_tmp['conf']
# 6) evaluate bspline track over all reference times
results = [{} for _ in range(num_images)]
t_values = torch.stack([time_step[i].squeeze() for i in range(num_images)]) # [N]
for img_id in range(num_images):
pts3d_t, conf_t = evaluate_bspline_conf(ctrl_pts3d[:, img_id], ctrl_conf[:, img_id], t_values.detach())
res = {
'time': time_step[img_id],
'ctrl_pts3d': ctrl_pts3d[:, img_id],
'ctrl_conf': ctrl_conf[:, img_id],
'track_pts3d': [pts3d_t[[t]] for t in range(num_images)],
'track_conf': [conf_t[[t]] for t in range(num_images)],
}
if self.whether_local:
pts3d_t_l, conf_t_l = evaluate_bspline_conf(ctrl_pts3d_local[:, img_id], ctrl_conf_local[:, img_id], t_values.detach())
res.update({
'ctrl_pts3d_local': ctrl_pts3d_local[:, img_id],
'ctrl_conf_local': ctrl_conf_local[:, img_id],
'track_pts3d_local': [pts3d_t_l[[t]] for t in range(num_images)],
'track_conf_local': [conf_t_l[[t]] for t in range(num_images)],
})
results[img_id] = res
return results
```
The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.