```
├── .gitattributes
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── detikzify/
├── __init__.py
├── dataset/
├── __init__.py
├── paper2fig/
├── __init__.py
├── paper2fig.py
├── scicap/
├── __init__.py
├── scicap.py
├── evaluate/
├── __init__.py
├── clipscore.py
├── crystalbleu.py
├── dreamsim.py
├── eed.py
├── imagesim.py
├── kid.py
├── infer/
├── __init__.py
├── generate.py
├── tikz.py
├── mcts/
├── LICENSE
├── README.md
├── __init__.py
├── montecarlo.py
├── node.py
├── model/
├── __init__.py
├── adapter/
├── __init__.py
├── modeling_adapter.py
├── processing_adapter.py
├── configuration_detikzify.py
├── modeling_detikzify.py
├── processing_detikzify.py
├── v1/
├── __init__.py
├── configuration_detikzify.py
├── modeling_detikzify.py
├── processing_detikzify.py
├── train/
├── __init__.py
├── adapter/
├── __init__.py
├── pretrain.py
├── train.py
├── pretrain.py
├── train.py
├── util/
├── __init__.py
├── functools.py
├── generation.py
├── image.py
├── subprocess.py
├── torch.py
├── trainer.py
├── webui/
├── README.md
├── __init__.py
├── __main__.py
├── helpers.py
├── strings.py
├── webui.py
```
## /.gitattributes
```gitattributes path="/.gitattributes"
detikzify/mcts/** linguist-vendored
```
## /.gitignore
```gitignore path="/.gitignore"
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Project pecific
*/_version.py
```
## /LICENSE
``` path="/LICENSE"
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```
## /MANIFEST.in
```in path="/MANIFEST.in"
prune examples
```
## /README.md
# DeTi*k*Zify
Synthesizing Graphics Programs for Scientific Figures and Sketches with Ti*k*Z
[](https://openreview.net/forum?id=bcVLFQCOjc)
[](https://arxiv.org/abs/2405.15306)
[](https://huggingface.co/collections/nllg/detikzify-664460c521aa7c2880095a8b)
[](https://colab.research.google.com/drive/1hPWqucbPGTavNlYvOBvSNBAwdcPZKe8F)
Creating high-quality scientific figures can be time-consuming and challenging,
even though sketching ideas on paper is relatively easy. Furthermore,
recreating existing figures that are not stored in formats preserving semantic
information is equally complex. To tackle this problem, we introduce
[DeTi*k*Zify](https://github.com/potamides/DeTikZify), a novel multimodal
language model that automatically synthesizes scientific figures as
semantics-preserving [Ti*k*Z](https://github.com/pgf-tikz/pgf) graphics
programs based on sketches and existing figures. We also introduce an
MCTS-based inference algorithm that enables DeTi*k*Zify to iteratively refine
its outputs without the need for additional training.
https://github.com/potamides/DeTikZify/assets/53401822/203d2853-0b5c-4a2b-9d09-3ccb65880cd3
## News
* **2025-03-17**: We release
[Ti*k*Zero](https://huggingface.co/nllg/tikzero-adapter) adapters which plug
directly into [DeTi*k*Zifyv2
(8b)](https://huggingface.co/nllg/detikzify-v2-8b) and enable zero-shot
text-conditioning, and
[Ti*k*Zero+](https://huggingface.co/nllg/tikzero-plus-10b) with additional
end-to-end fine-tuning. For more information see our
[paper](https://arxiv.org/abs/2503.11509) and usage examples [below](#usage).
* **2024-12-05**: We release [DeTi*k*Zifyv2
(8b)](https://huggingface.co/nllg/detikzify-v2-8b), our latest model which
surpasses all previous versions in our evaluation and make it the new default
model in our [Hugging Face
Space](https://huggingface.co/spaces/nllg/DeTikZify). Check out the [model
card](https://huggingface.co/nllg/detikzify-v2-8b-preview#model-card-for-detikzifyv2-8b)
for more information.
* **2024-09-24**: DeTi*k*Zify was accepted at [NeurIPS
2024](https://neurips.cc/Conferences/2024) as a [spotlight
paper](https://neurips.cc/virtual/2024/poster/94474)!
## Installation
> [!TIP]
> If you encounter difficulties with installation and inference on your own
> hardware, consider visiting our [Hugging Face
> Space](https://huggingface.co/spaces/nllg/DeTikZify) (please note that
> restarting the space can take up to 30 minutes). Should you experience long
> queues, you have the option to
> [duplicate](https://huggingface.co/spaces/nllg/DeTikZify?duplicate=true) it
> with a paid private GPU runtime for a more seamless experience. Additionally,
> you can try our demo on [Google
> Colab](https://colab.research.google.com/drive/1hPWqucbPGTavNlYvOBvSNBAwdcPZKe8F).
> However, setting up the environment there might take some time, and the free
> tier only supports inference for the 1b models.
The Python package of DeTi*k*Zify can be easily installed using
[pip](https://pip.pypa.io/en/stable):
```sh
pip install 'detikzify[legacy] @ git+https://github.com/potamides/DeTikZify'
```
The `[legacy]` extra is only required if you plan to use the
DeTi*k*Zifyv1 models. If you only plan to use
DeTi*k*Zifyv2 you can remove it. If your goal is to run the included
[examples](examples), it is easier to clone the repository and install it in
editable mode like this:
```sh
git clone https://github.com/potamides/DeTikZify
pip install -e DeTikZify[examples]
```
In addition, DeTi*k*Zify requires a full
[TeX Live 2023](https://www.tug.org/texlive) installation,
[ghostscript](https://www.ghostscript.com), and
[poppler](https://poppler.freedesktop.org) which you have to install through
your package manager or via other means.
## Usage
> [!TIP]
> For interactive use and general [usage tips](detikzify/webui#usage-tips),
> we recommend checking out our [web UI](detikzify/webui), which can be started
> directly from the command line (use `--help` for a list of all options):
> ```sh
> python -m detikzify.webui --light
> ```
If all required dependencies are installed, the full range of DeTi*k*Zify
features such as compiling, rendering, and saving Ti*k*Z graphics, and
MCTS-based inference can be accessed through its programming interface:
DeTikZify Example
```python
from operator import itemgetter
from detikzify.model import load
from detikzify.infer import DetikzifyPipeline
image = "https://w.wiki/A7Cc"
pipeline = DetikzifyPipeline(*load(
model_name_or_path="nllg/detikzify-v2-8b",
device_map="auto",
torch_dtype="bfloat16",
))
# generate a single TikZ program
fig = pipeline.sample(image=image)
# if it compiles, rasterize it and show it
if fig.is_rasterizable:
fig.rasterize().show()
# run MCTS for 10 minutes and generate multiple TikZ programs
figs = set()
for score, fig in pipeline.simulate(image=image, timeout=600):
figs.add((score, fig))
# save the best TikZ program
best = sorted(figs, key=itemgetter(0))[-1][1]
best.save("fig.tex")
```
Through [Ti*k*Zero](https://huggingface.co/nllg/tikzero-adapter) adapters and
[Ti*k*Zero+](https://huggingface.co/nllg/tikzero-plus-10b) it is also possible
to synthesize graphics programs conditioned on text (cf. our
[paper](https://arxiv.org/abs/2503.11509) for
details). Note that this currently only supported through the programming
interface:
TikZero+ Example
```python
from detikzify.model import load
from detikzify.infer import DetikzifyPipeline
caption = "A multi-layer perceptron with two hidden layers."
pipeline = DetikzifyPipeline(*load(
model_name_or_path="nllg/tikzero-plus-10b",
device_map="auto",
torch_dtype="bfloat16",
))
# generate a single TikZ program
fig = pipeline.sample(text=caption)
# if it compiles, rasterize it and show it
if fig.is_rasterizable:
fig.rasterize().show()
```
TikZero Example
```python
from detikzify.model import load, load_adapter
from detikzify.infer import DetikzifyPipeline
caption = "A multi-layer perceptron with two hidden layers."
pipeline = DetikzifyPipeline(
*load_adapter(
*load(
model_name_or_path="nllg/detikzify-v2-8b",
device_map="auto",
torch_dtype="bfloat16",
),
adapter_name_or_path="nllg/tikzero-adapter",
)
)
# generate a single TikZ program
fig = pipeline.sample(text=caption)
# if it compiles, rasterize it and show it
if fig.is_rasterizable:
fig.rasterize().show()
```
More involved examples, for example for evaluation and training, can be found
in the [examples](examples) folder.
## Model Weights & Datasets
We upload all our DeTi*k*Zify models and datasets to the [Hugging Face
Hub](https://huggingface.co/collections/nllg/detikzify-664460c521aa7c2880095a8b)
(Ti*k*Zero models are available
[here](https://huggingface.co/collections/nllg/tikzero-67d1952fab69f5bd172de1fe)).
However, please note that for the public release of the
[DaTi*k*Zv2](https://huggingface.co/datasets/nllg/datikz-v2)
and [DaTi*k*Zv3](https://huggingface.co/datasets/nllg/datikz-v3)
datasets, we had to remove a considerable portion of Ti*k*Z drawings
originating from [arXiv](https://arxiv.org), as the [arXiv non-exclusive
license](https://arxiv.org/licenses/nonexclusive-distrib/1.0/license.html) does
not permit redistribution. We do, however, release our [dataset creation
scripts](https://github.com/potamides/DaTikZ) and encourage anyone to recreate
the full version of DaTi*k*Z themselves.
## Citation
If DeTi*k*Zify and Ti*k*Zero have been beneficial for your research or
applications, we kindly request you to acknowledge this by citing them as
follows:
```bibtex
@inproceedings{belouadi2024detikzify,
title={{DeTikZify}: Synthesizing Graphics Programs for Scientific Figures and Sketches with {TikZ}},
author={Jonas Belouadi and Simone Paolo Ponzetto and Steffen Eger},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=bcVLFQCOjc}
}
```
```bibtex
@misc{belouadi2025tikzero,
title={{TikZero}: Zero-Shot Text-Guided Graphics Program Synthesis},
author={Jonas Belouadi and Eddy Ilg and Margret Keuper and Hideki Tanaka and Masao Utiyama and Raj Dabre and Steffen Eger and Simone Paolo Ponzetto},
year={2025},
eprint={2503.11509},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2503.11509},
}
```
## Acknowledgments
The implementation of the DeTi*k*Zify model architecture is based on
[LLaVA](https://github.com/haotian-liu/LLaVA) and
[AutomaTikZ](https://github.com/potamides/AutomaTikZ) (v1), and [Idefics
3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (v2). Our MCTS
implementation is based on
[VerMCTS](https://github.com/namin/llm-verified-with-monte-carlo-tree-search).
The Ti*k*Zero architecture draws inspiration from
[Flamingo](https://deepmind.google/discover/blog/tackling-multiple-tasks-with-a-single-visual-language-model/)
and [LLaMA
3.2-Vision](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices).
## /detikzify/__init__.py
```py path="/detikzify/__init__.py"
```
## /detikzify/dataset/__init__.py
```py path="/detikzify/dataset/__init__.py"
from datasets.load import load_dataset as _load_dataset
from os.path import dirname, isdir, join
def load_dataset(path, *args, **kwargs):
if isdir(local := join(dirname(__file__), path)):
return _load_dataset(local, *args, trust_remote_code=True, **kwargs)
return _load_dataset(path, *args, **kwargs)
```
## /detikzify/dataset/paper2fig/__init__.py
```py path="/detikzify/dataset/paper2fig/__init__.py"
```
## /detikzify/dataset/paper2fig/paper2fig.py
```py path="/detikzify/dataset/paper2fig/paper2fig.py"
"""
Images from the Paper2Fig100k dataset.
"""
from itertools import chain
from json import load
from os.path import basename
import tarfile
from datasets import Features, Image, Sequence, Value, builder
from datasets.info import DatasetInfo
from datasets.splits import Split, SplitGenerator
from detikzify.util import convert, expand
class Paper2FigConfig(builder.BuilderConfig):
"""BuilderConfig for Paper2Fig."""
def __init__(self, size, *args, **kwargs):
super().__init__(*args, **kwargs)
self.size = size
self.archive = "https://zenodo.org/records/7299423/files/Paper2Fig100k.tar.gz"
class Paper2Fig(builder.GeneratorBasedBuilder):
"""The Paper2Fig100k dataset in the format DeTikZify expects (everything is training data)."""
BUILDER_CONFIG_CLASS = Paper2FigConfig
def _info(self):
features = {
"caption": Value("string"),
"mention": Sequence(Sequence(Value("string"))),
"ocr": Sequence(Value("string")),
"image": Image(),
}
return DatasetInfo(
description=str(__doc__),
features=Features(features),
)
def _split_generators(self, dl_manager):
archive = dl_manager.download(self.config.archive) # type: ignore
return [SplitGenerator(name=str(Split.TRAIN), gen_kwargs=dict(archive=archive))]
def _generate_examples(self, archive):
with tarfile.open(archive) as tf:
metadata = dict()
for figdata in chain.from_iterable(load(tf.extractfile(f)) for f in tf if f.name.endswith(".json")): # type: ignore
metadata[figdata.pop("figure_id")] = figdata
for idx, member in enumerate(tf):
if member.name.endswith(".png"):
figure_id = basename(member.name).removesuffix(".png")
figdata = metadata[figure_id]
yield idx, dict(
caption=figdata["captions"][0],
mention=[figdata["captions"][1:]],
ocr=[result['text'] for result in figdata['ocr_result']['ocr_result']],
image=convert(expand(tf.extractfile(member), self.config.size), "png"),
)
```
## /detikzify/dataset/scicap/__init__.py
```py path="/detikzify/dataset/scicap/__init__.py"
```
## /detikzify/dataset/scicap/scicap.py
```py path="/detikzify/dataset/scicap/scicap.py"
"""
The SciCap dataset, unified in a single train split.
"""
from json import load
from os import symlink
from os.path import basename, join
from subprocess import run
from tempfile import TemporaryDirectory
from zipfile import ZipFile
from datasets import Features, Image, Sequence, Value, builder
from datasets.info import DatasetInfo
from datasets.splits import Split, SplitGenerator
from datasets.utils.hub import hf_hub_url
from detikzify.util import convert, expand
class SciCapConfig(builder.BuilderConfig):
"""BuilderConfig for SciCap."""
def __init__(self, size, *args, **kwargs):
super().__init__(*args, **kwargs)
self.repo_id = "CrowdAILab/scicap"
self.size = size
self.files = {
"img": {
(public:="img-split"): 10,
(hidden:="img-hide_test"): 0
},
"text": {
"train": public,
"train-acl": public,
"val": public,
"public-test": public,
"hide_test": hidden,
}
}
class SciCap(builder.GeneratorBasedBuilder):
"""The SciCap dataset in the format DeTikZify expects (everything is training data)."""
BUILDER_CONFIG_CLASS = SciCapConfig
def _info(self):
features = {
"caption": Value("string"),
"mention": Sequence(Sequence(Value("string"))),
"paragraph": Sequence(Value("string")),
"ocr": Sequence(Value("string")),
"image": Image(),
}
return DatasetInfo(
description=str(__doc__),
features=Features(features),
)
def _split_generators(self, dl_manager):
with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
def dl(path):
return dl_manager.download(hf_hub_url(self.config.repo_id, path)) # type: ignore
def zip_dl(path, num_splits=0):
paths = [f"{path}.zip"] + list(f"{path}.z{{:02d}}".format(i+1) for i in range(num_splits))
downloaded = [dl(path) for path in paths]
if num_splits:
output = join(tmpdirname, f"{path}-joined.zip")
for src, dst in zip(downloaded, paths):
symlink(src, join(tmpdirname, dst)) # type: ignore
run(["zip", "-FF", join(tmpdirname, paths[0]), "--out", output], check=True, capture_output=True)
return output
else:
return downloaded[0]
files_to_download = self.config.files # type: ignore
img = {file:zip_dl(file, num_splits) for file, num_splits in files_to_download['img'].items()}
text = {dl(f"{file}.json"):img[img_file] for file, img_file in files_to_download['text'].items()}
yield SplitGenerator(name=str(Split.TRAIN), gen_kwargs={"shards": text})
def _generate_examples(self, shards):
idx = 0
for path, image_zip in shards.items():
with ZipFile(file=image_zip, mode='r') as zf:
imagemap = {basename(name):name for name in zf.namelist()}
with open(path) as f:
images, annotations = load(f).values()
for annotation, image in zip(annotations, images):
assert image["id"] == annotation['image_id']
with zf.open(imagemap[image['file_name']]) as img:
yield idx, dict(
caption=annotation.get("caption_no_index"),
mention=annotation.get("mention"),
paragraph=annotation.get("paragraph"),
ocr=image.get("ocr"),
image=convert(expand(img, self.config.size), "png")
)
idx += 1
```
## /detikzify/evaluate/__init__.py
```py path="/detikzify/evaluate/__init__.py"
# pyright: reportUnsupportedDunderAll=false
from importlib import import_module
from typing import Any
from .imagesim import * # this metric is used by MCTS, so it is not optional
__all__ = [
"ImageSim",
"CrystalBLEU",
"KernelInceptionDistance",
"TexEditDistance",
"DreamSim",
"ClipScore",
]
# lazy import optional metrics (https://peps.python.org/pep-0562/)
def __getattr__(name) -> Any:
def load(metric):
return getattr(import_module("." + metric, __name__), name)
try:
match name:
case "CrystalBLEU":
return load("crystalbleu")
case "KernelInceptionDistance":
return load("kid")
case "TexEditDistance":
return load("eed")
case "DreamSim":
return load("dreamsim")
case "ClipScore":
return load("clipscore")
except ImportError:
raise ValueError(
"Missing dependencies: "
"Install this project with the [evaluate] feature name!"
)
return import_module("." + name, __name__)
```
## /detikzify/evaluate/clipscore.py
```py path="/detikzify/evaluate/clipscore.py"
from functools import cached_property
from typing import List
from PIL import Image
import torch
from torch.cuda import is_available as is_cuda_available, is_bf16_supported
from torchmetrics import Metric
from transformers import AutoModel, AutoProcessor
from ..util import expand, infer_device, load
class ClipScore(Metric):
"""Calculates CLIPScore which is a text-to-image similarity metric."""
higher_is_better = True
def __init__(
self,
model_name: str = "google/siglip-so400m-patch14-384",
preprocess: bool = True,
device: str = infer_device(),
dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16,
**kwargs
):
super().__init__(**kwargs)
self.model_name = model_name
self.preprocess = preprocess
self._device = device
self.set_dtype(dtype)
self.add_state("score", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum")
self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum")
def __str__(self):
return self.__class__.__name__
@cached_property
def model(self):
model = AutoModel.from_pretrained(self.model_name, torch_dtype=self.dtype)
return model.to(self.device)
@cached_property
def processor(self):
return AutoProcessor.from_pretrained(self.model_name)
def update(
self,
images: Image.Image | str | List[Image.Image | str],
text: str | List[str]
):
images = images if isinstance(images, List) else [images]
text = text if isinstance(text, List) else [text]
for img, txt in zip(images, text):
img = load(img)
if self.preprocess:
img = expand(img, max(img.size), do_trim=True)
with torch.inference_mode():
inputs = self.processor(text=txt, images=img, truncation=True, return_tensors="pt")
outputs = self.model(
input_ids=inputs.input_ids.to(self.device),
pixel_values=inputs.pixel_values.to(self.device, self.dtype)
)
self.score += torch.sigmoid(outputs.logits_per_image).item()
self.n_samples += 1
def compute(self):
return (self.score / self.n_samples).item()
```
## /detikzify/evaluate/crystalbleu.py
```py path="/detikzify/evaluate/crystalbleu.py"
from collections import Counter
from functools import cached_property
from hashlib import md5
from itertools import chain, tee
from pickle import dump, load
from typing import List
from crystalbleu import corpus_bleu
from datasets.utils.logging import get_logger
from huggingface_hub import cached_assets_path
from pygments.lexers.markup import TexLexer
from pygments.token import Comment, Name, Text
from sacremoses import MosesTokenizer
from torchmetrics import Metric
logger = get_logger("datasets")
# adopted from nltk
def pad_sequence(sequence, n, pad_left=False, pad_right=False, left_pad_symbol=None, right_pad_symbol=None):
sequence = iter(sequence)
if pad_left:
sequence = chain((left_pad_symbol,) * (n - 1), sequence)
if pad_right:
sequence = chain(sequence, (right_pad_symbol,) * (n - 1))
return sequence
# adopted from nltk
def ngrams(sequence, n, **kwargs):
sequence = pad_sequence(sequence, n, **kwargs)
iterables = tee(sequence, n)
for i, sub_iterable in enumerate(iterables): # For each window,
for _ in range(i): # iterate through every order of ngrams
next(sub_iterable, None) # generate the ngrams within the window.
return zip(*iterables) # Unpack and flattens the iterables.
class CrystalBLEU(Metric):
"""Wrapper around https://github.com/sola-st/crystalbleu (adapted for LaTeX)"""
def __init__(self, corpus, k=500, n=4, use_cache=True, **kwargs):
super().__init__(**kwargs)
self.lexer = TexLexer()
self.tokenizer = MosesTokenizer()
self.use_cache = use_cache
self.corpus = corpus
self.k = k
self.n = n
self.add_state("list_of_references", [], dist_reduce_fx="cat")
self.add_state("hypotheses", [], dist_reduce_fx="cat")
def __str__(self):
return self.__class__.__name__
@cached_property
def trivially_shared_ngrams(self):
"""
Computes trivially shared ngrams and caches them.
"""
cache_dir = cached_assets_path(library_name="evaluate", namespace=self.__class__.__name__.lower())
dhash = md5()
dhash.update(str(sorted(self.corpus)).encode())
hashname = f"{dhash.hexdigest()}.pkl"
if (cache_file:=(cache_dir / hashname)).is_file() and self.use_cache:
logger.info(f"Found cached trivially shared ngrams ({cache_file})")
with open(cache_file, "rb") as f:
return load(f)
else:
all_ngrams = list()
for o in range(1, self.n+1):
for tex in self.corpus:
all_ngrams.extend(ngrams(self._tokenize(tex), o))
frequencies = Counter(all_ngrams)
trivially_shared_ngrams = dict(frequencies.most_common(self.k))
if self.use_cache:
logger.info(f"Caching trivially shared ngrams ({cache_file})")
with open(cache_file, "wb") as f:
dump(trivially_shared_ngrams, f)
return trivially_shared_ngrams
def _tokenize(self, text):
tokens = list()
for tokentype, value in self.lexer.get_tokens(text):
if value.strip() and not tokentype is Comment:
if any(tokentype is tp for tp in [Text, Name.Attribute, Name.Builtin]):
tokens.extend(self.tokenizer.tokenize(value.strip()))
else:
tokens.append(value.strip())
return tokens
def update(
self,
list_of_references: List[List[str]],
hypotheses: List[str],
):
assert len(list_of_references) == len(hypotheses)
self.list_of_references.extend([self._tokenize(ref) for ref in refs] for refs in list_of_references)
self.hypotheses.extend(self._tokenize(hyp) for hyp in hypotheses)
def compute(self):
return corpus_bleu(
list_of_references=self.list_of_references,
hypotheses=self.hypotheses,
ignoring=self.trivially_shared_ngrams
)
```
## /detikzify/evaluate/dreamsim.py
```py path="/detikzify/evaluate/dreamsim.py"
from functools import cached_property
from typing import List
from PIL import Image
from dreamsim import dreamsim
from huggingface_hub import cached_assets_path
import torch
from torch.cuda import is_available as is_cuda_available, is_bf16_supported
from torchmetrics import Metric
from ..util import expand, infer_device, load
class DreamSim(Metric):
"""Perceptual image similarity using DreamSim"""
higher_is_better = True
def __init__(
self,
model_name: str = "ensemble",
pretrained: bool = True,
normalize: bool = True,
preprocess: bool = True,
device: str = infer_device(),
dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16,
**kwargs
):
super().__init__(**kwargs)
self.model_name = model_name
self.pretrained = pretrained
self.normalize = normalize
self._device = device
self.set_dtype(dtype)
self.preprocess = preprocess
self.add_state("score", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum")
self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum")
def __str__(self):
return self.__class__.__name__
@cached_property
def dreamsim(self):
model, processor = dreamsim(
dreamsim_type=self.model_name,
pretrained = self.pretrained,
normalize_embeds=self.normalize,
device=str(self.device),
cache_dir=str(cached_assets_path(library_name="evaluate", namespace=self.__class__.__name__.lower()))
)
for extractor in model.extractor_list:
extractor.model = extractor.model.to(self.dtype)
extractor.proj = extractor.proj.to(self.dtype)
return dict(
model=model.to(self.dtype),
processor=processor
)
@property
def model(self):
return self.dreamsim['model']
@property
def processor(self):
return self.dreamsim['processor']
def update(
self,
img1: Image.Image | str | List[Image.Image | str],
img2: Image.Image | str | List[Image.Image | str],
):
if isinstance(img1, List) or isinstance(img2, List):
assert type(img1) == type(img2) and len(img1) == len(img2) # type: ignore
else:
img1, img2 = [img1], [img2]
for i1, i2 in zip(img1, img2): # type: ignore
i1, i2 = load(i1), load(i2)
if self.preprocess:
i1 = expand(load(i1), max(i1.size), do_trim=True)
i2 = expand(load(i2), max(i2.size), do_trim=True)
i1 = self.processor(i1).to(self.device, self.dtype)
i2 = self.processor(i2).to(self.device, self.dtype)
with torch.inference_mode():
self.score += 1 - self.model(i1, i2).item() # type: ignore
self.n_samples += 1
def compute(self):
return (self.score / self.n_samples).item()
```
## /detikzify/evaluate/eed.py
```py path="/detikzify/evaluate/eed.py"
from pygments.lexers.markup import TexLexer
from pygments.token import Comment, Text
from torchmetrics.text import ExtendedEditDistance
from torchmetrics.functional.text.eed import (
_compute_sentence_statistics,
_preprocess_en,
_preprocess_ja,
)
from torchmetrics.functional.text.helper import _validate_inputs
class TexEditDistance(ExtendedEditDistance):
"""Adapt torchmetrics ExtendedEditDistance for TeX"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lexer = TexLexer()
def __str__(self):
return self.__class__.__name__
def _preprocess_sentences(self, preds, target, language):
target, preds = _validate_inputs(hypothesis_corpus=preds, ref_corpus=target)
def tokenize(text):
tokens = list()
for tokentype, value in self.lexer.get_tokens(text):
if value.strip():
if tokentype is Text:
if language == "en":
preprocess_function = _preprocess_en
elif language == "ja":
preprocess_function = _preprocess_ja
else:
raise ValueError(f"Expected argument `language` to either be `en` or `ja` but got {language}")
tokens.extend(preprocess_function(value).split())
elif not tokentype is Comment:
tokens.extend(value.split())
return " " + " ".join(tokens) + " "
preds = [tokenize(pred) for pred in preds]
target = [[tokenize(ref) for ref in reference] for reference in target]
return preds, target
def update(self, preds, target):
"""Update state with predictions and targets."""
preds, target = self._preprocess_sentences(preds, target, self.language)
if self.sentence_eed is None:
self.sentence_eed = []
if 0 in (len(preds), len(target[0])):
return self.sentence_eed
for (hypothesis, target_words) in zip(preds, target):
score = _compute_sentence_statistics(
hypothesis,
target_words,
self.alpha,
self.rho,
self.deletion,
self.insertion
)
self.sentence_eed.append(score)
return self.sentence_eed
def compute(self, *args, **kwargs):
return super().compute(*args, **kwargs).item() # type: ignore
```
## /detikzify/evaluate/imagesim.py
```py path="/detikzify/evaluate/imagesim.py"
from functools import cached_property
from math import tanh
from typing import List, Literal, Optional
from PIL import Image
from ot.lp import emd2
import torch
from torch.cuda import is_available as is_cuda_available, is_bf16_supported
import torch.nn.functional as F
from torchmetrics import Metric
from torchmetrics.functional import pairwise_cosine_similarity
from transformers import AutoImageProcessor, AutoModel, PreTrainedModel, ProcessorMixin
from ..model.adapter import (
AdapterProcessor,
CrossAttentionAdapterMixin as AdapterMixin,
has_adapter,
)
from ..util import cast, expand, infer_device, load, unwrap_processor
class ImageSim(Metric):
"""Perceptual image similarity using visual encoders."""
higher_is_better = True
def __init__(
self,
model_name: str = "google/siglip-so400m-patch14-384",
mode: Literal["emd", "cos", "cos_avg"] = "cos",
preprocess: bool = True,
device: str = infer_device(),
dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16,
**kwargs
):
super().__init__(**kwargs)
self.model_name = model_name
self.preprocess = preprocess
self.mode = mode
self._device = device
self.set_dtype(dtype)
self.add_state("score", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum")
self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum")
def __str__(self):
return self.__class__.__name__ + f" ({self.mode.upper().replace('_', '-')})"
@cached_property
def model(self):
# even if we instantiate with from_detikzify we still end up in this function
if (model:=dict(self.named_children()).get("model")) is None:
model = AutoModel.from_pretrained(self.model_name, torch_dtype=self.dtype)
model = model.vision_model.to(self.device)
return model
@cached_property
def processor(self):
return AutoImageProcessor.from_pretrained(self.model_name)
@classmethod
def from_detikzify(cls, model: PreTrainedModel, processor: ProcessorMixin, mode=None, *args, **kwargs):
derived_kwargs = dict(
model_name = model.name_or_path,
mode = getattr(model.config, "pooling_mode", "emd") if mode is None else mode,
device = model.device,
dtype = model.dtype,
)
imagesim = cls(*args, **(derived_kwargs | kwargs))
if has_adapter(model):
class AdapterVisionModel(type(model.model.vision_model), AdapterMixin):
embedding_model=model.embedding_model
adapter=model.adapter
@classmethod
def cast(cls, vision_model):
adapter_vision_model = cast(cls, vision_model)
adapter_vision_model.add_hooks()
return adapter_vision_model
imagesim.model = AdapterVisionModel.cast(model.model.vision_model)
imagesim.processor = AdapterProcessor(
processor=unwrap_processor(processor).image_processor,
tokenizer=processor.tokenizer # type: ignore
)
else:
imagesim.model = model.model.vision_model
imagesim.processor = unwrap_processor(processor).image_processor
return imagesim
def get_vision_features(self, image: Optional[Image.Image | str] = None, text: Optional[str] = None):
if image is not None:
image = load(image)
if self.preprocess:
image = expand(image, max(image.size), do_trim=True)
with torch.inference_mode():
if text is not None:
encoding = self.processor(text=text, images=image, return_tensors="pt").to(self.device, self.dtype)
else:
encoding = self.processor(images=image, return_tensors="pt").to(self.device, self.dtype)
if self.mode == "cos":
return self.model(**encoding).pooler_output.squeeze()
elif self.mode == "cos_avg":
return self.model(**encoding).last_hidden_state.squeeze().mean(dim=0)
else:
return self.model(**encoding).last_hidden_state.squeeze()
def get_similarity(
self,
img1: Optional[Image.Image | str] = None,
img2: Optional[Image.Image | str] = None,
text1: Optional[str] = None,
text2: Optional[str] = None,
):
img1_feats = self.get_vision_features(img1, text1)
img2_feats = self.get_vision_features(img2, text2)
if img1_feats.is_mps: # mps backend does not support dtype double
img1_feats, img2_feats = img1_feats.cpu(), img2_feats.cpu()
if img1_feats.ndim > 1:
dists = 1 - pairwise_cosine_similarity(img1_feats.double(), img2_feats.double()).cpu().numpy()
return 2 * tanh(-emd2(M=dists, a=list(), b=list())) + 1 # type: ignore
else:
return F.cosine_similarity(img1_feats.double(), img2_feats.double(), dim=0).item()
def update(
self,
img1: Optional[Image.Image | str | List[Image.Image | str]] = None,
img2: Optional[Image.Image | str | List[Image.Image | str]] = None,
text1: Optional[str | List[str]] = None,
text2: Optional[str | List[str]] = None,
):
inputs = dict()
for key, value in dict(img1=img1, img2=img2, text1=text1, text2=text2).items():
if value is not None:
inputs[key] = value if isinstance(value, List) else [value]
assert not ({"img1", "text1"}.isdisjoint(inputs.keys()) or {"img2", "text2"}.isdisjoint(inputs.keys()))
assert len(set(map(len, inputs.values()))) == 1
for inpt in zip(*inputs.values()):
self.score += self.get_similarity(**dict(zip(inputs.keys(), inpt)))
self.n_samples += 1
def compute(self):
return (self.score / self.n_samples).item()
```
## /detikzify/evaluate/kid.py
```py path="/detikzify/evaluate/kid.py"
from functools import cached_property
from typing import List
from PIL import Image
import torch
from torch import nn
from torch.cuda import is_available as is_cuda_available, is_bf16_supported
from torchmetrics.image.kid import KernelInceptionDistance as KID
from transformers import AutoModel, AutoImageProcessor
from ..util import expand, infer_device, load
class FeatureWrapper(nn.Module):
def __init__(self, model_name, device, dtype):
super().__init__()
self.model_name = model_name
self.device = device
self.dtype = dtype
@cached_property
def model(self):
model = AutoModel.from_pretrained(self.model_name, torch_dtype=self.dtype)
return model.to(self.device)
def forward(self, pixel_values):
with torch.inference_mode():
return self.model.get_image_features(pixel_values.to(self.device, self.dtype))
class KernelInceptionDistance(KID):
"""Wrapper around torchmetrics Kernel Inception Distance with CLIP"""
def __init__(
self,
model_name: str = "google/siglip-so400m-patch14-384",
subset_size: int = 50,
preprocess: bool = True,
device: str = infer_device(),
dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16,
**kwargs
):
super().__init__(
subset_size=subset_size,
feature=FeatureWrapper(
model_name=model_name,
device=device,
dtype=dtype),
**kwargs
)
self.preprocess = preprocess
def __str__(self):
return self.__class__.__name__
@cached_property
def processor(self):
return AutoImageProcessor.from_pretrained(self.inception.model_name)
def open(self, img):
img = load(img)
if self.preprocess:
return expand(img, max(img.size), do_trim=True)
return img
def update(self, imgs: Image.Image | str | List[Image.Image | str], *args, **kwargs):
if not isinstance(imgs, List):
imgs = [imgs]
super().update(
self.processor([self.open(img) for img in imgs], return_tensors="pt")["pixel_values"],
*args,
**kwargs
)
def compute(self, *args, **kwargs): # type: ignore
return tuple(tensor.item() for tensor in super().compute(*args, **kwargs))
```
## /detikzify/infer/__init__.py
```py path="/detikzify/infer/__init__.py"
from .tikz import *
from .generate import *
```
## /detikzify/infer/generate.py
```py path="/detikzify/infer/generate.py"
from collections import deque
from dataclasses import dataclass
from functools import cached_property
from math import sqrt
from multiprocessing.pool import ThreadPool
from re import sub
from time import time
from types import SimpleNamespace as Namespace
from typing import Any, Dict, Generator, List, Literal, Optional, Set, Tuple, Union
from PIL import Image
import torch
from torchmetrics import Metric
from transformers import StoppingCriteriaList
from transformers.generation.streamers import BaseStreamer
from ..evaluate.imagesim import ImageSim
from ..mcts.montecarlo import MonteCarlo
from ..mcts.node import Node
from ..model.adapter import has_adapter
from ..util import (
ExplicitAbort,
StreamerList,
TokenStreamer,
cache_cast,
expand,
load,
unwrap_processor as unwrap,
)
from .tikz import TikzDocument
Numeric = Union[int, float]
@dataclass(frozen=True)
class NodeState:
token_ids: torch.Tensor
num_lines: int = 0
def __eq__(self, other: Any) -> bool:
try:
return self.token_ids.equal(other.token_ids)
except (AttributeError, TypeError):
return False
def __hash__(self):
return hash(tuple(self.token_ids.tolist()))
class WideNode(Node):
state: NodeState
def __init__(self, *args, exploration=0.6, is_widen_node=False, **kwargs):
super().__init__(NodeState(*args, **kwargs))
self.discovery_factor = exploration
self.is_widen_node = is_widen_node
self.update_policy_value(1.0)
if not is_widen_node:
self.add_child(WideNode(
*args,
exploration=exploration,
is_widen_node=not is_widen_node,
**kwargs
))
def add_child(self, child):
self.expanded = self.expanded or not child.is_widen_node
super().add_child(child)
@property
def depth(self) -> int:
depth, current = 0, self
while parent:=current.parent:
depth, current = depth + 1, parent
return depth
@property
def token_ids(self):
return self.state.token_ids
@property
def num_lines(self):
return self.state.num_lines
class DynMinMaxNorm:
def __init__(self, default_value: Numeric = 0):
self.scores = set()
self.default_value = default_value
def normalize(self, score: Numeric) -> "MinMaxScore":
self.scores.add(score)
return self.MinMaxScore(score, all_scores=self.scores, default_value=self.default_value)
def __call__(self, *args, **kwargs) -> "MinMaxScore":
return self.normalize(*args, **kwargs)
class MinMaxScore:
def __init__(
self,
*scores: Numeric,
all_scores: Set[Numeric],
default_value: Numeric,
no_minmax_scores: List[Numeric] = list(),
):
self.scores = list(scores)
self.all_scores = all_scores
self.default_value = default_value
self.no_minmax_scores = no_minmax_scores.copy()
@property
def score(self) -> Numeric:
min_score, max_score = min(self.all_scores), max(self.all_scores)
try:
score = sum((score - min_score) / (max_score - min_score) for score in self.scores)
except ZeroDivisionError:
score = self.default_value
return score + sum(self.no_minmax_scores)
def __add__(self, other: Any) -> "DynMinMaxNorm.MinMaxScore":
new = self.__class__(
*self.scores,
all_scores=self.all_scores,
default_value=self.default_value,
no_minmax_scores=self.no_minmax_scores
)
try:
new.scores.extend(other.scores)
new.no_minmax_scores.extend(other.no_minmax_scores)
except AttributeError:
new.no_minmax_scores.append(other)
return new
def __mul__(self, other: Any) -> "DynMinMaxNorm.MinMaxScore":
return self.score * other
def __truediv__(self, other: Any) -> "DynMinMaxNorm.MinMaxScore":
return self.score / other
def __rtruediv__(self, other: Any) -> "DynMinMaxNorm.MinMaxScore":
return other / self.score
__radd__, __rmul__ = __add__, __mul__
class DetikzifyGenerator:
def __init__(
self,
model,
processor,
image: Optional[Image.Image],
text: Optional[str] = None,
metric: Optional[Metric] = None,
compile_timeout: Optional[int] = 60,
mcts_timeout: Optional[int] = None,
streamer: Optional[BaseStreamer] = None,
control: Optional[ExplicitAbort] = None,
exploration: float = 0.6, # exploration coefficient
strict: bool = False, # if True, treat recoverable errors same as fatal errors when computing scores
**gen_kwargs,
):
self.model = model
self.processor = processor
self.metric = metric
self.image = image
self.text = text
self.compile_timeout = compile_timeout
self.mcts_timeout = mcts_timeout
self.streamer = streamer
self.exploration = exploration
self.strict = strict
self.gen_kwargs = gen_kwargs
self.solution = deque(maxlen=1)
self.failed_rollouts = dict()
self.norm = DynMinMaxNorm()
self.control = control or ExplicitAbort()
self.montecarlo = MonteCarlo(
root_node=WideNode(
processor(
images=self.image,
text=self.text,
return_tensors="pt",
).input_ids.to(model.device).squeeze(),
exploration=self.exploration
)
)
self.montecarlo.child_finder = self.child_finder # type: ignore
# https://stackoverflow.com/a/68550238
self.decode = cache_cast(lambda token_ids: tuple(token_ids.tolist()))(self.decode)
self.score = cache_cast(lambda image: image.tobytes())(self.score)
def __call__(self, *args, **kwargs):
return self.simulate(*args, **kwargs)
def simulate(self, expansions: Optional[Numeric] = 1) -> Generator[Tuple[Numeric, TikzDocument], None, None]:
"""
Run the simulations. Yields all rollouts (successful and unsuccessful)
as (score, TikZ picture) tuples.
"""
start_time = time()
while expansions is None or (expansions:=expansions-1) >= 0:
self.montecarlo.simulate()
yield self.solution.pop()
if self.mcts_timeout is not None and time() - start_time > self.mcts_timeout:
return
def generate(self, input_ids: torch.Tensor, streamer: Optional[BaseStreamer] = None, **gen_kwargs) -> torch.Tensor:
streamers, numel = StreamerList(filter(bool, [streamer, self.streamer])), input_ids.numel()
max_length = {**self.model.generation_config.to_dict(), **self.gen_kwargs, **gen_kwargs}["max_length"]
if (numel and input_ids[-1] == unwrap(self.processor).tokenizer.eos_token_id) or numel >= max_length:
streamers.end()
return input_ids # prevent continuing generation after eos
with torch.inference_mode():
token_ids = self.processor(images=self.image, text=self.text, text_kwargs={"truncation": True}, return_tensors="pt")
adapter_kwargs = {k: v for k, v in token_ids.to(self.model.device).items() if k.startswith("adapter")}
return self.model.generate(
input_ids=input_ids.unsqueeze(0),
bad_words_ids=[[self.model.config.image_token_id]],
pixel_values=token_ids.get("pixel_values"),
streamer=streamers,
**adapter_kwargs,
**self.gen_kwargs,
**gen_kwargs
).squeeze()
@cached_property
def newlineinfo(self):
# tokens can potentially contain multiple newlines, so we need special
# handling when we want to map error lines to tokens
newlineinfo = dict()
for token_id in unwrap(self.processor).tokenizer.vocab.values():
# NOTE: Newline normalization might lead to inaccurate estimations
# for windows line separators (if split over two tokens). However,
# the likeliness of such tokens being generated (not in training
# data) as well as the potential impact is negligible.
# https://www.overleaf.com/learn/latex/Articles/An_introduction_to_%5Cendlinechar%3A_How_TeX_reads_lines_from_text_files
token = sub(r"\r\n|\r", r"\n", self.processor.decode([token_id]))
if (num_lines:=token.count("\n")):
newlineinfo[token_id] = Namespace(num_lines=num_lines, trailing=token.endswith("\n"))
assert newlineinfo
return newlineinfo
def rollout(self, state: NodeState) -> Generator[Tuple[torch.Tensor, int], None, None]:
input_ids, num_lines, continuation = state.token_ids, state.num_lines, False
with ThreadPool(processes=1) as thread:
streamer = TokenStreamer()
async_result = thread.apply_async(
func=self.generate,
error_callback=streamer.propagate_error,
args=[input_ids],
kwds=dict(
stopping_criteria=StoppingCriteriaList([self.control.reset()]),
streamer=streamer,
)
)
try:
prev_ids, line = input_ids, list()
for token in streamer:
line.append(token)
if info:=self.newlineinfo.get(token):
# continuations (newline followed by text in a single
# token) don't appear in the llama 3.1 tokenizer, but
# handling them now makes this code future proof
num_lines += info.num_lines - continuation
continuation = not info.trailing
prev_ids = torch.cat((prev_ids, torch.tensor(line, device=prev_ids.device)))
line.clear()
yield prev_ids, num_lines
if line:
yield torch.cat((prev_ids, torch.tensor(line, device=prev_ids.device))), num_lines - continuation
except (GeneratorExit, KeyboardInterrupt):
self.control.abort()
raise
else:
if self.control.should_stop:
raise InterruptedError
finally:
async_result.wait()
def decode(self, token_ids: torch.Tensor) -> TikzDocument:
return TikzDocument(
timeout=self.compile_timeout,
code=self.processor.decode(
token_ids=token_ids[len(self.montecarlo.root_node.token_ids):],
skip_special_tokens=True
)
)
def score(self, image: Image.Image) -> Numeric:
assert self.metric
self.metric.update(img1=image, img2=self.image, text2=self.text)
score = self.metric.compute()
self.metric.reset()
return score
def sample(self):
return self.decode(self.generate(
input_ids=self.montecarlo.root_node.token_ids,
))
def child_finder(self, node: WideNode, montecarlo: MonteCarlo):
new_nodes = list()
for new_state in (rollout:=self.rollout(node.state)):
new_node = WideNode(*new_state, exploration=self.exploration)
if new_node.state in self.failed_rollouts:
new_nodes.extend(self.failed_rollouts[new_node.state])
rollout.close()
break
new_nodes.append(new_node)
if node.is_widen_node:
node.visits += 1
node, new_nodes = self.merge(node.parent, new_nodes) # type: ignore
tikz = self.decode((new_nodes or [node])[-1].token_ids)
skip_idx = round(sqrt(len(new_nodes)))
if scorable:=(tikz.is_rasterizable and not (self.strict and tikz.compiled_with_errors)):
for new_node in new_nodes[:skip_idx]:
node.add_child(node:=new_node)
# Only process failed rollouts when we can locate the error (errorln !=
# 0). In rare cases there is no error information even though the
# tikzpic is not rasterizable because only cropping failed -> use [0].
elif errorln:=min(tikz.errors or [0]):
for idx, new_node in enumerate(new_nodes):
ends_with_eol = self.newlineinfo.get(new_node.token_ids[-1])
if new_node.num_lines < errorln and idx < skip_idx:
node.add_child(node:=new_node)
elif new_node.num_lines > errorln or (new_node.num_lines == errorln and ends_with_eol):
self.failed_rollouts[new_node.state] = new_nodes[idx:]
break
if self.metric:
score = self.score(tikz.rasterize()) if scorable else -1 # type: ignore
else: # if we do not have a metric, use compiler logs instead
score = scorable - tikz.compiled_with_errors
node.update_win_value(self.norm(score) if scorable and self.metric else score)
self.solution.append((score, tikz))
def merge(self, node: WideNode, nodes_to_merge: List[WideNode]) -> Tuple[WideNode, List[WideNode]]:
for merge_node in nodes_to_merge:
for child in node.children:
if child.state == merge_node.state:
node, nodes_to_merge = child, nodes_to_merge[1:]
break
else:
break
return node, nodes_to_merge
class DetikzifyPipeline:
def __init__(
self,
model,
processor,
# hyperparams based on "a systematic evaluation of large language models of code"
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 0,
compile_timeout: Optional[int] = 60, # same as old overleaf compile timeout
metric: Union[Literal["model", "fast"], Metric] = "model",
**gen_kwargs,
):
self.model = model
self.processor = processor
if metric == "model": # SelfSim
self.metric = ImageSim.from_detikzify(model, processor, sync_on_compute=False)
elif metric == "fast": # Compiler Diagnostics
self.metric = None
else:
self.metric = metric
self.gen_kwargs: Dict[str, Any] = dict(
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_length=unwrap(processor).tokenizer.model_max_length,
do_sample=True,
compile_timeout=compile_timeout,
**gen_kwargs
)
def load(self, image: Union[Image.Image, str], preprocess: bool = True):
image = load(image)
if preprocess:
return expand(image, max(image.size), do_trim=True)
return image
def check_inputs(self, image, text):
assert text is None or has_adapter(self.model), "You need to load an adapter for textual inputs!"
assert image or text, "Either image or text (or both) required!"
def sample(
self,
image: Optional[Union[Image.Image, str]] = None,
text: Optional[str] = None,
preprocess: bool = True,
**gen_kwargs,
) -> TikzDocument:
"""
DeTikZify a raster image. Samples a single image and returns it.
image: the image
text: textual instruction
preprocess: whether to preprocess the image (expand to square and
trim to content)
gen_kwargs: additional generation kwargs (potentially overriding
the default ones)
"""
self.check_inputs(image, text)
generator = DetikzifyGenerator(
model=self.model,
processor=self.processor,
image=self.load(image, preprocess=preprocess) if image is not None else None,
text=text,
**self.gen_kwargs,
**gen_kwargs
)
return generator.sample()
def simulate(
self,
image: Optional[Union[Image.Image, str]] = None,
text: Optional[str] = None,
preprocess: bool = True,
expansions: Optional[Numeric] = None,
timeout: Optional[int] = None,
**gen_kwargs,
) -> Generator[Tuple[Numeric, TikzDocument], None, None]:
"""
DeTikZify a raster image using MCTS. Returns an iterator yielding
(score, tikzdoc) tuples of TikZ documents created during rollouts.
image: the image
text: textual instruction
preprocess: whether to preprocess the image (expand to square and
trim to content)
expansions: number of attempted MCTS expansions (set to None, 0 or
math.inf for infinite)
timeout: timeout for MCTS in seconds (set to 0, math.inf, or
None for infinite)
gen_kwargs: additional generation kwargs (potentially overriding
the default ones)
"""
self.check_inputs(image, text)
generator = DetikzifyGenerator(
model=self.model,
processor=self.processor,
metric=self.metric,
mcts_timeout=timeout or None,
image=self.load(image, preprocess=preprocess) if image is not None else None,
text=text,
**self.gen_kwargs,
**gen_kwargs
)
yield from generator.simulate(expansions or None)
def __call__(self, *args, **kwargs) -> TikzDocument:
return self.sample(*args, **kwargs)
```
## /detikzify/infer/tikz.py
```py path="/detikzify/infer/tikz.py"
from collections import namedtuple
from functools import cache, cached_property
from io import BytesIO
from os import environ
from os.path import isfile, join
from re import MULTILINE, escape, findall, search
from subprocess import CalledProcessError, DEVNULL, TimeoutExpired
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Dict, Optional, Union
from PIL import Image
from pdf2image.pdf2image import convert_from_bytes
from pdfCropMargins import crop
import pymupdf
from transformers.utils import logging
from ..util import check_output, expand, redact as redact_text
logger = logging.get_logger("transformers")
class TikzDocument:
"""
Facilitate some operations with TikZ code. To compile the images a full
TeXLive installation is assumed to be on the PATH. Cropping additionally
requires Ghostscript, and rasterization needs poppler.
"""
# engines to try, could also try: https://tex.stackexchange.com/a/495999
engines = ["pdflatex", "lualatex", "xelatex"]
Output = namedtuple("Output", ['pdf', 'status', 'log'], defaults=[None, -1, ""])
def __init__(self, code: str, timeout: Optional[int] = 60):
self.code = code
self.timeout = timeout
# https://stackoverflow.com/a/68550238
self.compile = cache(self.compile)
@property
def status(self) -> int:
return self.compile().status
@property
def pdf(self) -> Optional[pymupdf.Document]: # type: ignore
return self.compile().pdf
@property
def log(self) -> str:
return self.compile().log
@property
def compiled_with_errors(self) -> bool:
return self.status != 0
@property
def errors(self, rootfile: Optional[str] = None) -> Dict[int, str]:
"""
Returns a dict of (linenr, errormsg) pairs. linenr==0 is a special
value reserved for errors that do not have a linenumber in rootfile.
"""
if self.compiled_with_errors:
if not rootfile and (match:=search(r"^\((.+)$", self.log, MULTILINE)):
rootfile = match.group(1)
else:
ValueError("rootfile not found!")
errors = dict()
for file, line, error in findall(r'^(.+):(\d+):(.+)$', self.log, MULTILINE):
if file == rootfile:
errors[int(line)] = error.strip()
else: # error occurred in other file
errors[0] = error.strip()
return errors or {0: "Fatal error occurred, no output PDF file produced!"}
return dict()
@cached_property
def is_rasterizable(self) -> bool:
"""true if we have an image"""
return self.rasterize() is not None
@cached_property
def has_content(self) -> bool:
"""true if we have an image that isn't empty"""
return (img:=self.rasterize()) is not None and img.getcolors(1) is None
@classmethod
def set_engines(cls, engines: Union[str, list]):
cls.engines = [engines] if isinstance(engines, str) else engines
def compile(self) -> "Output":
output = dict()
with TemporaryDirectory() as tmpdirname:
with NamedTemporaryFile(dir=tmpdirname, buffering=0) as tmpfile:
codelines = self.code.split("\n")
# make sure we don't have page numbers in compiled pdf (for cropping)
codelines.insert(1, r"{cmd}\AtBeginDocument{{{cmd}}}".format(cmd=r"\thispagestyle{empty}\pagestyle{empty}"))
tmpfile.write("\n".join(codelines).encode())
try:
# compile
errorln, tmppdf, outpdf = -1, f"{tmpfile.name}.pdf", join(tmpdirname, "tikz.pdf")
open(f"{tmpfile.name}.bbl", 'a').close() # some classes expect a bibfile
def try_save_last_page():
try:
doc = pymupdf.open(tmppdf)
doc.select([len(doc)-1])
doc.save(outpdf)
except:
pass
for engine in self.engines:
try:
check_output(
cwd=tmpdirname,
timeout=self.timeout,
stderr=DEVNULL,
env=environ | dict(max_print_line="1000"), # improve formatting of log
args=["latexmk", "-f", "-nobibtex", "-norc", "-file-line-error", "-interaction=nonstopmode", f"-{engine}", tmpfile.name]
)
except (CalledProcessError, TimeoutExpired) as proc:
log = (getattr(proc, "output", b'') or b'').decode(errors="ignore")
error = search(rf'^{escape(tmpfile.name)}:(\d+):.+$', log, MULTILINE)
# only update status and log if first error occurs later than in previous engine
if (linenr:=int(error.group(1)) if error else 0) > errorln:
errorln = linenr
output.update(status=getattr(proc, 'returncode', -1), log=log)
try_save_last_page()
else:
output.update(status=0, log='')
try_save_last_page()
break
# crop
croppdf = f"{tmpfile.name}.crop"
crop(["-gsf", "-c", "gb", "-p", "0", "-a", "-1", "-o", croppdf, outpdf], quiet=True)
if isfile(croppdf):
output['pdf'] = pymupdf.open(croppdf)
except FileNotFoundError:
logger.error("Missing dependencies: Did you install TeX Live?")
except RuntimeError: # pdf error during cropping
pass
if output.get("status") == 0 and not output.get("pdf", None):
logger.warning("Could compile document but something seems to have gone wrong during cropping!")
return self.Output(**output)
def rasterize(self, size=420, expand_to_square=True, redact=False, **redact_kwargs) -> Optional[Image.Image]:
if pdf:=self.pdf:
if redact:
pdf = redact_text(pdf, **redact_kwargs)
image = convert_from_bytes(pdf.tobytes(), size=size, single_file=True)[0]
if expand_to_square:
return expand(image, size)
return image
def save(self, filename: str, *args, **kwargs):
match filename.split(".")[-1]:
case "tex": content = self.code.encode()
case "pdf" if self.pdf: content = self.pdf.tobytes()
case fmt if img := self.rasterize(*args, **kwargs):
img.save(imgByteArr:=BytesIO(), format=fmt)
content = imgByteArr.getvalue()
case fmt: raise ValueError(f"Couldn't save with format '{fmt}'!")
with open(filename, "wb") as f:
f.write(content)
```
## /detikzify/mcts/LICENSE
``` path="/detikzify/mcts/LICENSE"
The MIT License
Copyright (c) 2010-2018 ImparaAI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
```
## /detikzify/mcts/README.md
A Python3 library for running a [Monte Carlo tree search](https://en.wikipedia.org/wiki/Monte_Carlo_tree_search), either traditionally by drilling down to end game states or with expert policies as might be provided by a neural network.
Adapted from **Version:** 1.3.1 of [ImparaAI/monte-carlo-tree-search](https://github.com/ImparaAI/monte-carlo-tree-search).
# Monte Carlo tree search basics
The Monte Carlo tree search (MCTS) algorithm can help with making a decision from a number of options. It avoids exploring every possible option by randomly sampling a small number of pathways and picking the move with the highest probability of victory. This is commonly applied to games like chess or go where it's useful to know what move should come next if you want to win the game.
MCTS works by expanding the search tree to figure out which moves (or child/subsequent states) are likely to produce a positive result if chosen. While time is available, the algorithm continues to explore the tree, always slightly favoring the direction that has either proven to be fruitful or is less explored. When no time is left, the most explored direction is chosen.
The search tree expansion can be done in two different ways:
- **Traditional**: At least one random rollout to a game's end state (e.g. win, loss, tie) for each move under evaluation so the algorithm can make a choice.
- **Expert policy (i.e. neural network)**: Instead of expensively rolling all the way out to a game's end state ask an expert (a neural network for example) which move is most likely to produce a positive outcome.
For a deeper dive into the topic, check out [this article](http://tim.hibal.org/blog/alpha-zero-how-and-why-it-works/).
# This library
As the user of this library, you only have to provide:
- A function that finds the direct children of each search tree node (called the **`child_finder`**)
- A function for evaluating nodes for end state outcomes (called the **`node_evaluator`**)
-- *(Not necessary with neural network)*
# Usage
Create a new Monte Carlo tree:
```python
from chess import Game
from montecarlo.node import Node
from montecarlo.montecarlo import MonteCarlo
chess_game = Game()
montecarlo = MonteCarlo(Node(chess_game))
```
The root node describes your current game state. This state will be used by you later in the **`child_finder`** and the **`node_evaluator`**.
For the sake of demonstration, we will assume you have a generic `Game` library that can tell you what moves are possible and allows you to perform those moves to change the game's state.
## Traditional Monte Carlo
Add a **`child_finder`** and a **`node_evaluator`**:
```python
def child_finder(node, montecarlo):
for move in node.state.get_possible_moves():
child = Node(deepcopy(node.state)) #or however you want to construct the child's state
child.state.move(move) #or however your library works
node.add_child(child)
def node_evaluator(node, montecarlo):
if node.state.won():
return 1
elif node.state.lost():
return -1
montecarlo.child_finder = child_finder
montecarlo.node_evaluator = node_evaluator
```
The **`child_finder`** should add any child nodes to the parent node passed into the function, if there are any. If there are none, the parent should be in an end state, so the **`node_evaluator`** should return a value between `-1` and `1`.
## Expert policy (AI)
If you have an expert policy that you can apply to the children as they're being generated, the library will recognize that it doesn't need to make the costly drill down to an end state. If your neural net produces both an expert policy value for the children and a win value for the parent node, you can skip declaring the `node_evaluator` altogether.
```python
def child_finder(node, montecarlo):
win_value, expert_policy_values = neural_network.predict(node.state)
for move in node.state.get_possible_moves():
child = Node(deepcopy(node.state))
child.state.move(move)
child.player_number = child.state.whose_turn()
child.policy_value = get_child_policy_value(child, expert_policy_values) #should return a probability value between 0 and 1
node.add_child(child)
node.update_win_value(win_value)
montecarlo.child_finder = child_finder
```
## Simulate and make a choice
Run the simulations:
```python
montecarlo.simulate(50) #number of expansions to run. higher is typically more accurate at the cost of processing time
```
Once the simulations have run you can ask the instance to make a choice:
```python
chosen_child_node = montecarlo.make_choice()
chosen_child_node.state.do_something()
```
After you've chosen a new root node, you can override it on the `montecarlo` instance and do more simulations from the new position in the tree.
```python
montecarlo.root_node = montecarlo.make_choice()
```
If you're training a neural network, you may want to make a more exploratory choice for the first N moves of a game:
```python
montecarlo.root_node = montecarlo.make_exploratory_choice()
```
This won't provide a purely random choice, rather it will be random with a bias favoring the more explored pathways.
## Turn-based environments
If you are modeling a turn-based environment (e.g. a two player board game), set the `player_number` on each node so the selection process can invert child win values:
```python
node = Node(state)
node.player_number = 1
```
It doesn't matter what this number is (you can use 1 and 2 or 5 and 6), only that it is consistent with other nodes.
## Tweaking the discovery factor
When building a new child node, you can change the rate at which discovery is preferred:
```python
node = Node(state)
node.discovery_factor = 0.2 #0.35 by default, can be between 0 and 1
```
The closer this number is to 1, the more discovery will be favored over demonstrated value in later simulations.
## /detikzify/mcts/__init__.py
```py path="/detikzify/mcts/__init__.py"
```
## /detikzify/mcts/montecarlo.py
```py path="/detikzify/mcts/montecarlo.py"
import random
import time
class MonteCarlo:
def __init__(self, root_node, mins_timeout=None):
self.root_node = root_node
self.solution = None
self.child_finder = None
self.node_evaluator = lambda child, montecarlo: None
self.stats_expansion_count = 0
self.stats_failed_expansion_count = 0
self.mins_timeout = mins_timeout
def make_choice(self):
best_children = []
most_visits = float("-inf")
for child in self.root_node.children:
if child.visits > most_visits:
most_visits = child.visits
best_children = [child]
elif child.visits == most_visits:
best_children.append(child)
return random.choice(best_children)
def make_exploratory_choice(self):
children_visits = map(lambda child: child.visits, self.root_node.children)
children_visit_probabilities = [
visit / self.root_node.visits for visit in children_visits
]
random_probability = random.uniform(0, 1)
probabilities_already_counted = 0.0
for i, probability in enumerate(children_visit_probabilities):
if probabilities_already_counted + probability >= random_probability:
return self.root_node.children[i]
probabilities_already_counted += probability
def simulate(self, expansion_count=1):
i = 0
start_time = time.time()
while expansion_count is None or i < expansion_count:
i += 1
if self.solution is not None:
return
if self.mins_timeout is not None:
curr_time = time.time()
duration = curr_time - start_time
if duration > (self.mins_timeout * 60):
print("reached timelimit, stopping expansion on current node")
return
current_node = self.root_node
while current_node.expanded:
current_node = current_node.get_preferred_child(self.root_node)
self.expand(current_node)
def expand(self, node):
self.stats_expansion_count += 1
self.child_finder(node, self)
for child in node.children:
child_win_value = self.node_evaluator(child, self)
if child_win_value != None:
child.update_win_value(child_win_value)
if not child.is_scorable():
self.random_rollout(child)
child.children = []
if len(node.children):
node.expanded = True
else:
self.stats_failed_expansion_count += 1
def random_rollout(self, node):
self.child_finder(node, self)
child = random.choice(node.children)
node.children = []
node.add_child(child)
child_win_value = self.node_evaluator(child, self)
if child_win_value != None:
node.update_win_value(child_win_value)
else:
self.random_rollout(child)
def print_tree(self, f):
f.write("graph\n{\n")
self.root_node.print_node(f, 0, self.root_node, "a")
f.write("}\n")
```
## /detikzify/mcts/node.py
```py path="/detikzify/mcts/node.py"
import random
import json
from math import log, sqrt
class Node:
def __init__(self, state):
self.state = state
self.win_value = 0
self.policy_value = None
self.visits = 0
self.parent = None
self.children = []
self.expanded = False
self.player_number = None
self.discovery_factor = 0.35
self.is_widen_node = False
def update_win_value(self, value):
self.win_value += value
self.visits += 1
if self.parent:
self.parent.update_win_value(value)
def update_policy_value(self, value):
self.policy_value = value
def add_child(self, child):
self.children.append(child)
child.parent = self
def add_children(self, children):
for child in children:
self.add_child(child)
def get_preferred_child(self, root_node):
best_children = []
best_score = float("-inf")
for child in self.children:
score = child.get_score(root_node)
if score > best_score:
best_score = score
best_children = [child]
elif score == best_score:
best_children.append(child)
return random.choice(best_children)
def get_score(self, root_node):
discovery_operand = (
self.discovery_factor
* (self.policy_value or 1)
* sqrt(log(self.parent.visits) / (self.visits or 1))
)
if self.is_widen_node:
win_operand = 0
else:
win_multiplier = (
1 if self.parent.player_number == root_node.player_number else -1
)
win_operand = win_multiplier * self.win_value / (self.visits or 1)
self.score = win_operand + discovery_operand
return self.score
def is_scorable(self):
return self.visits or self.policy_value != None
def print_node(self, f, i, root, st):
escape = lambda x : json.dumps(x).strip('"')
if self.parent is None:
f.write((' ' * i) + st + " [label=\"" + escape(self.state) + "\",shape=box]\n")
else:
diff = '\n'.join([x for x in self.state.split("\n") if x not in self.parent.state.split("\n")])
f.write((' ' * i) + st + " [label=\"" + escape(diff) + "\",shape=box]\n")
num = 0
for child in self.children:
new_st = st + "_" + str(num)
child.print_node(f, i + 2, root, new_st)
f.write(' ' * i + st + " -- " + new_st + "\n")
num = num + 1
```
## /detikzify/model/__init__.py
```py path="/detikzify/model/__init__.py"
from datasets import DownloadManager
from safetensors.torch import load_file
from transformers.utils.hub import has_file
from transformers import (
AutoConfig,
AutoModelForVision2Seq,
AutoProcessor,
is_timm_available,
)
from transformers.utils.hub import is_remote_url
from .configuration_detikzify import *
from .modeling_detikzify import *
from .processing_detikzify import *
from .adapter import load as load_adapter
if is_timm_available():
from .v1 import models as v1_models, load as load_v1
def register():
try:
AutoConfig.register("detikzify", DetikzifyConfig)
AutoModelForVision2Seq.register(DetikzifyConfig, DetikzifyForConditionalGeneration)
AutoProcessor.register(DetikzifyConfig, DetikzifyProcessor)
except ValueError:
pass # already registered
def load(model_name_or_path, modality_projector=None, is_v1=False, **kwargs):
# backwards compatibility with v1 models
if is_timm_available() and (is_v1 or model_name_or_path in v1_models): # type: ignore
model, tokenizer, image_processor = load_v1( # type: ignore
model_name_or_path=model_name_or_path,
modality_projector=modality_projector,
**kwargs
)
return model, DetikzifyProcessor(
tokenizer=tokenizer,
image_processor=image_processor,
image_seq_len=model.config.num_patches,
image_token=tokenizer.convert_ids_to_tokens(model.config.patch_token_id)
)
register()
processor = AutoProcessor.from_pretrained(model_name_or_path)
model = AutoModelForVision2Seq.from_pretrained(model_name_or_path, **kwargs)
if modality_projector is not None:
if is_remote_url(modality_projector):
modality_projector = DownloadManager().download(modality_projector)
model.load_state_dict(
state_dict=load_file(
filename=modality_projector, # type: ignore
device=str(model.device)
),
strict=False
)
if has_file(model_name_or_path, "adapter/model.safetensors"):
model, processor = load_adapter(model=model, processor=processor)
return model, processor
```
## /detikzify/model/adapter/__init__.py
```py path="/detikzify/model/adapter/__init__.py"
from transformers import AutoTokenizer
from .modeling_adapter import CrossAttentionAdapterMixin
from .processing_adapter import AdapterProcessor
def has_adapter(model):
return hasattr(model, "adapter")
def load(model, processor, adapter_name_or_path=None, **kwargs):
embedding_model = "meta-llama/Llama-3.2-1B"
model.load_cross_attn_adapter(embedding_model, adapter_name_or_path, **kwargs)
processor = AdapterProcessor(
processor=processor,
tokenizer=AutoTokenizer.from_pretrained(
embedding_model,
pad_token="<|finetune_right_pad_id|>",
model_max_length=512,
),
)
model.embedding_model.config.pad_token_id = processor.tokenizer.pad_token_id
return model, processor
```
## /detikzify/model/adapter/modeling_adapter.py
```py path="/detikzify/model/adapter/modeling_adapter.py"
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted from modelling_mllama.py and modelling_siglip.py
# https://github.com/huggingface/transformers/commit/2e24ee4dfa39cc0bc264b89edbccc373c8337086
from functools import partial
from os.path import basename
from typing import Optional, Tuple
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel, PretrainedConfig
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
class CrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: Optional[PretrainedConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = self.config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_norm = nn.LayerNorm(self.head_dim, eps=config.layer_norm_eps)
self.k_norm = nn.LayerNorm(self.head_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class CrossSdpaAttention(CrossAttention):
"""
Attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `CrossAttention`
as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API.
"""
# Adapted from CrossAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Using CrossSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
cross_attention_states=cross_attention_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
class CrossFlashAttention2(CrossAttention):
"""
CrossAttention flash attention module. This module inherits from `CrossAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
output_attentions = False
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if attention_mask is not None and attention_mask.all():
# FIXME: figure out why all 1 attention mask leads to different results
attention_mask = None
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=False,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
CROSS_ATTENTION_CLASSES = {
"eager": CrossAttention,
"sdpa": CrossSdpaAttention,
"flash_attention_2": CrossFlashAttention2,
}
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class CrossAttentionLayer(torch.nn.Module):
"""Cross-attention transformer block with sigmoid-gated attention and feedforward."""
def __init__(self, config: PretrainedConfig) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.cross_attn = CROSS_ATTENTION_CLASSES[config._attn_implementation](config=config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = MLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))
self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1))
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
output_attentions=output_attentions,
)
hidden_states = residual + self.cross_attn_attn_gate.sigmoid() * hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.cross_attn_mlp_gate.sigmoid() * hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class CrossAttentionAdapter(PreTrainedModel):
base_model_prefix = "model"
no_split_modules = ["CrossAttentionLayer"]
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config, input_hidden_size, cross_attn_every_n_layers=1):
super().__init__(config)
self.num_patches = (config.image_size // config.patch_size) ** 2
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layers = nn.ModuleList([ # type: ignore
CrossAttentionLayer(config)
if (layer_idx + 1) % cross_attn_every_n_layers == 0
else None
for layer_idx in range(config.num_hidden_layers)
])
self.connector = nn.Linear(
input_hidden_size,
config.hidden_size,
bias=True
)
self.dummy_input = nn.Parameter(
torch.ones(
config.num_channels,
config.image_size,
config.image_size
)
)
self.post_init()
def connect(self, inputs):
return self.connector(inputs)
def prepare_4d_attention_mask(self, attention_mask, dtype):
if attention_mask is not None and not self._use_flash_attention_2:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
return _prepare_4d_attention_mask(attention_mask, dtype, self.num_patches)
return attention_mask
# pyright: reportAttributeAccessIssue=false
class CrossAttentionAdapterMixin:
def init_cross_attn_adapter(
self,
model_or_model_name_or_path,
cross_attn_every_n_layers: Optional[int] = 1,
embedding_kwargs: dict = dict(),
adapter_kwargs: dict = dict(),
**kwargs,
):
self.embedding_model = self.load_embedding_model(
model_or_model_name_or_path,
**embedding_kwargs,
**kwargs
).to(self.device)
self.adapter = CrossAttentionAdapter._from_config(
input_hidden_size=self.embedding_model.config.hidden_size,
cross_attn_every_n_layers=cross_attn_every_n_layers,
config=getattr(self.config, "vision_config", self.config),
torch_dtype=self.dtype,
**adapter_kwargs,
**kwargs
).to(self.device, self.dtype)
self.add_hooks()
def load_cross_attn_adapter(
self,
model_or_model_name_or_path,
adapter_name_or_path: Optional[str] = None,
cross_attn_every_n_layers: Optional[int] = 1,
embedding_kwargs: dict = dict(),
adapter_kwargs: dict = dict(),
**kwargs,
):
self.embedding_model = self.load_embedding_model(
model_or_model_name_or_path,
**embedding_kwargs,
**kwargs
)
if adapter_name_or_path is not None:
self.adapter = CrossAttentionAdapter.from_pretrained(
pretrained_model_name_or_path=adapter_name_or_path,
input_hidden_size=self.embedding_model.config.hidden_size,
cross_attn_every_n_layers=cross_attn_every_n_layers,
config=getattr(self.config, "vision_config", self.config),
torch_dtype=self.dtype,
**adapter_kwargs,
**kwargs
).to(self.dtype)
else:
self.adapter = CrossAttentionAdapter.from_pretrained(
pretrained_model_name_or_path=self.config.name_or_path,
input_hidden_size=self.embedding_model.config.hidden_size,
cross_attn_every_n_layers=cross_attn_every_n_layers,
config=getattr(self.config, "vision_config", self.config),
subfolder="adapter",
torch_dtype=self.dtype,
**adapter_kwargs,
**kwargs
).to(self.dtype)
if "device_map" not in kwargs:
self.embedding_model = self.embedding_model.to(self.device)
self.adapter = self.adapter.to(self.device)
self.add_hooks()
def load_embedding_model(self, model_or_model_name_or_path, **model_kwargs):
if isinstance(model_or_model_name_or_path, str):
model = AutoModel.from_pretrained(
pretrained_model_name_or_path=model_or_model_name_or_path,
torch_dtype=self.dtype,
**model_kwargs
)
else:
model = model_or_model_name_or_path
return model.to(self.dtype)
def add_hooks(self):
handles, adapter_inputs, cross_attention = list(), dict(), dict()
for name, module in self.named_modules():
if "vision_model" in name and type(module) == nn.ModuleList:
vision_layers = module
break
else:
raise ValueError("Couldn't locate vision encoder layers!")
# HACK: convert args to kwargs
def args_to_kwargs(module, args):
return dict(zip(type(module).forward.__code__.co_varnames[1:], args))
def forward_hook(layer, args, kwargs):
if kwargs.get("image_hidden_states") is not None:
# we are in .generate method which calls forward iteratively
for key in ["adapter_input_ids", "adapter_attention_mask"]:
kwargs.pop(key, None)
elif (adapter_input_ids:=kwargs.pop("adapter_input_ids", None)) is not None:
if not hasattr(self, "adapter"):
raise ValueError("Got `adapter_input_ids` but no adapter is loaded!")
adapter_inputs.update(
input_ids=adapter_input_ids,
attention_mask=kwargs.pop("adapter_attention_mask", None)
)
if (kwargs | args_to_kwargs(layer, args)).get("pixel_values") is None:
dummy_input = self.adapter.dummy_input.clamp(-1, 1)
kwargs['pixel_values'] = dummy_input.repeat(len(adapter_input_ids), 1, 1, 1)
return args, kwargs
for layer, cross_layer in zip(vision_layers, self.adapter.layers):
if cross_layer is not None:
def layer_hook(cross_layer, layer, args, kwargs):
if adapter_inputs:
embeddings = self.embedding_model(**adapter_inputs).last_hidden_state
cross_attention.update(
cross_attention_states=self.adapter.connect(embeddings),
cross_attention_mask=self.adapter.prepare_4d_attention_mask(
adapter_inputs["attention_mask"],
embeddings.dtype
))
adapter_inputs.clear()
if cross_attention:
kwargs |= args_to_kwargs(layer, args)
kwargs['hidden_states'] = cross_layer(**cross_attention, **kwargs)[0]
return tuple(), kwargs
handles.append(layer.register_forward_pre_hook(partial(layer_hook, cross_layer), with_kwargs=True))
handles.append(self.register_forward_pre_hook(forward_hook, with_kwargs=True))
handles.append(self.register_forward_hook(lambda *_: cross_attention.clear()))
self.handles = handles
def unload_cross_attn_adapter(self):
for handle in self.handles:
handle.remove()
del self.adapter, self.embedding_model, self.handles
def save_cross_attn_adapter(self, *args, **kwargs):
return self.adapter.save_pretrained(*args, **kwargs)
def has_adapter(self):
return hasattr(self, "adapter")
```
## /detikzify/model/adapter/processing_adapter.py
```py path="/detikzify/model/adapter/processing_adapter.py"
from typing import List, Optional, TYPE_CHECKING, Union, Unpack
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, make_list_of_images
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
from transformers.tokenization_utils_base import (
BatchEncoding,
PreTokenizedInput,
TextInput,
)
from transformers.utils import logging
from ...util import DUMMY_IMAGE
if TYPE_CHECKING:
from transformers.tokenization_utils_base import PreTokenizedInput
logger = logging.get_logger(__name__)
class AdapterProcessor(ProcessorMixin):
attributes = ["processor", "tokenizer"]
processor_class = ("ProcessorMixin", "ImageProcessingMixin")
tokenizer_class = "AutoTokenizer"
def __init__(self, processor, tokenizer=None, **kwargs):
if processor is None:
raise ValueError("You need to specify a `processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
super().__init__(processor, tokenizer, **kwargs)
def __call__(
self,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
images: Optional[ImageInput] = None,
**kwargs: Unpack[ProcessingKwargs],
) -> BatchEncoding:
if images is None and text is None:
raise ValueError("Either `images` or `text` (or both) are expected as arguments to an `AdapterProcessor` instance.")
text_kwargs, images_kwargs = kwargs.pop("text_kwargs", {}), kwargs.pop("images_kwargs", {})
if text is None:
text_inputs = dict()
else:
text = [text] if isinstance(text, str) else text
text_inputs = {f"adapter_{key}": value for key, value in self.tokenizer(text=text, **kwargs, **text_kwargs).items()}
if getattr(self.processor, "model_expects_text", False):
images_kwargs.update(text=text, add_bos_token=True)
if images is None:
image_inputs = self.processor(images=len(text) * [DUMMY_IMAGE], **kwargs, **images_kwargs)
image_inputs = dict((k, image_inputs[k]) for k in ["input_ids", "attention_mask"] if k in image_inputs)
else:
images = make_list_of_images(images)
image_inputs = self.processor(images=images, **kwargs, **images_kwargs)
if text is not None and images is not None and len(images) != len(text):
raise ValueError(
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
)
return BatchFeature(data={**image_inputs, **text_inputs})
def batch_decode(self, *args, **kwargs):
return self.processor.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.processor.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
processor_input_names = self.processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + processor_input_names))
```
## /detikzify/model/configuration_detikzify.py
```py path="/detikzify/model/configuration_detikzify.py"
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted from
# https://github.com/huggingface/transformers/commit/e1b150862e66e16acf951edfa13206ffcd1032be
import os
from typing import Union
from transformers import CONFIG_MAPPING
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class DetikzifyVisionConfig(PretrainedConfig):
model_type = "detikzify"
def __init__(
self,
hidden_size=1152,
intermediate_size=4304,
num_hidden_layers=27,
num_attention_heads=16,
num_channels=3,
image_size=420,
patch_size=14,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.initializer_range = initializer_range
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from DetikzifyConfig
if config_dict.get("model_type") == "detikzify":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class DetikzifyConfig(PretrainedConfig):
model_type = "detikzify"
is_composition = True
def __init__(
self,
use_cache=True,
image_token_id=128005,
tie_word_embeddings=False,
vision_config=None,
text_config=None,
concat_factor=3,
pad_token_id=128004,
**kwargs,
):
self.image_token_id = image_token_id
self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings
if vision_config is None:
self.vision_config = DetikzifyVisionConfig()
logger.info("vision_config is None, using default vision config")
elif isinstance(vision_config, dict):
self.vision_config = DetikzifyVisionConfig(**vision_config)
elif isinstance(vision_config, DetikzifyVisionConfig):
self.vision_config = vision_config
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
logger.info("text_config is None, using default text config")
text_config = CONFIG_MAPPING["llama"](
rms_norm_eps=1e-5,
pad_token_id=pad_token_id,
tie_word_embeddings=False,
)
self.text_config = text_config
self.concat_factor = concat_factor
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
```
## /detikzify/model/modeling_detikzify.py
```py path="/detikzify/model/modeling_detikzify.py"
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted from
# https://github.com/huggingface/transformers/commit/e1b150862e66e16acf951edfa13206ffcd1032be
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.utils.checkpoint
from transformers import (
AutoModel,
Cache,
DynamicCache,
GenerationMixin,
PreTrainedModel,
SiglipVisionModel,
)
from transformers.modeling_outputs import ModelOutput
from transformers.utils import logging
from .adapter import CrossAttentionAdapterMixin
from .configuration_detikzify import DetikzifyConfig
logger = logging.get_logger(__name__)
@dataclass
class DetikzifyBaseModelOutputWithPast(ModelOutput):
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class DetikzifyCausalLMOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class DetikzifySimpleMLP(nn.Module):
def __init__(self, config):
super().__init__()
input_size = config.vision_config.hidden_size * config.concat_factor
output_size = config.text_config.hidden_size
self.proj = nn.Linear(input_size, output_size, bias=False)
def forward(self, x):
return self.proj(x)
class DetikzifyConnector(nn.Module):
def __init__(self, config):
super().__init__()
self.concat_factor = config.concat_factor
self.modality_projection = DetikzifySimpleMLP(config)
def concatenate(self, x, concat_factor=3):
bsz, seq, embed_dim = x.size()
return x.reshape(bsz, seq // concat_factor, embed_dim * concat_factor)
def forward(self, image_hidden_states):
image_hidden_states = self.concatenate(image_hidden_states, self.concat_factor)
image_hidden_states = self.modality_projection(image_hidden_states)
return image_hidden_states
class DetikzifyPreTrainedModel(PreTrainedModel):
config_class = DetikzifyConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = []
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class DetikzifyModel(DetikzifyPreTrainedModel):
def __init__(self, config: DetikzifyConfig):
super().__init__(config)
self.padding_idx = self.config.text_config.pad_token_id
self.vocab_size = self.config.text_config.vocab_size
self.vision_model = SiglipVisionModel._from_config(config.vision_config)
self.connector = DetikzifyConnector(config)
self.text_model = AutoModel.from_config(config.text_config)
self.image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.concat_factor)
)
self.image_token_id = self.config.image_token_id
self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
self.post_init()
def enable_input_require_grads(self):
def get_lowest_module(module):
if len(list(module.children())) == 0:
# If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
return module
else:
# Recursively call the function on each child module
return get_lowest_module(list(module.children())[0])
def make_inputs_require_grads(module, input, output):
output.requires_grad_(True)
self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
make_inputs_require_grads
)
def disable_input_require_grads(self):
self._text_require_grads_hook.remove()
self._vision_require_grads_hook.remove()
def get_input_embeddings(self):
return self.text_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)
def inputs_merger(
self,
input_ids: torch.LongTensor,
inputs_embeds: Optional[torch.Tensor],
image_hidden_states: Optional[torch.Tensor],
):
num_images, _, vision_hidden_size = image_hidden_states.shape
special_image_token_mask = input_ids == self.image_token_id
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
new_inputs_embeds = inputs_embeds.clone()
reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
# cast to the dtype of the input_embeds to support quantized models
reshaped_image_hidden_states = reshaped_image_hidden_states.to(inputs_embeds.dtype)
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
return new_inputs_embeds
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_hidden_states: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, DetikzifyBaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.training and self.text_model.gradient_checkpointing and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# retrieve input_ids and inputs_embeds
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
past_seen_tokens = 0
if use_cache:
if past_key_values is None:
past_key_values = DynamicCache()
past_seen_tokens = past_key_values.get_seq_length()
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
if inputs_embeds is None:
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
# START VISUAL INPUTS INTEGRATION
if pixel_values is not None and image_hidden_states is not None:
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
elif pixel_values is not None:
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values.to(dtype=self.dtype), # fp16 compatibility
).last_hidden_state
# Modality projection
image_hidden_states = self.connector(image_hidden_states)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self.inputs_merger(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
image_hidden_states=image_hidden_states,
)
outputs = self.text_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
return DetikzifyBaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_hidden_states,
)
class DetikzifyForConditionalGeneration(DetikzifyPreTrainedModel, GenerationMixin, CrossAttentionAdapterMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = DetikzifyModel(config)
self.image_token_id = self.config.image_token_id
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.vocab_size = config.text_config.vocab_size
# Initialize weights and apply final processing
self.post_init()
def enable_input_require_grads(self):
def make_inputs_require_grads(module, input, output):
output.requires_grad_(True)
self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
make_inputs_require_grads
)
def disable_input_require_grads(self):
self._text_require_grads_hook.remove()
self._vision_require_grads_hook.remove()
def get_input_embeddings(self):
return self.model.text_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.text_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def tie_weights(self):
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
if getattr(self.config, "tie_word_embeddings", True):
output_embeddings.weight = input_embeddings.weight
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_hidden_states: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, DetikzifyCausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_hidden_states=image_hidden_states,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return DetikzifyCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
pixel_values=None,
image_hidden_states=None,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
# but IDEFICS requires noth ids and embeds to be present
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
if image_hidden_states is not None:
pixel_values = None
# support model.generate method with adapters
if self.has_adapter():
for key in ["adapter_input_ids", "adapter_attention_mask"]:
if key in kwargs:
model_inputs[key] = kwargs[key]
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_hidden_states": image_hidden_states,
}
)
return model_inputs
def _validate_model_kwargs(self, model_kwargs):
# support model.generate method with adapters
if self.has_adapter():
for key in ['adapter_input_ids', 'adapter_attention_mask']:
model_kwargs.pop(key, None)
super()._validate_model_kwargs(model_kwargs)
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs
```
## /detikzify/model/processing_detikzify.py
```py path="/detikzify/model/processing_detikzify.py"
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted from
# https://github.com/huggingface/transformers/commit/e1b150862e66e16acf951edfa13206ffcd1032be
from typing import List, Optional, Union, Unpack
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, make_list_of_images
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
from transformers.tokenization_utils_base import (
BatchEncoding,
PreTokenizedInput,
TextInput,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
class DetikzifyProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"add_special_tokens": False,
"padding": False,
},
}
class DetikzifyProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor,
tokenizer=None,
image_seq_len: int = 300,
image_token: str = "<|reserved_special_token_2|>",
model_expects_text: bool = False,
**kwargs,
):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
if image_token not in tokenizer.vocab:
raise ValueError(f"{image_token} needs to be added to the `tokenizer` vocabulary.")
self.image_token = image_token
self.image_seq_len = image_seq_len
self.model_expects_text = model_expects_text
super().__init__(image_processor, tokenizer, **kwargs)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
image_seq_len: Optional[int] = None,
add_bos_token: bool = None,
add_eos_token: bool = None,
**kwargs: Unpack[DetikzifyProcessorKwargs],
) -> BatchEncoding:
output_kwargs = self._merge_kwargs(
DetikzifyProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Temporary fix for "padding_side" in init_kwargs
output_kwargs["text_kwargs"].pop("padding_side", None)
if images is None:
raise ValueError("`images` are expected as arguments to a `DetikzifyProcessor` instance.")
else:
images = make_list_of_images(images)
if text is None:
text = len(images) * [""]
elif isinstance(text, str):
text = [text]
if len(images) != len(text):
raise ValueError(
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
)
prompt_strings = []
for prompt in text:
assert self.image_token not in prompt, "Image tokens are added by the processor!"
if add_bos_token:
prompt += self.tokenizer.bos_token
if add_eos_token:
prompt += self.tokenizer.eos_token
image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
prompt_strings.append((self.image_token * image_seq_len) + prompt)
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"])
return BatchFeature(data={**image_inputs, **text_inputs})
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
```
## /detikzify/model/v1/__init__.py
```py path="/detikzify/model/v1/__init__.py"
from datasets import DownloadManager
from transformers import AutoConfig, AutoModel
from transformers import AutoTokenizer, PretrainedConfig
from transformers.utils.hub import is_remote_url
from .configuration_detikzify import *
from .modeling_detikzify import *
from .processing_detikzify import *
models = [
"nllg/detikzify-ds-1.3b",
"nllg/detikzify-ds-7b",
"nllg/detikzify-tl-1.1b",
"nllg/detikzify-cl-7b",
]
def register():
try:
AutoConfig.register("detikzify", DetikzifyConfig)
AutoModel.register(DetikzifyConfig, DetikzifyForCausalLM)
except ValueError:
pass # already registered
def load(model_name_or_path, vision_tower="vit_so400m_patch14_siglip_384.webli", modality_projector=None, **kwargs):
base_tokenizer = PretrainedConfig.from_pretrained(model_name_or_path).name_or_path or model_name_or_path
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=base_tokenizer,
model_max_length=2048,
add_bos_token=False,
add_eos_token=True,
pad_token="",
padding_side="right", # NOTE: only for training, need to change to "left" for batched inference
legacy=False
)
model = DetikzifyForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_name_or_path,
use_cache=True,
**kwargs
)
model.config.model_type = DetikzifyConfig.model_type # type: ignore
model.generation_config.pad_token_id = tokenizer.pad_token_id # type: ignore
if len(tokenizer) > model.config.vocab_size: # type: ignore
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) # type: ignore
if modality_projector and is_remote_url(modality_projector):
modality_projector = DownloadManager().download(modality_projector)
processor = model.get_model().initialize_vision_modules( # type: ignore
patch_token_id=tokenizer.bos_token_id,
modality_projector=modality_projector,
vision_tower=getattr(model.config, "vision_tower", vision_tower), # type: ignore
feature_layer=getattr(model.config, "feature_layer", -1), # type: ignore
concat_patches=getattr(model.config, "concat_patches", 3) # type: ignore
)
return model, tokenizer, processor
```
## /detikzify/model/v1/configuration_detikzify.py
```py path="/detikzify/model/v1/configuration_detikzify.py"
from transformers import LlamaConfig
class DetikzifyConfig(LlamaConfig):
model_type = "detikzify"
# compatibility with new inference code
@property
def image_token_id(self):
return self.patch_token_id
@property
def pooling_mode(self):
return "cos"
```
## /detikzify/model/v1/modeling_detikzify.py
```py path="/detikzify/model/v1/modeling_detikzify.py"
# Adopted from https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava.py. Below is the original copyright:
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pickle import UnpicklingError
from typing import List, Optional, Tuple, Union
from numpy import clip
from safetensors.torch import load_file
from timm import create_model as create_vision_model
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from transformers import (
BatchEncoding,
LlamaConfig,
LlamaForCausalLM,
LlamaModel,
PretrainedConfig,
PreTrainedModel,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention,
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.utils import logging
from .configuration_detikzify import DetikzifyConfig
from .processing_detikzify import DetikzifyImageProcessor
logger = logging.get_logger("transformers")
class DetikzifyVisionModel(PreTrainedModel):
_no_split_modules = ["VisionTransformer"]
def __init__(self, model, **kwargs) -> None:
super().__init__(PretrainedConfig.from_dict(model.pretrained_cfg), **kwargs)
# HACK: wrap in list so that vision model does not count as a parameter
self.model = [model]
def get_input_embeddings(self) -> torch.nn.Module:
return self.model[0].patch_embed
def to_input_dtype(self, pixel_values: torch.Tensor):
target_dtype = self.get_input_embeddings().proj.weight.dtype
return pixel_values.to(dtype=target_dtype)
def forward(self, pixel_values: torch.Tensor):
last_hidden_state = self.model[0].forward_features(self.to_input_dtype(pixel_values))
pooler_output = self.model[0].forward_head(last_hidden_state)
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output
)
def get_intermediate_layers(self, pixel_values: torch.Tensor, *args, **kwargs):
return self.model[0].get_intermediate_layers(self.to_input_dtype(pixel_values), *args, **kwargs)
class DetikzifyModel(LlamaModel):
config_class = DetikzifyConfig
def __init__(self, config: LlamaConfig):
super(DetikzifyModel, self).__init__(config)
if getattr(config, "use_mm_proj"):
self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
self.vision_model = DetikzifyVisionModel(create_vision_model(config.vision_tower))
def initialize_vision_modules(
self,
vision_tower,
patch_token_id,
concat_patches=3,
feature_layer=-1,
modality_projector=None,
**kwargs
):
vision_model = create_vision_model(vision_tower, pretrained=True, **kwargs)
self.vision_model = DetikzifyVisionModel(vision_model.to(self.device, self.dtype).eval().requires_grad_(False))
processor = DetikzifyImageProcessor.from_pretrained(vision_tower)
self.config.use_mm_proj = True
self.config.vision_tower = vision_tower
self.config.mm_hidden_size = vision_model.embed_dim * concat_patches
self.config.patch_token_id = patch_token_id
self.config.concat_patches = concat_patches
self.config.feature_layer = int(clip(feature_layer, -(depth:=len(vision_model.blocks)), depth-1) % depth)
self.config.vision_config = processor.to_dict() # type: ignore
self.config.num_patches = vision_model.patch_embed.num_patches // concat_patches
if not hasattr(self, 'mm_projector'):
self.mm_projector = nn.Linear(
self.config.mm_hidden_size,
self.config.hidden_size,
dtype=self.dtype,
device=self.device
)
if modality_projector is not None:
try: # first try to load as pickle
mm_projector_weights = torch.load(modality_projector, map_location=self.device)
except UnpicklingError: # and if that fails we try safetensors
mm_projector_weights = load_file(modality_projector, device=str(self.device))
self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
return processor
# https://stackoverflow.com/a/57208704
def _apply(self, fn):
super()._apply(fn)
if hasattr(self, "vision_model"):
self.set_vision_model = self.vision_model._apply(fn)
return self
def get_vision_features(self, pixel_values):
concat, n_patch, layer = self.config.concat_patches, self.config.num_patches, self.config.feature_layer
feats = self.vision_model.get_intermediate_layers(pixel_values, n=[layer], norm=True)[0]
# in case the number of feature vectors is not divisible by the number
# of patches we want to concatenate, we remove the first feature(s)
return feats[:, -n_patch * concat:].reshape(-1, n_patch, feats.shape[-1] * concat)
def is_tensor(self, thing):
if isinstance(thing, (BatchEncoding, dict)):
return all(isinstance(v, torch.Tensor) for v in thing.values())
return isinstance(thing, torch.Tensor)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
pixel_values: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if hasattr(self, "vision_model") and (input_ids.shape[1] != 1 or self.training) and pixel_values is not None:
with torch.no_grad():
image_features = self.get_vision_features(pixel_values)
image_features = self.mm_projector(image_features)
dummy_image_features = torch.zeros(len(image_features[0]), self.config.mm_hidden_size, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features = self.mm_projector(dummy_image_features)
new_input_embeds = []
cur_image_idx = 0
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
if (cur_input_ids == self.config.image_token_id).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
cur_image_idx += 1
continue
cur_image_features = image_features[cur_image_idx].to(cur_input_embeds.device)
num_patches = cur_image_features.shape[0]
if (cur_input_ids == self.config.image_token_id).sum() != num_patches:
raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
masked_indices = torch.where(cur_input_ids == self.config.image_token_id)[0]
mask_index_start = masked_indices[0]
if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
raise ValueError("The image patch tokens should be consecutive.")
cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
new_input_embeds.append(cur_new_input_embeds)
cur_image_idx += 1
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return super(DetikzifyModel, self).forward(
input_ids=None,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
class DetikzifyForCausalLM(LlamaForCausalLM):
config_class = DetikzifyConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = DetikzifyModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
pixel_values: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
pixel_values=pixel_values
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": kwargs.get("pixel_values", None),
}
)
return model_inputs
```
## /detikzify/model/v1/processing_detikzify.py
```py path="/detikzify/model/v1/processing_detikzify.py"
# Adopted from https://github.com/huggingface/optimum-intel/blob/main/optimum/intel/openvino/modeling_timm.py Below is the original copyright:
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, List, Optional, Union
import numpy as np
from timm.data import resolve_data_config
from timm.models import resolve_pretrained_cfg
from transformers.image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from transformers.image_transforms import resize, to_channel_dimension_format
from transformers.image_utils import (
ChannelDimension,
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.utils import TensorType
class DetikzifyImageProcessor(BaseImageProcessor):
r"""
Constructs a ViT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
self.do_resize = do_resize
self.do_rescale = do_rescale
self.do_normalize = do_normalize
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs,
):
pretrained_cfg = resolve_pretrained_cfg(variant=pretrained_model_name_or_path)
timm_config_dict = resolve_data_config(pretrained_cfg.to_dict())
_, im_h, im_w = timm_config_dict.get("input_size", [3, 224, 224])
image_preprocess_config_dict = {
"crop_size": {"height": im_h, "width": im_w},
"do_center_crop": True if timm_config_dict.get("crop_mode") == "center" else False,
"do_normalize": True,
"do_reduce_labels": False,
"do_rescale": True,
"do_resize": True,
"image_mean": timm_config_dict.get("mean", IMAGENET_STANDARD_MEAN),
"image_processor_type": "TimmImageProcessor",
"image_std": timm_config_dict.get("std", IMAGENET_STANDARD_STD),
"resample": 3,
"rescale_factor": 0.00392156862745098,
"size": {"height": im_h, "width": im_w},
}
return cls.from_dict(image_preprocess_config_dict, **kwargs)
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample:
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The resized image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
if image.ndim == 2:
image = np.stack([image] * 3, axis=-1)
return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
**kwargs,
):
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
resizing.
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use if `do_normalize` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
resample = resample if resample is not None else self.resample
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
size_dict = get_size_dict(size)
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [self.resize(image=image, size=size_dict, resample=resample) for image in images]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
```
## /detikzify/train/__init__.py
```py path="/detikzify/train/__init__.py"
from .pretrain import train as pretrain
from .train import train
```
## /detikzify/train/adapter/__init__.py
```py path="/detikzify/train/adapter/__init__.py"
from transformers import SiglipVisionModel
from ...model.adapter import CrossAttentionAdapterMixin
from .pretrain import train as pretrain
from .train import train
#from .train import train
class CrossAttentionSiglipVisionModel(SiglipVisionModel, CrossAttentionAdapterMixin):
...
```
## /detikzify/train/adapter/pretrain.py
```py path="/detikzify/train/adapter/pretrain.py"
import os
from typing import Dict, Literal
import torch
from torch.utils.data import Dataset
from torchvision.transforms import v2
from transformers import (
Trainer,
TrainerCallback,
TrainingArguments,
is_torch_xla_available,
)
from transformers.trainer_utils import SaveStrategy, get_last_checkpoint
from transformers.utils import logging
from ...model.adapter.modeling_adapter import CrossAttentionAdapterMixin
from ...util import (
EditCutMix,
EditCutOut,
EditMixUp,
FullErase,
SketchAugment,
SplitEpochSaveCallback,
unwrap_processor,
)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm # type: ignore
logger = logging.get_logger("transformers")
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
class EmbeddingSimilarityLoss():
def __init__(self, elementwise: bool = True, use_mse: bool = False):
self.cosine = torch.nn.CosineSimilarity(dim=-1)
self.mae = torch.nn.L1Loss(reduction="none")
self.mse = torch.nn.MSELoss(reduction="none")
self.elementwise = elementwise
self.use_mse = use_mse
# https://github.com/pytorch/pytorch/issues/104564#issuecomment-1651575112
@torch.compile
def cosine_loss(self, x, y):
if self.elementwise:
return self.mae(cos:=self.cosine(x, y), torch.ones_like(cos)).mean()
else:
X = self.cosine(x.unsqueeze(2), y.unsqueeze(1))
Y = self.cosine(y.unsqueeze(2), y.unsqueeze(1))
return self.mae(X, Y).max(dim=-1)[0].mean()
@torch.compile
def l2_loss(self, x, y):
if self.elementwise:
return self.mse(x, y).mean()
else:
X, Y = torch.cdist(x, y), torch.cdist(y, y)
return self.mae(X, Y).max(dim=-1)[0].mean()
def __call__(self, x, y):
if self.use_mse:
return self.l2_loss(x, y)
else:
return self.cosine_loss(x, y)
# https://huggingface.co/docs/transformers/main/en/tasks/knowledge_distillation_for_image_classification
class AdapterTrainer(Trainer):
def __init__(
self,
model: CrossAttentionAdapterMixin,
loss_term: Literal["avg", "pool", "patch", "layer"] = "patch",
elementwise_loss: bool = True,
mse_loss: bool = False,
pool_train_head: bool = False,
multimodal: bool = False,
*args,
**kwargs,
):
self.term = loss_term
self.loss_function = EmbeddingSimilarityLoss(elementwise=elementwise_loss, use_mse=mse_loss)
train_head = self.term == "pool" and pool_train_head
super().__init__(self.prepare_model(model, train_head, multimodal), *args, **kwargs) # type: ignore
if self.term == "layer":
self.loss_layers = sorted({len(self.model.adapter.layers)} | {
idx for idx, layer in enumerate(self.model.adapter.layers, 1) if layer is not None
})
self.control.layer_losses = {layer: 0 for layer in self.loss_layers}
def prepare_model(self, model, train_head=False, multimodal=False):
for name, param in model.named_parameters():
if not "adapter" in name and (not train_head or not "head" in name):
param.requires_grad = False
elif multimodal and "dummy_input" in name:
param.requires_grad = False
elif model.dtype != torch.float32:
param.data = param.data.to(torch.float32)
if train_head: # in this case we also want gradients for the teacher
model.vision_model.head.forward = torch.enable_grad(model.vision_model.head.forward)
if self.term != "pool":
model.vision_model.use_head = False
model.embedding_model.enable_input_require_grads()
model.enable_input_require_grads()
return model
def compute_loss(self, model, inputs, return_outputs=False, **_):
with torch.no_grad():
teacher_output = model(
pixel_values=inputs.pop("labels"),
output_hidden_states=self.term=="layer"
)
student_output = model(
output_hidden_states=self.term=="layer",
**inputs,
)
if self.term == "avg":
loss = self.loss_function(
student_output.last_hidden_state.mean(dim=1),
teacher_output.last_hidden_state.mean(dim=1),
)
elif self.term == "pool":
loss = self.loss_function(
student_output.pooler_output,
teacher_output.pooler_output,
)
elif self.term == "patch":
loss = self.loss_function(
student_output.last_hidden_state,
teacher_output.last_hidden_state
)
else:
loss = 0
for layer in self.loss_layers:
last_layer = layer == self.loss_layers[-1]
layer_loss = self.loss_function(
student_output.last_hidden_state if last_layer else student_output.hidden_states[layer],
teacher_output.last_hidden_state if last_layer else teacher_output.hidden_states[layer]
)
loss += .5 * (1 if last_layer else 1/(len(self.loss_layers)-1)) * layer_loss
log_layer_loss = layer_loss.mean() if self.args.n_gpu > 1 else layer_loss
log_layer_loss = log_layer_loss.detach() / self.args.gradient_accumulation_steps
self.control.layer_losses[layer] += log_layer_loss
return (loss, student_output) if return_outputs else loss
# https://github.com/naba89/custom_hf_trainer
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_xla_available():
xm.mark_step() # type: ignore
logs: Dict[str, float] = {}
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item() # type: ignore
# reset tr_loss to zero
tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
logs["learning_rate"] = self._get_learning_rate()
if self.term == "layer":
for k, v in self.control.layer_losses.items():
layer_loss = self._nested_gather(v).mean().item() # type: ignore
logs[f"layer_loss_{k}"] = round(layer_loss / (self.state.global_step - self._globalstep_last_logged), 4)
self.control.layer_losses[k] -= v # reset the loss
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()
self.log(logs, start_time)
metrics = None
if self.control.should_evaluate:
metrics = self._evaluate(trial, ignore_keys_for_eval)
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
if self.args.save_strategy == SaveStrategy.BEST:
self.control.should_save = is_new_best_metric
if self.control.should_save:
self._save_checkpoint(model, trial)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
class AdapterDataset(Dataset, TrainerCallback):
def __init__(self, dataset, processor, multimodal=False):
super().__init__()
self.processor = processor
self.dataset = dataset
self.multimodal = multimodal
self.offset = 1
self.sketchify = SketchAugment(intensity=2)
self.mixup = EditMixUp()
self.cutmix = EditCutMix()
self.cutout = EditCutOut()
self.erase = FullErase()
def __len__(self):
return len(self.dataset)
def __getitems__(self, indices) -> Dict[str, torch.Tensor]:
batch, images = self.dataset[indices], None
labels = torch.stack([v2.functional.pil_to_tensor(img) for img in batch['image']])
if self.multimodal:
partition, images = torch.randint(3, (len(indices),)), torch.empty_like(labels)
if len(sketch_ids:=torch.argwhere(partition == 0).flatten()):
images[sketch_ids] = self.sketchify(labels[sketch_ids])
if len(blank_ids:=torch.argwhere(partition == 1).flatten()):
images[blank_ids] = self.erase(labels[blank_ids])
if len(edit_ids:=torch.argwhere(partition == 2).flatten()):
edit_partition = torch.randint(3, (len(edit_ids),))
if len(cutout_ids:=edit_ids[torch.argwhere(edit_partition == 0)].flatten()):
images[cutout_ids] = self.cutout(labels[cutout_ids])
if len(mixup_ids:=edit_ids[torch.argwhere(edit_partition == 1)].flatten()):
mixup_imgs = self.dataset[[(indices[idx] + self.offset) % len(self) for idx in mixup_ids]]['image']
mixup_imgs = torch.stack([v2.functional.pil_to_tensor(img) for img in mixup_imgs])
interleaved_imgs = torch.stack([labels[mixup_ids], mixup_imgs], dim=1).view(-1, *mixup_imgs.shape[1:])
images[mixup_ids] = self.mixup(interleaved_imgs)[::2]
if len(cutmix_ids:=edit_ids[torch.argwhere(edit_partition == 2)].flatten()):
cutmix_imgs = self.dataset[[(indices[idx] + self.offset) % len(self) for idx in cutmix_ids]]['image']
cutmix_imgs = torch.stack([v2.functional.pil_to_tensor(img) for img in cutmix_imgs])
interleaved_imgs = torch.stack([labels[cutmix_ids], cutmix_imgs], dim=1).view(-1, *cutmix_imgs.shape[1:])
images[cutmix_ids] = self.cutmix(interleaved_imgs)[::2]
input_ids = self.processor(
images=images,
text=batch['text'],
return_tensors="pt",
text_kwargs=dict(
padding=True,
truncation=True,
)
)
label_ids = unwrap_processor(self.processor)(images=labels, return_tensors="pt")
input_ids['labels'] = label_ids['pixel_values']
return input_ids
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
return self.__getitems__([index] if isinstance(index, int) else index)
def on_epoch_end(self, *args, **kwargs):
self.offset += 1
def train(
output_dir: str,
model,
processor,
dataset,
overwrite=False,
deepspeed=None,
# training hyperparams
multimodal: bool = False,
batch_size: int = 512,
micro_batch_size: int = 8,
num_epochs: int = 3,
learning_rate: float = 1e-4,
gradient_checkpointing: bool = False,
**loss_kwargs
):
dataset = AdapterDataset(dataset, processor=processor, multimodal=multimodal)
gradient_accumulation_steps = batch_size // micro_batch_size
if WORLD_SIZE != 1:
gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE
last_checkpoint = None
if os.path.isdir(output_dir) and not overwrite:
last_checkpoint = get_last_checkpoint(output_dir)
if last_checkpoint is None and len(os.listdir(output_dir)) > 0:
raise ValueError(
f"Output directory ({output_dir}) already exists and is not empty. "
"Use `overwrite` to overcome."
)
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `output_dir` or add `overwrite` to train from scratch."
)
trainer = AdapterTrainer(
model=model,
train_dataset=dataset,
multimodal=multimodal,
callbacks=[SplitEpochSaveCallback(step_size=0.5)],
data_collator=lambda batch: batch,
**loss_kwargs,
args=TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing=gradient_checkpointing,
# https://github.com/huggingface/transformers/issues/21381
gradient_checkpointing_kwargs={'use_reentrant': False},
dataloader_num_workers=WORLD_SIZE,
warmup_steps=500,
weight_decay=0.1,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
torch_compile=True,
bf16=True,
tf32=True,
logging_steps=250,
logging_first_step=True,
lr_scheduler_type="cosine",
optim="adamw_torch" if deepspeed else "adamw_torch_fused",
ddp_find_unused_parameters=False,
remove_unused_columns=False,
save_strategy="epoch",
report_to="none",
output_dir=output_dir,
deepspeed=deepspeed,
)
)
trainer.add_callback(trainer.train_dataset)
trainer.train(resume_from_checkpoint=last_checkpoint)
if trainer.is_deepspeed_enabled:
# https://huggingface.co/docs/accelerate/v0.11.0/en/deepspeed#saving-and-loading
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
last_checkpoint = get_last_checkpoint(output_dir)
load_state_dict_from_zero_checkpoint(trainer.model.float(), last_checkpoint)
model.save_cross_attn_adapter(output_dir)
trainer.save_state()
return model, processor
```
## /detikzify/train/adapter/train.py
```py path="/detikzify/train/adapter/train.py"
from copy import deepcopy
from datetime import timedelta
import os
from typing import Dict, List
from accelerate import Accelerator, InitProcessGroupKwargs
from datasets import Dataset
import torch
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import logging
from ...util import SplitEpochSaveCallback, unwrap_processor
logger = logging.get_logger("transformers")
IGNORE_INDEX = -100
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
def tokenize(
batch,
processor,
caption_condition=False,
**kwargs
):
unwrapped_processor = unwrap_processor(processor)
image_token = unwrapped_processor.image_token
image_token_id = unwrapped_processor.tokenizer.convert_tokens_to_ids(image_token)
bos_token = unwrapped_processor.tokenizer.bos_token
input_ids = processor(
text=batch['caption'],
images_kwargs=dict(
text=[bos_token.join(text) for text in zip(batch['caption'], batch['code'])] if caption_condition else batch['code'],
max_length=unwrapped_processor.tokenizer.model_max_length,
pad_to_multiple_of=8,
add_eos_token=True,
truncation=False,
padding=True
),
text_kwargs=dict(
padding=True,
truncation=True,
),
**kwargs
)
input_ids['labels'] = deepcopy(input_ids['input_ids'])
if caption_condition:
# do not train on caption and pad tokens
for label_ids in input_ids['labels']:
after_bos_token = False
for idx, label_id in enumerate(label_ids):
if not after_bos_token or label_id in {image_token_id, unwrapped_processor.tokenizer.pad_token_id}:
if label_id == unwrapped_processor.tokenizer.bos_token_id:
after_bos_token = True
label_ids[idx] = IGNORE_INDEX
elif label_id == unwrapped_processor.tokenizer.bos_token_id:
after_bos_token = True
else:
# do not train on image and pad tokens
for label_ids in input_ids['labels']:
for idx, label_id in enumerate(label_ids):
if label_id in {image_token_id, processor.tokenizer.pad_token_id}:
label_ids[idx] = IGNORE_INDEX
return input_ids
class AdapterDataset(Dataset):
def __init__(self, dataset, processor, caption_condition=False):
super().__init__()
self.processor = processor
self.dataset = dataset.with_transform(self.tokenize)
self.caption_condition = caption_condition
def __len__(self):
return len(self.dataset)
def tokenize(self, batch):
return tokenize(
batch=batch,
processor=self.processor,
caption_condition=self.caption_condition,
return_tensors="pt",
)
def filter(self, *args, **kwargs):
self.dataset = self.dataset.filter(*args, **kwargs)
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
return self.dataset[index]
def __getitems__(self, indices) -> Dict[str, List[torch.Tensor]]:
return self.dataset[*indices]
def train(
output_dir: str,
model,
processor,
dataset,
overwrite=False,
deepspeed=None,
# training hyperparams
caption_condition: bool = False,
batch_size: int = 128,
micro_batch_size: int = 1,
num_epochs: int = 5,
learning_rate: float = 5e-5,
gradient_checkpointing: bool = False,
):
gradient_accumulation_steps = batch_size // micro_batch_size
if WORLD_SIZE != 1:
gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE
for _, param in model.model.vision_model.named_parameters():
param.requires_grad = False
for _, param in model.adapter.named_parameters():
param.requires_grad = False
for _, param in model.embedding_model.named_parameters():
param.requires_grad = False
model.enable_input_require_grads()
model.embedding_model.enable_input_require_grads()
dataset = AdapterDataset(dataset, processor, caption_condition=caption_condition)
logger.info(f"Dataset size before filtering out too long examples: {len(dataset)}")
eos_token_id, model_max_length = unwrap_processor(processor).tokenizer.eos_token_id, unwrap_processor(processor).tokenizer.model_max_length
with Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(days=3))]).main_process_first():
dataset.filter(lambda ex: (ex['input_ids'] == eos_token_id).nonzero() < model_max_length, num_proc=64, batch_size=16)
logger.info(f"Dataset size after filtering out too long examples: {len(dataset)}")
last_checkpoint = None
if os.path.isdir(output_dir) and not overwrite:
last_checkpoint = get_last_checkpoint(output_dir)
if last_checkpoint is None and len(os.listdir(output_dir)) > 0:
raise ValueError(
f"Output directory ({output_dir}) already exists and is not empty. "
"Use `overwrite` to overcome."
)
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `output_dir` or add `overwrite` to train from scratch."
)
trainer = Trainer(
model=model,
train_dataset=dataset,
args=TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing=gradient_checkpointing,
# https://github.com/huggingface/transformers/issues/32576
#gradient_checkpointing_kwargs={'use_reentrant':False},
dataloader_num_workers=WORLD_SIZE,
warmup_ratio=0.03,
weight_decay=0,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
torch_compile=True,
bf16=True,
tf32=True,
logging_steps=10,
lr_scheduler_type="cosine",
optim="adamw_torch" if deepspeed else "adamw_torch_fused",
ddp_find_unused_parameters=False,
remove_unused_columns=False,
save_strategy="epoch",
report_to="none",
save_total_limit=1,
output_dir=output_dir,
deepspeed=deepspeed,
),
callbacks=[SplitEpochSaveCallback(step_size=0.25)],
data_collator=lambda batch: batch
)
trainer.add_callback(trainer.train_dataset)
trainer.train(resume_from_checkpoint=last_checkpoint)
if trainer.is_deepspeed_enabled:
# https://huggingface.co/docs/accelerate/v0.11.0/en/deepspeed#saving-and-loading
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
last_checkpoint = get_last_checkpoint(output_dir)
load_state_dict_from_zero_checkpoint(trainer.model.float(), last_checkpoint)
trainer.model.unload_cross_attn_adapter()
trainer.save_model(output_dir)
trainer.save_state()
processor.processor.save_pretrained(output_dir)
return model, processor
```
## /detikzify/train/pretrain.py
```py path="/detikzify/train/pretrain.py"
import copy
from functools import partial
import os
from typing import List
from transformers import Trainer, TrainingArguments
IGNORE_INDEX = -100
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
def tokenize(
batch,
processor,
**kwargs
):
image_token = processor.image_token
image_token_id = processor.tokenizer.convert_tokens_to_ids(image_token)
input_ids = processor(
text=batch['text'],
images=batch['image'],
max_length=processor.tokenizer.model_max_length,
pad_to_multiple_of=8,
add_eos_token=True,
**kwargs
)
input_ids['labels'] = copy.deepcopy(input_ids['input_ids'])
# do not train on image and pad tokens
for label_ids in input_ids['labels']:
for idx, label_id in enumerate(label_ids):
if label_id in {image_token_id, processor.tokenizer.pad_token_id}:
label_ids[idx] = IGNORE_INDEX
return input_ids
def train(
output_dir: str,
model,
processor,
dataset,
deepspeed=None,
# training hyperparams
batch_size: int = 256,
micro_batch_size: int = 1,
num_epochs: int = 1,
learning_rate: float = 1e-3,
gradient_checkpointing: bool = False,
full_finetune_modules: List[str] = [
"modality_projection",
],
):
gradient_accumulation_steps = batch_size // micro_batch_size
if WORLD_SIZE != 1:
gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE
for name, param in model.named_parameters():
if not any(module in name for module in full_finetune_modules):
param.requires_grad = False
dataset.set_transform(partial(
tokenize,
processor=processor,
return_tensors="pt",
truncation=True,
padding=True
))
trainer = Trainer(
model=model,
train_dataset=dataset,
args=TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing=gradient_checkpointing,
# https://github.com/huggingface/transformers/issues/21381
gradient_checkpointing_kwargs={'use_reentrant':False},
dataloader_num_workers=WORLD_SIZE,
warmup_ratio=0.03,
weight_decay=0,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
torch_compile=True,
bf16=True,
tf32=True,
logging_steps=10,
lr_scheduler_type="cosine",
optim="adamw_torch" if deepspeed else "adamw_torch_fused",
ddp_find_unused_parameters=False,
remove_unused_columns=False,
save_strategy="no",
report_to="none",
output_dir=output_dir,
deepspeed=deepspeed,
)
)
if trainer.is_deepspeed_enabled and trainer.accelerator.state.deepspeed_plugin.hf_ds_config.is_zero3():
raise ValueError("Pretraining with zero stage 3 is not yet supported.")
trainer.train()
model.save_pretrained(
output_dir,
state_dict={
name: weight
for name, weight in model.state_dict().items()
if any(key_match in name for key_match in full_finetune_modules)
},
)
trainer.save_state()
return model, processor
```
## /detikzify/train/train.py
```py path="/detikzify/train/train.py"
from io import BytesIO
import os
from random import random
from typing import Dict, List
from PIL import Image
import torch
from torch.utils.data import Dataset
from transformers import Trainer, TrainerCallback, TrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import logging
from ..util import SketchAugment, SplitEpochSaveCallback
from .pretrain import tokenize
logger = logging.get_logger("transformers")
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
class ImageSketchDataset(Dataset, TrainerCallback):
"""
Dataset which samples sketches instead of images, when a sketch exists
for the current epoch.
"""
def __init__(self, dataset, processor, ds_sketch_ratio=.5):
super().__init__()
self.processor = processor
self.dataset = dataset.with_transform(self.tokenize)
self.ds_sketch_ratio = ds_sketch_ratio
self.sketchify = SketchAugment()
self.cur_epoch = 0
def __len__(self):
return len(self.dataset)
def tokenize(self, batch):
for idx, sketches in enumerate(batch['sketches']):
if (sketch:=sketches[self.cur_epoch]):
if random() >= self.ds_sketch_ratio:
batch['image'][idx] = Image.open(BytesIO(sketch['bytes'])).convert("RGB")
else:
batch['image'][idx] = self.sketchify(batch['image'][idx])
return tokenize(
batch=batch,
processor=self.processor,
return_tensors="pt",
truncation=False,
padding=True
)
def filter(self, *args, **kwargs):
self.dataset = self.dataset.filter(*args, **kwargs)
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
return self.dataset[index]
def __getitems__(self, indices) -> Dict[str, List[torch.Tensor]]:
return self.dataset[*indices]
def on_epoch_end(self, *args, **kwargs):
self.cur_epoch += 1
def train(
output_dir: str,
model,
processor,
dataset,
overwrite=False,
deepspeed=None,
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 1,
num_epochs: int = 5,
learning_rate: float = 5e-5,
sketch_ratio=.5,
gradient_checkpointing: bool = False,
):
assert num_epochs <= len(dataset[0]['sketches'])
gradient_accumulation_steps = batch_size // micro_batch_size
if WORLD_SIZE != 1:
gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE
dataset = ImageSketchDataset(dataset, processor, ds_sketch_ratio=sketch_ratio)
logger.info(f"Dataset size before filtering out too long examples: {len(dataset)}")
eos_token_id, model_max_length = processor.tokenizer.eos_token_id, processor.tokenizer.model_max_length
dataset.filter(lambda ex: (ex['input_ids'] == eos_token_id).nonzero() < model_max_length)
logger.info(f"Dataset size after filtering out too long examples: {len(dataset)}")
last_checkpoint = None
if os.path.isdir(output_dir) and not overwrite:
last_checkpoint = get_last_checkpoint(output_dir)
if last_checkpoint is None and len(os.listdir(output_dir)) > 0:
raise ValueError(
f"Output directory ({output_dir}) already exists and is not empty. "
"Use `overwrite` to overcome."
)
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `output_dir` or add `overwrite` to train from scratch."
)
trainer = Trainer(
model=model,
train_dataset=dataset,
args=TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing=gradient_checkpointing,
# https://github.com/huggingface/transformers/issues/32576
gradient_checkpointing_kwargs={'use_reentrant':False},
dataloader_num_workers=WORLD_SIZE,
warmup_ratio=0.03,
weight_decay=0,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
torch_compile=True,
bf16=True,
tf32=True,
logging_steps=10,
lr_scheduler_type="cosine",
optim="adamw_torch" if deepspeed else "adamw_torch_fused",
ddp_find_unused_parameters=False,
remove_unused_columns=False,
save_strategy="epoch",
report_to="none",
save_total_limit=1,
output_dir=output_dir,
deepspeed=deepspeed,
),
callbacks=[SplitEpochSaveCallback(step_size=0.25)],
data_collator=lambda batch: batch
)
trainer.add_callback(trainer.train_dataset)
trainer.train(resume_from_checkpoint=last_checkpoint)
if trainer.is_deepspeed_enabled:
# https://huggingface.co/docs/accelerate/v0.11.0/en/deepspeed#saving-and-loading
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
last_checkpoint = get_last_checkpoint(output_dir)
load_state_dict_from_zero_checkpoint(trainer.model.float(), last_checkpoint)
trainer.save_model(output_dir)
trainer.save_state()
return model, processor
```
## /detikzify/util/__init__.py
```py path="/detikzify/util/__init__.py"
from .functools import *
from .generation import *
from .image import *
from .subprocess import *
from .torch import *
from .trainer import *
```
## /detikzify/util/functools.py
```py path="/detikzify/util/functools.py"
from collections import defaultdict
from collections.abc import Callable
from copy import copy
from functools import cache, wraps
from typing import Any
def cache_cast(cast_func: Callable[..., Any]):
"""
functools.cache which takes a user-defined function to convert arguments
into something immutable so it can be cached.
"""
def decorator(func):
cache_args, cache_kwargs = None, None
@cache
def cached_func(_):
return func(*cache_args, **cache_kwargs)
@wraps(func)
def wrapped_func(*args, **kwargs):
nonlocal cache_args, cache_kwargs
cache_args, cache_kwargs = args, kwargs
return cached_func(cast_func(*args, **kwargs))
return wrapped_func
return decorator
def cast(cls, object):
clone = copy(object)
clone.__class__ = cls
return clone
# https://stackoverflow.com/a/12377059
def listify(fn=None, wrapper=list):
"""
A decorator which wraps a function's return value in ``list(...)``.
Useful when an algorithm can be expressed more cleanly as a generator but
the function should return a list.
Example::
>>> @listify
... def get_lengths(iterable):
... for i in iterable:
... yield len(i)
>>> get_lengths(["spam", "eggs"])
[4, 4]
>>>
>>> @listify(wrapper=tuple)
... def get_lengths_tuple(iterable):
... for i in iterable:
... yield len(i)
>>> get_lengths_tuple(["foo", "bar"])
(3, 3)
"""
def listify_return(fn):
@wraps(fn)
def listify_helper(*args, **kw):
return wrapper(fn(*args, **kw))
return listify_helper
if fn is None:
return listify_return
return listify_return(fn)
def batchify(fn=None):
def batch(list_of_dicts):
batch_dict = defaultdict(list)
for d in list_of_dicts:
for k, v in d.items():
batch_dict[k].append(v)
return batch_dict
return listify(fn=fn, wrapper=batch)
```
## /detikzify/util/generation.py
```py path="/detikzify/util/generation.py"
from queue import Queue
from typing import Optional
from transformers import StoppingCriteria
from transformers.generation import streamers
class ExplicitAbort(StoppingCriteria):
"""
Abort a model generation explicitly (i.e., when using a streamer in a thread).
"""
def __init__(self):
super().__init__()
self.should_stop = False
def __call__(self, input_ids, scores, **kwargs) -> bool:
return self.should_stop
def reset(self):
self.should_stop = False
return self
def abort(self):
self.should_stop = True
class TokenStreamer(streamers.BaseStreamer):
"""
Stream raw token ids (i.e., not decoded strings).
"""
def __init__(self, skip_prompt: bool = True, timeout: Optional[float] = None):
self.skip_prompt = skip_prompt
self.next_tokens_are_prompt = True
self.token_queue = Queue()
self.stop_signal = None
self.timeout = timeout
def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TokenStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
for token_id in value.tolist():
self.token_queue.put(token_id, timeout=self.timeout)
def end(self):
self.next_tokens_are_prompt = True
self.token_queue.put(self.stop_signal, timeout=self.timeout)
def propagate_error(self, exc):
self.token_queue.put(exc, timeout=self.timeout)
def __iter__(self):
return self
def __next__(self):
value = self.token_queue.get(timeout=self.timeout)
if value == self.stop_signal:
raise StopIteration()
elif isinstance(value, BaseException):
raise value
else:
return value
class TextIteratorStreamer(streamers.TextIteratorStreamer):
def propagate_error(self, exc):
self.text_queue.put(exc, timeout=self.timeout)
def __next__(self):
value = self.text_queue.get(timeout=self.timeout)
if value == self.stop_signal:
raise StopIteration()
elif isinstance(value, BaseException):
raise value
else:
return value
class StreamerList(list, streamers.BaseStreamer):
"""
Similar to StoppingCriteriaList, only for Streamers.
"""
def put(self, value):
for streamer in self:
streamer.put(value)
def end(self):
for streamer in self:
streamer.end()
def unwrap_processor(processor):
"""
Unwrap a processor, nested processors can happen when using the adapter
processor.
"""
if hasattr(processor, "processor"):
return unwrap_processor(processor.processor)
else:
return processor
```
## /detikzify/util/image.py
```py path="/detikzify/util/image.py"
from base64 import b64decode
from codecs import encode
from io import BytesIO
from os.path import isfile
from PIL import Image, ImageChops, ImageOps
import pymupdf
import requests
from transformers.utils.hub import is_remote_url
DUMMY_IMAGE = Image.new("RGB", (24, 24), color="white")
def convert(image, filetype):
image.save(imgbytes:=BytesIO(), format=filetype)
return Image.open(imgbytes)
def remove_alpha(image, bg):
# https://stackoverflow.com/a/62414364
background = Image.new('RGBA', image.size, bg)
alpha_composite = Image.alpha_composite(background, image.convert("RGBA"))
return alpha_composite.convert("RGB")
# https://stackoverflow.com/a/10616717
def trim(image, bg="white"):
bg = Image.new(image.mode, image.size, bg)
diff = ImageChops.difference(image, bg)
#diff = ImageChops.add(diff, diff, 2.0, -10)
return image.crop(bbox) if (bbox:=diff.getbbox()) else image
def expand(image, size, do_trim=False, bg="white"):
"""Expand image to a square of size {size}. Optionally trims borders first."""
image = trim(image, bg=bg) if do_trim else image
return ImageOps.pad(image, (size, size), color=bg, method=Image.Resampling.LANCZOS)
# based on transformers/image_utils.py (added support for rgba images)
def load(image: Image.Image | str | bytes, bg="white", timeout=None):
if isinstance(image, bytes):
# assume image bytes and open
image = Image.open(BytesIO(image))
elif isinstance(image, str):
if is_remote_url(image):
# https://stackoverflow.com/a/69791396
headers = {'user-agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:68.0) Gecko/20100101 Firefox/68.0'}
image = Image.open(BytesIO(requests.get(image, timeout=timeout, headers=headers).content))
elif isfile(image):
image = Image.open(image)
else:
try:
image.removeprefix("data:image/")
image = Image.open(BytesIO(b64decode(image)))
except Exception as e:
raise ValueError(
"Incorrect image source. "
"Must be a valid URL starting with `http://` or `https://`, "
"a valid path to an image file, bytes, or a base64 encoded "
f"string. Got {image}. Failed with {e}"
)
image = ImageOps.exif_transpose(image) # type: ignore
return remove_alpha(image, bg=bg)
def redact(doc, rot_13=False):
for page in (copy:=pymupdf.open("pdf", doc.tobytes())):
for word in page.get_text("words", clip=pymupdf.INFINITE_RECT()): # type: ignore
text = encode(word[4], "rot13") if rot_13 else None
page.add_redact_annot(word[:4], text=text, fill=False) # type: ignore
page.apply_redactions( # type: ignore
images=pymupdf.PDF_REDACT_IMAGE_NONE, # type: ignore
graphics=pymupdf.PDF_REDACT_LINE_ART_NONE # type: ignore
)
return copy
```
## /detikzify/util/subprocess.py
```py path="/detikzify/util/subprocess.py"
from os import killpg, getpgid
from subprocess import Popen, TimeoutExpired, CalledProcessError, CompletedProcess, PIPE
from signal import SIGKILL
def safe_killpg(pid, signal):
try:
killpg(getpgid(pid), signal)
except ProcessLookupError:
pass # Supress the race condition error; bpo-40550.
# Patch subprocess.run and subprocess.check_output to also kill children of the
# started process on timeouts (cf.
# https://alexandra-zaharia.github.io/posts/kill-subprocess-and-its-children-on-timeout-python/)
def run(*popenargs, input=None, timeout=None, check=False, **kwargs):
with Popen(*popenargs, start_new_session=True, **kwargs) as process:
try:
stdout, stderr = process.communicate(input, timeout=timeout)
except TimeoutExpired:
safe_killpg(process.pid, SIGKILL)
process.wait()
raise
except:
safe_killpg(process.pid, SIGKILL)
raise
retcode = process.poll()
if check and retcode:
raise CalledProcessError(retcode, process.args,
output=stdout, stderr=stderr)
return CompletedProcess(process.args, retcode, stdout, stderr) # type: ignore
def check_output(*popenargs, timeout=None, **kwargs):
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True, **kwargs).stdout
```
## /detikzify/util/torch.py
```py path="/detikzify/util/torch.py"
from torch.cuda import is_available as is_torch_cuda_available
from transformers.utils import is_torch_npu_available, is_torch_xpu_available
# https://github.com/huggingface/peft/blob/c4cf9e7d3b2948e71ec65a19e6cd1ff230781d13/src/peft/utils/other.py#L60-L71
def infer_device():
if is_torch_cuda_available():
torch_device = "cuda"
elif is_torch_xpu_available():
torch_device = "xpu"
elif is_torch_npu_available():
torch_device = "npu"
else:
torch_device = "cpu"
return torch_device
```
## /detikzify/util/trainer.py
```py path="/detikzify/util/trainer.py"
from functools import partial
from numpy import arange
import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
from transformers import (
IntervalStrategy,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import has_length
from torchvision.transforms.v2._utils import query_size
class SplitEpochSaveCallback(TrainerCallback):
"""
If save_strategy==EPOCH also save checkpoints at arbitrary fractions of an
epoch (controlled by step_size).
"""
def __init__(self, step_size: float = 0.5):
self.steps = arange(step_size, 1, step_size)
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if has_length(train_dataloader:=kwargs['train_dataloader']):
self.num_update_steps_per_epoch = max(len(train_dataloader) // args.gradient_accumulation_steps, 1)
else:
self.num_update_steps_per_epoch = args.max_steps
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): # type: ignore
steps = [round(self.num_update_steps_per_epoch * step) for step in self.steps]
if (
state.global_step % self.num_update_steps_per_epoch in steps
and args.save_strategy == IntervalStrategy.EPOCH
):
control.should_save = True
return control
class SketchAugment(v2.Compose):
def __init__(self, intensity=1):
super().__init__([
v2.RandomOrder([
v2.ElasticTransform(alpha=50. * intensity, fill=255),
v2.JPEG((40 * intensity, 100)),
v2.ColorJitter(brightness=(.75 + .25 * intensity, 1.75)),
v2.RandomEqualize(),
v2.RandomGrayscale()
]),
v2.RGB()
])
class FullErase(v2.Lambda):
def __init__(self, value=255):
super().__init__(partial(v2.functional.erase, i=0, j=0, h=-1, w=-1, v=torch.tensor(value)))
class EditBase(v2.Transform):
def __init__(self, *, alpha: float = 1.0) -> None:
super().__init__()
self.alpha = float(alpha)
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def _get_boxes(self, flat_inputs):
lam = self._dist.sample((len(flat_inputs),)).squeeze() # type: ignore
H, W = query_size(flat_inputs)
r_x = torch.randint(W, size=(len(flat_inputs),))
r_y = torch.randint(H, size=(len(flat_inputs),))
r = 0.5 * torch.sqrt(1.0 - lam)
r_w_half = (r * W).int()
r_h_half = (r * H).int()
x1 = torch.clamp(r_x - r_w_half, min=0)
y1 = torch.clamp(r_y - r_h_half, min=0)
x2 = torch.clamp(r_x + r_w_half, max=W)
y2 = torch.clamp(r_y + r_h_half, max=H)
grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H), indexing="ij")
grid_x = grid_x.unsqueeze(0).expand(len(flat_inputs), -1, -1)
grid_y = grid_y.unsqueeze(0).expand(len(flat_inputs), -1, -1)
mask = (grid_x >= x1.unsqueeze(1).unsqueeze(2)) & (grid_x < x2.unsqueeze(1).unsqueeze(2)) & \
(grid_y >= y1.unsqueeze(1).unsqueeze(2)) & (grid_y < y2.unsqueeze(1).unsqueeze(2))
return mask.unsqueeze(1).expand(-1, 3, -1, -1)
class EditCutMix(EditBase):
def _transform(self, inpt, params):
output = inpt.clone()
rolled = inpt.roll(1, 0)
box = self._get_boxes(inpt)
output[box] = rolled[box]
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
class EditMixUp(EditBase):
def _transform(self, inpt, params):
lam = self._dist.sample((len(inpt),)).view(-1, *([1] * len(inpt.shape[1:]))) # type: ignore
output = inpt.roll(1, 0).mul(1.0 - lam).add_(inpt.mul(lam)).to(inpt.dtype)
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
class EditCutOut(EditBase):
def __init__(self, *args, value=255, **kwargs):
self.value = value
super().__init__(*args, **kwargs)
def _transform(self, inpt, params):
output = inpt.clone()
box = self._get_boxes(inpt)
output[box] = self.value
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
```
## /detikzify/webui/README.md
# Web UI
The web UI of DeTi*k*Zify requires [TeX Live
2023](https://www.tug.org/texlive), [ghostscript](https://www.ghostscript.com),
and [poppler](https://poppler.freedesktop.org). You can launch it by running
`python -m detikzify.webui`. It comes with a command line interface. With the
`--share` flag, for example, you can create a shareable link. Checkout `--help`
for a full list of supported options. As scientific figures usually use black
fonts on a white background, it is best to use the web UI in light mode. This
can be enforced by using the `--light` flag. If [FlashAttention-2](
https://huggingface.co/docs/transformers/en/perf_infer_gpu_one?install=NVIDIA#flashattention-2)
is installed, it is picked up automatically and should boost inference speeds.
## Usage Tips
**Visual Prompting** Creating sketches for DeTi*k*Zify (or providing any input
images) shares many similarities with the process of [prompting large language
models](https://en.wikipedia.org/wiki/Prompt_engineering). If DeTi*k*Zify
struggles to comprehend your intent, consider "rephrasing" your input. This
could mean simplifying your sketches or focusing more on the key issue at hand.
In
[this](https://github.com/potamides/DeTikZify/assets/53401822/2819ebca-81f6-4173-8809-0b4255d3e976)
particular instance, for example, we attempted to "prompt" DeTikZify to align
characters diagonally around an equal sign, but it was unsuccessful even after
many simulations. However, upon adjusting the input (by reducing the stroke
width and using more easily recognizable characters) we achieved the [intended
output](https://github.com/potamides/DeTikZify/assets/53401822/c8ecfbff-d22e-41d5-8f73-e0cfafe88690)
after only one simulation.
**Image Editor** You can draw sketches in the integrated image editor, but its
feature set is quite limited. If you are not satisfied with the synthesized
Ti*k*Z programs, try drawing more elaborate sketches in an editor of your
choice (perhaps with graphics primitives) and upload them into the UI.
Alternatively, experimenting with line thickness and/or colors in the
integrated editor might also help.
**Input Postprocessing** Please note that all input images are cropped to the
smallest square around their content and then resized to the resolution
DeTi*k*Zify expects. If you leave large margins this means that DeTi*k*Zify
might perceive your input differently from how you intended (e.g., by drawing
thicker axes). As a rule of thumb, always try to fill as much of the canvas as
possible.
**Input Complexity** If you provide very complex sketches (or figures) and are
not satisfied with the results, you can also try segmenting (or simplifying)
your input and letting DeTi*k*Zify synthesize the individual pieces
independently. This has the advantage that the results will probably be better,
and the disadvantage that you will have to modify and assemble the pieces
yourself.
**Source Code Artifacts** Due to the way we preprocess our
[arXiv.org](https://arxiv.org) data, the preambles of the extracted Ti*k*Z
programs sometimes include packages that are not used inside the `tikzpicture`
environments, and the DeTi*k*Zify models pick up on this behavior. While this
does not hinder compilation in any way, we still recommend everyone to check
the generated preambles and clean them up, if necessary.
**Accuracy-Efficiency Trade-Off** We noticed that lower values for temperatures
and top-p (nucleus) values force DeTi*k*Zify to generate Ti*k*Z programs that
follow the input images more closely, at the expense of generating more
compile-time errors. We pick sensible defaults that aim to balance these two
aspects, but you might want to try to tune these parameters yourself.
**External Graphics** In DaTi*k*Zv2, we replace any externally
included graphics in the `tikzpicture` environments with the [example
image](https://mirrors.ctan.org/macros/latex/contrib/mwe/example-image.pdf)
placeholder from the [mwe](http://www.ctan.org/pkg/mwe) package. So if you want
to generate code with placeholders for your own external graphics, just draw
that example image.
## /detikzify/webui/__init__.py
```py path="/detikzify/webui/__init__.py"
from .webui import *
from .strings import *
from .helpers import *
```
## /detikzify/webui/__main__.py
```py path="/detikzify/webui/__main__.py"
from argparse import ArgumentParser
from .strings import ALGORITHMS, MODELS
from .webui import build_ui
def parse_args():
argument_parser = ArgumentParser(
description="Web UI for DeTikZify."
)
argument_parser.add_argument(
"--model",
default=list(MODELS)[0],
help="Initially selected model. You can also specify a path to your own models.",
)
argument_parser.add_argument(
"--algorithm",
default=list(ALGORITHMS)[0],
choices=list(ALGORITHMS),
help="The inference algorithm to use.",
)
argument_parser.add_argument(
"--lock",
action="store_true",
help="Whether to allow users to change the model or not.",
)
argument_parser.add_argument(
"--lock_reason",
default="Duplicate this space to be able to change this value.",
help="Additional information why model selection is locked.",
)
argument_parser.add_argument(
"--share",
action="store_true",
help="Whether to create a publicly shareable link for the interface.",
)
argument_parser.add_argument(
"--light",
action="store_true",
help= "Whether to enforce light theme (useful for vector graphics with dark text)."
)
argument_parser.add_argument(
"--timeout",
default=60,
type=int,
help="Allowed timeframe for compilation.",
)
return vars(argument_parser.parse_args())
if __name__ == "__main__":
args = parse_args()
share = args.pop("share")
build_ui(**args).queue().launch(share=share)
```
## /detikzify/webui/helpers.py
```py path="/detikzify/webui/helpers.py"
from functools import cache, lru_cache
from inspect import signature
from operator import itemgetter
from os import fdopen
from tempfile import mkstemp
import gradio as gr
from ..infer import TikzDocument
from ..model import load
def to_svg(
tikzdoc: TikzDocument,
build_dir: str
):
if not tikzdoc.is_rasterizable:
if tikzdoc.compiled_with_errors:
raise gr.Error("TikZ code did not compile!")
else:
gr.Warning("TikZ code compiled to an empty image!")
elif tikzdoc.compiled_with_errors:
gr.Warning("TikZ code compiled with errors!")
fd, path = mkstemp(dir=build_dir, suffix=".svg")
with fdopen(fd, "w") as f:
if pdf:=tikzdoc.pdf:
f.write(pdf[0].get_svg_image())
return path if pdf else None
# https://stackoverflow.com/a/50992575
def make_ordinal(n):
n = int(n)
if 11 <= (n % 100) <= 13:
suffix = 'th'
else:
suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)]
return str(n) + suffix
class MctsOutputs(set):
def __init__(self, build_dir, *args, **kwargs):
super().__init__(*args, **kwargs)
self.build_dir, self.svgmap, self.fails = build_dir, dict(), 0
def add(self, score, tikzdoc): # type: ignore
if (score, tikzdoc) not in self:
try:
svg = to_svg(tikzdoc, build_dir=self.build_dir)
super().add((score, tikzdoc))
self.svgmap[tikzdoc] = svg
except gr.Error:
gr.Warning("TikZ code did not compile, discarding output!")
if len(self): self.fails += 1
elif len(self): self.fails += 1
@property
def programs(self):
return [tikzdoc.code for _, tikzdoc in sorted(self, key=itemgetter(0), reverse=True)]
@property
def images(self):
return [
(self.svgmap[tikzdoc], make_ordinal(idx))
for idx, (_, tikzdoc) in enumerate(sorted(self, key=itemgetter(0), reverse=True), 1)
]
@property
def first_success(self):
return len(self) == 1 and not self.fails
def make_light(stylable):
"""
Patch gradio to only contain light mode colors.
"""
if isinstance(stylable, gr.themes.Base): # remove dark variants from the entire theme
params = signature(stylable.set).parameters
colors = {color: getattr(stylable, color.removesuffix("_dark")) for color in dir(stylable) if color in params}
return stylable.set(**colors)
elif isinstance(stylable, gr.Blocks): # also handle components which do not use the theme (e.g. modals)
stylable.load(
fn=None,
js="() => document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'))"
)
return stylable
else:
raise ValueError
@lru_cache(maxsize=1)
def cached_load(*args, **kwargs):
gr.Info("Instantiating model. This could take a while...")
return load(*args, **kwargs)
@cache
def info_once(message):
gr.Info(message)
class GeneratorLock:
"""
Ensure that only one instance of a given generator is active.
Useful when a previous invocation was canceled. See
https://github.com/gradio-app/gradio/issues/8503 for more information.
"""
def __init__(self, gen_func):
self.gen_func = gen_func
self.generator = None
def generate(self, *args, **kwargs):
if self.generator:
if self.generator.gi_running:
return # somehow we can end up here
self.generator.close()
self.generator = self.gen_func(*args, **kwargs)
yield from self.generator
def __call__(self, *args, **kwargs):
yield from self.generate(*args, **kwargs)
```
## /detikzify/webui/strings.py
```py path="/detikzify/webui/strings.py"
from os.path import basename
from transformers import is_timm_available
BANNER = '''\
DeTikZify: Synthesizing Graphics Programs for Scientific Figures and Sketches with TikZ
'''
MODELS = {
basename(model): model
for model in [
"nllg/detikzify-v2-8b",
#"nllg/detikzify-v2-3b", # coming soon
]
}
if is_timm_available():
MODELS |= {
basename(model).replace("detikzify", "detikzify-v1"): model
for model in [
"nllg/detikzify-ds-7b",
"nllg/detikzify-cl-7b",
"nllg/detikzify-ds-1.3b",
"nllg/detikzify-tl-1.1b",
]
}
ALGORITHMS = {
"mcts": "MCTS",
"sampling": "Sampling"
}
# https://github.com/gradio-app/gradio/issues/3202#issuecomment-1741571240
# https://github.com/gradio-app/gradio/issues/2666#issuecomment-1651127149
# https://stackoverflow.com/a/64033350
CSS = """
.input-image {
flex-grow: 1;
}
.output-code {
flex-grow: 1;
height: 0vh;
min-height: 250px;
scrollbar-width: thin !important;
}
.output-code .hide {
display: none;
}
.output-code .cm-scroller {
flex-grow: 1;
}
.output-code .cm-gutters {
position: relative !important;
}
.output-image {
flex-grow: 1;
height: 0vh;
min-height: 250px;
overflow-y: auto !important;
scrollbar-width: thin !important;
}
.output-image .image-container, .output-image .grid-container {
width: 100%;
height: 100%;
}
.output-image .thumbnail-item img {
object-fit: contain;
}
.output-image .grid-wrap.fixed-height {
max-height: 100% !important;
}
.outputs .tabs {
display: flex;
flex-direction: column;
flex-grow: 1;
}
.outputs .tabitem[style="display: block;"] {
flex-grow: 1;
display: flex !important;
}
.outputs .gap {
flex-grow: 1;
}
.outputs .form {
flex-grow: 1 !important;
}
.outputs .form > :last-child{
flex-grow: 1;
}
"""
# (Ab)use an invisible fake button with id preview-close to propagate the
# actual press of the button that closes a preview
# https://github.com/gradio-app/gradio/issues/6697
GALLERY_DESELECT_HACK = """
"""
```
The content has been capped at 50000 tokens, and files over NaN bytes have been omitted. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.