``` ├── .dockerignore ├── .env.example ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── DOCKER.md ├── LICENSE ├── README.md ├── __init__.py ├── assets/ ├── morphik_logo.png ├── core/ ├── __init__.py ├── api.py ├── cache/ ├── base_cache.py ├── base_cache_factory.py ├── hf_cache.py ├── llama_cache.py ├── llama_cache_factory.py ├── completion/ ├── __init__.py ├── base_completion.py ├── litellm_completion.py ├── config.py ├── database/ ├── base_database.py ├── postgres_database.py ├── user_limits_db.py ├── embedding/ ├── __init__.py ├── base_embedding_model.py ``` ## /.dockerignore ```dockerignore path="/.dockerignore" # flyctl launch added from .gitignore # Python-related files **/*__pycache__ **/*.pyc **/*.pyo **/*.pyd **/.Python **/env **/.env **/venv/* **/ENV **/dist **/build **/*.egg-info **/.eggs **/*.egg **/*.pytest_cache # Virtual environment **/.venv **/.vscode **/*.DS_Store # flyctl launch added from .pytest_cache/.gitignore # Created by pytest automatically. .pytest_cache/**/* # flyctl launch added from .venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/.gitignore .venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/pydevd_frame_evaluator.*.so .venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/pydevd_frame_evaluator.*.pyd .venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/pydevd_frame_evaluator.pyx # flyctl launch added from .venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/linux_and_mac/.gitignore .venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/linux_and_mac/attach_x86.dylib .venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/linux_and_mac/attach_x86_64.dylib .venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/linux_and_mac/attach_linux_x86.o .venv/lib/python3.12/site-packages/debugpy/_vendored/pydevd/pydevd_attach_to_process/linux_and_mac/attach_linux_x86_64.o fly.toml ``` ## /.env.example ```example path="/.env.example" JWT_SECRET_KEY="..." # Required in production, optional in dev mode (dev_mode=true in morphik.toml) POSTGRES_URI="postgresql+asyncpg://postgres:postgres@localhost:5432/morphik" # Required for PostgreSQL database UNSTRUCTURED_API_KEY="..." # Optional: Needed for parsing via unstructured API OPENAI_API_KEY="..." # Optional: Needed for OpenAI embeddings and completions ASSEMBLYAI_API_KEY="..." # Optional: Needed for combined parser ANTHROPIC_API_KEY="..." # Optional: Needed for contextual parser AWS_ACCESS_KEY="..." # Optional: Needed for AWS S3 storage AWS_SECRET_ACCESS_KEY="..." # Optional: Needed for AWS S3 storage ``` ## /.gitignore ```gitignore path="/.gitignore" # Python-related files *__pycache__/ *.pyc *.pyo *.pyd .Python env/ .env venv/* ENV/ dist/ build/ *.egg-info/ .eggs/ *.egg *.pytest_cache/ core/tests/output core/tests/assets # Virtual environment .venv/ .vscode/ *.DS_Store storage/* logs/* samples/* aggregated_code.txt offload/* test.pdf experiments/* ee/ui-component/package-lock.json/* ee/ui-component/node-modules/* ee/ui-component/.next ui-component/notebook-storage/notebooks.json ee/ui-component/package-lock.json ``` ## /.pre-commit-config.yaml ```yaml path="/.pre-commit-config.yaml" repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.3.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort name: isort (python) - repo: https://github.com/psf/black rev: 24.4.2 hooks: - id: black args: [--line-length=120] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.1.6 hooks: - id: ruff args: [--fix] ``` ## /CODE_OF_CONDUCT.md # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at founders@morphik.ai. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. ## /DOCKER.md # Docker Setup Guide for Morphik Core Morphik Core provides a streamlined Docker-based setup that includes all necessary components: the core API, PostgreSQL with pgvector, and Ollama for AI models. ## Prerequisites - Docker and Docker Compose installed on your system - At least 10GB of free disk space (for models and data) - 8GB+ RAM recommended ## Quick Start 1. Clone the repository and navigate to the project directory: ```bash git clone https://github.com/morphik-org/morphik-core.git cd morphik-core ``` 2. First-time setup: ```bash docker compose up --build ``` This command will: - Build all required containers - Download necessary AI models (nomic-embed-text and llama3.2) - Initialize the PostgreSQL database with pgvector - Start all services The initial setup may take 5-10 minutes depending on your internet speed, as it needs to download the AI models. 3. For subsequent runs: ```bash docker compose up # Start all services docker compose down # Stop all services ``` 4. To completely reset (will delete all data and models): ```bash docker compose down -v ``` ## Configuration ### 1. Default Setup The default configuration works out of the box and includes: - PostgreSQL with pgvector for document storage - Ollama for AI models (embeddings and completions) - Local file storage - Basic authentication ### 2. Configuration File (morphik.toml) The default `morphik.toml` is configured for Docker and includes: ```toml [api] host = "0.0.0.0" # Important: Use 0.0.0.0 for Docker port = 8000 [completion] provider = "ollama" model_name = "llama3.2" base_url = "http://ollama:11434" # Use Docker service name [embedding] provider = "ollama" model_name = "nomic-embed-text" base_url = "http://ollama:11434" # Use Docker service name [database] provider = "postgres" [vector_store] provider = "pgvector" [storage] provider = "local" storage_path = "/app/storage" ``` ### 3. Environment Variables Create a `.env` file to customize these settings: ```bash JWT_SECRET_KEY=your-secure-key-here # Important: Change in production OPENAI_API_KEY=sk-... # Only if using OpenAI HOST=0.0.0.0 # Leave as is for Docker PORT=8000 # Change if needed ``` ### 4. Custom Configuration To use your own configuration: 1. Create a custom `morphik.toml` 2. Mount it in `docker-compose.yml`: ```yaml services: morphik: volumes: - ./my-custom-morphik.toml:/app/morphik.toml ``` ## Accessing Services - Morphik API: http://localhost:8000 - API Documentation: http://localhost:8000/docs - Health Check: http://localhost:8000/health ## Storage and Data - Database data: Stored in the `postgres_data` Docker volume - AI Models: Stored in the `ollama_data` Docker volume - Documents: Stored in `./storage` directory (mounted to container) - Logs: Available in `./logs` directory ## Troubleshooting 1. **Service Won't Start** ```bash # View all logs docker compose logs # View specific service logs docker compose logs morphik docker compose logs postgres docker compose logs ollama ``` 2. **Database Issues** - Check PostgreSQL is healthy: `docker compose ps` - Verify database connection: `docker compose exec postgres psql -U morphik -d morphik` 3. **Model Download Issues** - Check Ollama logs: `docker compose logs ollama` - Ensure enough disk space for models - Try restarting Ollama: `docker compose restart ollama` 4. **Performance Issues** - Monitor resources: `docker stats` - Ensure sufficient RAM (8GB+ recommended) - Check disk space: `df -h` ## Production Deployment For production environments: 1. **Security**: - Change the default `JWT_SECRET_KEY` - Use proper network security groups - Enable HTTPS (recommended: use a reverse proxy) - Regularly update containers and dependencies 2. **Persistence**: - Use named volumes for all data - Set up regular backups of PostgreSQL - Back up the storage directory 3. **Monitoring**: - Set up container monitoring - Configure proper logging - Use health checks ## Support For issues and feature requests: - GitHub Issues: [https://github.com/morphik-org/morphik-core/issues](https://github.com/morphik-org/morphik-core/issues) - Documentation: [https://docs.morphik.ai](https://docs.morphik.ai) ## Repository Information - License: MIT ## /LICENSE ``` path="/LICENSE" Copyright (c) 2024-2025 Morphik, Inc. Portions of this software are licensed as follows: * All content that resides under the "ee/" directory of this repository, if that directory exists, is licensed under the license defined in "ee/LICENSE". * All third party components incorporated into the Morphik Software are licensed under the original license provided by the owner of the applicable component. * Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ``` ## /README.md

Morphik Logo

PRs Welcome GitHub commit activity GitHub closed issues PyPI - Downloads Discord

Docs - Community - Why Morphik? - Bug reports

