```
├── .dockerignore
├── .env.example
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── DOCKER.md
├── LICENSE
├── README.md
├── __init__.py
├── assets/
├── morphik_logo.png
├── core/
├── __init__.py
├── api.py
├── cache/
├── base_cache.py
├── base_cache_factory.py
├── hf_cache.py
├── llama_cache.py
├── llama_cache_factory.py
├── completion/
├── __init__.py
├── base_completion.py
├── litellm_completion.py
├── config.py
├── database/
├── base_database.py
├── postgres_database.py
├── user_limits_db.py
├── embedding/
├── __init__.py
├── base_embedding_model.py
```
## /.dockerignore
```dockerignore path="/.dockerignore"
# flyctl launch added from .gitignore
# Python-related files
**/*__pycache__
**/*.pyc
**/*.pyo
**/*.pyd
**/.Python
**/env
**/.env
**/venv/*
**/ENV
**/dist
**/build
**/*.egg-info
**/.eggs
**/*.egg
**/*.pytest_cache
# Virtual environment
**/.venv
**/.vscode
**/*.DS_Store
# flyctl launch added from .pytest_cache/.gitignore
# Created by pytest automatically.
.pytest_cache/**/*
# flyctl launch added from .venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/.gitignore
.venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/pydevd_frame_evaluator.*.so
.venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/pydevd_frame_evaluator.*.pyd
.venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/pydevd_frame_evaluator.pyx
# flyctl launch added from .venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/linux_and_mac/.gitignore
.venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/linux_and_mac/attach_x86.dylib
.venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/linux_and_mac/attach_x86_64.dylib
.venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/linux_and_mac/attach_linux_x86.o
.venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/linux_and_mac/attach_linux_x86_64.o
fly.toml
```
## /.env.example
```example path="/.env.example"
JWT_SECRET_KEY="..." # Required in production, optional in dev mode (dev_mode=true in morphik.toml)
POSTGRES_URI="postgresql+asyncpg://postgres:postgres@localhost:5432/morphik" # Required for PostgreSQL database
UNSTRUCTURED_API_KEY="..." # Optional: Needed for parsing via unstructured API
OPENAI_API_KEY="..." # Optional: Needed for OpenAI embeddings and completions
ASSEMBLYAI_API_KEY="..." # Optional: Needed for combined parser
ANTHROPIC_API_KEY="..." # Optional: Needed for contextual parser
AWS_ACCESS_KEY="..." # Optional: Needed for AWS S3 storage
AWS_SECRET_ACCESS_KEY="..." # Optional: Needed for AWS S3 storage
```
## /.gitignore
```gitignore path="/.gitignore"
# Python-related files
*__pycache__/
*.pyc
*.pyo
*.pyd
.Python
env/
.env
venv/*
ENV/
dist/
build/
*.egg-info/
.eggs/
*.egg
*.pytest_cache/
core/tests/output
core/tests/assets
# Virtual environment
.venv/
.vscode/
*.DS_Store
storage/*
logs/*
samples/*
aggregated_code.txt
offload/*
test.pdf
experiments/*
ee/ui-component/package-lock.json/*
ee/ui-component/node-modules/*
ee/ui-component/.next
ui-component/notebook-storage/notebooks.json
ee/ui-component/package-lock.json
```
## /.pre-commit-config.yaml
```yaml path="/.pre-commit-config.yaml"
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
args: [--line-length=120]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.6
hooks:
- id: ruff
args: [--fix]
```
## /CODE_OF_CONDUCT.md
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
founders@morphik.ai.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.
## /DOCKER.md
# Docker Setup Guide for Morphik Core
Morphik Core provides a streamlined Docker-based setup that includes all necessary components: the core API, PostgreSQL with pgvector, and Ollama for AI models.
## Prerequisites
- Docker and Docker Compose installed on your system
- At least 10GB of free disk space (for models and data)
- 8GB+ RAM recommended
## Quick Start
1. Clone the repository and navigate to the project directory:
```bash
git clone https://github.com/morphik-org/morphik-core.git
cd morphik-core
```
2. First-time setup:
```bash
docker compose up --build
```
This command will:
- Build all required containers
- Download necessary AI models (nomic-embed-text and llama3.2)
- Initialize the PostgreSQL database with pgvector
- Start all services
The initial setup may take 5-10 minutes depending on your internet speed, as it needs to download the AI models.
3. For subsequent runs:
```bash
docker compose up # Start all services
docker compose down # Stop all services
```
4. To completely reset (will delete all data and models):
```bash
docker compose down -v
```
## Configuration
### 1. Default Setup
The default configuration works out of the box and includes:
- PostgreSQL with pgvector for document storage
- Ollama for AI models (embeddings and completions)
- Local file storage
- Basic authentication
### 2. Configuration File (morphik.toml)
The default `morphik.toml` is configured for Docker and includes:
```toml
[api]
host = "0.0.0.0" # Important: Use 0.0.0.0 for Docker
port = 8000
[completion]
provider = "ollama"
model_name = "llama3.2"
base_url = "http://ollama:11434" # Use Docker service name
[embedding]
provider = "ollama"
model_name = "nomic-embed-text"
base_url = "http://ollama:11434" # Use Docker service name
[database]
provider = "postgres"
[vector_store]
provider = "pgvector"
[storage]
provider = "local"
storage_path = "/app/storage"
```
### 3. Environment Variables
Create a `.env` file to customize these settings:
```bash
JWT_SECRET_KEY=your-secure-key-here # Important: Change in production
OPENAI_API_KEY=sk-... # Only if using OpenAI
HOST=0.0.0.0 # Leave as is for Docker
PORT=8000 # Change if needed
```
### 4. Custom Configuration
To use your own configuration:
1. Create a custom `morphik.toml`
2. Mount it in `docker-compose.yml`:
```yaml
services:
morphik:
volumes:
- ./my-custom-morphik.toml:/app/morphik.toml
```
## Accessing Services
- Morphik API: http://localhost:8000
- API Documentation: http://localhost:8000/docs
- Health Check: http://localhost:8000/health
## Storage and Data
- Database data: Stored in the `postgres_data` Docker volume
- AI Models: Stored in the `ollama_data` Docker volume
- Documents: Stored in `./storage` directory (mounted to container)
- Logs: Available in `./logs` directory
## Troubleshooting
1. **Service Won't Start**
```bash
# View all logs
docker compose logs
# View specific service logs
docker compose logs morphik
docker compose logs postgres
docker compose logs ollama
```
2. **Database Issues**
- Check PostgreSQL is healthy: `docker compose ps`
- Verify database connection: `docker compose exec postgres psql -U morphik -d morphik`
3. **Model Download Issues**
- Check Ollama logs: `docker compose logs ollama`
- Ensure enough disk space for models
- Try restarting Ollama: `docker compose restart ollama`
4. **Performance Issues**
- Monitor resources: `docker stats`
- Ensure sufficient RAM (8GB+ recommended)
- Check disk space: `df -h`
## Production Deployment
For production environments:
1. **Security**:
- Change the default `JWT_SECRET_KEY`
- Use proper network security groups
- Enable HTTPS (recommended: use a reverse proxy)
- Regularly update containers and dependencies
2. **Persistence**:
- Use named volumes for all data
- Set up regular backups of PostgreSQL
- Back up the storage directory
3. **Monitoring**:
- Set up container monitoring
- Configure proper logging
- Use health checks
## Support
For issues and feature requests:
- GitHub Issues: [https://github.com/morphik-org/morphik-core/issues](https://github.com/morphik-org/morphik-core/issues)
- Documentation: [https://docs.morphik.ai](https://docs.morphik.ai)
## Repository Information
- License: MIT
## /LICENSE
``` path="/LICENSE"
Copyright (c) 2024-2025 Morphik, Inc.
Portions of this software are licensed as follows:
* All content that resides under the "ee/" directory of this repository, if that directory exists, is licensed under the license defined in "ee/LICENSE".
* All third party components incorporated into the Morphik Software are licensed under the original license provided by the owner of the applicable component.
* Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```
## /README.md
Docs - Community - Why Morphik? - Bug reports
## Morphik is an alternative to traditional RAG for highly technical and visual documents.
[Morphik](https://morphik.ai) provides developers the tools to ingest, search (deep and shallow), transform, and manage unstructured and multimodal documents. Some of our features include:
- [Multimodal Search](https://docs.morphik.ai/concepts/colpali): We employ techniques such as ColPali to build search that actually *understands* the visual content of documents you provide. Search over images, PDFs, videos, and more with a single endpoint.
- [Knowledge Graphs](https://docs.morphik.ai/concepts/knowledge-graphs): Build knowledge graphs for domain-specific use cases in a single line of code. Use our battle-tested system prompts, or use your own.
- [Fast and Scalable Metadata Extraction](https://docs.morphik.ai/concepts/rules-processing): Extract metadata from documents - including bounding boxes, labeling, classification, and more.
- [Integrations](https://docs.morphik.ai/integrations): Integrate with existing tools and workflows. Including (but not limited to) Google Suite, Slack, and Confluence.
- [Cache-Augmented-Generation](https://docs.morphik.ai/python-sdk/create_cache): Create persistent KV-caches of your documents to speed up generation.
The best part? Morphik has a [free tier](https://www.morphik.ai/pricing) and is open source! Get started by signing up at [Morphik](https://www.morphik.ai/signup).
## Table of Contents
- [Getting Started with Morphik](#getting-started-with-morphik-recommended)
- [Self-hosting the open-source version](#self-hosting-the-open-source-version)
- [Using Morphik](#using-morphik)
- [Contributing](#contributing)
- [Open source vs paid](#open-source-vs-paid)
## Getting Started with Morphik (Recommended)
The fastest and easiest way to get started with Morphik is by signing up for free at [Morphik](https://www.morphik.ai/signup). Your first 200 pages and 100 queries are on us! After this, you can pay based on usage with discounted rates for heavier use.
## Self-hosting the open-source version
If you'd like to self-host Morphik, you can find the dedicated instruction [here](https://docs.morphik.ai/getting-started). We offer options for direct installation and installation via docker.
**Important**: Due to limited resources, we cannot provide full support for open-source deployments. We have an installation guide, and a [Discord community](https://discord.gg/BwMtv3Zaju) to help, but we can't guarantee full support.
## Using Morphik
Once you've signed up for Morphik, you can get started with ingesting and search your data right away.
### Code (Example: Python SDK)
For programmers, we offer a [Python SDK](https://docs.morphik.ai/python-sdk/morphik) and a [REST API](https://docs.morphik.ai/api-reference/health-check). Ingesting a file is as simple as:
```python
from morphik import Morphik
morphik = Morphik("")
morphik.ingest_file("path/to/your/super/complex/file.pdf")
```
Similarly, searching and querying your data is easy too:
```python
morphik.query("What's the height of screw 14-A in the chair assembly instructions?")
```
### Morphik Console
You can also interact with Morphik via the Morphik Console. This is a web-based interface that allows you to ingest, search, and query your data. You can upload files, connect to different data sources, and chat with your data all within the same place.
### Model Context Protocol
Finally, you can also access Morphik via MCP. Instructions are available [here](https://docs.morphik.ai/using-morphik/mcp).
## Contributing
You're welcome to contribute to the project! We love:
- Bug reports via [GitHub issues](https://github.com/morphik-org/morphik-core/issues)
- Feature requests via [GitHub issues](https://github.com/morphik-org/morphik-core/issues)
- Pull requests
Currently, we're focused on improving speed, integrating with more tools, and finding the research papers that provide the most value to our users. If you ahve thoughts, let us know in the discord or in GitHub!
## Open source vs paid
Certain features - such as Morphik Console - are not available in the open-source version. Any feature in the `ee` namespace is not available in the open-source version and carries a different license. Any feature outside that is open source under the MIT expat license.
## Contributors
Visit our special thanks page dedicated to our contributors [here](https://docs.morphik.ai/special-thanks).
## PS
We took inspiration from [PostHog](https://posthog.com) while writing this README. If you're from PostHog, thank you ❤️
## /__init__.py
```py path="/__init__.py"
```
## /assets/morphik_logo.png
Binary file available at https://raw.githubusercontent.com/morphik-org/morphik-core/refs/heads/main/assets/morphik_logo.png
## /core/__init__.py
```py path="/core/__init__.py"
```
## /core/api.py
```py path="/core/api.py"
import asyncio
import base64
import json
import logging
import uuid
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional
import arq
import jwt
import tomli
from fastapi import Depends, FastAPI, File, Form, Header, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from core.cache.llama_cache_factory import LlamaCacheFactory
from core.completion.litellm_completion import LiteLLMCompletionModel
from core.config import get_settings
from core.database.postgres_database import PostgresDatabase
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
from core.embedding.litellm_embedding import LiteLLMEmbeddingModel
from core.limits_utils import check_and_increment_limits
from core.models.auth import AuthContext, EntityType
from core.models.completion import ChunkSource, CompletionResponse
from core.models.documents import ChunkResult, Document, DocumentResult
from core.models.folders import Folder, FolderCreate
from core.models.graph import Graph
from core.models.prompts import validate_prompt_overrides_with_http_exception
from core.models.request import (
BatchIngestResponse,
CompletionQueryRequest,
CreateGraphRequest,
GenerateUriRequest,
IngestTextRequest,
RetrieveRequest,
SetFolderRuleRequest,
UpdateGraphRequest,
)
from core.parser.morphik_parser import MorphikParser
from core.reranker.flag_reranker import FlagReranker
from core.services.document_service import DocumentService
from core.services.telemetry import TelemetryService
from core.storage.local_storage import LocalStorage
from core.storage.s3_storage import S3Storage
from core.vector_store.multi_vector_store import MultiVectorStore
from core.vector_store.pgvector_store import PGVectorStore
# Initialize FastAPI app
app = FastAPI(title="Morphik API")
logger = logging.getLogger(__name__)
# Add health check endpoints
@app.get("/health")
async def health_check():
"""Basic health check endpoint."""
return {"status": "healthy"}
@app.get("/health/ready")
async def readiness_check():
"""Readiness check that verifies the application is initialized."""
return {
"status": "ready",
"components": {
"database": settings.DATABASE_PROVIDER,
"vector_store": settings.VECTOR_STORE_PROVIDER,
"embedding": settings.EMBEDDING_PROVIDER,
"completion": settings.COMPLETION_PROVIDER,
"storage": settings.STORAGE_PROVIDER,
},
}
# Initialize telemetry
telemetry = TelemetryService()
# Add OpenTelemetry instrumentation - exclude HTTP send/receive spans
FastAPIInstrumentor.instrument_app(
app,
excluded_urls="health,health/.*", # Exclude health check endpoints
exclude_spans=["send", "receive"], # Exclude HTTP send/receive spans to reduce telemetry volume
http_capture_headers_server_request=None, # Don't capture request headers
http_capture_headers_server_response=None, # Don't capture response headers
tracer_provider=None, # Use the global tracer provider
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize service
settings = get_settings()
# Initialize database
if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for PostgreSQL database")
database = PostgresDatabase(uri=settings.POSTGRES_URI)
# Redis settings already imported at top of file
@app.on_event("startup")
async def initialize_database():
"""Initialize database tables and indexes on application startup."""
logger.info("Initializing database...")
success = await database.initialize()
if success:
logger.info("Database initialization successful")
else:
logger.error("Database initialization failed")
# We don't raise an exception here to allow the app to continue starting
# even if there are initialization errors
@app.on_event("startup")
async def initialize_vector_store():
"""Initialize vector store tables and indexes on application startup."""
# First initialize the primary vector store (PGVectorStore if using pgvector)
logger.info("Initializing primary vector store...")
if hasattr(vector_store, "initialize"):
success = await vector_store.initialize()
if success:
logger.info("Primary vector store initialization successful")
else:
logger.error("Primary vector store initialization failed")
else:
logger.warning("Primary vector store does not have an initialize method")
# Then initialize the multivector store if enabled
if settings.ENABLE_COLPALI and colpali_vector_store:
logger.info("Initializing multivector store...")
# Handle both synchronous and asynchronous initialize methods
if hasattr(colpali_vector_store.initialize, "__awaitable__"):
success = await colpali_vector_store.initialize()
else:
success = colpali_vector_store.initialize()
if success:
logger.info("Multivector store initialization successful")
else:
logger.error("Multivector store initialization failed")
@app.on_event("startup")
async def initialize_user_limits_database():
"""Initialize user service on application startup."""
logger.info("Initializing user service...")
if settings.MODE == "cloud":
from core.database.user_limits_db import UserLimitsDatabase
user_limits_db = UserLimitsDatabase(uri=settings.POSTGRES_URI)
await user_limits_db.initialize()
@app.on_event("startup")
async def initialize_redis_pool():
"""Initialize the Redis connection pool for background tasks."""
global redis_pool
logger.info("Initializing Redis connection pool...")
# Get Redis settings from configuration
redis_host = settings.REDIS_HOST
redis_port = settings.REDIS_PORT
# Log the Redis connection details
logger.info(f"Connecting to Redis at {redis_host}:{redis_port}")
redis_settings = arq.connections.RedisSettings(
host=redis_host,
port=redis_port,
)
redis_pool = await arq.create_pool(redis_settings)
logger.info("Redis connection pool initialized successfully")
@app.on_event("shutdown")
async def close_redis_pool():
"""Close the Redis connection pool on application shutdown."""
global redis_pool
if redis_pool:
logger.info("Closing Redis connection pool...")
redis_pool.close()
await redis_pool.wait_closed()
logger.info("Redis connection pool closed")
# Initialize vector store
if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for pgvector store")
vector_store = PGVectorStore(
uri=settings.POSTGRES_URI,
)
# Initialize storage
match settings.STORAGE_PROVIDER:
case "local":
storage = LocalStorage(storage_path=settings.STORAGE_PATH)
case "aws-s3":
if not settings.AWS_ACCESS_KEY or not settings.AWS_SECRET_ACCESS_KEY:
raise ValueError("AWS credentials are required for S3 storage")
storage = S3Storage(
aws_access_key=settings.AWS_ACCESS_KEY,
aws_secret_key=settings.AWS_SECRET_ACCESS_KEY,
region_name=settings.AWS_REGION,
default_bucket=settings.S3_BUCKET,
)
case _:
raise ValueError(f"Unsupported storage provider: {settings.STORAGE_PROVIDER}")
# Initialize parser
parser = MorphikParser(
chunk_size=settings.CHUNK_SIZE,
chunk_overlap=settings.CHUNK_OVERLAP,
use_unstructured_api=settings.USE_UNSTRUCTURED_API,
unstructured_api_key=settings.UNSTRUCTURED_API_KEY,
assemblyai_api_key=settings.ASSEMBLYAI_API_KEY,
anthropic_api_key=settings.ANTHROPIC_API_KEY,
use_contextual_chunking=settings.USE_CONTEXTUAL_CHUNKING,
)
# Initialize embedding model
# Create a LiteLLM model using the registered model config
embedding_model = LiteLLMEmbeddingModel(
model_key=settings.EMBEDDING_MODEL,
)
logger.info(f"Initialized LiteLLM embedding model with model key: {settings.EMBEDDING_MODEL}")
# Initialize completion model
# Create a LiteLLM model using the registered model config
completion_model = LiteLLMCompletionModel(
model_key=settings.COMPLETION_MODEL,
)
logger.info(f"Initialized LiteLLM completion model with model key: {settings.COMPLETION_MODEL}")
# Initialize reranker
reranker = None
if settings.USE_RERANKING:
match settings.RERANKER_PROVIDER:
case "flag":
reranker = FlagReranker(
model_name=settings.RERANKER_MODEL,
device=settings.RERANKER_DEVICE,
use_fp16=settings.RERANKER_USE_FP16,
query_max_length=settings.RERANKER_QUERY_MAX_LENGTH,
passage_max_length=settings.RERANKER_PASSAGE_MAX_LENGTH,
)
case _:
raise ValueError(f"Unsupported reranker provider: {settings.RERANKER_PROVIDER}")
# Initialize cache factory
cache_factory = LlamaCacheFactory(Path(settings.STORAGE_PATH))
# Initialize ColPali embedding model if enabled
colpali_embedding_model = ColpaliEmbeddingModel() if settings.ENABLE_COLPALI else None
colpali_vector_store = MultiVectorStore(uri=settings.POSTGRES_URI) if settings.ENABLE_COLPALI else None
# Initialize document service with configured components
document_service = DocumentService(
storage=storage,
database=database,
vector_store=vector_store,
embedding_model=embedding_model,
completion_model=completion_model,
parser=parser,
reranker=reranker,
cache_factory=cache_factory,
enable_colpali=settings.ENABLE_COLPALI,
colpali_embedding_model=colpali_embedding_model,
colpali_vector_store=colpali_vector_store,
)
async def verify_token(authorization: str = Header(None)) -> AuthContext:
"""Verify JWT Bearer token or return dev context if dev_mode is enabled."""
# Check if dev mode is enabled
if settings.dev_mode:
return AuthContext(
entity_type=EntityType(settings.dev_entity_type),
entity_id=settings.dev_entity_id,
permissions=set(settings.dev_permissions),
user_id=settings.dev_entity_id, # In dev mode, entity_id is also the user_id
)
# Normal token verification flow
if not authorization:
raise HTTPException(
status_code=401,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
try:
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[7:] # Remove "Bearer "
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC):
raise HTTPException(status_code=401, detail="Token expired")
# Support both "type" and "entity_type" fields for compatibility
entity_type_field = payload.get("type") or payload.get("entity_type")
if not entity_type_field:
raise HTTPException(status_code=401, detail="Missing entity type in token")
return AuthContext(
entity_type=EntityType(entity_type_field),
entity_id=payload["entity_id"],
app_id=payload.get("app_id"),
permissions=set(payload.get("permissions", ["read"])),
user_id=payload.get("user_id", payload["entity_id"]), # Use user_id if available, fallback to entity_id
)
except jwt.InvalidTokenError as e:
raise HTTPException(status_code=401, detail=str(e))
@app.post("/ingest/text", response_model=Document)
@telemetry.track(operation_type="ingest_text", metadata_resolver=telemetry.ingest_text_metadata)
async def ingest_text(
request: IngestTextRequest,
auth: AuthContext = Depends(verify_token),
) -> Document:
"""
Ingest a text document.
Args:
request: IngestTextRequest containing:
- content: Text content to ingest
- filename: Optional filename to help determine content type
- metadata: Optional metadata dictionary
- rules: Optional list of rules. Each rule should be either:
- MetadataExtractionRule: {"type": "metadata_extraction", "schema": {...}}
- NaturalLanguageRule: {"type": "natural_language", "prompt": "..."}
- folder_name: Optional folder to scope the document to
- end_user_id: Optional end-user ID to scope the document to
auth: Authentication context
Returns:
Document: Metadata of ingested document
"""
try:
return await document_service.ingest_text(
content=request.content,
filename=request.filename,
metadata=request.metadata,
rules=request.rules,
use_colpali=request.use_colpali,
auth=auth,
folder_name=request.folder_name,
end_user_id=request.end_user_id,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
# Redis pool for background tasks
redis_pool = None
def get_redis_pool():
"""Get the global Redis connection pool for background tasks."""
return redis_pool
@app.post("/ingest/file", response_model=Document)
@telemetry.track(operation_type="queue_ingest_file", metadata_resolver=telemetry.ingest_file_metadata)
async def ingest_file(
file: UploadFile,
metadata: str = Form("{}"),
rules: str = Form("[]"),
auth: AuthContext = Depends(verify_token),
use_colpali: Optional[bool] = None,
folder_name: Optional[str] = Form(None),
end_user_id: Optional[str] = Form(None),
redis: arq.ArqRedis = Depends(get_redis_pool),
) -> Document:
"""
Ingest a file document asynchronously.
Args:
file: File to ingest
metadata: JSON string of metadata
rules: JSON string of rules list. Each rule should be either:
- MetadataExtractionRule: {"type": "metadata_extraction", "schema": {...}}
- NaturalLanguageRule: {"type": "natural_language", "prompt": "..."}
auth: Authentication context
use_colpali: Whether to use ColPali embedding model
folder_name: Optional folder to scope the document to
end_user_id: Optional end-user ID to scope the document to
redis: Redis connection pool for background tasks
Returns:
Document with processing status that can be used to check progress
"""
try:
# Parse metadata and rules
metadata_dict = json.loads(metadata)
rules_list = json.loads(rules)
# Fix bool conversion: ensure string "false" is properly converted to False
def str2bool(v):
return v if isinstance(v, bool) else str(v).lower() in {"true", "1", "yes"}
use_colpali = str2bool(use_colpali)
# Ensure user has write permission
if "write" not in auth.permissions:
raise PermissionError("User does not have write permission")
logger.debug(f"API: Queueing file ingestion with use_colpali: {use_colpali}")
# Create a document with processing status
doc = Document(
content_type=file.content_type,
filename=file.filename,
metadata=metadata_dict,
owner={"type": auth.entity_type.value, "id": auth.entity_id},
access_control={
"readers": [auth.entity_id],
"writers": [auth.entity_id],
"admins": [auth.entity_id],
"user_id": [auth.user_id] if auth.user_id else [],
},
system_metadata={"status": "processing"},
)
# Add folder_name and end_user_id to system_metadata if provided
if folder_name:
doc.system_metadata["folder_name"] = folder_name
if end_user_id:
doc.system_metadata["end_user_id"] = end_user_id
# Set processing status
doc.system_metadata["status"] = "processing"
# Store the document in the database
success = await database.store_document(doc)
if not success:
raise Exception("Failed to store document metadata")
# If folder_name is provided, ensure the folder exists and add document to it
if folder_name:
try:
await document_service._ensure_folder_exists(folder_name, doc.external_id, auth)
logger.debug(f"Ensured folder '{folder_name}' exists and contains document {doc.external_id}")
except Exception as e:
# Log error but don't raise - we want document ingestion to continue even if folder operation fails
logger.error(f"Error ensuring folder exists: {e}")
# Read file content
file_content = await file.read()
# Generate a unique key for the file
file_key = f"ingest_uploads/{uuid.uuid4()}/{file.filename}"
# Store the file in the configured storage
file_content_base64 = base64.b64encode(file_content).decode()
bucket, stored_key = await storage.upload_from_base64(file_content_base64, file_key, file.content_type)
logger.debug(f"Stored file in bucket {bucket} with key {stored_key}")
# Update document with storage info
doc.storage_info = {"bucket": bucket, "key": stored_key}
# Initialize storage_files array with the first file
from datetime import UTC, datetime
from core.models.documents import StorageFileInfo
# Create a StorageFileInfo for the initial file
initial_file_info = StorageFileInfo(
bucket=bucket,
key=stored_key,
version=1,
filename=file.filename,
content_type=file.content_type,
timestamp=datetime.now(UTC),
)
doc.storage_files = [initial_file_info]
# Log storage files
logger.debug(f"Initial storage_files for {doc.external_id}: {doc.storage_files}")
# Update both storage_info and storage_files
await database.update_document(
document_id=doc.external_id,
updates={"storage_info": doc.storage_info, "storage_files": doc.storage_files},
auth=auth,
)
# Convert auth context to a dictionary for serialization
auth_dict = {
"entity_type": auth.entity_type.value,
"entity_id": auth.entity_id,
"app_id": auth.app_id,
"permissions": list(auth.permissions),
"user_id": auth.user_id,
}
# Enqueue the background job
job = await redis.enqueue_job(
"process_ingestion_job",
document_id=doc.external_id,
file_key=stored_key,
bucket=bucket,
original_filename=file.filename,
content_type=file.content_type,
metadata_json=metadata,
auth_dict=auth_dict,
rules_list=rules_list,
use_colpali=use_colpali,
folder_name=folder_name,
end_user_id=end_user_id,
)
logger.info(f"File ingestion job queued with ID: {job.job_id} for document: {doc.external_id}")
return doc
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Error during file ingestion: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error during file ingestion: {str(e)}")
@app.post("/ingest/files", response_model=BatchIngestResponse)
@telemetry.track(operation_type="queue_batch_ingest", metadata_resolver=telemetry.batch_ingest_metadata)
async def batch_ingest_files(
files: List[UploadFile] = File(...),
metadata: str = Form("{}"),
rules: str = Form("[]"),
use_colpali: Optional[bool] = Form(None),
parallel: Optional[bool] = Form(True),
folder_name: Optional[str] = Form(None),
end_user_id: Optional[str] = Form(None),
auth: AuthContext = Depends(verify_token),
redis: arq.ArqRedis = Depends(get_redis_pool),
) -> BatchIngestResponse:
"""
Batch ingest multiple files using the task queue.
Args:
files: List of files to ingest
metadata: JSON string of metadata (either a single dict or list of dicts)
rules: JSON string of rules list. Can be either:
- A single list of rules to apply to all files
- A list of rule lists, one per file
use_colpali: Whether to use ColPali-style embedding
folder_name: Optional folder to scope the documents to
end_user_id: Optional end-user ID to scope the documents to
auth: Authentication context
redis: Redis connection pool for background tasks
Returns:
BatchIngestResponse containing:
- documents: List of created documents with processing status
- errors: List of errors that occurred during the batch operation
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided for batch ingestion")
try:
metadata_value = json.loads(metadata)
rules_list = json.loads(rules)
# Fix bool conversion: ensure string "false" is properly converted to False
def str2bool(v):
return str(v).lower() in {"true", "1", "yes"}
use_colpali = str2bool(use_colpali)
# Ensure user has write permission
if "write" not in auth.permissions:
raise PermissionError("User does not have write permission")
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
# Validate metadata if it's a list
if isinstance(metadata_value, list) and len(metadata_value) != len(files):
raise HTTPException(
status_code=400,
detail=f"Number of metadata items ({len(metadata_value)}) must match number of files ({len(files)})",
)
# Validate rules if it's a list of lists
if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list):
if len(rules_list) != len(files):
raise HTTPException(
status_code=400,
detail=f"Number of rule lists ({len(rules_list)}) must match number of files ({len(files)})",
)
# Convert auth context to a dictionary for serialization
auth_dict = {
"entity_type": auth.entity_type.value,
"entity_id": auth.entity_id,
"app_id": auth.app_id,
"permissions": list(auth.permissions),
"user_id": auth.user_id,
}
created_documents = []
try:
for i, file in enumerate(files):
# Get the metadata and rules for this file
metadata_item = metadata_value[i] if isinstance(metadata_value, list) else metadata_value
file_rules = (
rules_list[i]
if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list)
else rules_list
)
# Create a document with processing status
doc = Document(
content_type=file.content_type,
filename=file.filename,
metadata=metadata_item,
owner={"type": auth.entity_type.value, "id": auth.entity_id},
access_control={
"readers": [auth.entity_id],
"writers": [auth.entity_id],
"admins": [auth.entity_id],
"user_id": [auth.user_id] if auth.user_id else [],
},
)
# Add folder_name and end_user_id to system_metadata if provided
if folder_name:
doc.system_metadata["folder_name"] = folder_name
if end_user_id:
doc.system_metadata["end_user_id"] = end_user_id
# Set processing status
doc.system_metadata["status"] = "processing"
# Store the document in the database
success = await database.store_document(doc)
if not success:
raise Exception(f"Failed to store document metadata for {file.filename}")
# If folder_name is provided, ensure the folder exists and add document to it
if folder_name:
try:
await document_service._ensure_folder_exists(folder_name, doc.external_id, auth)
logger.debug(f"Ensured folder '{folder_name}' exists and contains document {doc.external_id}")
except Exception as e:
# Log error but don't raise - we want document ingestion to continue even if folder operation fails
logger.error(f"Error ensuring folder exists: {e}")
# Read file content
file_content = await file.read()
# Generate a unique key for the file
file_key = f"ingest_uploads/{uuid.uuid4()}/{file.filename}"
# Store the file in the configured storage
file_content_base64 = base64.b64encode(file_content).decode()
bucket, stored_key = await storage.upload_from_base64(file_content_base64, file_key, file.content_type)
logger.debug(f"Stored file in bucket {bucket} with key {stored_key}")
# Update document with storage info
doc.storage_info = {"bucket": bucket, "key": stored_key}
await database.update_document(
document_id=doc.external_id, updates={"storage_info": doc.storage_info}, auth=auth
)
# Convert metadata to JSON string for job
metadata_json = json.dumps(metadata_item)
# Enqueue the background job
job = await redis.enqueue_job(
"process_ingestion_job",
document_id=doc.external_id,
file_key=stored_key,
bucket=bucket,
original_filename=file.filename,
content_type=file.content_type,
metadata_json=metadata_json,
auth_dict=auth_dict,
rules_list=file_rules,
use_colpali=use_colpali,
folder_name=folder_name,
end_user_id=end_user_id,
)
logger.info(f"File ingestion job queued with ID: {job.job_id} for document: {doc.external_id}")
# Add document to the list
created_documents.append(doc)
# Return information about created documents
return BatchIngestResponse(documents=created_documents, errors=[])
except Exception as e:
logger.error(f"Error queueing batch file ingestion: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error queueing batch file ingestion: {str(e)}")
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
@telemetry.track(operation_type="retrieve_chunks", metadata_resolver=telemetry.retrieve_chunks_metadata)
async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
"""
Retrieve relevant chunks.
Args:
request: RetrieveRequest containing:
- query: Search query text
- filters: Optional metadata filters
- k: Number of results (default: 4)
- min_score: Minimum similarity threshold (default: 0.0)
- use_reranking: Whether to use reranking
- use_colpali: Whether to use ColPali-style embedding model
- folder_name: Optional folder to scope the search to
- end_user_id: Optional end-user ID to scope the search to
auth: Authentication context
Returns:
List[ChunkResult]: List of relevant chunks
"""
try:
return await document_service.retrieve_chunks(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.use_reranking,
request.use_colpali,
request.folder_name,
request.end_user_id,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/retrieve/docs", response_model=List[DocumentResult])
@telemetry.track(operation_type="retrieve_docs", metadata_resolver=telemetry.retrieve_docs_metadata)
async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
"""
Retrieve relevant documents.
Args:
request: RetrieveRequest containing:
- query: Search query text
- filters: Optional metadata filters
- k: Number of results (default: 4)
- min_score: Minimum similarity threshold (default: 0.0)
- use_reranking: Whether to use reranking
- use_colpali: Whether to use ColPali-style embedding model
- folder_name: Optional folder to scope the search to
- end_user_id: Optional end-user ID to scope the search to
auth: Authentication context
Returns:
List[DocumentResult]: List of relevant documents
"""
try:
return await document_service.retrieve_docs(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.use_reranking,
request.use_colpali,
request.folder_name,
request.end_user_id,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/batch/documents", response_model=List[Document])
@telemetry.track(operation_type="batch_get_documents", metadata_resolver=telemetry.batch_documents_metadata)
async def batch_get_documents(request: Dict[str, Any], auth: AuthContext = Depends(verify_token)):
"""
Retrieve multiple documents by their IDs in a single batch operation.
Args:
request: Dictionary containing:
- document_ids: List of document IDs to retrieve
- folder_name: Optional folder to scope the operation to
- end_user_id: Optional end-user ID to scope the operation to
auth: Authentication context
Returns:
List[Document]: List of documents matching the IDs
"""
try:
# Extract document_ids from request
document_ids = request.get("document_ids", [])
folder_name = request.get("folder_name")
end_user_id = request.get("end_user_id")
if not document_ids:
return []
return await document_service.batch_retrieve_documents(document_ids, auth, folder_name, end_user_id)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/batch/chunks", response_model=List[ChunkResult])
@telemetry.track(operation_type="batch_get_chunks", metadata_resolver=telemetry.batch_chunks_metadata)
async def batch_get_chunks(request: Dict[str, Any], auth: AuthContext = Depends(verify_token)):
"""
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
Args:
request: Dictionary containing:
- sources: List of ChunkSource objects (with document_id and chunk_number)
- folder_name: Optional folder to scope the operation to
- end_user_id: Optional end-user ID to scope the operation to
auth: Authentication context
Returns:
List[ChunkResult]: List of chunk results
"""
try:
# Extract sources from request
sources = request.get("sources", [])
folder_name = request.get("folder_name")
end_user_id = request.get("end_user_id")
use_colpali = request.get("use_colpali")
if not sources:
return []
# Convert sources to ChunkSource objects if needed
chunk_sources = []
for source in sources:
if isinstance(source, dict):
chunk_sources.append(ChunkSource(**source))
else:
chunk_sources.append(source)
return await document_service.batch_retrieve_chunks(chunk_sources, auth, folder_name, end_user_id, use_colpali)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/query", response_model=CompletionResponse)
@telemetry.track(operation_type="query", metadata_resolver=telemetry.query_metadata)
async def query_completion(request: CompletionQueryRequest, auth: AuthContext = Depends(verify_token)):
"""
Generate completion using relevant chunks as context.
When graph_name is provided, the query will leverage the knowledge graph
to enhance retrieval by finding relevant entities and their connected documents.
Args:
request: CompletionQueryRequest containing:
- query: Query text
- filters: Optional metadata filters
- k: Number of chunks to use as context (default: 4)
- min_score: Minimum similarity threshold (default: 0.0)
- max_tokens: Maximum tokens in completion
- temperature: Model temperature
- use_reranking: Whether to use reranking
- use_colpali: Whether to use ColPali-style embedding model
- graph_name: Optional name of the graph to use for knowledge graph-enhanced retrieval
- hop_depth: Number of relationship hops to traverse in the graph (1-3)
- include_paths: Whether to include relationship paths in the response
- prompt_overrides: Optional customizations for entity extraction, resolution, and query prompts
- folder_name: Optional folder to scope the operation to
- end_user_id: Optional end-user ID to scope the operation to
- schema: Optional schema for structured output
auth: Authentication context
Returns:
CompletionResponse: Generated text completion or structured output
"""
try:
# Validate prompt overrides before proceeding
if request.prompt_overrides:
validate_prompt_overrides_with_http_exception(request.prompt_overrides, operation_type="query")
# Check query limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "query", 1)
return await document_service.query(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.max_tokens,
request.temperature,
request.use_reranking,
request.use_colpali,
request.graph_name,
request.hop_depth,
request.include_paths,
request.prompt_overrides,
request.folder_name,
request.end_user_id,
request.schema,
)
except ValueError as e:
validate_prompt_overrides_with_http_exception(operation_type="query", error=e)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/documents", response_model=List[Document])
async def list_documents(
auth: AuthContext = Depends(verify_token),
skip: int = 0,
limit: int = 10000,
filters: Optional[Dict[str, Any]] = None,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
):
"""
List accessible documents.
Args:
auth: Authentication context
skip: Number of documents to skip
limit: Maximum number of documents to return
filters: Optional metadata filters
folder_name: Optional folder to scope the operation to
end_user_id: Optional end-user ID to scope the operation to
Returns:
List[Document]: List of accessible documents
"""
# Create system filters for folder and user scoping
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id
return await document_service.db.get_documents(auth, skip, limit, filters, system_filters)
@app.get("/documents/{document_id}", response_model=Document)
async def get_document(document_id: str, auth: AuthContext = Depends(verify_token)):
"""Get document by ID."""
try:
doc = await document_service.db.get_document(document_id, auth)
logger.debug(f"Found document: {doc}")
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
return doc
except HTTPException as e:
logger.error(f"Error getting document: {e}")
raise e
@app.get("/documents/{document_id}/status", response_model=Dict[str, Any])
async def get_document_status(document_id: str, auth: AuthContext = Depends(verify_token)):
"""
Get the processing status of a document.
Args:
document_id: ID of the document to check
auth: Authentication context
Returns:
Dict containing status information for the document
"""
try:
doc = await document_service.db.get_document(document_id, auth)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
# Extract status information
status = doc.system_metadata.get("status", "unknown")
response = {
"document_id": doc.external_id,
"status": status,
"filename": doc.filename,
"created_at": doc.system_metadata.get("created_at"),
"updated_at": doc.system_metadata.get("updated_at"),
}
# Add error information if failed
if status == "failed":
response["error"] = doc.system_metadata.get("error", "Unknown error")
return response
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting document status: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error getting document status: {str(e)}")
@app.delete("/documents/{document_id}")
@telemetry.track(operation_type="delete_document", metadata_resolver=telemetry.document_delete_metadata)
async def delete_document(document_id: str, auth: AuthContext = Depends(verify_token)):
"""
Delete a document and all associated data.
This endpoint deletes a document and all its associated data, including:
- Document metadata
- Document content in storage
- Document chunks and embeddings in vector store
Args:
document_id: ID of the document to delete
auth: Authentication context (must have write access to the document)
Returns:
Deletion status
"""
try:
success = await document_service.delete_document(document_id, auth)
if not success:
raise HTTPException(status_code=404, detail="Document not found or delete failed")
return {"status": "success", "message": f"Document {document_id} deleted successfully"}
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.get("/documents/filename/{filename}", response_model=Document)
async def get_document_by_filename(
filename: str,
auth: AuthContext = Depends(verify_token),
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
):
"""
Get document by filename.
Args:
filename: Filename of the document to retrieve
auth: Authentication context
folder_name: Optional folder to scope the operation to
end_user_id: Optional end-user ID to scope the operation to
Returns:
Document: Document metadata if found and accessible
"""
try:
# Create system filters for folder and user scoping
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id
doc = await document_service.db.get_document_by_filename(filename, auth, system_filters)
logger.debug(f"Found document by filename: {doc}")
if not doc:
raise HTTPException(status_code=404, detail=f"Document with filename '{filename}' not found")
return doc
except HTTPException as e:
logger.error(f"Error getting document by filename: {e}")
raise e
@app.post("/documents/{document_id}/update_text", response_model=Document)
@telemetry.track(operation_type="update_document_text", metadata_resolver=telemetry.document_update_text_metadata)
async def update_document_text(
document_id: str,
request: IngestTextRequest,
update_strategy: str = "add",
auth: AuthContext = Depends(verify_token),
):
"""
Update a document with new text content using the specified strategy.
Args:
document_id: ID of the document to update
request: Text content and metadata for the update
update_strategy: Strategy for updating the document (default: 'add')
Returns:
Document: Updated document metadata
"""
try:
doc = await document_service.update_document(
document_id=document_id,
auth=auth,
content=request.content,
file=None,
filename=request.filename,
metadata=request.metadata,
rules=request.rules,
update_strategy=update_strategy,
use_colpali=request.use_colpali,
)
if not doc:
raise HTTPException(status_code=404, detail="Document not found or update failed")
return doc
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/documents/{document_id}/update_file", response_model=Document)
@telemetry.track(operation_type="update_document_file", metadata_resolver=telemetry.document_update_file_metadata)
async def update_document_file(
document_id: str,
file: UploadFile,
metadata: str = Form("{}"),
rules: str = Form("[]"),
update_strategy: str = Form("add"),
use_colpali: Optional[bool] = None,
auth: AuthContext = Depends(verify_token),
):
"""
Update a document with content from a file using the specified strategy.
Args:
document_id: ID of the document to update
file: File to add to the document
metadata: JSON string of metadata to merge with existing metadata
rules: JSON string of rules to apply to the content
update_strategy: Strategy for updating the document (default: 'add')
use_colpali: Whether to use multi-vector embedding
Returns:
Document: Updated document metadata
"""
try:
metadata_dict = json.loads(metadata)
rules_list = json.loads(rules)
doc = await document_service.update_document(
document_id=document_id,
auth=auth,
content=None,
file=file,
filename=file.filename,
metadata=metadata_dict,
rules=rules_list,
update_strategy=update_strategy,
use_colpali=use_colpali,
)
if not doc:
raise HTTPException(status_code=404, detail="Document not found or update failed")
return doc
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/documents/{document_id}/update_metadata", response_model=Document)
@telemetry.track(
operation_type="update_document_metadata",
metadata_resolver=telemetry.document_update_metadata_resolver,
)
async def update_document_metadata(
document_id: str, metadata: Dict[str, Any], auth: AuthContext = Depends(verify_token)
):
"""
Update only a document's metadata.
Args:
document_id: ID of the document to update
metadata: New metadata to merge with existing metadata
Returns:
Document: Updated document metadata
"""
try:
doc = await document_service.update_document(
document_id=document_id,
auth=auth,
content=None,
file=None,
filename=None,
metadata=metadata,
rules=[],
update_strategy="add",
use_colpali=None,
)
if not doc:
raise HTTPException(status_code=404, detail="Document not found or update failed")
return doc
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
# Usage tracking endpoints
@app.get("/usage/stats")
@telemetry.track(operation_type="get_usage_stats", metadata_resolver=telemetry.usage_stats_metadata)
async def get_usage_stats(auth: AuthContext = Depends(verify_token)) -> Dict[str, int]:
"""Get usage statistics for the authenticated user."""
if not auth.permissions or "admin" not in auth.permissions:
return telemetry.get_user_usage(auth.entity_id)
return telemetry.get_user_usage(auth.entity_id)
@app.get("/usage/recent")
@telemetry.track(operation_type="get_recent_usage", metadata_resolver=telemetry.recent_usage_metadata)
async def get_recent_usage(
auth: AuthContext = Depends(verify_token),
operation_type: Optional[str] = None,
since: Optional[datetime] = None,
status: Optional[str] = None,
) -> List[Dict]:
"""Get recent usage records."""
if not auth.permissions or "admin" not in auth.permissions:
records = telemetry.get_recent_usage(
user_id=auth.entity_id, operation_type=operation_type, since=since, status=status
)
else:
records = telemetry.get_recent_usage(operation_type=operation_type, since=since, status=status)
return [
{
"timestamp": record.timestamp,
"operation_type": record.operation_type,
"tokens_used": record.tokens_used,
"user_id": record.user_id,
"duration_ms": record.duration_ms,
"status": record.status,
"metadata": record.metadata,
}
for record in records
]
# Cache endpoints
@app.post("/cache/create")
@telemetry.track(operation_type="create_cache", metadata_resolver=telemetry.cache_create_metadata)
async def create_cache(
name: str,
model: str,
gguf_file: str,
filters: Optional[Dict[str, Any]] = None,
docs: Optional[List[str]] = None,
auth: AuthContext = Depends(verify_token),
) -> Dict[str, Any]:
"""Create a new cache with specified configuration."""
try:
# Check cache creation limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "cache", 1)
filter_docs = set(await document_service.db.get_documents(auth, filters=filters))
additional_docs = (
{await document_service.db.get_document(document_id=doc_id, auth=auth) for doc_id in docs}
if docs
else set()
)
docs_to_add = list(filter_docs.union(additional_docs))
if not docs_to_add:
raise HTTPException(status_code=400, detail="No documents to add to cache")
response = await document_service.create_cache(name, model, gguf_file, docs_to_add, filters)
return response
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.get("/cache/{name}")
@telemetry.track(operation_type="get_cache", metadata_resolver=telemetry.cache_get_metadata)
async def get_cache(name: str, auth: AuthContext = Depends(verify_token)) -> Dict[str, Any]:
"""Get cache configuration by name."""
try:
exists = await document_service.load_cache(name)
return {"exists": exists}
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/cache/{name}/update")
@telemetry.track(operation_type="update_cache", metadata_resolver=telemetry.cache_update_metadata)
async def update_cache(name: str, auth: AuthContext = Depends(verify_token)) -> Dict[str, bool]:
"""Update cache with new documents matching its filter."""
try:
if name not in document_service.active_caches:
exists = await document_service.load_cache(name)
if not exists:
raise HTTPException(status_code=404, detail=f"Cache '{name}' not found")
cache = document_service.active_caches[name]
docs = await document_service.db.get_documents(auth, filters=cache.filters)
docs_to_add = [doc for doc in docs if doc.id not in cache.docs]
return cache.add_docs(docs_to_add)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/cache/{name}/add_docs")
@telemetry.track(operation_type="add_docs_to_cache", metadata_resolver=telemetry.cache_add_docs_metadata)
async def add_docs_to_cache(name: str, docs: List[str], auth: AuthContext = Depends(verify_token)) -> Dict[str, bool]:
"""Add specific documents to the cache."""
try:
cache = document_service.active_caches[name]
docs_to_add = [
await document_service.db.get_document(doc_id, auth) for doc_id in docs if doc_id not in cache.docs
]
return cache.add_docs(docs_to_add)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/cache/{name}/query")
@telemetry.track(operation_type="query_cache", metadata_resolver=telemetry.cache_query_metadata)
async def query_cache(
name: str,
query: str,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
auth: AuthContext = Depends(verify_token),
) -> CompletionResponse:
"""Query the cache with a prompt."""
try:
# Check cache query limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "cache_query", 1)
cache = document_service.active_caches[name]
logger.info(f"Cache state: {cache.state.n_tokens}")
return cache.query(query) # , max_tokens, temperature)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/graph/create", response_model=Graph)
@telemetry.track(operation_type="create_graph", metadata_resolver=telemetry.create_graph_metadata)
async def create_graph(
request: CreateGraphRequest,
auth: AuthContext = Depends(verify_token),
) -> Graph:
"""
Create a graph from documents.
This endpoint extracts entities and relationships from documents
matching the specified filters or document IDs and creates a graph.
Args:
request: CreateGraphRequest containing:
- name: Name of the graph to create
- filters: Optional metadata filters to determine which documents to include
- documents: Optional list of specific document IDs to include
- prompt_overrides: Optional customizations for entity extraction and resolution prompts
- folder_name: Optional folder to scope the operation to
- end_user_id: Optional end-user ID to scope the operation to
auth: Authentication context
Returns:
Graph: The created graph object
"""
try:
# Validate prompt overrides before proceeding
if request.prompt_overrides:
validate_prompt_overrides_with_http_exception(request.prompt_overrides, operation_type="graph")
# Check graph creation limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "graph", 1)
# Create system filters for folder and user scoping
system_filters = {}
if request.folder_name:
system_filters["folder_name"] = request.folder_name
if request.end_user_id:
system_filters["end_user_id"] = request.end_user_id
return await document_service.create_graph(
name=request.name,
auth=auth,
filters=request.filters,
documents=request.documents,
prompt_overrides=request.prompt_overrides,
system_filters=system_filters,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except ValueError as e:
validate_prompt_overrides_with_http_exception(operation_type="graph", error=e)
@app.post("/folders", response_model=Folder)
async def create_folder(
folder_create: FolderCreate,
auth: AuthContext = Depends(verify_token),
) -> Folder:
"""
Create a new folder.
Args:
folder_create: Folder creation request containing name and optional description
auth: Authentication context
Returns:
Folder: Created folder
"""
try:
async with telemetry.track_operation(
operation_type="create_folder",
user_id=auth.entity_id,
metadata={
"name": folder_create.name,
},
):
# Create a folder object with explicit ID
import uuid
folder_id = str(uuid.uuid4())
logger.info(f"Creating folder with ID: {folder_id}, auth.user_id: {auth.user_id}")
# Set up access control with user_id
access_control = {
"readers": [auth.entity_id],
"writers": [auth.entity_id],
"admins": [auth.entity_id],
}
if auth.user_id:
access_control["user_id"] = [auth.user_id]
logger.info(f"Adding user_id {auth.user_id} to folder access control")
folder = Folder(
id=folder_id,
name=folder_create.name,
description=folder_create.description,
owner={
"type": auth.entity_type.value,
"id": auth.entity_id,
},
access_control=access_control,
)
# Store in database
success = await document_service.db.create_folder(folder)
if not success:
raise HTTPException(status_code=500, detail="Failed to create folder")
return folder
except Exception as e:
logger.error(f"Error creating folder: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/folders", response_model=List[Folder])
async def list_folders(
auth: AuthContext = Depends(verify_token),
) -> List[Folder]:
"""
List all folders the user has access to.
Args:
auth: Authentication context
Returns:
List[Folder]: List of folders
"""
try:
async with telemetry.track_operation(
operation_type="list_folders",
user_id=auth.entity_id,
):
folders = await document_service.db.list_folders(auth)
return folders
except Exception as e:
logger.error(f"Error listing folders: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/folders/{folder_id}", response_model=Folder)
async def get_folder(
folder_id: str,
auth: AuthContext = Depends(verify_token),
) -> Folder:
"""
Get a folder by ID.
Args:
folder_id: ID of the folder
auth: Authentication context
Returns:
Folder: Folder if found and accessible
"""
try:
async with telemetry.track_operation(
operation_type="get_folder",
user_id=auth.entity_id,
metadata={
"folder_id": folder_id,
},
):
folder = await document_service.db.get_folder(folder_id, auth)
if not folder:
raise HTTPException(status_code=404, detail=f"Folder {folder_id} not found")
return folder
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting folder: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/folders/{folder_id}/documents/{document_id}")
async def add_document_to_folder(
folder_id: str,
document_id: str,
auth: AuthContext = Depends(verify_token),
):
"""
Add a document to a folder.
Args:
folder_id: ID of the folder
document_id: ID of the document
auth: Authentication context
Returns:
Success status
"""
try:
async with telemetry.track_operation(
operation_type="add_document_to_folder",
user_id=auth.entity_id,
metadata={
"folder_id": folder_id,
"document_id": document_id,
},
):
success = await document_service.db.add_document_to_folder(folder_id, document_id, auth)
if not success:
raise HTTPException(status_code=500, detail="Failed to add document to folder")
return {"status": "success"}
except Exception as e:
logger.error(f"Error adding document to folder: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/folders/{folder_id}/documents/{document_id}")
async def remove_document_from_folder(
folder_id: str,
document_id: str,
auth: AuthContext = Depends(verify_token),
):
"""
Remove a document from a folder.
Args:
folder_id: ID of the folder
document_id: ID of the document
auth: Authentication context
Returns:
Success status
"""
try:
async with telemetry.track_operation(
operation_type="remove_document_from_folder",
user_id=auth.entity_id,
metadata={
"folder_id": folder_id,
"document_id": document_id,
},
):
success = await document_service.db.remove_document_from_folder(folder_id, document_id, auth)
if not success:
raise HTTPException(status_code=500, detail="Failed to remove document from folder")
return {"status": "success"}
except Exception as e:
logger.error(f"Error removing document from folder: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/graph/{name}", response_model=Graph)
@telemetry.track(operation_type="get_graph", metadata_resolver=telemetry.get_graph_metadata)
async def get_graph(
name: str,
auth: AuthContext = Depends(verify_token),
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> Graph:
"""
Get a graph by name.
This endpoint retrieves a graph by its name if the user has access to it.
Args:
name: Name of the graph to retrieve
auth: Authentication context
folder_name: Optional folder to scope the operation to
end_user_id: Optional end-user ID to scope the operation to
Returns:
Graph: The requested graph object
"""
try:
# Create system filters for folder and user scoping
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id
graph = await document_service.db.get_graph(name, auth, system_filters)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph '{name}' not found")
return graph
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/graphs", response_model=List[Graph])
@telemetry.track(operation_type="list_graphs", metadata_resolver=telemetry.list_graphs_metadata)
async def list_graphs(
auth: AuthContext = Depends(verify_token),
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> List[Graph]:
"""
List all graphs the user has access to.
This endpoint retrieves all graphs the user has access to.
Args:
auth: Authentication context
folder_name: Optional folder to scope the operation to
end_user_id: Optional end-user ID to scope the operation to
Returns:
List[Graph]: List of graph objects
"""
try:
# Create system filters for folder and user scoping
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id
return await document_service.db.list_graphs(auth, system_filters)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/graph/{name}/update", response_model=Graph)
@telemetry.track(operation_type="update_graph", metadata_resolver=telemetry.update_graph_metadata)
async def update_graph(
name: str,
request: UpdateGraphRequest,
auth: AuthContext = Depends(verify_token),
) -> Graph:
"""
Update an existing graph with new documents.
This endpoint processes additional documents based on the original graph filters
and/or new filters/document IDs, extracts entities and relationships, and
updates the graph with new information.
Args:
name: Name of the graph to update
request: UpdateGraphRequest containing:
- additional_filters: Optional additional metadata filters to determine which new documents to include
- additional_documents: Optional list of additional document IDs to include
- prompt_overrides: Optional customizations for entity extraction and resolution prompts
- folder_name: Optional folder to scope the operation to
- end_user_id: Optional end-user ID to scope the operation to
auth: Authentication context
Returns:
Graph: The updated graph object
"""
try:
# Validate prompt overrides before proceeding
if request.prompt_overrides:
validate_prompt_overrides_with_http_exception(request.prompt_overrides, operation_type="graph")
# Create system filters for folder and user scoping
system_filters = {}
if request.folder_name:
system_filters["folder_name"] = request.folder_name
if request.end_user_id:
system_filters["end_user_id"] = request.end_user_id
return await document_service.update_graph(
name=name,
auth=auth,
additional_filters=request.additional_filters,
additional_documents=request.additional_documents,
prompt_overrides=request.prompt_overrides,
system_filters=system_filters,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except ValueError as e:
validate_prompt_overrides_with_http_exception(operation_type="graph", error=e)
except Exception as e:
logger.error(f"Error updating graph: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/local/generate_uri", include_in_schema=True)
async def generate_local_uri(
name: str = Form("admin"),
expiry_days: int = Form(30),
) -> Dict[str, str]:
"""Generate a local URI for development. This endpoint is unprotected."""
try:
# Clean name
name = name.replace(" ", "_").lower()
# Create payload
payload = {
"type": "developer",
"entity_id": name,
"permissions": ["read", "write", "admin"],
"exp": datetime.now(UTC) + timedelta(days=expiry_days),
}
# Generate token
token = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
# Read config for host/port
with open("morphik.toml", "rb") as f:
config = tomli.load(f)
base_url = f"{config['api']['host']}:{config['api']['port']}".replace("localhost", "127.0.0.1")
# Generate URI
uri = f"morphik://{name}:{token}@{base_url}"
return {"uri": uri}
except Exception as e:
logger.error(f"Error generating local URI: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/cloud/generate_uri", include_in_schema=True)
async def generate_cloud_uri(
request: GenerateUriRequest,
authorization: str = Header(None),
) -> Dict[str, str]:
"""Generate a URI for cloud hosted applications."""
try:
app_id = request.app_id
name = request.name
user_id = request.user_id
expiry_days = request.expiry_days
logger.debug(f"Generating cloud URI for app_id={app_id}, name={name}, user_id={user_id}")
# Verify authorization header before proceeding
if not authorization:
logger.warning("Missing authorization header")
raise HTTPException(
status_code=401,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
# Verify the token is valid
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[7:] # Remove "Bearer "
try:
# Decode the token to ensure it's valid
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
# Only allow users to create apps for themselves (or admin)
token_user_id = payload.get("user_id")
logger.debug(f"Token user ID: {token_user_id}")
logger.debug(f"User ID: {user_id}")
if not (token_user_id == user_id or "admin" in payload.get("permissions", [])):
raise HTTPException(
status_code=403,
detail="You can only create apps for your own account unless you have admin permissions",
)
except jwt.InvalidTokenError as e:
raise HTTPException(status_code=401, detail=str(e))
# Import UserService here to avoid circular imports
from core.services.user_service import UserService
user_service = UserService()
# Initialize user service if needed
await user_service.initialize()
# Clean name
name = name.replace(" ", "_").lower()
# Check if the user is within app limit and generate URI
uri = await user_service.generate_cloud_uri(user_id, app_id, name, expiry_days)
if not uri:
logger.debug("Application limit reached for this account tier with user_id: %s", user_id)
raise HTTPException(status_code=403, detail="Application limit reached for this account tier")
return {"uri": uri, "app_id": app_id}
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Error generating cloud URI: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/folders/{folder_id}/set_rule")
@telemetry.track(operation_type="set_folder_rule", metadata_resolver=telemetry.set_folder_rule_metadata)
async def set_folder_rule(
folder_id: str,
request: SetFolderRuleRequest,
auth: AuthContext = Depends(verify_token),
apply_to_existing: bool = True,
):
"""
Set extraction rules for a folder.
Args:
folder_id: ID of the folder to set rules for
request: SetFolderRuleRequest containing metadata extraction rules
auth: Authentication context
apply_to_existing: Whether to apply rules to existing documents in the folder
Returns:
Success status with processing results
"""
# Import text here to ensure it's available in this function's scope
from sqlalchemy import text
try:
# Log detailed information about the rules
logger.debug(f"Setting rules for folder {folder_id}")
logger.debug(f"Number of rules: {len(request.rules)}")
for i, rule in enumerate(request.rules):
logger.debug(f"\nRule {i + 1}:")
logger.debug(f"Type: {rule.type}")
logger.debug("Schema:")
for field_name, field_config in rule.schema.items():
logger.debug(f" Field: {field_name}")
logger.debug(f" Type: {field_config.get('type', 'unknown')}")
logger.debug(f" Description: {field_config.get('description', 'No description')}")
if "schema" in field_config:
logger.debug(" Has JSON schema: Yes")
logger.debug(f" Schema: {field_config['schema']}")
# Get the folder
folder = await document_service.db.get_folder(folder_id, auth)
if not folder:
raise HTTPException(status_code=404, detail=f"Folder {folder_id} not found")
# Check if user has write access to the folder
if not document_service.db._check_folder_access(folder, auth, "write"):
raise HTTPException(status_code=403, detail="You don't have write access to this folder")
# Update folder with rules
# Convert rules to dicts for JSON serialization
rules_dicts = [rule.model_dump() for rule in request.rules]
# Update the folder in the database
async with document_service.db.async_session() as session:
# Execute update query
await session.execute(
text(
"""
UPDATE folders
SET rules = :rules
WHERE id = :folder_id
"""
),
{"folder_id": folder_id, "rules": json.dumps(rules_dicts)},
)
await session.commit()
logger.info(f"Successfully updated folder {folder_id} with {len(request.rules)} rules")
# Get updated folder
updated_folder = await document_service.db.get_folder(folder_id, auth)
# If apply_to_existing is True, apply these rules to all existing documents in the folder
processing_results = {"processed": 0, "errors": []}
if apply_to_existing and folder.document_ids:
logger.info(f"Applying rules to {len(folder.document_ids)} existing documents in folder")
# Import rules processor
# Get all documents in the folder
documents = await document_service.db.get_documents_by_id(folder.document_ids, auth)
# Process each document
for doc in documents:
try:
# Get document content
logger.info(f"Processing document {doc.external_id}")
# For each document, apply the rules from the folder
doc_content = None
# Get content from system_metadata if available
if doc.system_metadata and "content" in doc.system_metadata:
doc_content = doc.system_metadata["content"]
logger.info(f"Retrieved content from system_metadata for document {doc.external_id}")
# If we still have no content, log error and continue
if not doc_content:
error_msg = f"No content found in system_metadata for document {doc.external_id}"
logger.error(error_msg)
processing_results["errors"].append({"document_id": doc.external_id, "error": error_msg})
continue
# Process document with rules
try:
# Convert request rules to actual rule models and apply them
from core.models.rules import MetadataExtractionRule
for rule_request in request.rules:
if rule_request.type == "metadata_extraction":
# Create the actual rule model
rule = MetadataExtractionRule(type=rule_request.type, schema=rule_request.schema)
# Apply the rule with retries
max_retries = 3
base_delay = 1 # seconds
extracted_metadata = None
last_error = None
for retry_count in range(max_retries):
try:
if retry_count > 0:
# Exponential backoff
delay = base_delay * (2 ** (retry_count - 1))
logger.info(f"Retry {retry_count}/{max_retries} after {delay}s delay")
await asyncio.sleep(delay)
extracted_metadata, _ = await rule.apply(doc_content, {})
logger.info(
f"Successfully extracted metadata on attempt {retry_count + 1}: "
f"{extracted_metadata}"
)
break # Success, exit retry loop
except Exception as rule_apply_error:
last_error = rule_apply_error
logger.warning(
f"Metadata extraction attempt {retry_count + 1} failed: "
f"{rule_apply_error}"
)
if retry_count == max_retries - 1: # Last attempt
logger.error(f"All {max_retries} metadata extraction attempts failed")
processing_results["errors"].append(
{
"document_id": doc.external_id,
"error": f"Failed to extract metadata after {max_retries} "
f"attempts: {str(last_error)}",
}
)
continue # Skip to next document
# Update document metadata if extraction succeeded
if extracted_metadata:
# Merge new metadata with existing
doc.metadata.update(extracted_metadata)
# Create an updates dict that only updates metadata
# We need to create system_metadata with all preserved fields
# Note: In the database, metadata is stored as 'doc_metadata', not 'metadata'
updates = {
"doc_metadata": doc.metadata, # Use doc_metadata for the database
"system_metadata": {}, # Will be merged with existing in update_document
}
# Explicitly preserve the content field in system_metadata
if "content" in doc.system_metadata:
updates["system_metadata"]["content"] = doc.system_metadata["content"]
# Log the updates we're making
logger.info(
f"Updating document {doc.external_id} with metadata: {extracted_metadata}"
)
logger.info(f"Full metadata being updated: {doc.metadata}")
logger.info(f"Update object being sent to database: {updates}")
logger.info(
f"Preserving content in system_metadata: {'content' in doc.system_metadata}"
)
# Update document in database
success = await document_service.db.update_document(doc.external_id, updates, auth)
if success:
logger.info(f"Updated metadata for document {doc.external_id}")
processing_results["processed"] += 1
else:
logger.error(f"Failed to update metadata for document {doc.external_id}")
processing_results["errors"].append(
{
"document_id": doc.external_id,
"error": "Failed to update document metadata",
}
)
except Exception as rule_error:
logger.error(f"Error processing rules for document {doc.external_id}: {rule_error}")
processing_results["errors"].append(
{
"document_id": doc.external_id,
"error": f"Error processing rules: {str(rule_error)}",
}
)
except Exception as doc_error:
logger.error(f"Error processing document {doc.external_id}: {doc_error}")
processing_results["errors"].append({"document_id": doc.external_id, "error": str(doc_error)})
return {
"status": "success",
"message": "Rules set successfully",
"folder_id": folder_id,
"rules": updated_folder.rules,
"processing_results": processing_results,
}
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Error setting folder rules: {e}")
raise HTTPException(status_code=500, detail=str(e))
```
## /core/cache/base_cache.py
```py path="/core/cache/base_cache.py"
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from core.models.completion import CompletionResponse
from core.models.documents import Document
class BaseCache(ABC):
"""Base class for cache implementations.
This class defines the interface for cache implementations that support
document ingestion and cache-augmented querying.
"""
def __init__(self, name: str, model: str, gguf_file: str, filters: Dict[str, Any], docs: List[Document]):
"""Initialize the cache with the given parameters.
Args:
name: Name of the cache instance
model: Model identifier
gguf_file: Path to the GGUF model file
filters: Filters used to create the cache context
docs: Initial documents to ingest into the cache
"""
self.name = name
self.filters = filters
self.docs = [] # List of document IDs that have been ingested
self._initialize(model, gguf_file, docs)
@abstractmethod
def _initialize(self, model: str, gguf_file: str, docs: List[Document]) -> None:
"""Internal initialization method to be implemented by subclasses."""
pass
@abstractmethod
async def add_docs(self, docs: List[Document]) -> bool:
"""Add documents to the cache.
Args:
docs: List of documents to add to the cache
Returns:
bool: True if documents were successfully added
"""
pass
@abstractmethod
async def query(self, query: str) -> CompletionResponse:
"""Query the cache for relevant documents and generate a response.
Args:
query: Query string to search for relevant documents
Returns:
CompletionResponse: Generated response based on cached context
"""
pass
@property
@abstractmethod
def saveable_state(self) -> bytes:
"""Get the saveable state of the cache as bytes.
Returns:
bytes: Serialized state that can be used to restore the cache
"""
pass
```
## /core/cache/base_cache_factory.py
```py path="/core/cache/base_cache_factory.py"
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict
from .base_cache import BaseCache
class BaseCacheFactory(ABC):
"""Abstract base factory for creating and loading caches."""
def __init__(self, storage_path: Path):
"""Initialize the cache factory.
Args:
storage_path: Base path for storing cache files
"""
self.storage_path = storage_path
self.storage_path.mkdir(parents=True, exist_ok=True)
@abstractmethod
def create_new_cache(self, name: str, model: str, model_file: str, **kwargs: Dict[str, Any]) -> BaseCache:
"""Create a new cache instance.
Args:
name: Name of the cache
model: Name/type of the model to use
model_file: Path or identifier for the model file
**kwargs: Additional arguments for cache creation
Returns:
BaseCache: The created cache instance
"""
pass
@abstractmethod
def load_cache_from_bytes(
self, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs: Dict[str, Any]
) -> BaseCache:
"""Load a cache from its serialized bytes.
Args:
name: Name of the cache
cache_bytes: Serialized cache data
metadata: Cache metadata including model info
**kwargs: Additional arguments for cache loading
Returns:
BaseCache: The loaded cache instance
"""
pass
def get_cache_path(self, name: str) -> Path:
"""Get the storage path for a cache.
Args:
name: Name of the cache
Returns:
Path: Directory path for the cache
"""
path = self.storage_path / name
path.mkdir(parents=True, exist_ok=True)
return path
```
## /core/cache/hf_cache.py
```py path="/core/cache/hf_cache.py"
# hugging face cache implementation.
from pathlib import Path
from typing import List, Optional, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
from core.cache.base_cache import BaseCache
from core.models.completion import CompletionRequest, CompletionResponse
class HuggingFaceCache(BaseCache):
"""Hugging Face Cache implementation for cache-augmented generation"""
def __init__(
self,
cache_path: Path,
model_name: str = "distilgpt2",
device: str = "cpu",
default_max_new_tokens: int = 100,
use_fp16: bool = False,
):
"""Initialize the HuggingFace cache.
Args:
cache_path: Path to store cache files
model_name: Name of the HuggingFace model to use
device: Device to run the model on (e.g. "cpu", "cuda", "mps")
default_max_new_tokens: Default maximum number of new tokens to generate
use_fp16: Whether to use FP16 precision
"""
super().__init__()
self.cache_path = cache_path
self.model_name = model_name
self.device = device
self.default_max_new_tokens = default_max_new_tokens
self.use_fp16 = use_fp16
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Configure model loading based on device
model_kwargs = {"low_cpu_mem_usage": True}
if device == "cpu":
# For CPU, use standard loading
model_kwargs.update({"torch_dtype": torch.float32})
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs).to(device)
else:
# For GPU/MPS, use automatic device mapping and optional FP16
model_kwargs.update({"device_map": "auto", "torch_dtype": torch.float16 if use_fp16 else torch.float32})
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
self.kv_cache = None
self.origin_len = None
def get_kv_cache(self, prompt: str) -> DynamicCache:
"""Build KV cache from prompt"""
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
cache = DynamicCache()
with torch.no_grad():
_ = self.model(input_ids=input_ids, past_key_values=cache, use_cache=True)
return cache
def clean_up_cache(self, cache: DynamicCache, origin_len: int):
"""Clean up cache by removing appended tokens"""
for i in range(len(cache.key_cache)):
cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
def generate(self, input_ids: torch.Tensor, past_key_values, max_new_tokens: Optional[int] = None) -> torch.Tensor:
"""Generate text using the model and cache"""
device = next(self.model.parameters()).device
origin_len = input_ids.shape[-1]
input_ids = input_ids.to(device)
output_ids = input_ids.clone()
next_token = input_ids
with torch.no_grad():
for _ in range(max_new_tokens or self.default_max_new_tokens):
out = self.model(input_ids=next_token, past_key_values=past_key_values, use_cache=True)
logits = out.logits[:, -1, :]
token = torch.argmax(logits, dim=-1, keepdim=True)
output_ids = torch.cat([output_ids, token], dim=-1)
past_key_values = out.past_key_values
next_token = token.to(device)
if self.model.config.eos_token_id is not None and token.item() == self.model.config.eos_token_id:
break
return output_ids[:, origin_len:]
async def ingest(self, docs: List[str]) -> bool:
"""Ingest documents into cache"""
try:
# Create system prompt with documents
system_prompt = f"""
<|system|>
You are an assistant who provides concise factual answers.
<|user|>
Context:
{' '.join(docs)}
Question:
""".strip()
# Build the cache
input_ids = self.tokenizer(system_prompt, return_tensors="pt").input_ids.to(self.device)
self.kv_cache = DynamicCache()
with torch.no_grad():
# First run to get the cache shape
outputs = self.model(input_ids=input_ids, use_cache=True)
# Initialize cache with empty tensors of the right shape
n_layers = len(outputs.past_key_values)
batch_size = input_ids.shape[0]
# Handle different model architectures
if hasattr(self.model.config, "num_key_value_heads"):
# Models with grouped query attention (GQA) like Llama
n_kv_heads = self.model.config.num_key_value_heads
head_dim = self.model.config.head_dim
elif hasattr(self.model.config, "n_head"):
# GPT-style models
n_kv_heads = self.model.config.n_head
head_dim = self.model.config.n_embd // self.model.config.n_head
elif hasattr(self.model.config, "num_attention_heads"):
# OPT-style models
n_kv_heads = self.model.config.num_attention_heads
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
else:
raise ValueError(f"Unsupported model architecture: {self.model.config.model_type}")
seq_len = input_ids.shape[1]
for i in range(n_layers):
key_shape = (batch_size, n_kv_heads, seq_len, head_dim)
value_shape = key_shape
self.kv_cache.key_cache.append(torch.zeros(key_shape, device=self.device))
self.kv_cache.value_cache.append(torch.zeros(value_shape, device=self.device))
# Now run with the initialized cache
outputs = self.model(input_ids=input_ids, past_key_values=self.kv_cache, use_cache=True)
# Update cache with actual values
self.kv_cache.key_cache = [layer[0] for layer in outputs.past_key_values]
self.kv_cache.value_cache = [layer[1] for layer in outputs.past_key_values]
self.origin_len = self.kv_cache.key_cache[0].shape[-2]
return True
except Exception as e:
print(f"Error ingesting documents: {e}")
return False
async def update(self, new_doc: str) -> bool:
"""Update cache with new document"""
try:
if self.kv_cache is None:
return await self.ingest([new_doc])
# Clean up existing cache
self.clean_up_cache(self.kv_cache, self.origin_len)
# Add new document to cache
input_ids = self.tokenizer(new_doc + "\n", return_tensors="pt").input_ids.to(self.device)
# First run to get the cache shape
outputs = self.model(input_ids=input_ids, use_cache=True)
# Initialize cache with empty tensors of the right shape
n_layers = len(outputs.past_key_values)
batch_size = input_ids.shape[0]
# Handle different model architectures
if hasattr(self.model.config, "num_key_value_heads"):
# Models with grouped query attention (GQA) like Llama
n_kv_heads = self.model.config.num_key_value_heads
head_dim = self.model.config.head_dim
elif hasattr(self.model.config, "n_head"):
# GPT-style models
n_kv_heads = self.model.config.n_head
head_dim = self.model.config.n_embd // self.model.config.n_head
elif hasattr(self.model.config, "num_attention_heads"):
# OPT-style models
n_kv_heads = self.model.config.num_attention_heads
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
else:
raise ValueError(f"Unsupported model architecture: {self.model.config.model_type}")
seq_len = input_ids.shape[1]
# Create a new cache for the update
new_cache = DynamicCache()
for i in range(n_layers):
key_shape = (batch_size, n_kv_heads, seq_len, head_dim)
value_shape = key_shape
new_cache.key_cache.append(torch.zeros(key_shape, device=self.device))
new_cache.value_cache.append(torch.zeros(value_shape, device=self.device))
# Run with the initialized cache
outputs = self.model(input_ids=input_ids, past_key_values=new_cache, use_cache=True)
# Update cache with actual values
self.kv_cache.key_cache = [layer[0] for layer in outputs.past_key_values]
self.kv_cache.value_cache = [layer[1] for layer in outputs.past_key_values]
return True
except Exception as e:
print(f"Error updating cache: {e}")
return False
async def complete(self, request: CompletionRequest) -> CompletionResponse:
"""Generate completion using cache-augmented generation"""
try:
if self.kv_cache is None:
raise ValueError("Cache not initialized. Please ingest documents first.")
# Clean up cache
self.clean_up_cache(self.kv_cache, self.origin_len)
# Generate completion
input_ids = self.tokenizer(request.query + "\n", return_tensors="pt").input_ids.to(self.device)
gen_ids = self.generate(input_ids, self.kv_cache, max_new_tokens=request.max_tokens)
completion = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True)
# Calculate token usage
usage = {
"prompt_tokens": len(input_ids[0]),
"completion_tokens": len(gen_ids[0]),
"total_tokens": len(input_ids[0]) + len(gen_ids[0]),
}
return CompletionResponse(completion=completion, usage=usage)
except Exception as e:
print(f"Error generating completion: {e}")
return CompletionResponse(
completion=f"Error: {str(e)}",
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
)
def save_cache(self) -> Path:
"""Save the KV cache to disk"""
if self.kv_cache is None:
raise ValueError("No cache to save")
cache_dir = self.cache_path / "kv_cache"
cache_dir.mkdir(parents=True, exist_ok=True)
# Save key and value caches
cache_data = {
"key_cache": self.kv_cache.key_cache,
"value_cache": self.kv_cache.value_cache,
"origin_len": self.origin_len,
}
cache_path = cache_dir / "cache.pt"
torch.save(cache_data, cache_path)
return cache_path
def load_cache(self, cache_path: Union[str, Path]) -> None:
"""Load KV cache from disk"""
cache_path = Path(cache_path)
if not cache_path.exists():
raise FileNotFoundError(f"Cache file not found at {cache_path}")
cache_data = torch.load(cache_path, map_location=self.device)
self.kv_cache = DynamicCache()
self.kv_cache.key_cache = cache_data["key_cache"]
self.kv_cache.value_cache = cache_data["value_cache"]
self.origin_len = cache_data["origin_len"]
```
## /core/cache/llama_cache.py
```py path="/core/cache/llama_cache.py"
import json
import logging
import pickle
from typing import Any, Dict, List
from llama_cpp import Llama
from core.cache.base_cache import BaseCache
from core.models.completion import CompletionResponse
from core.models.documents import Document
logger = logging.getLogger(__name__)
INITIAL_SYSTEM_PROMPT = """<|im_start|>system
You are a helpful AI assistant with access to provided documents. Your role is to:
1. Answer questions accurately based on the documents provided
2. Stay focused on the document content and avoid speculation
3. Admit when you don't have enough information to answer
4. Be clear and concise in your responses
5. Use direct quotes from documents when relevant
Provided documents: {documents}
<|im_end|>
""".strip()
ADD_DOC_SYSTEM_PROMPT = """<|im_start|>system
I'm adding some additional documents for your reference:
{documents}
Please incorporate this new information along with what you already know from previous documents while maintaining the same guidelines for responses.
<|im_end|>
""".strip()
QUERY_PROMPT = """<|im_start|>user
{query}
<|im_end|>
<|im_start|>assistant
""".strip()
class LlamaCache(BaseCache):
def __init__(
self,
name: str,
model: str,
gguf_file: str,
filters: Dict[str, Any],
docs: List[Document],
**kwargs,
):
logger.info(f"Initializing LlamaCache with name={name}, model={model}")
# cache related
self.name = name
self.model = model
self.filters = filters
self.docs = docs
# llama specific
self.gguf_file = gguf_file
self.n_gpu_layers = kwargs.get("n_gpu_layers", -1)
logger.info(f"Using {self.n_gpu_layers} GPU layers")
# late init (when we call _initialize)
self.llama = None
self.state = None
self.cached_tokens = 0
self._initialize(model, gguf_file, docs)
logger.info("LlamaCache initialization complete")
def _initialize(self, model: str, gguf_file: str, docs: List[Document]) -> None:
logger.info(f"Loading Llama model from {model} with file {gguf_file}")
try:
# Set a reasonable default context size (32K tokens)
default_ctx_size = 32768
self.llama = Llama.from_pretrained(
repo_id=model,
filename=gguf_file,
n_gpu_layers=self.n_gpu_layers,
n_ctx=default_ctx_size,
verbose=False, # Enable verbose mode for better error reporting
)
logger.info("Model loaded successfully")
# Format and tokenize system prompt
documents = "\n".join(doc.system_metadata.get("content", "") for doc in docs)
system_prompt = INITIAL_SYSTEM_PROMPT.format(documents=documents)
logger.info(f"Built system prompt: {system_prompt[:200]}...")
try:
tokens = self.llama.tokenize(system_prompt.encode())
logger.info(f"System prompt tokenized to {len(tokens)} tokens")
# Process tokens to build KV cache
logger.info("Evaluating system prompt")
self.llama.eval(tokens)
logger.info("Saving initial KV cache state")
self.state = self.llama.save_state()
self.cached_tokens = len(tokens)
logger.info(f"Initial KV cache built with {self.cached_tokens} tokens")
except Exception as e:
logger.error(f"Error during prompt processing: {str(e)}")
raise ValueError(f"Failed to process system prompt: {str(e)}")
except Exception as e:
logger.error(f"Failed to initialize Llama model: {str(e)}")
raise ValueError(f"Failed to initialize Llama model: {str(e)}")
def add_docs(self, docs: List[Document]) -> bool:
logger.info(f"Adding {len(docs)} new documents to cache")
documents = "\n".join(doc.system_metadata.get("content", "") for doc in docs)
system_prompt = ADD_DOC_SYSTEM_PROMPT.format(documents=documents)
# Tokenize and process
new_tokens = self.llama.tokenize(system_prompt.encode())
self.llama.eval(new_tokens)
self.state = self.llama.save_state()
self.cached_tokens += len(new_tokens)
logger.info(f"Added {len(new_tokens)} tokens, total: {self.cached_tokens}")
return True
def query(self, query: str) -> CompletionResponse:
# Format query with proper chat template
formatted_query = QUERY_PROMPT.format(query=query)
logger.info(f"Processing query: {formatted_query}")
# Reset and load cached state
self.llama.reset()
self.llama.load_state(self.state)
logger.info(f"Loaded state with {self.state.n_tokens} tokens")
# print(f"Loaded state with {self.state.n_tokens} tokens", file=sys.stderr)
# Tokenize and process query
query_tokens = self.llama.tokenize(formatted_query.encode())
self.llama.eval(query_tokens)
logger.info(f"Evaluated query tokens: {query_tokens}")
# print(f"Evaluated query tokens: {query_tokens}", file=sys.stderr)
# Generate response
output_tokens = []
for token in self.llama.generate(tokens=[], reset=False):
output_tokens.append(token)
# Stop generation when EOT token is encountered
if token == self.llama.token_eos():
break
# Decode and return
completion = self.llama.detokenize(output_tokens).decode()
logger.info(f"Generated completion: {completion}")
return CompletionResponse(
completion=completion,
usage={"prompt_tokens": self.cached_tokens, "completion_tokens": len(output_tokens)},
)
@property
def saveable_state(self) -> bytes:
logger.info("Serializing cache state")
state_bytes = pickle.dumps(self.state)
logger.info(f"Serialized state size: {len(state_bytes)} bytes")
return state_bytes
@classmethod
def from_bytes(cls, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs) -> "LlamaCache":
"""Load a cache from its serialized state.
Args:
name: Name of the cache
cache_bytes: Pickled state bytes
metadata: Cache metadata including model info
**kwargs: Additional arguments
Returns:
LlamaCache: Loaded cache instance
"""
logger.info(f"Loading cache from bytes with name={name}")
logger.info(f"Cache metadata: {metadata}")
# Create new instance with metadata
# logger.info(f"Docs: {metadata['docs']}")
docs = [json.loads(doc) for doc in metadata["docs"]]
# time.sleep(10)
cache = cls(
name=name,
model=metadata["model"],
gguf_file=metadata["model_file"],
filters=metadata["filters"],
docs=[Document(**doc) for doc in docs],
)
# Load the saved state
logger.info(f"Loading saved KV cache state of size {len(cache_bytes)} bytes")
cache.state = pickle.loads(cache_bytes)
cache.llama.load_state(cache.state)
logger.info("Cache successfully loaded from bytes")
return cache
```
## /core/cache/llama_cache_factory.py
```py path="/core/cache/llama_cache_factory.py"
from typing import Any, Dict
from core.cache.base_cache_factory import BaseCacheFactory
from core.cache.llama_cache import LlamaCache
class LlamaCacheFactory(BaseCacheFactory):
def create_new_cache(self, name: str, model: str, model_file: str, **kwargs: Dict[str, Any]) -> LlamaCache:
return LlamaCache(name, model, model_file, **kwargs)
def load_cache_from_bytes(
self, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs: Dict[str, Any]
) -> LlamaCache:
return LlamaCache.from_bytes(name, cache_bytes, metadata, **kwargs)
```
## /core/completion/__init__.py
```py path="/core/completion/__init__.py"
from core.completion.base_completion import BaseCompletionModel
from core.completion.litellm_completion import LiteLLMCompletionModel
__all__ = ["BaseCompletionModel", "LiteLLMCompletionModel"]
```
## /core/completion/base_completion.py
```py path="/core/completion/base_completion.py"
from abc import ABC, abstractmethod
from core.models.completion import CompletionRequest, CompletionResponse
class BaseCompletionModel(ABC):
"""Base class for completion models"""
@abstractmethod
async def complete(self, request: CompletionRequest) -> CompletionResponse:
"""Generate completion from query and context"""
pass
```
## /core/completion/litellm_completion.py
```py path="/core/completion/litellm_completion.py"
import logging
import re # Import re for parsing model name
from typing import Any, Dict, List, Optional, Tuple, Union
import litellm
try:
import ollama
except ImportError:
ollama = None # Make ollama import optional
from pydantic import BaseModel
from core.config import get_settings
from core.models.completion import CompletionRequest, CompletionResponse
from .base_completion import BaseCompletionModel
logger = logging.getLogger(__name__)
def get_system_message() -> Dict[str, str]:
"""Return the standard system message for Morphik's query agent."""
return {
"role": "system",
"content": """You are Morphik's powerful query agent. Your role is to:
1. Analyze the provided context chunks from documents carefully
2. Use the context to answer questions accurately and comprehensively
3. Be clear and concise in your answers
4. When relevant, cite specific parts of the context to support your answers
5. For image-based queries, analyze the visual content in conjunction with any text context provided
Remember: Your primary goal is to provide accurate, context-aware responses that help users understand
and utilize the information in their documents effectively.""",
}
def process_context_chunks(context_chunks: List[str], is_ollama: bool) -> Tuple[List[str], List[str], List[str]]:
"""
Process context chunks and separate text from images.
Args:
context_chunks: List of context chunks which may include images
is_ollama: Whether we're using Ollama (affects image processing)
Returns:
Tuple of (context_text, image_urls, ollama_image_data)
"""
context_text = []
image_urls = [] # For non-Ollama models (full data URI)
ollama_image_data = [] # For Ollama models (raw base64)
for chunk in context_chunks:
if chunk.startswith("data:image/"):
if is_ollama:
# For Ollama, strip the data URI prefix and just keep the base64 data
try:
base64_data = chunk.split(",", 1)[1]
ollama_image_data.append(base64_data)
except IndexError:
logger.warning(f"Could not parse base64 data from image chunk: {chunk[:50]}...")
else:
image_urls.append(chunk)
else:
context_text.append(chunk)
return context_text, image_urls, ollama_image_data
def format_user_content(context_text: List[str], query: str, prompt_template: Optional[str] = None) -> str:
"""
Format the user content based on context and query.
Args:
context_text: List of context text chunks
query: The user query
prompt_template: Optional template to format the content
Returns:
Formatted user content string
"""
context = "\n" + "\n\n".join(context_text) + "\n\n" if context_text else ""
if prompt_template:
return prompt_template.format(
context=context,
question=query,
query=query,
)
elif context_text:
return f"Context: {context} Question: {query}"
else:
return query
def create_dynamic_model_from_schema(schema: Union[type, Dict]) -> Optional[type]:
"""
Create a dynamic Pydantic model from a schema definition.
Args:
schema: Either a Pydantic BaseModel class or a JSON schema dict
Returns:
A Pydantic model class or None if schema format is not recognized
"""
from pydantic import create_model
if isinstance(schema, type) and issubclass(schema, BaseModel):
return schema
elif isinstance(schema, dict) and "properties" in schema:
# Create a dynamic model from JSON schema
field_definitions = {}
schema_dict = schema
for field_name, field_info in schema_dict.get("properties", {}).items():
if isinstance(field_info, dict) and "type" in field_info:
field_type = field_info.get("type")
# Convert schema types to Python types
if field_type == "string":
field_definitions[field_name] = (str, None)
elif field_type == "number":
field_definitions[field_name] = (float, None)
elif field_type == "integer":
field_definitions[field_name] = (int, None)
elif field_type == "boolean":
field_definitions[field_name] = (bool, None)
elif field_type == "array":
field_definitions[field_name] = (list, None)
elif field_type == "object":
field_definitions[field_name] = (dict, None)
else:
# Default to Any for unknown types
field_definitions[field_name] = (Any, None)
# Create the dynamic model
return create_model("DynamicQueryModel", **field_definitions)
else:
logger.warning(f"Unrecognized schema format: {schema}")
return None
class LiteLLMCompletionModel(BaseCompletionModel):
"""
LiteLLM completion model implementation that provides unified access to various LLM providers.
Uses registered models from the config file. Can optionally use direct Ollama client.
"""
def __init__(self, model_key: str):
"""
Initialize LiteLLM completion model with a model key from registered_models.
Args:
model_key: The key of the model in the registered_models config
"""
settings = get_settings()
self.model_key = model_key
# Get the model configuration from registered_models
if not hasattr(settings, "REGISTERED_MODELS") or model_key not in settings.REGISTERED_MODELS:
raise ValueError(f"Model '{model_key}' not found in registered_models configuration")
self.model_config = settings.REGISTERED_MODELS[model_key]
# Check if it's an Ollama model for potential direct usage
self.is_ollama = "ollama" in self.model_config.get("model_name", "").lower()
self.ollama_api_base = None
self.ollama_base_model_name = None
if self.is_ollama:
if ollama is None:
logger.warning("Ollama model selected, but 'ollama' library not installed. Falling back to LiteLLM.")
self.is_ollama = False # Fallback to LiteLLM if library missing
else:
self.ollama_api_base = self.model_config.get("api_base")
if not self.ollama_api_base:
logger.warning(
f"Ollama model {self.model_key} selected for direct use, "
"but 'api_base' is missing in config. Falling back to LiteLLM."
)
self.is_ollama = False # Fallback if api_base is missing
else:
# Extract base model name (e.g., 'llama3.2' from 'ollama_chat/llama3.2')
match = re.search(r"[^/]+$", self.model_config["model_name"])
if match:
self.ollama_base_model_name = match.group(0)
else:
logger.warning(
f"Could not parse base model name from Ollama model "
f"{self.model_config['model_name']}. Falling back to LiteLLM."
)
self.is_ollama = False # Fallback if name parsing fails
logger.info(
f"Initialized LiteLLM completion model with model_key={model_key}, "
f"config={self.model_config}, is_ollama_direct={self.is_ollama}"
)
async def _handle_structured_ollama(
self,
dynamic_model: type,
system_message: Dict[str, str],
user_content: str,
ollama_image_data: List[str],
request: CompletionRequest,
) -> CompletionResponse:
"""Handle structured output generation with Ollama."""
try:
client = ollama.AsyncClient(host=self.ollama_api_base)
# Add images directly to content if available
content_data = user_content
if ollama_image_data and len(ollama_image_data) > 0:
# Ollama image handling is limited; we can use only the first image
content_data = {"content": user_content, "images": [ollama_image_data[0]]}
# Create messages for Ollama
messages = [system_message, {"role": "user", "content": content_data}]
# Get the JSON schema from the dynamic model
format_schema = dynamic_model.model_json_schema()
# Call Ollama directly with format parameter
response = await client.chat(
model=self.ollama_base_model_name,
messages=messages,
format=format_schema,
options={
"temperature": request.temperature or 0.1, # Lower temperature for structured output
"num_predict": request.max_tokens,
},
)
# Parse the response into the dynamic model
parsed_response = dynamic_model.model_validate_json(response["message"]["content"])
# Extract token usage information
usage = {
"prompt_tokens": response.get("prompt_eval_count", 0),
"completion_tokens": response.get("eval_count", 0),
"total_tokens": response.get("prompt_eval_count", 0) + response.get("eval_count", 0),
}
return CompletionResponse(
completion=parsed_response,
usage=usage,
finish_reason=response.get("done_reason", "stop"),
)
except Exception as e:
logger.error(f"Error using Ollama for structured output: {e}")
# Fall back to standard completion if structured output fails
logger.warning("Falling back to standard Ollama completion without structured output")
return None
async def _handle_structured_litellm(
self,
dynamic_model: type,
system_message: Dict[str, str],
user_content: str,
image_urls: List[str],
request: CompletionRequest,
) -> CompletionResponse:
"""Handle structured output generation with LiteLLM."""
import instructor
from instructor import Mode
try:
# Use instructor with litellm
client = instructor.from_litellm(litellm.acompletion, mode=Mode.JSON)
# Create content list with text and images
content_list = [{"type": "text", "text": user_content}]
# Add images if available
if image_urls:
NUM_IMAGES = min(3, len(image_urls))
for img_url in image_urls[:NUM_IMAGES]:
content_list.append({"type": "image_url", "image_url": {"url": img_url}})
# Create messages for instructor
messages = [system_message, {"role": "user", "content": content_list}]
# Extract model configuration
model = self.model_config.get("model_name")
model_kwargs = {k: v for k, v in self.model_config.items() if k != "model_name"}
# Override with completion request parameters
if request.temperature is not None:
model_kwargs["temperature"] = request.temperature
if request.max_tokens is not None:
model_kwargs["max_tokens"] = request.max_tokens
# Add format forcing for structured output
model_kwargs["response_format"] = {"type": "json_object"}
# Call instructor with litellm
response = await client.chat.completions.create(
model=model,
messages=messages,
response_model=dynamic_model,
**model_kwargs,
)
# Get token usage from response
completion_tokens = model_kwargs.get("response_tokens", 0)
prompt_tokens = model_kwargs.get("prompt_tokens", 0)
return CompletionResponse(
completion=response,
usage={
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
finish_reason="stop",
)
except Exception as e:
logger.error(f"Error using instructor with LiteLLM: {e}")
# Fall back to standard completion if instructor fails
logger.warning("Falling back to standard LiteLLM completion without structured output")
return None
async def _handle_standard_ollama(
self, user_content: str, ollama_image_data: List[str], request: CompletionRequest
) -> CompletionResponse:
"""Handle standard (non-structured) output generation with Ollama."""
logger.debug(f"Using direct Ollama client for model: {self.ollama_base_model_name}")
client = ollama.AsyncClient(host=self.ollama_api_base)
# Construct Ollama messages
system_message = {"role": "system", "content": get_system_message()["content"]}
user_message_data = {"role": "user", "content": user_content}
# Add images directly to the user message if available
if ollama_image_data:
if len(ollama_image_data) > 1:
logger.warning(
f"Ollama model {self.model_config['model_name']} only supports one image per message. "
"Using the first image and ignoring others."
)
# Add 'images' key inside the user message dictionary
user_message_data["images"] = [ollama_image_data[0]]
ollama_messages = [system_message, user_message_data]
# Construct Ollama options
options = {
"temperature": request.temperature,
"num_predict": (
request.max_tokens if request.max_tokens is not None else -1
), # Default to model's default if None
}
try:
response = await client.chat(model=self.ollama_base_model_name, messages=ollama_messages, options=options)
# Map Ollama response to CompletionResponse
prompt_tokens = response.get("prompt_eval_count", 0)
completion_tokens = response.get("eval_count", 0)
return CompletionResponse(
completion=response["message"]["content"],
usage={
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
finish_reason=response.get("done_reason", "unknown"), # Map done_reason if available
)
except Exception as e:
logger.error(f"Error during direct Ollama call: {e}")
raise
async def _handle_standard_litellm(
self, user_content: str, image_urls: List[str], request: CompletionRequest
) -> CompletionResponse:
"""Handle standard (non-structured) output generation with LiteLLM."""
logger.debug(f"Using LiteLLM for model: {self.model_config['model_name']}")
# Build messages for LiteLLM
content_list = [{"type": "text", "text": user_content}]
include_images = image_urls # Use the collected full data URIs
if include_images:
NUM_IMAGES = min(3, len(image_urls))
for img_url in image_urls[:NUM_IMAGES]:
content_list.append({"type": "image_url", "image_url": {"url": img_url}})
# LiteLLM uses list content format
user_message = {"role": "user", "content": content_list}
# Use the system prompt defined earlier
litellm_messages = [get_system_message(), user_message]
# Prepare LiteLLM parameters
model_params = {
"model": self.model_config["model_name"],
"messages": litellm_messages,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"num_retries": 3,
}
for key, value in self.model_config.items():
if key != "model_name":
model_params[key] = value
logger.debug(f"Calling LiteLLM with params: {model_params}")
response = await litellm.acompletion(**model_params)
return CompletionResponse(
completion=response.choices[0].message.content,
usage={
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
},
finish_reason=response.choices[0].finish_reason,
)
async def complete(self, request: CompletionRequest) -> CompletionResponse:
"""
Generate completion using LiteLLM or direct Ollama client if configured.
Args:
request: CompletionRequest object containing query, context, and parameters
Returns:
CompletionResponse object with the generated text and usage statistics
"""
# Process context chunks and handle images
context_text, image_urls, ollama_image_data = process_context_chunks(request.context_chunks, self.is_ollama)
# Format user content
user_content = format_user_content(context_text, request.query, request.prompt_template)
# Check if structured output is requested
structured_output = request.schema is not None
# If structured output is requested, use instructor to handle it
if structured_output:
# Get dynamic model from schema
dynamic_model = create_dynamic_model_from_schema(request.schema)
# If schema format is not recognized, log warning and fall back to text completion
if not dynamic_model:
logger.warning(f"Unrecognized schema format: {request.schema}. Falling back to text completion.")
structured_output = False
else:
logger.info(f"Using structured output with model: {dynamic_model.__name__}")
# Create system and user messages with enhanced instructions for structured output
system_message = {
"role": "system",
"content": get_system_message()["content"]
+ "\n\nYou MUST format your response according to the required schema.",
}
# Create enhanced user message that includes schema information
enhanced_user_content = (
user_content + "\n\nPlease format your response according to the required schema."
)
# Try structured output based on model type
if self.is_ollama:
response = await self._handle_structured_ollama(
dynamic_model, system_message, enhanced_user_content, ollama_image_data, request
)
if response:
return response
structured_output = False # Fall back if structured output failed
else:
response = await self._handle_structured_litellm(
dynamic_model, system_message, enhanced_user_content, image_urls, request
)
if response:
return response
structured_output = False # Fall back if structured output failed
# If we're here, either structured output wasn't requested or instructor failed
# Proceed with standard completion based on model type
if self.is_ollama:
return await self._handle_standard_ollama(user_content, ollama_image_data, request)
else:
return await self._handle_standard_litellm(user_content, image_urls, request)
```
## /core/config.py
```py path="/core/config.py"
import os
from collections import ChainMap
from functools import lru_cache
from typing import Any, Dict, Literal, Optional
import tomli
from dotenv import load_dotenv
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Morphik configuration settings."""
# Environment variables
JWT_SECRET_KEY: str
POSTGRES_URI: Optional[str] = None
UNSTRUCTURED_API_KEY: Optional[str] = None
AWS_ACCESS_KEY: Optional[str] = None
AWS_SECRET_ACCESS_KEY: Optional[str] = None
OPENAI_API_KEY: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None
ASSEMBLYAI_API_KEY: Optional[str] = None
# API configuration
HOST: str
PORT: int
RELOAD: bool
# Auth configuration
JWT_ALGORITHM: str
dev_mode: bool = False
dev_entity_type: str = "developer"
dev_entity_id: str = "dev_user"
dev_permissions: list = ["read", "write", "admin"]
# Registered models configuration
REGISTERED_MODELS: Dict[str, Dict[str, Any]] = {}
# Completion configuration
COMPLETION_PROVIDER: Literal["litellm"] = "litellm"
COMPLETION_MODEL: str
# Database configuration
DATABASE_PROVIDER: Literal["postgres"]
DATABASE_NAME: Optional[str] = None
# Database connection pool settings
DB_POOL_SIZE: int = 20
DB_MAX_OVERFLOW: int = 30
DB_POOL_RECYCLE: int = 3600
DB_POOL_TIMEOUT: int = 10
DB_POOL_PRE_PING: bool = True
DB_MAX_RETRIES: int = 3
DB_RETRY_DELAY: float = 1.0
# Embedding configuration
EMBEDDING_PROVIDER: Literal["litellm"] = "litellm"
EMBEDDING_MODEL: str
VECTOR_DIMENSIONS: int
EMBEDDING_SIMILARITY_METRIC: Literal["cosine", "dotProduct"]
# Parser configuration
CHUNK_SIZE: int
CHUNK_OVERLAP: int
USE_UNSTRUCTURED_API: bool
FRAME_SAMPLE_RATE: Optional[int] = None
USE_CONTEXTUAL_CHUNKING: bool = False
# Rules configuration
RULES_PROVIDER: Literal["litellm"] = "litellm"
RULES_MODEL: str
RULES_BATCH_SIZE: int = 4096
# Graph configuration
GRAPH_PROVIDER: Literal["litellm"] = "litellm"
GRAPH_MODEL: str
ENABLE_ENTITY_RESOLUTION: bool = True
# Reranker configuration
USE_RERANKING: bool
RERANKER_PROVIDER: Optional[Literal["flag"]] = None
RERANKER_MODEL: Optional[str] = None
RERANKER_QUERY_MAX_LENGTH: Optional[int] = None
RERANKER_PASSAGE_MAX_LENGTH: Optional[int] = None
RERANKER_USE_FP16: Optional[bool] = None
RERANKER_DEVICE: Optional[str] = None
# Storage configuration
STORAGE_PROVIDER: Literal["local", "aws-s3"]
STORAGE_PATH: Optional[str] = None
AWS_REGION: Optional[str] = None
S3_BUCKET: Optional[str] = None
# Vector store configuration
VECTOR_STORE_PROVIDER: Literal["pgvector"]
VECTOR_STORE_DATABASE_NAME: Optional[str] = None
# Colpali configuration
ENABLE_COLPALI: bool
# Mode configuration
MODE: Literal["cloud", "self_hosted"] = "cloud"
# API configuration
API_DOMAIN: str = "api.morphik.ai"
# Redis configuration
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379
# Telemetry configuration
TELEMETRY_ENABLED: bool = True
HONEYCOMB_ENABLED: bool = True
HONEYCOMB_ENDPOINT: str = "https://api.honeycomb.io"
HONEYCOMB_PROXY_ENDPOINT: str = "https://otel-proxy.onrender.com/"
SERVICE_NAME: str = "morphik-core"
OTLP_TIMEOUT: int = 10
OTLP_MAX_RETRIES: int = 3
OTLP_RETRY_DELAY: int = 1
OTLP_MAX_EXPORT_BATCH_SIZE: int = 512
OTLP_SCHEDULE_DELAY_MILLIS: int = 5000
OTLP_MAX_QUEUE_SIZE: int = 2048
@lru_cache()
def get_settings() -> Settings:
"""Get cached settings instance."""
load_dotenv(override=True)
# Load config.toml
with open("morphik.toml", "rb") as f:
config = tomli.load(f)
em = "'{missing_value}' needed if '{field}' is set to '{value}'"
openai_config = {}
# load api config
api_config = {
"HOST": config["api"]["host"],
"PORT": int(config["api"]["port"]),
"RELOAD": bool(config["api"]["reload"]),
}
# load auth config
auth_config = {
"JWT_ALGORITHM": config["auth"]["jwt_algorithm"],
"JWT_SECRET_KEY": os.environ.get("JWT_SECRET_KEY", "dev-secret-key"), # Default for dev mode
"dev_mode": config["auth"].get("dev_mode", False),
"dev_entity_type": config["auth"].get("dev_entity_type", "developer"),
"dev_entity_id": config["auth"].get("dev_entity_id", "dev_user"),
"dev_permissions": config["auth"].get("dev_permissions", ["read", "write", "admin"]),
}
# Only require JWT_SECRET_KEY in non-dev mode
if not auth_config["dev_mode"] and "JWT_SECRET_KEY" not in os.environ:
raise ValueError("JWT_SECRET_KEY is required when dev_mode is disabled")
# Load registered models if available
registered_models = {}
if "registered_models" in config:
registered_models = {"REGISTERED_MODELS": config["registered_models"]}
# load completion config
completion_config = {
"COMPLETION_PROVIDER": "litellm",
}
# Set the model key for LiteLLM
if "model" not in config["completion"]:
raise ValueError("'model' is required in the completion configuration")
completion_config["COMPLETION_MODEL"] = config["completion"]["model"]
# load database config
database_config = {
"DATABASE_PROVIDER": config["database"]["provider"],
"DATABASE_NAME": config["database"].get("name", None),
# Add database connection pool settings
"DB_POOL_SIZE": config["database"].get("pool_size", 20),
"DB_MAX_OVERFLOW": config["database"].get("max_overflow", 30),
"DB_POOL_RECYCLE": config["database"].get("pool_recycle", 3600),
"DB_POOL_TIMEOUT": config["database"].get("pool_timeout", 10),
"DB_POOL_PRE_PING": config["database"].get("pool_pre_ping", True),
"DB_MAX_RETRIES": config["database"].get("max_retries", 3),
"DB_RETRY_DELAY": config["database"].get("retry_delay", 1.0),
}
if database_config["DATABASE_PROVIDER"] != "postgres":
prov = database_config["DATABASE_PROVIDER"]
raise ValueError(f"Unknown database provider selected: '{prov}'")
if "POSTGRES_URI" in os.environ:
database_config.update({"POSTGRES_URI": os.environ["POSTGRES_URI"]})
else:
msg = em.format(missing_value="POSTGRES_URI", field="database.provider", value="postgres")
raise ValueError(msg)
# load embedding config
embedding_config = {
"EMBEDDING_PROVIDER": "litellm",
"VECTOR_DIMENSIONS": config["embedding"]["dimensions"],
"EMBEDDING_SIMILARITY_METRIC": config["embedding"]["similarity_metric"],
}
# Set the model key for LiteLLM
if "model" not in config["embedding"]:
raise ValueError("'model' is required in the embedding configuration")
embedding_config["EMBEDDING_MODEL"] = config["embedding"]["model"]
# load parser config
parser_config = {
"CHUNK_SIZE": config["parser"]["chunk_size"],
"CHUNK_OVERLAP": config["parser"]["chunk_overlap"],
"USE_UNSTRUCTURED_API": config["parser"]["use_unstructured_api"],
"USE_CONTEXTUAL_CHUNKING": config["parser"].get("use_contextual_chunking", False),
}
if parser_config["USE_UNSTRUCTURED_API"] and "UNSTRUCTURED_API_KEY" not in os.environ:
msg = em.format(missing_value="UNSTRUCTURED_API_KEY", field="parser.use_unstructured_api", value="true")
raise ValueError(msg)
elif parser_config["USE_UNSTRUCTURED_API"]:
parser_config.update({"UNSTRUCTURED_API_KEY": os.environ["UNSTRUCTURED_API_KEY"]})
# load reranker config
reranker_config = {"USE_RERANKING": config["reranker"]["use_reranker"]}
if reranker_config["USE_RERANKING"]:
reranker_config.update(
{
"RERANKER_PROVIDER": config["reranker"]["provider"],
"RERANKER_MODEL": config["reranker"]["model_name"],
"RERANKER_QUERY_MAX_LENGTH": config["reranker"]["query_max_length"],
"RERANKER_PASSAGE_MAX_LENGTH": config["reranker"]["passage_max_length"],
"RERANKER_USE_FP16": config["reranker"]["use_fp16"],
"RERANKER_DEVICE": config["reranker"]["device"],
}
)
# load storage config
storage_config = {
"STORAGE_PROVIDER": config["storage"]["provider"],
"STORAGE_PATH": config["storage"]["storage_path"],
}
match storage_config["STORAGE_PROVIDER"]:
case "local":
storage_config.update({"STORAGE_PATH": config["storage"]["storage_path"]})
case "aws-s3" if all(key in os.environ for key in ["AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY"]):
storage_config.update(
{
"AWS_REGION": config["storage"]["region"],
"S3_BUCKET": config["storage"]["bucket_name"],
"AWS_ACCESS_KEY": os.environ["AWS_ACCESS_KEY"],
"AWS_SECRET_ACCESS_KEY": os.environ["AWS_SECRET_ACCESS_KEY"],
}
)
case "aws-s3":
msg = em.format(missing_value="AWS credentials", field="storage.provider", value="aws-s3")
raise ValueError(msg)
case _:
prov = storage_config["STORAGE_PROVIDER"]
raise ValueError(f"Unknown storage provider selected: '{prov}'")
# load vector store config
vector_store_config = {"VECTOR_STORE_PROVIDER": config["vector_store"]["provider"]}
if vector_store_config["VECTOR_STORE_PROVIDER"] != "pgvector":
prov = vector_store_config["VECTOR_STORE_PROVIDER"]
raise ValueError(f"Unknown vector store provider selected: '{prov}'")
if "POSTGRES_URI" not in os.environ:
msg = em.format(missing_value="POSTGRES_URI", field="vector_store.provider", value="pgvector")
raise ValueError(msg)
# load rules config
rules_config = {
"RULES_PROVIDER": "litellm",
"RULES_BATCH_SIZE": config["rules"]["batch_size"],
}
# Set the model key for LiteLLM
if "model" not in config["rules"]:
raise ValueError("'model' is required in the rules configuration")
rules_config["RULES_MODEL"] = config["rules"]["model"]
# load morphik config
morphik_config = {
"ENABLE_COLPALI": config["morphik"]["enable_colpali"],
"MODE": config["morphik"].get("mode", "cloud"), # Default to "cloud" mode
"API_DOMAIN": config["morphik"].get("api_domain", "api.morphik.ai"), # Default API domain
}
# load redis config
redis_config = {}
if "redis" in config:
redis_config = {
"REDIS_HOST": config["redis"].get("host", "localhost"),
"REDIS_PORT": int(config["redis"].get("port", 6379)),
}
# load graph config
graph_config = {
"GRAPH_PROVIDER": "litellm",
"ENABLE_ENTITY_RESOLUTION": config["graph"].get("enable_entity_resolution", True),
}
# Set the model key for LiteLLM
if "model" not in config["graph"]:
raise ValueError("'model' is required in the graph configuration")
graph_config["GRAPH_MODEL"] = config["graph"]["model"]
# load telemetry config
telemetry_config = {}
if "telemetry" in config:
telemetry_config = {
"TELEMETRY_ENABLED": config["telemetry"].get("enabled", True),
"HONEYCOMB_ENABLED": config["telemetry"].get("honeycomb_enabled", True),
"HONEYCOMB_ENDPOINT": config["telemetry"].get("honeycomb_endpoint", "https://api.honeycomb.io"),
"SERVICE_NAME": config["telemetry"].get("service_name", "morphik-core"),
"OTLP_TIMEOUT": config["telemetry"].get("otlp_timeout", 10),
"OTLP_MAX_RETRIES": config["telemetry"].get("otlp_max_retries", 3),
"OTLP_RETRY_DELAY": config["telemetry"].get("otlp_retry_delay", 1),
"OTLP_MAX_EXPORT_BATCH_SIZE": config["telemetry"].get("otlp_max_export_batch_size", 512),
"OTLP_SCHEDULE_DELAY_MILLIS": config["telemetry"].get("otlp_schedule_delay_millis", 5000),
"OTLP_MAX_QUEUE_SIZE": config["telemetry"].get("otlp_max_queue_size", 2048),
}
settings_dict = dict(
ChainMap(
api_config,
auth_config,
registered_models,
completion_config,
database_config,
embedding_config,
parser_config,
reranker_config,
storage_config,
vector_store_config,
rules_config,
morphik_config,
redis_config,
graph_config,
telemetry_config,
openai_config,
)
)
return Settings(**settings_dict)
```
## /core/database/base_database.py
```py path="/core/database/base_database.py"
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from ..models.auth import AuthContext
from ..models.documents import Document
from ..models.folders import Folder
from ..models.graph import Graph
class BaseDatabase(ABC):
"""Base interface for document metadata storage."""
@abstractmethod
async def store_document(self, document: Document) -> bool:
"""
Store document metadata.
Returns: Success status
"""
pass
@abstractmethod
async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]:
"""
Retrieve document metadata by ID if user has access.
Returns: Document if found and accessible, None otherwise
"""
pass
@abstractmethod
async def get_document_by_filename(
self, filename: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None
) -> Optional[Document]:
"""
Retrieve document metadata by filename if user has access.
If multiple documents have the same filename, returns the most recently updated one.
Args:
filename: The filename to search for
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
Document if found and accessible, None otherwise
"""
pass
@abstractmethod
async def get_documents_by_id(
self,
document_ids: List[str],
auth: AuthContext,
system_filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Only returns documents the user has access to.
Can filter by system metadata fields like folder_name and end_user_id.
Args:
document_ids: List of document IDs to retrieve
auth: Authentication context
system_filters: Optional filters for system metadata fields
Returns:
List of Document objects that were found and user has access to
"""
pass
@abstractmethod
async def get_documents(
self,
auth: AuthContext,
skip: int = 0,
limit: int = 100,
filters: Optional[Dict[str, Any]] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""
List documents the user has access to.
Supports pagination and filtering.
Args:
auth: Authentication context
skip: Number of documents to skip (for pagination)
limit: Maximum number of documents to return
filters: Optional metadata filters
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
List of documents matching the criteria
"""
pass
@abstractmethod
async def update_document(self, document_id: str, updates: Dict[str, Any], auth: AuthContext) -> bool:
"""
Update document metadata if user has access.
Returns: Success status
"""
pass
@abstractmethod
async def delete_document(self, document_id: str, auth: AuthContext) -> bool:
"""
Delete document metadata if user has admin access.
Returns: Success status
"""
pass
@abstractmethod
async def find_authorized_and_filtered_documents(
self,
auth: AuthContext,
filters: Optional[Dict[str, Any]] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> List[str]:
"""Find document IDs matching filters that user has access to.
Args:
auth: Authentication context
filters: Optional metadata filters
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
List of document IDs matching the criteria
"""
pass
@abstractmethod
async def check_access(self, document_id: str, auth: AuthContext, required_permission: str = "read") -> bool:
"""
Check if user has required permission for document.
Returns: True if user has required access, False otherwise
"""
pass
@abstractmethod
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
"""Store metadata for a cache.
Args:
name: Name of the cache
metadata: Cache metadata including model info and storage location
Returns:
bool: Whether the operation was successful
"""
pass
@abstractmethod
async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]:
"""Get metadata for a cache.
Args:
name: Name of the cache
Returns:
Optional[Dict[str, Any]]: Cache metadata if found, None otherwise
"""
pass
@abstractmethod
async def store_graph(self, graph: Graph) -> bool:
"""Store a graph.
Args:
graph: Graph to store
Returns:
bool: Whether the operation was successful
"""
pass
@abstractmethod
async def get_graph(
self, name: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None
) -> Optional[Graph]:
"""Get a graph by name.
Args:
name: Name of the graph
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
Optional[Graph]: Graph if found and accessible, None otherwise
"""
pass
@abstractmethod
async def list_graphs(self, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Graph]:
"""List all graphs the user has access to.
Args:
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
List[Graph]: List of graphs
"""
pass
@abstractmethod
async def update_graph(self, graph: Graph) -> bool:
"""Update an existing graph.
Args:
graph: Graph to update
Returns:
bool: Whether the operation was successful
"""
pass
@abstractmethod
async def create_folder(self, folder: Folder) -> bool:
"""Create a new folder.
Args:
folder: Folder to create
Returns:
bool: Whether the operation was successful
"""
pass
@abstractmethod
async def get_folder(self, folder_id: str, auth: AuthContext) -> Optional[Folder]:
"""Get a folder by ID.
Args:
folder_id: ID of the folder
auth: Authentication context
Returns:
Optional[Folder]: Folder if found and accessible, None otherwise
"""
pass
@abstractmethod
async def get_folder_by_name(self, name: str, auth: AuthContext) -> Optional[Folder]:
"""Get a folder by name.
Args:
name: Name of the folder
auth: Authentication context
Returns:
Optional[Folder]: Folder if found and accessible, None otherwise
"""
pass
@abstractmethod
async def list_folders(self, auth: AuthContext) -> List[Folder]:
"""List all folders the user has access to.
Args:
auth: Authentication context
Returns:
List[Folder]: List of folders
"""
pass
@abstractmethod
async def add_document_to_folder(self, folder_id: str, document_id: str, auth: AuthContext) -> bool:
"""Add a document to a folder.
Args:
folder_id: ID of the folder
document_id: ID of the document
auth: Authentication context
Returns:
bool: Whether the operation was successful
"""
pass
@abstractmethod
async def remove_document_from_folder(self, folder_id: str, document_id: str, auth: AuthContext) -> bool:
"""Remove a document from a folder.
Args:
folder_id: ID of the folder
document_id: ID of the document
auth: Authentication context
Returns:
bool: Whether the operation was successful
"""
pass
```
## /core/database/postgres_database.py
```py path="/core/database/postgres_database.py"
import json
import logging
from datetime import UTC, datetime
from typing import Any, Dict, List, Optional
from sqlalchemy import Column, Index, String, select, text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from ..models.auth import AuthContext
from ..models.documents import Document, StorageFileInfo
from ..models.folders import Folder
from ..models.graph import Graph
from .base_database import BaseDatabase
logger = logging.getLogger(__name__)
Base = declarative_base()
class DocumentModel(Base):
"""SQLAlchemy model for document metadata."""
__tablename__ = "documents"
external_id = Column(String, primary_key=True)
owner = Column(JSONB)
content_type = Column(String)
filename = Column(String, nullable=True)
doc_metadata = Column(JSONB, default=dict)
storage_info = Column(JSONB, default=dict)
system_metadata = Column(JSONB, default=dict)
additional_metadata = Column(JSONB, default=dict)
access_control = Column(JSONB, default=dict)
chunk_ids = Column(JSONB, default=list)
storage_files = Column(JSONB, default=list)
# Create indexes
__table_args__ = (
Index("idx_owner_id", "owner", postgresql_using="gin"),
Index("idx_access_control", "access_control", postgresql_using="gin"),
Index("idx_system_metadata", "system_metadata", postgresql_using="gin"),
)
class GraphModel(Base):
"""SQLAlchemy model for graph data."""
__tablename__ = "graphs"
id = Column(String, primary_key=True)
name = Column(String, index=True) # Not unique globally anymore
entities = Column(JSONB, default=list)
relationships = Column(JSONB, default=list)
graph_metadata = Column(JSONB, default=dict) # Renamed from 'metadata' to avoid conflict
system_metadata = Column(JSONB, default=dict) # For folder_name and end_user_id
document_ids = Column(JSONB, default=list)
filters = Column(JSONB, nullable=True)
created_at = Column(String) # ISO format string
updated_at = Column(String) # ISO format string
owner = Column(JSONB)
access_control = Column(JSONB, default=dict)
# Create indexes
__table_args__ = (
Index("idx_graph_name", "name"),
Index("idx_graph_owner", "owner", postgresql_using="gin"),
Index("idx_graph_access_control", "access_control", postgresql_using="gin"),
Index("idx_graph_system_metadata", "system_metadata", postgresql_using="gin"),
# Create a unique constraint on name scoped by owner ID
Index("idx_graph_owner_name", "name", text("(owner->>'id')"), unique=True),
)
class FolderModel(Base):
"""SQLAlchemy model for folder data."""
__tablename__ = "folders"
id = Column(String, primary_key=True)
name = Column(String, index=True)
description = Column(String, nullable=True)
owner = Column(JSONB)
document_ids = Column(JSONB, default=list)
system_metadata = Column(JSONB, default=dict)
access_control = Column(JSONB, default=dict)
rules = Column(JSONB, default=list)
# Create indexes
__table_args__ = (
Index("idx_folder_name", "name"),
Index("idx_folder_owner", "owner", postgresql_using="gin"),
Index("idx_folder_access_control", "access_control", postgresql_using="gin"),
)
def _serialize_datetime(obj: Any) -> Any:
"""Helper function to serialize datetime objects to ISO format strings."""
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, dict):
return {key: _serialize_datetime(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [_serialize_datetime(item) for item in obj]
return obj
class PostgresDatabase(BaseDatabase):
"""PostgreSQL implementation for document metadata storage."""
def __init__(
self,
uri: str,
):
"""Initialize PostgreSQL connection for document storage."""
# Load settings from config
from core.config import get_settings
settings = get_settings()
# Get database pool settings from config with defaults
pool_size = getattr(settings, "DB_POOL_SIZE", 20)
max_overflow = getattr(settings, "DB_MAX_OVERFLOW", 30)
pool_recycle = getattr(settings, "DB_POOL_RECYCLE", 3600)
pool_timeout = getattr(settings, "DB_POOL_TIMEOUT", 10)
pool_pre_ping = getattr(settings, "DB_POOL_PRE_PING", True)
logger.info(
f"Initializing PostgreSQL connection pool with size={pool_size}, "
f"max_overflow={max_overflow}, pool_recycle={pool_recycle}s"
)
# Create async engine with explicit pool settings
self.engine = create_async_engine(
uri,
# Prevent connection timeouts by keeping connections alive
pool_pre_ping=pool_pre_ping,
# Increase pool size to handle concurrent operations
pool_size=pool_size,
# Maximum overflow connections allowed beyond pool_size
max_overflow=max_overflow,
# Keep connections in the pool for up to 60 minutes
pool_recycle=pool_recycle,
# Time to wait for a connection from the pool (10 seconds)
pool_timeout=pool_timeout,
# Echo SQL for debugging (set to False in production)
echo=False,
)
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
self._initialized = False
async def initialize(self):
"""Initialize database tables and indexes."""
if self._initialized:
return True
try:
logger.info("Initializing PostgreSQL database tables and indexes...")
# Create ORM models
async with self.engine.begin() as conn:
# Explicitly create all tables with checkfirst=True to avoid errors if tables already exist
await conn.run_sync(lambda conn: Base.metadata.create_all(conn, checkfirst=True))
# No need to manually create graphs table again since SQLAlchemy does it
logger.info("Created database tables successfully")
# Create caches table if it doesn't exist (kept as direct SQL for backward compatibility)
await conn.execute(
text(
"""
CREATE TABLE IF NOT EXISTS caches (
name TEXT PRIMARY KEY,
metadata JSONB NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)
"""
)
)
# Check if storage_files column exists
result = await conn.execute(
text(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'documents' AND column_name = 'storage_files'
"""
)
)
if not result.first():
# Add storage_files column to documents table
await conn.execute(
text(
"""
ALTER TABLE documents
ADD COLUMN IF NOT EXISTS storage_files JSONB DEFAULT '[]'::jsonb
"""
)
)
logger.info("Added storage_files column to documents table")
# Create indexes for folder_name and end_user_id in system_metadata for documents
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_system_metadata_folder_name
ON documents ((system_metadata->>'folder_name'));
"""
)
)
# Create folders table if it doesn't exist
await conn.execute(
text(
"""
CREATE TABLE IF NOT EXISTS folders (
id TEXT PRIMARY KEY,
name TEXT,
description TEXT,
owner JSONB,
document_ids JSONB DEFAULT '[]',
system_metadata JSONB DEFAULT '{}',
access_control JSONB DEFAULT '{}'
);
"""
)
)
# Add rules column to folders table if it doesn't exist
result = await conn.execute(
text(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'folders' AND column_name = 'rules'
"""
)
)
if not result.first():
# Add rules column to folders table
await conn.execute(
text(
"""
ALTER TABLE folders
ADD COLUMN IF NOT EXISTS rules JSONB DEFAULT '[]'::jsonb
"""
)
)
logger.info("Added rules column to folders table")
# Create indexes for folders table
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_folder_name ON folders (name);"))
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_folder_owner ON folders USING gin (owner);"))
await conn.execute(
text("CREATE INDEX IF NOT EXISTS idx_folder_access_control ON folders USING gin (access_control);")
)
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_system_metadata_end_user_id
ON documents ((system_metadata->>'end_user_id'));
"""
)
)
# Check if system_metadata column exists in graphs table
result = await conn.execute(
text(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'graphs' AND column_name = 'system_metadata'
"""
)
)
if not result.first():
# Add system_metadata column to graphs table
await conn.execute(
text(
"""
ALTER TABLE graphs
ADD COLUMN IF NOT EXISTS system_metadata JSONB DEFAULT '{}'::jsonb
"""
)
)
logger.info("Added system_metadata column to graphs table")
# Create indexes for folder_name and end_user_id in system_metadata for graphs
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_graph_system_metadata_folder_name
ON graphs ((system_metadata->>'folder_name'));
"""
)
)
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_graph_system_metadata_end_user_id
ON graphs ((system_metadata->>'end_user_id'));
"""
)
)
logger.info("Created indexes for folder_name and end_user_id in system_metadata")
logger.info("PostgreSQL tables and indexes created successfully")
self._initialized = True
return True
except Exception as e:
logger.error(f"Error creating PostgreSQL tables and indexes: {str(e)}")
return False
async def store_document(self, document: Document) -> bool:
"""Store document metadata."""
try:
doc_dict = document.model_dump()
# Rename metadata to doc_metadata
if "metadata" in doc_dict:
doc_dict["doc_metadata"] = doc_dict.pop("metadata")
doc_dict["doc_metadata"]["external_id"] = doc_dict["external_id"]
# Ensure system metadata
if "system_metadata" not in doc_dict:
doc_dict["system_metadata"] = {}
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
# Handle storage_files
if "storage_files" in doc_dict and doc_dict["storage_files"]:
# Convert storage_files to the expected format for storage
doc_dict["storage_files"] = [file.model_dump() for file in doc_dict["storage_files"]]
# Serialize datetime objects to ISO format strings
doc_dict = _serialize_datetime(doc_dict)
async with self.async_session() as session:
doc_model = DocumentModel(**doc_dict)
session.add(doc_model)
await session.commit()
return True
except Exception as e:
logger.error(f"Error storing document metadata: {str(e)}")
return False
async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]:
"""Retrieve document metadata by ID if user has access."""
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query document
query = (
select(DocumentModel)
.where(DocumentModel.external_id == document_id)
.where(text(f"({access_filter})"))
)
result = await session.execute(query)
doc_model = result.scalar_one_or_none()
if doc_model:
# Convert doc_metadata back to metadata
# Also convert storage_files from dict to StorageFileInfo
storage_files = []
if doc_model.storage_files:
for file_info in doc_model.storage_files:
if isinstance(file_info, dict):
storage_files.append(StorageFileInfo(**file_info))
else:
storage_files.append(file_info)
doc_dict = {
"external_id": doc_model.external_id,
"owner": doc_model.owner,
"content_type": doc_model.content_type,
"filename": doc_model.filename,
"metadata": doc_model.doc_metadata,
"storage_info": doc_model.storage_info,
"system_metadata": doc_model.system_metadata,
"additional_metadata": doc_model.additional_metadata,
"access_control": doc_model.access_control,
"chunk_ids": doc_model.chunk_ids,
"storage_files": storage_files,
}
return Document(**doc_dict)
return None
except Exception as e:
logger.error(f"Error retrieving document metadata: {str(e)}")
return None
async def get_document_by_filename(
self, filename: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None
) -> Optional[Document]:
"""Retrieve document metadata by filename if user has access.
If multiple documents have the same filename, returns the most recently updated one.
Args:
filename: The filename to search for
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
"""
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
system_metadata_filter = self._build_system_metadata_filter(system_filters)
filename = filename.replace("'", "''")
# Construct where clauses
where_clauses = [
f"({access_filter})",
f"filename = '{filename}'", # Escape single quotes
]
if system_metadata_filter:
where_clauses.append(f"({system_metadata_filter})")
final_where_clause = " AND ".join(where_clauses)
# Query document with system filters
query = (
select(DocumentModel).where(text(final_where_clause))
# Order by updated_at in system_metadata to get the most recent document
.order_by(text("system_metadata->>'updated_at' DESC"))
)
logger.debug(f"Querying document by filename with system filters: {system_filters}")
result = await session.execute(query)
doc_model = result.scalar_one_or_none()
if doc_model:
# Convert doc_metadata back to metadata
# Also convert storage_files from dict to StorageFileInfo
storage_files = []
if doc_model.storage_files:
for file_info in doc_model.storage_files:
if isinstance(file_info, dict):
storage_files.append(StorageFileInfo(**file_info))
else:
storage_files.append(file_info)
doc_dict = {
"external_id": doc_model.external_id,
"owner": doc_model.owner,
"content_type": doc_model.content_type,
"filename": doc_model.filename,
"metadata": doc_model.doc_metadata,
"storage_info": doc_model.storage_info,
"system_metadata": doc_model.system_metadata,
"additional_metadata": doc_model.additional_metadata,
"access_control": doc_model.access_control,
"chunk_ids": doc_model.chunk_ids,
"storage_files": storage_files,
}
return Document(**doc_dict)
return None
except Exception as e:
logger.error(f"Error retrieving document metadata by filename: {str(e)}")
return None
async def get_documents_by_id(
self,
document_ids: List[str],
auth: AuthContext,
system_filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Only returns documents the user has access to.
Can filter by system metadata fields like folder_name and end_user_id.
Args:
document_ids: List of document IDs to retrieve
auth: Authentication context
system_filters: Optional filters for system metadata fields
Returns:
List of Document objects that were found and user has access to
"""
try:
if not document_ids:
return []
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
system_metadata_filter = self._build_system_metadata_filter(system_filters)
# Construct where clauses
document_ids_linked = ", ".join([("'" + doc_id + "'") for doc_id in document_ids])
where_clauses = [f"({access_filter})", f"external_id IN ({document_ids_linked})"]
if system_metadata_filter:
where_clauses.append(f"({system_metadata_filter})")
final_where_clause = " AND ".join(where_clauses)
# Query documents with document IDs, access check, and system filters in a single query
query = select(DocumentModel).where(text(final_where_clause))
logger.info(f"Batch retrieving {len(document_ids)} documents with a single query")
# Execute batch query
result = await session.execute(query)
doc_models = result.scalars().all()
documents = []
for doc_model in doc_models:
# Convert doc_metadata back to metadata
doc_dict = {
"external_id": doc_model.external_id,
"owner": doc_model.owner,
"content_type": doc_model.content_type,
"filename": doc_model.filename,
"metadata": doc_model.doc_metadata,
"storage_info": doc_model.storage_info,
"system_metadata": doc_model.system_metadata,
"additional_metadata": doc_model.additional_metadata,
"access_control": doc_model.access_control,
"chunk_ids": doc_model.chunk_ids,
"storage_files": doc_model.storage_files or [],
}
documents.append(Document(**doc_dict))
logger.info(f"Found {len(documents)} documents in batch retrieval")
return documents
except Exception as e:
logger.error(f"Error batch retrieving documents: {str(e)}")
return []
async def get_documents(
self,
auth: AuthContext,
skip: int = 0,
limit: int = 10000,
filters: Optional[Dict[str, Any]] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""List documents the user has access to."""
try:
async with self.async_session() as session:
# Build query
access_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
system_metadata_filter = self._build_system_metadata_filter(system_filters)
where_clauses = [f"({access_filter})"]
if metadata_filter:
where_clauses.append(f"({metadata_filter})")
if system_metadata_filter:
where_clauses.append(f"({system_metadata_filter})")
final_where_clause = " AND ".join(where_clauses)
query = select(DocumentModel).where(text(final_where_clause))
query = query.offset(skip).limit(limit)
result = await session.execute(query)
doc_models = result.scalars().all()
return [
Document(
external_id=doc.external_id,
owner=doc.owner,
content_type=doc.content_type,
filename=doc.filename,
metadata=doc.doc_metadata,
storage_info=doc.storage_info,
system_metadata=doc.system_metadata,
additional_metadata=doc.additional_metadata,
access_control=doc.access_control,
chunk_ids=doc.chunk_ids,
storage_files=doc.storage_files or [],
)
for doc in doc_models
]
except Exception as e:
logger.error(f"Error listing documents: {str(e)}")
return []
async def update_document(self, document_id: str, updates: Dict[str, Any], auth: AuthContext) -> bool:
"""Update document metadata if user has write access."""
try:
if not await self.check_access(document_id, auth, "write"):
return False
# Get existing document to preserve system_metadata
existing_doc = await self.get_document(document_id, auth)
if not existing_doc:
return False
# Update system metadata
updates.setdefault("system_metadata", {})
# Merge with existing system_metadata instead of just preserving specific fields
if existing_doc.system_metadata:
# Start with existing system_metadata
merged_system_metadata = dict(existing_doc.system_metadata)
# Update with new values
merged_system_metadata.update(updates["system_metadata"])
# Replace with merged result
updates["system_metadata"] = merged_system_metadata
logger.debug("Merged system_metadata during document update, preserving existing fields")
# Always update the updated_at timestamp
updates["system_metadata"]["updated_at"] = datetime.now(UTC)
# Serialize datetime objects to ISO format strings
updates = _serialize_datetime(updates)
async with self.async_session() as session:
result = await session.execute(select(DocumentModel).where(DocumentModel.external_id == document_id))
doc_model = result.scalar_one_or_none()
if doc_model:
# Log what we're updating
logger.info(f"Document update: updating fields {list(updates.keys())}")
# Special handling for metadata/doc_metadata conversion
if "metadata" in updates and "doc_metadata" not in updates:
logger.info("Converting 'metadata' to 'doc_metadata' for database update")
updates["doc_metadata"] = updates.pop("metadata")
# Set all attributes
for key, value in updates.items():
if key == "storage_files" and isinstance(value, list):
serialized_value = [
_serialize_datetime(
item.model_dump()
if hasattr(item, "model_dump")
else (item.dict() if hasattr(item, "dict") else item)
)
for item in value
]
logger.debug("Serializing storage_files before setting attribute")
setattr(doc_model, key, serialized_value)
else:
logger.debug(f"Setting document attribute {key} = {value}")
setattr(doc_model, key, value)
await session.commit()
logger.info(f"Document {document_id} updated successfully")
return True
return False
except Exception as e:
logger.error(f"Error updating document metadata: {str(e)}")
return False
async def delete_document(self, document_id: str, auth: AuthContext) -> bool:
"""Delete document if user has write access."""
try:
if not await self.check_access(document_id, auth, "write"):
return False
async with self.async_session() as session:
result = await session.execute(select(DocumentModel).where(DocumentModel.external_id == document_id))
doc_model = result.scalar_one_or_none()
if doc_model:
await session.delete(doc_model)
await session.commit()
return True
return False
except Exception as e:
logger.error(f"Error deleting document: {str(e)}")
return False
async def find_authorized_and_filtered_documents(
self,
auth: AuthContext,
filters: Optional[Dict[str, Any]] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> List[str]:
"""Find document IDs matching filters and access permissions."""
try:
async with self.async_session() as session:
# Build query
access_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
system_metadata_filter = self._build_system_metadata_filter(system_filters)
logger.debug(f"Access filter: {access_filter}")
logger.debug(f"Metadata filter: {metadata_filter}")
logger.debug(f"System metadata filter: {system_metadata_filter}")
logger.debug(f"Original filters: {filters}")
logger.debug(f"System filters: {system_filters}")
where_clauses = [f"({access_filter})"]
if metadata_filter:
where_clauses.append(f"({metadata_filter})")
if system_metadata_filter:
where_clauses.append(f"({system_metadata_filter})")
final_where_clause = " AND ".join(where_clauses)
query = select(DocumentModel.external_id).where(text(final_where_clause))
logger.debug(f"Final query: {query}")
result = await session.execute(query)
doc_ids = [row[0] for row in result.all()]
logger.debug(f"Found document IDs: {doc_ids}")
return doc_ids
except Exception as e:
logger.error(f"Error finding authorized documents: {str(e)}")
return []
async def check_access(self, document_id: str, auth: AuthContext, required_permission: str = "read") -> bool:
"""Check if user has required permission for document."""
try:
async with self.async_session() as session:
result = await session.execute(select(DocumentModel).where(DocumentModel.external_id == document_id))
doc_model = result.scalar_one_or_none()
if not doc_model:
return False
# Check owner access
owner = doc_model.owner
if owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id:
return True
# Check permission-specific access
access_control = doc_model.access_control
permission_map = {"read": "readers", "write": "writers", "admin": "admins"}
permission_set = permission_map.get(required_permission)
if not permission_set:
return False
return auth.entity_id in access_control.get(permission_set, [])
except Exception as e:
logger.error(f"Error checking document access: {str(e)}")
return False
def _build_access_filter(self, auth: AuthContext) -> str:
"""Build PostgreSQL filter for access control."""
filters = [
f"owner->>'id' = '{auth.entity_id}'",
f"access_control->'readers' ? '{auth.entity_id}'",
f"access_control->'writers' ? '{auth.entity_id}'",
f"access_control->'admins' ? '{auth.entity_id}'",
]
if auth.entity_type == "DEVELOPER" and auth.app_id:
# Add app-specific access for developers
filters.append(f"access_control->'app_access' ? '{auth.app_id}'")
# Add user_id filter in cloud mode
if auth.user_id:
from core.config import get_settings
settings = get_settings()
if settings.MODE == "cloud":
# Filter by user_id in access_control
filters.append(f"access_control->>'user_id' = '{auth.user_id}'")
return " OR ".join(filters)
def _build_metadata_filter(self, filters: Dict[str, Any]) -> str:
"""Build PostgreSQL filter for metadata."""
if not filters:
return ""
filter_conditions = []
for key, value in filters.items():
# Handle list of values (IN operator)
if isinstance(value, list):
if not value: # Skip empty lists
continue
# Build a list of properly escaped values
escaped_values = []
for item in value:
if isinstance(item, bool):
escaped_values.append(str(item).lower())
elif isinstance(item, str):
# Use standard replace, avoid complex f-string quoting for black
escaped_value = item.replace("'", "''")
escaped_values.append(f"'{escaped_value}'")
else:
escaped_values.append(f"'{item}'")
# Join with commas for IN clause
values_str = ", ".join(escaped_values)
filter_conditions.append(f"doc_metadata->>'{key}' IN ({values_str})")
else:
# Handle single value (equality)
# Convert boolean values to string 'true' or 'false'
if isinstance(value, bool):
value = str(value).lower()
# Use proper SQL escaping for string values
if isinstance(value, str):
# Replace single quotes with double single quotes to escape them
value = value.replace("'", "''")
filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'")
return " AND ".join(filter_conditions)
def _build_system_metadata_filter(self, system_filters: Optional[Dict[str, Any]]) -> str:
"""Build PostgreSQL filter for system metadata."""
if not system_filters:
return ""
conditions = []
for key, value in system_filters.items():
if value is None:
continue
# Handle list of values (IN operator)
if isinstance(value, list):
if not value: # Skip empty lists
continue
# Build a list of properly escaped values
escaped_values = []
for item in value:
if isinstance(item, bool):
escaped_values.append(str(item).lower())
elif isinstance(item, str):
# Use standard replace, avoid complex f-string quoting for black
escaped_value = item.replace("'", "''")
escaped_values.append(f"'{escaped_value}'")
else:
escaped_values.append(f"'{item}'")
# Join with commas for IN clause
values_str = ", ".join(escaped_values)
conditions.append(f"system_metadata->>'{key}' IN ({values_str})")
else:
# Handle single value (equality)
if isinstance(value, str):
# Replace single quotes with double single quotes to escape them
escaped_value = value.replace("'", "''")
conditions.append(f"system_metadata->>'{key}' = '{escaped_value}'")
elif isinstance(value, bool):
conditions.append(f"system_metadata->>'{key}' = '{str(value).lower()}'")
else:
conditions.append(f"system_metadata->>'{key}' = '{value}'")
return " AND ".join(conditions)
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
"""Store metadata for a cache in PostgreSQL.
Args:
name: Name of the cache
metadata: Cache metadata including model info and storage location
Returns:
bool: Whether the operation was successful
"""
try:
async with self.async_session() as session:
await session.execute(
text(
"""
INSERT INTO caches (name, metadata, updated_at)
VALUES (:name, :metadata, CURRENT_TIMESTAMP)
ON CONFLICT (name)
DO UPDATE SET
metadata = :metadata,
updated_at = CURRENT_TIMESTAMP
"""
),
{"name": name, "metadata": json.dumps(metadata)},
)
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to store cache metadata: {e}")
return False
async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]:
"""Get metadata for a cache from PostgreSQL.
Args:
name: Name of the cache
Returns:
Optional[Dict[str, Any]]: Cache metadata if found, None otherwise
"""
try:
async with self.async_session() as session:
result = await session.execute(text("SELECT metadata FROM caches WHERE name = :name"), {"name": name})
row = result.first()
return row[0] if row else None
except Exception as e:
logger.error(f"Failed to get cache metadata: {e}")
return None
async def store_graph(self, graph: Graph) -> bool:
"""Store a graph in PostgreSQL.
This method stores the graph metadata, entities, and relationships
in a PostgreSQL table.
Args:
graph: Graph to store
Returns:
bool: Whether the operation was successful
"""
# Ensure database is initialized
if not self._initialized:
await self.initialize()
try:
# First serialize the graph model to dict
graph_dict = graph.model_dump()
# Change 'metadata' to 'graph_metadata' to match our model
if "metadata" in graph_dict:
graph_dict["graph_metadata"] = graph_dict.pop("metadata")
# Serialize datetime objects to ISO format strings
graph_dict = _serialize_datetime(graph_dict)
# Store the graph metadata in PostgreSQL
async with self.async_session() as session:
# Store graph metadata in our table
graph_model = GraphModel(**graph_dict)
session.add(graph_model)
await session.commit()
logger.info(
f"Stored graph '{graph.name}' with {len(graph.entities)} entities "
f"and {len(graph.relationships)} relationships"
)
return True
except Exception as e:
logger.error(f"Error storing graph: {str(e)}")
return False
async def get_graph(
self, name: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None
) -> Optional[Graph]:
"""Get a graph by name.
Args:
name: Name of the graph
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
Optional[Graph]: Graph if found and accessible, None otherwise
"""
# Ensure database is initialized
if not self._initialized:
await self.initialize()
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# We need to check if the documents in the graph match the system filters
# First get the graph without system filters
query = select(GraphModel).where(GraphModel.name == name).where(text(f"({access_filter})"))
result = await session.execute(query)
graph_model = result.scalar_one_or_none()
if graph_model:
# If system filters are provided, we need to filter the document_ids
document_ids = graph_model.document_ids
if system_filters and document_ids:
# Apply system_filters to document_ids
system_metadata_filter = self._build_system_metadata_filter(system_filters)
if system_metadata_filter:
# Get document IDs with system filters
doc_id_placeholders = ", ".join([f"'{doc_id}'" for doc_id in document_ids])
filter_query = f"""
SELECT external_id FROM documents
WHERE external_id IN ({doc_id_placeholders})
AND ({system_metadata_filter})
"""
filter_result = await session.execute(text(filter_query))
filtered_doc_ids = [row[0] for row in filter_result.all()]
# If no documents match system filters, return None
if not filtered_doc_ids:
return None
# Update document_ids with filtered results
document_ids = filtered_doc_ids
# Convert to Graph model
graph_dict = {
"id": graph_model.id,
"name": graph_model.name,
"entities": graph_model.entities,
"relationships": graph_model.relationships,
"metadata": graph_model.graph_metadata, # Reference the renamed column
"system_metadata": graph_model.system_metadata or {}, # Include system_metadata
"document_ids": document_ids, # Use possibly filtered document_ids
"filters": graph_model.filters,
"created_at": graph_model.created_at,
"updated_at": graph_model.updated_at,
"owner": graph_model.owner,
"access_control": graph_model.access_control,
}
return Graph(**graph_dict)
return None
except Exception as e:
logger.error(f"Error retrieving graph: {str(e)}")
return None
async def list_graphs(self, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Graph]:
"""List all graphs the user has access to.
Args:
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
List[Graph]: List of graphs
"""
# Ensure database is initialized
if not self._initialized:
await self.initialize()
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query graphs
query = select(GraphModel).where(text(f"({access_filter})"))
result = await session.execute(query)
graph_models = result.scalars().all()
graphs = []
# If system filters are provided, we need to filter each graph's document_ids
if system_filters:
system_metadata_filter = self._build_system_metadata_filter(system_filters)
for graph_model in graph_models:
document_ids = graph_model.document_ids
if document_ids and system_metadata_filter:
# Get document IDs with system filters
doc_id_placeholders = ", ".join([f"'{doc_id}'" for doc_id in document_ids])
filter_query = f"""
SELECT external_id FROM documents
WHERE external_id IN ({doc_id_placeholders})
AND ({system_metadata_filter})
"""
filter_result = await session.execute(text(filter_query))
filtered_doc_ids = [row[0] for row in filter_result.all()]
# Only include graphs that have documents matching the system filters
if filtered_doc_ids:
graph = Graph(
id=graph_model.id,
name=graph_model.name,
entities=graph_model.entities,
relationships=graph_model.relationships,
metadata=graph_model.graph_metadata, # Reference the renamed column
system_metadata=graph_model.system_metadata or {}, # Include system_metadata
document_ids=filtered_doc_ids, # Use filtered document_ids
filters=graph_model.filters,
created_at=graph_model.created_at,
updated_at=graph_model.updated_at,
owner=graph_model.owner,
access_control=graph_model.access_control,
)
graphs.append(graph)
else:
# No system filters, include all graphs
graphs = [
Graph(
id=graph.id,
name=graph.name,
entities=graph.entities,
relationships=graph.relationships,
metadata=graph.graph_metadata, # Reference the renamed column
system_metadata=graph.system_metadata or {}, # Include system_metadata
document_ids=graph.document_ids,
filters=graph.filters,
created_at=graph.created_at,
updated_at=graph.updated_at,
owner=graph.owner,
access_control=graph.access_control,
)
for graph in graph_models
]
return graphs
except Exception as e:
logger.error(f"Error listing graphs: {str(e)}")
return []
async def update_graph(self, graph: Graph) -> bool:
"""Update an existing graph in PostgreSQL.
This method updates the graph metadata, entities, and relationships
in the PostgreSQL table.
Args:
graph: Graph to update
Returns:
bool: Whether the operation was successful
"""
# Ensure database is initialized
if not self._initialized:
await self.initialize()
try:
# First serialize the graph model to dict
graph_dict = graph.model_dump()
# Change 'metadata' to 'graph_metadata' to match our model
if "metadata" in graph_dict:
graph_dict["graph_metadata"] = graph_dict.pop("metadata")
# Serialize datetime objects to ISO format strings
graph_dict = _serialize_datetime(graph_dict)
# Update the graph in PostgreSQL
async with self.async_session() as session:
# Check if the graph exists
result = await session.execute(select(GraphModel).where(GraphModel.id == graph.id))
graph_model = result.scalar_one_or_none()
if not graph_model:
logger.error(f"Graph '{graph.name}' with ID {graph.id} not found for update")
return False
# Update the graph model with new values
for key, value in graph_dict.items():
setattr(graph_model, key, value)
await session.commit()
logger.info(
f"Updated graph '{graph.name}' with {len(graph.entities)} entities "
f"and {len(graph.relationships)} relationships"
)
return True
except Exception as e:
logger.error(f"Error updating graph: {str(e)}")
return False
async def create_folder(self, folder: Folder) -> bool:
"""Create a new folder."""
try:
async with self.async_session() as session:
folder_dict = folder.model_dump()
# Convert datetime objects to strings for JSON serialization
folder_dict = _serialize_datetime(folder_dict)
# Check if a folder with this name already exists for this owner
# Use only the type/id format
stmt = text(
"""
SELECT id FROM folders
WHERE name = :name
AND owner->>'id' = :entity_id
AND owner->>'type' = :entity_type
"""
).bindparams(name=folder.name, entity_id=folder.owner["id"], entity_type=folder.owner["type"])
result = await session.execute(stmt)
existing_folder = result.scalar_one_or_none()
if existing_folder:
logger.info(
f"Folder '{folder.name}' already exists with ID {existing_folder}, not creating a duplicate"
)
# Update the provided folder's ID to match the existing one
# so the caller gets the correct ID
folder.id = existing_folder
return True
# Create a new folder model
access_control = folder_dict.get("access_control", {})
# Log access control to debug any issues
if "user_id" in access_control:
logger.info(f"Storing folder with user_id: {access_control['user_id']}")
else:
logger.info("No user_id found in folder access_control")
folder_model = FolderModel(
id=folder.id,
name=folder.name,
description=folder.description,
owner=folder_dict["owner"],
document_ids=folder_dict.get("document_ids", []),
system_metadata=folder_dict.get("system_metadata", {}),
access_control=access_control,
rules=folder_dict.get("rules", []),
)
session.add(folder_model)
await session.commit()
logger.info(f"Created new folder '{folder.name}' with ID {folder.id}")
return True
except Exception as e:
logger.error(f"Error creating folder: {e}")
return False
async def get_folder(self, folder_id: str, auth: AuthContext) -> Optional[Folder]:
"""Get a folder by ID."""
try:
async with self.async_session() as session:
# Get the folder
logger.info(f"Getting folder with ID: {folder_id}")
result = await session.execute(select(FolderModel).where(FolderModel.id == folder_id))
folder_model = result.scalar_one_or_none()
if not folder_model:
logger.error(f"Folder with ID {folder_id} not found in database")
return None
# Convert to Folder object
folder_dict = {
"id": folder_model.id,
"name": folder_model.name,
"description": folder_model.description,
"owner": folder_model.owner,
"document_ids": folder_model.document_ids,
"system_metadata": folder_model.system_metadata,
"access_control": folder_model.access_control,
"rules": folder_model.rules,
}
folder = Folder(**folder_dict)
# Check if the user has access to the folder
if not self._check_folder_access(folder, auth, "read"):
return None
return folder
except Exception as e:
logger.error(f"Error getting folder: {e}")
return None
async def get_folder_by_name(self, name: str, auth: AuthContext) -> Optional[Folder]:
"""Get a folder by name."""
try:
async with self.async_session() as session:
# First try to get a folder owned by this entity
if auth.entity_type and auth.entity_id:
stmt = text(
"""
SELECT * FROM folders
WHERE name = :name
AND (owner->>'id' = :entity_id)
AND (owner->>'type' = :entity_type)
"""
).bindparams(name=name, entity_id=auth.entity_id, entity_type=auth.entity_type.value)
result = await session.execute(stmt)
folder_row = result.fetchone()
if folder_row:
# Convert to Folder object
folder_dict = {
"id": folder_row.id,
"name": folder_row.name,
"description": folder_row.description,
"owner": folder_row.owner,
"document_ids": folder_row.document_ids,
"system_metadata": folder_row.system_metadata,
"access_control": folder_row.access_control,
"rules": folder_row.rules,
}
return Folder(**folder_dict)
# If not found, try to find any accessible folder with that name
stmt = text(
"""
SELECT * FROM folders
WHERE name = :name
AND (
(owner->>'id' = :entity_id AND owner->>'type' = :entity_type)
OR (access_control->'readers' ? :entity_id)
OR (access_control->'writers' ? :entity_id)
OR (access_control->'admins' ? :entity_id)
OR (access_control->'user_id' ? :user_id)
)
"""
).bindparams(
name=name,
entity_id=auth.entity_id,
entity_type=auth.entity_type.value,
user_id=auth.user_id if auth.user_id else "",
)
result = await session.execute(stmt)
folder_row = result.fetchone()
if folder_row:
# Convert to Folder object
folder_dict = {
"id": folder_row.id,
"name": folder_row.name,
"description": folder_row.description,
"owner": folder_row.owner,
"document_ids": folder_row.document_ids,
"system_metadata": folder_row.system_metadata,
"access_control": folder_row.access_control,
"rules": folder_row.rules,
}
return Folder(**folder_dict)
return None
except Exception as e:
logger.error(f"Error getting folder by name: {e}")
return None
async def list_folders(self, auth: AuthContext) -> List[Folder]:
"""List all folders the user has access to."""
try:
folders = []
async with self.async_session() as session:
# Get all folders
result = await session.execute(select(FolderModel))
folder_models = result.scalars().all()
for folder_model in folder_models:
# Convert to Folder object
folder_dict = {
"id": folder_model.id,
"name": folder_model.name,
"description": folder_model.description,
"owner": folder_model.owner,
"document_ids": folder_model.document_ids,
"system_metadata": folder_model.system_metadata,
"access_control": folder_model.access_control,
"rules": folder_model.rules,
}
folder = Folder(**folder_dict)
# Check if the user has access to the folder
if self._check_folder_access(folder, auth, "read"):
folders.append(folder)
return folders
except Exception as e:
logger.error(f"Error listing folders: {e}")
return []
async def add_document_to_folder(self, folder_id: str, document_id: str, auth: AuthContext) -> bool:
"""Add a document to a folder."""
try:
# First, check if the user has access to the folder
folder = await self.get_folder(folder_id, auth)
if not folder:
logger.error(f"Folder {folder_id} not found or user does not have access")
return False
# Check if user has write access to the folder
if not self._check_folder_access(folder, auth, "write"):
logger.error(f"User does not have write access to folder {folder_id}")
return False
# Check if the document exists and user has access
document = await self.get_document(document_id, auth)
if not document:
logger.error(f"Document {document_id} not found or user does not have access")
return False
# Check if the document is already in the folder
if document_id in folder.document_ids:
logger.info(f"Document {document_id} is already in folder {folder_id}")
return True
# Add the document to the folder
async with self.async_session() as session:
# Add document_id to document_ids array
new_document_ids = folder.document_ids + [document_id]
folder_model = await session.get(FolderModel, folder_id)
if not folder_model:
logger.error(f"Folder {folder_id} not found in database")
return False
folder_model.document_ids = new_document_ids
# Also update the document's system_metadata to include the folder_name
folder_name_json = json.dumps(folder.name)
stmt = text(
f"""
UPDATE documents
SET system_metadata = jsonb_set(system_metadata, '{{folder_name}}', '{folder_name_json}'::jsonb)
WHERE external_id = :document_id
"""
).bindparams(document_id=document_id)
await session.execute(stmt)
await session.commit()
logger.info(f"Added document {document_id} to folder {folder_id}")
return True
except Exception as e:
logger.error(f"Error adding document to folder: {e}")
return False
async def remove_document_from_folder(self, folder_id: str, document_id: str, auth: AuthContext) -> bool:
"""Remove a document from a folder."""
try:
# First, check if the user has access to the folder
folder = await self.get_folder(folder_id, auth)
if not folder:
logger.error(f"Folder {folder_id} not found or user does not have access")
return False
# Check if user has write access to the folder
if not self._check_folder_access(folder, auth, "write"):
logger.error(f"User does not have write access to folder {folder_id}")
return False
# Check if the document is in the folder
if document_id not in folder.document_ids:
logger.warning(f"Tried to delete document {document_id} not in folder {folder_id}")
return True
# Remove the document from the folder
async with self.async_session() as session:
# Remove document_id from document_ids array
new_document_ids = [doc_id for doc_id in folder.document_ids if doc_id != document_id]
folder_model = await session.get(FolderModel, folder_id)
if not folder_model:
logger.error(f"Folder {folder_id} not found in database")
return False
folder_model.document_ids = new_document_ids
# Also update the document's system_metadata to remove the folder_name
stmt = text(
"""
UPDATE documents
SET system_metadata = jsonb_set(system_metadata, '{folder_name}', 'null'::jsonb)
WHERE external_id = :document_id
"""
).bindparams(document_id=document_id)
await session.execute(stmt)
await session.commit()
logger.info(f"Removed document {document_id} from folder {folder_id}")
return True
except Exception as e:
logger.error(f"Error removing document from folder: {e}")
return False
def _check_folder_access(self, folder: Folder, auth: AuthContext, permission: str = "read") -> bool:
"""Check if the user has the required permission for the folder."""
# Admin always has access
if "admin" in auth.permissions:
return True
# Check if folder is owned by the user
if (
auth.entity_type
and auth.entity_id
and folder.owner.get("type") == auth.entity_type.value
and folder.owner.get("id") == auth.entity_id
):
# In cloud mode, also verify user_id if present
if auth.user_id:
from core.config import get_settings
settings = get_settings()
if settings.MODE == "cloud":
folder_user_ids = folder.access_control.get("user_id", [])
if auth.user_id not in folder_user_ids:
return False
return True
# Check access control lists
access_control = folder.access_control or {}
if permission == "read":
readers = access_control.get("readers", [])
if f"{auth.entity_type.value}:{auth.entity_id}" in readers:
return True
if permission == "write":
writers = access_control.get("writers", [])
if f"{auth.entity_type.value}:{auth.entity_id}" in writers:
return True
# For admin permission, check admins list
if permission == "admin":
admins = access_control.get("admins", [])
if f"{auth.entity_type.value}:{auth.entity_id}" in admins:
return True
return False
```
## /core/database/user_limits_db.py
```py path="/core/database/user_limits_db.py"
import json
import logging
from datetime import UTC, datetime, timedelta
from typing import Any, Dict, Optional
from sqlalchemy import Column, Index, String, select, text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import declarative_base, sessionmaker
logger = logging.getLogger(__name__)
Base = declarative_base()
class UserLimitsModel(Base):
"""SQLAlchemy model for user limits data."""
__tablename__ = "user_limits"
user_id = Column(String, primary_key=True)
tier = Column(String, nullable=False) # free, developer, startup, custom
custom_limits = Column(JSONB, nullable=True)
usage = Column(JSONB, default=dict) # Holds all usage counters
app_ids = Column(JSONB, default=list) # List of app IDs registered by this user
stripe_customer_id = Column(String, nullable=True)
stripe_subscription_id = Column(String, nullable=True)
stripe_product_id = Column(String, nullable=True)
subscription_status = Column(String, nullable=True)
created_at = Column(String) # ISO format string
updated_at = Column(String) # ISO format string
# Create indexes
__table_args__ = (Index("idx_user_tier", "tier"),)
class UserLimitsDatabase:
"""Database operations for user limits."""
def __init__(self, uri: str):
"""Initialize database connection."""
self.engine = create_async_engine(uri)
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
self._initialized = False
async def initialize(self) -> bool:
"""Initialize database tables and indexes."""
if self._initialized:
return True
try:
logger.info("Initializing user limits database tables...")
# Create tables if they don't exist
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Check if we need to add the new Stripe columns
# This safely adds columns if they don't exist without affecting existing data
try:
# Check if the columns exist first to avoid errors
for column_name in [
"stripe_customer_id",
"stripe_subscription_id",
"stripe_product_id",
"subscription_status",
]:
await conn.execute(
text(
f"DO $$\n"
f"BEGIN\n"
f" IF NOT EXISTS (SELECT 1 FROM information_schema.columns \n"
f" WHERE table_name='user_limits' AND column_name='{column_name}') THEN\n"
f" ALTER TABLE user_limits ADD COLUMN {column_name} VARCHAR;\n"
f" END IF;\n"
f"END$$;"
)
)
logger.info("Successfully migrated user_limits table schema if needed")
except Exception as migration_error:
logger.warning(f"Migration step failed, but continuing: {migration_error}")
# We continue even if migration fails as the app can still function
self._initialized = True
logger.info("User limits database tables initialized successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize user limits database: {e}")
return False
async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
"""
Get user limits for a user.
Args:
user_id: The user ID to get limits for
Returns:
Dict with user limits if found, None otherwise
"""
async with self.async_session() as session:
result = await session.execute(select(UserLimitsModel).where(UserLimitsModel.user_id == user_id))
user_limits = result.scalars().first()
if not user_limits:
return None
return {
"user_id": user_limits.user_id,
"tier": user_limits.tier,
"custom_limits": user_limits.custom_limits,
"usage": user_limits.usage,
"app_ids": user_limits.app_ids,
"stripe_customer_id": user_limits.stripe_customer_id,
"stripe_subscription_id": user_limits.stripe_subscription_id,
"stripe_product_id": user_limits.stripe_product_id,
"subscription_status": user_limits.subscription_status,
"created_at": user_limits.created_at,
"updated_at": user_limits.updated_at,
}
async def create_user_limits(self, user_id: str, tier: str = "free") -> bool:
"""
Create user limits record.
Args:
user_id: The user ID
tier: Initial tier (defaults to "free")
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC).isoformat()
async with self.async_session() as session:
# Check if already exists
result = await session.execute(select(UserLimitsModel).where(UserLimitsModel.user_id == user_id))
if result.scalars().first():
return True # Already exists
# Create new record with properly initialized JSONB columns
# Create JSON strings and parse them for consistency
usage_json = json.dumps(
{
"storage_file_count": 0,
"storage_size_bytes": 0,
"hourly_query_count": 0,
"hourly_query_reset": now,
"monthly_query_count": 0,
"monthly_query_reset": now,
"hourly_ingest_count": 0,
"hourly_ingest_reset": now,
"monthly_ingest_count": 0,
"monthly_ingest_reset": now,
"graph_count": 0,
"cache_count": 0,
}
)
app_ids_json = json.dumps([]) # Empty array but as JSON string
# Create the model with the JSON parsed
user_limits = UserLimitsModel(
user_id=user_id,
tier=tier,
usage=json.loads(usage_json),
app_ids=json.loads(app_ids_json),
stripe_customer_id=None,
stripe_subscription_id=None,
stripe_product_id=None,
subscription_status=None,
created_at=now,
updated_at=now,
)
session.add(user_limits)
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to create user limits: {e}")
return False
async def update_user_tier(self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None) -> bool:
"""
Update user tier and custom limits.
Args:
user_id: The user ID
tier: New tier
custom_limits: Optional custom limits for CUSTOM tier
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC).isoformat()
async with self.async_session() as session:
result = await session.execute(select(UserLimitsModel).where(UserLimitsModel.user_id == user_id))
user_limits = result.scalars().first()
if not user_limits:
return False
user_limits.tier = tier
user_limits.custom_limits = custom_limits
user_limits.updated_at = now
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to update user tier: {e}")
return False
async def update_subscription_info(self, user_id: str, subscription_data: Dict[str, Any]) -> bool:
"""
Update user subscription information.
Args:
user_id: The user ID
subscription_data: Dictionary containing subscription information with keys:
- stripeCustomerId
- stripeSubscriptionId
- stripeProductId
- subscriptionStatus
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC).isoformat()
async with self.async_session() as session:
result = await session.execute(select(UserLimitsModel).where(UserLimitsModel.user_id == user_id))
user_limits = result.scalars().first()
if not user_limits:
return False
user_limits.stripe_customer_id = subscription_data.get("stripeCustomerId")
user_limits.stripe_subscription_id = subscription_data.get("stripeSubscriptionId")
user_limits.stripe_product_id = subscription_data.get("stripeProductId")
user_limits.subscription_status = subscription_data.get("subscriptionStatus")
user_limits.updated_at = now
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to update subscription info: {e}")
return False
async def register_app(self, user_id: str, app_id: str) -> bool:
"""
Register an app for a user.
Args:
user_id: The user ID
app_id: The app ID to register
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC).isoformat()
async with self.async_session() as session:
# First check if user exists
result = await session.execute(select(UserLimitsModel).where(UserLimitsModel.user_id == user_id))
user_limits = result.scalars().first()
if not user_limits:
logger.error(f"User {user_id} not found in register_app")
return False
# Use raw SQL with jsonb_array_append to update the app_ids array
# This is the most reliable way to append to a JSONB array in PostgreSQL
query = text(
"""
UPDATE user_limits
SET
app_ids = CASE
WHEN NOT (app_ids ? :app_id) -- Check if app_id is not in the array
THEN app_ids || :app_id_json -- Append it if not present
ELSE app_ids -- Keep it unchanged if already present
END,
updated_at = :now
WHERE user_id = :user_id
RETURNING app_ids;
"""
)
# Execute the query
result = await session.execute(
query,
{
"app_id": app_id, # For the check
"app_id_json": f'["{app_id}"]', # JSON array format for appending
"now": now,
"user_id": user_id,
},
)
# Log the result for debugging
updated_app_ids = result.scalar()
logger.info(f"Updated app_ids for user {user_id}: {updated_app_ids}")
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to register app: {e}")
return False
async def update_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool:
"""
Update usage counter for a user.
Args:
user_id: The user ID
usage_type: Type of usage to update
increment: Value to increment by
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC)
now_iso = now.isoformat()
async with self.async_session() as session:
result = await session.execute(select(UserLimitsModel).where(UserLimitsModel.user_id == user_id))
user_limits = result.scalars().first()
if not user_limits:
return False
# Create a new dictionary to force SQLAlchemy to detect the change
usage = dict(user_limits.usage) if user_limits.usage else {}
# Handle different usage types
if usage_type == "query":
# Check hourly reset
hourly_reset_str = usage.get("hourly_query_reset", "")
if hourly_reset_str:
hourly_reset = datetime.fromisoformat(hourly_reset_str)
if now > hourly_reset + timedelta(hours=1):
usage["hourly_query_count"] = increment
usage["hourly_query_reset"] = now_iso
else:
usage["hourly_query_count"] = usage.get("hourly_query_count", 0) + increment
else:
usage["hourly_query_count"] = increment
usage["hourly_query_reset"] = now_iso
# Check monthly reset
monthly_reset_str = usage.get("monthly_query_reset", "")
if monthly_reset_str:
monthly_reset = datetime.fromisoformat(monthly_reset_str)
if now > monthly_reset + timedelta(days=30):
usage["monthly_query_count"] = increment
usage["monthly_query_reset"] = now_iso
else:
usage["monthly_query_count"] = usage.get("monthly_query_count", 0) + increment
else:
usage["monthly_query_count"] = increment
usage["monthly_query_reset"] = now_iso
elif usage_type == "ingest":
# Similar pattern for ingest
hourly_reset_str = usage.get("hourly_ingest_reset", "")
if hourly_reset_str:
hourly_reset = datetime.fromisoformat(hourly_reset_str)
if now > hourly_reset + timedelta(hours=1):
usage["hourly_ingest_count"] = increment
usage["hourly_ingest_reset"] = now_iso
else:
usage["hourly_ingest_count"] = usage.get("hourly_ingest_count", 0) + increment
else:
usage["hourly_ingest_count"] = increment
usage["hourly_ingest_reset"] = now_iso
monthly_reset_str = usage.get("monthly_ingest_reset", "")
if monthly_reset_str:
monthly_reset = datetime.fromisoformat(monthly_reset_str)
if now > monthly_reset + timedelta(days=30):
usage["monthly_ingest_count"] = increment
usage["monthly_ingest_reset"] = now_iso
else:
usage["monthly_ingest_count"] = usage.get("monthly_ingest_count", 0) + increment
else:
usage["monthly_ingest_count"] = increment
usage["monthly_ingest_reset"] = now_iso
elif usage_type == "storage_file":
usage["storage_file_count"] = usage.get("storage_file_count", 0) + increment
elif usage_type == "storage_size":
usage["storage_size_bytes"] = usage.get("storage_size_bytes", 0) + increment
elif usage_type == "graph":
usage["graph_count"] = usage.get("graph_count", 0) + increment
elif usage_type == "cache":
usage["cache_count"] = usage.get("cache_count", 0) + increment
# Force SQLAlchemy to recognize the change by assigning a new dict
user_limits.usage = usage
user_limits.updated_at = now_iso
# Explicitly mark as modified
session.add(user_limits)
# Log the updated usage for debugging
logger.info(f"Updated usage for user {user_id}, type: {usage_type}, value: {increment}")
logger.info(f"New usage values: {usage}")
logger.info(f"About to commit: user_id={user_id}, usage={user_limits.usage}")
# Commit and flush to ensure changes are written
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to update usage: {e}")
import traceback
logger.error(traceback.format_exc())
return False
```
## /core/embedding/__init__.py
```py path="/core/embedding/__init__.py"
from core.embedding.base_embedding_model import BaseEmbeddingModel
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
from core.embedding.litellm_embedding import LiteLLMEmbeddingModel
__all__ = ["BaseEmbeddingModel", "LiteLLMEmbeddingModel", "ColpaliEmbeddingModel"]
```
The content has been capped at 50000 tokens, and files over NaN bytes have been omitted. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.