``` ├── .gitattributes ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── detikzify/ ├── __init__.py ├── dataset/ ├── __init__.py ├── paper2fig/ ├── __init__.py ├── paper2fig.py ├── scicap/ ├── __init__.py ├── scicap.py ├── evaluate/ ├── __init__.py ├── clipscore.py ├── crystalbleu.py ├── dreamsim.py ├── eed.py ├── imagesim.py ├── kid.py ├── infer/ ├── __init__.py ├── generate.py ├── tikz.py ├── mcts/ ├── LICENSE ├── README.md ├── __init__.py ├── montecarlo.py ├── node.py ├── model/ ├── __init__.py ├── adapter/ ├── __init__.py ├── modeling_adapter.py ├── processing_adapter.py ├── configuration_detikzify.py ├── modeling_detikzify.py ├── processing_detikzify.py ├── v1/ ├── __init__.py ├── configuration_detikzify.py ├── modeling_detikzify.py ├── processing_detikzify.py ├── train/ ├── __init__.py ├── adapter/ ├── __init__.py ├── pretrain.py ├── train.py ├── pretrain.py ├── train.py ├── util/ ├── __init__.py ├── functools.py ├── generation.py ├── image.py ├── subprocess.py ├── torch.py ├── trainer.py ├── webui/ ├── README.md ├── __init__.py ├── __main__.py ├── helpers.py ├── strings.py ├── webui.py ``` ## /.gitattributes ```gitattributes path="/.gitattributes" detikzify/mcts/** linguist-vendored ``` ## /.gitignore ```gitignore path="/.gitignore" # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ # Project pecific */_version.py ``` ## /LICENSE ``` path="/LICENSE" Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ``` ## /MANIFEST.in ```in path="/MANIFEST.in" prune examples ``` ## /README.md # DeTi*k*Zify
Synthesizing Graphics Programs for Scientific Figures and Sketches with Ti*k*Z [![OpenReview](https://img.shields.io/badge/View%20on%20OpenReview-8C1B13?labelColor=gray&logo=)](https://openreview.net/forum?id=bcVLFQCOjc) [![arXiv](https://img.shields.io/badge/View%20on%20arXiv-B31B1B?logo=arxiv&labelColor=gray)](https://arxiv.org/abs/2405.15306) [![Hugging Face](https://img.shields.io/badge/View%20on%20Hugging%20Face-blue?labelColor=gray&logo=)](https://huggingface.co/collections/nllg/detikzify-664460c521aa7c2880095a8b) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1hPWqucbPGTavNlYvOBvSNBAwdcPZKe8F) Creating high-quality scientific figures can be time-consuming and challenging, even though sketching ideas on paper is relatively easy. Furthermore, recreating existing figures that are not stored in formats preserving semantic information is equally complex. To tackle this problem, we introduce [DeTi*k*Zify](https://github.com/potamides/DeTikZify), a novel multimodal language model that automatically synthesizes scientific figures as semantics-preserving [Ti*k*Z](https://github.com/pgf-tikz/pgf) graphics programs based on sketches and existing figures. We also introduce an MCTS-based inference algorithm that enables DeTi*k*Zify to iteratively refine its outputs without the need for additional training. https://github.com/potamides/DeTikZify/assets/53401822/203d2853-0b5c-4a2b-9d09-3ccb65880cd3 ## News * **2025-03-17**: We release [Ti*k*Zero](https://huggingface.co/nllg/tikzero-adapter) adapters which plug directly into [DeTi*k*Zifyv2 (8b)](https://huggingface.co/nllg/detikzify-v2-8b) and enable zero-shot text-conditioning, and [Ti*k*Zero+](https://huggingface.co/nllg/tikzero-plus-10b) with additional end-to-end fine-tuning. For more information see our [paper](https://arxiv.org/abs/2503.11509) and usage examples [below](#usage). * **2024-12-05**: We release [DeTi*k*Zifyv2 (8b)](https://huggingface.co/nllg/detikzify-v2-8b), our latest model which surpasses all previous versions in our evaluation and make it the new default model in our [Hugging Face Space](https://huggingface.co/spaces/nllg/DeTikZify). Check out the [model card](https://huggingface.co/nllg/detikzify-v2-8b-preview#model-card-for-detikzifyv2-8b) for more information. * **2024-09-24**: DeTi*k*Zify was accepted at [NeurIPS 2024](https://neurips.cc/Conferences/2024) as a [spotlight paper](https://neurips.cc/virtual/2024/poster/94474)! ## Installation > [!TIP] > If you encounter difficulties with installation and inference on your own > hardware, consider visiting our [Hugging Face > Space](https://huggingface.co/spaces/nllg/DeTikZify) (please note that > restarting the space can take up to 30 minutes). Should you experience long > queues, you have the option to > [duplicate](https://huggingface.co/spaces/nllg/DeTikZify?duplicate=true) it > with a paid private GPU runtime for a more seamless experience. Additionally, > you can try our demo on [Google > Colab](https://colab.research.google.com/drive/1hPWqucbPGTavNlYvOBvSNBAwdcPZKe8F). > However, setting up the environment there might take some time, and the free > tier only supports inference for the 1b models. The Python package of DeTi*k*Zify can be easily installed using [pip](https://pip.pypa.io/en/stable): ```sh pip install 'detikzify[legacy] @ git+https://github.com/potamides/DeTikZify' ``` The `[legacy]` extra is only required if you plan to use the DeTi*k*Zifyv1 models. If you only plan to use DeTi*k*Zifyv2 you can remove it. If your goal is to run the included [examples](examples), it is easier to clone the repository and install it in editable mode like this: ```sh git clone https://github.com/potamides/DeTikZify pip install -e DeTikZify[examples] ``` In addition, DeTi*k*Zify requires a full [TeX Live 2023](https://www.tug.org/texlive) installation, [ghostscript](https://www.ghostscript.com), and [poppler](https://poppler.freedesktop.org) which you have to install through your package manager or via other means. ## Usage > [!TIP] > For interactive use and general [usage tips](detikzify/webui#usage-tips), > we recommend checking out our [web UI](detikzify/webui), which can be started > directly from the command line (use `--help` for a list of all options): > ```sh > python -m detikzify.webui --light > ``` If all required dependencies are installed, the full range of DeTi*k*Zify features such as compiling, rendering, and saving Ti*k*Z graphics, and MCTS-based inference can be accessed through its programming interface:
DeTikZify Example ```python from operator import itemgetter from detikzify.model import load from detikzify.infer import DetikzifyPipeline image = "https://w.wiki/A7Cc" pipeline = DetikzifyPipeline(*load( model_name_or_path="nllg/detikzify-v2-8b", device_map="auto", torch_dtype="bfloat16", )) # generate a single TikZ program fig = pipeline.sample(image=image) # if it compiles, rasterize it and show it if fig.is_rasterizable: fig.rasterize().show() # run MCTS for 10 minutes and generate multiple TikZ programs figs = set() for score, fig in pipeline.simulate(image=image, timeout=600): figs.add((score, fig)) # save the best TikZ program best = sorted(figs, key=itemgetter(0))[-1][1] best.save("fig.tex") ```
Through [Ti*k*Zero](https://huggingface.co/nllg/tikzero-adapter) adapters and [Ti*k*Zero+](https://huggingface.co/nllg/tikzero-plus-10b) it is also possible to synthesize graphics programs conditioned on text (cf. our [paper](https://arxiv.org/abs/2503.11509) for details). Note that this currently only supported through the programming interface:
TikZero+ Example ```python from detikzify.model import load from detikzify.infer import DetikzifyPipeline caption = "A multi-layer perceptron with two hidden layers." pipeline = DetikzifyPipeline(*load( model_name_or_path="nllg/tikzero-plus-10b", device_map="auto", torch_dtype="bfloat16", )) # generate a single TikZ program fig = pipeline.sample(text=caption) # if it compiles, rasterize it and show it if fig.is_rasterizable: fig.rasterize().show() ```
TikZero Example ```python from detikzify.model import load, load_adapter from detikzify.infer import DetikzifyPipeline caption = "A multi-layer perceptron with two hidden layers." pipeline = DetikzifyPipeline( *load_adapter( *load( model_name_or_path="nllg/detikzify-v2-8b", device_map="auto", torch_dtype="bfloat16", ), adapter_name_or_path="nllg/tikzero-adapter", ) ) # generate a single TikZ program fig = pipeline.sample(text=caption) # if it compiles, rasterize it and show it if fig.is_rasterizable: fig.rasterize().show() ```
More involved examples, for example for evaluation and training, can be found in the [examples](examples) folder. ## Model Weights & Datasets We upload all our DeTi*k*Zify models and datasets to the [Hugging Face Hub](https://huggingface.co/collections/nllg/detikzify-664460c521aa7c2880095a8b) (Ti*k*Zero models are available [here](https://huggingface.co/collections/nllg/tikzero-67d1952fab69f5bd172de1fe)). However, please note that for the public release of the [DaTi*k*Zv2](https://huggingface.co/datasets/nllg/datikz-v2) and [DaTi*k*Zv3](https://huggingface.co/datasets/nllg/datikz-v3) datasets, we had to remove a considerable portion of Ti*k*Z drawings originating from [arXiv](https://arxiv.org), as the [arXiv non-exclusive license](https://arxiv.org/licenses/nonexclusive-distrib/1.0/license.html) does not permit redistribution. We do, however, release our [dataset creation scripts](https://github.com/potamides/DaTikZ) and encourage anyone to recreate the full version of DaTi*k*Z themselves. ## Citation If DeTi*k*Zify and Ti*k*Zero have been beneficial for your research or applications, we kindly request you to acknowledge this by citing them as follows: ```bibtex @inproceedings{belouadi2024detikzify, title={{DeTikZify}: Synthesizing Graphics Programs for Scientific Figures and Sketches with {TikZ}}, author={Jonas Belouadi and Simone Paolo Ponzetto and Steffen Eger}, booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, year={2024}, url={https://openreview.net/forum?id=bcVLFQCOjc} } ``` ```bibtex @misc{belouadi2025tikzero, title={{TikZero}: Zero-Shot Text-Guided Graphics Program Synthesis}, author={Jonas Belouadi and Eddy Ilg and Margret Keuper and Hideki Tanaka and Masao Utiyama and Raj Dabre and Steffen Eger and Simone Paolo Ponzetto}, year={2025}, eprint={2503.11509}, archivePrefix={arXiv}, primaryClass={cs.CL}, url={https://arxiv.org/abs/2503.11509}, } ``` ## Acknowledgments The implementation of the DeTi*k*Zify model architecture is based on [LLaVA](https://github.com/haotian-liu/LLaVA) and [AutomaTikZ](https://github.com/potamides/AutomaTikZ) (v1), and [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (v2). Our MCTS implementation is based on [VerMCTS](https://github.com/namin/llm-verified-with-monte-carlo-tree-search). The Ti*k*Zero architecture draws inspiration from [Flamingo](https://deepmind.google/discover/blog/tackling-multiple-tasks-with-a-single-visual-language-model/) and [LLaMA 3.2-Vision](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices). ## /detikzify/__init__.py ```py path="/detikzify/__init__.py" ``` ## /detikzify/dataset/__init__.py ```py path="/detikzify/dataset/__init__.py" from datasets.load import load_dataset as _load_dataset from os.path import dirname, isdir, join def load_dataset(path, *args, **kwargs): if isdir(local := join(dirname(__file__), path)): return _load_dataset(local, *args, trust_remote_code=True, **kwargs) return _load_dataset(path, *args, **kwargs) ``` ## /detikzify/dataset/paper2fig/__init__.py ```py path="/detikzify/dataset/paper2fig/__init__.py" ``` ## /detikzify/dataset/paper2fig/paper2fig.py ```py path="/detikzify/dataset/paper2fig/paper2fig.py" """ Images from the Paper2Fig100k dataset. """ from itertools import chain from json import load from os.path import basename import tarfile from datasets import Features, Image, Sequence, Value, builder from datasets.info import DatasetInfo from datasets.splits import Split, SplitGenerator from detikzify.util import convert, expand class Paper2FigConfig(builder.BuilderConfig): """BuilderConfig for Paper2Fig.""" def __init__(self, size, *args, **kwargs): super().__init__(*args, **kwargs) self.size = size self.archive = "https://zenodo.org/records/7299423/files/Paper2Fig100k.tar.gz" class Paper2Fig(builder.GeneratorBasedBuilder): """The Paper2Fig100k dataset in the format DeTikZify expects (everything is training data).""" BUILDER_CONFIG_CLASS = Paper2FigConfig def _info(self): features = { "caption": Value("string"), "mention": Sequence(Sequence(Value("string"))), "ocr": Sequence(Value("string")), "image": Image(), } return DatasetInfo( description=str(__doc__), features=Features(features), ) def _split_generators(self, dl_manager): archive = dl_manager.download(self.config.archive) # type: ignore return [SplitGenerator(name=str(Split.TRAIN), gen_kwargs=dict(archive=archive))] def _generate_examples(self, archive): with tarfile.open(archive) as tf: metadata = dict() for figdata in chain.from_iterable(load(tf.extractfile(f)) for f in tf if f.name.endswith(".json")): # type: ignore metadata[figdata.pop("figure_id")] = figdata for idx, member in enumerate(tf): if member.name.endswith(".png"): figure_id = basename(member.name).removesuffix(".png") figdata = metadata[figure_id] yield idx, dict( caption=figdata["captions"][0], mention=[figdata["captions"][1:]], ocr=[result['text'] for result in figdata['ocr_result']['ocr_result']], image=convert(expand(tf.extractfile(member), self.config.size), "png"), ) ``` ## /detikzify/dataset/scicap/__init__.py ```py path="/detikzify/dataset/scicap/__init__.py" ``` ## /detikzify/dataset/scicap/scicap.py ```py path="/detikzify/dataset/scicap/scicap.py" """ The SciCap dataset, unified in a single train split. """ from json import load from os import symlink from os.path import basename, join from subprocess import run from tempfile import TemporaryDirectory from zipfile import ZipFile from datasets import Features, Image, Sequence, Value, builder from datasets.info import DatasetInfo from datasets.splits import Split, SplitGenerator from datasets.utils.hub import hf_hub_url from detikzify.util import convert, expand class SciCapConfig(builder.BuilderConfig): """BuilderConfig for SciCap.""" def __init__(self, size, *args, **kwargs): super().__init__(*args, **kwargs) self.repo_id = "CrowdAILab/scicap" self.size = size self.files = { "img": { (public:="img-split"): 10, (hidden:="img-hide_test"): 0 }, "text": { "train": public, "train-acl": public, "val": public, "public-test": public, "hide_test": hidden, } } class SciCap(builder.GeneratorBasedBuilder): """The SciCap dataset in the format DeTikZify expects (everything is training data).""" BUILDER_CONFIG_CLASS = SciCapConfig def _info(self): features = { "caption": Value("string"), "mention": Sequence(Sequence(Value("string"))), "paragraph": Sequence(Value("string")), "ocr": Sequence(Value("string")), "image": Image(), } return DatasetInfo( description=str(__doc__), features=Features(features), ) def _split_generators(self, dl_manager): with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: def dl(path): return dl_manager.download(hf_hub_url(self.config.repo_id, path)) # type: ignore def zip_dl(path, num_splits=0): paths = [f"{path}.zip"] + list(f"{path}.z{{:02d}}".format(i+1) for i in range(num_splits)) downloaded = [dl(path) for path in paths] if num_splits: output = join(tmpdirname, f"{path}-joined.zip") for src, dst in zip(downloaded, paths): symlink(src, join(tmpdirname, dst)) # type: ignore run(["zip", "-FF", join(tmpdirname, paths[0]), "--out", output], check=True, capture_output=True) return output else: return downloaded[0] files_to_download = self.config.files # type: ignore img = {file:zip_dl(file, num_splits) for file, num_splits in files_to_download['img'].items()} text = {dl(f"{file}.json"):img[img_file] for file, img_file in files_to_download['text'].items()} yield SplitGenerator(name=str(Split.TRAIN), gen_kwargs={"shards": text}) def _generate_examples(self, shards): idx = 0 for path, image_zip in shards.items(): with ZipFile(file=image_zip, mode='r') as zf: imagemap = {basename(name):name for name in zf.namelist()} with open(path) as f: images, annotations = load(f).values() for annotation, image in zip(annotations, images): assert image["id"] == annotation['image_id'] with zf.open(imagemap[image['file_name']]) as img: yield idx, dict( caption=annotation.get("caption_no_index"), mention=annotation.get("mention"), paragraph=annotation.get("paragraph"), ocr=image.get("ocr"), image=convert(expand(img, self.config.size), "png") ) idx += 1 ``` ## /detikzify/evaluate/__init__.py ```py path="/detikzify/evaluate/__init__.py" # pyright: reportUnsupportedDunderAll=false from importlib import import_module from typing import Any from .imagesim import * # this metric is used by MCTS, so it is not optional __all__ = [ "ImageSim", "CrystalBLEU", "KernelInceptionDistance", "TexEditDistance", "DreamSim", "ClipScore", ] # lazy import optional metrics (https://peps.python.org/pep-0562/) def __getattr__(name) -> Any: def load(metric): return getattr(import_module("." + metric, __name__), name) try: match name: case "CrystalBLEU": return load("crystalbleu") case "KernelInceptionDistance": return load("kid") case "TexEditDistance": return load("eed") case "DreamSim": return load("dreamsim") case "ClipScore": return load("clipscore") except ImportError: raise ValueError( "Missing dependencies: " "Install this project with the [evaluate] feature name!" ) return import_module("." + name, __name__) ``` ## /detikzify/evaluate/clipscore.py ```py path="/detikzify/evaluate/clipscore.py" from functools import cached_property from typing import List from PIL import Image import torch from torch.cuda import is_available as is_cuda_available, is_bf16_supported from torchmetrics import Metric from transformers import AutoModel, AutoProcessor from ..util import expand, infer_device, load class ClipScore(Metric): """Calculates CLIPScore which is a text-to-image similarity metric.""" higher_is_better = True def __init__( self, model_name: str = "google/siglip-so400m-patch14-384", preprocess: bool = True, device: str = infer_device(), dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16, **kwargs ): super().__init__(**kwargs) self.model_name = model_name self.preprocess = preprocess self._device = device self.set_dtype(dtype) self.add_state("score", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum") self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") def __str__(self): return self.__class__.__name__ @cached_property def model(self): model = AutoModel.from_pretrained(self.model_name, torch_dtype=self.dtype) return model.to(self.device) @cached_property def processor(self): return AutoProcessor.from_pretrained(self.model_name) def update( self, images: Image.Image | str | List[Image.Image | str], text: str | List[str] ): images = images if isinstance(images, List) else [images] text = text if isinstance(text, List) else [text] for img, txt in zip(images, text): img = load(img) if self.preprocess: img = expand(img, max(img.size), do_trim=True) with torch.inference_mode(): inputs = self.processor(text=txt, images=img, truncation=True, return_tensors="pt") outputs = self.model( input_ids=inputs.input_ids.to(self.device), pixel_values=inputs.pixel_values.to(self.device, self.dtype) ) self.score += torch.sigmoid(outputs.logits_per_image).item() self.n_samples += 1 def compute(self): return (self.score / self.n_samples).item() ``` ## /detikzify/evaluate/crystalbleu.py ```py path="/detikzify/evaluate/crystalbleu.py" from collections import Counter from functools import cached_property from hashlib import md5 from itertools import chain, tee from pickle import dump, load from typing import List from crystalbleu import corpus_bleu from datasets.utils.logging import get_logger from huggingface_hub import cached_assets_path from pygments.lexers.markup import TexLexer from pygments.token import Comment, Name, Text from sacremoses import MosesTokenizer from torchmetrics import Metric logger = get_logger("datasets") # adopted from nltk def pad_sequence(sequence, n, pad_left=False, pad_right=False, left_pad_symbol=None, right_pad_symbol=None): sequence = iter(sequence) if pad_left: sequence = chain((left_pad_symbol,) * (n - 1), sequence) if pad_right: sequence = chain(sequence, (right_pad_symbol,) * (n - 1)) return sequence # adopted from nltk def ngrams(sequence, n, **kwargs): sequence = pad_sequence(sequence, n, **kwargs) iterables = tee(sequence, n) for i, sub_iterable in enumerate(iterables): # For each window, for _ in range(i): # iterate through every order of ngrams next(sub_iterable, None) # generate the ngrams within the window. return zip(*iterables) # Unpack and flattens the iterables. class CrystalBLEU(Metric): """Wrapper around https://github.com/sola-st/crystalbleu (adapted for LaTeX)""" def __init__(self, corpus, k=500, n=4, use_cache=True, **kwargs): super().__init__(**kwargs) self.lexer = TexLexer() self.tokenizer = MosesTokenizer() self.use_cache = use_cache self.corpus = corpus self.k = k self.n = n self.add_state("list_of_references", [], dist_reduce_fx="cat") self.add_state("hypotheses", [], dist_reduce_fx="cat") def __str__(self): return self.__class__.__name__ @cached_property def trivially_shared_ngrams(self): """ Computes trivially shared ngrams and caches them. """ cache_dir = cached_assets_path(library_name="evaluate", namespace=self.__class__.__name__.lower()) dhash = md5() dhash.update(str(sorted(self.corpus)).encode()) hashname = f"{dhash.hexdigest()}.pkl" if (cache_file:=(cache_dir / hashname)).is_file() and self.use_cache: logger.info(f"Found cached trivially shared ngrams ({cache_file})") with open(cache_file, "rb") as f: return load(f) else: all_ngrams = list() for o in range(1, self.n+1): for tex in self.corpus: all_ngrams.extend(ngrams(self._tokenize(tex), o)) frequencies = Counter(all_ngrams) trivially_shared_ngrams = dict(frequencies.most_common(self.k)) if self.use_cache: logger.info(f"Caching trivially shared ngrams ({cache_file})") with open(cache_file, "wb") as f: dump(trivially_shared_ngrams, f) return trivially_shared_ngrams def _tokenize(self, text): tokens = list() for tokentype, value in self.lexer.get_tokens(text): if value.strip() and not tokentype is Comment: if any(tokentype is tp for tp in [Text, Name.Attribute, Name.Builtin]): tokens.extend(self.tokenizer.tokenize(value.strip())) else: tokens.append(value.strip()) return tokens def update( self, list_of_references: List[List[str]], hypotheses: List[str], ): assert len(list_of_references) == len(hypotheses) self.list_of_references.extend([self._tokenize(ref) for ref in refs] for refs in list_of_references) self.hypotheses.extend(self._tokenize(hyp) for hyp in hypotheses) def compute(self): return corpus_bleu( list_of_references=self.list_of_references, hypotheses=self.hypotheses, ignoring=self.trivially_shared_ngrams ) ``` ## /detikzify/evaluate/dreamsim.py ```py path="/detikzify/evaluate/dreamsim.py" from functools import cached_property from typing import List from PIL import Image from dreamsim import dreamsim from huggingface_hub import cached_assets_path import torch from torch.cuda import is_available as is_cuda_available, is_bf16_supported from torchmetrics import Metric from ..util import expand, infer_device, load class DreamSim(Metric): """Perceptual image similarity using DreamSim""" higher_is_better = True def __init__( self, model_name: str = "ensemble", pretrained: bool = True, normalize: bool = True, preprocess: bool = True, device: str = infer_device(), dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16, **kwargs ): super().__init__(**kwargs) self.model_name = model_name self.pretrained = pretrained self.normalize = normalize self._device = device self.set_dtype(dtype) self.preprocess = preprocess self.add_state("score", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum") self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") def __str__(self): return self.__class__.__name__ @cached_property def dreamsim(self): model, processor = dreamsim( dreamsim_type=self.model_name, pretrained = self.pretrained, normalize_embeds=self.normalize, device=str(self.device), cache_dir=str(cached_assets_path(library_name="evaluate", namespace=self.__class__.__name__.lower())) ) for extractor in model.extractor_list: extractor.model = extractor.model.to(self.dtype) extractor.proj = extractor.proj.to(self.dtype) return dict( model=model.to(self.dtype), processor=processor ) @property def model(self): return self.dreamsim['model'] @property def processor(self): return self.dreamsim['processor'] def update( self, img1: Image.Image | str | List[Image.Image | str], img2: Image.Image | str | List[Image.Image | str], ): if isinstance(img1, List) or isinstance(img2, List): assert type(img1) == type(img2) and len(img1) == len(img2) # type: ignore else: img1, img2 = [img1], [img2] for i1, i2 in zip(img1, img2): # type: ignore i1, i2 = load(i1), load(i2) if self.preprocess: i1 = expand(load(i1), max(i1.size), do_trim=True) i2 = expand(load(i2), max(i2.size), do_trim=True) i1 = self.processor(i1).to(self.device, self.dtype) i2 = self.processor(i2).to(self.device, self.dtype) with torch.inference_mode(): self.score += 1 - self.model(i1, i2).item() # type: ignore self.n_samples += 1 def compute(self): return (self.score / self.n_samples).item() ``` ## /detikzify/evaluate/eed.py ```py path="/detikzify/evaluate/eed.py" from pygments.lexers.markup import TexLexer from pygments.token import Comment, Text from torchmetrics.text import ExtendedEditDistance from torchmetrics.functional.text.eed import ( _compute_sentence_statistics, _preprocess_en, _preprocess_ja, ) from torchmetrics.functional.text.helper import _validate_inputs class TexEditDistance(ExtendedEditDistance): """Adapt torchmetrics ExtendedEditDistance for TeX""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.lexer = TexLexer() def __str__(self): return self.__class__.__name__ def _preprocess_sentences(self, preds, target, language): target, preds = _validate_inputs(hypothesis_corpus=preds, ref_corpus=target) def tokenize(text): tokens = list() for tokentype, value in self.lexer.get_tokens(text): if value.strip(): if tokentype is Text: if language == "en": preprocess_function = _preprocess_en elif language == "ja": preprocess_function = _preprocess_ja else: raise ValueError(f"Expected argument `language` to either be `en` or `ja` but got {language}") tokens.extend(preprocess_function(value).split()) elif not tokentype is Comment: tokens.extend(value.split()) return " " + " ".join(tokens) + " " preds = [tokenize(pred) for pred in preds] target = [[tokenize(ref) for ref in reference] for reference in target] return preds, target def update(self, preds, target): """Update state with predictions and targets.""" preds, target = self._preprocess_sentences(preds, target, self.language) if self.sentence_eed is None: self.sentence_eed = [] if 0 in (len(preds), len(target[0])): return self.sentence_eed for (hypothesis, target_words) in zip(preds, target): score = _compute_sentence_statistics( hypothesis, target_words, self.alpha, self.rho, self.deletion, self.insertion ) self.sentence_eed.append(score) return self.sentence_eed def compute(self, *args, **kwargs): return super().compute(*args, **kwargs).item() # type: ignore ``` ## /detikzify/evaluate/imagesim.py ```py path="/detikzify/evaluate/imagesim.py" from functools import cached_property from math import tanh from typing import List, Literal, Optional from PIL import Image from ot.lp import emd2 import torch from torch.cuda import is_available as is_cuda_available, is_bf16_supported import torch.nn.functional as F from torchmetrics import Metric from torchmetrics.functional import pairwise_cosine_similarity from transformers import AutoImageProcessor, AutoModel, PreTrainedModel, ProcessorMixin from ..model.adapter import ( AdapterProcessor, CrossAttentionAdapterMixin as AdapterMixin, has_adapter, ) from ..util import cast, expand, infer_device, load, unwrap_processor class ImageSim(Metric): """Perceptual image similarity using visual encoders.""" higher_is_better = True def __init__( self, model_name: str = "google/siglip-so400m-patch14-384", mode: Literal["emd", "cos", "cos_avg"] = "cos", preprocess: bool = True, device: str = infer_device(), dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16, **kwargs ): super().__init__(**kwargs) self.model_name = model_name self.preprocess = preprocess self.mode = mode self._device = device self.set_dtype(dtype) self.add_state("score", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum") self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") def __str__(self): return self.__class__.__name__ + f" ({self.mode.upper().replace('_', '-')})" @cached_property def model(self): # even if we instantiate with from_detikzify we still end up in this function if (model:=dict(self.named_children()).get("model")) is None: model = AutoModel.from_pretrained(self.model_name, torch_dtype=self.dtype) model = model.vision_model.to(self.device) return model @cached_property def processor(self): return AutoImageProcessor.from_pretrained(self.model_name) @classmethod def from_detikzify(cls, model: PreTrainedModel, processor: ProcessorMixin, mode=None, *args, **kwargs): derived_kwargs = dict( model_name = model.name_or_path, mode = getattr(model.config, "pooling_mode", "emd") if mode is None else mode, device = model.device, dtype = model.dtype, ) imagesim = cls(*args, **(derived_kwargs | kwargs)) if has_adapter(model): class AdapterVisionModel(type(model.model.vision_model), AdapterMixin): embedding_model=model.embedding_model adapter=model.adapter @classmethod def cast(cls, vision_model): adapter_vision_model = cast(cls, vision_model) adapter_vision_model.add_hooks() return adapter_vision_model imagesim.model = AdapterVisionModel.cast(model.model.vision_model) imagesim.processor = AdapterProcessor( processor=unwrap_processor(processor).image_processor, tokenizer=processor.tokenizer # type: ignore ) else: imagesim.model = model.model.vision_model imagesim.processor = unwrap_processor(processor).image_processor return imagesim def get_vision_features(self, image: Optional[Image.Image | str] = None, text: Optional[str] = None): if image is not None: image = load(image) if self.preprocess: image = expand(image, max(image.size), do_trim=True) with torch.inference_mode(): if text is not None: encoding = self.processor(text=text, images=image, return_tensors="pt").to(self.device, self.dtype) else: encoding = self.processor(images=image, return_tensors="pt").to(self.device, self.dtype) if self.mode == "cos": return self.model(**encoding).pooler_output.squeeze() elif self.mode == "cos_avg": return self.model(**encoding).last_hidden_state.squeeze().mean(dim=0) else: return self.model(**encoding).last_hidden_state.squeeze() def get_similarity( self, img1: Optional[Image.Image | str] = None, img2: Optional[Image.Image | str] = None, text1: Optional[str] = None, text2: Optional[str] = None, ): img1_feats = self.get_vision_features(img1, text1) img2_feats = self.get_vision_features(img2, text2) if img1_feats.is_mps: # mps backend does not support dtype double img1_feats, img2_feats = img1_feats.cpu(), img2_feats.cpu() if img1_feats.ndim > 1: dists = 1 - pairwise_cosine_similarity(img1_feats.double(), img2_feats.double()).cpu().numpy() return 2 * tanh(-emd2(M=dists, a=list(), b=list())) + 1 # type: ignore else: return F.cosine_similarity(img1_feats.double(), img2_feats.double(), dim=0).item() def update( self, img1: Optional[Image.Image | str | List[Image.Image | str]] = None, img2: Optional[Image.Image | str | List[Image.Image | str]] = None, text1: Optional[str | List[str]] = None, text2: Optional[str | List[str]] = None, ): inputs = dict() for key, value in dict(img1=img1, img2=img2, text1=text1, text2=text2).items(): if value is not None: inputs[key] = value if isinstance(value, List) else [value] assert not ({"img1", "text1"}.isdisjoint(inputs.keys()) or {"img2", "text2"}.isdisjoint(inputs.keys())) assert len(set(map(len, inputs.values()))) == 1 for inpt in zip(*inputs.values()): self.score += self.get_similarity(**dict(zip(inputs.keys(), inpt))) self.n_samples += 1 def compute(self): return (self.score / self.n_samples).item() ``` ## /detikzify/evaluate/kid.py ```py path="/detikzify/evaluate/kid.py" from functools import cached_property from typing import List from PIL import Image import torch from torch import nn from torch.cuda import is_available as is_cuda_available, is_bf16_supported from torchmetrics.image.kid import KernelInceptionDistance as KID from transformers import AutoModel, AutoImageProcessor from ..util import expand, infer_device, load class FeatureWrapper(nn.Module): def __init__(self, model_name, device, dtype): super().__init__() self.model_name = model_name self.device = device self.dtype = dtype @cached_property def model(self): model = AutoModel.from_pretrained(self.model_name, torch_dtype=self.dtype) return model.to(self.device) def forward(self, pixel_values): with torch.inference_mode(): return self.model.get_image_features(pixel_values.to(self.device, self.dtype)) class KernelInceptionDistance(KID): """Wrapper around torchmetrics Kernel Inception Distance with CLIP""" def __init__( self, model_name: str = "google/siglip-so400m-patch14-384", subset_size: int = 50, preprocess: bool = True, device: str = infer_device(), dtype=torch.bfloat16 if is_cuda_available() and is_bf16_supported() else torch.float16, **kwargs ): super().__init__( subset_size=subset_size, feature=FeatureWrapper( model_name=model_name, device=device, dtype=dtype), **kwargs ) self.preprocess = preprocess def __str__(self): return self.__class__.__name__ @cached_property def processor(self): return AutoImageProcessor.from_pretrained(self.inception.model_name) def open(self, img): img = load(img) if self.preprocess: return expand(img, max(img.size), do_trim=True) return img def update(self, imgs: Image.Image | str | List[Image.Image | str], *args, **kwargs): if not isinstance(imgs, List): imgs = [imgs] super().update( self.processor([self.open(img) for img in imgs], return_tensors="pt")["pixel_values"], *args, **kwargs ) def compute(self, *args, **kwargs): # type: ignore return tuple(tensor.item() for tensor in super().compute(*args, **kwargs)) ``` ## /detikzify/infer/__init__.py ```py path="/detikzify/infer/__init__.py" from .tikz import * from .generate import * ``` ## /detikzify/infer/generate.py ```py path="/detikzify/infer/generate.py" from collections import deque from dataclasses import dataclass from functools import cached_property from math import sqrt from multiprocessing.pool import ThreadPool from re import sub from time import time from types import SimpleNamespace as Namespace from typing import Any, Dict, Generator, List, Literal, Optional, Set, Tuple, Union from PIL import Image import torch from torchmetrics import Metric from transformers import StoppingCriteriaList from transformers.generation.streamers import BaseStreamer from ..evaluate.imagesim import ImageSim from ..mcts.montecarlo import MonteCarlo from ..mcts.node import Node from ..model.adapter import has_adapter from ..util import ( ExplicitAbort, StreamerList, TokenStreamer, cache_cast, expand, load, unwrap_processor as unwrap, ) from .tikz import TikzDocument Numeric = Union[int, float] @dataclass(frozen=True) class NodeState: token_ids: torch.Tensor num_lines: int = 0 def __eq__(self, other: Any) -> bool: try: return self.token_ids.equal(other.token_ids) except (AttributeError, TypeError): return False def __hash__(self): return hash(tuple(self.token_ids.tolist())) class WideNode(Node): state: NodeState def __init__(self, *args, exploration=0.6, is_widen_node=False, **kwargs): super().__init__(NodeState(*args, **kwargs)) self.discovery_factor = exploration self.is_widen_node = is_widen_node self.update_policy_value(1.0) if not is_widen_node: self.add_child(WideNode( *args, exploration=exploration, is_widen_node=not is_widen_node, **kwargs )) def add_child(self, child): self.expanded = self.expanded or not child.is_widen_node super().add_child(child) @property def depth(self) -> int: depth, current = 0, self while parent:=current.parent: depth, current = depth + 1, parent return depth @property def token_ids(self): return self.state.token_ids @property def num_lines(self): return self.state.num_lines class DynMinMaxNorm: def __init__(self, default_value: Numeric = 0): self.scores = set() self.default_value = default_value def normalize(self, score: Numeric) -> "MinMaxScore": self.scores.add(score) return self.MinMaxScore(score, all_scores=self.scores, default_value=self.default_value) def __call__(self, *args, **kwargs) -> "MinMaxScore": return self.normalize(*args, **kwargs) class MinMaxScore: def __init__( self, *scores: Numeric, all_scores: Set[Numeric], default_value: Numeric, no_minmax_scores: List[Numeric] = list(), ): self.scores = list(scores) self.all_scores = all_scores self.default_value = default_value self.no_minmax_scores = no_minmax_scores.copy() @property def score(self) -> Numeric: min_score, max_score = min(self.all_scores), max(self.all_scores) try: score = sum((score - min_score) / (max_score - min_score) for score in self.scores) except ZeroDivisionError: score = self.default_value return score + sum(self.no_minmax_scores) def __add__(self, other: Any) -> "DynMinMaxNorm.MinMaxScore": new = self.__class__( *self.scores, all_scores=self.all_scores, default_value=self.default_value, no_minmax_scores=self.no_minmax_scores ) try: new.scores.extend(other.scores) new.no_minmax_scores.extend(other.no_minmax_scores) except AttributeError: new.no_minmax_scores.append(other) return new def __mul__(self, other: Any) -> "DynMinMaxNorm.MinMaxScore": return self.score * other def __truediv__(self, other: Any) -> "DynMinMaxNorm.MinMaxScore": return self.score / other def __rtruediv__(self, other: Any) -> "DynMinMaxNorm.MinMaxScore": return other / self.score __radd__, __rmul__ = __add__, __mul__ class DetikzifyGenerator: def __init__( self, model, processor, image: Optional[Image.Image], text: Optional[str] = None, metric: Optional[Metric] = None, compile_timeout: Optional[int] = 60, mcts_timeout: Optional[int] = None, streamer: Optional[BaseStreamer] = None, control: Optional[ExplicitAbort] = None, exploration: float = 0.6, # exploration coefficient strict: bool = False, # if True, treat recoverable errors same as fatal errors when computing scores **gen_kwargs, ): self.model = model self.processor = processor self.metric = metric self.image = image self.text = text self.compile_timeout = compile_timeout self.mcts_timeout = mcts_timeout self.streamer = streamer self.exploration = exploration self.strict = strict self.gen_kwargs = gen_kwargs self.solution = deque(maxlen=1) self.failed_rollouts = dict() self.norm = DynMinMaxNorm() self.control = control or ExplicitAbort() self.montecarlo = MonteCarlo( root_node=WideNode( processor( images=self.image, text=self.text, return_tensors="pt", ).input_ids.to(model.device).squeeze(), exploration=self.exploration ) ) self.montecarlo.child_finder = self.child_finder # type: ignore # https://stackoverflow.com/a/68550238 self.decode = cache_cast(lambda token_ids: tuple(token_ids.tolist()))(self.decode) self.score = cache_cast(lambda image: image.tobytes())(self.score) def __call__(self, *args, **kwargs): return self.simulate(*args, **kwargs) def simulate(self, expansions: Optional[Numeric] = 1) -> Generator[Tuple[Numeric, TikzDocument], None, None]: """ Run the simulations. Yields all rollouts (successful and unsuccessful) as (score, TikZ picture) tuples. """ start_time = time() while expansions is None or (expansions:=expansions-1) >= 0: self.montecarlo.simulate() yield self.solution.pop() if self.mcts_timeout is not None and time() - start_time > self.mcts_timeout: return def generate(self, input_ids: torch.Tensor, streamer: Optional[BaseStreamer] = None, **gen_kwargs) -> torch.Tensor: streamers, numel = StreamerList(filter(bool, [streamer, self.streamer])), input_ids.numel() max_length = {**self.model.generation_config.to_dict(), **self.gen_kwargs, **gen_kwargs}["max_length"] if (numel and input_ids[-1] == unwrap(self.processor).tokenizer.eos_token_id) or numel >= max_length: streamers.end() return input_ids # prevent continuing generation after eos with torch.inference_mode(): token_ids = self.processor(images=self.image, text=self.text, text_kwargs={"truncation": True}, return_tensors="pt") adapter_kwargs = {k: v for k, v in token_ids.to(self.model.device).items() if k.startswith("adapter")} return self.model.generate( input_ids=input_ids.unsqueeze(0), bad_words_ids=[[self.model.config.image_token_id]], pixel_values=token_ids.get("pixel_values"), streamer=streamers, **adapter_kwargs, **self.gen_kwargs, **gen_kwargs ).squeeze() @cached_property def newlineinfo(self): # tokens can potentially contain multiple newlines, so we need special # handling when we want to map error lines to tokens newlineinfo = dict() for token_id in unwrap(self.processor).tokenizer.vocab.values(): # NOTE: Newline normalization might lead to inaccurate estimations # for windows line separators (if split over two tokens). However, # the likeliness of such tokens being generated (not in training # data) as well as the potential impact is negligible. # https://www.overleaf.com/learn/latex/Articles/An_introduction_to_%5Cendlinechar%3A_How_TeX_reads_lines_from_text_files token = sub(r"\r\n|\r", r"\n", self.processor.decode([token_id])) if (num_lines:=token.count("\n")): newlineinfo[token_id] = Namespace(num_lines=num_lines, trailing=token.endswith("\n")) assert newlineinfo return newlineinfo def rollout(self, state: NodeState) -> Generator[Tuple[torch.Tensor, int], None, None]: input_ids, num_lines, continuation = state.token_ids, state.num_lines, False with ThreadPool(processes=1) as thread: streamer = TokenStreamer() async_result = thread.apply_async( func=self.generate, error_callback=streamer.propagate_error, args=[input_ids], kwds=dict( stopping_criteria=StoppingCriteriaList([self.control.reset()]), streamer=streamer, ) ) try: prev_ids, line = input_ids, list() for token in streamer: line.append(token) if info:=self.newlineinfo.get(token): # continuations (newline followed by text in a single # token) don't appear in the llama 3.1 tokenizer, but # handling them now makes this code future proof num_lines += info.num_lines - continuation continuation = not info.trailing prev_ids = torch.cat((prev_ids, torch.tensor(line, device=prev_ids.device))) line.clear() yield prev_ids, num_lines if line: yield torch.cat((prev_ids, torch.tensor(line, device=prev_ids.device))), num_lines - continuation except (GeneratorExit, KeyboardInterrupt): self.control.abort() raise else: if self.control.should_stop: raise InterruptedError finally: async_result.wait() def decode(self, token_ids: torch.Tensor) -> TikzDocument: return TikzDocument( timeout=self.compile_timeout, code=self.processor.decode( token_ids=token_ids[len(self.montecarlo.root_node.token_ids):], skip_special_tokens=True ) ) def score(self, image: Image.Image) -> Numeric: assert self.metric self.metric.update(img1=image, img2=self.image, text2=self.text) score = self.metric.compute() self.metric.reset() return score def sample(self): return self.decode(self.generate( input_ids=self.montecarlo.root_node.token_ids, )) def child_finder(self, node: WideNode, montecarlo: MonteCarlo): new_nodes = list() for new_state in (rollout:=self.rollout(node.state)): new_node = WideNode(*new_state, exploration=self.exploration) if new_node.state in self.failed_rollouts: new_nodes.extend(self.failed_rollouts[new_node.state]) rollout.close() break new_nodes.append(new_node) if node.is_widen_node: node.visits += 1 node, new_nodes = self.merge(node.parent, new_nodes) # type: ignore tikz = self.decode((new_nodes or [node])[-1].token_ids) skip_idx = round(sqrt(len(new_nodes))) if scorable:=(tikz.is_rasterizable and not (self.strict and tikz.compiled_with_errors)): for new_node in new_nodes[:skip_idx]: node.add_child(node:=new_node) # Only process failed rollouts when we can locate the error (errorln != # 0). In rare cases there is no error information even though the # tikzpic is not rasterizable because only cropping failed -> use [0]. elif errorln:=min(tikz.errors or [0]): for idx, new_node in enumerate(new_nodes): ends_with_eol = self.newlineinfo.get(new_node.token_ids[-1]) if new_node.num_lines < errorln and idx < skip_idx: node.add_child(node:=new_node) elif new_node.num_lines > errorln or (new_node.num_lines == errorln and ends_with_eol): self.failed_rollouts[new_node.state] = new_nodes[idx:] break if self.metric: score = self.score(tikz.rasterize()) if scorable else -1 # type: ignore else: # if we do not have a metric, use compiler logs instead score = scorable - tikz.compiled_with_errors node.update_win_value(self.norm(score) if scorable and self.metric else score) self.solution.append((score, tikz)) def merge(self, node: WideNode, nodes_to_merge: List[WideNode]) -> Tuple[WideNode, List[WideNode]]: for merge_node in nodes_to_merge: for child in node.children: if child.state == merge_node.state: node, nodes_to_merge = child, nodes_to_merge[1:] break else: break return node, nodes_to_merge class DetikzifyPipeline: def __init__( self, model, processor, # hyperparams based on "a systematic evaluation of large language models of code" temperature: float = 0.8, top_p: float = 0.95, top_k: int = 0, compile_timeout: Optional[int] = 60, # same as old overleaf compile timeout metric: Union[Literal["model", "fast"], Metric] = "model", **gen_kwargs, ): self.model = model self.processor = processor if metric == "model": # SelfSim self.metric = ImageSim.from_detikzify(model, processor, sync_on_compute=False) elif metric == "fast": # Compiler Diagnostics self.metric = None else: self.metric = metric self.gen_kwargs: Dict[str, Any] = dict( temperature=temperature, top_p=top_p, top_k=top_k, max_length=unwrap(processor).tokenizer.model_max_length, do_sample=True, compile_timeout=compile_timeout, **gen_kwargs ) def load(self, image: Union[Image.Image, str], preprocess: bool = True): image = load(image) if preprocess: return expand(image, max(image.size), do_trim=True) return image def check_inputs(self, image, text): assert text is None or has_adapter(self.model), "You need to load an adapter for textual inputs!" assert image or text, "Either image or text (or both) required!" def sample( self, image: Optional[Union[Image.Image, str]] = None, text: Optional[str] = None, preprocess: bool = True, **gen_kwargs, ) -> TikzDocument: """ DeTikZify a raster image. Samples a single image and returns it. image: the image text: textual instruction preprocess: whether to preprocess the image (expand to square and trim to content) gen_kwargs: additional generation kwargs (potentially overriding the default ones) """ self.check_inputs(image, text) generator = DetikzifyGenerator( model=self.model, processor=self.processor, image=self.load(image, preprocess=preprocess) if image is not None else None, text=text, **self.gen_kwargs, **gen_kwargs ) return generator.sample() def simulate( self, image: Optional[Union[Image.Image, str]] = None, text: Optional[str] = None, preprocess: bool = True, expansions: Optional[Numeric] = None, timeout: Optional[int] = None, **gen_kwargs, ) -> Generator[Tuple[Numeric, TikzDocument], None, None]: """ DeTikZify a raster image using MCTS. Returns an iterator yielding (score, tikzdoc) tuples of TikZ documents created during rollouts. image: the image text: textual instruction preprocess: whether to preprocess the image (expand to square and trim to content) expansions: number of attempted MCTS expansions (set to None, 0 or math.inf for infinite) timeout: timeout for MCTS in seconds (set to 0, math.inf, or None for infinite) gen_kwargs: additional generation kwargs (potentially overriding the default ones) """ self.check_inputs(image, text) generator = DetikzifyGenerator( model=self.model, processor=self.processor, metric=self.metric, mcts_timeout=timeout or None, image=self.load(image, preprocess=preprocess) if image is not None else None, text=text, **self.gen_kwargs, **gen_kwargs ) yield from generator.simulate(expansions or None) def __call__(self, *args, **kwargs) -> TikzDocument: return self.sample(*args, **kwargs) ``` ## /detikzify/infer/tikz.py ```py path="/detikzify/infer/tikz.py" from collections import namedtuple from functools import cache, cached_property from io import BytesIO from os import environ from os.path import isfile, join from re import MULTILINE, escape, findall, search from subprocess import CalledProcessError, DEVNULL, TimeoutExpired from tempfile import NamedTemporaryFile, TemporaryDirectory from typing import Dict, Optional, Union from PIL import Image from pdf2image.pdf2image import convert_from_bytes from pdfCropMargins import crop import pymupdf from transformers.utils import logging from ..util import check_output, expand, redact as redact_text logger = logging.get_logger("transformers") class TikzDocument: """ Facilitate some operations with TikZ code. To compile the images a full TeXLive installation is assumed to be on the PATH. Cropping additionally requires Ghostscript, and rasterization needs poppler. """ # engines to try, could also try: https://tex.stackexchange.com/a/495999 engines = ["pdflatex", "lualatex", "xelatex"] Output = namedtuple("Output", ['pdf', 'status', 'log'], defaults=[None, -1, ""]) def __init__(self, code: str, timeout: Optional[int] = 60): self.code = code self.timeout = timeout # https://stackoverflow.com/a/68550238 self.compile = cache(self.compile) @property def status(self) -> int: return self.compile().status @property def pdf(self) -> Optional[pymupdf.Document]: # type: ignore return self.compile().pdf @property def log(self) -> str: return self.compile().log @property def compiled_with_errors(self) -> bool: return self.status != 0 @property def errors(self, rootfile: Optional[str] = None) -> Dict[int, str]: """ Returns a dict of (linenr, errormsg) pairs. linenr==0 is a special value reserved for errors that do not have a linenumber in rootfile. """ if self.compiled_with_errors: if not rootfile and (match:=search(r"^\((.+)$", self.log, MULTILINE)): rootfile = match.group(1) else: ValueError("rootfile not found!") errors = dict() for file, line, error in findall(r'^(.+):(\d+):(.+)$', self.log, MULTILINE): if file == rootfile: errors[int(line)] = error.strip() else: # error occurred in other file errors[0] = error.strip() return errors or {0: "Fatal error occurred, no output PDF file produced!"} return dict() @cached_property def is_rasterizable(self) -> bool: """true if we have an image""" return self.rasterize() is not None @cached_property def has_content(self) -> bool: """true if we have an image that isn't empty""" return (img:=self.rasterize()) is not None and img.getcolors(1) is None @classmethod def set_engines(cls, engines: Union[str, list]): cls.engines = [engines] if isinstance(engines, str) else engines def compile(self) -> "Output": output = dict() with TemporaryDirectory() as tmpdirname: with NamedTemporaryFile(dir=tmpdirname, buffering=0) as tmpfile: codelines = self.code.split("\n") # make sure we don't have page numbers in compiled pdf (for cropping) codelines.insert(1, r"{cmd}\AtBeginDocument{{{cmd}}}".format(cmd=r"\thispagestyle{empty}\pagestyle{empty}")) tmpfile.write("\n".join(codelines).encode()) try: # compile errorln, tmppdf, outpdf = -1, f"{tmpfile.name}.pdf", join(tmpdirname, "tikz.pdf") open(f"{tmpfile.name}.bbl", 'a').close() # some classes expect a bibfile def try_save_last_page(): try: doc = pymupdf.open(tmppdf) doc.select([len(doc)-1]) doc.save(outpdf) except: pass for engine in self.engines: try: check_output( cwd=tmpdirname, timeout=self.timeout, stderr=DEVNULL, env=environ | dict(max_print_line="1000"), # improve formatting of log args=["latexmk", "-f", "-nobibtex", "-norc", "-file-line-error", "-interaction=nonstopmode", f"-{engine}", tmpfile.name] ) except (CalledProcessError, TimeoutExpired) as proc: log = (getattr(proc, "output", b'') or b'').decode(errors="ignore") error = search(rf'^{escape(tmpfile.name)}:(\d+):.+$', log, MULTILINE) # only update status and log if first error occurs later than in previous engine if (linenr:=int(error.group(1)) if error else 0) > errorln: errorln = linenr output.update(status=getattr(proc, 'returncode', -1), log=log) try_save_last_page() else: output.update(status=0, log='') try_save_last_page() break # crop croppdf = f"{tmpfile.name}.crop" crop(["-gsf", "-c", "gb", "-p", "0", "-a", "-1", "-o", croppdf, outpdf], quiet=True) if isfile(croppdf): output['pdf'] = pymupdf.open(croppdf) except FileNotFoundError: logger.error("Missing dependencies: Did you install TeX Live?") except RuntimeError: # pdf error during cropping pass if output.get("status") == 0 and not output.get("pdf", None): logger.warning("Could compile document but something seems to have gone wrong during cropping!") return self.Output(**output) def rasterize(self, size=420, expand_to_square=True, redact=False, **redact_kwargs) -> Optional[Image.Image]: if pdf:=self.pdf: if redact: pdf = redact_text(pdf, **redact_kwargs) image = convert_from_bytes(pdf.tobytes(), size=size, single_file=True)[0] if expand_to_square: return expand(image, size) return image def save(self, filename: str, *args, **kwargs): match filename.split(".")[-1]: case "tex": content = self.code.encode() case "pdf" if self.pdf: content = self.pdf.tobytes() case fmt if img := self.rasterize(*args, **kwargs): img.save(imgByteArr:=BytesIO(), format=fmt) content = imgByteArr.getvalue() case fmt: raise ValueError(f"Couldn't save with format '{fmt}'!") with open(filename, "wb") as f: f.write(content) ``` ## /detikzify/mcts/LICENSE ``` path="/detikzify/mcts/LICENSE" The MIT License Copyright (c) 2010-2018 ImparaAI Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ``` ## /detikzify/mcts/README.md A Python3 library for running a [Monte Carlo tree search](https://en.wikipedia.org/wiki/Monte_Carlo_tree_search), either traditionally by drilling down to end game states or with expert policies as might be provided by a neural network. Adapted from **Version:** 1.3.1 of [ImparaAI/monte-carlo-tree-search](https://github.com/ImparaAI/monte-carlo-tree-search). # Monte Carlo tree search basics The Monte Carlo tree search (MCTS) algorithm can help with making a decision from a number of options. It avoids exploring every possible option by randomly sampling a small number of pathways and picking the move with the highest probability of victory. This is commonly applied to games like chess or go where it's useful to know what move should come next if you want to win the game. MCTS works by expanding the search tree to figure out which moves (or child/subsequent states) are likely to produce a positive result if chosen. While time is available, the algorithm continues to explore the tree, always slightly favoring the direction that has either proven to be fruitful or is less explored. When no time is left, the most explored direction is chosen. The search tree expansion can be done in two different ways: - **Traditional**: At least one random rollout to a game's end state (e.g. win, loss, tie) for each move under evaluation so the algorithm can make a choice. - **Expert policy (i.e. neural network)**: Instead of expensively rolling all the way out to a game's end state ask an expert (a neural network for example) which move is most likely to produce a positive outcome. For a deeper dive into the topic, check out [this article](http://tim.hibal.org/blog/alpha-zero-how-and-why-it-works/). # This library As the user of this library, you only have to provide: - A function that finds the direct children of each search tree node (called the **`child_finder`**) - A function for evaluating nodes for end state outcomes (called the **`node_evaluator`**) -- *(Not necessary with neural network)* # Usage Create a new Monte Carlo tree: ```python from chess import Game from montecarlo.node import Node from montecarlo.montecarlo import MonteCarlo chess_game = Game() montecarlo = MonteCarlo(Node(chess_game)) ``` The root node describes your current game state. This state will be used by you later in the **`child_finder`** and the **`node_evaluator`**. For the sake of demonstration, we will assume you have a generic `Game` library that can tell you what moves are possible and allows you to perform those moves to change the game's state. ## Traditional Monte Carlo Add a **`child_finder`** and a **`node_evaluator`**: ```python def child_finder(node, montecarlo): for move in node.state.get_possible_moves(): child = Node(deepcopy(node.state)) #or however you want to construct the child's state child.state.move(move) #or however your library works node.add_child(child) def node_evaluator(node, montecarlo): if node.state.won(): return 1 elif node.state.lost(): return -1 montecarlo.child_finder = child_finder montecarlo.node_evaluator = node_evaluator ``` The **`child_finder`** should add any child nodes to the parent node passed into the function, if there are any. If there are none, the parent should be in an end state, so the **`node_evaluator`** should return a value between `-1` and `1`. ## Expert policy (AI) If you have an expert policy that you can apply to the children as they're being generated, the library will recognize that it doesn't need to make the costly drill down to an end state. If your neural net produces both an expert policy value for the children and a win value for the parent node, you can skip declaring the `node_evaluator` altogether. ```python def child_finder(node, montecarlo): win_value, expert_policy_values = neural_network.predict(node.state) for move in node.state.get_possible_moves(): child = Node(deepcopy(node.state)) child.state.move(move) child.player_number = child.state.whose_turn() child.policy_value = get_child_policy_value(child, expert_policy_values) #should return a probability value between 0 and 1 node.add_child(child) node.update_win_value(win_value) montecarlo.child_finder = child_finder ``` ## Simulate and make a choice Run the simulations: ```python montecarlo.simulate(50) #number of expansions to run. higher is typically more accurate at the cost of processing time ``` Once the simulations have run you can ask the instance to make a choice: ```python chosen_child_node = montecarlo.make_choice() chosen_child_node.state.do_something() ``` After you've chosen a new root node, you can override it on the `montecarlo` instance and do more simulations from the new position in the tree. ```python montecarlo.root_node = montecarlo.make_choice() ``` If you're training a neural network, you may want to make a more exploratory choice for the first N moves of a game: ```python montecarlo.root_node = montecarlo.make_exploratory_choice() ``` This won't provide a purely random choice, rather it will be random with a bias favoring the more explored pathways. ## Turn-based environments If you are modeling a turn-based environment (e.g. a two player board game), set the `player_number` on each node so the selection process can invert child win values: ```python node = Node(state) node.player_number = 1 ``` It doesn't matter what this number is (you can use 1 and 2 or 5 and 6), only that it is consistent with other nodes. ## Tweaking the discovery factor When building a new child node, you can change the rate at which discovery is preferred: ```python node = Node(state) node.discovery_factor = 0.2 #0.35 by default, can be between 0 and 1 ``` The closer this number is to 1, the more discovery will be favored over demonstrated value in later simulations. ## /detikzify/mcts/__init__.py ```py path="/detikzify/mcts/__init__.py" ``` ## /detikzify/mcts/montecarlo.py ```py path="/detikzify/mcts/montecarlo.py" import random import time class MonteCarlo: def __init__(self, root_node, mins_timeout=None): self.root_node = root_node self.solution = None self.child_finder = None self.node_evaluator = lambda child, montecarlo: None self.stats_expansion_count = 0 self.stats_failed_expansion_count = 0 self.mins_timeout = mins_timeout def make_choice(self): best_children = [] most_visits = float("-inf") for child in self.root_node.children: if child.visits > most_visits: most_visits = child.visits best_children = [child] elif child.visits == most_visits: best_children.append(child) return random.choice(best_children) def make_exploratory_choice(self): children_visits = map(lambda child: child.visits, self.root_node.children) children_visit_probabilities = [ visit / self.root_node.visits for visit in children_visits ] random_probability = random.uniform(0, 1) probabilities_already_counted = 0.0 for i, probability in enumerate(children_visit_probabilities): if probabilities_already_counted + probability >= random_probability: return self.root_node.children[i] probabilities_already_counted += probability def simulate(self, expansion_count=1): i = 0 start_time = time.time() while expansion_count is None or i < expansion_count: i += 1 if self.solution is not None: return if self.mins_timeout is not None: curr_time = time.time() duration = curr_time - start_time if duration > (self.mins_timeout * 60): print("reached timelimit, stopping expansion on current node") return current_node = self.root_node while current_node.expanded: current_node = current_node.get_preferred_child(self.root_node) self.expand(current_node) def expand(self, node): self.stats_expansion_count += 1 self.child_finder(node, self) for child in node.children: child_win_value = self.node_evaluator(child, self) if child_win_value != None: child.update_win_value(child_win_value) if not child.is_scorable(): self.random_rollout(child) child.children = [] if len(node.children): node.expanded = True else: self.stats_failed_expansion_count += 1 def random_rollout(self, node): self.child_finder(node, self) child = random.choice(node.children) node.children = [] node.add_child(child) child_win_value = self.node_evaluator(child, self) if child_win_value != None: node.update_win_value(child_win_value) else: self.random_rollout(child) def print_tree(self, f): f.write("graph\n{\n") self.root_node.print_node(f, 0, self.root_node, "a") f.write("}\n") ``` ## /detikzify/mcts/node.py ```py path="/detikzify/mcts/node.py" import random import json from math import log, sqrt class Node: def __init__(self, state): self.state = state self.win_value = 0 self.policy_value = None self.visits = 0 self.parent = None self.children = [] self.expanded = False self.player_number = None self.discovery_factor = 0.35 self.is_widen_node = False def update_win_value(self, value): self.win_value += value self.visits += 1 if self.parent: self.parent.update_win_value(value) def update_policy_value(self, value): self.policy_value = value def add_child(self, child): self.children.append(child) child.parent = self def add_children(self, children): for child in children: self.add_child(child) def get_preferred_child(self, root_node): best_children = [] best_score = float("-inf") for child in self.children: score = child.get_score(root_node) if score > best_score: best_score = score best_children = [child] elif score == best_score: best_children.append(child) return random.choice(best_children) def get_score(self, root_node): discovery_operand = ( self.discovery_factor * (self.policy_value or 1) * sqrt(log(self.parent.visits) / (self.visits or 1)) ) if self.is_widen_node: win_operand = 0 else: win_multiplier = ( 1 if self.parent.player_number == root_node.player_number else -1 ) win_operand = win_multiplier * self.win_value / (self.visits or 1) self.score = win_operand + discovery_operand return self.score def is_scorable(self): return self.visits or self.policy_value != None def print_node(self, f, i, root, st): escape = lambda x : json.dumps(x).strip('"') if self.parent is None: f.write((' ' * i) + st + " [label=\"" + escape(self.state) + "\",shape=box]\n") else: diff = '\n'.join([x for x in self.state.split("\n") if x not in self.parent.state.split("\n")]) f.write((' ' * i) + st + " [label=\"" + escape(diff) + "\",shape=box]\n") num = 0 for child in self.children: new_st = st + "_" + str(num) child.print_node(f, i + 2, root, new_st) f.write(' ' * i + st + " -- " + new_st + "\n") num = num + 1 ``` ## /detikzify/model/__init__.py ```py path="/detikzify/model/__init__.py" from datasets import DownloadManager from safetensors.torch import load_file from transformers.utils.hub import has_file from transformers import ( AutoConfig, AutoModelForVision2Seq, AutoProcessor, is_timm_available, ) from transformers.utils.hub import is_remote_url from .configuration_detikzify import * from .modeling_detikzify import * from .processing_detikzify import * from .adapter import load as load_adapter if is_timm_available(): from .v1 import models as v1_models, load as load_v1 def register(): try: AutoConfig.register("detikzify", DetikzifyConfig) AutoModelForVision2Seq.register(DetikzifyConfig, DetikzifyForConditionalGeneration) AutoProcessor.register(DetikzifyConfig, DetikzifyProcessor) except ValueError: pass # already registered def load(model_name_or_path, modality_projector=None, is_v1=False, **kwargs): # backwards compatibility with v1 models if is_timm_available() and (is_v1 or model_name_or_path in v1_models): # type: ignore model, tokenizer, image_processor = load_v1( # type: ignore model_name_or_path=model_name_or_path, modality_projector=modality_projector, **kwargs ) return model, DetikzifyProcessor( tokenizer=tokenizer, image_processor=image_processor, image_seq_len=model.config.num_patches, image_token=tokenizer.convert_ids_to_tokens(model.config.patch_token_id) ) register() processor = AutoProcessor.from_pretrained(model_name_or_path) model = AutoModelForVision2Seq.from_pretrained(model_name_or_path, **kwargs) if modality_projector is not None: if is_remote_url(modality_projector): modality_projector = DownloadManager().download(modality_projector) model.load_state_dict( state_dict=load_file( filename=modality_projector, # type: ignore device=str(model.device) ), strict=False ) if has_file(model_name_or_path, "adapter/model.safetensors"): model, processor = load_adapter(model=model, processor=processor) return model, processor ``` ## /detikzify/model/adapter/__init__.py ```py path="/detikzify/model/adapter/__init__.py" from transformers import AutoTokenizer from .modeling_adapter import CrossAttentionAdapterMixin from .processing_adapter import AdapterProcessor def has_adapter(model): return hasattr(model, "adapter") def load(model, processor, adapter_name_or_path=None, **kwargs): embedding_model = "meta-llama/Llama-3.2-1B" model.load_cross_attn_adapter(embedding_model, adapter_name_or_path, **kwargs) processor = AdapterProcessor( processor=processor, tokenizer=AutoTokenizer.from_pretrained( embedding_model, pad_token="<|finetune_right_pad_id|>", model_max_length=512, ), ) model.embedding_model.config.pad_token_id = processor.tokenizer.pad_token_id return model, processor ``` ## /detikzify/model/adapter/modeling_adapter.py ```py path="/detikzify/model/adapter/modeling_adapter.py" # coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Adapted from modelling_mllama.py and modelling_siglip.py # https://github.com/huggingface/transformers/commit/2e24ee4dfa39cc0bc264b89edbccc373c8337086 from functools import partial from os.path import basename from typing import Optional, Tuple import torch from torch import nn from transformers import AutoModel, PreTrainedModel, PretrainedConfig from transformers.activations import ACT2FN from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.utils import ( is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, ) if is_flash_attn_2_available(): from transformers.modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) class CrossAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, config: Optional[PretrainedConfig] = None, ): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = self.config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_norm = nn.LayerNorm(self.head_dim, eps=config.layer_norm_eps) self.k_norm = nn.LayerNorm(self.head_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) k_v_seq_len = key_states.shape[-2] attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): raise ValueError( f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): raise ValueError( f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights class CrossSdpaAttention(CrossAttention): """ Attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `CrossAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ # Adapted from CrossAttention.forward def forward( self, hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( "Using CrossSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, cross_attention_states=cross_attention_states, attention_mask=attention_mask, output_attentions=output_attentions, ) batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=False, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None class CrossFlashAttention2(CrossAttention): """ CrossAttention flash attention module. This module inherits from `CrossAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward def forward( self, hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: output_attentions = False batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) if attention_mask is not None and attention_mask.all(): # FIXME: figure out why all 1 attention mask leads to different results attention_mask = None attn_output = _flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, is_causal=False, use_top_left_mask=self._flash_attn_uses_top_left_mask, ) attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights CROSS_ATTENTION_CLASSES = { "eager": CrossAttention, "sdpa": CrossSdpaAttention, "flash_attention_2": CrossFlashAttention2, } class MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class CrossAttentionLayer(torch.nn.Module): """Cross-attention transformer block with sigmoid-gated attention and feedforward.""" def __init__(self, config: PretrainedConfig) -> None: super().__init__() self.embed_dim = config.hidden_size self.cross_attn = CROSS_ATTENTION_CLASSES[config._attn_implementation](config=config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = MLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) def forward( self, hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, cross_attention_mask: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.cross_attn( hidden_states=hidden_states, attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, output_attentions=output_attentions, ) hidden_states = residual + self.cross_attn_attn_gate.sigmoid() * hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + self.cross_attn_mlp_gate.sigmoid() * hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class CrossAttentionAdapter(PreTrainedModel): base_model_prefix = "model" no_split_modules = ["CrossAttentionLayer"] supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True def __init__(self, config, input_hidden_size, cross_attn_every_n_layers=1): super().__init__(config) self.num_patches = (config.image_size // config.patch_size) ** 2 self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.layers = nn.ModuleList([ # type: ignore CrossAttentionLayer(config) if (layer_idx + 1) % cross_attn_every_n_layers == 0 else None for layer_idx in range(config.num_hidden_layers) ]) self.connector = nn.Linear( input_hidden_size, config.hidden_size, bias=True ) self.dummy_input = nn.Parameter( torch.ones( config.num_channels, config.image_size, config.image_size ) ) self.post_init() def connect(self, inputs): return self.connector(inputs) def prepare_4d_attention_mask(self, attention_mask, dtype): if attention_mask is not None and not self._use_flash_attention_2: # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] return _prepare_4d_attention_mask(attention_mask, dtype, self.num_patches) return attention_mask # pyright: reportAttributeAccessIssue=false class CrossAttentionAdapterMixin: def init_cross_attn_adapter( self, model_or_model_name_or_path, cross_attn_every_n_layers: Optional[int] = 1, embedding_kwargs: dict = dict(), adapter_kwargs: dict = dict(), **kwargs, ): self.embedding_model = self.load_embedding_model( model_or_model_name_or_path, **embedding_kwargs, **kwargs ).to(self.device) self.adapter = CrossAttentionAdapter._from_config( input_hidden_size=self.embedding_model.config.hidden_size, cross_attn_every_n_layers=cross_attn_every_n_layers, config=getattr(self.config, "vision_config", self.config), torch_dtype=self.dtype, **adapter_kwargs, **kwargs ).to(self.device, self.dtype) self.add_hooks() def load_cross_attn_adapter( self, model_or_model_name_or_path, adapter_name_or_path: Optional[str] = None, cross_attn_every_n_layers: Optional[int] = 1, embedding_kwargs: dict = dict(), adapter_kwargs: dict = dict(), **kwargs, ): self.embedding_model = self.load_embedding_model( model_or_model_name_or_path, **embedding_kwargs, **kwargs ) if adapter_name_or_path is not None: self.adapter = CrossAttentionAdapter.from_pretrained( pretrained_model_name_or_path=adapter_name_or_path, input_hidden_size=self.embedding_model.config.hidden_size, cross_attn_every_n_layers=cross_attn_every_n_layers, config=getattr(self.config, "vision_config", self.config), torch_dtype=self.dtype, **adapter_kwargs, **kwargs ).to(self.dtype) else: self.adapter = CrossAttentionAdapter.from_pretrained( pretrained_model_name_or_path=self.config.name_or_path, input_hidden_size=self.embedding_model.config.hidden_size, cross_attn_every_n_layers=cross_attn_every_n_layers, config=getattr(self.config, "vision_config", self.config), subfolder="adapter", torch_dtype=self.dtype, **adapter_kwargs, **kwargs ).to(self.dtype) if "device_map" not in kwargs: self.embedding_model = self.embedding_model.to(self.device) self.adapter = self.adapter.to(self.device) self.add_hooks() def load_embedding_model(self, model_or_model_name_or_path, **model_kwargs): if isinstance(model_or_model_name_or_path, str): model = AutoModel.from_pretrained( pretrained_model_name_or_path=model_or_model_name_or_path, torch_dtype=self.dtype, **model_kwargs ) else: model = model_or_model_name_or_path return model.to(self.dtype) def add_hooks(self): handles, adapter_inputs, cross_attention = list(), dict(), dict() for name, module in self.named_modules(): if "vision_model" in name and type(module) == nn.ModuleList: vision_layers = module break else: raise ValueError("Couldn't locate vision encoder layers!") # HACK: convert args to kwargs def args_to_kwargs(module, args): return dict(zip(type(module).forward.__code__.co_varnames[1:], args)) def forward_hook(layer, args, kwargs): if kwargs.get("image_hidden_states") is not None: # we are in .generate method which calls forward iteratively for key in ["adapter_input_ids", "adapter_attention_mask"]: kwargs.pop(key, None) elif (adapter_input_ids:=kwargs.pop("adapter_input_ids", None)) is not None: if not hasattr(self, "adapter"): raise ValueError("Got `adapter_input_ids` but no adapter is loaded!") adapter_inputs.update( input_ids=adapter_input_ids, attention_mask=kwargs.pop("adapter_attention_mask", None) ) if (kwargs | args_to_kwargs(layer, args)).get("pixel_values") is None: dummy_input = self.adapter.dummy_input.clamp(-1, 1) kwargs['pixel_values'] = dummy_input.repeat(len(adapter_input_ids), 1, 1, 1) return args, kwargs for layer, cross_layer in zip(vision_layers, self.adapter.layers): if cross_layer is not None: def layer_hook(cross_layer, layer, args, kwargs): if adapter_inputs: embeddings = self.embedding_model(**adapter_inputs).last_hidden_state cross_attention.update( cross_attention_states=self.adapter.connect(embeddings), cross_attention_mask=self.adapter.prepare_4d_attention_mask( adapter_inputs["attention_mask"], embeddings.dtype )) adapter_inputs.clear() if cross_attention: kwargs |= args_to_kwargs(layer, args) kwargs['hidden_states'] = cross_layer(**cross_attention, **kwargs)[0] return tuple(), kwargs handles.append(layer.register_forward_pre_hook(partial(layer_hook, cross_layer), with_kwargs=True)) handles.append(self.register_forward_pre_hook(forward_hook, with_kwargs=True)) handles.append(self.register_forward_hook(lambda *_: cross_attention.clear())) self.handles = handles def unload_cross_attn_adapter(self): for handle in self.handles: handle.remove() del self.adapter, self.embedding_model, self.handles def save_cross_attn_adapter(self, *args, **kwargs): return self.adapter.save_pretrained(*args, **kwargs) def has_adapter(self): return hasattr(self, "adapter") ``` ## /detikzify/model/adapter/processing_adapter.py ```py path="/detikzify/model/adapter/processing_adapter.py" from typing import List, Optional, TYPE_CHECKING, Union, Unpack from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ImageInput, make_list_of_images from transformers.processing_utils import ProcessingKwargs, ProcessorMixin from transformers.tokenization_utils_base import ( BatchEncoding, PreTokenizedInput, TextInput, ) from transformers.utils import logging from ...util import DUMMY_IMAGE if TYPE_CHECKING: from transformers.tokenization_utils_base import PreTokenizedInput logger = logging.get_logger(__name__) class AdapterProcessor(ProcessorMixin): attributes = ["processor", "tokenizer"] processor_class = ("ProcessorMixin", "ImageProcessingMixin") tokenizer_class = "AutoTokenizer" def __init__(self, processor, tokenizer=None, **kwargs): if processor is None: raise ValueError("You need to specify a `processor`.") if tokenizer is None: raise ValueError("You need to specify a `tokenizer`.") super().__init__(processor, tokenizer, **kwargs) def __call__( self, text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, images: Optional[ImageInput] = None, **kwargs: Unpack[ProcessingKwargs], ) -> BatchEncoding: if images is None and text is None: raise ValueError("Either `images` or `text` (or both) are expected as arguments to an `AdapterProcessor` instance.") text_kwargs, images_kwargs = kwargs.pop("text_kwargs", {}), kwargs.pop("images_kwargs", {}) if text is None: text_inputs = dict() else: text = [text] if isinstance(text, str) else text text_inputs = {f"adapter_{key}": value for key, value in self.tokenizer(text=text, **kwargs, **text_kwargs).items()} if getattr(self.processor, "model_expects_text", False): images_kwargs.update(text=text, add_bos_token=True) if images is None: image_inputs = self.processor(images=len(text) * [DUMMY_IMAGE], **kwargs, **images_kwargs) image_inputs = dict((k, image_inputs[k]) for k in ["input_ids", "attention_mask"] if k in image_inputs) else: images = make_list_of_images(images) image_inputs = self.processor(images=images, **kwargs, **images_kwargs) if text is not None and images is not None and len(images) != len(text): raise ValueError( f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image." ) return BatchFeature(data={**image_inputs, **text_inputs}) def batch_decode(self, *args, **kwargs): return self.processor.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.processor.decode(*args, **kwargs) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names processor_input_names = self.processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + processor_input_names)) ``` ## /detikzify/model/configuration_detikzify.py ```py path="/detikzify/model/configuration_detikzify.py" # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Adapted from # https://github.com/huggingface/transformers/commit/e1b150862e66e16acf951edfa13206ffcd1032be import os from typing import Union from transformers import CONFIG_MAPPING from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class DetikzifyVisionConfig(PretrainedConfig): model_type = "detikzify" def __init__( self, hidden_size=1152, intermediate_size=4304, num_hidden_layers=27, num_attention_heads=16, num_channels=3, image_size=420, patch_size=14, hidden_act="gelu_pytorch_tanh", layer_norm_eps=1e-6, attention_dropout=0.0, initializer_range=0.02, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_channels = num_channels self.patch_size = patch_size self.image_size = image_size self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act self.initializer_range = initializer_range @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) # get the vision config dict if we are loading from DetikzifyConfig if config_dict.get("model_type") == "detikzify": config_dict = config_dict["vision_config"] if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) class DetikzifyConfig(PretrainedConfig): model_type = "detikzify" is_composition = True def __init__( self, use_cache=True, image_token_id=128005, tie_word_embeddings=False, vision_config=None, text_config=None, concat_factor=3, pad_token_id=128004, **kwargs, ): self.image_token_id = image_token_id self.use_cache = use_cache self.tie_word_embeddings = tie_word_embeddings if vision_config is None: self.vision_config = DetikzifyVisionConfig() logger.info("vision_config is None, using default vision config") elif isinstance(vision_config, dict): self.vision_config = DetikzifyVisionConfig(**vision_config) elif isinstance(vision_config, DetikzifyVisionConfig): self.vision_config = vision_config if isinstance(text_config, dict): text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) elif text_config is None: logger.info("text_config is None, using default text config") text_config = CONFIG_MAPPING["llama"]( rms_norm_eps=1e-5, pad_token_id=pad_token_id, tie_word_embeddings=False, ) self.text_config = text_config self.concat_factor = concat_factor super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) ``` ## /detikzify/model/modeling_detikzify.py ```py path="/detikzify/model/modeling_detikzify.py" # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Adapted from # https://github.com/huggingface/transformers/commit/e1b150862e66e16acf951edfa13206ffcd1032be from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss import torch.utils.checkpoint from transformers import ( AutoModel, Cache, DynamicCache, GenerationMixin, PreTrainedModel, SiglipVisionModel, ) from transformers.modeling_outputs import ModelOutput from transformers.utils import logging from .adapter import CrossAttentionAdapterMixin from .configuration_detikzify import DetikzifyConfig logger = logging.get_logger(__name__) @dataclass class DetikzifyBaseModelOutputWithPast(ModelOutput): last_hidden_state: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None @dataclass class DetikzifyCausalLMOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None class DetikzifySimpleMLP(nn.Module): def __init__(self, config): super().__init__() input_size = config.vision_config.hidden_size * config.concat_factor output_size = config.text_config.hidden_size self.proj = nn.Linear(input_size, output_size, bias=False) def forward(self, x): return self.proj(x) class DetikzifyConnector(nn.Module): def __init__(self, config): super().__init__() self.concat_factor = config.concat_factor self.modality_projection = DetikzifySimpleMLP(config) def concatenate(self, x, concat_factor=3): bsz, seq, embed_dim = x.size() return x.reshape(bsz, seq // concat_factor, embed_dim * concat_factor) def forward(self, image_hidden_states): image_hidden_states = self.concatenate(image_hidden_states, self.concat_factor) image_hidden_states = self.modality_projection(image_hidden_states) return image_hidden_states class DetikzifyPreTrainedModel(PreTrainedModel): config_class = DetikzifyConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = [] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True def _init_weights(self, module): std = ( self.config.initializer_range if hasattr(self.config, "initializer_range") else self.config.text_config.initializer_range ) if hasattr(module, "class_embedding"): module.class_embedding.data.normal_(mean=0.0, std=std) if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class DetikzifyModel(DetikzifyPreTrainedModel): def __init__(self, config: DetikzifyConfig): super().__init__(config) self.padding_idx = self.config.text_config.pad_token_id self.vocab_size = self.config.text_config.vocab_size self.vision_model = SiglipVisionModel._from_config(config.vision_config) self.connector = DetikzifyConnector(config) self.text_model = AutoModel.from_config(config.text_config) self.image_seq_len = int( ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.concat_factor) ) self.image_token_id = self.config.image_token_id self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" self.post_init() def enable_input_require_grads(self): def get_lowest_module(module): if len(list(module.children())) == 0: # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) return module else: # Recursively call the function on each child module return get_lowest_module(list(module.children())[0]) def make_inputs_require_grads(module, input, output): output.requires_grad_(True) self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( make_inputs_require_grads ) def disable_input_require_grads(self): self._text_require_grads_hook.remove() self._vision_require_grads_hook.remove() def get_input_embeddings(self): return self.text_model.get_input_embeddings() def set_input_embeddings(self, value): self.text_model.set_input_embeddings(value) def inputs_merger( self, input_ids: torch.LongTensor, inputs_embeds: Optional[torch.Tensor], image_hidden_states: Optional[torch.Tensor], ): num_images, _, vision_hidden_size = image_hidden_states.shape special_image_token_mask = input_ids == self.image_token_id # Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. new_inputs_embeds = inputs_embeds.clone() reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size) # cast to the dtype of the input_embeds to support quantized models reshaped_image_hidden_states = reshaped_image_hidden_states.to(inputs_embeds.dtype) new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states return new_inputs_embeds def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_hidden_states: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, DetikzifyBaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.training and self.text_model.gradient_checkpointing and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # retrieve input_ids and inputs_embeds if input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") past_seen_tokens = 0 if use_cache: if past_key_values is None: past_key_values = DynamicCache() past_seen_tokens = past_key_values.get_seq_length() if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") if inputs_embeds is None: inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device) # START VISUAL INPUTS INTEGRATION if pixel_values is not None and image_hidden_states is not None: raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") elif pixel_values is not None: # Get sequence from the vision encoder image_hidden_states = self.vision_model( pixel_values=pixel_values.to(dtype=self.dtype), # fp16 compatibility ).last_hidden_state # Modality projection image_hidden_states = self.connector(image_hidden_states) elif image_hidden_states is not None: image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: # When we generate, we don't want to replace the potential image_token_id that we generated by images # that simply don't exist inputs_embeds = self.inputs_merger( input_ids=input_ids, inputs_embeds=inputs_embeds, image_hidden_states=image_hidden_states, ) outputs = self.text_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: return tuple(v for v in [*outputs, image_hidden_states] if v is not None) return DetikzifyBaseModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_hidden_states, ) class DetikzifyForConditionalGeneration(DetikzifyPreTrainedModel, GenerationMixin, CrossAttentionAdapterMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = DetikzifyModel(config) self.image_token_id = self.config.image_token_id self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.vocab_size = config.text_config.vocab_size # Initialize weights and apply final processing self.post_init() def enable_input_require_grads(self): def make_inputs_require_grads(module, input, output): output.requires_grad_(True) self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook( make_inputs_require_grads ) def disable_input_require_grads(self): self._text_require_grads_hook.remove() self._vision_require_grads_hook.remove() def get_input_embeddings(self): return self.model.text_model.get_input_embeddings() def set_input_embeddings(self, value): self.model.text_model.set_input_embeddings(value) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def tie_weights(self): output_embeddings = self.get_output_embeddings() input_embeddings = self.get_input_embeddings() if getattr(self.config, "tie_word_embeddings", True): output_embeddings.weight = input_embeddings.weight def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_hidden_states: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, DetikzifyCausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, pixel_values=pixel_values, image_hidden_states=image_hidden_states, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: labels = labels.to(logits.device) # Shift so that tokens < n predict n if attention_mask is not None: # we use the input attention mask to shift the logits and labels, because it is 2D. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return DetikzifyCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, pixel_values=None, image_hidden_states=None, num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: input_ids = input_ids[:, cache_position] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step # but IDEFICS requires noth ids and embeds to be present if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids} else: # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if num_logits_to_keep is not None: model_inputs["num_logits_to_keep"] = num_logits_to_keep if image_hidden_states is not None: pixel_values = None # support model.generate method with adapters if self.has_adapter(): for key in ["adapter_input_ids", "adapter_attention_mask"]: if key in kwargs: model_inputs[key] = kwargs[key] model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "pixel_values": pixel_values, "image_hidden_states": image_hidden_states, } ) return model_inputs def _validate_model_kwargs(self, model_kwargs): # support model.generate method with adapters if self.has_adapter(): for key in ['adapter_input_ids', 'adapter_attention_mask']: model_kwargs.pop(key, None) super()._validate_model_kwargs(model_kwargs) def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): model_kwargs = super()._update_model_kwargs_for_generation( outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, **kwargs, ) # Get the precomputed image_hidden_states model_kwargs["image_hidden_states"] = outputs.image_hidden_states return model_kwargs ``` ## /detikzify/model/processing_detikzify.py ```py path="/detikzify/model/processing_detikzify.py" # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Adapted from # https://github.com/huggingface/transformers/commit/e1b150862e66e16acf951edfa13206ffcd1032be from typing import List, Optional, Union, Unpack from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ImageInput, make_list_of_images from transformers.processing_utils import ProcessingKwargs, ProcessorMixin from transformers.tokenization_utils_base import ( BatchEncoding, PreTokenizedInput, TextInput, ) from transformers.utils import logging logger = logging.get_logger(__name__) class DetikzifyProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "add_special_tokens": False, "padding": False, }, } class DetikzifyProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" def __init__( self, image_processor, tokenizer=None, image_seq_len: int = 300, image_token: str = "<|reserved_special_token_2|>", model_expects_text: bool = False, **kwargs, ): if image_processor is None: raise ValueError("You need to specify an `image_processor`.") if tokenizer is None: raise ValueError("You need to specify a `tokenizer`.") if image_token not in tokenizer.vocab: raise ValueError(f"{image_token} needs to be added to the `tokenizer` vocabulary.") self.image_token = image_token self.image_seq_len = image_seq_len self.model_expects_text = model_expects_text super().__init__(image_processor, tokenizer, **kwargs) def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, images: ImageInput = None, image_seq_len: Optional[int] = None, add_bos_token: bool = None, add_eos_token: bool = None, **kwargs: Unpack[DetikzifyProcessorKwargs], ) -> BatchEncoding: output_kwargs = self._merge_kwargs( DetikzifyProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) # Temporary fix for "padding_side" in init_kwargs output_kwargs["text_kwargs"].pop("padding_side", None) if images is None: raise ValueError("`images` are expected as arguments to a `DetikzifyProcessor` instance.") else: images = make_list_of_images(images) if text is None: text = len(images) * [""] elif isinstance(text, str): text = [text] if len(images) != len(text): raise ValueError( f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image." ) prompt_strings = [] for prompt in text: assert self.image_token not in prompt, "Image tokens are added by the processor!" if add_bos_token: prompt += self.tokenizer.bos_token if add_eos_token: prompt += self.tokenizer.eos_token image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len prompt_strings.append((self.image_token * image_seq_len) + prompt) image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"]) return BatchFeature(data={**image_inputs, **text_inputs}) def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) ``` ## /detikzify/model/v1/__init__.py ```py path="/detikzify/model/v1/__init__.py" from datasets import DownloadManager from transformers import AutoConfig, AutoModel from transformers import AutoTokenizer, PretrainedConfig from transformers.utils.hub import is_remote_url from .configuration_detikzify import * from .modeling_detikzify import * from .processing_detikzify import * models = [ "nllg/detikzify-ds-1.3b", "nllg/detikzify-ds-7b", "nllg/detikzify-tl-1.1b", "nllg/detikzify-cl-7b", ] def register(): try: AutoConfig.register("detikzify", DetikzifyConfig) AutoModel.register(DetikzifyConfig, DetikzifyForCausalLM) except ValueError: pass # already registered def load(model_name_or_path, vision_tower="vit_so400m_patch14_siglip_384.webli", modality_projector=None, **kwargs): base_tokenizer = PretrainedConfig.from_pretrained(model_name_or_path).name_or_path or model_name_or_path tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=base_tokenizer, model_max_length=2048, add_bos_token=False, add_eos_token=True, pad_token="", padding_side="right", # NOTE: only for training, need to change to "left" for batched inference legacy=False ) model = DetikzifyForCausalLM.from_pretrained( pretrained_model_name_or_path=model_name_or_path, use_cache=True, **kwargs ) model.config.model_type = DetikzifyConfig.model_type # type: ignore model.generation_config.pad_token_id = tokenizer.pad_token_id # type: ignore if len(tokenizer) > model.config.vocab_size: # type: ignore model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) # type: ignore if modality_projector and is_remote_url(modality_projector): modality_projector = DownloadManager().download(modality_projector) processor = model.get_model().initialize_vision_modules( # type: ignore patch_token_id=tokenizer.bos_token_id, modality_projector=modality_projector, vision_tower=getattr(model.config, "vision_tower", vision_tower), # type: ignore feature_layer=getattr(model.config, "feature_layer", -1), # type: ignore concat_patches=getattr(model.config, "concat_patches", 3) # type: ignore ) return model, tokenizer, processor ``` ## /detikzify/model/v1/configuration_detikzify.py ```py path="/detikzify/model/v1/configuration_detikzify.py" from transformers import LlamaConfig class DetikzifyConfig(LlamaConfig): model_type = "detikzify" # compatibility with new inference code @property def image_token_id(self): return self.patch_token_id @property def pooling_mode(self): return "cos" ``` ## /detikzify/model/v1/modeling_detikzify.py ```py path="/detikzify/model/v1/modeling_detikzify.py" # Adopted from https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava.py. Below is the original copyright: # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from pickle import UnpicklingError from typing import List, Optional, Tuple, Union from numpy import clip from safetensors.torch import load_file from timm import create_model as create_vision_model import torch import torch.nn as nn from torch.nn import CrossEntropyLoss import torch.nn.functional as F from transformers import ( BatchEncoding, LlamaConfig, LlamaForCausalLM, LlamaModel, PretrainedConfig, PreTrainedModel, ) from transformers.modeling_outputs import ( BaseModelOutputWithPoolingAndNoAttention, BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.utils import logging from .configuration_detikzify import DetikzifyConfig from .processing_detikzify import DetikzifyImageProcessor logger = logging.get_logger("transformers") class DetikzifyVisionModel(PreTrainedModel): _no_split_modules = ["VisionTransformer"] def __init__(self, model, **kwargs) -> None: super().__init__(PretrainedConfig.from_dict(model.pretrained_cfg), **kwargs) # HACK: wrap in list so that vision model does not count as a parameter self.model = [model] def get_input_embeddings(self) -> torch.nn.Module: return self.model[0].patch_embed def to_input_dtype(self, pixel_values: torch.Tensor): target_dtype = self.get_input_embeddings().proj.weight.dtype return pixel_values.to(dtype=target_dtype) def forward(self, pixel_values: torch.Tensor): last_hidden_state = self.model[0].forward_features(self.to_input_dtype(pixel_values)) pooler_output = self.model[0].forward_head(last_hidden_state) return BaseModelOutputWithPoolingAndNoAttention( last_hidden_state=last_hidden_state, pooler_output=pooler_output ) def get_intermediate_layers(self, pixel_values: torch.Tensor, *args, **kwargs): return self.model[0].get_intermediate_layers(self.to_input_dtype(pixel_values), *args, **kwargs) class DetikzifyModel(LlamaModel): config_class = DetikzifyConfig def __init__(self, config: LlamaConfig): super(DetikzifyModel, self).__init__(config) if getattr(config, "use_mm_proj"): self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) self.vision_model = DetikzifyVisionModel(create_vision_model(config.vision_tower)) def initialize_vision_modules( self, vision_tower, patch_token_id, concat_patches=3, feature_layer=-1, modality_projector=None, **kwargs ): vision_model = create_vision_model(vision_tower, pretrained=True, **kwargs) self.vision_model = DetikzifyVisionModel(vision_model.to(self.device, self.dtype).eval().requires_grad_(False)) processor = DetikzifyImageProcessor.from_pretrained(vision_tower) self.config.use_mm_proj = True self.config.vision_tower = vision_tower self.config.mm_hidden_size = vision_model.embed_dim * concat_patches self.config.patch_token_id = patch_token_id self.config.concat_patches = concat_patches self.config.feature_layer = int(clip(feature_layer, -(depth:=len(vision_model.blocks)), depth-1) % depth) self.config.vision_config = processor.to_dict() # type: ignore self.config.num_patches = vision_model.patch_embed.num_patches // concat_patches if not hasattr(self, 'mm_projector'): self.mm_projector = nn.Linear( self.config.mm_hidden_size, self.config.hidden_size, dtype=self.dtype, device=self.device ) if modality_projector is not None: try: # first try to load as pickle mm_projector_weights = torch.load(modality_projector, map_location=self.device) except UnpicklingError: # and if that fails we try safetensors mm_projector_weights = load_file(modality_projector, device=str(self.device)) self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) return processor # https://stackoverflow.com/a/57208704 def _apply(self, fn): super()._apply(fn) if hasattr(self, "vision_model"): self.set_vision_model = self.vision_model._apply(fn) return self def get_vision_features(self, pixel_values): concat, n_patch, layer = self.config.concat_patches, self.config.num_patches, self.config.feature_layer feats = self.vision_model.get_intermediate_layers(pixel_values, n=[layer], norm=True)[0] # in case the number of feature vectors is not divisible by the number # of patches we want to concatenate, we remove the first feature(s) return feats[:, -n_patch * concat:].reshape(-1, n_patch, feats.shape[-1] * concat) def is_tensor(self, thing): if isinstance(thing, (BatchEncoding, dict)): return all(isinstance(v, torch.Tensor) for v in thing.values()) return isinstance(thing, torch.Tensor) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, pixel_values: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if hasattr(self, "vision_model") and (input_ids.shape[1] != 1 or self.training) and pixel_values is not None: with torch.no_grad(): image_features = self.get_vision_features(pixel_values) image_features = self.mm_projector(image_features) dummy_image_features = torch.zeros(len(image_features[0]), self.config.mm_hidden_size, device=inputs_embeds.device, dtype=inputs_embeds.dtype) dummy_image_features = self.mm_projector(dummy_image_features) new_input_embeds = [] cur_image_idx = 0 for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): if (cur_input_ids == self.config.image_token_id).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() new_input_embeds.append(cur_input_embeds) cur_image_idx += 1 continue cur_image_features = image_features[cur_image_idx].to(cur_input_embeds.device) num_patches = cur_image_features.shape[0] if (cur_input_ids == self.config.image_token_id).sum() != num_patches: raise ValueError("The number of image patch tokens should be the same as the number of image patches.") masked_indices = torch.where(cur_input_ids == self.config.image_token_id)[0] mask_index_start = masked_indices[0] if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any(): raise ValueError("The image patch tokens should be consecutive.") cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0) new_input_embeds.append(cur_new_input_embeds) cur_image_idx += 1 inputs_embeds = torch.stack(new_input_embeds, dim=0) return super(DetikzifyModel, self).forward( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) class DetikzifyForCausalLM(LlamaForCausalLM): config_class = DetikzifyConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = DetikzifyModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, pixel_values: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, pixel_values=pixel_values ) hidden_states = outputs[0] if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "pixel_values": kwargs.get("pixel_values", None), } ) return model_inputs ``` ## /detikzify/model/v1/processing_detikzify.py ```py path="/detikzify/model/v1/processing_detikzify.py" # Adopted from https://github.com/huggingface/optimum-intel/blob/main/optimum/intel/openvino/modeling_timm.py Below is the original copyright: # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Dict, List, Optional, Union import numpy as np from timm.data import resolve_data_config from timm.models import resolve_pretrained_cfg from transformers.image_processing_utils import ( BaseImageProcessor, BatchFeature, get_size_dict, ) from transformers.image_transforms import resize, to_channel_dimension_format from transformers.image_utils import ( ChannelDimension, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageInput, PILImageResampling, make_list_of_images, to_numpy_array, valid_images, ) from transformers.utils import TensorType class DetikzifyImageProcessor(BaseImageProcessor): r""" Constructs a ViT image processor. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to resize the image's (height, width) dimensions to the specified `(size["height"], size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`): Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` method. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the `preprocess` method. do_rescale (`bool`, *optional*, defaults to `True`): Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` parameter in the `preprocess` method. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the `preprocess` method. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` method. image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): Mean to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. """ model_input_names = ["pixel_values"] def __init__( self, do_resize: bool = True, size: Optional[Dict[str, int]] = None, resample: PILImageResampling = PILImageResampling.BILINEAR, do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, **kwargs, ) -> None: super().__init__(**kwargs) size = size if size is not None else {"height": 224, "width": 224} size = get_size_dict(size) self.do_resize = do_resize self.do_rescale = do_rescale self.do_normalize = do_normalize self.size = size self.resample = resample self.rescale_factor = rescale_factor self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs, ): pretrained_cfg = resolve_pretrained_cfg(variant=pretrained_model_name_or_path) timm_config_dict = resolve_data_config(pretrained_cfg.to_dict()) _, im_h, im_w = timm_config_dict.get("input_size", [3, 224, 224]) image_preprocess_config_dict = { "crop_size": {"height": im_h, "width": im_w}, "do_center_crop": True if timm_config_dict.get("crop_mode") == "center" else False, "do_normalize": True, "do_reduce_labels": False, "do_rescale": True, "do_resize": True, "image_mean": timm_config_dict.get("mean", IMAGENET_STANDARD_MEAN), "image_processor_type": "TimmImageProcessor", "image_std": timm_config_dict.get("std", IMAGENET_STANDARD_STD), "resample": 3, "rescale_factor": 0.00392156862745098, "size": {"height": im_h, "width": im_w}, } return cls.from_dict(image_preprocess_config_dict, **kwargs) def resize( self, image: np.ndarray, size: Dict[str, int], resample: PILImageResampling = PILImageResampling.BILINEAR, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> np.ndarray: """ Resize an image to `(size["height"], size["width"])`. Args: image (`np.ndarray`): Image to resize. size (`Dict[str, int]`): Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. resample: `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the output image. If unset, the channel dimension format of the input image is used. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. Returns: `np.ndarray`: The resized image. """ size = get_size_dict(size) if "height" not in size or "width" not in size: raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") if image.ndim == 2: image = np.stack([image] * 3, axis=-1) return resize( image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs ) def preprocess( self, images: ImageInput, do_resize: Optional[bool] = None, size: Dict[str, int] = None, resample: PILImageResampling = None, do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, **kwargs, ): """ Preprocess an image or batch of images. Args: images (`ImageInput`): Image to preprocess. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after resizing. resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`): `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has an effect if `do_resize` is set to `True`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image values between [0 - 1]. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): Rescale factor to rescale the image by if `do_rescale` is set to `True`. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Image mean to use if `do_normalize` is set to `True`. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation to use if `do_normalize` is set to `True`. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - Unset: Use the channel dimension format of the input image. """ do_resize = do_resize if do_resize is not None else self.do_resize do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_normalize = do_normalize if do_normalize is not None else self.do_normalize resample = resample if resample is not None else self.resample rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std size = size if size is not None else self.size size_dict = get_size_dict(size) images = make_list_of_images(images) if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) if do_resize and size is None: raise ValueError("Size must be specified if do_resize is True.") if do_rescale and rescale_factor is None: raise ValueError("Rescale factor must be specified if do_rescale is True.") # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] if do_resize: images = [self.resize(image=image, size=size_dict, resample=resample) for image in images] if do_rescale: images = [self.rescale(image=image, scale=rescale_factor) for image in images] if do_normalize: images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [to_channel_dimension_format(image, data_format) for image in images] data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) ``` ## /detikzify/train/__init__.py ```py path="/detikzify/train/__init__.py" from .pretrain import train as pretrain from .train import train ``` ## /detikzify/train/adapter/__init__.py ```py path="/detikzify/train/adapter/__init__.py" from transformers import SiglipVisionModel from ...model.adapter import CrossAttentionAdapterMixin from .pretrain import train as pretrain from .train import train #from .train import train class CrossAttentionSiglipVisionModel(SiglipVisionModel, CrossAttentionAdapterMixin): ... ``` ## /detikzify/train/adapter/pretrain.py ```py path="/detikzify/train/adapter/pretrain.py" import os from typing import Dict, Literal import torch from torch.utils.data import Dataset from torchvision.transforms import v2 from transformers import ( Trainer, TrainerCallback, TrainingArguments, is_torch_xla_available, ) from transformers.trainer_utils import SaveStrategy, get_last_checkpoint from transformers.utils import logging from ...model.adapter.modeling_adapter import CrossAttentionAdapterMixin from ...util import ( EditCutMix, EditCutOut, EditMixUp, FullErase, SketchAugment, SplitEpochSaveCallback, unwrap_processor, ) if is_torch_xla_available(): import torch_xla.core.xla_model as xm # type: ignore logger = logging.get_logger("transformers") WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) class EmbeddingSimilarityLoss(): def __init__(self, elementwise: bool = True, use_mse: bool = False): self.cosine = torch.nn.CosineSimilarity(dim=-1) self.mae = torch.nn.L1Loss(reduction="none") self.mse = torch.nn.MSELoss(reduction="none") self.elementwise = elementwise self.use_mse = use_mse # https://github.com/pytorch/pytorch/issues/104564#issuecomment-1651575112 @torch.compile def cosine_loss(self, x, y): if self.elementwise: return self.mae(cos:=self.cosine(x, y), torch.ones_like(cos)).mean() else: X = self.cosine(x.unsqueeze(2), y.unsqueeze(1)) Y = self.cosine(y.unsqueeze(2), y.unsqueeze(1)) return self.mae(X, Y).max(dim=-1)[0].mean() @torch.compile def l2_loss(self, x, y): if self.elementwise: return self.mse(x, y).mean() else: X, Y = torch.cdist(x, y), torch.cdist(y, y) return self.mae(X, Y).max(dim=-1)[0].mean() def __call__(self, x, y): if self.use_mse: return self.l2_loss(x, y) else: return self.cosine_loss(x, y) # https://huggingface.co/docs/transformers/main/en/tasks/knowledge_distillation_for_image_classification class AdapterTrainer(Trainer): def __init__( self, model: CrossAttentionAdapterMixin, loss_term: Literal["avg", "pool", "patch", "layer"] = "patch", elementwise_loss: bool = True, mse_loss: bool = False, pool_train_head: bool = False, multimodal: bool = False, *args, **kwargs, ): self.term = loss_term self.loss_function = EmbeddingSimilarityLoss(elementwise=elementwise_loss, use_mse=mse_loss) train_head = self.term == "pool" and pool_train_head super().__init__(self.prepare_model(model, train_head, multimodal), *args, **kwargs) # type: ignore if self.term == "layer": self.loss_layers = sorted({len(self.model.adapter.layers)} | { idx for idx, layer in enumerate(self.model.adapter.layers, 1) if layer is not None }) self.control.layer_losses = {layer: 0 for layer in self.loss_layers} def prepare_model(self, model, train_head=False, multimodal=False): for name, param in model.named_parameters(): if not "adapter" in name and (not train_head or not "head" in name): param.requires_grad = False elif multimodal and "dummy_input" in name: param.requires_grad = False elif model.dtype != torch.float32: param.data = param.data.to(torch.float32) if train_head: # in this case we also want gradients for the teacher model.vision_model.head.forward = torch.enable_grad(model.vision_model.head.forward) if self.term != "pool": model.vision_model.use_head = False model.embedding_model.enable_input_require_grads() model.enable_input_require_grads() return model def compute_loss(self, model, inputs, return_outputs=False, **_): with torch.no_grad(): teacher_output = model( pixel_values=inputs.pop("labels"), output_hidden_states=self.term=="layer" ) student_output = model( output_hidden_states=self.term=="layer", **inputs, ) if self.term == "avg": loss = self.loss_function( student_output.last_hidden_state.mean(dim=1), teacher_output.last_hidden_state.mean(dim=1), ) elif self.term == "pool": loss = self.loss_function( student_output.pooler_output, teacher_output.pooler_output, ) elif self.term == "patch": loss = self.loss_function( student_output.last_hidden_state, teacher_output.last_hidden_state ) else: loss = 0 for layer in self.loss_layers: last_layer = layer == self.loss_layers[-1] layer_loss = self.loss_function( student_output.last_hidden_state if last_layer else student_output.hidden_states[layer], teacher_output.last_hidden_state if last_layer else teacher_output.hidden_states[layer] ) loss += .5 * (1 if last_layer else 1/(len(self.loss_layers)-1)) * layer_loss log_layer_loss = layer_loss.mean() if self.args.n_gpu > 1 else layer_loss log_layer_loss = log_layer_loss.detach() / self.args.gradient_accumulation_steps self.control.layer_losses[layer] += log_layer_loss return (loss, student_output) if return_outputs else loss # https://github.com/naba89/custom_hf_trainer def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time): if self.control.should_log and self.state.global_step > self._globalstep_last_logged: if is_torch_xla_available(): xm.mark_step() # type: ignore logs: Dict[str, float] = {} # all_gather + mean() to get average loss over all processes tr_loss_scalar = self._nested_gather(tr_loss).mean().item() # type: ignore # reset tr_loss to zero tr_loss -= tr_loss logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) if grad_norm is not None: logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm logs["learning_rate"] = self._get_learning_rate() if self.term == "layer": for k, v in self.control.layer_losses.items(): layer_loss = self._nested_gather(v).mean().item() # type: ignore logs[f"layer_loss_{k}"] = round(layer_loss / (self.state.global_step - self._globalstep_last_logged), 4) self.control.layer_losses[k] -= v # reset the loss self._total_loss_scalar += tr_loss_scalar self._globalstep_last_logged = self.state.global_step self.store_flos() self.log(logs, start_time) metrics = None if self.control.should_evaluate: metrics = self._evaluate(trial, ignore_keys_for_eval) is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) if self.args.save_strategy == SaveStrategy.BEST: self.control.should_save = is_new_best_metric if self.control.should_save: self._save_checkpoint(model, trial) self.control = self.callback_handler.on_save(self.args, self.state, self.control) class AdapterDataset(Dataset, TrainerCallback): def __init__(self, dataset, processor, multimodal=False): super().__init__() self.processor = processor self.dataset = dataset self.multimodal = multimodal self.offset = 1 self.sketchify = SketchAugment(intensity=2) self.mixup = EditMixUp() self.cutmix = EditCutMix() self.cutout = EditCutOut() self.erase = FullErase() def __len__(self): return len(self.dataset) def __getitems__(self, indices) -> Dict[str, torch.Tensor]: batch, images = self.dataset[indices], None labels = torch.stack([v2.functional.pil_to_tensor(img) for img in batch['image']]) if self.multimodal: partition, images = torch.randint(3, (len(indices),)), torch.empty_like(labels) if len(sketch_ids:=torch.argwhere(partition == 0).flatten()): images[sketch_ids] = self.sketchify(labels[sketch_ids]) if len(blank_ids:=torch.argwhere(partition == 1).flatten()): images[blank_ids] = self.erase(labels[blank_ids]) if len(edit_ids:=torch.argwhere(partition == 2).flatten()): edit_partition = torch.randint(3, (len(edit_ids),)) if len(cutout_ids:=edit_ids[torch.argwhere(edit_partition == 0)].flatten()): images[cutout_ids] = self.cutout(labels[cutout_ids]) if len(mixup_ids:=edit_ids[torch.argwhere(edit_partition == 1)].flatten()): mixup_imgs = self.dataset[[(indices[idx] + self.offset) % len(self) for idx in mixup_ids]]['image'] mixup_imgs = torch.stack([v2.functional.pil_to_tensor(img) for img in mixup_imgs]) interleaved_imgs = torch.stack([labels[mixup_ids], mixup_imgs], dim=1).view(-1, *mixup_imgs.shape[1:]) images[mixup_ids] = self.mixup(interleaved_imgs)[::2] if len(cutmix_ids:=edit_ids[torch.argwhere(edit_partition == 2)].flatten()): cutmix_imgs = self.dataset[[(indices[idx] + self.offset) % len(self) for idx in cutmix_ids]]['image'] cutmix_imgs = torch.stack([v2.functional.pil_to_tensor(img) for img in cutmix_imgs]) interleaved_imgs = torch.stack([labels[cutmix_ids], cutmix_imgs], dim=1).view(-1, *cutmix_imgs.shape[1:]) images[cutmix_ids] = self.cutmix(interleaved_imgs)[::2] input_ids = self.processor( images=images, text=batch['text'], return_tensors="pt", text_kwargs=dict( padding=True, truncation=True, ) ) label_ids = unwrap_processor(self.processor)(images=labels, return_tensors="pt") input_ids['labels'] = label_ids['pixel_values'] return input_ids def __getitem__(self, index) -> Dict[str, torch.Tensor]: return self.__getitems__([index] if isinstance(index, int) else index) def on_epoch_end(self, *args, **kwargs): self.offset += 1 def train( output_dir: str, model, processor, dataset, overwrite=False, deepspeed=None, # training hyperparams multimodal: bool = False, batch_size: int = 512, micro_batch_size: int = 8, num_epochs: int = 3, learning_rate: float = 1e-4, gradient_checkpointing: bool = False, **loss_kwargs ): dataset = AdapterDataset(dataset, processor=processor, multimodal=multimodal) gradient_accumulation_steps = batch_size // micro_batch_size if WORLD_SIZE != 1: gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE last_checkpoint = None if os.path.isdir(output_dir) and not overwrite: last_checkpoint = get_last_checkpoint(output_dir) if last_checkpoint is None and len(os.listdir(output_dir)) > 0: raise ValueError( f"Output directory ({output_dir}) already exists and is not empty. " "Use `overwrite` to overcome." ) elif last_checkpoint is not None: logger.info( f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " "the `output_dir` or add `overwrite` to train from scratch." ) trainer = AdapterTrainer( model=model, train_dataset=dataset, multimodal=multimodal, callbacks=[SplitEpochSaveCallback(step_size=0.5)], data_collator=lambda batch: batch, **loss_kwargs, args=TrainingArguments( per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, gradient_checkpointing=gradient_checkpointing, # https://github.com/huggingface/transformers/issues/21381 gradient_checkpointing_kwargs={'use_reentrant': False}, dataloader_num_workers=WORLD_SIZE, warmup_steps=500, weight_decay=0.1, num_train_epochs=num_epochs, learning_rate=learning_rate, torch_compile=True, bf16=True, tf32=True, logging_steps=250, logging_first_step=True, lr_scheduler_type="cosine", optim="adamw_torch" if deepspeed else "adamw_torch_fused", ddp_find_unused_parameters=False, remove_unused_columns=False, save_strategy="epoch", report_to="none", output_dir=output_dir, deepspeed=deepspeed, ) ) trainer.add_callback(trainer.train_dataset) trainer.train(resume_from_checkpoint=last_checkpoint) if trainer.is_deepspeed_enabled: # https://huggingface.co/docs/accelerate/v0.11.0/en/deepspeed#saving-and-loading from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint last_checkpoint = get_last_checkpoint(output_dir) load_state_dict_from_zero_checkpoint(trainer.model.float(), last_checkpoint) model.save_cross_attn_adapter(output_dir) trainer.save_state() return model, processor ``` ## /detikzify/train/adapter/train.py ```py path="/detikzify/train/adapter/train.py" from copy import deepcopy from datetime import timedelta import os from typing import Dict, List from accelerate import Accelerator, InitProcessGroupKwargs from datasets import Dataset import torch from torch.utils.data import Dataset from transformers import Trainer, TrainingArguments from transformers.trainer_utils import get_last_checkpoint from transformers.utils import logging from ...util import SplitEpochSaveCallback, unwrap_processor logger = logging.get_logger("transformers") IGNORE_INDEX = -100 WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) RANK = int(os.environ.get("RANK", 0)) def tokenize( batch, processor, caption_condition=False, **kwargs ): unwrapped_processor = unwrap_processor(processor) image_token = unwrapped_processor.image_token image_token_id = unwrapped_processor.tokenizer.convert_tokens_to_ids(image_token) bos_token = unwrapped_processor.tokenizer.bos_token input_ids = processor( text=batch['caption'], images_kwargs=dict( text=[bos_token.join(text) for text in zip(batch['caption'], batch['code'])] if caption_condition else batch['code'], max_length=unwrapped_processor.tokenizer.model_max_length, pad_to_multiple_of=8, add_eos_token=True, truncation=False, padding=True ), text_kwargs=dict( padding=True, truncation=True, ), **kwargs ) input_ids['labels'] = deepcopy(input_ids['input_ids']) if caption_condition: # do not train on caption and pad tokens for label_ids in input_ids['labels']: after_bos_token = False for idx, label_id in enumerate(label_ids): if not after_bos_token or label_id in {image_token_id, unwrapped_processor.tokenizer.pad_token_id}: if label_id == unwrapped_processor.tokenizer.bos_token_id: after_bos_token = True label_ids[idx] = IGNORE_INDEX elif label_id == unwrapped_processor.tokenizer.bos_token_id: after_bos_token = True else: # do not train on image and pad tokens for label_ids in input_ids['labels']: for idx, label_id in enumerate(label_ids): if label_id in {image_token_id, processor.tokenizer.pad_token_id}: label_ids[idx] = IGNORE_INDEX return input_ids class AdapterDataset(Dataset): def __init__(self, dataset, processor, caption_condition=False): super().__init__() self.processor = processor self.dataset = dataset.with_transform(self.tokenize) self.caption_condition = caption_condition def __len__(self): return len(self.dataset) def tokenize(self, batch): return tokenize( batch=batch, processor=self.processor, caption_condition=self.caption_condition, return_tensors="pt", ) def filter(self, *args, **kwargs): self.dataset = self.dataset.filter(*args, **kwargs) def __getitem__(self, index) -> Dict[str, torch.Tensor]: return self.dataset[index] def __getitems__(self, indices) -> Dict[str, List[torch.Tensor]]: return self.dataset[*indices] def train( output_dir: str, model, processor, dataset, overwrite=False, deepspeed=None, # training hyperparams caption_condition: bool = False, batch_size: int = 128, micro_batch_size: int = 1, num_epochs: int = 5, learning_rate: float = 5e-5, gradient_checkpointing: bool = False, ): gradient_accumulation_steps = batch_size // micro_batch_size if WORLD_SIZE != 1: gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE for _, param in model.model.vision_model.named_parameters(): param.requires_grad = False for _, param in model.adapter.named_parameters(): param.requires_grad = False for _, param in model.embedding_model.named_parameters(): param.requires_grad = False model.enable_input_require_grads() model.embedding_model.enable_input_require_grads() dataset = AdapterDataset(dataset, processor, caption_condition=caption_condition) logger.info(f"Dataset size before filtering out too long examples: {len(dataset)}") eos_token_id, model_max_length = unwrap_processor(processor).tokenizer.eos_token_id, unwrap_processor(processor).tokenizer.model_max_length with Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(days=3))]).main_process_first(): dataset.filter(lambda ex: (ex['input_ids'] == eos_token_id).nonzero() < model_max_length, num_proc=64, batch_size=16) logger.info(f"Dataset size after filtering out too long examples: {len(dataset)}") last_checkpoint = None if os.path.isdir(output_dir) and not overwrite: last_checkpoint = get_last_checkpoint(output_dir) if last_checkpoint is None and len(os.listdir(output_dir)) > 0: raise ValueError( f"Output directory ({output_dir}) already exists and is not empty. " "Use `overwrite` to overcome." ) elif last_checkpoint is not None: logger.info( f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " "the `output_dir` or add `overwrite` to train from scratch." ) trainer = Trainer( model=model, train_dataset=dataset, args=TrainingArguments( per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, gradient_checkpointing=gradient_checkpointing, # https://github.com/huggingface/transformers/issues/32576 #gradient_checkpointing_kwargs={'use_reentrant':False}, dataloader_num_workers=WORLD_SIZE, warmup_ratio=0.03, weight_decay=0, num_train_epochs=num_epochs, learning_rate=learning_rate, torch_compile=True, bf16=True, tf32=True, logging_steps=10, lr_scheduler_type="cosine", optim="adamw_torch" if deepspeed else "adamw_torch_fused", ddp_find_unused_parameters=False, remove_unused_columns=False, save_strategy="epoch", report_to="none", save_total_limit=1, output_dir=output_dir, deepspeed=deepspeed, ), callbacks=[SplitEpochSaveCallback(step_size=0.25)], data_collator=lambda batch: batch ) trainer.add_callback(trainer.train_dataset) trainer.train(resume_from_checkpoint=last_checkpoint) if trainer.is_deepspeed_enabled: # https://huggingface.co/docs/accelerate/v0.11.0/en/deepspeed#saving-and-loading from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint last_checkpoint = get_last_checkpoint(output_dir) load_state_dict_from_zero_checkpoint(trainer.model.float(), last_checkpoint) trainer.model.unload_cross_attn_adapter() trainer.save_model(output_dir) trainer.save_state() processor.processor.save_pretrained(output_dir) return model, processor ``` ## /detikzify/train/pretrain.py ```py path="/detikzify/train/pretrain.py" import copy from functools import partial import os from typing import List from transformers import Trainer, TrainingArguments IGNORE_INDEX = -100 WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) def tokenize( batch, processor, **kwargs ): image_token = processor.image_token image_token_id = processor.tokenizer.convert_tokens_to_ids(image_token) input_ids = processor( text=batch['text'], images=batch['image'], max_length=processor.tokenizer.model_max_length, pad_to_multiple_of=8, add_eos_token=True, **kwargs ) input_ids['labels'] = copy.deepcopy(input_ids['input_ids']) # do not train on image and pad tokens for label_ids in input_ids['labels']: for idx, label_id in enumerate(label_ids): if label_id in {image_token_id, processor.tokenizer.pad_token_id}: label_ids[idx] = IGNORE_INDEX return input_ids def train( output_dir: str, model, processor, dataset, deepspeed=None, # training hyperparams batch_size: int = 256, micro_batch_size: int = 1, num_epochs: int = 1, learning_rate: float = 1e-3, gradient_checkpointing: bool = False, full_finetune_modules: List[str] = [ "modality_projection", ], ): gradient_accumulation_steps = batch_size // micro_batch_size if WORLD_SIZE != 1: gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE for name, param in model.named_parameters(): if not any(module in name for module in full_finetune_modules): param.requires_grad = False dataset.set_transform(partial( tokenize, processor=processor, return_tensors="pt", truncation=True, padding=True )) trainer = Trainer( model=model, train_dataset=dataset, args=TrainingArguments( per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, gradient_checkpointing=gradient_checkpointing, # https://github.com/huggingface/transformers/issues/21381 gradient_checkpointing_kwargs={'use_reentrant':False}, dataloader_num_workers=WORLD_SIZE, warmup_ratio=0.03, weight_decay=0, num_train_epochs=num_epochs, learning_rate=learning_rate, torch_compile=True, bf16=True, tf32=True, logging_steps=10, lr_scheduler_type="cosine", optim="adamw_torch" if deepspeed else "adamw_torch_fused", ddp_find_unused_parameters=False, remove_unused_columns=False, save_strategy="no", report_to="none", output_dir=output_dir, deepspeed=deepspeed, ) ) if trainer.is_deepspeed_enabled and trainer.accelerator.state.deepspeed_plugin.hf_ds_config.is_zero3(): raise ValueError("Pretraining with zero stage 3 is not yet supported.") trainer.train() model.save_pretrained( output_dir, state_dict={ name: weight for name, weight in model.state_dict().items() if any(key_match in name for key_match in full_finetune_modules) }, ) trainer.save_state() return model, processor ``` ## /detikzify/train/train.py ```py path="/detikzify/train/train.py" from io import BytesIO import os from random import random from typing import Dict, List from PIL import Image import torch from torch.utils.data import Dataset from transformers import Trainer, TrainerCallback, TrainingArguments from transformers.trainer_utils import get_last_checkpoint from transformers.utils import logging from ..util import SketchAugment, SplitEpochSaveCallback from .pretrain import tokenize logger = logging.get_logger("transformers") WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) RANK = int(os.environ.get("RANK", 0)) class ImageSketchDataset(Dataset, TrainerCallback): """ Dataset which samples sketches instead of images, when a sketch exists for the current epoch. """ def __init__(self, dataset, processor, ds_sketch_ratio=.5): super().__init__() self.processor = processor self.dataset = dataset.with_transform(self.tokenize) self.ds_sketch_ratio = ds_sketch_ratio self.sketchify = SketchAugment() self.cur_epoch = 0 def __len__(self): return len(self.dataset) def tokenize(self, batch): for idx, sketches in enumerate(batch['sketches']): if (sketch:=sketches[self.cur_epoch]): if random() >= self.ds_sketch_ratio: batch['image'][idx] = Image.open(BytesIO(sketch['bytes'])).convert("RGB") else: batch['image'][idx] = self.sketchify(batch['image'][idx]) return tokenize( batch=batch, processor=self.processor, return_tensors="pt", truncation=False, padding=True ) def filter(self, *args, **kwargs): self.dataset = self.dataset.filter(*args, **kwargs) def __getitem__(self, index) -> Dict[str, torch.Tensor]: return self.dataset[index] def __getitems__(self, indices) -> Dict[str, List[torch.Tensor]]: return self.dataset[*indices] def on_epoch_end(self, *args, **kwargs): self.cur_epoch += 1 def train( output_dir: str, model, processor, dataset, overwrite=False, deepspeed=None, # training hyperparams batch_size: int = 128, micro_batch_size: int = 1, num_epochs: int = 5, learning_rate: float = 5e-5, sketch_ratio=.5, gradient_checkpointing: bool = False, ): assert num_epochs <= len(dataset[0]['sketches']) gradient_accumulation_steps = batch_size // micro_batch_size if WORLD_SIZE != 1: gradient_accumulation_steps = gradient_accumulation_steps // WORLD_SIZE dataset = ImageSketchDataset(dataset, processor, ds_sketch_ratio=sketch_ratio) logger.info(f"Dataset size before filtering out too long examples: {len(dataset)}") eos_token_id, model_max_length = processor.tokenizer.eos_token_id, processor.tokenizer.model_max_length dataset.filter(lambda ex: (ex['input_ids'] == eos_token_id).nonzero() < model_max_length) logger.info(f"Dataset size after filtering out too long examples: {len(dataset)}") last_checkpoint = None if os.path.isdir(output_dir) and not overwrite: last_checkpoint = get_last_checkpoint(output_dir) if last_checkpoint is None and len(os.listdir(output_dir)) > 0: raise ValueError( f"Output directory ({output_dir}) already exists and is not empty. " "Use `overwrite` to overcome." ) elif last_checkpoint is not None: logger.info( f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " "the `output_dir` or add `overwrite` to train from scratch." ) trainer = Trainer( model=model, train_dataset=dataset, args=TrainingArguments( per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, gradient_checkpointing=gradient_checkpointing, # https://github.com/huggingface/transformers/issues/32576 gradient_checkpointing_kwargs={'use_reentrant':False}, dataloader_num_workers=WORLD_SIZE, warmup_ratio=0.03, weight_decay=0, num_train_epochs=num_epochs, learning_rate=learning_rate, torch_compile=True, bf16=True, tf32=True, logging_steps=10, lr_scheduler_type="cosine", optim="adamw_torch" if deepspeed else "adamw_torch_fused", ddp_find_unused_parameters=False, remove_unused_columns=False, save_strategy="epoch", report_to="none", save_total_limit=1, output_dir=output_dir, deepspeed=deepspeed, ), callbacks=[SplitEpochSaveCallback(step_size=0.25)], data_collator=lambda batch: batch ) trainer.add_callback(trainer.train_dataset) trainer.train(resume_from_checkpoint=last_checkpoint) if trainer.is_deepspeed_enabled: # https://huggingface.co/docs/accelerate/v0.11.0/en/deepspeed#saving-and-loading from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint last_checkpoint = get_last_checkpoint(output_dir) load_state_dict_from_zero_checkpoint(trainer.model.float(), last_checkpoint) trainer.save_model(output_dir) trainer.save_state() return model, processor ``` ## /detikzify/util/__init__.py ```py path="/detikzify/util/__init__.py" from .functools import * from .generation import * from .image import * from .subprocess import * from .torch import * from .trainer import * ``` ## /detikzify/util/functools.py ```py path="/detikzify/util/functools.py" from collections import defaultdict from collections.abc import Callable from copy import copy from functools import cache, wraps from typing import Any def cache_cast(cast_func: Callable[..., Any]): """ functools.cache which takes a user-defined function to convert arguments into something immutable so it can be cached. """ def decorator(func): cache_args, cache_kwargs = None, None @cache def cached_func(_): return func(*cache_args, **cache_kwargs) @wraps(func) def wrapped_func(*args, **kwargs): nonlocal cache_args, cache_kwargs cache_args, cache_kwargs = args, kwargs return cached_func(cast_func(*args, **kwargs)) return wrapped_func return decorator def cast(cls, object): clone = copy(object) clone.__class__ = cls return clone # https://stackoverflow.com/a/12377059 def listify(fn=None, wrapper=list): """ A decorator which wraps a function's return value in ``list(...)``. Useful when an algorithm can be expressed more cleanly as a generator but the function should return a list. Example:: >>> @listify ... def get_lengths(iterable): ... for i in iterable: ... yield len(i) >>> get_lengths(["spam", "eggs"]) [4, 4] >>> >>> @listify(wrapper=tuple) ... def get_lengths_tuple(iterable): ... for i in iterable: ... yield len(i) >>> get_lengths_tuple(["foo", "bar"]) (3, 3) """ def listify_return(fn): @wraps(fn) def listify_helper(*args, **kw): return wrapper(fn(*args, **kw)) return listify_helper if fn is None: return listify_return return listify_return(fn) def batchify(fn=None): def batch(list_of_dicts): batch_dict = defaultdict(list) for d in list_of_dicts: for k, v in d.items(): batch_dict[k].append(v) return batch_dict return listify(fn=fn, wrapper=batch) ``` ## /detikzify/util/generation.py ```py path="/detikzify/util/generation.py" from queue import Queue from typing import Optional from transformers import StoppingCriteria from transformers.generation import streamers class ExplicitAbort(StoppingCriteria): """ Abort a model generation explicitly (i.e., when using a streamer in a thread). """ def __init__(self): super().__init__() self.should_stop = False def __call__(self, input_ids, scores, **kwargs) -> bool: return self.should_stop def reset(self): self.should_stop = False return self def abort(self): self.should_stop = True class TokenStreamer(streamers.BaseStreamer): """ Stream raw token ids (i.e., not decoded strings). """ def __init__(self, skip_prompt: bool = True, timeout: Optional[float] = None): self.skip_prompt = skip_prompt self.next_tokens_are_prompt = True self.token_queue = Queue() self.stop_signal = None self.timeout = timeout def put(self, value): if len(value.shape) > 1 and value.shape[0] > 1: raise ValueError("TokenStreamer only supports batch size 1") elif len(value.shape) > 1: value = value[0] if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return for token_id in value.tolist(): self.token_queue.put(token_id, timeout=self.timeout) def end(self): self.next_tokens_are_prompt = True self.token_queue.put(self.stop_signal, timeout=self.timeout) def propagate_error(self, exc): self.token_queue.put(exc, timeout=self.timeout) def __iter__(self): return self def __next__(self): value = self.token_queue.get(timeout=self.timeout) if value == self.stop_signal: raise StopIteration() elif isinstance(value, BaseException): raise value else: return value class TextIteratorStreamer(streamers.TextIteratorStreamer): def propagate_error(self, exc): self.text_queue.put(exc, timeout=self.timeout) def __next__(self): value = self.text_queue.get(timeout=self.timeout) if value == self.stop_signal: raise StopIteration() elif isinstance(value, BaseException): raise value else: return value class StreamerList(list, streamers.BaseStreamer): """ Similar to StoppingCriteriaList, only for Streamers. """ def put(self, value): for streamer in self: streamer.put(value) def end(self): for streamer in self: streamer.end() def unwrap_processor(processor): """ Unwrap a processor, nested processors can happen when using the adapter processor. """ if hasattr(processor, "processor"): return unwrap_processor(processor.processor) else: return processor ``` ## /detikzify/util/image.py ```py path="/detikzify/util/image.py" from base64 import b64decode from codecs import encode from io import BytesIO from os.path import isfile from PIL import Image, ImageChops, ImageOps import pymupdf import requests from transformers.utils.hub import is_remote_url DUMMY_IMAGE = Image.new("RGB", (24, 24), color="white") def convert(image, filetype): image.save(imgbytes:=BytesIO(), format=filetype) return Image.open(imgbytes) def remove_alpha(image, bg): # https://stackoverflow.com/a/62414364 background = Image.new('RGBA', image.size, bg) alpha_composite = Image.alpha_composite(background, image.convert("RGBA")) return alpha_composite.convert("RGB") # https://stackoverflow.com/a/10616717 def trim(image, bg="white"): bg = Image.new(image.mode, image.size, bg) diff = ImageChops.difference(image, bg) #diff = ImageChops.add(diff, diff, 2.0, -10) return image.crop(bbox) if (bbox:=diff.getbbox()) else image def expand(image, size, do_trim=False, bg="white"): """Expand image to a square of size {size}. Optionally trims borders first.""" image = trim(image, bg=bg) if do_trim else image return ImageOps.pad(image, (size, size), color=bg, method=Image.Resampling.LANCZOS) # based on transformers/image_utils.py (added support for rgba images) def load(image: Image.Image | str | bytes, bg="white", timeout=None): if isinstance(image, bytes): # assume image bytes and open image = Image.open(BytesIO(image)) elif isinstance(image, str): if is_remote_url(image): # https://stackoverflow.com/a/69791396 headers = {'user-agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:68.0) Gecko/20100101 Firefox/68.0'} image = Image.open(BytesIO(requests.get(image, timeout=timeout, headers=headers).content)) elif isfile(image): image = Image.open(image) else: try: image.removeprefix("data:image/") image = Image.open(BytesIO(b64decode(image))) except Exception as e: raise ValueError( "Incorrect image source. " "Must be a valid URL starting with `http://` or `https://`, " "a valid path to an image file, bytes, or a base64 encoded " f"string. Got {image}. Failed with {e}" ) image = ImageOps.exif_transpose(image) # type: ignore return remove_alpha(image, bg=bg) def redact(doc, rot_13=False): for page in (copy:=pymupdf.open("pdf", doc.tobytes())): for word in page.get_text("words", clip=pymupdf.INFINITE_RECT()): # type: ignore text = encode(word[4], "rot13") if rot_13 else None page.add_redact_annot(word[:4], text=text, fill=False) # type: ignore page.apply_redactions( # type: ignore images=pymupdf.PDF_REDACT_IMAGE_NONE, # type: ignore graphics=pymupdf.PDF_REDACT_LINE_ART_NONE # type: ignore ) return copy ``` ## /detikzify/util/subprocess.py ```py path="/detikzify/util/subprocess.py" from os import killpg, getpgid from subprocess import Popen, TimeoutExpired, CalledProcessError, CompletedProcess, PIPE from signal import SIGKILL def safe_killpg(pid, signal): try: killpg(getpgid(pid), signal) except ProcessLookupError: pass # Supress the race condition error; bpo-40550. # Patch subprocess.run and subprocess.check_output to also kill children of the # started process on timeouts (cf. # https://alexandra-zaharia.github.io/posts/kill-subprocess-and-its-children-on-timeout-python/) def run(*popenargs, input=None, timeout=None, check=False, **kwargs): with Popen(*popenargs, start_new_session=True, **kwargs) as process: try: stdout, stderr = process.communicate(input, timeout=timeout) except TimeoutExpired: safe_killpg(process.pid, SIGKILL) process.wait() raise except: safe_killpg(process.pid, SIGKILL) raise retcode = process.poll() if check and retcode: raise CalledProcessError(retcode, process.args, output=stdout, stderr=stderr) return CompletedProcess(process.args, retcode, stdout, stderr) # type: ignore def check_output(*popenargs, timeout=None, **kwargs): return run(*popenargs, stdout=PIPE, timeout=timeout, check=True, **kwargs).stdout ``` ## /detikzify/util/torch.py ```py path="/detikzify/util/torch.py" from torch.cuda import is_available as is_torch_cuda_available from transformers.utils import is_torch_npu_available, is_torch_xpu_available # https://github.com/huggingface/peft/blob/c4cf9e7d3b2948e71ec65a19e6cd1ff230781d13/src/peft/utils/other.py#L60-L71 def infer_device(): if is_torch_cuda_available(): torch_device = "cuda" elif is_torch_xpu_available(): torch_device = "xpu" elif is_torch_npu_available(): torch_device = "npu" else: torch_device = "cpu" return torch_device ``` ## /detikzify/util/trainer.py ```py path="/detikzify/util/trainer.py" from functools import partial from numpy import arange import torch from torchvision import tv_tensors from torchvision.transforms import v2 from transformers import ( IntervalStrategy, TrainerCallback, TrainerControl, TrainerState, TrainingArguments, ) from transformers.trainer_utils import has_length from torchvision.transforms.v2._utils import query_size class SplitEpochSaveCallback(TrainerCallback): """ If save_strategy==EPOCH also save checkpoints at arbitrary fractions of an epoch (controlled by step_size). """ def __init__(self, step_size: float = 0.5): self.steps = arange(step_size, 1, step_size) def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): if has_length(train_dataloader:=kwargs['train_dataloader']): self.num_update_steps_per_epoch = max(len(train_dataloader) // args.gradient_accumulation_steps, 1) else: self.num_update_steps_per_epoch = args.max_steps def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): # type: ignore steps = [round(self.num_update_steps_per_epoch * step) for step in self.steps] if ( state.global_step % self.num_update_steps_per_epoch in steps and args.save_strategy == IntervalStrategy.EPOCH ): control.should_save = True return control class SketchAugment(v2.Compose): def __init__(self, intensity=1): super().__init__([ v2.RandomOrder([ v2.ElasticTransform(alpha=50. * intensity, fill=255), v2.JPEG((40 * intensity, 100)), v2.ColorJitter(brightness=(.75 + .25 * intensity, 1.75)), v2.RandomEqualize(), v2.RandomGrayscale() ]), v2.RGB() ]) class FullErase(v2.Lambda): def __init__(self, value=255): super().__init__(partial(v2.functional.erase, i=0, j=0, h=-1, w=-1, v=torch.tensor(value))) class EditBase(v2.Transform): def __init__(self, *, alpha: float = 1.0) -> None: super().__init__() self.alpha = float(alpha) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) def _get_boxes(self, flat_inputs): lam = self._dist.sample((len(flat_inputs),)).squeeze() # type: ignore H, W = query_size(flat_inputs) r_x = torch.randint(W, size=(len(flat_inputs),)) r_y = torch.randint(H, size=(len(flat_inputs),)) r = 0.5 * torch.sqrt(1.0 - lam) r_w_half = (r * W).int() r_h_half = (r * H).int() x1 = torch.clamp(r_x - r_w_half, min=0) y1 = torch.clamp(r_y - r_h_half, min=0) x2 = torch.clamp(r_x + r_w_half, max=W) y2 = torch.clamp(r_y + r_h_half, max=H) grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H), indexing="ij") grid_x = grid_x.unsqueeze(0).expand(len(flat_inputs), -1, -1) grid_y = grid_y.unsqueeze(0).expand(len(flat_inputs), -1, -1) mask = (grid_x >= x1.unsqueeze(1).unsqueeze(2)) & (grid_x < x2.unsqueeze(1).unsqueeze(2)) & \ (grid_y >= y1.unsqueeze(1).unsqueeze(2)) & (grid_y < y2.unsqueeze(1).unsqueeze(2)) return mask.unsqueeze(1).expand(-1, 3, -1, -1) class EditCutMix(EditBase): def _transform(self, inpt, params): output = inpt.clone() rolled = inpt.roll(1, 0) box = self._get_boxes(inpt) output[box] = rolled[box] if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)): output = tv_tensors.wrap(output, like=inpt) return output class EditMixUp(EditBase): def _transform(self, inpt, params): lam = self._dist.sample((len(inpt),)).view(-1, *([1] * len(inpt.shape[1:]))) # type: ignore output = inpt.roll(1, 0).mul(1.0 - lam).add_(inpt.mul(lam)).to(inpt.dtype) if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)): output = tv_tensors.wrap(output, like=inpt) return output class EditCutOut(EditBase): def __init__(self, *args, value=255, **kwargs): self.value = value super().__init__(*args, **kwargs) def _transform(self, inpt, params): output = inpt.clone() box = self._get_boxes(inpt) output[box] = self.value if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)): output = tv_tensors.wrap(output, like=inpt) return output ``` ## /detikzify/webui/README.md # Web UI The web UI of DeTi*k*Zify requires [TeX Live 2023](https://www.tug.org/texlive), [ghostscript](https://www.ghostscript.com), and [poppler](https://poppler.freedesktop.org). You can launch it by running `python -m detikzify.webui`. It comes with a command line interface. With the `--share` flag, for example, you can create a shareable link. Checkout `--help` for a full list of supported options. As scientific figures usually use black fonts on a white background, it is best to use the web UI in light mode. This can be enforced by using the `--light` flag. If [FlashAttention-2]( https://huggingface.co/docs/transformers/en/perf_infer_gpu_one?install=NVIDIA#flashattention-2) is installed, it is picked up automatically and should boost inference speeds. ## Usage Tips **Visual Prompting** Creating sketches for DeTi*k*Zify (or providing any input images) shares many similarities with the process of [prompting large language models](https://en.wikipedia.org/wiki/Prompt_engineering). If DeTi*k*Zify struggles to comprehend your intent, consider "rephrasing" your input. This could mean simplifying your sketches or focusing more on the key issue at hand. In [this](https://github.com/potamides/DeTikZify/assets/53401822/2819ebca-81f6-4173-8809-0b4255d3e976) particular instance, for example, we attempted to "prompt" DeTikZify to align characters diagonally around an equal sign, but it was unsuccessful even after many simulations. However, upon adjusting the input (by reducing the stroke width and using more easily recognizable characters) we achieved the [intended output](https://github.com/potamides/DeTikZify/assets/53401822/c8ecfbff-d22e-41d5-8f73-e0cfafe88690) after only one simulation. **Image Editor** You can draw sketches in the integrated image editor, but its feature set is quite limited. If you are not satisfied with the synthesized Ti*k*Z programs, try drawing more elaborate sketches in an editor of your choice (perhaps with graphics primitives) and upload them into the UI. Alternatively, experimenting with line thickness and/or colors in the integrated editor might also help. **Input Postprocessing** Please note that all input images are cropped to the smallest square around their content and then resized to the resolution DeTi*k*Zify expects. If you leave large margins this means that DeTi*k*Zify might perceive your input differently from how you intended (e.g., by drawing thicker axes). As a rule of thumb, always try to fill as much of the canvas as possible. **Input Complexity** If you provide very complex sketches (or figures) and are not satisfied with the results, you can also try segmenting (or simplifying) your input and letting DeTi*k*Zify synthesize the individual pieces independently. This has the advantage that the results will probably be better, and the disadvantage that you will have to modify and assemble the pieces yourself. **Source Code Artifacts** Due to the way we preprocess our [arXiv.org](https://arxiv.org) data, the preambles of the extracted Ti*k*Z programs sometimes include packages that are not used inside the `tikzpicture` environments, and the DeTi*k*Zify models pick up on this behavior. While this does not hinder compilation in any way, we still recommend everyone to check the generated preambles and clean them up, if necessary. **Accuracy-Efficiency Trade-Off** We noticed that lower values for temperatures and top-p (nucleus) values force DeTi*k*Zify to generate Ti*k*Z programs that follow the input images more closely, at the expense of generating more compile-time errors. We pick sensible defaults that aim to balance these two aspects, but you might want to try to tune these parameters yourself. **External Graphics** In DaTi*k*Zv2, we replace any externally included graphics in the `tikzpicture` environments with the [example image](https://mirrors.ctan.org/macros/latex/contrib/mwe/example-image.pdf) placeholder from the [mwe](http://www.ctan.org/pkg/mwe) package. So if you want to generate code with placeholders for your own external graphics, just draw that example image. ## /detikzify/webui/__init__.py ```py path="/detikzify/webui/__init__.py" from .webui import * from .strings import * from .helpers import * ``` ## /detikzify/webui/__main__.py ```py path="/detikzify/webui/__main__.py" from argparse import ArgumentParser from .strings import ALGORITHMS, MODELS from .webui import build_ui def parse_args(): argument_parser = ArgumentParser( description="Web UI for DeTikZify." ) argument_parser.add_argument( "--model", default=list(MODELS)[0], help="Initially selected model. You can also specify a path to your own models.", ) argument_parser.add_argument( "--algorithm", default=list(ALGORITHMS)[0], choices=list(ALGORITHMS), help="The inference algorithm to use.", ) argument_parser.add_argument( "--lock", action="store_true", help="Whether to allow users to change the model or not.", ) argument_parser.add_argument( "--lock_reason", default="Duplicate this space to be able to change this value.", help="Additional information why model selection is locked.", ) argument_parser.add_argument( "--share", action="store_true", help="Whether to create a publicly shareable link for the interface.", ) argument_parser.add_argument( "--light", action="store_true", help= "Whether to enforce light theme (useful for vector graphics with dark text)." ) argument_parser.add_argument( "--timeout", default=60, type=int, help="Allowed timeframe for compilation.", ) return vars(argument_parser.parse_args()) if __name__ == "__main__": args = parse_args() share = args.pop("share") build_ui(**args).queue().launch(share=share) ``` ## /detikzify/webui/helpers.py ```py path="/detikzify/webui/helpers.py" from functools import cache, lru_cache from inspect import signature from operator import itemgetter from os import fdopen from tempfile import mkstemp import gradio as gr from ..infer import TikzDocument from ..model import load def to_svg( tikzdoc: TikzDocument, build_dir: str ): if not tikzdoc.is_rasterizable: if tikzdoc.compiled_with_errors: raise gr.Error("TikZ code did not compile!") else: gr.Warning("TikZ code compiled to an empty image!") elif tikzdoc.compiled_with_errors: gr.Warning("TikZ code compiled with errors!") fd, path = mkstemp(dir=build_dir, suffix=".svg") with fdopen(fd, "w") as f: if pdf:=tikzdoc.pdf: f.write(pdf[0].get_svg_image()) return path if pdf else None # https://stackoverflow.com/a/50992575 def make_ordinal(n): n = int(n) if 11 <= (n % 100) <= 13: suffix = 'th' else: suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)] return str(n) + suffix class MctsOutputs(set): def __init__(self, build_dir, *args, **kwargs): super().__init__(*args, **kwargs) self.build_dir, self.svgmap, self.fails = build_dir, dict(), 0 def add(self, score, tikzdoc): # type: ignore if (score, tikzdoc) not in self: try: svg = to_svg(tikzdoc, build_dir=self.build_dir) super().add((score, tikzdoc)) self.svgmap[tikzdoc] = svg except gr.Error: gr.Warning("TikZ code did not compile, discarding output!") if len(self): self.fails += 1 elif len(self): self.fails += 1 @property def programs(self): return [tikzdoc.code for _, tikzdoc in sorted(self, key=itemgetter(0), reverse=True)] @property def images(self): return [ (self.svgmap[tikzdoc], make_ordinal(idx)) for idx, (_, tikzdoc) in enumerate(sorted(self, key=itemgetter(0), reverse=True), 1) ] @property def first_success(self): return len(self) == 1 and not self.fails def make_light(stylable): """ Patch gradio to only contain light mode colors. """ if isinstance(stylable, gr.themes.Base): # remove dark variants from the entire theme params = signature(stylable.set).parameters colors = {color: getattr(stylable, color.removesuffix("_dark")) for color in dir(stylable) if color in params} return stylable.set(**colors) elif isinstance(stylable, gr.Blocks): # also handle components which do not use the theme (e.g. modals) stylable.load( fn=None, js="() => document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'))" ) return stylable else: raise ValueError @lru_cache(maxsize=1) def cached_load(*args, **kwargs): gr.Info("Instantiating model. This could take a while...") return load(*args, **kwargs) @cache def info_once(message): gr.Info(message) class GeneratorLock: """ Ensure that only one instance of a given generator is active. Useful when a previous invocation was canceled. See https://github.com/gradio-app/gradio/issues/8503 for more information. """ def __init__(self, gen_func): self.gen_func = gen_func self.generator = None def generate(self, *args, **kwargs): if self.generator: if self.generator.gi_running: return # somehow we can end up here self.generator.close() self.generator = self.gen_func(*args, **kwargs) yield from self.generator def __call__(self, *args, **kwargs): yield from self.generate(*args, **kwargs) ``` ## /detikzify/webui/strings.py ```py path="/detikzify/webui/strings.py" from os.path import basename from transformers import is_timm_available BANNER = '''\

DeTikZify: Synthesizing Graphics Programs for Scientific Figures and Sketches with TikZ

View on arXiv View on GitHub View on Hugging Face Open in Colab

''' MODELS = { basename(model): model for model in [ "nllg/detikzify-v2-8b", #"nllg/detikzify-v2-3b", # coming soon ] } if is_timm_available(): MODELS |= { basename(model).replace("detikzify", "detikzify-v1"): model for model in [ "nllg/detikzify-ds-7b", "nllg/detikzify-cl-7b", "nllg/detikzify-ds-1.3b", "nllg/detikzify-tl-1.1b", ] } ALGORITHMS = { "mcts": "MCTS", "sampling": "Sampling" } # https://github.com/gradio-app/gradio/issues/3202#issuecomment-1741571240 # https://github.com/gradio-app/gradio/issues/2666#issuecomment-1651127149 # https://stackoverflow.com/a/64033350 CSS = """ .input-image { flex-grow: 1; } .output-code { flex-grow: 1; height: 0vh; min-height: 250px; scrollbar-width: thin !important; } .output-code .hide { display: none; } .output-code .cm-scroller { flex-grow: 1; } .output-code .cm-gutters { position: relative !important; } .output-image { flex-grow: 1; height: 0vh; min-height: 250px; overflow-y: auto !important; scrollbar-width: thin !important; } .output-image .image-container, .output-image .grid-container { width: 100%; height: 100%; } .output-image .thumbnail-item img { object-fit: contain; } .output-image .grid-wrap.fixed-height { max-height: 100% !important; } .outputs .tabs { display: flex; flex-direction: column; flex-grow: 1; } .outputs .tabitem[style="display: block;"] { flex-grow: 1; display: flex !important; } .outputs .gap { flex-grow: 1; } .outputs .form { flex-grow: 1 !important; } .outputs .form > :last-child{ flex-grow: 1; } """ # (Ab)use an invisible fake button with id preview-close to propagate the # actual press of the button that closes a preview # https://github.com/gradio-app/gradio/issues/6697 GALLERY_DESELECT_HACK = """ """ ``` The content has been capped at 50000 tokens, and files over NaN bytes have been omitted. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.