## Morphik is an alternative to traditional RAG for highly technical and visual documents. [Morphik](https://morphik.ai) provides developers the tools to ingest, search (deep and shallow), transform, and manage unstructured and multimodal documents. Some of our features include: - [Multimodal Search](https://docs.morphik.ai/concepts/colpali): We employ techniques such as ColPali to build search that actually *understands* the visual content of documents you provide. Search over images, PDFs, videos, and more with a single endpoint. - [Knowledge Graphs](https://docs.morphik.ai/concepts/knowledge-graphs): Build knowledge graphs for domain-specific use cases in a single line of code. Use our battle-tested system prompts, or use your own. - [Fast and Scalable Metadata Extraction](https://docs.morphik.ai/concepts/rules-processing): Extract metadata from documents - including bounding boxes, labeling, classification, and more. - [Integrations](https://docs.morphik.ai/integrations): Integrate with existing tools and workflows. Including (but not limited to) Google Suite, Slack, and Confluence. - [Cache-Augmented-Generation](https://docs.morphik.ai/python-sdk/create_cache): Create persistent KV-caches of your documents to speed up generation. The best part? Morphik has a [free tier](https://www.morphik.ai/pricing) and is open source! Get started by signing up at [Morphik](https://www.morphik.ai/signup). ## Table of Contents - [Getting Started with Morphik](#getting-started-with-morphik-recommended) - [Self-hosting the open-source version](#self-hosting-the-open-source-version) - [Using Morphik](#using-morphik) - [Contributing](#contributing) - [Open source vs paid](#open-source-vs-paid) ## Getting Started with Morphik (Recommended) The fastest and easiest way to get started with Morphik is by signing up for free at [Morphik](https://www.morphik.ai/signup). Your first 200 pages and 100 queries are on us! After this, you can pay based on usage with discounted rates for heavier use. ## Self-hosting the open-source version If you'd like to self-host Morphik, you can find the dedicated instruction [here](https://docs.morphik.ai/getting-started). We offer options for direct installation and installation via docker. **Important**: Due to limited resources, we cannot provide full support for open-source deployments. We have an installation guide, and a [Discord community](https://discord.gg/BwMtv3Zaju) to help, but we can't guarantee full support. ## Using Morphik Once you've signed up for Morphik, you can get started with ingesting and search your data right away. ### Code (Example: Python SDK) For programmers, we offer a [Python SDK](https://docs.morphik.ai/python-sdk/morphik) and a [REST API](https://docs.morphik.ai/api-reference/health-check). Ingesting a file is as simple as: ```python from morphik import Morphik morphik = Morphik("") morphik.ingest_file("path/to/your/super/complex/file.pdf") ``` Similarly, searching and querying your data is easy too: ```python morphik.query("What's the height of screw 14-A in the chair assembly instructions?") ``` ### Morphik Console You can also interact with Morphik via the Morphik Console. This is a web-based interface that allows you to ingest, search, and query your data. You can upload files, connect to different data sources, and chat with your data all within the same place. ### Model Context Protocol Finally, you can also access Morphik via MCP. Instructions are available [here](https://docs.morphik.ai/using-morphik/mcp). ## Contributing You're welcome to contribute to the project! We love: - Bug reports via [GitHub issues](https://github.com/morphik-org/morphik-core/issues) - Feature requests via [GitHub issues](https://github.com/morphik-org/morphik-core/issues) - Pull requests Currently, we're focused on improving speed, integrating with more tools, and finding the research papers that provide the most value to our users. If you ahve thoughts, let us know in the discord or in GitHub! ## Open source vs paid Certain features - such as Morphik Console - are not available in the open-source version. Any feature in the `ee` namespace is not available in the open-source version and carries a different license. Any feature outside that is open source under the MIT expat license. ## Contributors Visit our special thanks page dedicated to our contributors [here](https://docs.morphik.ai/special-thanks). ## PS We took inspiration from [PostHog](https://posthog.com) while writing this README. If you're from PostHog, thank you ❤️ ## /__init__.py ```py path="/__init__.py" ``` ## /assets/morphik_logo.png Binary file available at https://raw.githubusercontent.com/morphik-org/morphik-core/refs/heads/main/assets/morphik_logo.png ## /core/__init__.py ```py path="/core/__init__.py" ``` ## /core/api.py ```py path="/core/api.py" import asyncio import base64 import json import logging import uuid from datetime import UTC, datetime, timedelta from pathlib import Path from typing import Any, Dict, List, Optional import arq import jwt import tomli from fastapi import Depends, FastAPI, File, Form, Header, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from core.cache.llama_cache_factory import LlamaCacheFactory from core.completion.litellm_completion import LiteLLMCompletionModel from core.config import get_settings from core.database.postgres_database import PostgresDatabase from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel from core.embedding.litellm_embedding import LiteLLMEmbeddingModel from core.limits_utils import check_and_increment_limits from core.models.auth import AuthContext, EntityType from core.models.completion import ChunkSource, CompletionResponse from core.models.documents import ChunkResult, Document, DocumentResult from core.models.folders import Folder, FolderCreate from core.models.graph import Graph from core.models.prompts import validate_prompt_overrides_with_http_exception from core.models.request import ( BatchIngestResponse, CompletionQueryRequest, CreateGraphRequest, GenerateUriRequest, IngestTextRequest, RetrieveRequest, SetFolderRuleRequest, UpdateGraphRequest, ) from core.parser.morphik_parser import MorphikParser from core.reranker.flag_reranker import FlagReranker from core.services.document_service import DocumentService from core.services.telemetry import TelemetryService from core.storage.local_storage import LocalStorage from core.storage.s3_storage import S3Storage from core.vector_store.multi_vector_store import MultiVectorStore from core.vector_store.pgvector_store import PGVectorStore # Initialize FastAPI app app = FastAPI(title="Morphik API") logger = logging.getLogger(__name__) # Add health check endpoints @app.get("/health") async def health_check(): """Basic health check endpoint.""" return {"status": "healthy"} @app.get("/health/ready") async def readiness_check(): """Readiness check that verifies the application is initialized.""" return { "status": "ready", "components": { "database": settings.DATABASE_PROVIDER, "vector_store": settings.VECTOR_STORE_PROVIDER, "embedding": settings.EMBEDDING_PROVIDER, "completion": settings.COMPLETION_PROVIDER, "storage": settings.STORAGE_PROVIDER, }, } # Initialize telemetry telemetry = TelemetryService() # Add OpenTelemetry instrumentation - exclude HTTP send/receive spans FastAPIInstrumentor.instrument_app( app, excluded_urls="health,health/.*", # Exclude health check endpoints exclude_spans=["send", "receive"], # Exclude HTTP send/receive spans to reduce telemetry volume http_capture_headers_server_request=None, # Don't capture request headers http_capture_headers_server_response=None, # Don't capture response headers tracer_provider=None, # Use the global tracer provider ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize service settings = get_settings() # Initialize database if not settings.POSTGRES_URI: raise ValueError("PostgreSQL URI is required for PostgreSQL database") database = PostgresDatabase(uri=settings.POSTGRES_URI) # Redis settings already imported at top of file @app.on_event("startup") async def initialize_database(): """Initialize database tables and indexes on application startup.""" logger.info("Initializing database...") success = await database.initialize() if success: logger.info("Database initialization successful") else: logger.error("Database initialization failed") # We don't raise an exception here to allow the app to continue starting # even if there are initialization errors @app.on_event("startup") async def initialize_vector_store(): """Initialize vector store tables and indexes on application startup.""" # First initialize the primary vector store (PGVectorStore if using pgvector) logger.info("Initializing primary vector store...") if hasattr(vector_store, "initialize"): success = await vector_store.initialize() if success: logger.info("Primary vector store initialization successful") else: logger.error("Primary vector store initialization failed") else: logger.warning("Primary vector store does not have an initialize method") # Then initialize the multivector store if enabled if settings.ENABLE_COLPALI and colpali_vector_store: logger.info("Initializing multivector store...") # Handle both synchronous and asynchronous initialize methods if hasattr(colpali_vector_store.initialize, "__awaitable__"): success = await colpali_vector_store.initialize() else: success = colpali_vector_store.initialize() if success: logger.info("Multivector store initialization successful") else: logger.error("Multivector store initialization failed") @app.on_event("startup") async def initialize_user_limits_database(): """Initialize user service on application startup.""" logger.info("Initializing user service...") if settings.MODE == "cloud": from core.database.user_limits_db import UserLimitsDatabase user_limits_db = UserLimitsDatabase(uri=settings.POSTGRES_URI) await user_limits_db.initialize() @app.on_event("startup") async def initialize_redis_pool(): """Initialize the Redis connection pool for background tasks.""" global redis_pool logger.info("Initializing Redis connection pool...") # Get Redis settings from configuration redis_host = settings.REDIS_HOST redis_port = settings.REDIS_PORT # Log the Redis connection details logger.info(f"Connecting to Redis at {redis_host}:{redis_port}") redis_settings = arq.connections.RedisSettings( host=redis_host, port=redis_port, ) redis_pool = await arq.create_pool(redis_settings) logger.info("Redis connection pool initialized successfully") @app.on_event("shutdown") async def close_redis_pool(): """Close the Redis connection pool on application shutdown.""" global redis_pool if redis_pool: logger.info("Closing Redis connection pool...") redis_pool.close() await redis_pool.wait_closed() logger.info("Redis connection pool closed") # Initialize vector store if not settings.POSTGRES_URI: raise ValueError("PostgreSQL URI is required for pgvector store") vector_store = PGVectorStore( uri=settings.POSTGRES_URI, ) # Initialize storage match settings.STORAGE_PROVIDER: case "local": storage = LocalStorage(storage_path=settings.STORAGE_PATH) case "aws-s3": if not settings.AWS_ACCESS_KEY or not settings.AWS_SECRET_ACCESS_KEY: raise ValueError("AWS credentials are required for S3 storage") storage = S3Storage( aws_access_key=settings.AWS_ACCESS_KEY, aws_secret_key=settings.AWS_SECRET_ACCESS_KEY, region_name=settings.AWS_REGION, default_bucket=settings.S3_BUCKET, ) case _: raise ValueError(f"Unsupported storage provider: {settings.STORAGE_PROVIDER}") # Initialize parser parser = MorphikParser( chunk_size=settings.CHUNK_SIZE, chunk_overlap=settings.CHUNK_OVERLAP, use_unstructured_api=settings.USE_UNSTRUCTURED_API, unstructured_api_key=settings.UNSTRUCTURED_API_KEY, assemblyai_api_key=settings.ASSEMBLYAI_API_KEY, anthropic_api_key=settings.ANTHROPIC_API_KEY, use_contextual_chunking=settings.USE_CONTEXTUAL_CHUNKING, ) # Initialize embedding model # Create a LiteLLM model using the registered model config embedding_model = LiteLLMEmbeddingModel( model_key=settings.EMBEDDING_MODEL, ) logger.info(f"Initialized LiteLLM embedding model with model key: {settings.EMBEDDING_MODEL}") # Initialize completion model # Create a LiteLLM model using the registered model config completion_model = LiteLLMCompletionModel( model_key=settings.COMPLETION_MODEL, ) logger.info(f"Initialized LiteLLM completion model with model key: {settings.COMPLETION_MODEL}") # Initialize reranker reranker = None if settings.USE_RERANKING: match settings.RERANKER_PROVIDER: case "flag": reranker = FlagReranker( model_name=settings.RERANKER_MODEL, device=settings.RERANKER_DEVICE, use_fp16=settings.RERANKER_USE_FP16, query_max_length=settings.RERANKER_QUERY_MAX_LENGTH, passage_max_length=settings.RERANKER_PASSAGE_MAX_LENGTH, ) case _: raise ValueError(f"Unsupported reranker provider: {settings.RERANKER_PROVIDER}") # Initialize cache factory cache_factory = LlamaCacheFactory(Path(settings.STORAGE_PATH)) # Initialize ColPali embedding model if enabled colpali_embedding_model = ColpaliEmbeddingModel() if settings.ENABLE_COLPALI else None colpali_vector_store = MultiVectorStore(uri=settings.POSTGRES_URI) if settings.ENABLE_COLPALI else None # Initialize document service with configured components document_service = DocumentService( storage=storage, database=database, vector_store=vector_store, embedding_model=embedding_model, completion_model=completion_model, parser=parser, reranker=reranker, cache_factory=cache_factory, enable_colpali=settings.ENABLE_COLPALI, colpali_embedding_model=colpali_embedding_model, colpali_vector_store=colpali_vector_store, ) async def verify_token(authorization: str = Header(None)) -> AuthContext: """Verify JWT Bearer token or return dev context if dev_mode is enabled.""" # Check if dev mode is enabled if settings.dev_mode: return AuthContext( entity_type=EntityType(settings.dev_entity_type), entity_id=settings.dev_entity_id, permissions=set(settings.dev_permissions), user_id=settings.dev_entity_id, # In dev mode, entity_id is also the user_id ) # Normal token verification flow if not authorization: raise HTTPException( status_code=401, detail="Missing authorization header", headers={"WWW-Authenticate": "Bearer"}, ) try: if not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid authorization header") token = authorization[7:] # Remove "Bearer " payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC): raise HTTPException(status_code=401, detail="Token expired") # Support both "type" and "entity_type" fields for compatibility entity_type_field = payload.get("type") or payload.get("entity_type") if not entity_type_field: raise HTTPException(status_code=401, detail="Missing entity type in token") return AuthContext( entity_type=EntityType(entity_type_field), entity_id=payload["entity_id"], app_id=payload.get("app_id"), permissions=set(payload.get("permissions", ["read"])), user_id=payload.get("user_id", payload["entity_id"]), # Use user_id if available, fallback to entity_id ) except jwt.InvalidTokenError as e: raise HTTPException(status_code=401, detail=str(e)) @app.post("/ingest/text", response_model=Document) @telemetry.track(operation_type="ingest_text", metadata_resolver=telemetry.ingest_text_metadata) async def ingest_text( request: IngestTextRequest, auth: AuthContext = Depends(verify_token), ) -> Document: """ Ingest a text document. Args: request: IngestTextRequest containing: - content: Text content to ingest - filename: Optional filename to help determine content type - metadata: Optional metadata dictionary - rules: Optional list of rules. Each rule should be either: - MetadataExtractionRule: {"type": "metadata_extraction", "schema": {...}} - NaturalLanguageRule: {"type": "natural_language", "prompt": "..."} - folder_name: Optional folder to scope the document to - end_user_id: Optional end-user ID to scope the document to auth: Authentication context Returns: Document: Metadata of ingested document """ try: return await document_service.ingest_text( content=request.content, filename=request.filename, metadata=request.metadata, rules=request.rules, use_colpali=request.use_colpali, auth=auth, folder_name=request.folder_name, end_user_id=request.end_user_id, ) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) # Redis pool for background tasks redis_pool = None def get_redis_pool(): """Get the global Redis connection pool for background tasks.""" return redis_pool @app.post("/ingest/file", response_model=Document) @telemetry.track(operation_type="queue_ingest_file", metadata_resolver=telemetry.ingest_file_metadata) async def ingest_file( file: UploadFile, metadata: str = Form("{}"), rules: str = Form("[]"), auth: AuthContext = Depends(verify_token), use_colpali: Optional[bool] = None, folder_name: Optional[str] = Form(None), end_user_id: Optional[str] = Form(None), redis: arq.ArqRedis = Depends(get_redis_pool), ) -> Document: """ Ingest a file document asynchronously. Args: file: File to ingest metadata: JSON string of metadata rules: JSON string of rules list. Each rule should be either: - MetadataExtractionRule: {"type": "metadata_extraction", "schema": {...}} - NaturalLanguageRule: {"type": "natural_language", "prompt": "..."} auth: Authentication context use_colpali: Whether to use ColPali embedding model folder_name: Optional folder to scope the document to end_user_id: Optional end-user ID to scope the document to redis: Redis connection pool for background tasks Returns: Document with processing status that can be used to check progress """ try: # Parse metadata and rules metadata_dict = json.loads(metadata) rules_list = json.loads(rules) # Fix bool conversion: ensure string "false" is properly converted to False def str2bool(v): return v if isinstance(v, bool) else str(v).lower() in {"true", "1", "yes"} use_colpali = str2bool(use_colpali) # Ensure user has write permission if "write" not in auth.permissions: raise PermissionError("User does not have write permission") logger.debug(f"API: Queueing file ingestion with use_colpali: {use_colpali}") # Create a document with processing status doc = Document( content_type=file.content_type, filename=file.filename, metadata=metadata_dict, owner={"type": auth.entity_type.value, "id": auth.entity_id}, access_control={ "readers": [auth.entity_id], "writers": [auth.entity_id], "admins": [auth.entity_id], "user_id": [auth.user_id] if auth.user_id else [], }, system_metadata={"status": "processing"}, ) # Add folder_name and end_user_id to system_metadata if provided if folder_name: doc.system_metadata["folder_name"] = folder_name if end_user_id: doc.system_metadata["end_user_id"] = end_user_id # Set processing status doc.system_metadata["status"] = "processing" # Store the document in the database success = await database.store_document(doc) if not success: raise Exception("Failed to store document metadata") # If folder_name is provided, ensure the folder exists and add document to it if folder_name: try: await document_service._ensure_folder_exists(folder_name, doc.external_id, auth) logger.debug(f"Ensured folder '{folder_name}' exists and contains document {doc.external_id}") except Exception as e: # Log error but don't raise - we want document ingestion to continue even if folder operation fails logger.error(f"Error ensuring folder exists: {e}") # Read file content file_content = await file.read() # Generate a unique key for the file file_key = f"ingest_uploads/{uuid.uuid4()}/{file.filename}" # Store the file in the configured storage file_content_base64 = base64.b64encode(file_content).decode() bucket, stored_key = await storage.upload_from_base64(file_content_base64, file_key, file.content_type) logger.debug(f"Stored file in bucket {bucket} with key {stored_key}") # Update document with storage info doc.storage_info = {"bucket": bucket, "key": stored_key} # Initialize storage_files array with the first file from datetime import UTC, datetime from core.models.documents import StorageFileInfo # Create a StorageFileInfo for the initial file initial_file_info = StorageFileInfo( bucket=bucket, key=stored_key, version=1, filename=file.filename, content_type=file.content_type, timestamp=datetime.now(UTC), ) doc.storage_files = [initial_file_info] # Log storage files logger.debug(f"Initial storage_files for {doc.external_id}: {doc.storage_files}") # Update both storage_info and storage_files await database.update_document( document_id=doc.external_id, updates={"storage_info": doc.storage_info, "storage_files": doc.storage_files}, auth=auth, ) # Convert auth context to a dictionary for serialization auth_dict = { "entity_type": auth.entity_type.value, "entity_id": auth.entity_id, "app_id": auth.app_id, "permissions": list(auth.permissions), "user_id": auth.user_id, } # Enqueue the background job job = await redis.enqueue_job( "process_ingestion_job", document_id=doc.external_id, file_key=stored_key, bucket=bucket, original_filename=file.filename, content_type=file.content_type, metadata_json=metadata, auth_dict=auth_dict, rules_list=rules_list, use_colpali=use_colpali, folder_name=folder_name, end_user_id=end_user_id, ) logger.info(f"File ingestion job queued with ID: {job.job_id} for document: {doc.external_id}") return doc except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) except Exception as e: logger.error(f"Error during file ingestion: {str(e)}") raise HTTPException(status_code=500, detail=f"Error during file ingestion: {str(e)}") @app.post("/ingest/files", response_model=BatchIngestResponse) @telemetry.track(operation_type="queue_batch_ingest", metadata_resolver=telemetry.batch_ingest_metadata) async def batch_ingest_files( files: List[UploadFile] = File(...), metadata: str = Form("{}"), rules: str = Form("[]"), use_colpali: Optional[bool] = Form(None), parallel: Optional[bool] = Form(True), folder_name: Optional[str] = Form(None), end_user_id: Optional[str] = Form(None), auth: AuthContext = Depends(verify_token), redis: arq.ArqRedis = Depends(get_redis_pool), ) -> BatchIngestResponse: """ Batch ingest multiple files using the task queue. Args: files: List of files to ingest metadata: JSON string of metadata (either a single dict or list of dicts) rules: JSON string of rules list. Can be either: - A single list of rules to apply to all files - A list of rule lists, one per file use_colpali: Whether to use ColPali-style embedding folder_name: Optional folder to scope the documents to end_user_id: Optional end-user ID to scope the documents to auth: Authentication context redis: Redis connection pool for background tasks Returns: BatchIngestResponse containing: - documents: List of created documents with processing status - errors: List of errors that occurred during the batch operation """ if not files: raise HTTPException(status_code=400, detail="No files provided for batch ingestion") try: metadata_value = json.loads(metadata) rules_list = json.loads(rules) # Fix bool conversion: ensure string "false" is properly converted to False def str2bool(v): return str(v).lower() in {"true", "1", "yes"} use_colpali = str2bool(use_colpali) # Ensure user has write permission if "write" not in auth.permissions: raise PermissionError("User does not have write permission") except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) # Validate metadata if it's a list if isinstance(metadata_value, list) and len(metadata_value) != len(files): raise HTTPException( status_code=400, detail=f"Number of metadata items ({len(metadata_value)}) must match number of files ({len(files)})", ) # Validate rules if it's a list of lists if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list): if len(rules_list) != len(files): raise HTTPException( status_code=400, detail=f"Number of rule lists ({len(rules_list)}) must match number of files ({len(files)})", ) # Convert auth context to a dictionary for serialization auth_dict = { "entity_type": auth.entity_type.value, "entity_id": auth.entity_id, "app_id": auth.app_id, "permissions": list(auth.permissions), "user_id": auth.user_id, } created_documents = [] try: for i, file in enumerate(files): # Get the metadata and rules for this file metadata_item = metadata_value[i] if isinstance(metadata_value, list) else metadata_value file_rules = ( rules_list[i] if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else rules_list ) # Create a document with processing status doc = Document( content_type=file.content_type, filename=file.filename, metadata=metadata_item, owner={"type": auth.entity_type.value, "id": auth.entity_id}, access_control={ "readers": [auth.entity_id], "writers": [auth.entity_id], "admins": [auth.entity_id], "user_id": [auth.user_id] if auth.user_id else [], }, ) # Add folder_name and end_user_id to system_metadata if provided if folder_name: doc.system_metadata["folder_name"] = folder_name if end_user_id: doc.system_metadata["end_user_id"] = end_user_id # Set processing status doc.system_metadata["status"] = "processing" # Store the document in the database success = await database.store_document(doc) if not success: raise Exception(f"Failed to store document metadata for {file.filename}") # If folder_name is provided, ensure the folder exists and add document to it if folder_name: try: await document_service._ensure_folder_exists(folder_name, doc.external_id, auth) logger.debug(f"Ensured folder '{folder_name}' exists and contains document {doc.external_id}") except Exception as e: # Log error but don't raise - we want document ingestion to continue even if folder operation fails logger.error(f"Error ensuring folder exists: {e}") # Read file content file_content = await file.read() # Generate a unique key for the file file_key = f"ingest_uploads/{uuid.uuid4()}/{file.filename}" # Store the file in the configured storage file_content_base64 = base64.b64encode(file_content).decode() bucket, stored_key = await storage.upload_from_base64(file_content_base64, file_key, file.content_type) logger.debug(f"Stored file in bucket {bucket} with key {stored_key}") # Update document with storage info doc.storage_info = {"bucket": bucket, "key": stored_key} await database.update_document( document_id=doc.external_id, updates={"storage_info": doc.storage_info}, auth=auth ) # Convert metadata to JSON string for job metadata_json = json.dumps(metadata_item) # Enqueue the background job job = await redis.enqueue_job( "process_ingestion_job", document_id=doc.external_id, file_key=stored_key, bucket=bucket, original_filename=file.filename, content_type=file.content_type, metadata_json=metadata_json, auth_dict=auth_dict, rules_list=file_rules, use_colpali=use_colpali, folder_name=folder_name, end_user_id=end_user_id, ) logger.info(f"File ingestion job queued with ID: {job.job_id} for document: {doc.external_id}") # Add document to the list created_documents.append(doc) # Return information about created documents return BatchIngestResponse(documents=created_documents, errors=[]) except Exception as e: logger.error(f"Error queueing batch file ingestion: {str(e)}") raise HTTPException(status_code=500, detail=f"Error queueing batch file ingestion: {str(e)}") @app.post("/retrieve/chunks", response_model=List[ChunkResult]) @telemetry.track(operation_type="retrieve_chunks", metadata_resolver=telemetry.retrieve_chunks_metadata) async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)): """ Retrieve relevant chunks. Args: request: RetrieveRequest containing: - query: Search query text - filters: Optional metadata filters - k: Number of results (default: 4) - min_score: Minimum similarity threshold (default: 0.0) - use_reranking: Whether to use reranking - use_colpali: Whether to use ColPali-style embedding model - folder_name: Optional folder to scope the search to - end_user_id: Optional end-user ID to scope the search to auth: Authentication context Returns: List[ChunkResult]: List of relevant chunks """ try: return await document_service.retrieve_chunks( request.query, auth, request.filters, request.k, request.min_score, request.use_reranking, request.use_colpali, request.folder_name, request.end_user_id, ) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.post("/retrieve/docs", response_model=List[DocumentResult]) @telemetry.track(operation_type="retrieve_docs", metadata_resolver=telemetry.retrieve_docs_metadata) async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)): """ Retrieve relevant documents. Args: request: RetrieveRequest containing: - query: Search query text - filters: Optional metadata filters - k: Number of results (default: 4) - min_score: Minimum similarity threshold (default: 0.0) - use_reranking: Whether to use reranking - use_colpali: Whether to use ColPali-style embedding model - folder_name: Optional folder to scope the search to - end_user_id: Optional end-user ID to scope the search to auth: Authentication context Returns: List[DocumentResult]: List of relevant documents """ try: return await document_service.retrieve_docs( request.query, auth, request.filters, request.k, request.min_score, request.use_reranking, request.use_colpali, request.folder_name, request.end_user_id, ) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.post("/batch/documents", response_model=List[Document]) @telemetry.track(operation_type="batch_get_documents", metadata_resolver=telemetry.batch_documents_metadata) async def batch_get_documents(request: Dict[str, Any], auth: AuthContext = Depends(verify_token)): """ Retrieve multiple documents by their IDs in a single batch operation. Args: request: Dictionary containing: - document_ids: List of document IDs to retrieve - folder_name: Optional folder to scope the operation to - end_user_id: Optional end-user ID to scope the operation to auth: Authentication context Returns: List[Document]: List of documents matching the IDs """ try: # Extract document_ids from request document_ids = request.get("document_ids", []) folder_name = request.get("folder_name") end_user_id = request.get("end_user_id") if not document_ids: return [] return await document_service.batch_retrieve_documents(document_ids, auth, folder_name, end_user_id) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.post("/batch/chunks", response_model=List[ChunkResult]) @telemetry.track(operation_type="batch_get_chunks", metadata_resolver=telemetry.batch_chunks_metadata) async def batch_get_chunks(request: Dict[str, Any], auth: AuthContext = Depends(verify_token)): """ Retrieve specific chunks by their document ID and chunk number in a single batch operation. Args: request: Dictionary containing: - sources: List of ChunkSource objects (with document_id and chunk_number) - folder_name: Optional folder to scope the operation to - end_user_id: Optional end-user ID to scope the operation to auth: Authentication context Returns: List[ChunkResult]: List of chunk results """ try: # Extract sources from request sources = request.get("sources", []) folder_name = request.get("folder_name") end_user_id = request.get("end_user_id") use_colpali = request.get("use_colpali") if not sources: return [] # Convert sources to ChunkSource objects if needed chunk_sources = [] for source in sources: if isinstance(source, dict): chunk_sources.append(ChunkSource(**source)) else: chunk_sources.append(source) return await document_service.batch_retrieve_chunks(chunk_sources, auth, folder_name, end_user_id, use_colpali) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.post("/query", response_model=CompletionResponse) @telemetry.track(operation_type="query", metadata_resolver=telemetry.query_metadata) async def query_completion(request: CompletionQueryRequest, auth: AuthContext = Depends(verify_token)): """ Generate completion using relevant chunks as context. When graph_name is provided, the query will leverage the knowledge graph to enhance retrieval by finding relevant entities and their connected documents. Args: request: CompletionQueryRequest containing: - query: Query text - filters: Optional metadata filters - k: Number of chunks to use as context (default: 4) - min_score: Minimum similarity threshold (default: 0.0) - max_tokens: Maximum tokens in completion - temperature: Model temperature - use_reranking: Whether to use reranking - use_colpali: Whether to use ColPali-style embedding model - graph_name: Optional name of the graph to use for knowledge graph-enhanced retrieval - hop_depth: Number of relationship hops to traverse in the graph (1-3) - include_paths: Whether to include relationship paths in the response - prompt_overrides: Optional customizations for entity extraction, resolution, and query prompts - folder_name: Optional folder to scope the operation to - end_user_id: Optional end-user ID to scope the operation to - schema: Optional schema for structured output auth: Authentication context Returns: CompletionResponse: Generated text completion or structured output """ try: # Validate prompt overrides before proceeding if request.prompt_overrides: validate_prompt_overrides_with_http_exception(request.prompt_overrides, operation_type="query") # Check query limits if in cloud mode if settings.MODE == "cloud" and auth.user_id: # Check limits before proceeding await check_and_increment_limits(auth, "query", 1) return await document_service.query( request.query, auth, request.filters, request.k, request.min_score, request.max_tokens, request.temperature, request.use_reranking, request.use_colpali, request.graph_name, request.hop_depth, request.include_paths, request.prompt_overrides, request.folder_name, request.end_user_id, request.schema, ) except ValueError as e: validate_prompt_overrides_with_http_exception(operation_type="query", error=e) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.post("/documents", response_model=List[Document]) async def list_documents( auth: AuthContext = Depends(verify_token), skip: int = 0, limit: int = 10000, filters: Optional[Dict[str, Any]] = None, folder_name: Optional[str] = None, end_user_id: Optional[str] = None, ): """ List accessible documents. Args: auth: Authentication context skip: Number of documents to skip limit: Maximum number of documents to return filters: Optional metadata filters folder_name: Optional folder to scope the operation to end_user_id: Optional end-user ID to scope the operation to Returns: List[Document]: List of accessible documents """ # Create system filters for folder and user scoping system_filters = {} if folder_name: system_filters["folder_name"] = folder_name if end_user_id: system_filters["end_user_id"] = end_user_id return await document_service.db.get_documents(auth, skip, limit, filters, system_filters) @app.get("/documents/{document_id}", response_model=Document) async def get_document(document_id: str, auth: AuthContext = Depends(verify_token)): """Get document by ID.""" try: doc = await document_service.db.get_document(document_id, auth) logger.debug(f"Found document: {doc}") if not doc: raise HTTPException(status_code=404, detail="Document not found") return doc except HTTPException as e: logger.error(f"Error getting document: {e}") raise e @app.get("/documents/{document_id}/status", response_model=Dict[str, Any]) async def get_document_status(document_id: str, auth: AuthContext = Depends(verify_token)): """ Get the processing status of a document. Args: document_id: ID of the document to check auth: Authentication context Returns: Dict containing status information for the document """ try: doc = await document_service.db.get_document(document_id, auth) if not doc: raise HTTPException(status_code=404, detail="Document not found") # Extract status information status = doc.system_metadata.get("status", "unknown") response = { "document_id": doc.external_id, "status": status, "filename": doc.filename, "created_at": doc.system_metadata.get("created_at"), "updated_at": doc.system_metadata.get("updated_at"), } # Add error information if failed if status == "failed": response["error"] = doc.system_metadata.get("error", "Unknown error") return response except HTTPException: raise except Exception as e: logger.error(f"Error getting document status: {str(e)}") raise HTTPException(status_code=500, detail=f"Error getting document status: {str(e)}") @app.delete("/documents/{document_id}") @telemetry.track(operation_type="delete_document", metadata_resolver=telemetry.document_delete_metadata) async def delete_document(document_id: str, auth: AuthContext = Depends(verify_token)): """ Delete a document and all associated data. This endpoint deletes a document and all its associated data, including: - Document metadata - Document content in storage - Document chunks and embeddings in vector store Args: document_id: ID of the document to delete auth: Authentication context (must have write access to the document) Returns: Deletion status """ try: success = await document_service.delete_document(document_id, auth) if not success: raise HTTPException(status_code=404, detail="Document not found or delete failed") return {"status": "success", "message": f"Document {document_id} deleted successfully"} except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.get("/documents/filename/{filename}", response_model=Document) async def get_document_by_filename( filename: str, auth: AuthContext = Depends(verify_token), folder_name: Optional[str] = None, end_user_id: Optional[str] = None, ): """ Get document by filename. Args: filename: Filename of the document to retrieve auth: Authentication context folder_name: Optional folder to scope the operation to end_user_id: Optional end-user ID to scope the operation to Returns: Document: Document metadata if found and accessible """ try: # Create system filters for folder and user scoping system_filters = {} if folder_name: system_filters["folder_name"] = folder_name if end_user_id: system_filters["end_user_id"] = end_user_id doc = await document_service.db.get_document_by_filename(filename, auth, system_filters) logger.debug(f"Found document by filename: {doc}") if not doc: raise HTTPException(status_code=404, detail=f"Document with filename '{filename}' not found") return doc except HTTPException as e: logger.error(f"Error getting document by filename: {e}") raise e @app.post("/documents/{document_id}/update_text", response_model=Document) @telemetry.track(operation_type="update_document_text", metadata_resolver=telemetry.document_update_text_metadata) async def update_document_text( document_id: str, request: IngestTextRequest, update_strategy: str = "add", auth: AuthContext = Depends(verify_token), ): """ Update a document with new text content using the specified strategy. Args: document_id: ID of the document to update request: Text content and metadata for the update update_strategy: Strategy for updating the document (default: 'add') Returns: Document: Updated document metadata """ try: doc = await document_service.update_document( document_id=document_id, auth=auth, content=request.content, file=None, filename=request.filename, metadata=request.metadata, rules=request.rules, update_strategy=update_strategy, use_colpali=request.use_colpali, ) if not doc: raise HTTPException(status_code=404, detail="Document not found or update failed") return doc except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.post("/documents/{document_id}/update_file", response_model=Document) @telemetry.track(operation_type="update_document_file", metadata_resolver=telemetry.document_update_file_metadata) async def update_document_file( document_id: str, file: UploadFile, metadata: str = Form("{}"), rules: str = Form("[]"), update_strategy: str = Form("add"), use_colpali: Optional[bool] = None, auth: AuthContext = Depends(verify_token), ): """ Update a document with content from a file using the specified strategy. Args: document_id: ID of the document to update file: File to add to the document metadata: JSON string of metadata to merge with existing metadata rules: JSON string of rules to apply to the content update_strategy: Strategy for updating the document (default: 'add') use_colpali: Whether to use multi-vector embedding Returns: Document: Updated document metadata """ try: metadata_dict = json.loads(metadata) rules_list = json.loads(rules) doc = await document_service.update_document( document_id=document_id, auth=auth, content=None, file=file, filename=file.filename, metadata=metadata_dict, rules=rules_list, update_strategy=update_strategy, use_colpali=use_colpali, ) if not doc: raise HTTPException(status_code=404, detail="Document not found or update failed") return doc except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.post("/documents/{document_id}/update_metadata", response_model=Document) @telemetry.track( operation_type="update_document_metadata", metadata_resolver=telemetry.document_update_metadata_resolver, ) async def update_document_metadata( document_id: str, metadata: Dict[str, Any], auth: AuthContext = Depends(verify_token) ): """ Update only a document's metadata. Args: document_id: ID of the document to update metadata: New metadata to merge with existing metadata Returns: Document: Updated document metadata """ try: doc = await document_service.update_document( document_id=document_id, auth=auth, content=None, file=None, filename=None, metadata=metadata, rules=[], update_strategy="add", use_colpali=None, ) if not doc: raise HTTPException(status_code=404, detail="Document not found or update failed") return doc except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) # Usage tracking endpoints @app.get("/usage/stats") @telemetry.track(operation_type="get_usage_stats", metadata_resolver=telemetry.usage_stats_metadata) async def get_usage_stats(auth: AuthContext = Depends(verify_token)) -> Dict[str, int]: """Get usage statistics for the authenticated user.""" if not auth.permissions or "admin" not in auth.permissions: return telemetry.get_user_usage(auth.entity_id) return telemetry.get_user_usage(auth.entity_id) @app.get("/usage/recent") @telemetry.track(operation_type="get_recent_usage", metadata_resolver=telemetry.recent_usage_metadata) async def get_recent_usage( auth: AuthContext = Depends(verify_token), operation_type: Optional[str] = None, since: Optional[datetime] = None, status: Optional[str] = None, ) -> List[Dict]: """Get recent usage records.""" if not auth.permissions or "admin" not in auth.permissions: records = telemetry.get_recent_usage( user_id=auth.entity_id, operation_type=operation_type, since=since, status=status ) else: records = telemetry.get_recent_usage(operation_type=operation_type, since=since, status=status) return [ { "timestamp": record.timestamp, "operation_type": record.operation_type, "tokens_used": record.tokens_used, "user_id": record.user_id, "duration_ms": record.duration_ms, "status": record.status, "metadata": record.metadata, } for record in records ] # Cache endpoints @app.post("/cache/create") @telemetry.track(operation_type="create_cache", metadata_resolver=telemetry.cache_create_metadata) async def create_cache( name: str, model: str, gguf_file: str, filters: Optional[Dict[str, Any]] = None, docs: Optional[List[str]] = None, auth: AuthContext = Depends(verify_token), ) -> Dict[str, Any]: """Create a new cache with specified configuration.""" try: # Check cache creation limits if in cloud mode if settings.MODE == "cloud" and auth.user_id: # Check limits before proceeding await check_and_increment_limits(auth, "cache", 1) filter_docs = set(await document_service.db.get_documents(auth, filters=filters)) additional_docs = ( {await document_service.db.get_document(document_id=doc_id, auth=auth) for doc_id in docs} if docs else set() ) docs_to_add = list(filter_docs.union(additional_docs)) if not docs_to_add: raise HTTPException(status_code=400, detail="No documents to add to cache") response = await document_service.create_cache(name, model, gguf_file, docs_to_add, filters) return response except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.get("/cache/{name}") @telemetry.track(operation_type="get_cache", metadata_resolver=telemetry.cache_get_metadata) async def get_cache(name: str, auth: AuthContext = Depends(verify_token)) -> Dict[str, Any]: """Get cache configuration by name.""" try: exists = await document_service.load_cache(name) return {"exists": exists} except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.post("/cache/{name}/update") @telemetry.track(operation_type="update_cache", metadata_resolver=telemetry.cache_update_metadata) async def update_cache(name: str, auth: AuthContext = Depends(verify_token)) -> Dict[str, bool]: """Update cache with new documents matching its filter.""" try: if name not in document_service.active_caches: exists = await document_service.load_cache(name) if not exists: raise HTTPException(status_code=404, detail=f"Cache '{name}' not found") cache = document_service.active_caches[name] docs = await document_service.db.get_documents(auth, filters=cache.filters) docs_to_add = [doc for doc in docs if doc.id not in cache.docs] return cache.add_docs(docs_to_add) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.post("/cache/{name}/add_docs") @telemetry.track(operation_type="add_docs_to_cache", metadata_resolver=telemetry.cache_add_docs_metadata) async def add_docs_to_cache(name: str, docs: List[str], auth: AuthContext = Depends(verify_token)) -> Dict[str, bool]: """Add specific documents to the cache.""" try: cache = document_service.active_caches[name] docs_to_add = [ await document_service.db.get_document(doc_id, auth) for doc_id in docs if doc_id not in cache.docs ] return cache.add_docs(docs_to_add) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.post("/cache/{name}/query") @telemetry.track(operation_type="query_cache", metadata_resolver=telemetry.cache_query_metadata) async def query_cache( name: str, query: str, max_tokens: Optional[int] = None, temperature: Optional[float] = None, auth: AuthContext = Depends(verify_token), ) -> CompletionResponse: """Query the cache with a prompt.""" try: # Check cache query limits if in cloud mode if settings.MODE == "cloud" and auth.user_id: # Check limits before proceeding await check_and_increment_limits(auth, "cache_query", 1) cache = document_service.active_caches[name] logger.info(f"Cache state: {cache.state.n_tokens}") return cache.query(query) # , max_tokens, temperature) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @app.post("/graph/create", response_model=Graph) @telemetry.track(operation_type="create_graph", metadata_resolver=telemetry.create_graph_metadata) async def create_graph( request: CreateGraphRequest, auth: AuthContext = Depends(verify_token), ) -> Graph: """ Create a graph from documents. This endpoint extracts entities and relationships from documents matching the specified filters or document IDs and creates a graph. Args: request: CreateGraphRequest containing: - name: Name of the graph to create - filters: Optional metadata filters to determine which documents to include - documents: Optional list of specific document IDs to include - prompt_overrides: Optional customizations for entity extraction and resolution prompts - folder_name: Optional folder to scope the operation to - end_user_id: Optional end-user ID to scope the operation to auth: Authentication context Returns: Graph: The created graph object """ try: # Validate prompt overrides before proceeding if request.prompt_overrides: validate_prompt_overrides_with_http_exception(request.prompt_overrides, operation_type="graph") # Check graph creation limits if in cloud mode if settings.MODE == "cloud" and auth.user_id: # Check limits before proceeding await check_and_increment_limits(auth, "graph", 1) # Create system filters for folder and user scoping system_filters = {} if request.folder_name: system_filters["folder_name"] = request.folder_name if request.end_user_id: system_filters["end_user_id"] = request.end_user_id return await document_service.create_graph( name=request.name, auth=auth, filters=request.filters, documents=request.documents, prompt_overrides=request.prompt_overrides, system_filters=system_filters, ) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) except ValueError as e: validate_prompt_overrides_with_http_exception(operation_type="graph", error=e) @app.post("/folders", response_model=Folder) async def create_folder( folder_create: FolderCreate, auth: AuthContext = Depends(verify_token), ) -> Folder: """ Create a new folder. Args: folder_create: Folder creation request containing name and optional description auth: Authentication context Returns: Folder: Created folder """ try: async with telemetry.track_operation( operation_type="create_folder", user_id=auth.entity_id, metadata={ "name": folder_create.name, }, ): # Create a folder object with explicit ID import uuid folder_id = str(uuid.uuid4()) logger.info(f"Creating folder with ID: {folder_id}, auth.user_id: {auth.user_id}") # Set up access control with user_id access_control = { "readers": [auth.entity_id], "writers": [auth.entity_id], "admins": [auth.entity_id], } if auth.user_id: access_control["user_id"] = [auth.user_id] logger.info(f"Adding user_id {auth.user_id} to folder access control") folder = Folder( id=folder_id, name=folder_create.name, description=folder_create.description, owner={ "type": auth.entity_type.value, "id": auth.entity_id, }, access_control=access_control, ) # Store in database success = await document_service.db.create_folder(folder) if not success: raise HTTPException(status_code=500, detail="Failed to create folder") return folder except Exception as e: logger.error(f"Error creating folder: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/folders", response_model=List[Folder]) async def list_folders( auth: AuthContext = Depends(verify_token), ) -> List[Folder]: """ List all folders the user has access to. Args: auth: Authentication context Returns: List[Folder]: List of folders """ try: async with telemetry.track_operation( operation_type="list_folders", user_id=auth.entity_id, ): folders = await document_service.db.list_folders(auth) return folders except Exception as e: logger.error(f"Error listing folders: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/folders/{folder_id}", response_model=Folder) async def get_folder( folder_id: str, auth: AuthContext = Depends(verify_token), ) -> Folder: """ Get a folder by ID. Args: folder_id: ID of the folder auth: Authentication context Returns: Folder: Folder if found and accessible """ try: async with telemetry.track_operation( operation_type="get_folder", user_id=auth.entity_id, metadata={ "folder_id": folder_id, }, ): folder = await document_service.db.get_folder(folder_id, auth) if not folder: raise HTTPException(status_code=404, detail=f"Folder {folder_id} not found") return folder except HTTPException: raise except Exception as e: logger.error(f"Error getting folder: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/folders/{folder_id}/documents/{document_id}") async def add_document_to_folder( folder_id: str, document_id: str, auth: AuthContext = Depends(verify_token), ): """ Add a document to a folder. Args: folder_id: ID of the folder document_id: ID of the document auth: Authentication context Returns: Success status """ try: async with telemetry.track_operation( operation_type="add_document_to_folder", user_id=auth.entity_id, metadata={ "folder_id": folder_id, "document_id": document_id, }, ): success = await document_service.db.add_document_to_folder(folder_id, document_id, auth) if not success: raise HTTPException(status_code=500, detail="Failed to add document to folder") return {"status": "success"} except Exception as e: logger.error(f"Error adding document to folder: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.delete("/folders/{folder_id}/documents/{document_id}") async def remove_document_from_folder( folder_id: str, document_id: str, auth: AuthContext = Depends(verify_token), ): """ Remove a document from a folder. Args: folder_id: ID of the folder document_id: ID of the document auth: Authentication context Returns: Success status """ try: async with telemetry.track_operation( operation_type="remove_document_from_folder", user_id=auth.entity_id, metadata={ "folder_id": folder_id, "document_id": document_id, }, ): success = await document_service.db.remove_document_from_folder(folder_id, document_id, auth) if not success: raise HTTPException(status_code=500, detail="Failed to remove document from folder") return {"status": "success"} except Exception as e: logger.error(f"Error removing document from folder: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/graph/{name}", response_model=Graph) @telemetry.track(operation_type="get_graph", metadata_resolver=telemetry.get_graph_metadata) async def get_graph( name: str, auth: AuthContext = Depends(verify_token), folder_name: Optional[str] = None, end_user_id: Optional[str] = None, ) -> Graph: """ Get a graph by name. This endpoint retrieves a graph by its name if the user has access to it. Args: name: Name of the graph to retrieve auth: Authentication context folder_name: Optional folder to scope the operation to end_user_id: Optional end-user ID to scope the operation to Returns: Graph: The requested graph object """ try: # Create system filters for folder and user scoping system_filters = {} if folder_name: system_filters["folder_name"] = folder_name if end_user_id: system_filters["end_user_id"] = end_user_id graph = await document_service.db.get_graph(name, auth, system_filters) if not graph: raise HTTPException(status_code=404, detail=f"Graph '{name}' not found") return graph except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/graphs", response_model=List[Graph]) @telemetry.track(operation_type="list_graphs", metadata_resolver=telemetry.list_graphs_metadata) async def list_graphs( auth: AuthContext = Depends(verify_token), folder_name: Optional[str] = None, end_user_id: Optional[str] = None, ) -> List[Graph]: """ List all graphs the user has access to. This endpoint retrieves all graphs the user has access to. Args: auth: Authentication context folder_name: Optional folder to scope the operation to end_user_id: Optional end-user ID to scope the operation to Returns: List[Graph]: List of graph objects """ try: # Create system filters for folder and user scoping system_filters = {} if folder_name: system_filters["folder_name"] = folder_name if end_user_id: system_filters["end_user_id"] = end_user_id return await document_service.db.list_graphs(auth, system_filters) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/graph/{name}/update", response_model=Graph) @telemetry.track(operation_type="update_graph", metadata_resolver=telemetry.update_graph_metadata) async def update_graph( name: str, request: UpdateGraphRequest, auth: AuthContext = Depends(verify_token), ) -> Graph: """ Update an existing graph with new documents. This endpoint processes additional documents based on the original graph filters and/or new filters/document IDs, extracts entities and relationships, and updates the graph with new information. Args: name: Name of the graph to update request: UpdateGraphRequest containing: - additional_filters: Optional additional metadata filters to determine which new documents to include - additional_documents: Optional list of additional document IDs to include - prompt_overrides: Optional customizations for entity extraction and resolution prompts - folder_name: Optional folder to scope the operation to - end_user_id: Optional end-user ID to scope the operation to auth: Authentication context Returns: Graph: The updated graph object """ try: # Validate prompt overrides before proceeding if request.prompt_overrides: validate_prompt_overrides_with_http_exception(request.prompt_overrides, operation_type="graph") # Create system filters for folder and user scoping system_filters = {} if request.folder_name: system_filters["folder_name"] = request.folder_name if request.end_user_id: system_filters["end_user_id"] = request.end_user_id return await document_service.update_graph( name=name, auth=auth, additional_filters=request.additional_filters, additional_documents=request.additional_documents, prompt_overrides=request.prompt_overrides, system_filters=system_filters, ) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) except ValueError as e: validate_prompt_overrides_with_http_exception(operation_type="graph", error=e) except Exception as e: logger.error(f"Error updating graph: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/local/generate_uri", include_in_schema=True) async def generate_local_uri( name: str = Form("admin"), expiry_days: int = Form(30), ) -> Dict[str, str]: """Generate a local URI for development. This endpoint is unprotected.""" try: # Clean name name = name.replace(" ", "_").lower() # Create payload payload = { "type": "developer", "entity_id": name, "permissions": ["read", "write", "admin"], "exp": datetime.now(UTC) + timedelta(days=expiry_days), } # Generate token token = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) # Read config for host/port with open("morphik.toml", "rb") as f: config = tomli.load(f) base_url = f"{config['api']['host']}:{config['api']['port']}".replace("localhost", "127.0.0.1") # Generate URI uri = f"morphik://{name}:{token}@{base_url}" return {"uri": uri} except Exception as e: logger.error(f"Error generating local URI: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/cloud/generate_uri", include_in_schema=True) async def generate_cloud_uri( request: GenerateUriRequest, authorization: str = Header(None), ) -> Dict[str, str]: """Generate a URI for cloud hosted applications.""" try: app_id = request.app_id name = request.name user_id = request.user_id expiry_days = request.expiry_days logger.debug(f"Generating cloud URI for app_id={app_id}, name={name}, user_id={user_id}") # Verify authorization header before proceeding if not authorization: logger.warning("Missing authorization header") raise HTTPException( status_code=401, detail="Missing authorization header", headers={"WWW-Authenticate": "Bearer"}, ) # Verify the token is valid if not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid authorization header") token = authorization[7:] # Remove "Bearer " try: # Decode the token to ensure it's valid payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) # Only allow users to create apps for themselves (or admin) token_user_id = payload.get("user_id") logger.debug(f"Token user ID: {token_user_id}") logger.debug(f"User ID: {user_id}") if not (token_user_id == user_id or "admin" in payload.get("permissions", [])): raise HTTPException( status_code=403, detail="You can only create apps for your own account unless you have admin permissions", ) except jwt.InvalidTokenError as e: raise HTTPException(status_code=401, detail=str(e)) # Import UserService here to avoid circular imports from core.services.user_service import UserService user_service = UserService() # Initialize user service if needed await user_service.initialize() # Clean name name = name.replace(" ", "_").lower() # Check if the user is within app limit and generate URI uri = await user_service.generate_cloud_uri(user_id, app_id, name, expiry_days) if not uri: logger.debug("Application limit reached for this account tier with user_id: %s", user_id) raise HTTPException(status_code=403, detail="Application limit reached for this account tier") return {"uri": uri, "app_id": app_id} except HTTPException: # Re-raise HTTP exceptions raise except Exception as e: logger.error(f"Error generating cloud URI: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/folders/{folder_id}/set_rule") @telemetry.track(operation_type="set_folder_rule", metadata_resolver=telemetry.set_folder_rule_metadata) async def set_folder_rule( folder_id: str, request: SetFolderRuleRequest, auth: AuthContext = Depends(verify_token), apply_to_existing: bool = True, ): """ Set extraction rules for a folder. Args: folder_id: ID of the folder to set rules for request: SetFolderRuleRequest containing metadata extraction rules auth: Authentication context apply_to_existing: Whether to apply rules to existing documents in the folder Returns: Success status with processing results """ # Import text here to ensure it's available in this function's scope from sqlalchemy import text try: # Log detailed information about the rules logger.debug(f"Setting rules for folder {folder_id}") logger.debug(f"Number of rules: {len(request.rules)}") for i, rule in enumerate(request.rules): logger.debug(f"\nRule {i + 1}:") logger.debug(f"Type: {rule.type}") logger.debug("Schema:") for field_name, field_config in rule.schema.items(): logger.debug(f" Field: {field_name}") logger.debug(f" Type: {field_config.get('type', 'unknown')}") logger.debug(f" Description: {field_config.get('description', 'No description')}") if "schema" in field_config: logger.debug(" Has JSON schema: Yes") logger.debug(f" Schema: {field_config['schema']}") # Get the folder folder = await document_service.db.get_folder(folder_id, auth) if not folder: raise HTTPException(status_code=404, detail=f"Folder {folder_id} not found") # Check if user has write access to the folder if not document_service.db._check_folder_access(folder, auth, "write"): raise HTTPException(status_code=403, detail="You don't have write access to this folder") # Update folder with rules # Convert rules to dicts for JSON serialization rules_dicts = [rule.model_dump() for rule in request.rules] # Update the folder in the database async with document_service.db.async_session() as session: # Execute update query await session.execute( text( """ UPDATE folders SET rules = :rules WHERE id = :folder_id """ ), {"folder_id": folder_id, "rules": json.dumps(rules_dicts)}, ) await session.commit() logger.info(f"Successfully updated folder {folder_id} with {len(request.rules)} rules") # Get updated folder updated_folder = await document_service.db.get_folder(folder_id, auth) # If apply_to_existing is True, apply these rules to all existing documents in the folder processing_results = {"processed": 0, "errors": []} if apply_to_existing and folder.document_ids: logger.info(f"Applying rules to {len(folder.document_ids)} existing documents in folder") # Import rules processor # Get all documents in the folder documents = await document_service.db.get_documents_by_id(folder.document_ids, auth) # Process each document for doc in documents: try: # Get document content logger.info(f"Processing document {doc.external_id}") # For each document, apply the rules from the folder doc_content = None # Get content from system_metadata if available if doc.system_metadata and "content" in doc.system_metadata: doc_content = doc.system_metadata["content"] logger.info(f"Retrieved content from system_metadata for document {doc.external_id}") # If we still have no content, log error and continue if not doc_content: error_msg = f"No content found in system_metadata for document {doc.external_id}" logger.error(error_msg) processing_results["errors"].append({"document_id": doc.external_id, "error": error_msg}) continue # Process document with rules try: # Convert request rules to actual rule models and apply them from core.models.rules import MetadataExtractionRule for rule_request in request.rules: if rule_request.type == "metadata_extraction": # Create the actual rule model rule = MetadataExtractionRule(type=rule_request.type, schema=rule_request.schema) # Apply the rule with retries max_retries = 3 base_delay = 1 # seconds extracted_metadata = None last_error = None for retry_count in range(max_retries): try: if retry_count > 0: # Exponential backoff delay = base_delay * (2 ** (retry_count - 1)) logger.info(f"Retry {retry_count}/{max_retries} after {delay}s delay") await asyncio.sleep(delay) extracted_metadata, _ = await rule.apply(doc_content, {}) logger.info( f"Successfully extracted metadata on attempt {retry_count + 1}: " f"{extracted_metadata}" ) break # Success, exit retry loop except Exception as rule_apply_error: last_error = rule_apply_error logger.warning( f"Metadata extraction attempt {retry_count + 1} failed: " f"{rule_apply_error}" ) if retry_count == max_retries - 1: # Last attempt logger.error(f"All {max_retries} metadata extraction attempts failed") processing_results["errors"].append( { "document_id": doc.external_id, "error": f"Failed to extract metadata after {max_retries} " f"attempts: {str(last_error)}", } ) continue # Skip to next document # Update document metadata if extraction succeeded if extracted_metadata: # Merge new metadata with existing doc.metadata.update(extracted_metadata) # Create an updates dict that only updates metadata # We need to create system_metadata with all preserved fields # Note: In the database, metadata is stored as 'doc_metadata', not 'metadata' updates = { "doc_metadata": doc.metadata, # Use doc_metadata for the database "system_metadata": {}, # Will be merged with existing in update_document } # Explicitly preserve the content field in system_metadata if "content" in doc.system_metadata: updates["system_metadata"]["content"] = doc.system_metadata["content"] # Log the updates we're making logger.info( f"Updating document {doc.external_id} with metadata: {extracted_metadata}" ) logger.info(f"Full metadata being updated: {doc.metadata}") logger.info(f"Update object being sent to database: {updates}") logger.info( f"Preserving content in system_metadata: {'content' in doc.system_metadata}" ) # Update document in database success = await document_service.db.update_document(doc.external_id, updates, auth) if success: logger.info(f"Updated metadata for document {doc.external_id}") processing_results["processed"] += 1 else: logger.error(f"Failed to update metadata for document {doc.external_id}") processing_results["errors"].append( { "document_id": doc.external_id, "error": "Failed to update document metadata", } ) except Exception as rule_error: logger.error(f"Error processing rules for document {doc.external_id}: {rule_error}") processing_results["errors"].append( { "document_id": doc.external_id, "error": f"Error processing rules: {str(rule_error)}", } ) except Exception as doc_error: logger.error(f"Error processing document {doc.external_id}: {doc_error}") processing_results["errors"].append({"document_id": doc.external_id, "error": str(doc_error)}) return { "status": "success", "message": "Rules set successfully", "folder_id": folder_id, "rules": updated_folder.rules, "processing_results": processing_results, } except HTTPException: # Re-raise HTTP exceptions raise except Exception as e: logger.error(f"Error setting folder rules: {e}") raise HTTPException(status_code=500, detail=str(e)) ``` ## /core/cache/base_cache.py ```py path="/core/cache/base_cache.py" from abc import ABC, abstractmethod from typing import Any, Dict, List from core.models.completion import CompletionResponse from core.models.documents import Document class BaseCache(ABC): """Base class for cache implementations. This class defines the interface for cache implementations that support document ingestion and cache-augmented querying. """ def __init__(self, name: str, model: str, gguf_file: str, filters: Dict[str, Any], docs: List[Document]): """Initialize the cache with the given parameters. Args: name: Name of the cache instance model: Model identifier gguf_file: Path to the GGUF model file filters: Filters used to create the cache context docs: Initial documents to ingest into the cache """ self.name = name self.filters = filters self.docs = [] # List of document IDs that have been ingested self._initialize(model, gguf_file, docs) @abstractmethod def _initialize(self, model: str, gguf_file: str, docs: List[Document]) -> None: """Internal initialization method to be implemented by subclasses.""" pass @abstractmethod async def add_docs(self, docs: List[Document]) -> bool: """Add documents to the cache. Args: docs: List of documents to add to the cache Returns: bool: True if documents were successfully added """ pass @abstractmethod async def query(self, query: str) -> CompletionResponse: """Query the cache for relevant documents and generate a response. Args: query: Query string to search for relevant documents Returns: CompletionResponse: Generated response based on cached context """ pass @property @abstractmethod def saveable_state(self) -> bytes: """Get the saveable state of the cache as bytes. Returns: bytes: Serialized state that can be used to restore the cache """ pass ``` ## /core/cache/base_cache_factory.py ```py path="/core/cache/base_cache_factory.py" from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict from .base_cache import BaseCache class BaseCacheFactory(ABC): """Abstract base factory for creating and loading caches.""" def __init__(self, storage_path: Path): """Initialize the cache factory. Args: storage_path: Base path for storing cache files """ self.storage_path = storage_path self.storage_path.mkdir(parents=True, exist_ok=True) @abstractmethod def create_new_cache(self, name: str, model: str, model_file: str, **kwargs: Dict[str, Any]) -> BaseCache: """Create a new cache instance. Args: name: Name of the cache model: Name/type of the model to use model_file: Path or identifier for the model file **kwargs: Additional arguments for cache creation Returns: BaseCache: The created cache instance """ pass @abstractmethod def load_cache_from_bytes( self, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs: Dict[str, Any] ) -> BaseCache: """Load a cache from its serialized bytes. Args: name: Name of the cache cache_bytes: Serialized cache data metadata: Cache metadata including model info **kwargs: Additional arguments for cache loading Returns: BaseCache: The loaded cache instance """ pass def get_cache_path(self, name: str) -> Path: """Get the storage path for a cache. Args: name: Name of the cache Returns: Path: Directory path for the cache """ path = self.storage_path / name path.mkdir(parents=True, exist_ok=True) return path ``` ## /core/cache/hf_cache.py ```py path="/core/cache/hf_cache.py" # hugging face cache implementation. from pathlib import Path from typing import List, Optional, Union import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.cache_utils import DynamicCache from core.cache.base_cache import BaseCache from core.models.completion import CompletionRequest, CompletionResponse class HuggingFaceCache(BaseCache): """Hugging Face Cache implementation for cache-augmented generation""" def __init__( self, cache_path: Path, model_name: str = "distilgpt2", device: str = "cpu", default_max_new_tokens: int = 100, use_fp16: bool = False, ): """Initialize the HuggingFace cache. Args: cache_path: Path to store cache files model_name: Name of the HuggingFace model to use device: Device to run the model on (e.g. "cpu", "cuda", "mps") default_max_new_tokens: Default maximum number of new tokens to generate use_fp16: Whether to use FP16 precision """ super().__init__() self.cache_path = cache_path self.model_name = model_name self.device = device self.default_max_new_tokens = default_max_new_tokens self.use_fp16 = use_fp16 # Initialize tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(model_name) # Configure model loading based on device model_kwargs = {"low_cpu_mem_usage": True} if device == "cpu": # For CPU, use standard loading model_kwargs.update({"torch_dtype": torch.float32}) self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs).to(device) else: # For GPU/MPS, use automatic device mapping and optional FP16 model_kwargs.update({"device_map": "auto", "torch_dtype": torch.float16 if use_fp16 else torch.float32}) self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) self.kv_cache = None self.origin_len = None def get_kv_cache(self, prompt: str) -> DynamicCache: """Build KV cache from prompt""" input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) cache = DynamicCache() with torch.no_grad(): _ = self.model(input_ids=input_ids, past_key_values=cache, use_cache=True) return cache def clean_up_cache(self, cache: DynamicCache, origin_len: int): """Clean up cache by removing appended tokens""" for i in range(len(cache.key_cache)): cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :] cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :] def generate(self, input_ids: torch.Tensor, past_key_values, max_new_tokens: Optional[int] = None) -> torch.Tensor: """Generate text using the model and cache""" device = next(self.model.parameters()).device origin_len = input_ids.shape[-1] input_ids = input_ids.to(device) output_ids = input_ids.clone() next_token = input_ids with torch.no_grad(): for _ in range(max_new_tokens or self.default_max_new_tokens): out = self.model(input_ids=next_token, past_key_values=past_key_values, use_cache=True) logits = out.logits[:, -1, :] token = torch.argmax(logits, dim=-1, keepdim=True) output_ids = torch.cat([output_ids, token], dim=-1) past_key_values = out.past_key_values next_token = token.to(device) if self.model.config.eos_token_id is not None and token.item() == self.model.config.eos_token_id: break return output_ids[:, origin_len:] async def ingest(self, docs: List[str]) -> bool: """Ingest documents into cache""" try: # Create system prompt with documents system_prompt = f""" <|system|> You are an assistant who provides concise factual answers. <|user|> Context: {' '.join(docs)} Question: """.strip() # Build the cache input_ids = self.tokenizer(system_prompt, return_tensors="pt").input_ids.to(self.device) self.kv_cache = DynamicCache() with torch.no_grad(): # First run to get the cache shape outputs = self.model(input_ids=input_ids, use_cache=True) # Initialize cache with empty tensors of the right shape n_layers = len(outputs.past_key_values) batch_size = input_ids.shape[0] # Handle different model architectures if hasattr(self.model.config, "num_key_value_heads"): # Models with grouped query attention (GQA) like Llama n_kv_heads = self.model.config.num_key_value_heads head_dim = self.model.config.head_dim elif hasattr(self.model.config, "n_head"): # GPT-style models n_kv_heads = self.model.config.n_head head_dim = self.model.config.n_embd // self.model.config.n_head elif hasattr(self.model.config, "num_attention_heads"): # OPT-style models n_kv_heads = self.model.config.num_attention_heads head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads else: raise ValueError(f"Unsupported model architecture: {self.model.config.model_type}") seq_len = input_ids.shape[1] for i in range(n_layers): key_shape = (batch_size, n_kv_heads, seq_len, head_dim) value_shape = key_shape self.kv_cache.key_cache.append(torch.zeros(key_shape, device=self.device)) self.kv_cache.value_cache.append(torch.zeros(value_shape, device=self.device)) # Now run with the initialized cache outputs = self.model(input_ids=input_ids, past_key_values=self.kv_cache, use_cache=True) # Update cache with actual values self.kv_cache.key_cache = [layer[0] for layer in outputs.past_key_values] self.kv_cache.value_cache = [layer[1] for layer in outputs.past_key_values] self.origin_len = self.kv_cache.key_cache[0].shape[-2] return True except Exception as e: print(f"Error ingesting documents: {e}") return False async def update(self, new_doc: str) -> bool: """Update cache with new document""" try: if self.kv_cache is None: return await self.ingest([new_doc]) # Clean up existing cache self.clean_up_cache(self.kv_cache, self.origin_len) # Add new document to cache input_ids = self.tokenizer(new_doc + "\n", return_tensors="pt").input_ids.to(self.device) # First run to get the cache shape outputs = self.model(input_ids=input_ids, use_cache=True) # Initialize cache with empty tensors of the right shape n_layers = len(outputs.past_key_values) batch_size = input_ids.shape[0] # Handle different model architectures if hasattr(self.model.config, "num_key_value_heads"): # Models with grouped query attention (GQA) like Llama n_kv_heads = self.model.config.num_key_value_heads head_dim = self.model.config.head_dim elif hasattr(self.model.config, "n_head"): # GPT-style models n_kv_heads = self.model.config.n_head head_dim = self.model.config.n_embd // self.model.config.n_head elif hasattr(self.model.config, "num_attention_heads"): # OPT-style models n_kv_heads = self.model.config.num_attention_heads head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads else: raise ValueError(f"Unsupported model architecture: {self.model.config.model_type}") seq_len = input_ids.shape[1] # Create a new cache for the update new_cache = DynamicCache() for i in range(n_layers): key_shape = (batch_size, n_kv_heads, seq_len, head_dim) value_shape = key_shape new_cache.key_cache.append(torch.zeros(key_shape, device=self.device)) new_cache.value_cache.append(torch.zeros(value_shape, device=self.device)) # Run with the initialized cache outputs = self.model(input_ids=input_ids, past_key_values=new_cache, use_cache=True) # Update cache with actual values self.kv_cache.key_cache = [layer[0] for layer in outputs.past_key_values] self.kv_cache.value_cache = [layer[1] for layer in outputs.past_key_values] return True except Exception as e: print(f"Error updating cache: {e}") return False async def complete(self, request: CompletionRequest) -> CompletionResponse: """Generate completion using cache-augmented generation""" try: if self.kv_cache is None: raise ValueError("Cache not initialized. Please ingest documents first.") # Clean up cache self.clean_up_cache(self.kv_cache, self.origin_len) # Generate completion input_ids = self.tokenizer(request.query + "\n", return_tensors="pt").input_ids.to(self.device) gen_ids = self.generate(input_ids, self.kv_cache, max_new_tokens=request.max_tokens) completion = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True) # Calculate token usage usage = { "prompt_tokens": len(input_ids[0]), "completion_tokens": len(gen_ids[0]), "total_tokens": len(input_ids[0]) + len(gen_ids[0]), } return CompletionResponse(completion=completion, usage=usage) except Exception as e: print(f"Error generating completion: {e}") return CompletionResponse( completion=f"Error: {str(e)}", usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, ) def save_cache(self) -> Path: """Save the KV cache to disk""" if self.kv_cache is None: raise ValueError("No cache to save") cache_dir = self.cache_path / "kv_cache" cache_dir.mkdir(parents=True, exist_ok=True) # Save key and value caches cache_data = { "key_cache": self.kv_cache.key_cache, "value_cache": self.kv_cache.value_cache, "origin_len": self.origin_len, } cache_path = cache_dir / "cache.pt" torch.save(cache_data, cache_path) return cache_path def load_cache(self, cache_path: Union[str, Path]) -> None: """Load KV cache from disk""" cache_path = Path(cache_path) if not cache_path.exists(): raise FileNotFoundError(f"Cache file not found at {cache_path}") cache_data = torch.load(cache_path, map_location=self.device) self.kv_cache = DynamicCache() self.kv_cache.key_cache = cache_data["key_cache"] self.kv_cache.value_cache = cache_data["value_cache"] self.origin_len = cache_data["origin_len"] ``` ## /core/cache/llama_cache.py ```py path="/core/cache/llama_cache.py" import json import logging import pickle from typing import Any, Dict, List from llama_cpp import Llama from core.cache.base_cache import BaseCache from core.models.completion import CompletionResponse from core.models.documents import Document logger = logging.getLogger(__name__) INITIAL_SYSTEM_PROMPT = """<|im_start|>system You are a helpful AI assistant with access to provided documents. Your role is to: 1. Answer questions accurately based on the documents provided 2. Stay focused on the document content and avoid speculation 3. Admit when you don't have enough information to answer 4. Be clear and concise in your responses 5. Use direct quotes from documents when relevant Provided documents: {documents} <|im_end|> """.strip() ADD_DOC_SYSTEM_PROMPT = """<|im_start|>system I'm adding some additional documents for your reference: {documents} Please incorporate this new information along with what you already know from previous documents while maintaining the same guidelines for responses. <|im_end|> """.strip() QUERY_PROMPT = """<|im_start|>user {query} <|im_end|> <|im_start|>assistant """.strip() class LlamaCache(BaseCache): def __init__( self, name: str, model: str, gguf_file: str, filters: Dict[str, Any], docs: List[Document], **kwargs, ): logger.info(f"Initializing LlamaCache with name={name}, model={model}") # cache related self.name = name self.model = model self.filters = filters self.docs = docs # llama specific self.gguf_file = gguf_file self.n_gpu_layers = kwargs.get("n_gpu_layers", -1) logger.info(f"Using {self.n_gpu_layers} GPU layers") # late init (when we call _initialize) self.llama = None self.state = None self.cached_tokens = 0 self._initialize(model, gguf_file, docs) logger.info("LlamaCache initialization complete") def _initialize(self, model: str, gguf_file: str, docs: List[Document]) -> None: logger.info(f"Loading Llama model from {model} with file {gguf_file}") try: # Set a reasonable default context size (32K tokens) default_ctx_size = 32768 self.llama = Llama.from_pretrained( repo_id=model, filename=gguf_file, n_gpu_layers=self.n_gpu_layers, n_ctx=default_ctx_size, verbose=False, # Enable verbose mode for better error reporting ) logger.info("Model loaded successfully") # Format and tokenize system prompt documents = "\n".join(doc.system_metadata.get("content", "") for doc in docs) system_prompt = INITIAL_SYSTEM_PROMPT.format(documents=documents) logger.info(f"Built system prompt: {system_prompt[:200]}...") try: tokens = self.llama.tokenize(system_prompt.encode()) logger.info(f"System prompt tokenized to {len(tokens)} tokens") # Process tokens to build KV cache logger.info("Evaluating system prompt") self.llama.eval(tokens) logger.info("Saving initial KV cache state") self.state = self.llama.save_state() self.cached_tokens = len(tokens) logger.info(f"Initial KV cache built with {self.cached_tokens} tokens") except Exception as e: logger.error(f"Error during prompt processing: {str(e)}") raise ValueError(f"Failed to process system prompt: {str(e)}") except Exception as e: logger.error(f"Failed to initialize Llama model: {str(e)}") raise ValueError(f"Failed to initialize Llama model: {str(e)}") def add_docs(self, docs: List[Document]) -> bool: logger.info(f"Adding {len(docs)} new documents to cache") documents = "\n".join(doc.system_metadata.get("content", "") for doc in docs) system_prompt = ADD_DOC_SYSTEM_PROMPT.format(documents=documents) # Tokenize and process new_tokens = self.llama.tokenize(system_prompt.encode()) self.llama.eval(new_tokens) self.state = self.llama.save_state() self.cached_tokens += len(new_tokens) logger.info(f"Added {len(new_tokens)} tokens, total: {self.cached_tokens}") return True def query(self, query: str) -> CompletionResponse: # Format query with proper chat template formatted_query = QUERY_PROMPT.format(query=query) logger.info(f"Processing query: {formatted_query}") # Reset and load cached state self.llama.reset() self.llama.load_state(self.state) logger.info(f"Loaded state with {self.state.n_tokens} tokens") # print(f"Loaded state with {self.state.n_tokens} tokens", file=sys.stderr) # Tokenize and process query query_tokens = self.llama.tokenize(formatted_query.encode()) self.llama.eval(query_tokens) logger.info(f"Evaluated query tokens: {query_tokens}") # print(f"Evaluated query tokens: {query_tokens}", file=sys.stderr) # Generate response output_tokens = [] for token in self.llama.generate(tokens=[], reset=False): output_tokens.append(token) # Stop generation when EOT token is encountered if token == self.llama.token_eos(): break # Decode and return completion = self.llama.detokenize(output_tokens).decode() logger.info(f"Generated completion: {completion}") return CompletionResponse( completion=completion, usage={"prompt_tokens": self.cached_tokens, "completion_tokens": len(output_tokens)}, ) @property def saveable_state(self) -> bytes: logger.info("Serializing cache state") state_bytes = pickle.dumps(self.state) logger.info(f"Serialized state size: {len(state_bytes)} bytes") return state_bytes @classmethod def from_bytes(cls, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs) -> "LlamaCache": """Load a cache from its serialized state. Args: name: Name of the cache cache_bytes: Pickled state bytes metadata: Cache metadata including model info **kwargs: Additional arguments Returns: LlamaCache: Loaded cache instance """ logger.info(f"Loading cache from bytes with name={name}") logger.info(f"Cache metadata: {metadata}") # Create new instance with metadata # logger.info(f"Docs: {metadata['docs']}") docs = [json.loads(doc) for doc in metadata["docs"]] # time.sleep(10) cache = cls( name=name, model=metadata["model"], gguf_file=metadata["model_file"], filters=metadata["filters"], docs=[Document(**doc) for doc in docs], ) # Load the saved state logger.info(f"Loading saved KV cache state of size {len(cache_bytes)} bytes") cache.state = pickle.loads(cache_bytes) cache.llama.load_state(cache.state) logger.info("Cache successfully loaded from bytes") return cache ``` ## /core/cache/llama_cache_factory.py ```py path="/core/cache/llama_cache_factory.py" from typing import Any, Dict from core.cache.base_cache_factory import BaseCacheFactory from core.cache.llama_cache import LlamaCache class LlamaCacheFactory(BaseCacheFactory): def create_new_cache(self, name: str, model: str, model_file: str, **kwargs: Dict[str, Any]) -> LlamaCache: return LlamaCache(name, model, model_file, **kwargs) def load_cache_from_bytes( self, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs: Dict[str, Any] ) -> LlamaCache: return LlamaCache.from_bytes(name, cache_bytes, metadata, **kwargs) ``` ## /core/completion/__init__.py ```py path="/core/completion/__init__.py" from core.completion.base_completion import BaseCompletionModel from core.completion.litellm_completion import LiteLLMCompletionModel __all__ = ["BaseCompletionModel", "LiteLLMCompletionModel"] ``` ## /core/completion/base_completion.py ```py path="/core/completion/base_completion.py" from abc import ABC, abstractmethod from core.models.completion import CompletionRequest, CompletionResponse class BaseCompletionModel(ABC): """Base class for completion models""" @abstractmethod async def complete(self, request: CompletionRequest) -> CompletionResponse: """Generate completion from query and context""" pass ``` ## /core/completion/litellm_completion.py ```py path="/core/completion/litellm_completion.py" import logging import re # Import re for parsing model name from typing import Any, Dict, List, Optional, Tuple, Union import litellm try: import ollama except ImportError: ollama = None # Make ollama import optional from pydantic import BaseModel from core.config import get_settings from core.models.completion import CompletionRequest, CompletionResponse from .base_completion import BaseCompletionModel logger = logging.getLogger(__name__) def get_system_message() -> Dict[str, str]: """Return the standard system message for Morphik's query agent.""" return { "role": "system", "content": """You are Morphik's powerful query agent. Your role is to: 1. Analyze the provided context chunks from documents carefully 2. Use the context to answer questions accurately and comprehensively 3. Be clear and concise in your answers 4. When relevant, cite specific parts of the context to support your answers 5. For image-based queries, analyze the visual content in conjunction with any text context provided Remember: Your primary goal is to provide accurate, context-aware responses that help users understand and utilize the information in their documents effectively.""", } def process_context_chunks(context_chunks: List[str], is_ollama: bool) -> Tuple[List[str], List[str], List[str]]: """ Process context chunks and separate text from images. Args: context_chunks: List of context chunks which may include images is_ollama: Whether we're using Ollama (affects image processing) Returns: Tuple of (context_text, image_urls, ollama_image_data) """ context_text = [] image_urls = [] # For non-Ollama models (full data URI) ollama_image_data = [] # For Ollama models (raw base64) for chunk in context_chunks: if chunk.startswith("data:image/"): if is_ollama: # For Ollama, strip the data URI prefix and just keep the base64 data try: base64_data = chunk.split(",", 1)[1] ollama_image_data.append(base64_data) except IndexError: logger.warning(f"Could not parse base64 data from image chunk: {chunk[:50]}...") else: image_urls.append(chunk) else: context_text.append(chunk) return context_text, image_urls, ollama_image_data def format_user_content(context_text: List[str], query: str, prompt_template: Optional[str] = None) -> str: """ Format the user content based on context and query. Args: context_text: List of context text chunks query: The user query prompt_template: Optional template to format the content Returns: Formatted user content string """ context = "\n" + "\n\n".join(context_text) + "\n\n" if context_text else "" if prompt_template: return prompt_template.format( context=context, question=query, query=query, ) elif context_text: return f"Context: {context} Question: {query}" else: return query def create_dynamic_model_from_schema(schema: Union[type, Dict]) -> Optional[type]: """ Create a dynamic Pydantic model from a schema definition. Args: schema: Either a Pydantic BaseModel class or a JSON schema dict Returns: A Pydantic model class or None if schema format is not recognized """ from pydantic import create_model if isinstance(schema, type) and issubclass(schema, BaseModel): return schema elif isinstance(schema, dict) and "properties" in schema: # Create a dynamic model from JSON schema field_definitions = {} schema_dict = schema for field_name, field_info in schema_dict.get("properties", {}).items(): if isinstance(field_info, dict) and "type" in field_info: field_type = field_info.get("type") # Convert schema types to Python types if field_type == "string": field_definitions[field_name] = (str, None) elif field_type == "number": field_definitions[field_name] = (float, None) elif field_type == "integer": field_definitions[field_name] = (int, None) elif field_type == "boolean": field_definitions[field_name] = (bool, None) elif field_type == "array": field_definitions[field_name] = (list, None) elif field_type == "object": field_definitions[field_name] = (dict, None) else: # Default to Any for unknown types field_definitions[field_name] = (Any, None) # Create the dynamic model return create_model("DynamicQueryModel", **field_definitions) else: logger.warning(f"Unrecognized schema format: {schema}") return None class LiteLLMCompletionModel(BaseCompletionModel): """ LiteLLM completion model implementation that provides unified access to various LLM providers. Uses registered models from the config file. Can optionally use direct Ollama client. """ def __init__(self, model_key: str): """ Initialize LiteLLM completion model with a model key from registered_models. Args: model_key: The key of the model in the registered_models config """ settings = get_settings() self.model_key = model_key # Get the model configuration from registered_models if not hasattr(settings, "REGISTERED_MODELS") or model_key not in settings.REGISTERED_MODELS: raise ValueError(f"Model '{model_key}' not found in registered_models configuration") self.model_config = settings.REGISTERED_MODELS[model_key] # Check if it's an Ollama model for potential direct usage self.is_ollama = "ollama" in self.model_config.get("model_name", "").lower() self.ollama_api_base = None self.ollama_base_model_name = None if self.is_ollama: if ollama is None: logger.warning("Ollama model selected, but 'ollama' library not installed. Falling back to LiteLLM.") self.is_ollama = False # Fallback to LiteLLM if library missing else: self.ollama_api_base = self.model_config.get("api_base") if not self.ollama_api_base: logger.warning( f"Ollama model {self.model_key} selected for direct use, " "but 'api_base' is missing in config. Falling back to LiteLLM." ) self.is_ollama = False # Fallback if api_base is missing else: # Extract base model name (e.g., 'llama3.2' from 'ollama_chat/llama3.2') match = re.search(r"[^/]+$", self.model_config["model_name"]) if match: self.ollama_base_model_name = match.group(0) else: logger.warning( f"Could not parse base model name from Ollama model " f"{self.model_config['model_name']}. Falling back to LiteLLM." ) self.is_ollama = False # Fallback if name parsing fails logger.info( f"Initialized LiteLLM completion model with model_key={model_key}, " f"config={self.model_config}, is_ollama_direct={self.is_ollama}" ) async def _handle_structured_ollama( self, dynamic_model: type, system_message: Dict[str, str], user_content: str, ollama_image_data: List[str], request: CompletionRequest, ) -> CompletionResponse: """Handle structured output generation with Ollama.""" try: client = ollama.AsyncClient(host=self.ollama_api_base) # Add images directly to content if available content_data = user_content if ollama_image_data and len(ollama_image_data) > 0: # Ollama image handling is limited; we can use only the first image content_data = {"content": user_content, "images": [ollama_image_data[0]]} # Create messages for Ollama messages = [system_message, {"role": "user", "content": content_data}] # Get the JSON schema from the dynamic model format_schema = dynamic_model.model_json_schema() # Call Ollama directly with format parameter response = await client.chat( model=self.ollama_base_model_name, messages=messages, format=format_schema, options={ "temperature": request.temperature or 0.1, # Lower temperature for structured output "num_predict": request.max_tokens, }, ) # Parse the response into the dynamic model parsed_response = dynamic_model.model_validate_json(response["message"]["content"]) # Extract token usage information usage = { "prompt_tokens": response.get("prompt_eval_count", 0), "completion_tokens": response.get("eval_count", 0), "total_tokens": response.get("prompt_eval_count", 0) + response.get("eval_count", 0), } return CompletionResponse( completion=parsed_response, usage=usage, finish_reason=response.get("done_reason", "stop"), ) except Exception as e: logger.error(f"Error using Ollama for structured output: {e}") # Fall back to standard completion if structured output fails logger.warning("Falling back to standard Ollama completion without structured output") return None async def _handle_structured_litellm( self, dynamic_model: type, system_message: Dict[str, str], user_content: str, image_urls: List[str], request: CompletionRequest, ) -> CompletionResponse: """Handle structured output generation with LiteLLM.""" import instructor from instructor import Mode try: # Use instructor with litellm client = instructor.from_litellm(litellm.acompletion, mode=Mode.JSON) # Create content list with text and images content_list = [{"type": "text", "text": user_content}] # Add images if available if image_urls: NUM_IMAGES = min(3, len(image_urls)) for img_url in image_urls[:NUM_IMAGES]: content_list.append({"type": "image_url", "image_url": {"url": img_url}}) # Create messages for instructor messages = [system_message, {"role": "user", "content": content_list}] # Extract model configuration model = self.model_config.get("model_name") model_kwargs = {k: v for k, v in self.model_config.items() if k != "model_name"} # Override with completion request parameters if request.temperature is not None: model_kwargs["temperature"] = request.temperature if request.max_tokens is not None: model_kwargs["max_tokens"] = request.max_tokens # Add format forcing for structured output model_kwargs["response_format"] = {"type": "json_object"} # Call instructor with litellm response = await client.chat.completions.create( model=model, messages=messages, response_model=dynamic_model, **model_kwargs, ) # Get token usage from response completion_tokens = model_kwargs.get("response_tokens", 0) prompt_tokens = model_kwargs.get("prompt_tokens", 0) return CompletionResponse( completion=response, usage={ "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, finish_reason="stop", ) except Exception as e: logger.error(f"Error using instructor with LiteLLM: {e}") # Fall back to standard completion if instructor fails logger.warning("Falling back to standard LiteLLM completion without structured output") return None async def _handle_standard_ollama( self, user_content: str, ollama_image_data: List[str], request: CompletionRequest ) -> CompletionResponse: """Handle standard (non-structured) output generation with Ollama.""" logger.debug(f"Using direct Ollama client for model: {self.ollama_base_model_name}") client = ollama.AsyncClient(host=self.ollama_api_base) # Construct Ollama messages system_message = {"role": "system", "content": get_system_message()["content"]} user_message_data = {"role": "user", "content": user_content} # Add images directly to the user message if available if ollama_image_data: if len(ollama_image_data) > 1: logger.warning( f"Ollama model {self.model_config['model_name']} only supports one image per message. " "Using the first image and ignoring others." ) # Add 'images' key inside the user message dictionary user_message_data["images"] = [ollama_image_data[0]] ollama_messages = [system_message, user_message_data] # Construct Ollama options options = { "temperature": request.temperature, "num_predict": ( request.max_tokens if request.max_tokens is not None else -1 ), # Default to model's default if None } try: response = await client.chat(model=self.ollama_base_model_name, messages=ollama_messages, options=options) # Map Ollama response to CompletionResponse prompt_tokens = response.get("prompt_eval_count", 0) completion_tokens = response.get("eval_count", 0) return CompletionResponse( completion=response["message"]["content"], usage={ "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, finish_reason=response.get("done_reason", "unknown"), # Map done_reason if available ) except Exception as e: logger.error(f"Error during direct Ollama call: {e}") raise async def _handle_standard_litellm( self, user_content: str, image_urls: List[str], request: CompletionRequest ) -> CompletionResponse: """Handle standard (non-structured) output generation with LiteLLM.""" logger.debug(f"Using LiteLLM for model: {self.model_config['model_name']}") # Build messages for LiteLLM content_list = [{"type": "text", "text": user_content}] include_images = image_urls # Use the collected full data URIs if include_images: NUM_IMAGES = min(3, len(image_urls)) for img_url in image_urls[:NUM_IMAGES]: content_list.append({"type": "image_url", "image_url": {"url": img_url}}) # LiteLLM uses list content format user_message = {"role": "user", "content": content_list} # Use the system prompt defined earlier litellm_messages = [get_system_message(), user_message] # Prepare LiteLLM parameters model_params = { "model": self.model_config["model_name"], "messages": litellm_messages, "max_tokens": request.max_tokens, "temperature": request.temperature, "num_retries": 3, } for key, value in self.model_config.items(): if key != "model_name": model_params[key] = value logger.debug(f"Calling LiteLLM with params: {model_params}") response = await litellm.acompletion(**model_params) return CompletionResponse( completion=response.choices[0].message.content, usage={ "prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens, "total_tokens": response.usage.total_tokens, }, finish_reason=response.choices[0].finish_reason, ) async def complete(self, request: CompletionRequest) -> CompletionResponse: """ Generate completion using LiteLLM or direct Ollama client if configured. Args: request: CompletionRequest object containing query, context, and parameters Returns: CompletionResponse object with the generated text and usage statistics """ # Process context chunks and handle images context_text, image_urls, ollama_image_data = process_context_chunks(request.context_chunks, self.is_ollama) # Format user content user_content = format_user_content(context_text, request.query, request.prompt_template) # Check if structured output is requested structured_output = request.schema is not None # If structured output is requested, use instructor to handle it if structured_output: # Get dynamic model from schema dynamic_model = create_dynamic_model_from_schema(request.schema) # If schema format is not recognized, log warning and fall back to text completion if not dynamic_model: logger.warning(f"Unrecognized schema format: {request.schema}. Falling back to text completion.") structured_output = False else: logger.info(f"Using structured output with model: {dynamic_model.__name__}") # Create system and user messages with enhanced instructions for structured output system_message = { "role": "system", "content": get_system_message()["content"] + "\n\nYou MUST format your response according to the required schema.", } # Create enhanced user message that includes schema information enhanced_user_content = ( user_content + "\n\nPlease format your response according to the required schema." ) # Try structured output based on model type if self.is_ollama: response = await self._handle_structured_ollama( dynamic_model, system_message, enhanced_user_content, ollama_image_data, request ) if response: return response structured_output = False # Fall back if structured output failed else: response = await self._handle_structured_litellm( dynamic_model, system_message, enhanced_user_content, image_urls, request ) if response: return response structured_output = False # Fall back if structured output failed # If we're here, either structured output wasn't requested or instructor failed # Proceed with standard completion based on model type if self.is_ollama: return await self._handle_standard_ollama(user_content, ollama_image_data, request) else: return await self._handle_standard_litellm(user_content, image_urls, request) ``` ## /core/config.py ```py path="/core/config.py" import os from collections import ChainMap from functools import lru_cache from typing import Any, Dict, Literal, Optional import tomli from dotenv import load_dotenv from pydantic_settings import BaseSettings class Settings(BaseSettings): """Morphik configuration settings.""" # Environment variables JWT_SECRET_KEY: str POSTGRES_URI: Optional[str] = None UNSTRUCTURED_API_KEY: Optional[str] = None AWS_ACCESS_KEY: Optional[str] = None AWS_SECRET_ACCESS_KEY: Optional[str] = None OPENAI_API_KEY: Optional[str] = None ANTHROPIC_API_KEY: Optional[str] = None ASSEMBLYAI_API_KEY: Optional[str] = None # API configuration HOST: str PORT: int RELOAD: bool # Auth configuration JWT_ALGORITHM: str dev_mode: bool = False dev_entity_type: str = "developer" dev_entity_id: str = "dev_user" dev_permissions: list = ["read", "write", "admin"] # Registered models configuration REGISTERED_MODELS: Dict[str, Dict[str, Any]] = {} # Completion configuration COMPLETION_PROVIDER: Literal["litellm"] = "litellm" COMPLETION_MODEL: str # Database configuration DATABASE_PROVIDER: Literal["postgres"] DATABASE_NAME: Optional[str] = None # Database connection pool settings DB_POOL_SIZE: int = 20 DB_MAX_OVERFLOW: int = 30 DB_POOL_RECYCLE: int = 3600 DB_POOL_TIMEOUT: int = 10 DB_POOL_PRE_PING: bool = True DB_MAX_RETRIES: int = 3 DB_RETRY_DELAY: float = 1.0 # Embedding configuration EMBEDDING_PROVIDER: Literal["litellm"] = "litellm" EMBEDDING_MODEL: str VECTOR_DIMENSIONS: int EMBEDDING_SIMILARITY_METRIC: Literal["cosine", "dotProduct"] # Parser configuration CHUNK_SIZE: int CHUNK_OVERLAP: int USE_UNSTRUCTURED_API: bool FRAME_SAMPLE_RATE: Optional[int] = None USE_CONTEXTUAL_CHUNKING: bool = False # Rules configuration RULES_PROVIDER: Literal["litellm"] = "litellm" RULES_MODEL: str RULES_BATCH_SIZE: int = 4096 # Graph configuration GRAPH_PROVIDER: Literal["litellm"] = "litellm" GRAPH_MODEL: str ENABLE_ENTITY_RESOLUTION: bool = True # Reranker configuration USE_RERANKING: bool RERANKER_PROVIDER: Optional[Literal["flag"]] = None RERANKER_MODEL: Optional[str] = None RERANKER_QUERY_MAX_LENGTH: Optional[int] = None RERANKER_PASSAGE_MAX_LENGTH: Optional[int] = None RERANKER_USE_FP16: Optional[bool] = None RERANKER_DEVICE: Optional[str] = None # Storage configuration STORAGE_PROVIDER: Literal["local", "aws-s3"] STORAGE_PATH: Optional[str] = None AWS_REGION: Optional[str] = None S3_BUCKET: Optional[str] = None # Vector store configuration VECTOR_STORE_PROVIDER: Literal["pgvector"] VECTOR_STORE_DATABASE_NAME: Optional[str] = None # Colpali configuration ENABLE_COLPALI: bool # Mode configuration MODE: Literal["cloud", "self_hosted"] = "cloud" # API configuration API_DOMAIN: str = "api.morphik.ai" # Redis configuration REDIS_HOST: str = "localhost" REDIS_PORT: int = 6379 # Telemetry configuration TELEMETRY_ENABLED: bool = True HONEYCOMB_ENABLED: bool = True HONEYCOMB_ENDPOINT: str = "https://api.honeycomb.io" HONEYCOMB_PROXY_ENDPOINT: str = "https://otel-proxy.onrender.com/" SERVICE_NAME: str = "morphik-core" OTLP_TIMEOUT: int = 10 OTLP_MAX_RETRIES: int = 3 OTLP_RETRY_DELAY: int = 1 OTLP_MAX_EXPORT_BATCH_SIZE: int = 512 OTLP_SCHEDULE_DELAY_MILLIS: int = 5000 OTLP_MAX_QUEUE_SIZE: int = 2048 @lru_cache() def get_settings() -> Settings: """Get cached settings instance.""" load_dotenv(override=True) # Load config.toml with open("morphik.toml", "rb") as f: config = tomli.load(f) em = "'{missing_value}' needed if '{field}' is set to '{value}'" openai_config = {} # load api config api_config = { "HOST": config["api"]["host"], "PORT": int(config["api"]["port"]), "RELOAD": bool(config["api"]["reload"]), } # load auth config auth_config = { "JWT_ALGORITHM": config["auth"]["jwt_algorithm"], "JWT_SECRET_KEY": os.environ.get("JWT_SECRET_KEY", "dev-secret-key"), # Default for dev mode "dev_mode": config["auth"].get("dev_mode", False), "dev_entity_type": config["auth"].get("dev_entity_type", "developer"), "dev_entity_id": config["auth"].get("dev_entity_id", "dev_user"), "dev_permissions": config["auth"].get("dev_permissions", ["read", "write", "admin"]), } # Only require JWT_SECRET_KEY in non-dev mode if not auth_config["dev_mode"] and "JWT_SECRET_KEY" not in os.environ: raise ValueError("JWT_SECRET_KEY is required when dev_mode is disabled") # Load registered models if available registered_models = {} if "registered_models" in config: registered_models = {"REGISTERED_MODELS": config["registered_models"]} # load completion config completion_config = { "COMPLETION_PROVIDER": "litellm", } # Set the model key for LiteLLM if "model" not in config["completion"]: raise ValueError("'model' is required in the completion configuration") completion_config["COMPLETION_MODEL"] = config["completion"]["model"] # load database config database_config = { "DATABASE_PROVIDER": config["database"]["provider"], "DATABASE_NAME": config["database"].get("name", None), # Add database connection pool settings "DB_POOL_SIZE": config["database"].get("pool_size", 20), "DB_MAX_OVERFLOW": config["database"].get("max_overflow", 30), "DB_POOL_RECYCLE": config["database"].get("pool_recycle", 3600), "DB_POOL_TIMEOUT": config["database"].get("pool_timeout", 10), "DB_POOL_PRE_PING": config["database"].get("pool_pre_ping", True), "DB_MAX_RETRIES": config["database"].get("max_retries", 3), "DB_RETRY_DELAY": config["database"].get("retry_delay", 1.0), } if database_config["DATABASE_PROVIDER"] != "postgres": prov = database_config["DATABASE_PROVIDER"] raise ValueError(f"Unknown database provider selected: '{prov}'") if "POSTGRES_URI" in os.environ: database_config.update({"POSTGRES_URI": os.environ["POSTGRES_URI"]}) else: msg = em.format(missing_value="POSTGRES_URI", field="database.provider", value="postgres") raise ValueError(msg) # load embedding config embedding_config = { "EMBEDDING_PROVIDER": "litellm", "VECTOR_DIMENSIONS": config["embedding"]["dimensions"], "EMBEDDING_SIMILARITY_METRIC": config["embedding"]["similarity_metric"], } # Set the model key for LiteLLM if "model" not in config["embedding"]: raise ValueError("'model' is required in the embedding configuration") embedding_config["EMBEDDING_MODEL"] = config["embedding"]["model"] # load parser config parser_config = { "CHUNK_SIZE": config["parser"]["chunk_size"], "CHUNK_OVERLAP": config["parser"]["chunk_overlap"], "USE_UNSTRUCTURED_API": config["parser"]["use_unstructured_api"], "USE_CONTEXTUAL_CHUNKING": config["parser"].get("use_contextual_chunking", False), } if parser_config["USE_UNSTRUCTURED_API"] and "UNSTRUCTURED_API_KEY" not in os.environ: msg = em.format(missing_value="UNSTRUCTURED_API_KEY", field="parser.use_unstructured_api", value="true") raise ValueError(msg) elif parser_config["USE_UNSTRUCTURED_API"]: parser_config.update({"UNSTRUCTURED_API_KEY": os.environ["UNSTRUCTURED_API_KEY"]}) # load reranker config reranker_config = {"USE_RERANKING": config["reranker"]["use_reranker"]} if reranker_config["USE_RERANKING"]: reranker_config.update( { "RERANKER_PROVIDER": config["reranker"]["provider"], "RERANKER_MODEL": config["reranker"]["model_name"], "RERANKER_QUERY_MAX_LENGTH": config["reranker"]["query_max_length"], "RERANKER_PASSAGE_MAX_LENGTH": config["reranker"]["passage_max_length"], "RERANKER_USE_FP16": config["reranker"]["use_fp16"], "RERANKER_DEVICE": config["reranker"]["device"], } ) # load storage config storage_config = { "STORAGE_PROVIDER": config["storage"]["provider"], "STORAGE_PATH": config["storage"]["storage_path"], } match storage_config["STORAGE_PROVIDER"]: case "local": storage_config.update({"STORAGE_PATH": config["storage"]["storage_path"]}) case "aws-s3" if all(key in os.environ for key in ["AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY"]): storage_config.update( { "AWS_REGION": config["storage"]["region"], "S3_BUCKET": config["storage"]["bucket_name"], "AWS_ACCESS_KEY": os.environ["AWS_ACCESS_KEY"], "AWS_SECRET_ACCESS_KEY": os.environ["AWS_SECRET_ACCESS_KEY"], } ) case "aws-s3": msg = em.format(missing_value="AWS credentials", field="storage.provider", value="aws-s3") raise ValueError(msg) case _: prov = storage_config["STORAGE_PROVIDER"] raise ValueError(f"Unknown storage provider selected: '{prov}'") # load vector store config vector_store_config = {"VECTOR_STORE_PROVIDER": config["vector_store"]["provider"]} if vector_store_config["VECTOR_STORE_PROVIDER"] != "pgvector": prov = vector_store_config["VECTOR_STORE_PROVIDER"] raise ValueError(f"Unknown vector store provider selected: '{prov}'") if "POSTGRES_URI" not in os.environ: msg = em.format(missing_value="POSTGRES_URI", field="vector_store.provider", value="pgvector") raise ValueError(msg) # load rules config rules_config = { "RULES_PROVIDER": "litellm", "RULES_BATCH_SIZE": config["rules"]["batch_size"], } # Set the model key for LiteLLM if "model" not in config["rules"]: raise ValueError("'model' is required in the rules configuration") rules_config["RULES_MODEL"] = config["rules"]["model"] # load morphik config morphik_config = { "ENABLE_COLPALI": config["morphik"]["enable_colpali"], "MODE": config["morphik"].get("mode", "cloud"), # Default to "cloud" mode "API_DOMAIN": config["morphik"].get("api_domain", "api.morphik.ai"), # Default API domain } # load redis config redis_config = {} if "redis" in config: redis_config = { "REDIS_HOST": config["redis"].get("host", "localhost"), "REDIS_PORT": int(config["redis"].get("port", 6379)), } # load graph config graph_config = { "GRAPH_PROVIDER": "litellm", "ENABLE_ENTITY_RESOLUTION": config["graph"].get("enable_entity_resolution", True), } # Set the model key for LiteLLM if "model" not in config["graph"]: raise ValueError("'model' is required in the graph configuration") graph_config["GRAPH_MODEL"] = config["graph"]["model"] # load telemetry config telemetry_config = {} if "telemetry" in config: telemetry_config = { "TELEMETRY_ENABLED": config["telemetry"].get("enabled", True), "HONEYCOMB_ENABLED": config["telemetry"].get("honeycomb_enabled", True), "HONEYCOMB_ENDPOINT": config["telemetry"].get("honeycomb_endpoint", "https://api.honeycomb.io"), "SERVICE_NAME": config["telemetry"].get("service_name", "morphik-core"), "OTLP_TIMEOUT": config["telemetry"].get("otlp_timeout", 10), "OTLP_MAX_RETRIES": config["telemetry"].get("otlp_max_retries", 3), "OTLP_RETRY_DELAY": config["telemetry"].get("otlp_retry_delay", 1), "OTLP_MAX_EXPORT_BATCH_SIZE": config["telemetry"].get("otlp_max_export_batch_size", 512), "OTLP_SCHEDULE_DELAY_MILLIS": config["telemetry"].get("otlp_schedule_delay_millis", 5000), "OTLP_MAX_QUEUE_SIZE": config["telemetry"].get("otlp_max_queue_size", 2048), } settings_dict = dict( ChainMap( api_config, auth_config, registered_models, completion_config, database_config, embedding_config, parser_config, reranker_config, storage_config, vector_store_config, rules_config, morphik_config, redis_config, graph_config, telemetry_config, openai_config, ) ) return Settings(**settings_dict) ``` ## /core/database/base_database.py ```py path="/core/database/base_database.py" from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional from ..models.auth import AuthContext from ..models.documents import Document from ..models.folders import Folder from ..models.graph import Graph class BaseDatabase(ABC): """Base interface for document metadata storage.""" @abstractmethod async def store_document(self, document: Document) -> bool: """ Store document metadata. Returns: Success status """ pass @abstractmethod async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]: """ Retrieve document metadata by ID if user has access. Returns: Document if found and accessible, None otherwise """ pass @abstractmethod async def get_document_by_filename( self, filename: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None ) -> Optional[Document]: """ Retrieve document metadata by filename if user has access. If multiple documents have the same filename, returns the most recently updated one. Args: filename: The filename to search for auth: Authentication context system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) Returns: Document if found and accessible, None otherwise """ pass @abstractmethod async def get_documents_by_id( self, document_ids: List[str], auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None, ) -> List[Document]: """ Retrieve multiple documents by their IDs in a single batch operation. Only returns documents the user has access to. Can filter by system metadata fields like folder_name and end_user_id. Args: document_ids: List of document IDs to retrieve auth: Authentication context system_filters: Optional filters for system metadata fields Returns: List of Document objects that were found and user has access to """ pass @abstractmethod async def get_documents( self, auth: AuthContext, skip: int = 0, limit: int = 100, filters: Optional[Dict[str, Any]] = None, system_filters: Optional[Dict[str, Any]] = None, ) -> List[Document]: """ List documents the user has access to. Supports pagination and filtering. Args: auth: Authentication context skip: Number of documents to skip (for pagination) limit: Maximum number of documents to return filters: Optional metadata filters system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) Returns: List of documents matching the criteria """ pass @abstractmethod async def update_document(self, document_id: str, updates: Dict[str, Any], auth: AuthContext) -> bool: """ Update document metadata if user has access. Returns: Success status """ pass @abstractmethod async def delete_document(self, document_id: str, auth: AuthContext) -> bool: """ Delete document metadata if user has admin access. Returns: Success status """ pass @abstractmethod async def find_authorized_and_filtered_documents( self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None, system_filters: Optional[Dict[str, Any]] = None, ) -> List[str]: """Find document IDs matching filters that user has access to. Args: auth: Authentication context filters: Optional metadata filters system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) Returns: List of document IDs matching the criteria """ pass @abstractmethod async def check_access(self, document_id: str, auth: AuthContext, required_permission: str = "read") -> bool: """ Check if user has required permission for document. Returns: True if user has required access, False otherwise """ pass @abstractmethod async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool: """Store metadata for a cache. Args: name: Name of the cache metadata: Cache metadata including model info and storage location Returns: bool: Whether the operation was successful """ pass @abstractmethod async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]: """Get metadata for a cache. Args: name: Name of the cache Returns: Optional[Dict[str, Any]]: Cache metadata if found, None otherwise """ pass @abstractmethod async def store_graph(self, graph: Graph) -> bool: """Store a graph. Args: graph: Graph to store Returns: bool: Whether the operation was successful """ pass @abstractmethod async def get_graph( self, name: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None ) -> Optional[Graph]: """Get a graph by name. Args: name: Name of the graph auth: Authentication context system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) Returns: Optional[Graph]: Graph if found and accessible, None otherwise """ pass @abstractmethod async def list_graphs(self, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Graph]: """List all graphs the user has access to. Args: auth: Authentication context system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) Returns: List[Graph]: List of graphs """ pass @abstractmethod async def update_graph(self, graph: Graph) -> bool: """Update an existing graph. Args: graph: Graph to update Returns: bool: Whether the operation was successful """ pass @abstractmethod async def create_folder(self, folder: Folder) -> bool: """Create a new folder. Args: folder: Folder to create Returns: bool: Whether the operation was successful """ pass @abstractmethod async def get_folder(self, folder_id: str, auth: AuthContext) -> Optional[Folder]: """Get a folder by ID. Args: folder_id: ID of the folder auth: Authentication context Returns: Optional[Folder]: Folder if found and accessible, None otherwise """ pass @abstractmethod async def get_folder_by_name(self, name: str, auth: AuthContext) -> Optional[Folder]: """Get a folder by name. Args: name: Name of the folder auth: Authentication context Returns: Optional[Folder]: Folder if found and accessible, None otherwise """ pass @abstractmethod async def list_folders(self, auth: AuthContext) -> List[Folder]: """List all folders the user has access to. Args: auth: Authentication context Returns: List[Folder]: List of folders """ pass @abstractmethod async def add_document_to_folder(self, folder_id: str, document_id: str, auth: AuthContext) -> bool: """Add a document to a folder. Args: folder_id: ID of the folder document_id: ID of the document auth: Authentication context Returns: bool: Whether the operation was successful """ pass @abstractmethod async def remove_document_from_folder(self, folder_id: str, document_id: str, auth: AuthContext) -> bool: """Remove a document from a folder. Args: folder_id: ID of the folder document_id: ID of the document auth: Authentication context Returns: bool: Whether the operation was successful """ pass ``` ## /core/database/postgres_database.py ```py path="/core/database/postgres_database.py" import json import logging from datetime import UTC, datetime from typing import Any, Dict, List, Optional from sqlalchemy import Column, Index, String, select, text from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import declarative_base, sessionmaker from ..models.auth import AuthContext from ..models.documents import Document, StorageFileInfo from ..models.folders import Folder from ..models.graph import Graph from .base_database import BaseDatabase logger = logging.getLogger(__name__) Base = declarative_base() class DocumentModel(Base): """SQLAlchemy model for document metadata.""" __tablename__ = "documents" external_id = Column(String, primary_key=True) owner = Column(JSONB) content_type = Column(String) filename = Column(String, nullable=True) doc_metadata = Column(JSONB, default=dict) storage_info = Column(JSONB, default=dict) system_metadata = Column(JSONB, default=dict) additional_metadata = Column(JSONB, default=dict) access_control = Column(JSONB, default=dict) chunk_ids = Column(JSONB, default=list) storage_files = Column(JSONB, default=list) # Create indexes __table_args__ = ( Index("idx_owner_id", "owner", postgresql_using="gin"), Index("idx_access_control", "access_control", postgresql_using="gin"), Index("idx_system_metadata", "system_metadata", postgresql_using="gin"), ) class GraphModel(Base): """SQLAlchemy model for graph data.""" __tablename__ = "graphs" id = Column(String, primary_key=True) name = Column(String, index=True) # Not unique globally anymore entities = Column(JSONB, default=list) relationships = Column(JSONB, default=list) graph_metadata = Column(JSONB, default=dict) # Renamed from 'metadata' to avoid conflict system_metadata = Column(JSONB, default=dict) # For folder_name and end_user_id document_ids = Column(JSONB, default=list) filters = Column(JSONB, nullable=True) created_at = Column(String) # ISO format string updated_at = Column(String) # ISO format string owner = Column(JSONB) access_control = Column(JSONB, default=dict) # Create indexes __table_args__ = ( Index("idx_graph_name", "name"), Index("idx_graph_owner", "owner", postgresql_using="gin"), Index("idx_graph_access_control", "access_control", postgresql_using="gin"), Index("idx_graph_system_metadata", "system_metadata", postgresql_using="gin"), # Create a unique constraint on name scoped by owner ID Index("idx_graph_owner_name", "name", text("(owner->>'id')"), unique=True), ) class FolderModel(Base): """SQLAlchemy model for folder data.""" __tablename__ = "folders" id = Column(String, primary_key=True) name = Column(String, index=True) description = Column(String, nullable=True) owner = Column(JSONB) document_ids = Column(JSONB, default=list) system_metadata = Column(JSONB, default=dict) access_control = Column(JSONB, default=dict) rules = Column(JSONB, default=list) # Create indexes __table_args__ = ( Index("idx_folder_name", "name"), Index("idx_folder_owner", "owner", postgresql_using="gin"), Index("idx_folder_access_control", "access_control", postgresql_using="gin"), ) def _serialize_datetime(obj: Any) -> Any: """Helper function to serialize datetime objects to ISO format strings.""" if isinstance(obj, datetime): return obj.isoformat() elif isinstance(obj, dict): return {key: _serialize_datetime(value) for key, value in obj.items()} elif isinstance(obj, list): return [_serialize_datetime(item) for item in obj] return obj class PostgresDatabase(BaseDatabase): """PostgreSQL implementation for document metadata storage.""" def __init__( self, uri: str, ): """Initialize PostgreSQL connection for document storage.""" # Load settings from config from core.config import get_settings settings = get_settings() # Get database pool settings from config with defaults pool_size = getattr(settings, "DB_POOL_SIZE", 20) max_overflow = getattr(settings, "DB_MAX_OVERFLOW", 30) pool_recycle = getattr(settings, "DB_POOL_RECYCLE", 3600) pool_timeout = getattr(settings, "DB_POOL_TIMEOUT", 10) pool_pre_ping = getattr(settings, "DB_POOL_PRE_PING", True) logger.info( f"Initializing PostgreSQL connection pool with size={pool_size}, " f"max_overflow={max_overflow}, pool_recycle={pool_recycle}s" ) # Create async engine with explicit pool settings self.engine = create_async_engine( uri, # Prevent connection timeouts by keeping connections alive pool_pre_ping=pool_pre_ping, # Increase pool size to handle concurrent operations pool_size=pool_size, # Maximum overflow connections allowed beyond pool_size max_overflow=max_overflow, # Keep connections in the pool for up to 60 minutes pool_recycle=pool_recycle, # Time to wait for a connection from the pool (10 seconds) pool_timeout=pool_timeout, # Echo SQL for debugging (set to False in production) echo=False, ) self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False) self._initialized = False async def initialize(self): """Initialize database tables and indexes.""" if self._initialized: return True try: logger.info("Initializing PostgreSQL database tables and indexes...") # Create ORM models async with self.engine.begin() as conn: # Explicitly create all tables with checkfirst=True to avoid errors if tables already exist await conn.run_sync(lambda conn: Base.metadata.create_all(conn, checkfirst=True)) # No need to manually create graphs table again since SQLAlchemy does it logger.info("Created database tables successfully") # Create caches table if it doesn't exist (kept as direct SQL for backward compatibility) await conn.execute( text( """ CREATE TABLE IF NOT EXISTS caches ( name TEXT PRIMARY KEY, metadata JSONB NOT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP ) """ ) ) # Check if storage_files column exists result = await conn.execute( text( """ SELECT column_name FROM information_schema.columns WHERE table_name = 'documents' AND column_name = 'storage_files' """ ) ) if not result.first(): # Add storage_files column to documents table await conn.execute( text( """ ALTER TABLE documents ADD COLUMN IF NOT EXISTS storage_files JSONB DEFAULT '[]'::jsonb """ ) ) logger.info("Added storage_files column to documents table") # Create indexes for folder_name and end_user_id in system_metadata for documents await conn.execute( text( """ CREATE INDEX IF NOT EXISTS idx_system_metadata_folder_name ON documents ((system_metadata->>'folder_name')); """ ) ) # Create folders table if it doesn't exist await conn.execute( text( """ CREATE TABLE IF NOT EXISTS folders ( id TEXT PRIMARY KEY, name TEXT, description TEXT, owner JSONB, document_ids JSONB DEFAULT '[]', system_metadata JSONB DEFAULT '{}', access_control JSONB DEFAULT '{}' ); """ ) ) # Add rules column to folders table if it doesn't exist result = await conn.execute( text( """ SELECT column_name FROM information_schema.columns WHERE table_name = 'folders' AND column_name = 'rules' """ ) ) if not result.first(): # Add rules column to folders table await conn.execute( text( """ ALTER TABLE folders ADD COLUMN IF NOT EXISTS rules JSONB DEFAULT '[]'::jsonb """ ) ) logger.info("Added rules column to folders table") # Create indexes for folders table await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_folder_name ON folders (name);")) await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_folder_owner ON folders USING gin (owner);")) await conn.execute( text("CREATE INDEX IF NOT EXISTS idx_folder_access_control ON folders USING gin (access_control);") ) await conn.execute( text( """ CREATE INDEX IF NOT EXISTS idx_system_metadata_end_user_id ON documents ((system_metadata->>'end_user_id')); """ ) ) # Check if system_metadata column exists in graphs table result = await conn.execute( text( """ SELECT column_name FROM information_schema.columns WHERE table_name = 'graphs' AND column_name = 'system_metadata' """ ) ) if not result.first(): # Add system_metadata column to graphs table await conn.execute( text( """ ALTER TABLE graphs ADD COLUMN IF NOT EXISTS system_metadata JSONB DEFAULT '{}'::jsonb """ ) ) logger.info("Added system_metadata column to graphs table") # Create indexes for folder_name and end_user_id in system_metadata for graphs await conn.execute( text( """ CREATE INDEX IF NOT EXISTS idx_graph_system_metadata_folder_name ON graphs ((system_metadata->>'folder_name')); """ ) ) await conn.execute( text( """ CREATE INDEX IF NOT EXISTS idx_graph_system_metadata_end_user_id ON graphs ((system_metadata->>'end_user_id')); """ ) ) logger.info("Created indexes for folder_name and end_user_id in system_metadata") logger.info("PostgreSQL tables and indexes created successfully") self._initialized = True return True except Exception as e: logger.error(f"Error creating PostgreSQL tables and indexes: {str(e)}") return False async def store_document(self, document: Document) -> bool: """Store document metadata.""" try: doc_dict = document.model_dump() # Rename metadata to doc_metadata if "metadata" in doc_dict: doc_dict["doc_metadata"] = doc_dict.pop("metadata") doc_dict["doc_metadata"]["external_id"] = doc_dict["external_id"] # Ensure system metadata if "system_metadata" not in doc_dict: doc_dict["system_metadata"] = {} doc_dict["system_metadata"]["created_at"] = datetime.now(UTC) doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC) # Handle storage_files if "storage_files" in doc_dict and doc_dict["storage_files"]: # Convert storage_files to the expected format for storage doc_dict["storage_files"] = [file.model_dump() for file in doc_dict["storage_files"]] # Serialize datetime objects to ISO format strings doc_dict = _serialize_datetime(doc_dict) async with self.async_session() as session: doc_model = DocumentModel(**doc_dict) session.add(doc_model) await session.commit() return True except Exception as e: logger.error(f"Error storing document metadata: {str(e)}") return False async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]: """Retrieve document metadata by ID if user has access.""" try: async with self.async_session() as session: # Build access filter access_filter = self._build_access_filter(auth) # Query document query = ( select(DocumentModel) .where(DocumentModel.external_id == document_id) .where(text(f"({access_filter})")) ) result = await session.execute(query) doc_model = result.scalar_one_or_none() if doc_model: # Convert doc_metadata back to metadata # Also convert storage_files from dict to StorageFileInfo storage_files = [] if doc_model.storage_files: for file_info in doc_model.storage_files: if isinstance(file_info, dict): storage_files.append(StorageFileInfo(**file_info)) else: storage_files.append(file_info) doc_dict = { "external_id": doc_model.external_id, "owner": doc_model.owner, "content_type": doc_model.content_type, "filename": doc_model.filename, "metadata": doc_model.doc_metadata, "storage_info": doc_model.storage_info, "system_metadata": doc_model.system_metadata, "additional_metadata": doc_model.additional_metadata, "access_control": doc_model.access_control, "chunk_ids": doc_model.chunk_ids, "storage_files": storage_files, } return Document(**doc_dict) return None except Exception as e: logger.error(f"Error retrieving document metadata: {str(e)}") return None async def get_document_by_filename( self, filename: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None ) -> Optional[Document]: """Retrieve document metadata by filename if user has access. If multiple documents have the same filename, returns the most recently updated one. Args: filename: The filename to search for auth: Authentication context system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) """ try: async with self.async_session() as session: # Build access filter access_filter = self._build_access_filter(auth) system_metadata_filter = self._build_system_metadata_filter(system_filters) filename = filename.replace("'", "''") # Construct where clauses where_clauses = [ f"({access_filter})", f"filename = '{filename}'", # Escape single quotes ] if system_metadata_filter: where_clauses.append(f"({system_metadata_filter})") final_where_clause = " AND ".join(where_clauses) # Query document with system filters query = ( select(DocumentModel).where(text(final_where_clause)) # Order by updated_at in system_metadata to get the most recent document .order_by(text("system_metadata->>'updated_at' DESC")) ) logger.debug(f"Querying document by filename with system filters: {system_filters}") result = await session.execute(query) doc_model = result.scalar_one_or_none() if doc_model: # Convert doc_metadata back to metadata # Also convert storage_files from dict to StorageFileInfo storage_files = [] if doc_model.storage_files: for file_info in doc_model.storage_files: if isinstance(file_info, dict): storage_files.append(StorageFileInfo(**file_info)) else: storage_files.append(file_info) doc_dict = { "external_id": doc_model.external_id, "owner": doc_model.owner, "content_type": doc_model.content_type, "filename": doc_model.filename, "metadata": doc_model.doc_metadata, "storage_info": doc_model.storage_info, "system_metadata": doc_model.system_metadata, "additional_metadata": doc_model.additional_metadata, "access_control": doc_model.access_control, "chunk_ids": doc_model.chunk_ids, "storage_files": storage_files, } return Document(**doc_dict) return None except Exception as e: logger.error(f"Error retrieving document metadata by filename: {str(e)}") return None async def get_documents_by_id( self, document_ids: List[str], auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None, ) -> List[Document]: """ Retrieve multiple documents by their IDs in a single batch operation. Only returns documents the user has access to. Can filter by system metadata fields like folder_name and end_user_id. Args: document_ids: List of document IDs to retrieve auth: Authentication context system_filters: Optional filters for system metadata fields Returns: List of Document objects that were found and user has access to """ try: if not document_ids: return [] async with self.async_session() as session: # Build access filter access_filter = self._build_access_filter(auth) system_metadata_filter = self._build_system_metadata_filter(system_filters) # Construct where clauses document_ids_linked = ", ".join([("'" + doc_id + "'") for doc_id in document_ids]) where_clauses = [f"({access_filter})", f"external_id IN ({document_ids_linked})"] if system_metadata_filter: where_clauses.append(f"({system_metadata_filter})") final_where_clause = " AND ".join(where_clauses) # Query documents with document IDs, access check, and system filters in a single query query = select(DocumentModel).where(text(final_where_clause)) logger.info(f"Batch retrieving {len(document_ids)} documents with a single query") # Execute batch query result = await session.execute(query) doc_models = result.scalars().all() documents = [] for doc_model in doc_models: # Convert doc_metadata back to metadata doc_dict = { "external_id": doc_model.external_id, "owner": doc_model.owner, "content_type": doc_model.content_type, "filename": doc_model.filename, "metadata": doc_model.doc_metadata, "storage_info": doc_model.storage_info, "system_metadata": doc_model.system_metadata, "additional_metadata": doc_model.additional_metadata, "access_control": doc_model.access_control, "chunk_ids": doc_model.chunk_ids, "storage_files": doc_model.storage_files or [], } documents.append(Document(**doc_dict)) logger.info(f"Found {len(documents)} documents in batch retrieval") return documents except Exception as e: logger.error(f"Error batch retrieving documents: {str(e)}") return [] async def get_documents( self, auth: AuthContext, skip: int = 0, limit: int = 10000, filters: Optional[Dict[str, Any]] = None, system_filters: Optional[Dict[str, Any]] = None, ) -> List[Document]: """List documents the user has access to.""" try: async with self.async_session() as session: # Build query access_filter = self._build_access_filter(auth) metadata_filter = self._build_metadata_filter(filters) system_metadata_filter = self._build_system_metadata_filter(system_filters) where_clauses = [f"({access_filter})"] if metadata_filter: where_clauses.append(f"({metadata_filter})") if system_metadata_filter: where_clauses.append(f"({system_metadata_filter})") final_where_clause = " AND ".join(where_clauses) query = select(DocumentModel).where(text(final_where_clause)) query = query.offset(skip).limit(limit) result = await session.execute(query) doc_models = result.scalars().all() return [ Document( external_id=doc.external_id, owner=doc.owner, content_type=doc.content_type, filename=doc.filename, metadata=doc.doc_metadata, storage_info=doc.storage_info, system_metadata=doc.system_metadata, additional_metadata=doc.additional_metadata, access_control=doc.access_control, chunk_ids=doc.chunk_ids, storage_files=doc.storage_files or [], ) for doc in doc_models ] except Exception as e: logger.error(f"Error listing documents: {str(e)}") return [] async def update_document(self, document_id: str, updates: Dict[str, Any], auth: AuthContext) -> bool: """Update document metadata if user has write access.""" try: if not await self.check_access(document_id, auth, "write"): return False # Get existing document to preserve system_metadata existing_doc = await self.get_document(document_id, auth) if not existing_doc: return False # Update system metadata updates.setdefault("system_metadata", {}) # Merge with existing system_metadata instead of just preserving specific fields if existing_doc.system_metadata: # Start with existing system_metadata merged_system_metadata = dict(existing_doc.system_metadata) # Update with new values merged_system_metadata.update(updates["system_metadata"]) # Replace with merged result updates["system_metadata"] = merged_system_metadata logger.debug("Merged system_metadata during document update, preserving existing fields") # Always update the updated_at timestamp updates["system_metadata"]["updated_at"] = datetime.now(UTC) # Serialize datetime objects to ISO format strings updates = _serialize_datetime(updates) async with self.async_session() as session: result = await session.execute(select(DocumentModel).where(DocumentModel.external_id == document_id)) doc_model = result.scalar_one_or_none() if doc_model: # Log what we're updating logger.info(f"Document update: updating fields {list(updates.keys())}") # Special handling for metadata/doc_metadata conversion if "metadata" in updates and "doc_metadata" not in updates: logger.info("Converting 'metadata' to 'doc_metadata' for database update") updates["doc_metadata"] = updates.pop("metadata") # Set all attributes for key, value in updates.items(): if key == "storage_files" and isinstance(value, list): serialized_value = [ _serialize_datetime( item.model_dump() if hasattr(item, "model_dump") else (item.dict() if hasattr(item, "dict") else item) ) for item in value ] logger.debug("Serializing storage_files before setting attribute") setattr(doc_model, key, serialized_value) else: logger.debug(f"Setting document attribute {key} = {value}") setattr(doc_model, key, value) await session.commit() logger.info(f"Document {document_id} updated successfully") return True return False except Exception as e: logger.error(f"Error updating document metadata: {str(e)}") return False async def delete_document(self, document_id: str, auth: AuthContext) -> bool: """Delete document if user has write access.""" try: if not await self.check_access(document_id, auth, "write"): return False async with self.async_session() as session: result = await session.execute(select(DocumentModel).where(DocumentModel.external_id == document_id)) doc_model = result.scalar_one_or_none() if doc_model: await session.delete(doc_model) await session.commit() return True return False except Exception as e: logger.error(f"Error deleting document: {str(e)}") return False async def find_authorized_and_filtered_documents( self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None, system_filters: Optional[Dict[str, Any]] = None, ) -> List[str]: """Find document IDs matching filters and access permissions.""" try: async with self.async_session() as session: # Build query access_filter = self._build_access_filter(auth) metadata_filter = self._build_metadata_filter(filters) system_metadata_filter = self._build_system_metadata_filter(system_filters) logger.debug(f"Access filter: {access_filter}") logger.debug(f"Metadata filter: {metadata_filter}") logger.debug(f"System metadata filter: {system_metadata_filter}") logger.debug(f"Original filters: {filters}") logger.debug(f"System filters: {system_filters}") where_clauses = [f"({access_filter})"] if metadata_filter: where_clauses.append(f"({metadata_filter})") if system_metadata_filter: where_clauses.append(f"({system_metadata_filter})") final_where_clause = " AND ".join(where_clauses) query = select(DocumentModel.external_id).where(text(final_where_clause)) logger.debug(f"Final query: {query}") result = await session.execute(query) doc_ids = [row[0] for row in result.all()] logger.debug(f"Found document IDs: {doc_ids}") return doc_ids except Exception as e: logger.error(f"Error finding authorized documents: {str(e)}") return [] async def check_access(self, document_id: str, auth: AuthContext, required_permission: str = "read") -> bool: """Check if user has required permission for document.""" try: async with self.async_session() as session: result = await session.execute(select(DocumentModel).where(DocumentModel.external_id == document_id)) doc_model = result.scalar_one_or_none() if not doc_model: return False # Check owner access owner = doc_model.owner if owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id: return True # Check permission-specific access access_control = doc_model.access_control permission_map = {"read": "readers", "write": "writers", "admin": "admins"} permission_set = permission_map.get(required_permission) if not permission_set: return False return auth.entity_id in access_control.get(permission_set, []) except Exception as e: logger.error(f"Error checking document access: {str(e)}") return False def _build_access_filter(self, auth: AuthContext) -> str: """Build PostgreSQL filter for access control.""" filters = [ f"owner->>'id' = '{auth.entity_id}'", f"access_control->'readers' ? '{auth.entity_id}'", f"access_control->'writers' ? '{auth.entity_id}'", f"access_control->'admins' ? '{auth.entity_id}'", ] if auth.entity_type == "DEVELOPER" and auth.app_id: # Add app-specific access for developers filters.append(f"access_control->'app_access' ? '{auth.app_id}'") # Add user_id filter in cloud mode if auth.user_id: from core.config import get_settings settings = get_settings() if settings.MODE == "cloud": # Filter by user_id in access_control filters.append(f"access_control->>'user_id' = '{auth.user_id}'") return " OR ".join(filters) def _build_metadata_filter(self, filters: Dict[str, Any]) -> str: """Build PostgreSQL filter for metadata.""" if not filters: return "" filter_conditions = [] for key, value in filters.items(): # Handle list of values (IN operator) if isinstance(value, list): if not value: # Skip empty lists continue # Build a list of properly escaped values escaped_values = [] for item in value: if isinstance(item, bool): escaped_values.append(str(item).lower()) elif isinstance(item, str): # Use standard replace, avoid complex f-string quoting for black escaped_value = item.replace("'", "''") escaped_values.append(f"'{escaped_value}'") else: escaped_values.append(f"'{item}'") # Join with commas for IN clause values_str = ", ".join(escaped_values) filter_conditions.append(f"doc_metadata->>'{key}' IN ({values_str})") else: # Handle single value (equality) # Convert boolean values to string 'true' or 'false' if isinstance(value, bool): value = str(value).lower() # Use proper SQL escaping for string values if isinstance(value, str): # Replace single quotes with double single quotes to escape them value = value.replace("'", "''") filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'") return " AND ".join(filter_conditions) def _build_system_metadata_filter(self, system_filters: Optional[Dict[str, Any]]) -> str: """Build PostgreSQL filter for system metadata.""" if not system_filters: return "" conditions = [] for key, value in system_filters.items(): if value is None: continue # Handle list of values (IN operator) if isinstance(value, list): if not value: # Skip empty lists continue # Build a list of properly escaped values escaped_values = [] for item in value: if isinstance(item, bool): escaped_values.append(str(item).lower()) elif isinstance(item, str): # Use standard replace, avoid complex f-string quoting for black escaped_value = item.replace("'", "''") escaped_values.append(f"'{escaped_value}'") else: escaped_values.append(f"'{item}'") # Join with commas for IN clause values_str = ", ".join(escaped_values) conditions.append(f"system_metadata->>'{key}' IN ({values_str})") else: # Handle single value (equality) if isinstance(value, str): # Replace single quotes with double single quotes to escape them escaped_value = value.replace("'", "''") conditions.append(f"system_metadata->>'{key}' = '{escaped_value}'") elif isinstance(value, bool): conditions.append(f"system_metadata->>'{key}' = '{str(value).lower()}'") else: conditions.append(f"system_metadata->>'{key}' = '{value}'") return " AND ".join(conditions) async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool: """Store metadata for a cache in PostgreSQL. Args: name: Name of the cache metadata: Cache metadata including model info and storage location Returns: bool: Whether the operation was successful """ try: async with self.async_session() as session: await session.execute( text( """ INSERT INTO caches (name, metadata, updated_at) VALUES (:name, :metadata, CURRENT_TIMESTAMP) ON CONFLICT (name) DO UPDATE SET metadata = :metadata, updated_at = CURRENT_TIMESTAMP """ ), {"name": name, "metadata": json.dumps(metadata)}, ) await session.commit() return True except Exception as e: logger.error(f"Failed to store cache metadata: {e}") return False async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]: """Get metadata for a cache from PostgreSQL. Args: name: Name of the cache Returns: Optional[Dict[str, Any]]: Cache metadata if found, None otherwise """ try: async with self.async_session() as session: result = await session.execute(text("SELECT metadata FROM caches WHERE name = :name"), {"name": name}) row = result.first() return row[0] if row else None except Exception as e: logger.error(f"Failed to get cache metadata: {e}") return None async def store_graph(self, graph: Graph) -> bool: """Store a graph in PostgreSQL. This method stores the graph metadata, entities, and relationships in a PostgreSQL table. Args: graph: Graph to store Returns: bool: Whether the operation was successful """ # Ensure database is initialized if not self._initialized: await self.initialize() try: # First serialize the graph model to dict graph_dict = graph.model_dump() # Change 'metadata' to 'graph_metadata' to match our model if "metadata" in graph_dict: graph_dict["graph_metadata"] = graph_dict.pop("metadata") # Serialize datetime objects to ISO format strings graph_dict = _serialize_datetime(graph_dict) # Store the graph metadata in PostgreSQL async with self.async_session() as session: # Store graph metadata in our table graph_model = GraphModel(**graph_dict) session.add(graph_model) await session.commit() logger.info( f"Stored graph '{graph.name}' with {len(graph.entities)} entities " f"and {len(graph.relationships)} relationships" ) return True except Exception as e: logger.error(f"Error storing graph: {str(e)}") return False async def get_graph( self, name: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None ) -> Optional[Graph]: """Get a graph by name. Args: name: Name of the graph auth: Authentication context system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) Returns: Optional[Graph]: Graph if found and accessible, None otherwise """ # Ensure database is initialized if not self._initialized: await self.initialize() try: async with self.async_session() as session: # Build access filter access_filter = self._build_access_filter(auth) # We need to check if the documents in the graph match the system filters # First get the graph without system filters query = select(GraphModel).where(GraphModel.name == name).where(text(f"({access_filter})")) result = await session.execute(query) graph_model = result.scalar_one_or_none() if graph_model: # If system filters are provided, we need to filter the document_ids document_ids = graph_model.document_ids if system_filters and document_ids: # Apply system_filters to document_ids system_metadata_filter = self._build_system_metadata_filter(system_filters) if system_metadata_filter: # Get document IDs with system filters doc_id_placeholders = ", ".join([f"'{doc_id}'" for doc_id in document_ids]) filter_query = f""" SELECT external_id FROM documents WHERE external_id IN ({doc_id_placeholders}) AND ({system_metadata_filter}) """ filter_result = await session.execute(text(filter_query)) filtered_doc_ids = [row[0] for row in filter_result.all()] # If no documents match system filters, return None if not filtered_doc_ids: return None # Update document_ids with filtered results document_ids = filtered_doc_ids # Convert to Graph model graph_dict = { "id": graph_model.id, "name": graph_model.name, "entities": graph_model.entities, "relationships": graph_model.relationships, "metadata": graph_model.graph_metadata, # Reference the renamed column "system_metadata": graph_model.system_metadata or {}, # Include system_metadata "document_ids": document_ids, # Use possibly filtered document_ids "filters": graph_model.filters, "created_at": graph_model.created_at, "updated_at": graph_model.updated_at, "owner": graph_model.owner, "access_control": graph_model.access_control, } return Graph(**graph_dict) return None except Exception as e: logger.error(f"Error retrieving graph: {str(e)}") return None async def list_graphs(self, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Graph]: """List all graphs the user has access to. Args: auth: Authentication context system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) Returns: List[Graph]: List of graphs """ # Ensure database is initialized if not self._initialized: await self.initialize() try: async with self.async_session() as session: # Build access filter access_filter = self._build_access_filter(auth) # Query graphs query = select(GraphModel).where(text(f"({access_filter})")) result = await session.execute(query) graph_models = result.scalars().all() graphs = [] # If system filters are provided, we need to filter each graph's document_ids if system_filters: system_metadata_filter = self._build_system_metadata_filter(system_filters) for graph_model in graph_models: document_ids = graph_model.document_ids if document_ids and system_metadata_filter: # Get document IDs with system filters doc_id_placeholders = ", ".join([f"'{doc_id}'" for doc_id in document_ids]) filter_query = f""" SELECT external_id FROM documents WHERE external_id IN ({doc_id_placeholders}) AND ({system_metadata_filter}) """ filter_result = await session.execute(text(filter_query)) filtered_doc_ids = [row[0] for row in filter_result.all()] # Only include graphs that have documents matching the system filters if filtered_doc_ids: graph = Graph( id=graph_model.id, name=graph_model.name, entities=graph_model.entities, relationships=graph_model.relationships, metadata=graph_model.graph_metadata, # Reference the renamed column system_metadata=graph_model.system_metadata or {}, # Include system_metadata document_ids=filtered_doc_ids, # Use filtered document_ids filters=graph_model.filters, created_at=graph_model.created_at, updated_at=graph_model.updated_at, owner=graph_model.owner, access_control=graph_model.access_control, ) graphs.append(graph) else: # No system filters, include all graphs graphs = [ Graph( id=graph.id, name=graph.name, entities=graph.entities, relationships=graph.relationships, metadata=graph.graph_metadata, # Reference the renamed column system_metadata=graph.system_metadata or {}, # Include system_metadata document_ids=graph.document_ids, filters=graph.filters, created_at=graph.created_at, updated_at=graph.updated_at, owner=graph.owner, access_control=graph.access_control, ) for graph in graph_models ] return graphs except Exception as e: logger.error(f"Error listing graphs: {str(e)}") return [] async def update_graph(self, graph: Graph) -> bool: """Update an existing graph in PostgreSQL. This method updates the graph metadata, entities, and relationships in the PostgreSQL table. Args: graph: Graph to update Returns: bool: Whether the operation was successful """ # Ensure database is initialized if not self._initialized: await self.initialize() try: # First serialize the graph model to dict graph_dict = graph.model_dump() # Change 'metadata' to 'graph_metadata' to match our model if "metadata" in graph_dict: graph_dict["graph_metadata"] = graph_dict.pop("metadata") # Serialize datetime objects to ISO format strings graph_dict = _serialize_datetime(graph_dict) # Update the graph in PostgreSQL async with self.async_session() as session: # Check if the graph exists result = await session.execute(select(GraphModel).where(GraphModel.id == graph.id)) graph_model = result.scalar_one_or_none() if not graph_model: logger.error(f"Graph '{graph.name}' with ID {graph.id} not found for update") return False # Update the graph model with new values for key, value in graph_dict.items(): setattr(graph_model, key, value) await session.commit() logger.info( f"Updated graph '{graph.name}' with {len(graph.entities)} entities " f"and {len(graph.relationships)} relationships" ) return True except Exception as e: logger.error(f"Error updating graph: {str(e)}") return False async def create_folder(self, folder: Folder) -> bool: """Create a new folder.""" try: async with self.async_session() as session: folder_dict = folder.model_dump() # Convert datetime objects to strings for JSON serialization folder_dict = _serialize_datetime(folder_dict) # Check if a folder with this name already exists for this owner # Use only the type/id format stmt = text( """ SELECT id FROM folders WHERE name = :name AND owner->>'id' = :entity_id AND owner->>'type' = :entity_type """ ).bindparams(name=folder.name, entity_id=folder.owner["id"], entity_type=folder.owner["type"]) result = await session.execute(stmt) existing_folder = result.scalar_one_or_none() if existing_folder: logger.info( f"Folder '{folder.name}' already exists with ID {existing_folder}, not creating a duplicate" ) # Update the provided folder's ID to match the existing one # so the caller gets the correct ID folder.id = existing_folder return True # Create a new folder model access_control = folder_dict.get("access_control", {}) # Log access control to debug any issues if "user_id" in access_control: logger.info(f"Storing folder with user_id: {access_control['user_id']}") else: logger.info("No user_id found in folder access_control") folder_model = FolderModel( id=folder.id, name=folder.name, description=folder.description, owner=folder_dict["owner"], document_ids=folder_dict.get("document_ids", []), system_metadata=folder_dict.get("system_metadata", {}), access_control=access_control, rules=folder_dict.get("rules", []), ) session.add(folder_model) await session.commit() logger.info(f"Created new folder '{folder.name}' with ID {folder.id}") return True except Exception as e: logger.error(f"Error creating folder: {e}") return False async def get_folder(self, folder_id: str, auth: AuthContext) -> Optional[Folder]: """Get a folder by ID.""" try: async with self.async_session() as session: # Get the folder logger.info(f"Getting folder with ID: {folder_id}") result = await session.execute(select(FolderModel).where(FolderModel.id == folder_id)) folder_model = result.scalar_one_or_none() if not folder_model: logger.error(f"Folder with ID {folder_id} not found in database") return None # Convert to Folder object folder_dict = { "id": folder_model.id, "name": folder_model.name, "description": folder_model.description, "owner": folder_model.owner, "document_ids": folder_model.document_ids, "system_metadata": folder_model.system_metadata, "access_control": folder_model.access_control, "rules": folder_model.rules, } folder = Folder(**folder_dict) # Check if the user has access to the folder if not self._check_folder_access(folder, auth, "read"): return None return folder except Exception as e: logger.error(f"Error getting folder: {e}") return None async def get_folder_by_name(self, name: str, auth: AuthContext) -> Optional[Folder]: """Get a folder by name.""" try: async with self.async_session() as session: # First try to get a folder owned by this entity if auth.entity_type and auth.entity_id: stmt = text( """ SELECT * FROM folders WHERE name = :name AND (owner->>'id' = :entity_id) AND (owner->>'type' = :entity_type) """ ).bindparams(name=name, entity_id=auth.entity_id, entity_type=auth.entity_type.value) result = await session.execute(stmt) folder_row = result.fetchone() if folder_row: # Convert to Folder object folder_dict = { "id": folder_row.id, "name": folder_row.name, "description": folder_row.description, "owner": folder_row.owner, "document_ids": folder_row.document_ids, "system_metadata": folder_row.system_metadata, "access_control": folder_row.access_control, "rules": folder_row.rules, } return Folder(**folder_dict) # If not found, try to find any accessible folder with that name stmt = text( """ SELECT * FROM folders WHERE name = :name AND ( (owner->>'id' = :entity_id AND owner->>'type' = :entity_type) OR (access_control->'readers' ? :entity_id) OR (access_control->'writers' ? :entity_id) OR (access_control->'admins' ? :entity_id) OR (access_control->'user_id' ? :user_id) ) """ ).bindparams( name=name, entity_id=auth.entity_id, entity_type=auth.entity_type.value, user_id=auth.user_id if auth.user_id else "", ) result = await session.execute(stmt) folder_row = result.fetchone() if folder_row: # Convert to Folder object folder_dict = { "id": folder_row.id, "name": folder_row.name, "description": folder_row.description, "owner": folder_row.owner, "document_ids": folder_row.document_ids, "system_metadata": folder_row.system_metadata, "access_control": folder_row.access_control, "rules": folder_row.rules, } return Folder(**folder_dict) return None except Exception as e: logger.error(f"Error getting folder by name: {e}") return None async def list_folders(self, auth: AuthContext) -> List[Folder]: """List all folders the user has access to.""" try: folders = [] async with self.async_session() as session: # Get all folders result = await session.execute(select(FolderModel)) folder_models = result.scalars().all() for folder_model in folder_models: # Convert to Folder object folder_dict = { "id": folder_model.id, "name": folder_model.name, "description": folder_model.description, "owner": folder_model.owner, "document_ids": folder_model.document_ids, "system_metadata": folder_model.system_metadata, "access_control": folder_model.access_control, "rules": folder_model.rules, } folder = Folder(**folder_dict) # Check if the user has access to the folder if self._check_folder_access(folder, auth, "read"): folders.append(folder) return folders except Exception as e: logger.error(f"Error listing folders: {e}") return [] async def add_document_to_folder(self, folder_id: str, document_id: str, auth: AuthContext) -> bool: """Add a document to a folder.""" try: # First, check if the user has access to the folder folder = await self.get_folder(folder_id, auth) if not folder: logger.error(f"Folder {folder_id} not found or user does not have access") return False # Check if user has write access to the folder if not self._check_folder_access(folder, auth, "write"): logger.error(f"User does not have write access to folder {folder_id}") return False # Check if the document exists and user has access document = await self.get_document(document_id, auth) if not document: logger.error(f"Document {document_id} not found or user does not have access") return False # Check if the document is already in the folder if document_id in folder.document_ids: logger.info(f"Document {document_id} is already in folder {folder_id}") return True # Add the document to the folder async with self.async_session() as session: # Add document_id to document_ids array new_document_ids = folder.document_ids + [document_id] folder_model = await session.get(FolderModel, folder_id) if not folder_model: logger.error(f"Folder {folder_id} not found in database") return False folder_model.document_ids = new_document_ids # Also update the document's system_metadata to include the folder_name folder_name_json = json.dumps(folder.name) stmt = text( f""" UPDATE documents SET system_metadata = jsonb_set(system_metadata, '{{folder_name}}', '{folder_name_json}'::jsonb) WHERE external_id = :document_id """ ).bindparams(document_id=document_id) await session.execute(stmt) await session.commit() logger.info(f"Added document {document_id} to folder {folder_id}") return True except Exception as e: logger.error(f"Error adding document to folder: {e}") return False async def remove_document_from_folder(self, folder_id: str, document_id: str, auth: AuthContext) -> bool: """Remove a document from a folder.""" try: # First, check if the user has access to the folder folder = await self.get_folder(folder_id, auth) if not folder: logger.error(f"Folder {folder_id} not found or user does not have access") return False # Check if user has write access to the folder if not self._check_folder_access(folder, auth, "write"): logger.error(f"User does not have write access to folder {folder_id}") return False # Check if the document is in the folder if document_id not in folder.document_ids: logger.warning(f"Tried to delete document {document_id} not in folder {folder_id}") return True # Remove the document from the folder async with self.async_session() as session: # Remove document_id from document_ids array new_document_ids = [doc_id for doc_id in folder.document_ids if doc_id != document_id] folder_model = await session.get(FolderModel, folder_id) if not folder_model: logger.error(f"Folder {folder_id} not found in database") return False folder_model.document_ids = new_document_ids # Also update the document's system_metadata to remove the folder_name stmt = text( """ UPDATE documents SET system_metadata = jsonb_set(system_metadata, '{folder_name}', 'null'::jsonb) WHERE external_id = :document_id """ ).bindparams(document_id=document_id) await session.execute(stmt) await session.commit() logger.info(f"Removed document {document_id} from folder {folder_id}") return True except Exception as e: logger.error(f"Error removing document from folder: {e}") return False def _check_folder_access(self, folder: Folder, auth: AuthContext, permission: str = "read") -> bool: """Check if the user has the required permission for the folder.""" # Admin always has access if "admin" in auth.permissions: return True # Check if folder is owned by the user if ( auth.entity_type and auth.entity_id and folder.owner.get("type") == auth.entity_type.value and folder.owner.get("id") == auth.entity_id ): # In cloud mode, also verify user_id if present if auth.user_id: from core.config import get_settings settings = get_settings() if settings.MODE == "cloud": folder_user_ids = folder.access_control.get("user_id", []) if auth.user_id not in folder_user_ids: return False return True # Check access control lists access_control = folder.access_control or {} if permission == "read": readers = access_control.get("readers", []) if f"{auth.entity_type.value}:{auth.entity_id}" in readers: return True if permission == "write": writers = access_control.get("writers", []) if f"{auth.entity_type.value}:{auth.entity_id}" in writers: return True # For admin permission, check admins list if permission == "admin": admins = access_control.get("admins", []) if f"{auth.entity_type.value}:{auth.entity_id}" in admins: return True return False ``` ## /core/database/user_limits_db.py ```py path="/core/database/user_limits_db.py" import json import logging from datetime import UTC, datetime, timedelta from typing import Any, Dict, Optional from sqlalchemy import Column, Index, String, select, text from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import declarative_base, sessionmaker logger = logging.getLogger(__name__) Base = declarative_base() class UserLimitsModel(Base): """SQLAlchemy model for user limits data.""" __tablename__ = "user_limits" user_id = Column(String, primary_key=True) tier = Column(String, nullable=False) # free, developer, startup, custom custom_limits = Column(JSONB, nullable=True) usage = Column(JSONB, default=dict) # Holds all usage counters app_ids = Column(JSONB, default=list) # List of app IDs registered by this user stripe_customer_id = Column(String, nullable=True) stripe_subscription_id = Column(String, nullable=True) stripe_product_id = Column(String, nullable=True) subscription_status = Column(String, nullable=True) created_at = Column(String) # ISO format string updated_at = Column(String) # ISO format string # Create indexes __table_args__ = (Index("idx_user_tier", "tier"),) class UserLimitsDatabase: """Database operations for user limits.""" def __init__(self, uri: str): """Initialize database connection.""" self.engine = create_async_engine(uri) self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False) self._initialized = False async def initialize(self) -> bool: """Initialize database tables and indexes.""" if self._initialized: return True try: logger.info("Initializing user limits database tables...") # Create tables if they don't exist async with self.engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) # Check if we need to add the new Stripe columns # This safely adds columns if they don't exist without affecting existing data try: # Check if the columns exist first to avoid errors for column_name in [ "stripe_customer_id", "stripe_subscription_id", "stripe_product_id", "subscription_status", ]: await conn.execute( text( f"DO $$\n" f"BEGIN\n" f" IF NOT EXISTS (SELECT 1 FROM information_schema.columns \n" f" WHERE table_name='user_limits' AND column_name='{column_name}') THEN\n" f" ALTER TABLE user_limits ADD COLUMN {column_name} VARCHAR;\n" f" END IF;\n" f"END$$;" ) ) logger.info("Successfully migrated user_limits table schema if needed") except Exception as migration_error: logger.warning(f"Migration step failed, but continuing: {migration_error}") # We continue even if migration fails as the app can still function self._initialized = True logger.info("User limits database tables initialized successfully") return True except Exception as e: logger.error(f"Failed to initialize user limits database: {e}") return False async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]: """ Get user limits for a user. Args: user_id: The user ID to get limits for Returns: Dict with user limits if found, None otherwise """ async with self.async_session() as session: result = await session.execute(select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)) user_limits = result.scalars().first() if not user_limits: return None return { "user_id": user_limits.user_id, "tier": user_limits.tier, "custom_limits": user_limits.custom_limits, "usage": user_limits.usage, "app_ids": user_limits.app_ids, "stripe_customer_id": user_limits.stripe_customer_id, "stripe_subscription_id": user_limits.stripe_subscription_id, "stripe_product_id": user_limits.stripe_product_id, "subscription_status": user_limits.subscription_status, "created_at": user_limits.created_at, "updated_at": user_limits.updated_at, } async def create_user_limits(self, user_id: str, tier: str = "free") -> bool: """ Create user limits record. Args: user_id: The user ID tier: Initial tier (defaults to "free") Returns: True if successful, False otherwise """ try: now = datetime.now(UTC).isoformat() async with self.async_session() as session: # Check if already exists result = await session.execute(select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)) if result.scalars().first(): return True # Already exists # Create new record with properly initialized JSONB columns # Create JSON strings and parse them for consistency usage_json = json.dumps( { "storage_file_count": 0, "storage_size_bytes": 0, "hourly_query_count": 0, "hourly_query_reset": now, "monthly_query_count": 0, "monthly_query_reset": now, "hourly_ingest_count": 0, "hourly_ingest_reset": now, "monthly_ingest_count": 0, "monthly_ingest_reset": now, "graph_count": 0, "cache_count": 0, } ) app_ids_json = json.dumps([]) # Empty array but as JSON string # Create the model with the JSON parsed user_limits = UserLimitsModel( user_id=user_id, tier=tier, usage=json.loads(usage_json), app_ids=json.loads(app_ids_json), stripe_customer_id=None, stripe_subscription_id=None, stripe_product_id=None, subscription_status=None, created_at=now, updated_at=now, ) session.add(user_limits) await session.commit() return True except Exception as e: logger.error(f"Failed to create user limits: {e}") return False async def update_user_tier(self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None) -> bool: """ Update user tier and custom limits. Args: user_id: The user ID tier: New tier custom_limits: Optional custom limits for CUSTOM tier Returns: True if successful, False otherwise """ try: now = datetime.now(UTC).isoformat() async with self.async_session() as session: result = await session.execute(select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)) user_limits = result.scalars().first() if not user_limits: return False user_limits.tier = tier user_limits.custom_limits = custom_limits user_limits.updated_at = now await session.commit() return True except Exception as e: logger.error(f"Failed to update user tier: {e}") return False async def update_subscription_info(self, user_id: str, subscription_data: Dict[str, Any]) -> bool: """ Update user subscription information. Args: user_id: The user ID subscription_data: Dictionary containing subscription information with keys: - stripeCustomerId - stripeSubscriptionId - stripeProductId - subscriptionStatus Returns: True if successful, False otherwise """ try: now = datetime.now(UTC).isoformat() async with self.async_session() as session: result = await session.execute(select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)) user_limits = result.scalars().first() if not user_limits: return False user_limits.stripe_customer_id = subscription_data.get("stripeCustomerId") user_limits.stripe_subscription_id = subscription_data.get("stripeSubscriptionId") user_limits.stripe_product_id = subscription_data.get("stripeProductId") user_limits.subscription_status = subscription_data.get("subscriptionStatus") user_limits.updated_at = now await session.commit() return True except Exception as e: logger.error(f"Failed to update subscription info: {e}") return False async def register_app(self, user_id: str, app_id: str) -> bool: """ Register an app for a user. Args: user_id: The user ID app_id: The app ID to register Returns: True if successful, False otherwise """ try: now = datetime.now(UTC).isoformat() async with self.async_session() as session: # First check if user exists result = await session.execute(select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)) user_limits = result.scalars().first() if not user_limits: logger.error(f"User {user_id} not found in register_app") return False # Use raw SQL with jsonb_array_append to update the app_ids array # This is the most reliable way to append to a JSONB array in PostgreSQL query = text( """ UPDATE user_limits SET app_ids = CASE WHEN NOT (app_ids ? :app_id) -- Check if app_id is not in the array THEN app_ids || :app_id_json -- Append it if not present ELSE app_ids -- Keep it unchanged if already present END, updated_at = :now WHERE user_id = :user_id RETURNING app_ids; """ ) # Execute the query result = await session.execute( query, { "app_id": app_id, # For the check "app_id_json": f'["{app_id}"]', # JSON array format for appending "now": now, "user_id": user_id, }, ) # Log the result for debugging updated_app_ids = result.scalar() logger.info(f"Updated app_ids for user {user_id}: {updated_app_ids}") await session.commit() return True except Exception as e: logger.error(f"Failed to register app: {e}") return False async def update_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool: """ Update usage counter for a user. Args: user_id: The user ID usage_type: Type of usage to update increment: Value to increment by Returns: True if successful, False otherwise """ try: now = datetime.now(UTC) now_iso = now.isoformat() async with self.async_session() as session: result = await session.execute(select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)) user_limits = result.scalars().first() if not user_limits: return False # Create a new dictionary to force SQLAlchemy to detect the change usage = dict(user_limits.usage) if user_limits.usage else {} # Handle different usage types if usage_type == "query": # Check hourly reset hourly_reset_str = usage.get("hourly_query_reset", "") if hourly_reset_str: hourly_reset = datetime.fromisoformat(hourly_reset_str) if now > hourly_reset + timedelta(hours=1): usage["hourly_query_count"] = increment usage["hourly_query_reset"] = now_iso else: usage["hourly_query_count"] = usage.get("hourly_query_count", 0) + increment else: usage["hourly_query_count"] = increment usage["hourly_query_reset"] = now_iso # Check monthly reset monthly_reset_str = usage.get("monthly_query_reset", "") if monthly_reset_str: monthly_reset = datetime.fromisoformat(monthly_reset_str) if now > monthly_reset + timedelta(days=30): usage["monthly_query_count"] = increment usage["monthly_query_reset"] = now_iso else: usage["monthly_query_count"] = usage.get("monthly_query_count", 0) + increment else: usage["monthly_query_count"] = increment usage["monthly_query_reset"] = now_iso elif usage_type == "ingest": # Similar pattern for ingest hourly_reset_str = usage.get("hourly_ingest_reset", "") if hourly_reset_str: hourly_reset = datetime.fromisoformat(hourly_reset_str) if now > hourly_reset + timedelta(hours=1): usage["hourly_ingest_count"] = increment usage["hourly_ingest_reset"] = now_iso else: usage["hourly_ingest_count"] = usage.get("hourly_ingest_count", 0) + increment else: usage["hourly_ingest_count"] = increment usage["hourly_ingest_reset"] = now_iso monthly_reset_str = usage.get("monthly_ingest_reset", "") if monthly_reset_str: monthly_reset = datetime.fromisoformat(monthly_reset_str) if now > monthly_reset + timedelta(days=30): usage["monthly_ingest_count"] = increment usage["monthly_ingest_reset"] = now_iso else: usage["monthly_ingest_count"] = usage.get("monthly_ingest_count", 0) + increment else: usage["monthly_ingest_count"] = increment usage["monthly_ingest_reset"] = now_iso elif usage_type == "storage_file": usage["storage_file_count"] = usage.get("storage_file_count", 0) + increment elif usage_type == "storage_size": usage["storage_size_bytes"] = usage.get("storage_size_bytes", 0) + increment elif usage_type == "graph": usage["graph_count"] = usage.get("graph_count", 0) + increment elif usage_type == "cache": usage["cache_count"] = usage.get("cache_count", 0) + increment # Force SQLAlchemy to recognize the change by assigning a new dict user_limits.usage = usage user_limits.updated_at = now_iso # Explicitly mark as modified session.add(user_limits) # Log the updated usage for debugging logger.info(f"Updated usage for user {user_id}, type: {usage_type}, value: {increment}") logger.info(f"New usage values: {usage}") logger.info(f"About to commit: user_id={user_id}, usage={user_limits.usage}") # Commit and flush to ensure changes are written await session.commit() return True except Exception as e: logger.error(f"Failed to update usage: {e}") import traceback logger.error(traceback.format_exc()) return False ``` ## /core/embedding/__init__.py ```py path="/core/embedding/__init__.py" from core.embedding.base_embedding_model import BaseEmbeddingModel from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel from core.embedding.litellm_embedding import LiteLLMEmbeddingModel __all__ = ["BaseEmbeddingModel", "LiteLLMEmbeddingModel", "ColpaliEmbeddingModel"] ``` The content has been capped at 50000 tokens, and files over NaN bytes have been omitted. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.