``` ├── .env.template ├── .github/ ├── workflows/ ├── pre-commit.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── GETTING_STARTED.md ├── Jenkinsfile ├── LICENSE ├── README.md ├── alembic.ini ├── app/ ├── alembic/ ├── README ├── env.py ├── script.py.mako ├── versions/ ├── 20240812184546_6d16b920a3ec_initial_migration.py ├── 20240812190934_5ceb460ac3ef_adding_support_for_projects.py ├── 20240812211350_bcc569077106_utc_timestamps_and_indexing.py ├── 20240813145447_56e7763c7d20_add_on_delete_cascade_to_message_.py ├── 20240820182032_d3f532773223_changes_for_implementation_of_.py ├── 20240823164559_05069444feee_project_id_to_string_anddelete_col.py ├── 20240826215938_3c7be0985b17_search_index.py ├── 20240828094302_48240c0ce09e_add_agent_id_support_in_conversation_.py ├── 20240902105155_6b44dc81d95d_prompt_tables.py ├── 20240905144257_342902c88262_add_user_preferences_table.py ├── 20240927094023_fb0b353e69d0_support_for_citations_in_backend.py ├── 20241003153813_827623103002_add_shared_with_email_to_the_.py ├── 20241020111943_262d870e9686_custom_agents.py ├── 20241028204107_684a330f9e9f_new_migration.py ├── 20241127095409_625f792419e7_support_for_repo_path.py ├── 20250303164854_414f9ab20475_custom_agent_sharing.py ├── 20250310201406_97a740b07a50_custom_agent_sharing.py ├── 82eb6e97aed3_merge_heads.py ├── api/ ├── router.py ├── celery/ ├── celery_app.py ├── tasks/ ├── parsing_tasks.py ├── worker.py ├── core/ ├── base_model.py ├── config_provider.py ├── database.py ├── models.py ├── main.py ├── modules/ ├── auth/ ├── api_key_service.py ├── auth_router.py ├── auth_schema.py ├── auth_service.py ├── tests/ ├── auth_service_test.py ├── code_provider/ ├── code_provider_service.py ├── github/ ├── github_controller.py ├── github_router.py ├── github_service.py ├── local_repo/ ├── local_repo_service.py ├── conversations/ ├── access/ ├── access_schema.py ├── access_service.py ├── conversation/ ├── conversation_controller.py ├── conversation_model.py ├── conversation_schema.py ├── conversation_service.py ├── conversations_router.py ├── message/ ├── message_model.py ├── message_schema.py ├── message_service.py ├── intelligence/ ├── __init__.py ├── agents/ ├── agents_controller.py ``` ## /.env.template ```template path="/.env.template" isDevelopmentMode=enabled ENV=development OPENAI_API_KEY= POSTGRES_SERVER=postgresql://postgres:mysecretpassword@localhost:5432/momentum NEO4J_URI=bolt://127.0.0.1:7687 NEO4J_USERNAME=neo4j NEO4J_PASSWORD=mysecretpassword REDISHOST=127.0.0.1 REDISPORT=6379 BROKER_URL=redis://127.0.0.1:6379/0 CELERY_QUEUE_NAME=dev defaultUsername=defaultuser PROJECT_PATH=projects #repositories will be downloaded/cloned to this path on your system. {PROVIDER}_API_KEY=ollama INFERENCE_MODEL=ollama_chat/qwen2.5-coder:7b CHAT_MODEL=ollama_chat/qwen2.5-coder:7b LLM_API_BASE= LLM_API_VERSION= # following are for production mode PORTKEY_API_KEY= GCP_PROJECT= FIREBASE_SERVICE_ACCOUNT= KNOWLEDGE_GRAPH_URL= GITHUB_APP_ID= GITHUB_PRIVATE_KEY= GH_TOKEN_LIST= TRANSACTION_EMAILS_ENABLED= EMAIL_FROM_ADDRESS= RESEND_API_KEY= ANTHROPIC_API_KEY= POSTHOG_API_KEY= POSTHOG_HOST= FIRECRAWL_API_KEY= ``` ## /.github/workflows/pre-commit.yml ```yml path="/.github/workflows/pre-commit.yml" name: Pre-commit on: pull_request: types: [opened, synchronize] permissions: contents: write # Grants write access to push changes jobs: pre-commit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch full commit history token: ${{ secrets.GITHUB_TOKEN }} # Allows pushing changes - uses: actions/setup-python@v4 with: python-version: "3.11" # Set a consistent Python version - name: Install dependencies run: pip install --upgrade pip pre-commit # Ensure latest pre-commit version - name: Run pre-commit run: pre-commit run --all-files --show-diff-on-failure || true # Run all hooks without failing - name: Check for changes and commit run: | if [[ `git status --porcelain` ]]; then git config --global user.name "github-actions[bot]" git config --global user.email "github-actions[bot]@users.noreply.github.com" git add . git commit -m "chore: Auto-fix pre-commit issues" git push origin HEAD:${{ github.head_ref }} fi env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Ensures authentication ``` ## /.gitignore ```gitignore path="/.gitignore" .DS_Store .vscode *__pycache__ .venv firebase_service_account.json firebase_service_account.json:Zone.Identifier cli/dist cli/.venv venv/ .env *.json .momentum db cli/momentum_cli/.momentum .hypothesis # Ignore all .env.* files .env.* # Except for .env.template !.env.template *.log .cursorrules projects/ # Ignore PyCharm config .idea # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. # dependencies /node_modules /.pnp .pnp.js .yarn/install-state.gz /.code-workspace # testing /coverage # next.js /.next/ /out/ # production /build # misc .DS_Store *.pem # debug npm-debug.log* yarn-debug.log* yarn-error.log* # local env files .env*.local .env # vercel .vercel # typescript *.tsbuildinfo next-env.d.ts certificates package-lock.json .next .vscode/ ``` ## /.gitmodules ```gitmodules path="/.gitmodules" [submodule "potpie-ui"] path = potpie-ui url = https://github.com/potpie-ai/potpie-ui ``` ## /.pre-commit-config.yaml ```yaml path="/.pre-commit-config.yaml" repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: - id: check-yaml - id: trailing-whitespace - id: end-of-file-fixer - id: check-merge-conflict - id: check-added-large-files args: ["--maxkb=51200"] - id: debug-statements - repo: https://github.com/psf/black rev: 24.8.0 hooks: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.6.2 hooks: - id: ruff args: ["--fix"] - repo: https://github.com/PyCQA/bandit rev: 1.7.9 hooks: - id: bandit ``` ## /CODE_OF_CONDUCT.md # Potpie Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at dhiren@potpie.ai All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ## /GETTING_STARTED.md # Development mode ## Running Potpie **Install Python 3.10**: Download and install Python 3.10 from the official Python website: https://www.python.org/downloads/release/python-3100/ 1. **Ensure Docker is Installed**: Verify that Docker is installed and running on your system. 2. **Set Up the Environment**: Create a `.env` file based on the provided `.env.template` in the repository. This file should include all necessary configuration settings for the application. Ensure that: ``` isDevelopmentMode=enabled ENV=development OPENAI_API_KEY= ``` Create a Virtual Environment using Python 3.10: ```bash python3.10 -m venv venv source venv/bin/activate ``` alternatively, you can also use the `virtualenv` library. Install dependencies in your venv: ```bash pip install -r requirements.txt ``` If you face any issues with the dependencies, you can try installing the dependencies using the following command: ```bash pip install -r requirements.txt --use-deprecated=legacy-resolver ``` 3. You can use the following env config to run potpie with local models: ``` INFERENCE_MODEL=ollama_chat/qwen2.5-coder:7b CHAT_MODEL=ollama_chat/qwen2.5-coder:7b ``` To run potpie with any other models, you can use the following env configuration: ``` {PROVIDER}_API_KEY=sk-or-your-key #your provider key e.g. OPENAI_API_KEY for Openai INFERENCE_MODEL=openrouter/deepseek/deepseek-chat #provider model name CHAT_MODEL=openrouter/deepseek/deepseek-chat #provider model name ``` **`INFERENCE_MODEL`** and **`CHAT_MODEL`** correspond to the models that will be used for generating knowledge graph and for agent reasoning respectively. These model names should be in the format of `provider/model_name` format or as expected by Litellm. For more information, refer to the [Litellm documentation](https://docs.litellm.ai/docs/providers).
4. **Run Potpie**: Execute the following command: ```bash ./start.sh ``` You may need to make it executable by running: ```bash chmod +x start.sh ``` 5. Start using Potpie with your local codebases! # Production setup For a production deployment with Firebase authentication, Github access, Secret Management etc ## Firebase Setup To set up Firebase, follow these steps: 1. **Create a Firebase Project**: Go to [Firebase Console](https://console.firebase.google.com/) and create a new project. 2. **Generate a Service Account Key**: - Click on **Project Overview Gear ⚙** from the sidebar. - Open the **Service Accounts** tab. - Click on the option to generate a new private key in the Firebase Admin SDK sub-section. - Read the warning and generate the key. Rename the downloaded key to `firebase_service_account.json` and move it to the root of the potpie source code. 3. **Create a Firebase App** - Go to the **Project Overview Gear ⚙** from the sidebar. - Create a Firebase app. - You will find keys for hosting, storage, and other services. Use these keys in your `.env` file. --- ## PostHog Integration PostHog is an open-source platform that helps us analyze user behavior on Potpie. - **Sign Up**: Create a free account at [PostHog](https://us.posthog.com/signup) and keep your API key in `.env` as `POSTHOG_API_KEY`, and `POSTHOG_HOST` --- ## Portkey Integration Portkey provides observability and monitoring capabilities for AI integration with Potpie. - **Sign Up**: Create a free account at [Portkey](https://app.portkey.ai/signup) and keep your API key in `.env` as `PORTKEY_API_KEY`. --- ## Setting Up GitHub App To enable login via GitHub, create a GitHub app by following these steps: 1. Visit [GitHub App Creation](https://github.com/settings/apps/new). 2. **Name Your App**: Choose a name relevant to Potpie (e.g., `potpie-auth`). 3. **Set Permissions**: - **Repository Permissions**: - Contents: Read Only - Metadata: Read Only - Pull Requests: Read and Write - Secrets: Read Only - Webhook: Read Only - **Organization Permissions**: Members : Read Only - **Account Permissions**: Email Address: Read Only - **Homepage URL** : https://potpie.ai - **Webhook** : Inactive 4. **Generate a Private Key**: Download the private key and place it in the project root . Add your app ID to `GITHUB_APP_ID`. 5. **Format your Private Key**: Use the `format_pem.sh` to format your key: ```bash chmod +x format_pem.sh ./format_pem.sh your-key.pem ``` The formatted key will be displayed in the terminal. Copy the formatted key and add it to env under `GITHUB_PRIVATE_KEY`. 6. **Install the App**: From the left sidebar, select **Install App** and install it next to your organization/user account. 7. **Create a GitHub Token**: Go to your GitHub Settings > Developer Settings > Personal Access Tokens > Tokens (classic). Add the token to your `.env` file under `GH_TOKEN_LIST` --- ## Enabling GitHub Auth on Firebase 1. Open Firebase and navigate to **Authentication**. 2. Enable GitHub sign-in capability by adding a GitHub OAuth app from your account. This will provide you with a client secret and client ID to add to Firebase. 3. Copy the callback URL from Firebase and add it to your GitHub app. GitHub Auth with Firebase is now ready. --- ## Google Cloud Setup Potpie uses Google Secret Manager to securely manage API keys. If you created a Firebase app, a linked Google Cloud account will be automatically created. You can use that or create a new one as needed. Follow these steps to set up the Secret Manager and Application Default Credentials (ADC) for Potpie: 1. Install gcloud CLI. Follow the official installation guide: https://cloud.google.com/sdk/docs/install After installation, initialize gcloud CLI: ```bash gcloud init ``` Say yes to configuring a default compute region. Select your local region when prompted. 2. Set up the gcloud Secret Manager API. 3. Configure Application Default Credentials for local use: https://cloud.google.com/docs/authentication/set-up-adc-local-dev-environment Once completed, you are ready to proceed with the Potpie setup. --- ## Running Potpie 1. **Ensure Docker is Installed**: Verify that Docker is installed and running on your system. 2. **Set Up the Environment**: Create a `.env` file based on the provided `.env.template` in the repository. This file should include all necessary configuration settings for the application. 3. **Google Cloud Authentication**: Log in to your Google Cloud account and set up Application Default Credentials (ADC). Detailed instructions can be found in the documentation. Alternatively place the service account key file for your gcp project in service-account.json file in the root of the codebase. 5. **Run Potpie**: Execute the following command: ```bash ./start.sh ``` You may need to make it executable by running: ```bash chmod +x start.sh ``` ## /Jenkinsfile ``` path="/Jenkinsfile" pipeline { agent any parameters { string(name: 'namespace', defaultValue: "mom-server", description: 'namespace to deploy') } environment { // Access environment variables using Jenkins credentials DOCKER_REGISTRY = credentials('momentum-server-docker-registry') GKE_CLUSTER = credentials('mom-core-gke-cluster') GKE_ZONE = credentials('gke-zone') GCP_PROJECT = credentials('gcp-project') GOOGLE_APPLICATION_CREDENTIALS = credentials('google-application-credentials') } stages { stage('Checkout') { steps { script { // Determine environment based on branch def branch = env.GIT_BRANCH if (branch == "origin/temp") { env.ENVIRONMENT = 'temp' } else if (branch == "origin/main") { env.ENVIRONMENT = 'main' } else if (branch == "origin/devops"){ env.ENVIORNMENT = 'devops' } else { error("Unknown branch: ${branch}. This pipeline only supports main and staging branches.") } checkout scm // Capture the short Git commit hash to use as the image tag env.GIT_COMMIT_HASH = sh(returnStdout: true, script: 'git rev-parse --short HEAD').trim() } } } stage('Configure Docker Authentication') { steps { script { // Extract the registry's hostname for authentication def registryHost = env.DOCKER_REGISTRY.tokenize('/')[0] sh """ sudo gcloud auth configure-docker ${registryHost} """ } } } stage('Build Docker Image') { steps { script { // Use the Git commit hash as the image tag def imageTag = env.GIT_COMMIT_HASH def dockerRegistry = env.DOCKER_REGISTRY echo "Printing the saved docker registry from env:" echo "${dockerRegistry}" sh "sudo docker build -t ${DOCKER_REGISTRY}/momentum-server:${imageTag} ." } } } stage('Push Docker Image') { steps { script { // Use the Git commit hash as the image tag def imageTag = env.GIT_COMMIT_HASH echo "printing the user here" sh "whoami && pwd" sh "sudo docker push ${DOCKER_REGISTRY}/momentum-server:${imageTag}" } } } stage('Configure GKE Authentication') { steps { script { // Use the service account path from credentials sh """ sudo gcloud auth activate-service-account --key-file=${GOOGLE_APPLICATION_CREDENTIALS} sudo gcloud container clusters get-credentials ${GKE_CLUSTER} --zone ${GKE_ZONE} --project ${GCP_PROJECT} """ } } } stage('Ask User for Deployment Confirmation') { steps { script { def deployConfirmation = input( id: 'userInput', message: 'Do you want to deploy the new Docker image?', parameters: [ choice(name: 'Deploy', choices: ['Yes', 'No'], description: 'Select Yes to deploy the image or No to abort.') ] ) if (deployConfirmation == 'No') { error('User chose not to deploy the images, stopping the pipeline.') } } } } stage('Deploy Image') { steps { script { def imageDeploySucceeded = false def imageTag = env.GIT_COMMIT_HASH echo "this is the fetched docker image tag: ${imageTag}" try { sh """ kubectl set image deployment/momentum-server-deployment momentum-server=${DOCKER_REGISTRY}/momentum-server:${imageTag} -n ${params.namespace} kubectl rollout status deployment/momentum-server-deployment -n ${params.namespace} """ imageDeploySucceeded = true } catch (Exception e) { echo "Deployment failed: ${e}" } if (!imageDeploySucceeded) { echo 'Rolling back to previous revision...' sh 'kubectl rollout undo deployment/momentum-server-deployment -n ${params.namespace}' } } } } stage('Pipeline finished') { steps { script { echo "Pipeline finished" // Check the deployment status sh """ echo "checking the deployment status" && kubectl get pods -n ${params.namespace} """ } } } } post { always { echo "Pipeline finished" // Optional cleanup action script { // Clean up local Docker images def imageTag = env.GIT_COMMIT_HASH sh """ docker rmi ${DOCKER_REGISTRY}/momentum-server:${imageTag} || true """ } } } } ``` ## /LICENSE ``` path="/LICENSE" Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS Copyright 2024 Momenta Softwares Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ``` ## /README.md

Potpie AI logo


potpie-ai%2Fpotpie | Trendshift

App | Documentation | API Reference | Chat with 🥧 Repo

Apache 2.0 GitHub Repo stars
Join our Discord
VS Code Extension
tweet

Prompt-To-Agent: Create custom engineering agents for your code

Potpie is an open-source platform that creates AI agents specialized in your codebase, enabling automated code analysis, testing, and development tasks. By building a comprehensive knowledge graph of your code, Potpie's agents can understand complex relationships and assist with everything from debugging to feature development.

Screenshot 2025-03-28 at 2 51 34 PM ## 📚 Table of Contents - [🥧 Why Potpie?](#why-potpie) - [🤖 Our Prebuilt Agents](#prebuilt-agents) - [🛠️ Tooling](#potpies-tooling-system) - [🚀 Getting Started](#getting-started) - [💡 Use Cases](#use-cases) - [🛠️ Custom Agents](#custom-agents-upgrade) - [🗝️ Accessing Agents via API Key](#accessing-agents-via-api-key) - [🎨 Make Potpie Your Own](#make-potpie-your-own) - [🤝 Contributing](#contributing) - [📜 License](#license) - [💪 Contributors](#-thanks-to-all-contributors) ## 🥧 Why Potpie? - 🧠 **Deep Code Understanding**: Built-in knowledge graph captures relationships between code components - 🤖 **Pre-built & Custom Agents**: Ready-to-use agents for common tasks + build your own - 🔄 **Seamless Integration**: Works with your existing development workflow - 📈 **Flexible**: Handles codebases of any size or language ## 🔌 VSCode Extension Bring the power of Potpie's AI agents directly into your development environment with our VSCode extension: - **Direct Integration**: Access all Potpie agents without leaving your editor - **Quick Setup**: Install directly from the [VSCode Marketplace](https://marketplace.visualstudio.com/items?itemName=PotpieAI.potpie-vscode-extension) - **Seamless Workflow**: Ask questions, get explanations, and implement suggestions right where you code ## 🧩 Slack Integration Bring your custom AI agents directly into your team's communication hub with our Slack integration: - **Team Collaboration**: Access all Potpie agents where your team already communicates - **Effortless Setup**: Install and configure in under 2 minutes. Checkout [Potpie docs](https://docs.potpie.ai/extensions/slack) - **Contextual Assistance**: Get answers, code solutions, and project insights directly in your Slack threads 👉 Install the Potpie Slack App: [Here](https://slack.potpie.ai/slack/install) ## 🤖 Potpie's Prebuilt Agents Potpie offers a suite of specialized codebase agents for automating and optimizing key aspects of software development: - **Debugging Agent**: Automatically analyzes stacktraces and provides debugging steps specific to your codebase. - **Codebase Q&A Agent**: Answers questions about your codebase and explains functions, features, and architecture. - **Code Changes Agent**: Analyzes code changes, identifies affected APIs, and suggests improvements before merging. - **Integration Test Agent**: Generates integration test plans and code for flows to ensure components work together properly. - **Unit Test Agent**: Automatically creates unit test plan and code for individual functions to enhance test coverage. - **LLD Agent**: Creates a low level design for implementing a new feature by providing functional requirements to this agent. - **Code Generation Agent**: Generates code for new features, refactors existing code, and suggests optimizations. ## 🛠️ Potpie's Tooling System Potpie provides a set of tools that agents can use to interact with the knowledge graph and the underlying infrastructure: - **get_code_from_probable_node_name**: Retrieves code snippets based on a probable node name. - **get_code_from_node_id**: Fetches code associated with a specific node ID. - **get_code_from_multiple_node_ids**: Retrieves code snippets for multiple node IDs simultaneously. - **ask_knowledge_graph_queries**: Executes vector similarity searches to obtain relevant information. - **get_nodes_from_tags**: Retrieves nodes tagged with specific keywords. - **get_code_graph_from_node_id/name**: Fetches code graph structures for a specific node. - **change_detection**: Detects changes in the current branch compared to the default branch. - **get_code_file_structure**: Retrieves the file structure of the codebase. ## 🚀 Getting Started ### Prerequisites - Docker installed and running - Git installed (for repository access) - Python 3.10.x ### Potpie UI An easy to use interface to interact with your Agents ## Initialize the UI Submodule To initialize the submodule: ```bash git submodule update --init ``` ### 1. Navigate to the `potpie-ui` Directory ```bash cd potpie-ui ``` ### 2. Update the Main Branch and Checkout ```bash git checkout main git pull origin main ``` ### 3. Set Up the Environment Create a `.env` file in the `potpie-ui` directory and copy the required configuration from `.env.template`. ```bash cp .env.template .env ``` ### 4. Build the Frontend ```bash pnpm build ``` ### 5. Start the Application ```bash pnpm start ``` ### Setup Steps **Install Python 3.10** - Download and install Python 3.10 from the official Python website: https://www.python.org/downloads/release/python-3100/ 1. **Prepare Your Environment** - Create a `.env` file based on the `.env.template` - Add the following required configurations: ```bash isDevelopmentMode=enabled ENV=development POSTGRES_SERVER=postgresql://postgres:mysecretpassword@localhost:5432/momentum NEO4J_URI=bolt://127.0.0.1:7687 NEO4J_USERNAME=neo4j NEO4J_PASSWORD=mysecretpassword REDISHOST=127.0.0.1 REDISPORT=6379 BROKER_URL=redis://127.0.0.1:6379/0 CELERY_QUEUE_NAME=dev defaultUsername=defaultuser PROJECT_PATH=projects #repositories will be downloaded/cloned to this path on your system. {PROVIDER}_API_KEY=sk-proj-your-key #your provider key e.g. ANTHROPIC_API_KEY for Anthropic INFERENCE_MODEL=ollama_chat/qwen2.5-coder:7b #provider model name CHAT_MODEL=ollama_chat/qwen2.5-coder:7b #provider model name ``` **`INFERENCE_MODEL`** and **`CHAT_MODEL`** correspond to the models that will be used for generating knowledge graph and for agent reasoning respectively. These model names should be in the format of `provider/model_name` format or as expected by Litellm. For more information, refer to the [Litellm documentation](https://docs.litellm.ai/docs/providers).
- Create a Virtual Environment using Python 3.10: ``` python3.10 -m venv venv source venv/bin/activate - Install dependencies in your venv: ```bash pip install -r requirements.txt 2. **Start Potpie** To start all Potpie services: ```bash chmod +x start.sh ./start.sh ``` **Windows** ```powershell ./start.ps1 ``` This will: - Start required Docker services - Wait for PostgreSQL to be ready - Apply database migrations - Start the FastAPI application - Start the Celery worker 3. **Stop Potpie** To stop all Potpie services: ```bash ./stop.sh ``` **Windows** ```powershell ./stop.ps1 ``` This will gracefully stop: - The FastAPI application - The Celery worker - All Docker Compose services 4. **Authentication Setup** (Skip this step in development mode) ```bash curl -X POST 'http://localhost:8001/api/v1/login' \ -H 'Content-Type: application/json' \ -d '{ "email": "your-email", "password": "your-password" }' # Save the bearer token from the response for subsequent requests 5. **Initialize Repository Parsing** ```bash # For development mode: curl -X POST 'http://localhost:8001/api/v1/parse' \ -H 'Content-Type: application/json' \ -d '{ "repo_path": "path/to/local/repo", "branch_name": "main" }' # For production mode: curl -X POST 'http://localhost:8001/api/v1/parse' \ -H 'Content-Type: application/json' \ -d '{ "repo_name": "owner/repo-name", "branch_name": "main" }' # Save the project_id from the response 6. **Monitor Parsing Status** ```bash curl -X GET 'http://localhost:8001/api/v1/parsing-status/your-project-id' # Wait until parsing is complete 7. **View Available Agents** ```bash curl -X GET 'http://localhost:8001/api/v1/list-available-agents/?list_system_agents=true' # Note down the agent_id you want to use ``` 8. **Create a Conversation** ```bash curl -X POST 'http://localhost:8001/api/v1/conversations/' \ -H 'Content-Type: application/json' \ -d '{ "user_id": "your_user_id", "title": "My First Conversation", "status": "active", "project_ids": ["your-project-id"], "agent_ids": ["chosen-agent-id"] }' # Save the conversation_id from the response 9. **Start Interacting with Your Agent** ```bash curl -X POST 'http://localhost:8001/api/v1/conversations/your-conversation-id/message/' \ -H 'Content-Type: application/json' \ -d '{ "content": "Your question or request here", "node_ids":[] }' ``` 10. **View Conversation History** (Optional) ```bash curl -X GET 'http://localhost:8001/api/v1/conversations/your-conversation-id/messages/?start=0&limit=10' ``` ## 💡 Use Cases - **Onboarding**: For developers new to a codebase, the codebase QnA agent helps them understand the codebase and get up to speed quickly. Ask it how to setup a new project, how to run the tests etc >We tried to onboard ourselves with Potpie to the [**AgentOps**](https://github.com/AgentOps-AI/AgentOps) codebase and it worked like a charm : Video [here](https://youtu.be/_mPixNDn2r8). - **Codebase Understanding**: Answer questions about any library you're integrating, explain functions, features, and architecture. >We used the Q&A agent to understand the underlying working of a feature of the [**CrewAI**](https://github.com/CrewAIInc/CrewAI) codebase that was not documented in official docs : Video [here](https://www.linkedin.com/posts/dhirenmathur_what-do-you-do-when-youre-stuck-and-even-activity-7256704603977613312-8X8G). - **Low Level Design**: Get detailed implementation plans for new features or improvements before writing code. >We fed an open issue from the [**Portkey-AI/Gateway**](https://github.com/Portkey-AI/Gateway) project to this agent to generate a low level design for it: Video [here](https://www.linkedin.com/posts/dhirenmathur_potpie-ai-agents-vs-llms-i-am-extremely-activity-7255607456448286720-roOC). - **Reviewing Code Changes**: Understand the functional impact of changes and compute the blast radius of modifications. - **Debugging**: Get step-by-step debugging guidance based on stacktraces and codebase context. - **Testing**: Generate contextually aware unit and integration test plans and test code that understand your codebase's structure and purpose. ## 🛠️ Custom Agents [Upgrade ✨](https://potpie.ai/pricing) With Custom Agents, you can design personalized tools that handle repeatable tasks with precision. Key components include: - **System Instructions**: Define the agent's task, goal, and expected output - **Agent Information**: Metadata about the agent's role and context - **Tasks**: Individual steps for job completion - **Tools**: Functions for querying the knowledge graph or retrieving code ## 🗝️ Accessing Agents via API Key You can access Potpie Agents through an API key, enabling integration into CI/CD workflows and other automated processes. For detailed instructions, please refer to the [Potpie API documentation](https://docs.potpie.ai/agents/api-access). - **Generate an API Key**: Easily create an API key for secure access. - **Parse Repositories**: Use the Parse API to analyze code repositories and obtain a project ID. - **Monitor Parsing Status**: Check the status of your parsing requests. - **Create Conversations**: Initiate conversations with specific agents using project and agent IDs adn get a conversation id. - **Send Messages**: Communicate with agents by sending messages within a conversation. ## 🎨 Make Potpie Your Own Potpie is designed to be flexible and customizable. Here are key areas to personalize your own deployment: ### **Effortless Agent Creation**: Design custom agents tailored to your specific tasks using a single prompt. Utilize the following API to create your custom agents: ```bash curl -X POST "http://localhost:8001/api/v1/custom-agents/agents/auto" \ -H "Content-Type: application/json" \ -d '{ "prompt": "Aan agent that takes stacktrace as input and gives root cause analysis and proposed solution as output" }' ``` Read more about other custom agent APIs to edit and delete your custom agents in our [documentation](https://docs.potpie.ai/open-source/agents/create-agent-from-prompt). ### Tool Integration Edit or add tools in the `app/modules/intelligence/tools` directory for your custom agents. Initialise the tools in the `app/modules/intelligence/tools/tool_service.py` file and include them in your agent. ## 🤝 Contributing We welcome contributions! To contribute: 1. Fork the repository 2. Create a new branch (`git checkout -b feature-branch`) 3. Make your changes 4. Commit (`git commit -m 'Add new feature'`) 5. Push to the branch (`git push origin feature-branch`) 6. Open a Pull Request See [Contributing Guide](./contributing.md) for more details. ## 📜 License This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details. ## 💪 Thanks To All Contributors Thanks for spending your time helping build Potpie. Keep rocking 🥂 Contributors ## /alembic.ini ```ini path="/alembic.ini" # A generic, single database configuration. [alembic] # path to migration scripts # Use forward slashes (/) also on windows to provide an os agnostic path script_location = app/alembic # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s # Uncomment the line below if you want the files to be prepended with date and time # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file # for all available tokens # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s # sys.path path, will be prepended to sys.path if present. # defaults to the current working directory. prepend_sys_path = . # timezone to use when rendering the date within the migration file # as well as the filename. # If specified, requires the python>=3.9 or backports.zoneinfo library. # Any required deps can installed by adding `alembic[tz]` to the pip requirements # string value is passed to ZoneInfo() # leave blank for localtime # timezone = # max length of characters to apply to the "slug" field # truncate_slug_length = 40 # set to 'true' to run the environment during # the 'revision' command, regardless of autogenerate # revision_environment = false # set to 'true' to allow .pyc and .pyo files without # a source .py file to be detected as revisions in the # versions/ directory # sourceless = false # version location specification; This defaults # to alembic/versions. When using multiple version # directories, initial revisions must be specified with --version-path. # The path separator used here should be the separator specified by "version_path_separator" below. # version_locations = %(here)s/bar:%(here)s/bat:alembic/versions # version path separator; As mentioned above, this is the character used to split # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. # Valid values for version_path_separator are: # # version_path_separator = : # version_path_separator = ; # version_path_separator = space version_path_separator = os # Use os.pathsep. Default configuration used for new projects. # set to 'true' to search source files recursively # in each "version_locations" directory # new in Alembic version 1.10 # recursive_version_locations = false # the output encoding used when revision files # are written from script.py.mako # output_encoding = utf-8 sqlalchemy.url = driver://user:pass@localhost/dbname [post_write_hooks] # post_write_hooks defines scripts or Python functions that are run # on newly generated revision scripts. See the documentation for further # detail and examples # format using "black" - use the console_scripts runner, against the "black" entrypoint # hooks = black # black.type = console_scripts # black.entrypoint = black # black.options = -l 79 REVISION_SCRIPT_FILENAME # lint with attempts to fix using "ruff" - use the exec runner, execute a binary # hooks = ruff # ruff.type = exec # ruff.executable = %(here)s/.venv/bin/ruff # ruff.options = --fix REVISION_SCRIPT_FILENAME # Logging configuration [loggers] keys = root,sqlalchemy,alembic [handlers] keys = console [formatters] keys = generic [logger_root] level = WARN handlers = console qualname = [logger_sqlalchemy] level = WARN handlers = qualname = sqlalchemy.engine [logger_alembic] level = INFO handlers = qualname = alembic [handler_console] class = StreamHandler args = (sys.stderr,) level = NOTSET formatter = generic [formatter_generic] format = %(levelname)-5.5s [%(name)s] %(message)s datefmt = %H:%M:%S ``` ## /app/alembic/README ``` path="/app/alembic/README" Generic single-database configuration. ``` ## /app/alembic/env.py ```py path="/app/alembic/env.py" import os import time from logging.config import fileConfig from alembic import context from alembic.operations import ops from dotenv import load_dotenv from sqlalchemy import engine_from_config, pool from app.core.base_model import Base from app.core.models import * # noqa target_metadata = Base.metadata # Load environment variables from .env load_dotenv(override=True) # Interpret the config file for Python logging. fileConfig(context.config.config_file_name) target_metadata = Base.metadata # Construct the database URL from environment variables POSTGRES_SERVER = os.getenv("POSTGRES_SERVER", "localhost") # Set the SQLAlchemy URL dynamically from the constructed DATABASE_URL config = context.config config.set_main_option("sqlalchemy.url", POSTGRES_SERVER) # Add your models' metadata object for 'autogenerate' support def process_revision_directives(context, revision, directives): """Automatically prepend timestamp to migration filenames.""" for directive in directives: if isinstance(directive, ops.MigrationScript): # Get the current timestamp timestamp = time.strftime("%Y%m%d%H%M%S") # Modify the revision ID to include the timestamp directive.rev_id = f"{timestamp}_{directive.rev_id}" def run_migrations_online(): connectable = engine_from_config( config.get_section(config.config_ini_section), prefix="sqlalchemy.", poolclass=pool.NullPool, ) with connectable.connect() as connection: context.configure( connection=connection, target_metadata=target_metadata, version_table="alembic_version", compare_type=True, process_revision_directives=process_revision_directives, # Add the timestamp hook here ) with context.begin_transaction(): context.run_migrations() if context.is_offline_mode(): raise Exception("Offline migrations not supported") else: run_migrations_online() ``` ## /app/alembic/script.py.mako ```mako path="/app/alembic/script.py.mako" """${message} Revision ID: ${up_revision} Revises: ${down_revision | comma,n} Create Date: ${create_date} """ from typing import Sequence, Union from alembic import op import sqlalchemy as sa ${imports if imports else ""} # revision identifiers, used by Alembic. revision: str = ${repr(up_revision)} down_revision: Union[str, None] = ${repr(down_revision)} branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} def upgrade() -> None: ${upgrades if upgrades else "pass"} def downgrade() -> None: ${downgrades if downgrades else "pass"} ``` ## /app/alembic/versions/20240812184546_6d16b920a3ec_initial_migration.py ```py path="/app/alembic/versions/20240812184546_6d16b920a3ec_initial_migration.py" """Initial migration Revision ID: 20240812184546_6d16b920a3ec Revises: Create Date: 2024-08-12 18:45:46.599604 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "20240812184546_6d16b920a3ec" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( "users", sa.Column("uid", sa.String(length=255), nullable=False), sa.Column("email", sa.String(length=255), nullable=False), sa.Column("display_name", sa.String(length=255), nullable=True), sa.Column("email_verified", sa.Boolean(), nullable=True), sa.Column("created_at", sa.TIMESTAMP(), nullable=True), sa.Column("last_login_at", sa.TIMESTAMP(), nullable=True), sa.Column( "provider_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True ), sa.Column("provider_username", sa.String(length=255), nullable=True), sa.PrimaryKeyConstraint("uid"), sa.UniqueConstraint("email"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_table("users") # ### end Alembic commands ### ``` ## /app/alembic/versions/20240812190934_5ceb460ac3ef_adding_support_for_projects.py ```py path="/app/alembic/versions/20240812190934_5ceb460ac3ef_adding_support_for_projects.py" """Adding support for projects Revision ID: 20240812190934_5ceb460ac3ef Revises: 20240812184546_6d16b920a3ec Create Date: 2024-08-12 19:09:34.063355 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "20240812190934_5ceb460ac3ef" down_revision: Union[str, None] = "20240812184546_6d16b920a3ec" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( "projects", sa.Column("id", sa.Integer(), nullable=False), sa.Column("directory", sa.Text(), nullable=True), sa.Column("is_default", sa.Boolean(), nullable=True), sa.Column("project_name", sa.Text(), nullable=True), sa.Column("properties", postgresql.BYTEA(), nullable=True), sa.Column("repo_name", sa.Text(), nullable=True), sa.Column("branch_name", sa.Text(), nullable=True), sa.Column("user_id", sa.String(length=255), nullable=False), sa.Column("created_at", sa.TIMESTAMP(), nullable=True), sa.Column("commit_id", sa.String(length=255), nullable=True), sa.Column("is_deleted", sa.Boolean(), nullable=True), sa.Column("updated_at", sa.TIMESTAMP(), nullable=True), sa.Column("status", sa.String(length=255), nullable=True), sa.CheckConstraint( "status IN ('created', 'ready', 'error')", name="check_status" ), sa.ForeignKeyConstraint(["user_id"], ["users.uid"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("directory"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_table("projects") # ### end Alembic commands ### ``` ## /app/alembic/versions/20240812211350_bcc569077106_utc_timestamps_and_indexing.py ```py path="/app/alembic/versions/20240812211350_bcc569077106_utc_timestamps_and_indexing.py" """UTC Timestamps and Indexing Revision ID: 20240812211350_bcc569077106 Revises: 20240812190934_5ceb460ac3ef Create Date: 2024-08-12 21:13:50.136975 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "20240812211350_bcc569077106" down_revision: Union[str, None] = "20240812190934_5ceb460ac3ef" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( "conversations", sa.Column("id", sa.String(length=255), nullable=False), sa.Column("user_id", sa.String(length=255), nullable=False), sa.Column("title", sa.String(length=255), nullable=False), sa.Column( "status", sa.Enum("ACTIVE", "ARCHIVED", "DELETED", name="conversationstatus"), nullable=False, ), sa.Column("project_ids", postgresql.ARRAY(sa.String()), nullable=False), sa.Column("agent_ids", postgresql.ARRAY(sa.String()), nullable=False), sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["users.uid"], ), sa.PrimaryKeyConstraint("id"), ) op.create_index(op.f("ix_conversations_id"), "conversations", ["id"], unique=False) op.create_index( op.f("ix_conversations_user_id"), "conversations", ["user_id"], unique=False ) op.create_table( "messages", sa.Column("id", sa.String(length=255), nullable=False), sa.Column("conversation_id", sa.String(length=255), nullable=False), sa.Column("content", sa.Text(), nullable=False), sa.Column("sender_id", sa.String(length=255), nullable=True), sa.Column( "type", sa.Enum("AI_GENERATED", "HUMAN", name="messagetype"), nullable=False ), sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), sa.CheckConstraint( "(type = 'HUMAN' AND sender_id IS NOT NULL) OR (type = 'AI_GENERATED' AND sender_id IS NULL)", name="check_sender_id_for_type", ), sa.ForeignKeyConstraint( ["conversation_id"], ["conversations.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_index( op.f("ix_messages_conversation_id"), "messages", ["conversation_id"], unique=False, ) op.alter_column( "projects", "created_at", existing_type=postgresql.TIMESTAMP(), type_=sa.TIMESTAMP(timezone=True), nullable=False, ) op.alter_column( "users", "created_at", existing_type=postgresql.TIMESTAMP(), type_=sa.TIMESTAMP(timezone=True), nullable=False, ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.alter_column( "users", "created_at", existing_type=sa.TIMESTAMP(timezone=True), type_=postgresql.TIMESTAMP(), nullable=True, ) op.alter_column( "projects", "created_at", existing_type=sa.TIMESTAMP(timezone=True), type_=postgresql.TIMESTAMP(), nullable=True, ) op.drop_index(op.f("ix_messages_conversation_id"), table_name="messages") op.drop_table("messages") op.drop_index(op.f("ix_conversations_user_id"), table_name="conversations") op.drop_index(op.f("ix_conversations_id"), table_name="conversations") op.drop_table("conversations") # ### end Alembic commands ### ``` ## /app/alembic/versions/20240813145447_56e7763c7d20_add_on_delete_cascade_to_message_.py ```py path="/app/alembic/versions/20240813145447_56e7763c7d20_add_on_delete_cascade_to_message_.py" """Add ON DELETE CASCADE to Message.conversation_id Revision ID: 20240813145447_56e7763c7d20 Revises: 20240812211350_bcc569077106 Create Date: 2024-08-13 14:54:47.718210 """ from typing import Sequence, Union from alembic import op # revision identifiers, used by Alembic. revision: str = "20240813145447_56e7763c7d20" down_revision: Union[str, None] = "20240812211350_bcc569077106" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_constraint("messages_conversation_id_fkey", "messages", type_="foreignkey") op.create_foreign_key( None, "messages", "conversations", ["conversation_id"], ["id"], ondelete="CASCADE", ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_constraint(None, "messages", type_="foreignkey") op.create_foreign_key( "messages_conversation_id_fkey", "messages", "conversations", ["conversation_id"], ["id"], ) # ### end Alembic commands ### ``` ## /app/alembic/versions/20240820182032_d3f532773223_changes_for_implementation_of_.py ```py path="/app/alembic/versions/20240820182032_d3f532773223_changes_for_implementation_of_.py" """Changes for implementation of conversations Revision ID: 20240820182032_d3f532773223 Revises: 20240813145447_56e7763c7d20 Create Date: 2024-08-20 18:20:32.408674 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import ENUM # revision identifiers, used by Alembic. revision: str = "20240820182032_d3f532773223" down_revision: Union[str, None] = "20240813145447_56e7763c7d20" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None message_status_enum = ENUM( "ACTIVE", "ARCHIVED", "DELETED", name="message_status_enum", create_type=False ) def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### # Add the SYSTEM_GENERATED value to the message type enum op.execute("ALTER TYPE messagetype ADD VALUE 'SYSTEM_GENERATED'") # Commit to ensure the new enum value is recognized op.execute("COMMIT") # Drop the old foreign key constraint and create a new one with ON DELETE CASCADE op.drop_constraint( "conversations_user_id_fkey", "conversations", type_="foreignkey" ) op.create_foreign_key( None, "conversations", "users", ["user_id"], ["uid"], ondelete="CASCADE" ) # Drop the agent_ids column from conversations table op.drop_column("conversations", "agent_ids") # Drop the existing check constraint for sender_id op.drop_constraint("check_sender_id_for_type", "messages", type_="check") # Create a new check constraint with the correct logic op.create_check_constraint( "check_sender_id_for_type", "messages", "((type = 'HUMAN'::messagetype AND sender_id IS NOT NULL) OR " "(type IN ('AI_GENERATED'::messagetype, 'SYSTEM_GENERATED'::messagetype) AND sender_id IS NULL))", ) op.alter_column( "projects", "updated_at", existing_type=postgresql.TIMESTAMP(), type_=sa.TIMESTAMP(timezone=True), existing_nullable=True, ) op.alter_column( "users", "last_login_at", existing_type=postgresql.TIMESTAMP(), type_=sa.TIMESTAMP(timezone=True), existing_nullable=True, ) # Create ENUM type in the database message_status_enum.create(op.get_bind(), checkfirst=True) # Add new column using the ENUM type op.add_column( "messages", sa.Column( "status", message_status_enum, nullable=False, server_default="ACTIVE" ), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### # Re-add the agent_ids column to conversations table op.add_column( "conversations", sa.Column( "agent_ids", postgresql.ARRAY(sa.VARCHAR()), autoincrement=False, nullable=False, ), ) # Drop the new foreign key constraint and re-create the old one op.drop_constraint(None, "conversations", type_="foreignkey") op.create_foreign_key( "conversations_user_id_fkey", "conversations", "users", ["user_id"], ["uid"] ) # Drop the new check constraint and re-create the old one op.drop_constraint("check_sender_id_for_type", "messages", type_="check") op.create_check_constraint( "check_sender_id_for_type", "messages", "((type = 'HUMAN'::messagetype AND sender_id IS NOT NULL) OR " "(type = 'AI_GENERATED'::messagetype AND sender_id IS NULL))", ) op.alter_column( "users", "last_login_at", existing_type=sa.TIMESTAMP(timezone=True), type_=postgresql.TIMESTAMP(), existing_nullable=True, ) op.alter_column( "projects", "updated_at", existing_type=sa.TIMESTAMP(timezone=True), type_=postgresql.TIMESTAMP(), existing_nullable=True, ) # Drop the column op.drop_column("messages", "status") # Drop the ENUM type if it is no longer used message_status_enum.drop(op.get_bind(), checkfirst=False) # ### end Alembic commands ### ``` ## /app/alembic/versions/20240823164559_05069444feee_project_id_to_string_anddelete_col.py ```py path="/app/alembic/versions/20240823164559_05069444feee_project_id_to_string_anddelete_col.py" """project id to string anddelete col Revision ID: 20240823164559_05069444feee Revises: 20240820182032_d3f532773223 Create Date: 2024-08-23 16:45:59.991109 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision: str = "20240823164559_05069444feee" down_revision: Union[str, None] = "20240820182032_d3f532773223" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.alter_column( "projects", "id", existing_type=sa.INTEGER(), type_=sa.Text(), existing_nullable=False, ) op.drop_constraint("projects_directory_key", "projects", type_="unique") op.drop_column("projects", "directory") op.drop_column("projects", "is_default") op.drop_column("projects", "project_name") op.drop_constraint("check_status", "projects", type_="check") op.create_check_constraint( "check_status", "projects", "status IN ('submitted', 'cloned', 'parsed', 'ready', 'error')", ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.add_column( "projects", sa.Column("project_name", sa.TEXT(), autoincrement=False, nullable=True), ) op.add_column( "projects", sa.Column("is_default", sa.BOOLEAN(), autoincrement=False, nullable=True), ) op.add_column( "projects", sa.Column("directory", sa.TEXT(), autoincrement=False, nullable=True), ) op.create_unique_constraint("projects_directory_key", "projects", ["directory"]) op.alter_column( "projects", "id", existing_type=sa.Text(), type_=sa.INTEGER(), existing_nullable=False, ) op.drop_constraint("check_status", "projects", type_="check") op.create_check_constraint( "check_status", "projects", "status IN ('created', 'ready', 'error')" ) # ### end Alembic commands ### ``` ## /app/alembic/versions/20240826215938_3c7be0985b17_search_index.py ```py path="/app/alembic/versions/20240826215938_3c7be0985b17_search_index.py" """search index Revision ID: 20240826215938_3c7be0985b17 Revises: 20240823164559_05069444feee Create Date: 2024-08-26 21:59:38.638095 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision: str = "20240826215938_3c7be0985b17" down_revision: Union[str, None] = "20240823164559_05069444feee" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( "search_indices", sa.Column("id", sa.Integer(), nullable=False), sa.Column("project_id", sa.Text(), nullable=True), sa.Column("node_id", sa.String(), nullable=True), sa.Column("name", sa.String(), nullable=True), sa.Column("file_path", sa.String(), nullable=True), sa.Column("content", sa.Text(), nullable=True), sa.ForeignKeyConstraint( ["project_id"], ["projects.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_index( op.f("ix_search_indices_file_path"), "search_indices", ["file_path"], unique=False, ) op.create_index( op.f("ix_search_indices_id"), "search_indices", ["id"], unique=False ) op.create_index( op.f("ix_search_indices_name"), "search_indices", ["name"], unique=False ) op.create_index( op.f("ix_search_indices_node_id"), "search_indices", ["node_id"], unique=False ) op.create_index( op.f("ix_search_indices_project_id"), "search_indices", ["project_id"], unique=False, ) op.create_table( "tasks", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "task_type", sa.Enum( "CODEBASE_PROCESSING", "FILE_INFERENCE", "FLOWS_PROCESSING", name="tasktype", ), nullable=False, ), sa.Column("custom_status", sa.String(length=50), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=True), sa.Column("updated_at", sa.DateTime(), nullable=True), sa.Column("project_id", sa.String(), nullable=False), sa.Column("result", sa.String(), nullable=True), sa.ForeignKeyConstraint( ["project_id"], ["projects.id"], ), sa.PrimaryKeyConstraint("id"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_table("tasks") op.drop_index(op.f("ix_search_indices_project_id"), table_name="search_indices") op.drop_index(op.f("ix_search_indices_node_id"), table_name="search_indices") op.drop_index(op.f("ix_search_indices_name"), table_name="search_indices") op.drop_index(op.f("ix_search_indices_id"), table_name="search_indices") op.drop_index(op.f("ix_search_indices_file_path"), table_name="search_indices") op.drop_table("search_indices") # ### end Alembic commands ### ``` ## /app/alembic/versions/20240828094302_48240c0ce09e_add_agent_id_support_in_conversation_.py ```py path="/app/alembic/versions/20240828094302_48240c0ce09e_add_agent_id_support_in_conversation_.py" """Add agent id support in conversation table Revision ID: 20240828094302_48240c0ce09e Revises: 20240826215938_3c7be0985b17 Create Date: 2024-08-28 09:43:02.922148 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision: str = "20240828094302_48240c0ce09e" down_revision: Union[str, None] = "20240826215938_3c7be0985b17" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.add_column( "conversations", sa.Column("agent_ids", sa.ARRAY(sa.String()), nullable=False) ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_column("conversations", "agent_ids") # ### end Alembic commands ### ``` ## /app/alembic/versions/20240902105155_6b44dc81d95d_prompt_tables.py ```py path="/app/alembic/versions/20240902105155_6b44dc81d95d_prompt_tables.py" """Prompt Tables Revision ID: 20240902105155_6b44dc81d95d Revises: 20240828094302_48240c0ce09e Create Date: 2024-09-02 10:51:55.205130 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision: str = "20240902105155_6b44dc81d95d" down_revision: Union[str, None] = "20240828094302_48240c0ce09e" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( "prompts", sa.Column("id", sa.String(), nullable=False), sa.Column("text", sa.Text(), nullable=False), sa.Column( "type", sa.Enum("SYSTEM", "HUMAN", name="prompttype"), nullable=False ), sa.Column("version", sa.Integer(), nullable=False, server_default="1"), sa.Column( "status", sa.Enum("ACTIVE", "INACTIVE", name="promptstatustype"), nullable=False, server_default="ACTIVE", ), sa.Column("created_by", sa.String(), nullable=True), sa.Column( "created_at", sa.TIMESTAMP(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.Column( "updated_at", sa.TIMESTAMP(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False, ), sa.ForeignKeyConstraint( ["created_by"], ["users.uid"], ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint( "text", "version", "created_by", name="unique_text_version_user" ), sa.CheckConstraint("version > 0", name="check_version_positive"), sa.CheckConstraint("created_at <= updated_at", name="check_timestamps"), ) op.create_table( "agent_prompt_mappings", sa.Column("id", sa.String(), nullable=False), sa.Column("agent_id", sa.String(), nullable=False), sa.Column("prompt_id", sa.String(), nullable=False), sa.Column("prompt_stage", sa.Integer(), nullable=False), sa.ForeignKeyConstraint(["prompt_id"], ["prompts.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint( "agent_id", "prompt_stage", name="unique_agent_prompt_stage" ), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_table("prompts") op.drop_table("agent_prompt_mappings") # ### end Alembic commands ### ``` ## /app/alembic/versions/20240905144257_342902c88262_add_user_preferences_table.py ```py path="/app/alembic/versions/20240905144257_342902c88262_add_user_preferences_table.py" """Add user preferences table Revision ID: 20240905144257_342902c88262 Revises: 20240902105155_6b44dc81d95d Create Date: 2024-09-05 14:42:57.885405 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision: str = "20240905144257_342902c88262" down_revision: Union[str, None] = "20240902105155_6b44dc81d95d" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( "user_preferences", sa.Column("user_id", sa.String(), nullable=False), sa.Column("preferences", sa.JSON(), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["users.uid"], ), sa.PrimaryKeyConstraint("user_id"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_table("user_preferences") # ### end Alembic commands ### ``` ## /app/alembic/versions/20240927094023_fb0b353e69d0_support_for_citations_in_backend.py ```py path="/app/alembic/versions/20240927094023_fb0b353e69d0_support_for_citations_in_backend.py" """Support for citations in backend Revision ID: 20240927094023_fb0b353e69d0 Revises: 20240905144257_342902c88262 Create Date: 2024-09-27 09:40:23.874379 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision: str = "20240927094023_fb0b353e69d0" down_revision: Union[str, None] = "20240905144257_342902c88262" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.add_column("messages", sa.Column("citations", sa.Text(), nullable=True)) op.create_index( "idx_user_preferences_user_id", "user_preferences", ["user_id"], unique=False ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_index("idx_user_preferences_user_id", table_name="user_preferences") op.drop_column("messages", "citations") # ### end Alembic commands ### ``` ## /app/alembic/versions/20241003153813_827623103002_add_shared_with_email_to_the_.py ```py path="/app/alembic/versions/20241003153813_827623103002_add_shared_with_email_to_the_.py" """add shared with email to the conversation model Revision ID: 20241003153813_827623103002 Revises: 20240927094023_fb0b353e69d0 Create Date: 2024-10-03 15:38:13.436502 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision: str = "20241003153813_827623103002" down_revision: Union[str, None] = "20240927094023_fb0b353e69d0" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.add_column( "conversations", sa.Column("shared_with_emails", sa.ARRAY(sa.String()), nullable=True), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_column("conversations", "shared_with_emails") # ### end Alembic commands ### ``` ## /app/alembic/versions/20241020111943_262d870e9686_custom_agents.py ```py path="/app/alembic/versions/20241020111943_262d870e9686_custom_agents.py" """custom_agents Revision ID: 20241020111943_262d870e9686 Revises: Create Date: 2024-10-20 11:19:43.653649 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "20241020111943_262d870e9686" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None # Add this line branch_labels = ("custom_agents_microservice",) def upgrade() -> None: if not op.get_bind().dialect.has_table(op.get_bind(), "custom_agents"): op.create_table( "custom_agents", sa.Column("id", sa.String(), nullable=False), sa.Column("user_id", sa.String(), nullable=True), sa.Column("role", sa.String(), nullable=True), sa.Column("goal", sa.String(), nullable=True), sa.Column("backstory", sa.String(), nullable=True), sa.Column("system_prompt", sa.String(), nullable=True), sa.Column("tasks", postgresql.JSONB(astext_type=sa.Text()), nullable=True), sa.Column("deployment_url", sa.String(), nullable=True), sa.Column("deployment_status", sa.String(), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False), sa.Column("updated_at", sa.DateTime(), nullable=False), sa.PrimaryKeyConstraint("id"), ) op.create_index( op.f("ix_custom_agents_id"), "custom_agents", ["id"], unique=False ) op.create_index( op.f("ix_custom_agents_user_id"), "custom_agents", ["user_id"], unique=False ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_index(op.f("ix_custom_agents_user_id"), table_name="custom_agents") op.drop_index(op.f("ix_custom_agents_id"), table_name="custom_agents") op.drop_table("custom_agents") # ### end Alembic commands ### ``` ## /app/alembic/versions/20241028204107_684a330f9e9f_new_migration.py ```py path="/app/alembic/versions/20241028204107_684a330f9e9f_new_migration.py" """New migration Revision ID: 20241028204107_684a330f9e9f Revises: 20241003153813_827623103002 Create Date: 2024-10-28 20:41:07.469748 """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision: str = "20241028204107_684a330f9e9f" down_revision: Union[str, None] = "20241003153813_827623103002" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.execute("CREATE TYPE visibility AS ENUM ('PRIVATE', 'PUBLIC')") op.add_column( "conversations", sa.Column( "visibility", sa.Enum("PRIVATE", "PUBLIC", name="visibility"), nullable=True ), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_column("conversations", "visibility") op.execute("DROP TYPE visibility") # ### end Alembic commands ### ``` ## /app/alembic/versions/20241127095409_625f792419e7_support_for_repo_path.py ```py path="/app/alembic/versions/20241127095409_625f792419e7_support_for_repo_path.py" """Support for repo path Revision ID: 20241127095409_625f792419e7 Revises: 20241028204107_684a330f9e9f Create Date: 2024-11-27 09:54:09.683918 """ from typing import Sequence, Union from alembic import op # revision identifiers, used by Alembic. revision: str = "20241127095409_625f792419e7" down_revision: Union[str, None] = "20241028204107_684a330f9e9f" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: op.execute("ALTER TABLE projects ADD COLUMN repo_path TEXT DEFAULT NULL") def downgrade() -> None: op.drop_column("projects", "repo_path") ``` ## /app/alembic/versions/20250303164854_414f9ab20475_custom_agent_sharing.py ```py path="/app/alembic/versions/20250303164854_414f9ab20475_custom_agent_sharing.py" """custom_agent_sharing Revision ID: 20250303164854_414f9ab20475 Revises: 82eb6e97aed3 Create Date: 2025-03-03 16:48:54.711260 """ from typing import Sequence, Union from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "20250303164854_414f9ab20475" down_revision: Union[str, None] = "82eb6e97aed3" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.add_column( "custom_agents", sa.Column("visibility", sa.String(), nullable=False, server_default="private"), ) op.alter_column( "custom_agents", "user_id", existing_type=sa.VARCHAR(), nullable=False ) op.alter_column( "custom_agents", "deployment_status", existing_type=sa.VARCHAR(), nullable=True ) op.alter_column( "custom_agents", "created_at", existing_type=postgresql.TIMESTAMP(), type_=sa.DateTime(timezone=True), existing_nullable=False, ) op.alter_column( "custom_agents", "updated_at", existing_type=postgresql.TIMESTAMP(), type_=sa.DateTime(timezone=True), existing_nullable=False, ) op.drop_index("ix_custom_agents_id", table_name="custom_agents") op.drop_index("ix_custom_agents_user_id", table_name="custom_agents") op.create_foreign_key( "fk_custom_agents_user_id", "custom_agents", "users", ["user_id"], ["uid"] ) # ### end Alembic commands ### def downgrade() -> None: op.drop_constraint("fk_custom_agents_user_id", "custom_agents", type_="foreignkey") op.create_index( "ix_custom_agents_user_id", "custom_agents", ["user_id"], unique=False ) op.create_index("ix_custom_agents_id", "custom_agents", ["id"], unique=False) op.alter_column( "custom_agents", "updated_at", existing_type=sa.DateTime(timezone=True), type_=postgresql.TIMESTAMP(), existing_nullable=False, ) op.alter_column( "custom_agents", "created_at", existing_type=sa.DateTime(timezone=True), type_=postgresql.TIMESTAMP(), existing_nullable=False, ) op.alter_column( "custom_agents", "deployment_status", existing_type=sa.VARCHAR(), nullable=False ) op.alter_column( "custom_agents", "user_id", existing_type=sa.VARCHAR(), nullable=True ) op.drop_column("custom_agents", "visibility") # ### end Alembic commands ### ``` ## /app/alembic/versions/20250310201406_97a740b07a50_custom_agent_sharing.py ```py path="/app/alembic/versions/20250310201406_97a740b07a50_custom_agent_sharing.py" """custom_agent_sharing Revision ID: 20250310201406_97a740b07a50 Revises: 20250303164854_414f9ab20475 Create Date: 2025-03-10 20:14:06.456798 """ from typing import Sequence, Union from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision: str = "20250310201406_97a740b07a50" down_revision: Union[str, None] = "20250303164854_414f9ab20475" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( "custom_agent_shares", sa.Column("id", sa.String(), nullable=False), sa.Column("agent_id", sa.String(), nullable=False), sa.Column("shared_with_user_id", sa.String(), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), sa.ForeignKeyConstraint(["agent_id"], ["custom_agents.id"], ondelete="CASCADE"), sa.ForeignKeyConstraint( ["shared_with_user_id"], ["users.uid"], ondelete="CASCADE" ), sa.PrimaryKeyConstraint("id"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_table("custom_agent_shares") # ### end Alembic commands ### ``` ## /app/alembic/versions/82eb6e97aed3_merge_heads.py ```py path="/app/alembic/versions/82eb6e97aed3_merge_heads.py" """merge heads Revision ID: 82eb6e97aed3 Revises: 20241020111943_262d870e9686, 20241127095409_625f792419e7 Create Date: 2025-03-03 16:48:42.230151 """ from typing import Sequence, Union # revision identifiers, used by Alembic. revision: str = "82eb6e97aed3" down_revision: Union[str, None] = ( "20241020111943_262d870e9686", "20241127095409_625f792419e7", ) branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: pass def downgrade() -> None: pass ``` ## /app/api/router.py ```py path="/app/api/router.py" from datetime import datetime import os from typing import List, Optional from fastapi import Depends, Header, HTTPException, Query from pydantic import BaseModel from sqlalchemy.orm import Session from app.core.database import get_db from app.modules.auth.api_key_service import APIKeyService from app.modules.conversations.conversation.conversation_controller import ( ConversationController, ) from app.modules.conversations.conversation.conversation_schema import ( ConversationStatus, CreateConversationRequest, CreateConversationResponse, ) from app.modules.conversations.message.message_schema import ( DirectMessageRequest, MessageRequest, ) from app.modules.intelligence.agents.agents_controller import AgentsController from app.modules.intelligence.prompts.prompt_service import PromptService from app.modules.intelligence.provider.provider_service import ProviderService from app.modules.intelligence.tools.tool_service import ToolService from app.modules.parsing.graph_construction.parsing_controller import ParsingController from app.modules.parsing.graph_construction.parsing_schema import ParsingRequest from app.modules.projects.projects_controller import ProjectController from app.modules.users.user_service import UserService from app.modules.utils.APIRouter import APIRouter from app.modules.usage.usage_service import UsageService from app.modules.search.search_service import SearchService from app.modules.search.search_schema import SearchRequest, SearchResponse router = APIRouter() class SimpleConversationRequest(BaseModel): project_ids: List[str] agent_ids: List[str] async def get_api_key_user( x_api_key: Optional[str] = Header(None), x_user_id: Optional[str] = Header(None), db: Session = Depends(get_db), ) -> dict: """Dependency to validate API key and get user info.""" if not x_api_key: raise HTTPException( status_code=401, detail="API key is required", headers={"WWW-Authenticate": "ApiKey"}, ) if x_api_key == os.environ.get("INTERNAL_ADMIN_SECRET"): user = UserService(db).get_user_by_uid(x_user_id or "") if not user: raise HTTPException( status_code=401, detail="Invalid user_id", headers={"WWW-Authenticate": "ApiKey"}, ) return {"user_id": user.uid, "email": user.email, "auth_type": "api_key"} user = await APIKeyService.validate_api_key(x_api_key, db) if not user: raise HTTPException( status_code=401, detail="Invalid API key", headers={"WWW-Authenticate": "ApiKey"}, ) return user @router.post("/conversations/", response_model=CreateConversationResponse) async def create_conversation( conversation: SimpleConversationRequest, hidden: bool = Query( True, description="Whether to hide this conversation from the web UI" ), db: Session = Depends(get_db), user=Depends(get_api_key_user), ): user_id = user["user_id"] # This will either return True or raise an HTTPException await UsageService.check_usage_limit(user_id) # Create full conversation request with defaults full_request = CreateConversationRequest( user_id=user_id, title=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), status=ConversationStatus.ACTIVE, # Let hidden parameter control the final status project_ids=conversation.project_ids, agent_ids=conversation.agent_ids, ) controller = ConversationController(db, user_id, None) return await controller.create_conversation(full_request, hidden) @router.post("/parse") async def parse_directory( repo_details: ParsingRequest, db: Session = Depends(get_db), user=Depends(get_api_key_user), ): return await ParsingController.parse_directory(repo_details, db, user) @router.get("/parsing-status/{project_id}") async def get_parsing_status( project_id: str, db: Session = Depends(get_db), user=Depends(get_api_key_user), ): return await ParsingController.fetch_parsing_status(project_id, db, user) @router.post("/conversations/{conversation_id}/message/") async def post_message( conversation_id: str, message: MessageRequest, db: Session = Depends(get_db), user=Depends(get_api_key_user), ): if message.content == "" or message.content is None or message.content.isspace(): raise HTTPException(status_code=400, detail="Message content cannot be empty") user_id = user["user_id"] checked = await UsageService.check_usage_limit(user_id) if not checked: raise HTTPException( status_code=402, detail="Subscription required to create a conversation.", ) # Note: email is no longer available with API key auth controller = ConversationController(db, user_id, None) message_stream = controller.post_message(conversation_id, message, stream=False) async for chunk in message_stream: return chunk @router.post("/project/{project_id}/message/") async def create_conversation_and_message( project_id: str, message: DirectMessageRequest, hidden: bool = Query( True, description="Whether to hide this conversation from the web UI" ), db: Session = Depends(get_db), user=Depends(get_api_key_user), ): if message.content == "" or message.content is None or message.content.isspace(): raise HTTPException(status_code=400, detail="Message content cannot be empty") user_id = user["user_id"] # default agent_id to codebase_qna_agent if message.agent_id is None: message.agent_id = "codebase_qna_agent" controller = ConversationController(db, user_id, None) # Create conversation with hidden parameter res = await controller.create_conversation( CreateConversationRequest( user_id=user_id, title=message.content, project_ids=[project_id], agent_ids=[message.agent_id], status=ConversationStatus.ACTIVE, # Let hidden parameter control the final status ), hidden, ) message_stream = controller.post_message( conversation_id=res.conversation_id, message=MessageRequest(content=message.content, node_ids=message.node_ids), stream=False, ) async for chunk in message_stream: return chunk @router.get("/projects/list") async def list_projects( db: Session = Depends(get_db), user=Depends(get_api_key_user), ): return await ProjectController.get_project_list(user, db) @router.get("/list-available-agents") async def list_agents( db: Session = Depends(get_db), user=Depends(get_api_key_user), ): user_id: str = user["user_id"] llm_provider = ProviderService(db, user_id) tools_provider = ToolService(db, user_id) prompt_provider = PromptService(db) controller = AgentsController(db, llm_provider, prompt_provider, tools_provider) return await controller.list_available_agents(user, True) @router.post("/search", response_model=SearchResponse) async def search_codebase( search_request: SearchRequest, db: Session = Depends(get_db), user=Depends(get_api_key_user), ): """Search codebase using API key authentication""" search_service = SearchService(db) results = await search_service.search_codebase( search_request.project_id, search_request.query ) return SearchResponse(results=results) ``` ## /app/celery/celery_app.py ```py path="/app/celery/celery_app.py" import logging import os from celery import Celery from dotenv import load_dotenv from app.core.models import * # noqa #This will import and initialize all models # Load environment variables from a .env file if present load_dotenv() # Redis configuration redishost = os.getenv("REDISHOST", "localhost") redisport = int(os.getenv("REDISPORT", 6379)) redisuser = os.getenv("REDISUSER", "") redispassword = os.getenv("REDISPASSWORD", "") queue_name = os.getenv("CELERY_QUEUE_NAME", "staging") # Construct the Redis URL if redisuser and redispassword: redis_url = f"redis://{redisuser}:{redispassword}@{redishost}:{redisport}/0" else: redis_url = f"redis://{redishost}:{redisport}/0" # Initialize the Celery app celery_app = Celery("KnowledgeGraph", broker=redis_url, backend=redis_url) # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Add logging for Redis connection logger.info(f"Connecting to Redis at: {redis_url}") try: celery_app.backend.client.ping() logger.info("Successfully connected to Redis") except Exception as e: logger.error(f"Failed to connect to Redis: {str(e)}") def configure_celery(queue_prefix: str): celery_app.conf.update( task_serializer="json", accept_content=["json"], result_serializer="json", timezone="UTC", enable_utc=True, task_routes={ "app.celery.tasks.parsing_tasks.process_parsing": { "queue": f"{queue_prefix}_process_repository" }, }, # Optimize task distribution worker_prefetch_multiplier=1, task_acks_late=True, task_track_started=True, task_time_limit=5400, # 90 minutes in seconds # Add fair task distribution settings worker_max_tasks_per_child=200, # Restart worker after 200 tasks to prevent memory leaks worker_max_memory_per_child=2000000, # Restart worker if using more than 2GB task_default_rate_limit="10/m", # Limit tasks to 10 per minute per worker task_reject_on_worker_lost=True, # Requeue tasks if worker dies broker_transport_options={ "visibility_timeout": 5400 }, # 45 minutes visibility timeout ) configure_celery(queue_name) # Import the lock decorator # Import the lock decorator from celery.contrib.abortable import AbortableTask # noqa # Import tasks to ensure they are registered import app.celery.tasks.parsing_tasks # noqa # Ensure the task module is imported ``` ## /app/celery/tasks/parsing_tasks.py ```py path="/app/celery/tasks/parsing_tasks.py" import asyncio import logging from typing import Any, Dict from celery import Task from app.celery.celery_app import celery_app from app.core.database import SessionLocal from app.modules.parsing.graph_construction.parsing_schema import ParsingRequest from app.modules.parsing.graph_construction.parsing_service import ParsingService logger = logging.getLogger(__name__) class BaseTask(Task): _db = None @property def db(self): if self._db is None: self._db = SessionLocal() return self._db def after_return(self, *args, **kwargs): if self._db is not None: self._db.close() self._db = None @celery_app.task( bind=True, base=BaseTask, name="app.celery.tasks.parsing_tasks.process_parsing", ) def process_parsing( self, repo_details: Dict[str, Any], user_id: str, user_email: str, project_id: str, cleanup_graph: bool = True, ) -> None: logger.info(f"Task received: Starting parsing process for project {project_id}") try: parsing_service = ParsingService(self.db, user_id) async def run_parsing(): import time start_time = time.time() await parsing_service.parse_directory( ParsingRequest(**repo_details), user_id, user_email, project_id, cleanup_graph, ) end_time = time.time() elapsed_time = end_time - start_time logger.info( f"Parsing process took {elapsed_time:.2f} seconds for project {project_id}" ) asyncio.run(run_parsing()) logger.info(f"Parsing process completed for project {project_id}") except Exception as e: logger.error(f"Error during parsing for project {project_id}: {str(e)}") raise logger.info("Parsing tasks module loaded") ``` ## /app/celery/worker.py ```py path="/app/celery/worker.py" # Import the module containing the task from app.celery.celery_app import celery_app, logger from app.celery.tasks.parsing_tasks import ( process_parsing, # Ensure the task is imported ) # Register tasks def register_tasks(): logger.info("Registering tasks") # Register parsing tasks celery_app.tasks.register(process_parsing) # If there are more tasks in other modules, register them here # For example: # from app.celery.tasks import other_tasks # celery_app.tasks.register(other_tasks.some_other_task) logger.info("Tasks registered successfully") # Call register_tasks() immediately register_tasks() logger.info("Celery worker initialization completed") if __name__ == "__main__": logger.info("Starting Celery worker") celery_app.start() ``` ## /app/core/base_model.py ```py path="/app/core/base_model.py" import typing as t from sqlalchemy.ext.declarative import as_declarative, declared_attr class_registry: t.Dict = {} @as_declarative(class_registry=class_registry) class Base: id: t.Any __name__: str # Generate __tablename__ automatically @declared_attr def __tablename__(cls) -> str: return cls.__name__.lower() ``` ## /app/core/config_provider.py ```py path="/app/core/config_provider.py" import os from dotenv import load_dotenv load_dotenv() class ConfigProvider: def __init__(self): self.neo4j_config = { "uri": os.getenv("NEO4J_URI"), "username": os.getenv("NEO4J_USERNAME"), "password": os.getenv("NEO4J_PASSWORD"), } self.github_key = os.getenv("GITHUB_PRIVATE_KEY") self.is_development_mode = os.getenv("isDevelopmentMode", "disabled") def get_neo4j_config(self): return self.neo4j_config def get_github_key(self): return self.github_key def get_demo_repo_list(self): return [ { "id": "demo8", "name": "langchain", "full_name": "langchain-ai/langchain", "private": False, "url": "https://github.com/langchain-ai/langchain", "owner": "langchain-ai", }, { "id": "demo6", "name": "cal.com", "full_name": "calcom/cal.com", "private": False, "url": "https://github.com/calcom/cal.com", "owner": "calcom", }, { "id": "demo5", "name": "formbricks", "full_name": "formbricks/formbricks", "private": False, "url": "https://github.com/formbricks/formbricks", "owner": "formbricks", }, { "id": "demo3", "name": "gateway", "full_name": "Portkey-AI/gateway", "private": False, "url": "https://github.com/Portkey-AI/gateway", "owner": "Portkey-AI", }, { "id": "demo2", "name": "crewAI", "full_name": "crewAIInc/crewAI", "private": False, "url": "https://github.com/crewAIInc/crewAI", "owner": "crewAIInc", }, { "id": "demo1", "name": "agentops", "full_name": "AgentOps-AI/agentops", "private": False, "url": "https://github.com/AgentOps-AI/agentops", "owner": "AgentOps-AI", }, { "id": "demo0", "name": "agentstack", "full_name": "AgentOps-AI/AgentStack", "private": False, "url": "https://github.com/AgentOps-AI/AgentStack", "owner": "AgentOps-AI", }, ] def get_redis_url(self): redishost = os.getenv("REDISHOST", "localhost") redisport = int(os.getenv("REDISPORT", 6379)) redisuser = os.getenv("REDISUSER", "") redispassword = os.getenv("REDISPASSWORD", "") # Construct the Redis URL if redisuser and redispassword: redis_url = f"redis://{redisuser}:{redispassword}@{redishost}:{redisport}/0" else: redis_url = f"redis://{redishost}:{redisport}/0" return redis_url def get_is_development_mode(self): return self.is_development_mode == "enabled" config_provider = ConfigProvider() ``` ## /app/core/database.py ```py path="/app/core/database.py" import os from dotenv import load_dotenv from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker load_dotenv(override=True) # Create engine with connection pooling and best practices engine = create_engine( os.getenv("POSTGRES_SERVER"), pool_size=10, # Initial number of connections in the pool max_overflow=10, # Maximum number of connections beyond pool_size pool_timeout=30, # Timeout in seconds for getting a connection from the pool pool_recycle=1800, # Recycle connections every 30 minutes (to avoid stale connections) pool_pre_ping=True, # Check the connection is alive before using it echo=False, # Set to True for SQL query logging, False in production ) # Create session factory SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # Base class for all ORM models Base = declarative_base() # Dependency to be used in routes def get_db(): db = SessionLocal() try: yield db finally: db.close() ``` ## /app/core/models.py ```py path="/app/core/models.py" from app.modules.conversations.conversation.conversation_model import ( # noqa Conversation, ) from app.modules.conversations.message.message_model import Message # noqa from app.modules.intelligence.prompts.prompt_model import ( # noqa AgentPromptMapping, Prompt, ) from app.modules.intelligence.agents.custom_agents.custom_agent_model import ( # noqa CustomAgent, ) from app.modules.projects.projects_model import Project # noqa from app.modules.search.search_models import SearchIndex # noqa from app.modules.tasks.task_model import Task # noqa from app.modules.users.user_model import User # noqa from app.modules.users.user_preferences_model import UserPreferences # noqa ``` ## /app/main.py ```py path="/app/main.py" import logging import os import subprocess import sentry_sdk from dotenv import load_dotenv from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from app.api.router import router as potpie_api_router from app.core.base_model import Base from app.core.database import SessionLocal, engine from app.core.models import * # noqa #necessary for models to not give import errors from app.modules.auth.auth_router import auth_router from app.modules.code_provider.github.github_router import router as github_router from app.modules.conversations.conversations_router import ( router as conversations_router, ) from app.modules.intelligence.agents.agents_router import router as agent_router from app.modules.intelligence.prompts.prompt_router import router as prompt_router from app.modules.intelligence.prompts.system_prompt_setup import SystemPromptSetup from app.modules.intelligence.provider.provider_router import router as provider_router from app.modules.intelligence.tools.tool_router import router as tool_router from app.modules.key_management.secret_manager import router as secret_manager_router from app.modules.parsing.graph_construction.parsing_router import ( router as parsing_router, ) from app.modules.projects.projects_router import router as projects_router from app.modules.search.search_router import router as search_router from app.modules.usage.usage_router import router as usage_router from app.modules.users.user_router import router as user_router from app.modules.users.user_service import UserService from app.modules.utils.firebase_setup import FirebaseSetup logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) class MainApp: def __init__(self): load_dotenv(override=True) if ( os.getenv("isDevelopmentMode") == "enabled" and os.getenv("ENV") != "development" ): logging.error( "Development mode enabled but ENV is not set to development. Exiting." ) exit(1) self.setup_sentry() self.app = FastAPI() self.setup_cors() self.initialize_database() self.setup_data() self.include_routers() def setup_sentry(self): if os.getenv("ENV") == "production": sentry_sdk.init( dsn=os.getenv("SENTRY_DSN"), traces_sample_rate=0.25, profiles_sample_rate=1.0, ) def setup_cors(self): origins = ["*"] self.app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def setup_data(self): if os.getenv("isDevelopmentMode") == "enabled": logging.info("Development mode enabled. Skipping Firebase setup.") # Setup dummy user for development mode db = SessionLocal() user_service = UserService(db) user_service.setup_dummy_user() db.close() logging.info("Dummy user created") else: FirebaseSetup.firebase_init() def initialize_database(self): # Initialize database tables Base.metadata.create_all(bind=engine) def include_routers(self): self.app.include_router(auth_router, prefix="/api/v1", tags=["Auth"]) self.app.include_router(user_router, prefix="/api/v1", tags=["User"]) self.app.include_router(parsing_router, prefix="/api/v1", tags=["Parsing"]) self.app.include_router( conversations_router, prefix="/api/v1", tags=["Conversations"] ) self.app.include_router(prompt_router, prefix="/api/v1", tags=["Prompts"]) self.app.include_router(projects_router, prefix="/api/v1", tags=["Projects"]) self.app.include_router(search_router, prefix="/api/v1", tags=["Search"]) self.app.include_router(github_router, prefix="/api/v1", tags=["Github"]) self.app.include_router(agent_router, prefix="/api/v1", tags=["Agents"]) self.app.include_router(provider_router, prefix="/api/v1", tags=["Providers"]) self.app.include_router(tool_router, prefix="/api/v1", tags=["Tools"]) self.app.include_router(usage_router, prefix="/api/v1/usage", tags=["Usage"]) self.app.include_router( potpie_api_router, prefix="/api/v2", tags=["Potpie API"] ) self.app.include_router( secret_manager_router, prefix="/api/v1", tags=["Secret Manager"] ) def add_health_check(self): @self.app.get("/health", tags=["Health"]) def health_check(): return { "status": "ok", "version": subprocess.check_output( ["git", "rev-parse", "--short", "HEAD"] ) .strip() .decode("utf-8"), } async def startup_event(self): db = SessionLocal() try: system_prompt_setup = SystemPromptSetup(db) await system_prompt_setup.initialize_system_prompts() logging.info("System prompts initialized successfully") except Exception as e: logging.error(f"Failed to initialize system prompts: {str(e)}") raise finally: db.close() def run(self): self.add_health_check() self.app.add_event_handler("startup", self.startup_event) return self.app # Create an instance of MainApp and run it main_app = MainApp() app = main_app.run() ``` ## /app/modules/auth/api_key_service.py ```py path="/app/modules/auth/api_key_service.py" import hashlib import os import secrets from typing import Optional from fastapi import HTTPException from google.cloud import secretmanager from sqlalchemy import text from sqlalchemy.orm import Session from app.modules.users.user_model import User from app.modules.users.user_preferences_model import UserPreferences import logging logger = logging.getLogger(__name__) class APIKeyService: SECRET_PREFIX = "sk-" KEY_LENGTH = 32 @staticmethod def get_client_and_project(): """Get Secret Manager client and project ID based on environment.""" is_dev_mode = os.getenv("isDevelopmentMode", "enabled") == "enabled" if is_dev_mode: return None, None project_id = os.environ.get("GCP_PROJECT") if not project_id: raise HTTPException( status_code=500, detail="GCP_PROJECT environment variable is not set" ) try: client = secretmanager.SecretManagerServiceClient() return client, project_id except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to initialize Secret Manager client: {str(e)}", ) @staticmethod def generate_api_key() -> str: """Generate a new API key with prefix.""" random_key = secrets.token_hex(APIKeyService.KEY_LENGTH) return f"{APIKeyService.SECRET_PREFIX}{random_key}" @staticmethod def hash_api_key(api_key: str) -> str: """Hash the API key for storage and comparison.""" return hashlib.sha256(api_key.encode()).hexdigest() @staticmethod async def create_api_key(user_id: str, db: Session) -> str: """Create a new API key for a user.""" api_key = APIKeyService.generate_api_key() hashed_key = APIKeyService.hash_api_key(api_key) # Store hashed key in user preferences user_pref = ( db.query(UserPreferences).filter(UserPreferences.user_id == user_id).first() ) if not user_pref: user_pref = UserPreferences(user_id=user_id, preferences={}) db.add(user_pref) if "api_key_hash" not in user_pref.preferences: pref = user_pref.preferences.copy() pref["api_key_hash"] = hashed_key user_pref.preferences = pref db.commit() db.refresh(user_pref) # Store actual key in Secret Manager if os.getenv("isDevelopmentMode") != "enabled": client, project_id = APIKeyService.get_client_and_project() secret_id = f"user-api-key-{user_id}" parent = f"projects/{project_id}" try: # Create secret secret = {"replication": {"automatic": {}}} response = client.create_secret( request={"parent": parent, "secret_id": secret_id, "secret": secret} ) # Add secret version version = {"payload": {"data": api_key.encode("UTF-8")}} client.add_secret_version( request={"parent": response.name, "payload": version["payload"]} ) except Exception as e: # Rollback database changes if secret manager fails if "api_key_hash" in user_pref.preferences: del user_pref.preferences["api_key_hash"] db.commit() raise HTTPException( status_code=500, detail=f"Failed to store API key: {str(e)}" ) return api_key @staticmethod async def validate_api_key(api_key: str, db: Session) -> Optional[dict]: """Validate an API key and return user info if valid.""" try: # Check if API key follows the correct syntax and prefix if not api_key.startswith(APIKeyService.SECRET_PREFIX): logger.error( f"Invalid API key format: missing required prefix '{APIKeyService.SECRET_PREFIX}'" ) return None hashed_key = APIKeyService.hash_api_key(api_key) # Find user with matching hashed key result = ( db.query(UserPreferences, User.email) .join(User, UserPreferences.user_id == User.uid) .filter(text("preferences->>'api_key_hash' = :hashed_key")) .params(hashed_key=hashed_key) .first() ) # No match found for Hashed API key if not result: logger.error("No user found with the provided API key hash") return None user_pref, email = result return { "user_id": user_pref.user_id, "email": email, "auth_type": "api_key", } except Exception as e: logger.error(f"Error validating API key: {str(e)}") return None @staticmethod async def revoke_api_key(user_id: str, db: Session) -> bool: """Revoke a user's API key.""" user_pref = ( db.query(UserPreferences).filter(UserPreferences.user_id == user_id).first() ) if not user_pref: return False if "api_key_hash" in user_pref.preferences: # Create a new dictionary without the api_key_hash updated_preferences = user_pref.preferences.copy() del updated_preferences["api_key_hash"] user_pref.preferences = updated_preferences db.commit() # Delete from Secret Manager if not in dev mode if os.getenv("isDevelopmentMode") != "enabled": client, project_id = APIKeyService.get_client_and_project() secret_id = f"user-api-key-{user_id}" name = f"projects/{project_id}/secrets/{secret_id}" try: client.delete_secret(request={"name": name}) except Exception: pass # Ignore if secret doesn't exist return True @staticmethod async def get_api_key(user_id: str, db: Session) -> Optional[str]: """Retrieve the existing API key for a user.""" user_pref = ( db.query(UserPreferences).filter(UserPreferences.user_id == user_id).first() ) if not user_pref or "api_key_hash" not in user_pref.preferences: return None if os.getenv("isDevelopmentMode") == "enabled": return None # In dev mode, we can't retrieve the actual key for security client, project_id = APIKeyService.get_client_and_project() secret_id = f"user-api-key-{user_id}" name = f"projects/{project_id}/secrets/{secret_id}/versions/latest" try: response = client.access_secret_version(request={"name": name}) return response.payload.data.decode("UTF-8") except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to retrieve API key: {str(e)}" ) ``` ## /app/modules/auth/auth_router.py ```py path="/app/modules/auth/auth_router.py" import json import os from datetime import datetime import requests from dotenv import load_dotenv from fastapi import Depends, Request from fastapi.responses import JSONResponse, Response from fastapi.exceptions import HTTPException from sqlalchemy.orm import Session from app.core.database import get_db from app.modules.auth.auth_schema import LoginRequest from app.modules.auth.auth_service import auth_handler from app.modules.users.user_schema import CreateUser from app.modules.users.user_service import UserService from app.modules.utils.APIRouter import APIRouter from app.modules.utils.posthog_helper import PostHogClient SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL", None) auth_router = APIRouter() load_dotenv(override=True) async def send_slack_message(message: str): payload = {"text": message} if SLACK_WEBHOOK_URL: requests.post(SLACK_WEBHOOK_URL, json=payload) class AuthAPI: @auth_router.post("/login") async def login(login_request: LoginRequest): email, password = login_request.email, login_request.password try: res = auth_handler.login(email=email, password=password) id_token = res.get("idToken") return JSONResponse(content={"token": id_token}, status_code=200) except ValueError: return JSONResponse( content={"error": "Invalid email or password"}, status_code=401 ) except HTTPException as he: return JSONResponse( content={"error": f"HTTP Error: {str(he)}"}, status_code=he.status_code ) except Exception as e: return JSONResponse(content={"error": f"ERROR: {str(e)}"}, status_code=400) @auth_router.post("/signup") async def signup(request: Request, db: Session = Depends(get_db)): body = json.loads(await request.body()) uid = body["uid"] oauth_token = body["accessToken"] user_service = UserService(db) user = user_service.get_user_by_uid(uid) if user: message, error = user_service.update_last_login(uid, oauth_token) if error: return Response(content=message, status_code=400) else: return Response( content=json.dumps({"uid": uid, "exists": True}), status_code=200, ) else: first_login = datetime.utcnow() provider_info = body["providerData"][0] provider_info["access_token"] = oauth_token user = CreateUser( uid=uid, email=body["email"], display_name=body["displayName"], email_verified=body["emailVerified"], created_at=first_login, last_login_at=first_login, provider_info=provider_info, provider_username=body["providerUsername"], ) uid, message, error = user_service.create_user(user) await send_slack_message( f"New signup: {body['email']} ({body['displayName']})" ) PostHogClient().send_event( uid, "signup_event", { "email": body["email"], "display_name": body["displayName"], "github_username": body["providerUsername"], }, ) if error: return Response(content=message, status_code=400) return Response( content=json.dumps({"uid": uid, "exists": False}), status_code=201, ) ``` ## /app/modules/auth/auth_schema.py ```py path="/app/modules/auth/auth_schema.py" from pydantic import BaseModel class LoginRequest(BaseModel): email: str password: str ``` ## /app/modules/auth/auth_service.py ```py path="/app/modules/auth/auth_service.py" import logging import os import requests from dotenv import load_dotenv from fastapi import Depends, HTTPException, Request, Response, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from firebase_admin import auth from firebase_admin.exceptions import FirebaseError load_dotenv(override=True) class AuthService: def login(self, email, password): log_prefix = "AuthService::login:" identity_tool_kit_id = os.getenv("GOOGLE_IDENTITY_TOOL_KIT_KEY") identity_url = f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={identity_tool_kit_id}" user_auth_response = requests.post( url=identity_url, json={ "email": email, "password": password, "returnSecureToken": True, }, ) try: user_auth_response.raise_for_status() return user_auth_response.json() except Exception as e: logging.exception(f"{log_prefix} {str(e)}") raise Exception(user_auth_response.json()) def signup(self, email: str, password: str, name: str) -> tuple: try: user = auth.create_user(email=email, password=password, display_name=name) return {"user": user, "message": "New user created successfully"}, None except FirebaseError as fe: return None, {"error": f"Firebase error: {fe.message}"} except ValueError as _ve: return None, {"error": "Invalid input data provided."} except Exception as e: return None, {"error": f"An unexpected error occurred: {str(e)}"} @classmethod @staticmethod async def check_auth( request: Request, res: Response, credential: HTTPAuthorizationCredentials = Depends( HTTPBearer(auto_error=False) ), ): # Check if the application is in debug mode if os.getenv("isDevelopmentMode") == "enabled" and credential is None: request.state.user = {"user_id": os.getenv("defaultUsername")} logging.info("Development mode enabled. Using Mock Authentication.") return { "user_id": os.getenv("defaultUsername"), "email": "defaultuser@potpie.ai", } else: if credential is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Bearer authentication is needed", headers={"WWW-Authenticate": 'Bearer realm="auth_required"'}, ) try: decoded_token = auth.verify_id_token(credential.credentials) request.state.user = decoded_token except Exception as err: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Invalid authentication from Firebase. {err}", headers={"WWW-Authenticate": 'Bearer error="invalid_token"'}, ) res.headers["WWW-Authenticate"] = 'Bearer realm="auth_required"' return decoded_token auth_handler = AuthService() ``` ## /app/modules/auth/tests/auth_service_test.py ```py path="/app/modules/auth/tests/auth_service_test.py" import os import pytest import requests from unittest.mock import patch, MagicMock from fastapi import HTTPException from app.modules.auth.auth_service import AuthService class TestAuthService: """Test cases for AuthService class.""" @pytest.fixture def auth_service(self): """Create an instance of AuthService for testing.""" return AuthService() @pytest.fixture def mock_env_vars(self): """Mock environment variables for testing.""" with patch.dict( os.environ, { "GOOGLE_IDENTITY_TOOL_KIT_KEY": "test_key", "isDevelopmentMode": "disabled", "defaultUsername": "test_user", }, ): yield @pytest.fixture def mock_successful_login_response(self): """Mock successful login response.""" mock_response = MagicMock() mock_response.json.return_value = { "idToken": "test_token", "email": "test@example.com", "localId": "test_user_id", } mock_response.raise_for_status.return_value = None return mock_response @pytest.fixture def mock_failed_login_response(self): """Mock failed login response.""" mock_response = MagicMock() mock_response.json.return_value = {"error": {"message": "INVALID_PASSWORD"}} mock_response.raise_for_status.side_effect = Exception(mock_response.json()) return mock_response class TestLogin: """Test cases for login functionality.""" def test_login_success( self, auth_service, mock_env_vars, mock_successful_login_response ): """Test successful login with valid credentials.""" with patch("requests.post", return_value=mock_successful_login_response): result = auth_service.login("test@example.com", "valid_password") assert result["idToken"] == "test_token" assert result["email"] == "test@example.com" def test_login_invalid_credentials( self, auth_service, mock_env_vars, mock_failed_login_response ): """Test login with invalid credentials.""" with patch("requests.post", return_value=mock_failed_login_response): with pytest.raises(Exception) as exc_info: auth_service.login("test@example.com", "invalid_password") assert "INVALID_PASSWORD" in str(exc_info.value) def test_login_empty_credentials(self, auth_service, mock_env_vars): """Test login with empty credentials.""" mock_response = MagicMock() mock_response.json.return_value = {"error": {"message": "MISSING_EMAIL"}} mock_response.raise_for_status.side_effect = Exception(mock_response.json()) with patch("requests.post", return_value=mock_response): with pytest.raises(Exception) as exc_info: auth_service.login("", "") assert "MISSING_EMAIL" in str(exc_info.value) def test_login_network_error(self, auth_service, mock_env_vars): """Test login with network connection error.""" with patch( "requests.post", side_effect=requests.exceptions.ConnectionError("Failed to connect"), ): with pytest.raises(Exception) as exc_info: auth_service.login("test@example.com", "password") assert isinstance(exc_info.value, requests.exceptions.ConnectionError) def test_login_timeout_error(self, auth_service, mock_env_vars): """Test login with request timeout.""" with patch( "requests.post", side_effect=requests.exceptions.Timeout("Request timed out"), ): with pytest.raises(Exception) as exc_info: auth_service.login("test@example.com", "password") assert isinstance(exc_info.value, requests.exceptions.Timeout) class TestSignup: """Test cases for signup functionality.""" @pytest.fixture def mock_user(self): """Mock user object for successful signup.""" mock_user = MagicMock() mock_user.uid = "test_user_id" mock_user.email = "test@example.com" mock_user.display_name = "Test User" return mock_user def test_signup_success(self, auth_service, mock_user): """Test successful user signup.""" with patch("firebase_admin.auth.create_user", return_value=mock_user): result = auth_service.signup( "test@example.com", "password123", "Test User" ) assert result.uid == "test_user_id" assert result.email == "test@example.com" assert result.display_name == "Test User" def test_signup_duplicate_email(self, auth_service): """Test signup with duplicate email.""" with patch( "firebase_admin.auth.create_user", side_effect=Exception("Email already exists"), ): with pytest.raises(Exception) as exc_info: auth_service.signup( "existing@example.com", "password123", "Test User" ) assert "Email already exists" in str(exc_info.value) def test_signup_invalid_email_format(self, auth_service): """Test signup with invalid email format.""" with patch( "firebase_admin.auth.create_user", side_effect=Exception("INVALID_EMAIL"), ): with pytest.raises(Exception) as exc_info: auth_service.signup("invalid-email", "password123", "Test User") assert "INVALID_EMAIL" in str(exc_info.value) def test_signup_weak_password(self, auth_service): """Test signup with weak password.""" with patch( "firebase_admin.auth.create_user", side_effect=Exception("WEAK_PASSWORD"), ): with pytest.raises(Exception) as exc_info: auth_service.signup("test@example.com", "123", "Test User") assert "WEAK_PASSWORD" in str(exc_info.value) def test_signup_empty_display_name(self, auth_service): """Test signup with empty display name.""" with patch( "firebase_admin.auth.create_user", side_effect=Exception("INVALID_DISPLAY_NAME"), ): with pytest.raises(Exception) as exc_info: auth_service.signup("test@example.com", "password123", "") assert "INVALID_DISPLAY_NAME" in str(exc_info.value) class TestAuthCheck: """Test cases for authentication check functionality.""" @pytest.fixture def mock_request_response(self): """Mock request and response objects.""" mock_request = MagicMock() mock_response = MagicMock() return mock_request, mock_response @pytest.mark.asyncio async def test_check_auth_valid_token( self, auth_service, mock_env_vars, mock_request_response ): """Test authentication check with valid token.""" mock_request, mock_response = mock_request_response mock_credential = MagicMock() mock_credential.credentials = "valid_token" mock_decoded_token = {"uid": "test_user_id", "email": "test@example.com"} with patch( "firebase_admin.auth.verify_id_token", return_value=mock_decoded_token ): result = await auth_service.check_auth( mock_request, mock_response, mock_credential ) assert result == mock_decoded_token assert mock_request.state.user == mock_decoded_token @pytest.mark.asyncio async def test_check_auth_invalid_token( self, auth_service, mock_env_vars, mock_request_response ): """Test authentication check with invalid token.""" mock_request, mock_response = mock_request_response mock_credential = MagicMock() mock_credential.credentials = "invalid_token" with patch( "firebase_admin.auth.verify_id_token", side_effect=Exception("Invalid token"), ): with pytest.raises(HTTPException) as exc_info: await auth_service.check_auth( mock_request, mock_response, mock_credential ) assert exc_info.value.status_code == 401 assert "Invalid authentication from Firebase" in str( exc_info.value.detail ) @pytest.mark.asyncio async def test_check_auth_missing_token( self, auth_service, mock_env_vars, mock_request_response ): """Test authentication check with missing token.""" mock_request, mock_response = mock_request_response with pytest.raises(HTTPException) as exc_info: await auth_service.check_auth(mock_request, mock_response, None) assert exc_info.value.status_code == 401 assert "Bearer authentication is needed" in str(exc_info.value.detail) @pytest.mark.asyncio async def test_check_auth_development_mode( self, auth_service, mock_request_response ): """Test authentication check in development mode.""" mock_request, mock_response = mock_request_response with patch.dict( os.environ, {"isDevelopmentMode": "enabled", "defaultUsername": "dev_user"}, ): result = await auth_service.check_auth( mock_request, mock_response, None ) assert result["user_id"] == "dev_user" assert result["email"] == "defaultuser@potpie.ai" assert mock_request.state.user == {"user_id": "dev_user"} @pytest.mark.asyncio async def test_check_auth_expired_token( self, auth_service, mock_env_vars, mock_request_response ): """Test authentication check with expired token.""" mock_request, mock_response = mock_request_response mock_credential = MagicMock() mock_credential.credentials = "expired_token" with patch( "firebase_admin.auth.verify_id_token", side_effect=Exception("Token has expired"), ): with pytest.raises(HTTPException) as exc_info: await auth_service.check_auth( mock_request, mock_response, mock_credential ) assert exc_info.value.status_code == 401 assert "Invalid authentication from Firebase" in str( exc_info.value.detail ) @pytest.mark.asyncio async def test_check_auth_malformed_token( self, auth_service, mock_env_vars, mock_request_response ): """Test authentication check with malformed token.""" mock_request, mock_response = mock_request_response mock_credential = MagicMock() mock_credential.credentials = "malformed_token" with patch( "firebase_admin.auth.verify_id_token", side_effect=Exception("Malformed token"), ): with pytest.raises(HTTPException) as exc_info: await auth_service.check_auth( mock_request, mock_response, mock_credential ) assert exc_info.value.status_code == 401 assert "Invalid authentication from Firebase" in str( exc_info.value.detail ) ``` ## /app/modules/code_provider/code_provider_service.py ```py path="/app/modules/code_provider/code_provider_service.py" import os from typing import Optional from app.modules.code_provider.github.github_service import GithubService from app.modules.code_provider.local_repo.local_repo_service import LocalRepoService class CodeProviderService: def __init__(self, sql_db): self.sql_db = sql_db self.service_instance = self._get_service_instance() def _get_service_instance(self): if os.getenv("isDevelopmentMode") == "enabled": return LocalRepoService(self.sql_db) else: return GithubService(self.sql_db) def get_repo(self, repo_name): return self.service_instance.get_repo(repo_name) async def get_project_structure_async(self, project_id, path: Optional[str] = None): return await self.service_instance.get_project_structure_async(project_id, path) def get_file_content( self, repo_name, file_path, start_line, end_line, branch_name, project_id ): return self.service_instance.get_file_content( repo_name, file_path, start_line, end_line, branch_name, project_id ) ``` ## /app/modules/code_provider/github/github_controller.py ```py path="/app/modules/code_provider/github/github_controller.py" from fastapi import HTTPException from sqlalchemy.orm import Session from app.modules.code_provider.github.github_service import GithubService class GithubController: def __init__(self, db: Session): self.github_service = GithubService(db) async def get_user_repos(self, user): user_id = user["user_id"] return await self.github_service.get_combined_user_repos(user_id) async def get_branch_list(self, repo_name: str): return await self.github_service.get_branch_list(repo_name) async def check_public_repo(self, repo_name: str): is_public = await self.github_service.check_public_repo(repo_name) if not is_public: raise HTTPException(status_code=403, detail="Repository is not found") return {"is_public": is_public} ``` ## /app/modules/code_provider/github/github_router.py ```py path="/app/modules/code_provider/github/github_router.py" from fastapi import Depends, Query from sqlalchemy.orm import Session from app.core.config_provider import config_provider from app.core.database import get_db from app.modules.auth.auth_service import AuthService from app.modules.code_provider.github.github_controller import GithubController from app.modules.utils.APIRouter import APIRouter router = APIRouter() @router.get("/github/user-repos") async def get_user_repos( user=Depends(AuthService.check_auth), db: Session = Depends(get_db) ): user_repo_list = await GithubController(db).get_user_repos(user=user) if not config_provider.get_is_development_mode(): user_repo_list["repositories"].extend(config_provider.get_demo_repo_list()) # Remove duplicates while preserving order seen = set() deduped_repos = [] for repo in reversed(user_repo_list["repositories"]): # Create tuple of values to use as hash key repo_key = repo["full_name"] if repo_key not in seen: seen.add(repo_key) deduped_repos.append(repo) user_repo_list["repositories"] = deduped_repos return user_repo_list @router.get("/github/get-branch-list") async def get_branch_list( repo_name: str = Query(..., description="Repository name"), user=Depends(AuthService.check_auth), db: Session = Depends(get_db), ): return await GithubController(db).get_branch_list(repo_name=repo_name) @router.get("/github/check-public-repo") async def check_public_repo( repo_name: str = Query(..., description="Repository name"), user=Depends(AuthService.check_auth), db: Session = Depends(get_db), ): return await GithubController(db).check_public_repo(repo_name=repo_name) ``` ## /app/modules/code_provider/github/github_service.py ```py path="/app/modules/code_provider/github/github_service.py" import asyncio import logging import os import random import re from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional, Tuple import aiohttp import chardet import git import requests from fastapi import HTTPException from github import Github from github.Auth import AppAuth from sqlalchemy import func from sqlalchemy.orm import Session from redis import Redis from app.core.config_provider import config_provider from app.modules.projects.projects_model import Project from app.modules.projects.projects_service import ProjectService from app.modules.users.user_model import User logger = logging.getLogger(__name__) class GithubService: gh_token_list: List[str] = [] @classmethod def initialize_tokens(cls): token_string = os.getenv("GH_TOKEN_LIST", "") cls.gh_token_list = [ token.strip() for token in token_string.split(",") if token.strip() ] if not cls.gh_token_list: raise ValueError( "GitHub token list is empty or not set in environment variables" ) logger.info(f"Initialized {len(cls.gh_token_list)} GitHub tokens") def __init__(self, db: Session): self.db = db self.project_manager = ProjectService(db) if not GithubService.gh_token_list: GithubService.initialize_tokens() self.redis = Redis.from_url(config_provider.get_redis_url()) self.max_workers = 10 self.max_depth = 4 self.executor = ThreadPoolExecutor(max_workers=self.max_workers) self.is_development_mode = config_provider.get_is_development_mode() def get_github_repo_details(self, repo_name: str) -> Tuple[Github, Dict, str]: private_key = ( "-----BEGIN RSA PRIVATE KEY-----\n" + config_provider.get_github_key() + "\n-----END RSA PRIVATE KEY-----\n" ) app_id = os.environ["GITHUB_APP_ID"] auth = AppAuth(app_id=app_id, private_key=private_key) jwt = auth.create_jwt() owner = repo_name.split("/")[0] url = f"https://api.github.com/repos/{owner}/{repo_name.split('/')[1]}/installation" headers = { "Accept": "application/vnd.github+json", "Authorization": f"Bearer {jwt}", "X-GitHub-Api-Version": "2022-11-28", } response = requests.get(url, headers=headers) if response.status_code != 200: raise HTTPException( status_code=400, detail=f"Failed to get installation ID for {repo_name}" ) app_auth = auth.get_installation_auth(response.json()["id"]) github = Github(auth=app_auth) return github, response.json(), owner def get_github_app_client(self, repo_name: str) -> Github: try: # Try authenticated access first private_key = ( "-----BEGIN RSA PRIVATE KEY-----\n" + config_provider.get_github_key() + "\n-----END RSA PRIVATE KEY-----\n" ) app_id = os.environ["GITHUB_APP_ID"] auth = AppAuth(app_id=app_id, private_key=private_key) jwt = auth.create_jwt() # Get installation ID url = f"https://api.github.com/repos/{repo_name}/installation" headers = { "Accept": "application/vnd.github+json", "Authorization": f"Bearer {jwt}", "X-GitHub-Api-Version": "2022-11-28", } response = requests.get(url, headers=headers) if response.status_code != 200: raise Exception(f"Failed to get installation ID for {repo_name}") app_auth = auth.get_installation_auth(response.json()["id"]) return Github(auth=app_auth) except Exception as private_error: logging.info(f"Failed to access private repo: {str(private_error)}") # If authenticated access fails, try public access try: return self.get_public_github_instance() except Exception as public_error: logging.error(f"Failed to access public repo: {str(public_error)}") raise Exception( f"Repository {repo_name} not found or inaccessible on GitHub" ) def get_file_content( self, repo_name: str, file_path: str, start_line: int, end_line: int, branch_name: str, project_id: str, ) -> str: logger.info(f"Attempting to access file: {file_path} in repo: {repo_name}") try: # Try authenticated access first github, repo = self.get_repo(repo_name) file_contents = repo.get_contents(file_path, ref=branch_name) except Exception as private_error: logger.info(f"Failed to access private repo: {str(private_error)}") # If authenticated access fails, try public access try: github = self.get_public_github_instance() repo = github.get_repo(repo_name) file_contents = repo.get_contents(file_path) except Exception as public_error: logger.error(f"Failed to access public repo: {str(public_error)}") raise HTTPException( status_code=404, detail=f"Repository or file not found or inaccessible: {repo_name}/{file_path}", ) if isinstance(file_contents, list): raise HTTPException( status_code=400, detail="Provided path is a directory, not a file" ) try: content_bytes = file_contents.decoded_content encoding = self._detect_encoding(content_bytes) decoded_content = content_bytes.decode(encoding) lines = decoded_content.splitlines() if (start_line == end_line == 0) or (start_line == end_line == None): return decoded_content # added -2 to start and end line to include the function definition/ decorator line start = start_line - 2 if start_line - 2 > 0 else 0 selected_lines = lines[start:end_line] return "\n".join(selected_lines) except Exception as e: logger.error( f"Error processing file content for {repo_name}/{file_path}: {e}", exc_info=True, ) raise HTTPException( status_code=500, detail=f"Error processing file content: {str(e)}", ) @staticmethod def _detect_encoding(content_bytes: bytes) -> str: detection = chardet.detect(content_bytes) encoding = detection["encoding"] confidence = detection["confidence"] if not encoding or confidence < 0.5: raise HTTPException( status_code=400, detail="Unable to determine file encoding or low confidence", ) return encoding def get_github_oauth_token(self, uid: str) -> str: user = self.db.query(User).filter(User.uid == uid).first() if user is None: raise HTTPException(status_code=404, detail="User not found") return user.provider_info["access_token"] def _parse_link_header(self, link_header: str) -> Dict[str, str]: """Parse GitHub Link header to extract pagination URLs.""" links = {} if not link_header: return links for link in link_header.split(","): parts = link.strip().split(";") if len(parts) < 2: continue url = parts[0].strip()[1:-1] # Remove < and > for p in parts[1:]: if "rel=" in p: rel = p.strip().split("=")[1].strip('"') links[rel] = url break return links async def get_repos_for_user(self, user_id: str): if self.is_development_mode: return {"repositories": []} import time # Import the time module start_time = time.time() # Start timing the entire method try: user = self.db.query(User).filter(User.uid == user_id).first() if user is None: raise HTTPException(status_code=404, detail="User not found") firebase_uid = user.uid github_username = user.provider_username if not github_username: raise HTTPException( status_code=400, detail="GitHub username not found for this user" ) github_oauth_token = self.get_github_oauth_token(firebase_uid) if not github_oauth_token: raise HTTPException( status_code=400, detail="GitHub OAuth token not found for this user" ) user_github = Github(github_oauth_token) user_orgs = user_github.get_user().get_orgs() org_logins = [org.login.lower() for org in user_orgs] private_key = ( "-----BEGIN RSA PRIVATE KEY-----\n" + config_provider.get_github_key() + "\n-----END RSA PRIVATE KEY-----\n" ) app_id = os.environ["GITHUB_APP_ID"] auth = AppAuth(app_id=app_id, private_key=private_key) jwt = auth.create_jwt() all_installations = [] base_url = "https://api.github.com/app/installations" headers = { "Accept": "application/vnd.github+json", "Authorization": f"Bearer {jwt}", "X-GitHub-Api-Version": "2022-11-28", } async with aiohttp.ClientSession() as session: # Get first page to determine total pages async with session.get( f"{base_url}?per_page=100", headers=headers ) as response: if response.status != 200: error_text = await response.text() logger.error( f"Failed to get installations. Response: {error_text}" ) raise HTTPException( status_code=response.status, detail=f"Failed to get installations: {error_text}", ) # Extract last page number from Link header last_page = 1 if "Link" in response.headers: links = self._parse_link_header(response.headers["Link"]) if "last" in links: last_url = links["last"] match = re.search(r"[?&]page=(\d+)", last_url) if match: last_page = int(match.group(1)) first_page_data = await response.json() all_installations.extend(first_page_data) # Generate remaining page URLs (skip page 1) page_urls = [ f"{base_url}?page={page}&per_page=100" for page in range(2, last_page + 1) ] # Process URLs in batches of 10 async def fetch_page(url): try: async with session.get(url, headers=headers) as response: if response.status == 200: installations = await response.json() return installations else: error_text = await response.text() logger.error( f"Failed to fetch page {url}. Response: {error_text}" ) return [] except Exception as e: logger.error(f"Error fetching page {url}: {str(e)}") return [] # Process URLs in batches of 10 for i in range(0, len(page_urls), 10): batch = page_urls[i : i + 10] batch_tasks = [fetch_page(url) for url in batch] batch_results = await asyncio.gather(*batch_tasks) for installations in batch_results: all_installations.extend(installations) # Filter installations user_installations = [] for installation in all_installations: account = installation["account"] account_login = account["login"].lower() account_type = account["type"] if ( account_type == "User" and account_login == github_username.lower() ): user_installations.append(installation) elif account_type == "Organization" and account_login in org_logins: user_installations.append(installation) # Fetch repositories for each installation repos = [] for installation in user_installations: app_auth = auth.get_installation_auth(installation["id"]) repos_url = installation["repositories_url"] github = Github(auth=app_auth) # do not remove this line auth_headers = {"Authorization": f"Bearer {app_auth.token}"} async with session.get( f"{repos_url}?per_page=100", headers=auth_headers ) as response: if response.status != 200: logger.error( f"Failed to fetch repositories for installation ID {installation['id']}. Response: {await response.text()}" ) continue first_page_data = await response.json() repos.extend(first_page_data.get("repositories", [])) # Get last page from Link header last_page = 1 if "Link" in response.headers: links = self._parse_link_header(response.headers["Link"]) if "last" in links: last_url = links["last"] match = re.search(r"[?&]page=(\d+)", last_url) if match: last_page = int(match.group(1)) if last_page > 1: # Generate remaining page URLs (skip page 1) page_urls = [ f"{repos_url}?page={page}&per_page=100" for page in range(2, last_page + 1) ] # Process URLs in batches of 10 for i in range(0, len(page_urls), 10): batch = page_urls[i : i + 10] tasks = [ session.get(url, headers=auth_headers) for url in batch ] responses = await asyncio.gather(*tasks) for response in responses: async with response: if response.status == 200: page_data = await response.json() repos.extend( page_data.get("repositories", []) ) else: logger.error( f"Failed to fetch repositories page. Response: {await response.text()}" ) # Remove duplicate repositories unique_repos = {repo["id"]: repo for repo in repos}.values() repo_list = [ { "id": repo["id"], "name": repo["name"], "full_name": repo["full_name"], "private": repo["private"], "url": repo["html_url"], "owner": repo["owner"]["login"], } for repo in unique_repos ] return {"repositories": repo_list} except Exception as e: logger.error(f"Failed to fetch repositories: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to fetch repositories: {str(e)}" ) finally: total_duration = time.time() - start_time # Calculate total duration logger.info( f"get_repos_for_user executed in {total_duration:.2f} seconds" ) # Log total duration async def get_combined_user_repos(self, user_id: str): subquery = ( self.db.query(Project.repo_name, func.min(Project.id).label("min_id")) .filter(Project.user_id == user_id) .group_by(Project.repo_name) .subquery() ) projects = ( self.db.query(Project) .join( subquery, (Project.repo_name == subquery.c.repo_name) & (Project.id == subquery.c.min_id), ) .all() ) project_list = ( [ { "id": project.id, "name": project.repo_name.split("/")[-1], "full_name": ( project.repo_name if not self.is_development_mode else project.repo_path ), "private": False, "url": f"https://github.com/{project.repo_name}", "owner": project.repo_name.split("/")[0], } for project in projects ] if projects is not None else [] ) user_repo_response = await self.get_repos_for_user(user_id) user_repos = user_repo_response["repositories"] db_project_full_names = {project["full_name"] for project in project_list} filtered_user_repos = [ {**user_repo, "private": True} for user_repo in user_repos if user_repo["full_name"] not in db_project_full_names # Only include unique user repos ] combined_repos = list(reversed(project_list + filtered_user_repos)) return {"repositories": combined_repos} async def get_branch_list(self, repo_name: str): try: # Check if repo_name is a path to a local repository if os.path.exists(repo_name) and os.path.isdir(repo_name): try: # Handle local repository local_repo = git.Repo(repo_name) # Get the default branch try: default_branch = local_repo.git.symbolic_ref( "refs/remotes/origin/HEAD" ).split("/")[-1] except git.GitCommandError: # If no remote HEAD is found, use the current branch default_branch = local_repo.active_branch.name # Get all local branches branches = [ branch.name for branch in local_repo.heads if branch.name != default_branch ] return {"branches": [default_branch] + branches} except git.InvalidGitRepositoryError: raise HTTPException( status_code=404, detail=f"Not a valid git repository: {repo_name}", ) except Exception as e: logger.error( f"Error fetching branches for local repo {repo_name}: {str(e)}", exc_info=True, ) raise HTTPException( status_code=500, detail=f"Error fetching branches for local repo: {str(e)}", ) else: # Handle GitHub repository (existing functionality) github, repo = self.get_repo(repo_name) default_branch = repo.default_branch branches = repo.get_branches() branch_list = [ branch.name for branch in branches if branch.name != default_branch ] return {"branches": [default_branch] + branch_list} except HTTPException as he: raise he except Exception as e: logger.error( f"Error fetching branches for repo {repo_name}: {str(e)}", exc_info=True ) raise HTTPException( status_code=404, detail=f"Repository not found or error fetching branches: {str(e)}", ) @classmethod def get_public_github_instance(cls): if not cls.gh_token_list: cls.initialize_tokens() token = random.choice(cls.gh_token_list) return Github(token) def get_repo(self, repo_name: str) -> Tuple[Github, Any]: try: # Try authenticated access first github, _, _ = self.get_github_repo_details(repo_name) repo = github.get_repo(repo_name) return github, repo except Exception as private_error: logger.info( f"Failed to access private repo {repo_name}: {str(private_error)}" ) # If authenticated access fails, try public access try: github = self.get_public_github_instance() repo = github.get_repo(repo_name) return github, repo except Exception as public_error: logger.error( f"Failed to access public repo {repo_name}: {str(public_error)}" ) raise HTTPException( status_code=404, detail=f"Repository {repo_name} not found or inaccessible on GitHub", ) async def get_project_structure_async( self, project_id: str, path: Optional[str] = None ) -> str: logger.info( f"Fetching project structure for project ID: {project_id}, path: {path}" ) # Modify cache key to reflect that we're only caching the specific path cache_key = ( f"project_structure:{project_id}:exact_path_{path}:depth_{self.max_depth}" ) cached_structure = self.redis.get(cache_key) if cached_structure: logger.info( f"Project structure found in cache for project ID: {project_id}, path: {path}" ) return cached_structure.decode("utf-8") project = await self.project_manager.get_project_from_db_by_id(project_id) if not project: raise HTTPException(status_code=404, detail="Project not found") repo_name = project["project_name"] if not repo_name: raise HTTPException( status_code=400, detail="Project has no associated GitHub repository" ) try: github, repo = self.get_repo(repo_name) # If path is provided, verify it exists if path: try: # Check if the path exists in the repository repo.get_contents(path) except Exception: raise HTTPException( status_code=404, detail=f"Path {path} not found in repository" ) # Start structure fetch from the specified path with depth 0 structure = await self._fetch_repo_structure_async( repo, path or "", current_depth=0, base_path=path ) formatted_structure = self._format_tree_structure(structure) self.redis.setex(cache_key, 3600, formatted_structure) # Cache for 1 hour return formatted_structure except HTTPException as he: raise he except Exception as e: logger.error( f"Error fetching project structure for {repo_name}: {str(e)}", exc_info=True, ) raise HTTPException( status_code=500, detail=f"Failed to fetch project structure: {str(e)}" ) async def _fetch_repo_structure_async( self, repo: Any, path: str = "", current_depth: int = 0, base_path: Optional[str] = None, ) -> Dict[str, Any]: exclude_extensions = [ "png", "jpg", "jpeg", "gif", "bmp", "tiff", "webp", "ico", "svg", "mp4", "avi", "mov", "wmv", "flv", "ipynb", "zlib", ] # Calculate current depth relative to base_path if base_path: # If we have a base_path, calculate depth relative to it relative_path = path[len(base_path) :].strip("/") current_depth = len(relative_path.split("/")) if relative_path else 0 else: # If no base_path, calculate depth from root current_depth = len(path.split("/")) if path else 0 # If we've reached max depth, return truncated indicator if current_depth >= self.max_depth: return { "type": "directory", "name": path.split("/")[-1] or repo.name, "children": [{"type": "file", "name": "...", "path": "truncated"}], } structure = { "type": "directory", "name": path.split("/")[-1] or repo.name, "children": [], } try: contents = await asyncio.get_event_loop().run_in_executor( self.executor, repo.get_contents, path ) if not isinstance(contents, list): contents = [contents] # Filter out files with excluded extensions contents = [ item for item in contents if item.type == "dir" or not any(item.name.endswith(ext) for ext in exclude_extensions) ] tasks = [] for item in contents: # Only process items within the base_path if it's specified if base_path and not item.path.startswith(base_path): continue if item.type == "dir": task = self._fetch_repo_structure_async( repo, item.path, current_depth=current_depth, base_path=base_path, ) tasks.append(task) else: structure["children"].append( { "type": "file", "name": item.name, "path": item.path, } ) if tasks: children = await asyncio.gather(*tasks) structure["children"].extend(children) except Exception as e: logger.error(f"Error fetching contents for path {path}: {str(e)}") return structure def _format_tree_structure( self, structure: Dict[str, Any], root_path: str = "" ) -> str: """ Creates a clear hierarchical structure using simple nested dictionaries. Args: self: The instance object structure: Dictionary containing name and children root_path: Optional root path string (unused but kept for signature compatibility) """ def _format_node(node: Dict[str, Any], depth: int = 0) -> List[str]: output = [] indent = " " * depth if depth > 0: # Skip root name output.append(f"{indent}{node['name']}") if "children" in node: children = sorted(node.get("children", []), key=lambda x: x["name"]) for child in children: output.extend(_format_node(child, depth + 1)) return output return "\n".join(_format_node(structure)) async def check_public_repo(self, repo_name: str) -> bool: try: github = self.get_public_github_instance() github.get_repo(repo_name) return True except Exception: return False ``` ## /app/modules/code_provider/local_repo/local_repo_service.py ```py path="/app/modules/code_provider/local_repo/local_repo_service.py" import asyncio import logging import os import re from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional, Union import git from fastapi import HTTPException from sqlalchemy.orm import Session from app.modules.projects.projects_service import ProjectService logger = logging.getLogger(__name__) class LocalRepoService: def __init__(self, db: Session): self.db = db self.project_manager = ProjectService(db) self.projects_dir = os.path.join(os.getcwd(), "projects") self.max_workers = 10 self.max_depth = 4 self.executor = ThreadPoolExecutor(max_workers=self.max_workers) def get_repo(self, repo_path: str) -> git.Repo: if not os.path.exists(repo_path): raise HTTPException( status_code=404, detail=f"Local repository at {repo_path} not found" ) return git.Repo(repo_path) def get_file_content( self, repo_name: str, file_path: str, start_line: int, end_line: int, branch_name: str, project_id: str, ) -> str: logger.info( f"Attempting to access file: {file_path} for project ID: {project_id}" ) try: project = self.project_manager.get_project_from_db_by_id_sync(project_id) if not project: raise HTTPException(status_code=404, detail="Project not found") repo_path = project["repo_path"] if not repo_path: raise HTTPException( status_code=400, detail="Project has no associated local repository" ) repo = self.get_repo(repo_path) repo.git.checkout(branch_name) file_full_path = os.path.join(repo_path, file_path) with open(file_full_path, "r", encoding="utf-8") as file: lines = file.readlines() if (start_line == end_line == 0) or (start_line == end_line == None): return "".join(lines) start = start_line - 2 if start_line - 2 > 0 else 0 selected_lines = lines[start:end_line] return "".join(selected_lines) except Exception as e: logger.error( f"Error processing file content for project ID {project_id}, file {file_path}: {e}", exc_info=True, ) raise HTTPException( status_code=500, detail=f"Error processing file content: {str(e)}", ) async def get_project_structure_async( self, project_id: str, path: Optional[str] = None ) -> str: project = await self.project_manager.get_project_from_db_by_id(project_id) if not project: raise HTTPException(status_code=404, detail="Project not found") repo_path = project["repo_path"] if not repo_path: raise HTTPException( status_code=400, detail="Project has no associated local repository" ) try: repo = self.get_repo(repo_path) structure = await self._fetch_repo_structure_async( repo, repo_path or "", current_depth=0, base_path=path ) formatted_structure = self._format_tree_structure(structure) return formatted_structure except Exception as e: logger.error( f"Error fetching project structure for {repo_path}: {str(e)}", exc_info=True, ) raise HTTPException( status_code=500, detail=f"Failed to fetch project structure: {str(e)}" ) async def _fetch_repo_structure_async( self, repo: Any, path: str = "", current_depth: int = 0, base_path: Optional[str] = None, ) -> Dict[str, Any]: exclude_extensions = [ "png", "jpg", "jpeg", "gif", "bmp", "tiff", "webp", "ico", "svg", "mp4", "avi", "mov", "wmv", "flv", "ipynb", "zlib", ] # Calculate current depth relative to base_path if base_path: # If we have a base_path, calculate depth relative to it relative_path = path[len(base_path) :].strip("/") current_depth = len(relative_path.split("/")) if relative_path else 0 else: # If no base_path, calculate depth from root current_depth = len(path.split("/")) if path else 0 # If we've reached max depth, return truncated indicator if current_depth >= self.max_depth: return { "type": "directory", "name": path.split("/")[-1] or repo.name, "children": [{"type": "file", "name": "...", "path": "truncated"}], } structure = { "type": "directory", "name": path.split("/")[-1] or repo.name, "children": [], } try: contents = await asyncio.get_event_loop().run_in_executor( self.executor, self._get_contents, path ) if not isinstance(contents, list): contents = [contents] # Filter out files with excluded extensions contents = [ item for item in contents if item["type"] == "dir" or not any(item["name"].endswith(ext) for ext in exclude_extensions) ] tasks = [] for item in contents: # Only process items within the base_path if it's specified if base_path and not item["path"].startswith(base_path): continue if item["type"] == "dir": task = self._fetch_repo_structure_async( repo, item["path"], current_depth=current_depth, base_path=base_path, ) tasks.append(task) else: structure["children"].append( { "type": "file", "name": item["name"], "path": item["path"], } ) if tasks: children = await asyncio.gather(*tasks) structure["children"].extend(children) except Exception as e: logger.error(f"Error fetching contents for path {path}: {str(e)}") return structure def _format_tree_structure( self, structure: Dict[str, Any], root_path: str = "" ) -> str: """ Creates a clear hierarchical structure using simple nested dictionaries. Args: self: The instance object structure: Dictionary containing name and children root_path: Optional root path string (unused but kept for signature compatibility) """ def _format_node(node: Dict[str, Any], depth: int = 0) -> List[str]: output = [] indent = " " * depth if depth > 0: # Skip root name output.append(f"{indent}{node['name']}") if "children" in node: children = sorted(node.get("children", []), key=lambda x: x["name"]) for child in children: output.extend(_format_node(child, depth + 1)) return output return "\n".join(_format_node(structure)) def get_local_repo_diff(self, repo_path: str, branch_name: str) -> Dict[str, str]: try: repo = self.get_repo(repo_path) repo.git.checkout(branch_name) # Determine the default branch name default_branch_name = repo.git.symbolic_ref( "refs/remotes/origin/HEAD" ).split("/")[-1] # Get the diff between the current branch and the default branch diff = repo.git.diff(f"{default_branch_name}..{branch_name}", unified=0) patches_dict = self._parse_diff(diff) return patches_dict except Exception as e: logger.error( f"Error computing diff for local repo: {str(e)}", exc_info=True ) raise HTTPException( status_code=500, detail=f"Error computing diff for local repo: {str(e)}" ) def _parse_diff(self, diff: str) -> Dict[str, str]: """ Parses the git diff output and returns a dictionary of file patches. """ patches_dict = {} current_file = None patch_lines = [] for line in diff.splitlines(): if line.startswith("diff --git"): if current_file and patch_lines: patches_dict[current_file] = "\n".join(patch_lines) match = re.search(r"b/(.+)", line) current_file = match.group(1) if match else None patch_lines = [] elif current_file: patch_lines.append(line) if current_file and patch_lines: patches_dict[current_file] = "\n".join(patch_lines) return patches_dict def _get_contents(self, path: str) -> Union[List[dict], dict]: """ If the path is a directory, it returns a list of dictionaries, each representing a file or subdirectory. If the path is a file, its content is read and returned. :param path: Relative or absolute path within the local repository. :return: A dict if the path is a file (with file content loaded), or a list of dicts if the path is a directory. """ if not isinstance(path, str): raise TypeError(f"Expected path to be a string, got {type(path).__name__}") if path == "/": path = "" abs_path = os.path.abspath(path) if not os.path.exists(abs_path): raise FileNotFoundError(f"Path '{abs_path}' does not exist.") if os.path.isdir(abs_path): contents = [] for item in os.listdir(abs_path): item_path = os.path.join(abs_path, item) if os.path.isdir(item_path): contents.append( { "path": item_path, "name": item, "type": "dir", "content": None, # path is a dir, content is not loaded "completed": True, } ) elif os.path.isfile(item_path): contents.append( { "path": item_path, "name": item, "type": "file", "content": None, "completed": False, } ) else: contents.append( { "path": item_path, "name": item, "type": "other", "content": None, "completed": True, } ) return contents elif os.path.isfile(abs_path): with open(abs_path, "r", encoding="utf-8") as file: file_content = file.read() return { "path": abs_path, "name": os.path.basename(abs_path), "type": "file", "content": file_content, # path is a file, content is loaded "completed": True, } ``` ## /app/modules/conversations/access/access_schema.py ```py path="/app/modules/conversations/access/access_schema.py" from typing import List, Optional from pydantic import BaseModel, EmailStr from app.modules.conversations.conversation.conversation_model import Visibility class ShareChatRequest(BaseModel): conversation_id: str recipientEmails: Optional[List[EmailStr]] = None visibility: Visibility class ShareChatResponse(BaseModel): message: str sharedID: str class SharedChatResponse(BaseModel): chat: dict class RemoveAccessRequest(BaseModel): emails: List[EmailStr] ``` ## /app/modules/conversations/access/access_service.py ```py path="/app/modules/conversations/access/access_service.py" from typing import List from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from app.modules.conversations.conversation.conversation_model import ( Conversation, Visibility, ) class ShareChatServiceError(Exception): """Base exception class for ShareChatService errors.""" class ShareChatService: def __init__(self, db: Session): self.db = db async def share_chat( self, conversation_id: str, user_id: str, recipient_emails: List[str] = None, visibility: Visibility = None, ) -> str: chat = ( self.db.query(Conversation) .filter_by(id=conversation_id, user_id=user_id) .first() ) if not chat: raise HTTPException( 404, "Chat does not exist or you are not authorized to access it." ) # Default to PRIVATE if visibility is not specified visibility = visibility or Visibility.PRIVATE try: # Update the visibility directly on the object chat.visibility = visibility if visibility == Visibility.PUBLIC: self.db.commit() return conversation_id # Handle PRIVATE visibility case if recipient_emails: existing_emails = chat.shared_with_emails or [] existing_emails_set = set(existing_emails) unique_new_emails_set = set(recipient_emails) to_share = unique_new_emails_set - existing_emails_set if to_share: updated_emails = existing_emails + list(to_share) chat.shared_with_emails = updated_emails # Always commit changes self.db.commit() return conversation_id except IntegrityError as e: self.db.rollback() raise ShareChatServiceError( "Failed to update shared chat due to a database integrity error." ) from e except Exception as e: self.db.rollback() raise ShareChatServiceError(f"Failed to update shared chat: {str(e)}") async def get_shared_emails(self, conversation_id: str, user_id: str) -> List[str]: chat = ( self.db.query(Conversation) .filter_by(id=conversation_id, user_id=user_id) .first() ) if not chat: raise HTTPException( 404, "Chat does not exist or you are not authorized to access it." ) return chat.shared_with_emails or [] async def remove_access( self, conversation_id: str, user_id: str, emails_to_remove: List[str] ) -> bool: """Remove access for specified emails from a conversation.""" chat = ( self.db.query(Conversation) .filter_by(id=conversation_id, user_id=user_id) .first() ) if not chat: raise HTTPException( status_code=404, detail="Chat does not exist or you are not authorized to access it.", ) if not chat.shared_with_emails: raise ShareChatServiceError("Chat has no shared access to remove.") existing_emails = set(chat.shared_with_emails) emails_to_remove_set = set(emails_to_remove) # Check if any of the emails to remove actually have access if not emails_to_remove_set.intersection(existing_emails): raise ShareChatServiceError( "None of the specified emails have access to this chat." ) try: updated_emails = list(existing_emails - emails_to_remove_set) self.db.query(Conversation).filter_by(id=conversation_id).update( {Conversation.shared_with_emails: updated_emails}, synchronize_session=False, ) self.db.commit() return True except IntegrityError as e: self.db.rollback() raise ShareChatServiceError( "Failed to update shared chat due to a database integrity error." ) from e ``` ## /app/modules/conversations/conversation/conversation_controller.py ```py path="/app/modules/conversations/conversation/conversation_controller.py" from typing import AsyncGenerator, List from fastapi import HTTPException from sqlalchemy.orm import Session from app.modules.conversations.conversation.conversation_schema import ( ChatMessageResponse, ConversationInfoResponse, CreateConversationRequest, CreateConversationResponse, ) from app.modules.conversations.conversation.conversation_service import ( AccessTypeNotFoundError, AccessTypeReadError, ConversationNotFoundError, ConversationService, ConversationServiceError, ) from app.modules.conversations.message.message_model import MessageType from app.modules.conversations.message.message_schema import ( MessageRequest, MessageResponse, NodeContext, ) class ConversationController: def __init__(self, db: Session, user_id: str, user_email: str): self.user_email = user_email self.service = ConversationService.create(db, user_id, user_email) self.user_id = user_id async def create_conversation( self, conversation: CreateConversationRequest, hidden: bool = False ) -> CreateConversationResponse: try: conversation_id, message = await self.service.create_conversation( conversation, self.user_id, hidden ) return CreateConversationResponse( message=message, conversation_id=conversation_id ) except ConversationServiceError as e: raise HTTPException(status_code=500, detail=str(e)) async def delete_conversation(self, conversation_id: str) -> dict: try: return await self.service.delete_conversation(conversation_id, self.user_id) except ConversationNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except AccessTypeReadError as e: raise HTTPException(status_code=401, detail=str(e)) except ConversationServiceError as e: raise HTTPException(status_code=500, detail=str(e)) async def get_conversation_info( self, conversation_id: str ) -> ConversationInfoResponse: try: return await self.service.get_conversation_info( conversation_id, self.user_id ) except ConversationNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except AccessTypeNotFoundError as e: raise HTTPException(status_code=401, detail=str(e)) except ConversationServiceError as e: raise HTTPException(status_code=500, detail=str(e)) async def get_conversation_messages( self, conversation_id: str, start: int, limit: int ) -> List[MessageResponse]: try: return await self.service.get_conversation_messages( conversation_id, start, limit, self.user_id ) except ConversationNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except AccessTypeNotFoundError as e: raise HTTPException(status_code=401, detail=str(e)) except ConversationServiceError as e: raise HTTPException(status_code=500, detail=str(e)) async def post_message( self, conversation_id: str, message: MessageRequest, stream: bool = True ) -> AsyncGenerator[ChatMessageResponse, None]: try: async for chunk in self.service.store_message( conversation_id, message, MessageType.HUMAN, self.user_id, stream ): yield chunk except ConversationNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except AccessTypeReadError as e: raise HTTPException(status_code=403, detail=str(e)) except ConversationServiceError as e: raise HTTPException(status_code=500, detail=str(e)) async def regenerate_last_message( self, conversation_id: str, node_ids: List[NodeContext] = [], stream: bool = True, ) -> AsyncGenerator[ChatMessageResponse, None]: try: async for chunk in self.service.regenerate_last_message( conversation_id, self.user_id, node_ids, stream ): yield chunk except ConversationNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except AccessTypeReadError as e: raise HTTPException(status_code=403, detail=str(e)) except ConversationServiceError as e: raise HTTPException(status_code=500, detail=str(e)) async def stop_generation(self, conversation_id: str) -> dict: try: return await self.service.stop_generation(conversation_id, self.user_id) except ConversationNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except ConversationServiceError as e: raise HTTPException(status_code=500, detail=str(e)) async def rename_conversation(self, conversation_id: str, new_title: str) -> dict: try: return await self.service.rename_conversation( conversation_id, new_title, self.user_id ) except ConversationNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except AccessTypeReadError as e: raise HTTPException(status_code=403, detail=str(e)) except ConversationServiceError as e: raise HTTPException(status_code=500, detail=str(e)) ``` ## /app/modules/conversations/conversation/conversation_model.py ```py path="/app/modules/conversations/conversation/conversation_model.py" import enum from sqlalchemy import ARRAY, TIMESTAMP, Column from sqlalchemy import Enum as SQLAEnum from sqlalchemy import ForeignKey, String, func from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import relationship from app.core.base_model import Base from app.modules.conversations.message.message_model import Message # noqa class ConversationStatus(enum.Enum): ACTIVE = "active" ARCHIVED = "archived" DELETED = "deleted" class Visibility(enum.Enum): PRIVATE = "private" PUBLIC = "public" class Conversation(Base): __tablename__ = "conversations" id = Column(String(255), primary_key=True, index=True) user_id = Column( String(255), ForeignKey("users.uid", ondelete="CASCADE"), nullable=False, index=True, ) title = Column(String(255), nullable=False) status = Column( SQLAEnum(ConversationStatus), default=ConversationStatus.ACTIVE, nullable=False ) project_ids = Column(ARRAY(String), nullable=False) agent_ids = Column(ARRAY(String), nullable=False) created_at = Column(TIMESTAMP(timezone=True), default=func.now(), nullable=False) updated_at = Column( TIMESTAMP(timezone=True), default=func.now(), onupdate=func.now(), nullable=False, ) shared_with_emails = Column(ARRAY(String), nullable=True) visibility = Column(SQLAEnum(Visibility), default=Visibility.PRIVATE, nullable=True) # Relationships user = relationship("User", back_populates="conversations") messages = relationship( "Message", back_populates="conversation", cascade="all, delete-orphan" ) @hybrid_property def projects(self): from app.core.database import SessionLocal from app.modules.projects.projects_model import Project with SessionLocal() as session: return session.query(Project).filter(Project.id.in_(self.project_ids)).all() ``` ## /app/modules/conversations/conversation/conversation_schema.py ```py path="/app/modules/conversations/conversation/conversation_schema.py" from datetime import datetime from enum import Enum from typing import Any, List, Optional from pydantic import BaseModel from app.modules.conversations.conversation.conversation_model import ( ConversationStatus, Visibility, ) class CreateConversationRequest(BaseModel): user_id: str title: str status: ConversationStatus project_ids: List[str] agent_ids: List[str] class ConversationAccessType(str, Enum): """ Enum for access type """ READ = "read" WRITE = "write" NOT_FOUND = "not_found" class CreateConversationResponse(BaseModel): message: str conversation_id: str class ConversationInfoResponse(BaseModel): id: str title: str status: ConversationStatus project_ids: List[str] created_at: datetime updated_at: datetime total_messages: int agent_ids: List[str] access_type: ConversationAccessType is_creator: bool creator_id: str visibility: Optional[Visibility] = None class Config: from_attributes = True class ChatMessageResponse(BaseModel): message: str citations: List[str] tool_calls: List[Any] # Resolve forward references ConversationInfoResponse.update_forward_refs() class RenameConversationRequest(BaseModel): title: str ``` ## /app/modules/conversations/conversation/conversation_service.py ```py path="/app/modules/conversations/conversation/conversation_service.py" import asyncio import json import logging from datetime import datetime, timezone from typing import AsyncGenerator, List from sqlalchemy import func from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.orm import Session from uuid6 import uuid7 from app.modules.code_provider.code_provider_service import CodeProviderService from app.modules.conversations.conversation.conversation_model import ( Conversation, ConversationStatus, Visibility, ) from app.modules.conversations.conversation.conversation_schema import ( ChatMessageResponse, ConversationAccessType, ConversationInfoResponse, CreateConversationRequest, ) from app.modules.conversations.message.message_model import ( Message, MessageStatus, MessageType, ) from app.modules.intelligence.agents.custom_agents.custom_agent_model import CustomAgent from app.modules.conversations.message.message_schema import ( MessageRequest, MessageResponse, NodeContext, ) from app.modules.intelligence.agents.custom_agents.custom_agents_service import ( CustomAgentService, ) from app.modules.intelligence.agents.agents_service import AgentsService from app.modules.intelligence.agents.chat_agent import ChatContext from app.modules.intelligence.memory.chat_history_service import ChatHistoryService from app.modules.intelligence.provider.provider_service import ( ProviderService, ) from app.modules.projects.projects_service import ProjectService from app.modules.users.user_service import UserService from app.modules.utils.posthog_helper import PostHogClient from app.modules.intelligence.agents.chat_agents.adaptive_agent import ( PromptService, ) from app.modules.intelligence.tools.tool_service import ToolService logger = logging.getLogger(__name__) class ConversationServiceError(Exception): pass class ConversationNotFoundError(ConversationServiceError): pass class MessageNotFoundError(ConversationServiceError): pass class AccessTypeNotFoundError(ConversationServiceError): pass class AccessTypeReadError(ConversationServiceError): pass class ConversationService: def __init__( self, db: Session, user_id: str, user_email: str, project_service: ProjectService, history_manager: ChatHistoryService, provider_service: ProviderService, tools_service: ToolService, promt_service: PromptService, agent_service: AgentsService, custom_agent_service: CustomAgentService, ): self.sql_db = db self.user_id = user_id self.user_email = user_email self.project_service = project_service self.history_manager = history_manager self.provider_service = provider_service self.tool_service = tools_service self.prompt_service = promt_service self.agent_service = agent_service self.custom_agent_service = custom_agent_service @classmethod def create(cls, db: Session, user_id: str, user_email: str): project_service = ProjectService(db) history_manager = ChatHistoryService(db) provider_service = ProviderService(db, user_id) tool_service = ToolService(db, user_id) prompt_service = PromptService(db) agent_service = AgentsService( db, provider_service, prompt_service, tool_service ) custom_agent_service = CustomAgentService(db) return cls( db, user_id, user_email, project_service, history_manager, provider_service, tool_service, prompt_service, agent_service, custom_agent_service, ) async def check_conversation_access( self, conversation_id: str, user_email: str ) -> str: if not user_email: return ConversationAccessType.WRITE user_service = UserService(self.sql_db) user_id = user_service.get_user_id_by_email(user_email) # Retrieve the conversation conversation = ( self.sql_db.query(Conversation).filter_by(id=conversation_id).first() ) if not conversation: return ( ConversationAccessType.NOT_FOUND ) # Return 'not found' if conversation doesn't exist if not conversation.visibility: conversation.visibility = Visibility.PRIVATE if user_id == conversation.user_id: # Check if the user is the creator return ConversationAccessType.WRITE # Creator always has write access if conversation.visibility == Visibility.PUBLIC: return ConversationAccessType.READ # Public users get read access # Check if the conversation is shared if conversation.shared_with_emails: shared_user_ids = user_service.get_user_ids_by_emails( conversation.shared_with_emails ) if shared_user_ids is None: return ConversationAccessType.NOT_FOUND # Check if the current user ID is in the shared user IDs if user_id in shared_user_ids: return ConversationAccessType.READ # Shared users can only read return ConversationAccessType.NOT_FOUND async def create_conversation( self, conversation: CreateConversationRequest, user_id: str, hidden: bool = False, ) -> tuple[str, str]: try: if not await self.agent_service.validate_agent_id( user_id, conversation.agent_ids[0] ): raise ConversationServiceError( f"Invalid agent_id: {conversation.agent_ids[0]}" ) project_name = await self.project_service.get_project_name( conversation.project_ids ) title = ( conversation.title.strip().replace("Untitled", project_name) if conversation.title else project_name ) conversation_id = self._create_conversation_record( conversation, title, user_id, hidden ) asyncio.create_task( CodeProviderService(self.sql_db).get_project_structure_async( conversation.project_ids[0] ) ) await self._add_system_message(conversation_id, project_name, user_id) return conversation_id, "Conversation created successfully." except IntegrityError as e: logger.error(f"IntegrityError in create_conversation: {e}", exc_info=True) self.sql_db.rollback() raise ConversationServiceError( "Failed to create conversation due to a database integrity error." ) from e except Exception as e: logger.error(f"Unexpected error in create_conversation: {e}", exc_info=True) self.sql_db.rollback() raise ConversationServiceError( "An unexpected error occurred while creating the conversation." ) from e def _create_conversation_record( self, conversation: CreateConversationRequest, title: str, user_id: str, hidden: bool = False, ) -> str: conversation_id = str(uuid7()) new_conversation = Conversation( id=conversation_id, user_id=user_id, title=title, status=ConversationStatus.ARCHIVED if hidden else ConversationStatus.ACTIVE, project_ids=conversation.project_ids, agent_ids=conversation.agent_ids, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) self.sql_db.add(new_conversation) self.sql_db.commit() logger.info( f"Project id : {conversation.project_ids[0]} Created new conversation with ID: {conversation_id}, title: {title}, user_id: {user_id}, agent_id: {conversation.agent_ids[0]}, hidden: {hidden}" ) return conversation_id async def _add_system_message( self, conversation_id: str, project_name: str, user_id: str ): content = f"You can now ask questions about the {project_name} repository." try: self.history_manager.add_message_chunk( conversation_id, content, MessageType.SYSTEM_GENERATED, user_id ) self.history_manager.flush_message_buffer( conversation_id, MessageType.SYSTEM_GENERATED, user_id ) logger.info( f"Added system message to conversation {conversation_id} for user {user_id}" ) except Exception as e: logger.error( f"Failed to add system message to conversation {conversation_id}: {e}", exc_info=True, ) raise ConversationServiceError( "Failed to add system message to the conversation." ) from e async def store_message( self, conversation_id: str, message: MessageRequest, message_type: MessageType, user_id: str, stream: bool = True, ) -> AsyncGenerator[ChatMessageResponse, None]: try: access_level = await self.check_conversation_access( conversation_id, self.user_email ) if access_level == ConversationAccessType.READ: raise AccessTypeReadError("Access denied.") self.history_manager.add_message_chunk( conversation_id, message.content, message_type, user_id ) self.history_manager.flush_message_buffer( conversation_id, message_type, user_id ) logger.info(f"Stored message in conversation {conversation_id}") if message_type == MessageType.HUMAN: conversation = await self._get_conversation_with_message_count( conversation_id ) if not conversation: raise ConversationNotFoundError( f"Conversation with id {conversation_id} not found" ) # Check if this is the first human message if conversation.human_message_count == 1: new_title = await self._generate_title( conversation, message.content ) await self._update_conversation_title(conversation_id, new_title) project_id = ( conversation.project_ids[0] if conversation.project_ids else None ) if not project_id: raise ConversationServiceError( "No project associated with this conversation" ) if stream: async for chunk in self._generate_and_stream_ai_response( message.content, conversation_id, user_id, message.node_ids ): yield chunk else: full_message = "" all_citations = [] async for chunk in self._generate_and_stream_ai_response( message.content, conversation_id, user_id, message.node_ids ): full_message += chunk.message all_citations = all_citations + chunk.citations yield ChatMessageResponse( message=full_message, citations=all_citations, tool_calls=[] ) except AccessTypeReadError: raise except Exception as e: logger.error( f"Error in store_message for conversation {conversation_id}: {e}", exc_info=True, ) raise ConversationServiceError( "Failed to store message or generate AI response." ) from e async def _get_conversation_with_message_count( self, conversation_id: str ) -> Conversation: result = ( self.sql_db.query( Conversation, func.count(Message.id) .filter(Message.type == MessageType.HUMAN) .label("human_message_count"), ) .outerjoin(Message, Conversation.id == Message.conversation_id) .filter(Conversation.id == conversation_id) .group_by(Conversation.id) .first() ) if result: conversation, human_message_count = result setattr(conversation, "human_message_count", human_message_count) return conversation return None async def _generate_title( self, conversation: Conversation, message_content: str ) -> str: agent_type = conversation.agent_ids[0] prompt = ( "Given an agent type '{agent_type}' and an initial message '{message}', " "generate a concise and relevant title for a conversation. " "The title should be no longer than 50 characters. Only return title string, do not wrap in quotes." ).format(agent_type=agent_type, message=message_content) messages = [ { "role": "system", "content": "You are a conversation title generator that creates concise and relevant titles.", }, {"role": "user", "content": prompt}, ] generated_title: str = await self.provider_service.call_llm( messages=messages, config_type="chat" ) # type: ignore if len(generated_title) > 50: generated_title = generated_title[:50].strip() + "..." return generated_title async def _update_conversation_title(self, conversation_id: str, new_title: str): self.sql_db.query(Conversation).filter_by(id=conversation_id).update( {"title": new_title, "updated_at": datetime.now(timezone.utc)} ) self.sql_db.commit() async def regenerate_last_message( self, conversation_id: str, user_id: str, node_ids: List[NodeContext] = [], stream: bool = True, ) -> AsyncGenerator[ChatMessageResponse, None]: try: access_level = await self.check_conversation_access( conversation_id, self.user_email ) if access_level != ConversationAccessType.WRITE: raise AccessTypeReadError( "Access denied. Only conversation creators can regenerate messages." ) last_human_message = await self._get_last_human_message(conversation_id) if not last_human_message: raise MessageNotFoundError("No human message found to regenerate from") await self._archive_subsequent_messages( conversation_id, last_human_message.created_at ) PostHogClient().send_event( user_id, "regenerate_conversation_event", {"conversation_id": conversation_id}, ) if stream: async for chunk in self._generate_and_stream_ai_response( last_human_message.content, conversation_id, user_id, node_ids ): yield chunk else: full_message = "" all_citations = [] async for chunk in self._generate_and_stream_ai_response( last_human_message.content, conversation_id, user_id, node_ids ): full_message += chunk.message all_citations = all_citations + chunk.citations yield ChatMessageResponse( message=full_message, citations=all_citations, tool_calls=[] ) except AccessTypeReadError: raise except MessageNotFoundError as e: logger.warning( f"No message to regenerate in conversation {conversation_id}: {e}" ) raise except Exception as e: logger.error( f"Error in regenerate_last_message for conversation {conversation_id}: {e}", exc_info=True, ) raise ConversationServiceError("Failed to regenerate last message.") from e async def _get_last_human_message(self, conversation_id: str): message = ( self.sql_db.query(Message) .filter_by(conversation_id=conversation_id, type=MessageType.HUMAN) .order_by(Message.created_at.desc()) .first() ) if not message: logger.warning(f"No human message found in conversation {conversation_id}") return message async def _archive_subsequent_messages( self, conversation_id: str, timestamp: datetime ): try: self.sql_db.query(Message).filter( Message.conversation_id == conversation_id, Message.created_at > timestamp, ).update( {Message.status: MessageStatus.ARCHIVED}, synchronize_session="fetch" ) self.sql_db.commit() logger.info( f"Archived subsequent messages in conversation {conversation_id}" ) except Exception as e: logger.error( f"Failed to archive messages in conversation {conversation_id}: {e}", exc_info=True, ) self.sql_db.rollback() raise ConversationServiceError( "Failed to archive subsequent messages." ) from e def parse_str_to_message(self, chunk: str) -> ChatMessageResponse: try: data = json.loads(chunk) except json.JSONDecodeError as e: logger.error(f"Failed to parse chunk as JSON: {e}") raise ConversationServiceError("Failed to parse AI response") from e # Extract the 'message' and 'citations' message: str = data.get("message", "") citations: List[str] = data.get("citations", []) tool_calls: List[dict] = data.get("tool_calls", []) return ChatMessageResponse( message=message, citations=citations, tool_calls=tool_calls ) async def _generate_and_stream_ai_response( self, query: str, conversation_id: str, user_id: str, node_ids: List[NodeContext], ) -> AsyncGenerator[ChatMessageResponse, None]: conversation = ( self.sql_db.query(Conversation).filter_by(id=conversation_id).first() ) if not conversation: raise ConversationNotFoundError( f"Conversation with id {conversation_id} not found" ) agent_id = conversation.agent_ids[0] project_id = conversation.project_ids[0] if conversation.project_ids else None try: history = self.history_manager.get_session_history(user_id, conversation_id) validated_history = [ (f"{msg.type}: {msg.content}" if msg.content else msg) for msg in history ] except Exception: raise ConversationServiceError("Failed to get chat history") try: type = await self.agent_service.validate_agent_id(user_id, str(agent_id)) if type is None: raise ConversationServiceError(f"Invalid agent_id {agent_id}") project_name = await self.project_service.get_project_name( project_ids=[project_id] ) logger.info( f"conversation_id: {conversation_id} Running agent {agent_id} with query: {query}" ) if type == "CUSTOM_AGENT": # Custom agent doesn't support streaming, so we'll yield the entire response at once response = ( await self.agent_service.custom_agent_service.execute_agent_runtime( agent_id, user_id, query, node_ids, project_id, project_name, conversation.id, ) ) yield ChatMessageResponse( message=response["message"], citations=[], tool_calls=[] ) else: res = self.agent_service.execute_stream( ChatContext( project_id=str(project_id), project_name=project_name, curr_agent_id=str(agent_id), history=validated_history[-8:], node_ids=[node.node_id for node in node_ids], query=query, ) ) async for chunk in res: self.history_manager.add_message_chunk( conversation_id, chunk.response, MessageType.AI_GENERATED, citations=chunk.citations, ) yield ChatMessageResponse( message=chunk.response, citations=chunk.citations, tool_calls=[ tool_call.model_dump_json() for tool_call in chunk.tool_calls ], ) self.history_manager.flush_message_buffer( conversation_id, MessageType.AI_GENERATED ) logger.info( f"Generated and streamed AI response for conversation {conversation.id} for user {user_id} using agent {agent_id}" ) except Exception as e: logger.error( f"Failed to generate and stream AI response for conversation {conversation.id}: {e}", exc_info=True, ) raise ConversationServiceError( "Failed to generate and stream AI response." ) from e async def delete_conversation(self, conversation_id: str, user_id: str) -> dict: try: access_level = await self.check_conversation_access( conversation_id, self.user_email ) if access_level == ConversationAccessType.READ: raise AccessTypeReadError("Access denied.") # Use a nested transaction if one is already in progress with self.sql_db.begin_nested(): # Delete related messages first deleted_messages = ( self.sql_db.query(Message) .filter(Message.conversation_id == conversation_id) .delete(synchronize_session="fetch") ) deleted_conversation = ( self.sql_db.query(Conversation) .filter(Conversation.id == conversation_id) .delete(synchronize_session="fetch") ) if deleted_conversation == 0: raise ConversationNotFoundError( f"Conversation with id {conversation_id} not found" ) # If we get here, commit the transaction self.sql_db.commit() PostHogClient().send_event( user_id, "delete_conversation_event", {"conversation_id": conversation_id}, ) logger.info( f"Deleted conversation {conversation_id} and {deleted_messages} related messages" ) return { "status": "success", "message": f"Conversation {conversation_id} and its messages have been permanently deleted.", "deleted_messages_count": deleted_messages, } except ConversationNotFoundError as e: logger.warning(str(e)) self.sql_db.rollback() raise except AccessTypeReadError: raise except SQLAlchemyError as e: logger.error(f"Database error in delete_conversation: {e}", exc_info=True) self.sql_db.rollback() raise ConversationServiceError( f"Failed to delete conversation {conversation_id} due to a database error" ) from e except Exception as e: logger.error(f"Unexpected error in delete_conversation: {e}", exc_info=True) self.sql_db.rollback() raise ConversationServiceError( f"Failed to delete conversation {conversation_id} due to an unexpected error" ) from e async def get_conversation_info( self, conversation_id: str, user_id: str ) -> ConversationInfoResponse: try: conversation = ( self.sql_db.query(Conversation).filter_by(id=conversation_id).first() ) if not conversation: raise ConversationNotFoundError( f"Conversation with id {conversation_id} not found" ) is_creator = conversation.user_id == user_id access_type = await self.check_conversation_access( conversation_id, self.user_email ) if access_type == ConversationAccessType.NOT_FOUND: raise AccessTypeNotFoundError("Access type not found") total_messages = ( self.sql_db.query(Message) .filter_by(conversation_id=conversation_id, status=MessageStatus.ACTIVE) .count() ) agent_id = conversation.agent_ids[0] if conversation.agent_ids else None agent_ids = conversation.agent_ids if agent_id: system_agents = self.agent_service._system_agents( self.provider_service, self.prompt_service, self.tool_service ) if agent_id in system_agents.keys(): agent_ids = conversation.agent_ids else: custom_agent = ( self.sql_db.query(CustomAgent).filter_by(id=agent_id).first() ) if custom_agent: agent_ids = [custom_agent.role] return ConversationInfoResponse( id=conversation.id, title=conversation.title, status=conversation.status, project_ids=conversation.project_ids, created_at=conversation.created_at, updated_at=conversation.updated_at, total_messages=total_messages, agent_ids=agent_ids, access_type=access_type, is_creator=is_creator, creator_id=conversation.user_id, visibility=conversation.visibility, ) except ConversationNotFoundError as e: logger.warning(str(e)) raise except AccessTypeNotFoundError: raise except Exception as e: logger.error(f"Error in get_conversation_info: {e}", exc_info=True) raise ConversationServiceError( f"Failed to get conversation info for {conversation_id}" ) from e async def get_conversation_messages( self, conversation_id: str, start: int, limit: int, user_id: str ) -> List[MessageResponse]: try: access_level = await self.check_conversation_access( conversation_id, self.user_email ) if access_level == ConversationAccessType.NOT_FOUND: raise AccessTypeNotFoundError("Access denied.") conversation = ( self.sql_db.query(Conversation).filter_by(id=conversation_id).first() ) if not conversation: raise ConversationNotFoundError( f"Conversation with id {conversation_id} not found" ) messages = ( self.sql_db.query(Message) .filter_by(conversation_id=conversation_id) .filter_by(status=MessageStatus.ACTIVE) .filter(Message.type != MessageType.SYSTEM_GENERATED) .order_by(Message.created_at) .offset(start) .limit(limit) .all() ) return [ MessageResponse( id=message.id, conversation_id=message.conversation_id, content=message.content, sender_id=message.sender_id, type=message.type, status=message.status, created_at=message.created_at, citations=( message.citations.split(",") if message.citations else None ), ) for message in messages ] except ConversationNotFoundError as e: logger.warning(str(e)) raise except AccessTypeNotFoundError: raise except Exception as e: logger.error(f"Error in get_conversation_messages: {e}", exc_info=True) raise ConversationServiceError( f"Failed to get messages for conversation {conversation_id}" ) from e async def stop_generation(self, conversation_id: str, user_id: str) -> dict: logger.info(f"Attempting to stop generation for conversation {conversation_id}") return {"status": "success", "message": "Generation stop request received"} async def rename_conversation( self, conversation_id: str, new_title: str, user_id: str ) -> dict: try: access_level = await self.check_conversation_access( conversation_id, self.user_email ) if access_level == ConversationAccessType.READ: raise AccessTypeReadError("Access denied.") conversation = ( self.sql_db.query(Conversation) .filter_by(id=conversation_id, user_id=user_id) .first() ) if not conversation: raise ConversationNotFoundError( f"Conversation with id {conversation_id} not found" ) conversation.title = new_title conversation.updated_at = datetime.now(timezone.utc) self.sql_db.commit() logger.info( f"Renamed conversation {conversation_id} to '{new_title}' by user {user_id}" ) return { "status": "success", "message": f"Conversation renamed to '{new_title}'", } except SQLAlchemyError as e: logger.error(f"Database error in rename_conversation: {e}", exc_info=True) self.sql_db.rollback() raise ConversationServiceError( "Failed to rename conversation due to a database error" ) from e except AccessTypeReadError: raise except Exception as e: logger.error(f"Unexpected error in rename_conversation: {e}", exc_info=True) self.sql_db.rollback() raise ConversationServiceError( "Failed to rename conversation due to an unexpected error" ) from e ``` ## /app/modules/conversations/conversations_router.py ```py path="/app/modules/conversations/conversations_router.py" import json from typing import Any, AsyncGenerator, List from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from app.core.database import get_db from app.modules.auth.auth_service import AuthService from app.modules.conversations.access.access_schema import ( RemoveAccessRequest, ShareChatRequest, ShareChatResponse, ) from app.modules.conversations.access.access_service import ( ShareChatService, ShareChatServiceError, ) from app.modules.conversations.conversation.conversation_controller import ( ConversationController, ) from app.modules.usage.usage_service import UsageService from .conversation.conversation_schema import ( ConversationInfoResponse, CreateConversationRequest, CreateConversationResponse, RenameConversationRequest, ) from .message.message_schema import MessageRequest, MessageResponse, RegenerateRequest router = APIRouter() async def get_stream(data_stream: AsyncGenerator[Any, None]): async for chunk in data_stream: yield json.dumps(chunk.dict()) class ConversationAPI: @staticmethod @router.post("/conversations/", response_model=CreateConversationResponse) async def create_conversation( conversation: CreateConversationRequest, hidden: bool = Query( False, description="Whether to hide this conversation from the web UI" ), db: Session = Depends(get_db), user=Depends(AuthService.check_auth), ): user_id = user["user_id"] checked = await UsageService.check_usage_limit(user_id) if not checked: raise HTTPException( status_code=402, detail="Subscription required to create a conversation.", ) user_email = user["email"] controller = ConversationController(db, user_id, user_email) return await controller.create_conversation(conversation, hidden) @staticmethod @router.get( "/conversations/{conversation_id}/info/", response_model=ConversationInfoResponse, ) async def get_conversation_info( conversation_id: str, db: Session = Depends(get_db), user=Depends(AuthService.check_auth), ): user_id = user["user_id"] user_email = user["email"] controller = ConversationController(db, user_id, user_email) return await controller.get_conversation_info(conversation_id) @staticmethod @router.get( "/conversations/{conversation_id}/messages/", response_model=List[MessageResponse], ) async def get_conversation_messages( conversation_id: str, start: int = Query(0, ge=0), limit: int = Query(10, ge=1), db: Session = Depends(get_db), user=Depends(AuthService.check_auth), ): user_id = user["user_id"] user_email = user["email"] controller = ConversationController(db, user_id, user_email) return await controller.get_conversation_messages(conversation_id, start, limit) @staticmethod @router.post("/conversations/{conversation_id}/message/") async def post_message( conversation_id: str, message: MessageRequest, stream: bool = Query(True, description="Whether to stream the response"), db: Session = Depends(get_db), user=Depends(AuthService.check_auth), ): if ( message.content == "" or message.content is None or message.content.isspace() ): raise HTTPException( status_code=400, detail="Message content cannot be empty" ) user_id = user["user_id"] user_email = user["email"] checked = await UsageService.check_usage_limit(user_id) if not checked: raise HTTPException( status_code=402, detail="Subscription required to create a conversation.", ) controller = ConversationController(db, user_id, user_email) message_stream = controller.post_message(conversation_id, message, stream) if stream: return StreamingResponse( get_stream(message_stream), media_type="text/event-stream" ) else: # TODO: fix this, add types. In below stream we have only one output. async for chunk in message_stream: return chunk @staticmethod @router.post( "/conversations/{conversation_id}/regenerate/", response_model=MessageResponse ) async def regenerate_last_message( conversation_id: str, request: RegenerateRequest, stream: bool = Query(True, description="Whether to stream the response"), db: Session = Depends(get_db), user=Depends(AuthService.check_auth), ): user_id = user["user_id"] checked = await UsageService.check_usage_limit(user_id) if not checked: raise HTTPException( status_code=402, detail="Subscription required to create a conversation.", ) user_email = user["email"] controller = ConversationController(db, user_id, user_email) message_stream = controller.regenerate_last_message( conversation_id, request.node_ids, stream ) if stream: return StreamingResponse( get_stream(message_stream), media_type="text/event-stream" ) else: async for chunk in message_stream: return chunk @staticmethod @router.delete("/conversations/{conversation_id}/", response_model=dict) async def delete_conversation( conversation_id: str, db: Session = Depends(get_db), user=Depends(AuthService.check_auth), ): user_id = user["user_id"] user_email = user["email"] controller = ConversationController(db, user_id, user_email) return await controller.delete_conversation(conversation_id) @staticmethod @router.post("/conversations/{conversation_id}/stop/", response_model=dict) async def stop_generation( conversation_id: str, db: Session = Depends(get_db), user=Depends(AuthService.check_auth), ): user_id = user["user_id"] user_email = user["email"] controller = ConversationController(db, user_id, user_email) return await controller.stop_generation(conversation_id) @staticmethod @router.patch("/conversations/{conversation_id}/rename/", response_model=dict) async def rename_conversation( conversation_id: str, request: RenameConversationRequest, db: Session = Depends(get_db), user=Depends(AuthService.check_auth), ): user_id = user["user_id"] user_email = user["email"] controller = ConversationController(db, user_id, user_email) return await controller.rename_conversation(conversation_id, request.title) @router.post("/conversations/share", response_model=ShareChatResponse, status_code=201) async def share_chat( request: ShareChatRequest, db: Session = Depends(get_db), user=Depends(AuthService.check_auth), ): user_id = user["user_id"] service = ShareChatService(db) try: shared_conversation = await service.share_chat( request.conversation_id, user_id, request.recipientEmails, request.visibility, ) return ShareChatResponse( message="Chat shared successfully!", sharedID=shared_conversation ) except ShareChatServiceError as e: raise HTTPException(status_code=400, detail=str(e)) @router.get("/conversations/{conversation_id}/shared-emails", response_model=List[str]) async def get_shared_emails( conversation_id: str, db: Session = Depends(get_db), user=Depends(AuthService.check_auth), ): user_id = user["user_id"] service = ShareChatService(db) shared_emails = await service.get_shared_emails(conversation_id, user_id) return shared_emails @router.delete("/conversations/{conversation_id}/access") async def remove_access( conversation_id: str, request: RemoveAccessRequest, user: str = Depends(AuthService.check_auth), db: Session = Depends(get_db), ) -> dict: """Remove access for specified emails from a conversation.""" share_service = ShareChatService(db) current_user_id = user["user_id"] try: await share_service.remove_access( conversation_id=conversation_id, user_id=current_user_id, emails_to_remove=request.emails, ) return {"message": "Access removed successfully"} except ShareChatServiceError as e: raise HTTPException(status_code=400, detail=str(e)) ``` ## /app/modules/conversations/message/message_model.py ```py path="/app/modules/conversations/message/message_model.py" import enum from sqlalchemy import TIMESTAMP, CheckConstraint, Column from sqlalchemy import Enum as SQLAEnum from sqlalchemy import ForeignKey, String, Text, func from sqlalchemy.orm import relationship from app.core.base_model import Base class MessageStatus(str, enum.Enum): ACTIVE = "ACTIVE" ARCHIVED = "ARCHIVED" DELETED = "DELETED" # Possible Future extension class MessageType(str, enum.Enum): AI_GENERATED = "AI_GENERATED" HUMAN = "HUMAN" SYSTEM_GENERATED = "SYSTEM_GENERATED" class Message(Base): __tablename__ = "messages" id = Column(String(255), primary_key=True) conversation_id = Column( String(255), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False, index=True, ) content = Column(Text, nullable=False) sender_id = Column(String(255), nullable=True) type = Column(SQLAEnum(MessageType), nullable=False) status = Column( SQLAEnum(MessageStatus), default=MessageStatus.ACTIVE, nullable=False ) created_at = Column(TIMESTAMP(timezone=True), default=func.now(), nullable=False) citations = Column(Text, nullable=True) conversation = relationship("Conversation", back_populates="messages") __table_args__ = ( CheckConstraint( "(type = 'HUMAN' AND sender_id IS NOT NULL) OR " "(type IN ('AI_GENERATED', 'SYSTEM_GENERATED') AND sender_id IS NULL)", name="check_sender_id_for_type", ), ) ``` ## /app/modules/conversations/message/message_schema.py ```py path="/app/modules/conversations/message/message_schema.py" from datetime import datetime from typing import List, Optional from pydantic import BaseModel from app.modules.conversations.message.message_model import MessageStatus, MessageType class NodeContext(BaseModel): node_id: str name: str class MessageRequest(BaseModel): content: str node_ids: Optional[List[NodeContext]] = None class DirectMessageRequest(BaseModel): content: str node_ids: Optional[List[NodeContext]] = None agent_id: str | None = None class RegenerateRequest(BaseModel): node_ids: Optional[List[NodeContext]] = None class MessageResponse(BaseModel): id: str conversation_id: str content: str sender_id: Optional[str] = None type: MessageType reason: Optional[str] = None created_at: datetime status: MessageStatus citations: Optional[List[str]] = None class Config: from_attributes = True ``` ## /app/modules/conversations/message/message_service.py ```py path="/app/modules/conversations/message/message_service.py" import asyncio import logging from datetime import datetime, timezone from typing import Optional from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.orm import Session from uuid6 import uuid7 from app.modules.conversations.message.message_model import ( Message, MessageStatus, MessageType, ) logger = logging.getLogger(__name__) class MessageServiceError(Exception): """Base exception class for MessageService errors.""" class MessageNotFoundError(MessageServiceError): """Raised when a message is not found.""" class InvalidMessageError(MessageServiceError): """Raised when there's an issue with message creation parameters.""" class MessageService: def __init__(self, db: Session): self.db = db async def create_message( self, conversation_id: str, content: str, message_type: MessageType, sender_id: Optional[str] = None, ) -> Message: try: if (message_type == MessageType.HUMAN and sender_id is None) or ( message_type in {MessageType.AI_GENERATED, MessageType.SYSTEM_GENERATED} and sender_id is not None ): raise InvalidMessageError( "Invalid sender_id for the given message_type." ) message_id = str(uuid7()) new_message = Message( id=message_id, conversation_id=conversation_id, content=content, type=message_type, created_at=datetime.now(timezone.utc), sender_id=sender_id, status=MessageStatus.ACTIVE, ) await asyncio.get_event_loop().run_in_executor( None, self._sync_create_message, new_message ) logger.info( f"Created new message with ID: {message_id} for conversation: {conversation_id}" ) return new_message except InvalidMessageError as e: logger.warning(f"Invalid message parameters: {str(e)}") raise except IntegrityError as e: logger.error( f"Database integrity error in create_message: {e}", exc_info=True ) raise MessageServiceError( "Failed to create message due to a database integrity error" ) from e except Exception as e: logger.error(f"Unexpected error in create_message: {e}", exc_info=True) raise MessageServiceError( "An unexpected error occurred while creating the message" ) from e def _sync_create_message(self, new_message: Message): try: self.db.add(new_message) self.db.commit() self.db.refresh(new_message) except SQLAlchemyError: self.db.rollback() raise async def mark_message_archived(self, message_id: str) -> None: try: await asyncio.get_event_loop().run_in_executor( None, self._sync_mark_message_archived, message_id ) # TODO: add conversation_id to the log logger.info(f"Marked message {message_id} as archived") except MessageNotFoundError as e: logger.warning(str(e)) raise except SQLAlchemyError as e: logger.error(f"Database error in mark_message_archived: {e}", exc_info=True) raise MessageServiceError( f"Failed to archive message {message_id} due to a database error" ) from e except Exception as e: logger.error( f"Unexpected error in mark_message_archived: {e}", exc_info=True ) raise MessageServiceError( f"An unexpected error occurred while archiving message {message_id}" ) from e def _sync_mark_message_archived(self, message_id: str): try: message = ( self.db.query(Message).filter(Message.id == message_id).one_or_none() ) if message: message.status = MessageStatus.ARCHIVED self.db.commit() else: raise MessageNotFoundError(f"Message with id {message_id} not found.") except SQLAlchemyError: self.db.rollback() raise ``` ## /app/modules/intelligence/__init__.py ```py path="/app/modules/intelligence/__init__.py" ``` The content has been capped at 50000 tokens, and files over NaN bytes have been omitted. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.