```
├── .dockerignore
├── .env.example (omitted)
├── .gitattributes (omitted)
├── .gitignore (100 tokens)
├── AGENTS.md (100 tokens)
├── Dockerfile (100 tokens)
├── LICENSE (omitted)
├── README.md (1000 tokens)
├── alembic.ini (100 tokens)
├── alembic/
├── env.py (300 tokens)
├── script.py.mako (100 tokens)
├── versions/
├── 0001_initial_schema.py (1400 tokens)
├── 0002_prompt_class_and_versions.py (1500 tokens)
├── 0003_prompt_tags.py (500 tokens)
├── 0004_refactor_llm_providers.py (700 tokens)
├── 0005_drop_desc_params.py (200 tokens)
├── 07f7c967ab21_add_llm_usage_logs.py (500 tokens)
├── 2d8d3a4e0c8b_merge_prompt_test_heads.py (100 tokens)
├── 3152d7c2b4f0_add_llm_model_concurrency.py (100 tokens)
├── 6d6a1f6dfb41_create_system_settings_table.py (200 tokens)
├── 72f3f786c4a1_add_prompt_test_tables.py (1200 tokens)
├── 9b546f1b6f1a_add_test_run_batch_id.py (100 tokens)
├── a1b2c3d4e5f6_add_soft_delete_to_prompt_test_tasks.py (100 tokens)
├── ddc6143e8fb8_add_default_prompt_class.py (300 tokens)
├── efea0d0224c5_seed_default_prompt_tags.py (400 tokens)
├── f5e1a97c2e3d_seed_sample_prompt.py (1000 tokens)
├── app/
├── __init__.py
├── api/
├── __init__.py
├── v1/
├── __init__.py
├── api.py (200 tokens)
├── endpoints/
├── __init__.py
├── llms.py (7k tokens)
├── prompt_classes.py (800 tokens)
├── prompt_tags.py (600 tokens)
├── prompt_test_tasks.py (2000 tokens)
├── prompts.py (1700 tokens)
├── settings.py (300 tokens)
├── test_prompt.py (1300 tokens)
├── usage.py (800 tokens)
├── core/
├── __init__.py
├── config.py (400 tokens)
├── llm_provider_registry.py (1000 tokens)
├── logging_config.py (500 tokens)
├── middleware.py (400 tokens)
├── prompt_test_task_queue.py (2.4k tokens)
├── task_queue.py (700 tokens)
├── db/
├── __init__.py
├── session.py (100 tokens)
├── types.py (100 tokens)
├── main.py (300 tokens)
├── models/
├── __init__.py (200 tokens)
├── base.py
├── llm_provider.py (700 tokens)
├── metric.py (300 tokens)
├── prompt.py (1100 tokens)
├── prompt_test.py (1400 tokens)
├── result.py (300 tokens)
├── system_setting.py (200 tokens)
├── test_run.py (700 tokens)
├── usage.py (500 tokens)
├── schemas/
├── __init__.py (300 tokens)
├── llm_provider.py (700 tokens)
├── metric.py (100 tokens)
├── prompt.py (1000 tokens)
├── prompt_test.py (900 tokens)
├── result.py (100 tokens)
├── settings.py (200 tokens)
├── test_run.py (400 tokens)
├── usage.py (200 tokens)
├── services/
├── __init__.py
├── llm_usage.py (100 tokens)
├── prompt_test_engine.py (3.7k tokens)
├── system_settings.py (600 tokens)
├── test_run.py (3.4k tokens)
├── usage_dashboard.py (1200 tokens)
├── docker-compose.yml (300 tokens)
├── docker/
├── backend/
├── entrypoint.sh (100 tokens)
├── docs/
├── README_en.md (1700 tokens)
├── UPDATES.md (700 tokens)
├── logo.jpg
├── frontend/
├── .gitignore (100 tokens)
├── Dockerfile (100 tokens)
├── index.html (100 tokens)
├── nginx.conf (100 tokens)
├── package-lock.json (26.3k tokens)
├── package.json (100 tokens)
├── public/
├── logo.png
├── vite.svg (300 tokens)
├── src/
├── App.vue (2.3k tokens)
├── api/
├── http.ts (300 tokens)
├── llmProvider.ts (500 tokens)
├── prompt.ts (500 tokens)
├── promptClass.ts (300 tokens)
├── promptTag.ts (200 tokens)
├── promptTest.ts (400 tokens)
├── quickTest.ts (1200 tokens)
├── settings.ts (100 tokens)
├── testRun.ts (300 tokens)
├── usage.ts (300 tokens)
├── composables/
├── usePromptDetail.ts (400 tokens)
├── useTestingSettings.ts (300 tokens)
├── env.d.ts (omitted)
├── i18n/
├── index.ts (200 tokens)
├── messages.ts (13.6k tokens)
├── main.ts (100 tokens)
├── router/
├── index.ts (700 tokens)
├── style.css (200 tokens)
├── types/
├── llm.ts (400 tokens)
├── prompt.ts (100 tokens)
├── promptTest.ts (400 tokens)
├── testRun.ts (300 tokens)
├── utils/
├── promptTestResult.ts (1600 tokens)
├── views/
├── LLMManagementView.vue (5.8k tokens)
├── PromptClassManagementView.vue (1500 tokens)
├── PromptDetailView.vue (7k tokens)
├── PromptManagementView.vue (4k tokens)
├── PromptTagManagementView.vue (1500 tokens)
├── PromptTestTaskCreateView.vue (7.2k tokens)
├── PromptTestTaskResultView.vue (6.9k tokens)
├── PromptTestUnitResultView.vue (3.7k tokens)
├── PromptVersionCompareView.vue (1800 tokens)
├── PromptVersionCreateView.vue (1100 tokens)
├── QuickTestView.vue (11.7k tokens)
├── TestJobCreateView.vue (6.4k tokens)
├── TestJobManagementView.vue (6.8k tokens)
├── TestJobResultView.vue (4.8k tokens)
├── UsageManagementView.vue (3.1k tokens)
├── tsconfig.json (100 tokens)
├── vite.config.ts (100 tokens)
├── pyproject.toml (500 tokens)
├── tests/
├── __init__.py
├── conftest.py (400 tokens)
├── test_core_config.py (300 tokens)
├── test_core_middleware.py (400 tokens)
├── test_db_session.py (100 tokens)
├── test_llms.py (6.4k tokens)
├── test_prompt_classes.py (800 tokens)
├── test_prompt_schemas.py (300 tokens)
├── test_prompt_tags.py (500 tokens)
├── test_prompt_test_engine.py (2.2k tokens)
├── test_prompt_test_task_queue.py (1100 tokens)
├── test_prompts.py (2.7k tokens)
├── test_settings_api.py (200 tokens)
├── test_system_settings.py (400 tokens)
├── test_test_prompt.py (3.1k tokens)
├── test_test_run_service.py (2.3k tokens)
├── test_usage_api.py (1000 tokens)
├── test_usage_dashboard_service.py (1000 tokens)
├── uv.lock (omitted)
```
## /.dockerignore
```dockerignore path="/.dockerignore"
.git
.gitignore
.env
.venv
.uv_cache
.cache
__pycache__
*.pyc
*.pyo
tests/
docs/
frontend/node_modules
frontend/dist
# README.md 需保留用于打包描述,此处不忽略
```
## /.gitignore
```gitignore path="/.gitignore"
# Byte-compiled / optimized files
__pycache__/
*.py[cod]
*$py.class
# Distribution / packaging
build/
dist/
*.egg-info/
# Virtual environments
.venv/
venv/
ENV/
env/
# Python tooling缓存
.uv_cache/
.ruff_cache/
# Node.js / Vite 构建产物
node_modules/
frontend/dist/
frontend/.vite/
npm-debug.log*
pnpm-debug.log*
yarn-error.log*
# 前端环境变量
frontend/.env.local
frontend/.env.*.local
# Environment variables
.env
# Testing / coverage
.coverage
.coverage.*
.pytest_cache/
htmlcov/
# Tooling
.mypy_cache/
.pyre/
.pytype/
.dmypy.json
# IDEs and editors
.vscode/
.idea/
*.swp
*.swo
# OS-specific
.DS_Store
Thumbs.db
# Logs
*.log
TODO.md
```
## /AGENTS.md
# 项目说明
PromptWorks 是一个聚焦 Prompt 资产管理与大模型运营的全栈解决方案,代码仓库包含 FastAPI 后端与 Vue + Element Plus 前端。平台支持 Prompt 全生命周期管理、模型调用、版本对比与指标跟踪,为团队提供统一展示与协同的智能运营后台。
# 开发规范
1. 本项目后端是python+fastapi开发,使用uv管理环境,使用poe配置任务,使用pytest测试
2. 后端开发完成后需要写对应的测试用例,并且通过uv run poe test-all测试
3. 项目前端使用Vue3+Element Plus开发,代码在./frontend中
4. 后端的api文件夹内文件仅实现接口定义、类型定义与检测、对应业务逻辑函数调用,具体业务逻辑写在services文件夹中
5. 每次开发任务完成并测试无误之后,将代码commit到本地git中(禁止:上传到云端和合并到dev或main),需要有简短的中文提交信息
6. 若要求更新README.md,需要同步修改英文版的docs/README_en.md
7. 编码统一要求utf-8
8. 若要求git commit,需要带中文信息
## /Dockerfile
``` path="/Dockerfile"
FROM python:3.11-slim AS base
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PIP_NO_CACHE_DIR=1
WORKDIR /app
# 预先复制必要文件,加速 Docker 层缓存
COPY pyproject.toml README.md alembic.ini /app/
COPY app /app/app
COPY alembic /app/alembic
RUN pip install --upgrade pip && \
pip install --no-cache-dir .
# 拷贝入口脚本
COPY docker/backend/entrypoint.sh /app/docker/backend/entrypoint.sh
RUN chmod +x /app/docker/backend/entrypoint.sh
EXPOSE 8000
ENTRYPOINT ["/app/docker/backend/entrypoint.sh"]
```
## /README.md

中文 | [English](docs/README_en.md) | [更新记录](docs/UPDATES.md)
# PromptWorks 项目总览
PromptWorks 是一个聚焦 Prompt 资产管理与大模型运营的全栈解决方案,仓库内包含 FastAPI 后端与 Vue + Element Plus 前端。平台支持 Prompt 全生命周期管理、模型配置、版本对比与评估实验,为团队提供统一的提示词协作与测试工作台。
## ✨ 核心能力
- **Prompt 管理**:支持提示词的创建、版本迭代与标签归类,保留完整审计信息。
- **版本对比**:提供差异视图,快速识别提示词更新带来的内容变化。
- **模型运营**:集中管理可用大模型服务与调用配额,为 A/B 实验提供能力。
- **评估测试**:后端暴露实验执行、指标记录能力,前端已预置测试面板待接入。
## 🧱 技术栈
- **后端**:Python 3.10+、FastAPI、SQLAlchemy、Alembic、Redis、Celery。
- **前端**:Vite、Vue 3(TypeScript)、Vue Router、Element Plus。
- **工具链**:uv 进行依赖与任务管理,PoeThePoet 统一开发命令,pytest + coverage 保证质量。
## 🏗️ 系统架构
- **后端服务**:位于 `app/` 目录,采用 FastAPI + SQLAlchemy 分层结构,业务逻辑集中在 `services/`。
- **数据库与消息组件**:默认使用 PostgreSQL 与 Redis,可按需扩展 Celery 任务队列能力。
- **前端应用**:`frontend/` 目录基于 Vite 构建,提供 Prompt 管理与测试的交互界面。
- **统一配置**:通过根目录 `.env` 与前端 `VITE_` 前缀环境变量解耦各环境差异。
## 🚀 快速开始
### 0. 环境准备
- Python 3.10+
- Node.js 18+
- PostgreSQL、Redis(生产环境推荐);本地可参考 `.env.example` 使用默认参数快速启动。
### 1. 后端环境初始化
```bash
# 同步后端依赖(包含开发工具)
uv sync --extra dev
# 初始化环境变量
cp .env.example .env
# 首次运行请先创建数据库与账号(以本地 postgres 超级用户为例)
createuser promptworks -P # 若已存在同名用户可跳过
createdb promptworks -O promptworks
# 或执行以下 SQL:
# psql -U postgres -c "CREATE USER promptworks WITH PASSWORD 'promptworks';"
# psql -U postgres -c "CREATE DATABASE promptworks OWNER promptworks;"
# 同步数据库结构
uv run alembic upgrade head
```
### 2. 前端依赖安装
```bash
cd frontend
npm install
```
### 3. 启动服务
```bash
# 后端 FastAPI 调试服务
uv run poe server
# 在新终端中启动前端开发服务器
cd frontend
npm run dev -- --host
## 或者
uv run poe frontend
```
后端默认运行在 `http://127.0.0.1:8000`(API 文档访问 `/docs`),前端默认运行在 `http://127.0.0.1:5173`。
### 4. 常用质量校验
```bash
uv run poe format # 统一代码风格
uv run poe lint # 静态类型检查
uv run poe test # 单元与集成测试
uv run poe test-all # 顺序执行上述三项
# 在 frontend 目录执行构建生产包
npm run build
```
## 🧪 测试任务消息约定
- 若测试任务的 Schema 未显式提供 `system` 消息,平台会把当前 Prompt 快照以 `user` 角色注入消息列表,兼容仅识别用户输入的模型。
- Schema 中若包含 `system` 消息,则保持原有顺序,不会重复注入快照内容。
- 仍会保证测试输入(`inputs`/`test_inputs`)中的问题作为后续 `user` 消息发送,支持多轮回放。
## 🐳 Docker 一键部署
- **环境准备**:确保本机已安装 Docker 与 Docker Compose(Docker Desktop 或 NerdCTL 均可)。
- **启动命令**:
```bash
docker compose up -d --build
```
- **访问入口**:前端服务默认暴露在 `http://localhost:18080`,后端 API 为 `http://localhost:8000/api/v1`,数据库与 Redis 对应端口分别为 `15432` 与 `6379`。
- **停止/清理**:
```bash
docker compose down # 停止容器
docker compose down -v # 停止并删除数据卷
```
### 容器编排说明
| 服务 | 说明 | 端口 | 额外信息 |
| --- | --- | --- | --- |
| `postgres` | PostgreSQL 数据库 | 15432 | 默认账户、密码、库名均为 `promptworks` |
| `redis` | Redis 缓存/消息队列 | 6379 | 已启用 AOF,适合作为开发环境使用 |
| `backend` | FastAPI 后端 | 8000 | 启动前自动执行 `alembic upgrade head` 同步结构 |
| `frontend` | Nginx 托管的前端静态文件 | 18080 | 构建时可通过 `VITE_API_BASE_URL` 定制后端地址 |
> 提示:如需自定义端口或数据库密码,可在 `docker-compose.yml` 中调整对应环境变量与端口映射(当前示例采用 `15432`、`18080`),然后重新执行 `docker compose up -d --build`。
## ⚙️ 环境变量说明
| 名称 | 是否必填 | 默认值 | 说明 |
| --- | --- | --- | --- |
| `APP_ENV` | 否 | `development` | 控制当前运行环境,用于日志等差异化配置。 |
| `APP_TEST_MODE` | 否 | `false` | 启用后输出 DEBUG 级别日志,建议仅在本地调试使用。 |
| `API_V1_STR` | 否 | `/api/v1` | 后端 API 的版本前缀。 |
| `PROJECT_NAME` | 否 | `PromptWorks` | 系统显示名称。 |
| `DATABASE_URL` | 是 | `postgresql+psycopg://...` | PostgreSQL 连接串,必须保证数据库可访问。 |
| `REDIS_URL` | 否 | `redis://localhost:6379/0` | Redis 连接地址,可用于缓存或异步任务。 |
| `BACKEND_CORS_ORIGINS` | 否 | `http://localhost:5173` | 允许跨域访问的前端地址,可用逗号分隔多个 URL。 |
| `BACKEND_CORS_ALLOW_CREDENTIALS` | 否 | `true` | 控制是否允许携带 Cookie 等认证信息。 |
| `OPENAI_API_KEY` | 否 | 空 | 集成 OpenAI 模型时填写对应密钥。 |
| `ANTHROPIC_API_KEY` | 否 | 空 | 集成 Anthropic 模型时填写对应密钥。 |
| `VITE_API_BASE_URL` | 前端必填 | `http://127.0.0.1:8000/api/v1` | 前端访问后端的基础地址,需写入 `frontend/.env.local`。 |
> 提示:复制 `.env.example` 为 `.env` 后,可在 `frontend/.env.example`(待创建)或 `.env.local` 中设置 `VITE_` 开头的变量,使得构建与运行环境保持一致。
## 🗂️ 项目结构
```
.
├── alembic/ # 数据库迁移脚本
├── app/ # FastAPI 应用主体
│ ├── api/ # REST 接口定义与依赖注入
│ ├── core/ # 配置、日志、跨域等基础设施
│ ├── db/ # 数据库会话与初始化
│ ├── models/ # SQLAlchemy 模型
│ ├── schemas/ # Pydantic 序列化模型
│ └── services/ # 业务服务封装
├── frontend/ # Vue 3 前端工程
│ ├── public/
│ ├── src/
│ │ ├── api/ # HTTP 客户端与请求封装
│ │ ├── router/ # 路由配置
│ │ ├── types/ # TypeScript 类型定义
│ │ └── views/ # 页面组件
├── tests/ # pytest 用例
├── pyproject.toml # 后端依赖与任务配置
├── README.md # 项目说明文档
└── .env.example # 环境变量模板
```
## 📡 API 与前端联动
- 后端提供 `/api/v1/prompts`、`/api/v1/test_prompt` 等接口供前端调用,当前前端示例使用本地 mock 数据,可在后续迭代中替换为真实 API。
- Prompt 详情页已预置版本 diff 组件与测试面板,接入接口后可实现端到端的提示词验证闭环。
- 测试任务列表默认展示新版任务入口,旧版“新建测试任务”按钮已隐藏,新版入口文案统一为“新建测试任务”。
## 🤝 贡献指南
1. 新建功能分支,遵循“格式化 → 类型检查 → 测试”工作流。
2. 开发完成后运行 `uv run poe test-all` 确保质量基线。
3. 提交 Pull Request,并在描述中说明变更范围与验证方式;本地提交信息建议使用简短中文描述。
欢迎提出 Issue 或改进建议,共建 PromptWorks!
## Star History
[](https://www.star-history.com/#YellowSeaa/PromptWorks&type=date&legend=top-left)
## /alembic.ini
```ini path="/alembic.ini"
[alembic]
script_location = alembic
prepend_sys_path = .
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = INFO
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
```
## /alembic/env.py
```py path="/alembic/env.py"
from __future__ import annotations
import sys
from logging.config import fileConfig
from pathlib import Path
from alembic import context
from sqlalchemy import create_engine, pool
# Ensure the project root is on the Python path so app modules can be imported.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.config import settings
from app.models import Base # noqa: E402 (import after sys.path adjustment)
import app.models # noqa: F401,E402 (force model metadata registration)
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
def get_url() -> str:
"""Return the database URL from application settings."""
return settings.DATABASE_URL
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode."""
url = get_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
connectable = create_engine(
get_url(),
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata, compare_type=True
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
```
## /alembic/script.py.mako
```mako path="/alembic/script.py.mako"
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa # noqa: F401
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Apply the migration."""
pass
def downgrade() -> None:
"""Revert the migration."""
pass
```
## /alembic/versions/0001_initial_schema.py
```py path="/alembic/versions/0001_initial_schema.py"
"""Create initial tables
Revision ID: 0001_initial_schema
Revises:
Create Date: 2024-10-06 00:00:00
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "0001_initial_schema"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
test_run_status_enum = postgresql.ENUM(
"pending",
"running",
"completed",
"failed",
name="test_run_status",
create_type=False,
)
def upgrade() -> None:
"""Apply the initial database schema."""
op.execute(
sa.text(
"""
DO $
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_type WHERE typname = 'test_run_status'
) THEN
CREATE TYPE test_run_status AS ENUM (
'pending',
'running',
'completed',
'failed'
);
END IF;
END
$;
"""
)
)
op.create_table(
"prompts",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("version", sa.String(length=50), nullable=False),
sa.Column("content", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("author", sa.String(length=100), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name", "version", name="uq_prompt_name_version"),
)
op.create_index("ix_prompts_id", "prompts", ["id"], unique=False)
op.create_table(
"llm_providers",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("provider_name", sa.String(length=100), nullable=False),
sa.Column("model_name", sa.String(length=150), nullable=False),
sa.Column("base_url", sa.String(length=255), nullable=True),
sa.Column("api_key", sa.Text(), nullable=False),
sa.Column("parameters", sa.JSON(), nullable=False),
sa.Column("is_custom", sa.Boolean(), nullable=False),
sa.Column("logo_url", sa.String(length=255), nullable=True),
sa.Column("logo_emoji", sa.String(length=16), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_llm_providers_id", "llm_providers", ["id"], unique=False)
op.create_table(
"test_runs",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("prompt_id", sa.Integer(), nullable=False),
sa.Column("model_name", sa.String(length=100), nullable=False),
sa.Column("model_version", sa.String(length=50), nullable=True),
sa.Column("temperature", sa.Float(), nullable=False),
sa.Column("top_p", sa.Float(), nullable=False),
sa.Column("repetitions", sa.Integer(), nullable=False),
sa.Column("schema", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"status",
test_run_status_enum,
server_default=sa.text("'pending'::test_run_status"),
nullable=False,
),
sa.Column("notes", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.ForeignKeyConstraint(["prompt_id"], ["prompts.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_test_runs_id", "test_runs", ["id"], unique=False)
op.create_table(
"results",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("test_run_id", sa.Integer(), nullable=False),
sa.Column("run_index", sa.Integer(), nullable=False),
sa.Column("output", sa.Text(), nullable=False),
sa.Column(
"parsed_output", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.Column("tokens_used", sa.Integer(), nullable=True),
sa.Column("latency_ms", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.ForeignKeyConstraint(["test_run_id"], ["test_runs.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_results_id", "results", ["id"], unique=False)
op.create_table(
"metrics",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("result_id", sa.Integer(), nullable=False),
sa.Column("is_valid_json", sa.Boolean(), nullable=True),
sa.Column("schema_pass", sa.Boolean(), nullable=True),
sa.Column(
"missing_fields", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.Column(
"type_mismatches", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.Column("consistency_score", sa.Float(), nullable=True),
sa.Column("numeric_accuracy", sa.Float(), nullable=True),
sa.Column("boolean_accuracy", sa.Float(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.ForeignKeyConstraint(["result_id"], ["results.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_metrics_id", "metrics", ["id"], unique=False)
def downgrade() -> None:
"""Drop the initial database schema."""
op.drop_index("ix_metrics_id", table_name="metrics")
op.drop_table("metrics")
op.drop_index("ix_results_id", table_name="results")
op.drop_table("results")
op.drop_index("ix_test_runs_id", table_name="test_runs")
op.drop_table("test_runs")
op.drop_index("ix_llm_providers_id", table_name="llm_providers")
op.drop_table("llm_providers")
op.drop_index("ix_prompts_id", table_name="prompts")
op.drop_table("prompts")
op.execute(sa.text("DROP TYPE IF EXISTS test_run_status"))
```
## /alembic/versions/0002_prompt_class_and_versions.py
```py path="/alembic/versions/0002_prompt_class_and_versions.py"
"""introduce prompt classifications and versions
Revision ID: 0002_prompt_class_and_versions
Revises: 0001_initial_schema
Create Date: 2025-09-20 00:05:00
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "0002_prompt_class_and_versions"
down_revision: Union[str, None] = "0001_initial_schema"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
connection = op.get_bind()
op.create_table(
"prompts_class",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False),
sa.Column("name", sa.String(length=255), nullable=False, unique=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
)
op.add_column(
"prompts",
sa.Column("class_id", sa.Integer(), nullable=True),
)
op.add_column(
"prompts",
sa.Column("current_version_id", sa.Integer(), nullable=True),
)
op.create_index("ix_prompts_class_id", "prompts", ["class_id"], unique=False)
op.create_index(
"ix_prompts_current_version_id", "prompts", ["current_version_id"], unique=False
)
op.create_table(
"prompts_versions",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False),
sa.Column("prompt_id", sa.Integer(), nullable=False),
sa.Column("version", sa.String(length=50), nullable=False),
sa.Column("content", sa.Text(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.ForeignKeyConstraint(["prompt_id"], ["prompts.id"], ondelete="CASCADE"),
sa.UniqueConstraint("prompt_id", "version", name="uq_prompt_version"),
)
op.create_index(
"ix_prompts_versions_prompt_id", "prompts_versions", ["prompt_id"], unique=False
)
result = connection.execute(
sa.text(
"INSERT INTO prompts_class (name, description) VALUES (:name, :description) RETURNING id"
),
{
"name": "默认分类",
"description": "迁移自动创建的默认分类",
},
)
default_class_id = result.scalar_one()
connection.execute(
sa.text("UPDATE prompts SET class_id = :class_id"),
{"class_id": default_class_id},
)
connection.execute(
sa.text(
"""
INSERT INTO prompts_versions (prompt_id, version, content, created_at, updated_at)
SELECT id, version, content, created_at, updated_at
FROM prompts
"""
)
)
connection.execute(
sa.text(
"""
UPDATE prompts p
SET current_version_id = pv.id
FROM prompts_versions pv
WHERE pv.prompt_id = p.id AND pv.version = p.version
"""
)
)
op.drop_constraint("uq_prompt_name_version", "prompts", type_="unique")
op.create_unique_constraint("uq_prompt_class_name", "prompts", ["class_id", "name"])
op.create_foreign_key(
"prompts_class_id_fkey",
"prompts",
"prompts_class",
["class_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"prompts_current_version_id_fkey",
"prompts",
"prompts_versions",
["current_version_id"],
["id"],
)
op.drop_constraint("test_runs_prompt_id_fkey", "test_runs", type_="foreignkey")
op.alter_column("test_runs", "prompt_id", new_column_name="prompt_version_id")
connection.execute(
sa.text(
"""
UPDATE test_runs tr
SET prompt_version_id = pv.id
FROM prompts_versions pv
WHERE pv.prompt_id = tr.prompt_version_id
"""
)
)
op.create_foreign_key(
"test_runs_prompt_version_id_fkey",
"test_runs",
"prompts_versions",
["prompt_version_id"],
["id"],
ondelete="CASCADE",
)
op.alter_column(
"prompts",
"class_id",
existing_type=sa.Integer(),
nullable=False,
)
op.drop_column("prompts", "version")
op.drop_column("prompts", "content")
def downgrade() -> None:
connection = op.get_bind()
op.add_column(
"prompts",
sa.Column("version", sa.String(length=50), nullable=True),
)
op.add_column(
"prompts",
sa.Column("content", sa.Text(), nullable=True),
)
connection.execute(
sa.text(
"""
UPDATE prompts p
SET version = pv.version,
content = pv.content
FROM prompts_versions pv
WHERE pv.id = p.current_version_id
"""
)
)
op.drop_constraint(
"test_runs_prompt_version_id_fkey", "test_runs", type_="foreignkey"
)
connection.execute(
sa.text(
"""
UPDATE test_runs tr
SET prompt_version_id = pv.prompt_id
FROM prompts_versions pv
WHERE pv.id = tr.prompt_version_id
"""
)
)
op.alter_column("test_runs", "prompt_version_id", new_column_name="prompt_id")
op.create_foreign_key(
"test_runs_prompt_id_fkey",
"test_runs",
"prompts",
["prompt_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("prompts_current_version_id_fkey", "prompts", type_="foreignkey")
op.drop_constraint("prompts_class_id_fkey", "prompts", type_="foreignkey")
op.drop_constraint("uq_prompt_class_name", "prompts", type_="unique")
connection.execute(
sa.text("UPDATE prompts SET class_id = NULL, current_version_id = NULL"),
)
op.create_unique_constraint(
"uq_prompt_name_version", "prompts", ["name", "version"]
)
op.drop_index("ix_prompts_current_version_id", table_name="prompts")
op.drop_index("ix_prompts_class_id", table_name="prompts")
op.drop_column("prompts", "current_version_id")
op.drop_column("prompts", "class_id")
op.drop_index("ix_prompts_versions_prompt_id", table_name="prompts_versions")
op.drop_table("prompts_versions")
op.drop_table("prompts_class")
connection.execute(
sa.text(
"""
UPDATE prompts
SET version = COALESCE(version, 'v1'),
content = COALESCE(content, '')
"""
)
)
op.alter_column(
"prompts",
"version",
existing_type=sa.String(length=50),
nullable=False,
)
op.alter_column(
"prompts",
"content",
existing_type=sa.Text(),
nullable=False,
)
```
## /alembic/versions/0003_prompt_tags.py
```py path="/alembic/versions/0003_prompt_tags.py"
"""add prompt tags support
Revision ID: 0003_prompt_tags
Revises: 0002_prompt_class_and_versions
Create Date: 2025-09-20 00:30:00
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "0003_prompt_tags"
down_revision: Union[str, None] = "0002_prompt_class_and_versions"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
TAG_SEED_DATA: list[tuple[str, str]] = [
("常规", "#2563EB"),
("测试", "#10B981"),
("紧急", "#F97316"),
("实验", "#8B5CF6"),
("归档", "#6B7280"),
]
def upgrade() -> None:
op.create_table(
"prompt_tags",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False),
sa.Column("name", sa.String(length=100), nullable=False, unique=True),
sa.Column("color", sa.String(length=7), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
)
op.create_table(
"prompt_tag_links",
sa.Column("prompt_id", sa.Integer(), nullable=False),
sa.Column("tag_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["prompt_id"],
["prompts.id"],
ondelete="CASCADE",
name="prompt_tag_links_prompt_id_fkey",
),
sa.ForeignKeyConstraint(
["tag_id"],
["prompt_tags.id"],
ondelete="CASCADE",
name="prompt_tag_links_tag_id_fkey",
),
sa.PrimaryKeyConstraint("prompt_id", "tag_id"),
)
tags_table = sa.table(
"prompt_tags",
sa.column("name", sa.String(length=100)),
sa.column("color", sa.String(length=7)),
)
op.bulk_insert(
tags_table,
[{"name": name, "color": color} for name, color in TAG_SEED_DATA],
)
def downgrade() -> None:
op.drop_table("prompt_tag_links")
op.drop_table("prompt_tags")
```
## /alembic/versions/0004_refactor_llm_providers.py
```py path="/alembic/versions/0004_refactor_llm_providers.py"
"""refactor llm providers and add models table
Revision ID: 0004_refactor_llm_providers
Revises: efea0d0224c5
Create Date: 2025-10-05 10:32:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "0004_refactor_llm_providers"
down_revision: Union[str, None] = "efea0d0224c5"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"llm_providers",
sa.Column("provider_key", sa.String(length=100), nullable=True),
)
op.add_column(
"llm_providers",
sa.Column("description", sa.Text(), nullable=True),
)
op.add_column(
"llm_providers",
sa.Column(
"is_archived",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
op.add_column(
"llm_providers",
sa.Column("default_model_name", sa.String(length=150), nullable=True),
)
op.drop_column("llm_providers", "parameters")
op.drop_column("llm_providers", "model_name")
op.create_index(
"ix_llm_providers_provider_key",
"llm_providers",
["provider_key"],
unique=False,
)
op.create_table(
"llm_models",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("provider_id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(length=150), nullable=False),
sa.Column("capability", sa.String(length=120), nullable=True),
sa.Column("quota", sa.String(length=120), nullable=True),
sa.Column(
"parameters", sa.JSON(), nullable=False, server_default=sa.text("'{}'")
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.ForeignKeyConstraint(
["provider_id"],
["llm_providers.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("provider_id", "name", name="uq_llm_model_provider_name"),
)
op.create_index(
"ix_llm_models_provider_id", "llm_models", ["provider_id"], unique=False
)
def downgrade() -> None:
op.drop_index("ix_llm_models_provider_id", table_name="llm_models")
op.drop_table("llm_models")
op.drop_index("ix_llm_providers_provider_key", table_name="llm_providers")
op.add_column(
"llm_providers",
sa.Column(
"model_name",
sa.String(length=150),
nullable=False,
server_default="default-model",
),
)
op.add_column(
"llm_providers",
sa.Column(
"parameters",
sa.JSON(),
nullable=False,
server_default=sa.text("'{}'"),
),
)
op.drop_column("llm_providers", "default_model_name")
op.drop_column("llm_providers", "is_archived")
op.drop_column("llm_providers", "description")
op.drop_column("llm_providers", "provider_key")
```
## /alembic/versions/0005_drop_desc_params.py
```py path="/alembic/versions/0005_drop_desc_params.py"
"""drop description and model parameters
Revision ID: 0005_drop_desc_params
Revises: 0004_refactor_llm_providers
Create Date: 2025-10-05 15:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "0005_drop_desc_params"
down_revision: Union[str, None] = "0004_refactor_llm_providers"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.drop_column("llm_providers", "description")
op.drop_column("llm_models", "parameters")
def downgrade() -> None:
op.add_column(
"llm_models",
sa.Column(
"parameters",
sa.JSON(),
nullable=False,
server_default=sa.text("'{}'"),
),
)
op.add_column(
"llm_providers",
sa.Column("description", sa.Text(), nullable=True),
)
```
## /alembic/versions/07f7c967ab21_add_llm_usage_logs.py
```py path="/alembic/versions/07f7c967ab21_add_llm_usage_logs.py"
"""add llm usage logs table
Revision ID: 07f7c967ab21
Revises: f5e1a97c2e3d
Create Date: 2025-10-06 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "07f7c967ab21"
down_revision: Union[str, None] = "f5e1a97c2e3d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"llm_usage_logs",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False),
sa.Column(
"provider_id",
sa.Integer(),
sa.ForeignKey("llm_providers.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column(
"model_id",
sa.Integer(),
sa.ForeignKey("llm_models.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("model_name", sa.String(length=150), nullable=False),
sa.Column(
"source",
sa.String(length=50),
nullable=False,
server_default=sa.text("'quick_test'"),
),
sa.Column(
"prompt_id",
sa.Integer(),
sa.ForeignKey("prompts.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column(
"prompt_version_id",
sa.Integer(),
sa.ForeignKey("prompts_versions.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("messages", sa.JSON(), nullable=True),
sa.Column("parameters", sa.JSON(), nullable=True),
sa.Column("response_text", sa.Text(), nullable=True),
sa.Column("temperature", sa.Float(), nullable=True),
sa.Column("latency_ms", sa.Integer(), nullable=True),
sa.Column("prompt_tokens", sa.Integer(), nullable=True),
sa.Column("completion_tokens", sa.Integer(), nullable=True),
sa.Column("total_tokens", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
)
op.create_index("ix_llm_usage_logs_provider_id", "llm_usage_logs", ["provider_id"])
op.create_index("ix_llm_usage_logs_created_at", "llm_usage_logs", ["created_at"])
def downgrade() -> None:
op.drop_index("ix_llm_usage_logs_created_at", table_name="llm_usage_logs")
op.drop_index("ix_llm_usage_logs_provider_id", table_name="llm_usage_logs")
op.drop_table("llm_usage_logs")
```
## /alembic/versions/2d8d3a4e0c8b_merge_prompt_test_heads.py
```py path="/alembic/versions/2d8d3a4e0c8b_merge_prompt_test_heads.py"
"""merge prompt test branch heads
Revision ID: 2d8d3a4e0c8b
Revises: 3152d7c2b4f0, 72f3f786c4a1
Create Date: 2025-10-11 17:05:00.000000
"""
from typing import Sequence, Union
from alembic import op # noqa: F401 - 维持 Alembic 导入规范
revision: str = "2d8d3a4e0c8b"
down_revision: Union[str, tuple[str, ...], None] = ("3152d7c2b4f0", "72f3f786c4a1")
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""该迁移仅用于合并分支,不包含业务变更。"""
def downgrade() -> None:
"""回滚时无需操作。"""
```
## /alembic/versions/3152d7c2b4f0_add_llm_model_concurrency.py
```py path="/alembic/versions/3152d7c2b4f0_add_llm_model_concurrency.py"
"""add concurrency limit to llm models
Revision ID: 3152d7c2b4f0
Revises: 9b546f1b6f1a
Create Date: 2025-10-06 18:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "3152d7c2b4f0"
down_revision: Union[str, None] = "9b546f1b6f1a"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"llm_models",
sa.Column(
"concurrency_limit",
sa.Integer(),
nullable=False,
server_default="5",
),
)
def downgrade() -> None:
op.drop_column("llm_models", "concurrency_limit")
```
## /alembic/versions/6d6a1f6dfb41_create_system_settings_table.py
```py path="/alembic/versions/6d6a1f6dfb41_create_system_settings_table.py"
"""create system settings table
Revision ID: 6d6a1f6dfb41
Revises: 3152d7c2b4f0
Create Date: 2025-12-02 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "6d6a1f6dfb41"
down_revision: Union[str, None] = "a1b2c3d4e5f6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"system_settings",
sa.Column("key", sa.String(length=120), nullable=False),
sa.Column("value", sa.JSON(), nullable=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("key"),
)
def downgrade() -> None:
op.drop_table("system_settings")
```
## /alembic/versions/72f3f786c4a1_add_prompt_test_tables.py
```py path="/alembic/versions/72f3f786c4a1_add_prompt_test_tables.py"
"""add prompt test task/unit/experiment tables
Revision ID: 72f3f786c4a1
Revises: 9b546f1b6f1a
Create Date: 2025-02-14 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision: str = "72f3f786c4a1"
down_revision: Union[str, None] = "9b546f1b6f1a"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
bind = op.get_bind()
task_status_enum = postgresql.ENUM(
"draft",
"ready",
"running",
"completed",
"failed",
name="prompt_test_task_status",
create_type=False,
)
experiment_status_enum = postgresql.ENUM(
"pending",
"running",
"completed",
"failed",
"cancelled",
name="prompt_test_experiment_status",
create_type=False,
)
task_status_enum.create(bind, checkfirst=True)
experiment_status_enum.create(bind, checkfirst=True)
op.create_table(
"prompt_test_tasks",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(length=120), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("prompt_version_id", sa.Integer(), nullable=True),
sa.Column("owner_id", sa.Integer(), nullable=True),
sa.Column("config", sa.JSON(), nullable=True),
sa.Column(
"status",
task_status_enum,
nullable=False,
server_default="draft",
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.ForeignKeyConstraint(
["prompt_version_id"], ["prompts_versions.id"], ondelete="SET NULL"
),
)
op.create_table(
"prompt_test_units",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("task_id", sa.Integer(), nullable=False),
sa.Column("prompt_version_id", sa.Integer(), nullable=True),
sa.Column("name", sa.String(length=120), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("model_name", sa.String(length=100), nullable=False),
sa.Column("llm_provider_id", sa.Integer(), nullable=True),
sa.Column("temperature", sa.Float(), nullable=False, server_default="0.7"),
sa.Column("top_p", sa.Float(), nullable=True),
sa.Column("rounds", sa.Integer(), nullable=False, server_default="1"),
sa.Column("prompt_template", sa.Text(), nullable=True),
sa.Column("variables", sa.JSON(), nullable=True),
sa.Column("parameters", sa.JSON(), nullable=True),
sa.Column("expectations", sa.JSON(), nullable=True),
sa.Column("tags", sa.JSON(), nullable=True),
sa.Column("extra", sa.JSON(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.ForeignKeyConstraint(
["task_id"], ["prompt_test_tasks.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["prompt_version_id"], ["prompts_versions.id"], ondelete="SET NULL"
),
sa.UniqueConstraint("task_id", "name", name="uq_prompt_test_unit_task_name"),
)
op.create_index(
"ix_prompt_test_units_task_id", "prompt_test_units", ["task_id"], unique=False
)
op.create_table(
"prompt_test_experiments",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("unit_id", sa.Integer(), nullable=False),
sa.Column("batch_id", sa.String(length=64), nullable=True),
sa.Column("sequence", sa.Integer(), nullable=False, server_default="1"),
sa.Column(
"status",
experiment_status_enum,
nullable=False,
server_default="pending",
),
sa.Column("outputs", sa.JSON(), nullable=True),
sa.Column("metrics", sa.JSON(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.ForeignKeyConstraint(
["unit_id"], ["prompt_test_units.id"], ondelete="CASCADE"
),
)
op.create_index(
"ix_prompt_test_experiments_unit_id",
"prompt_test_experiments",
["unit_id"],
unique=False,
)
op.create_index(
"ix_prompt_test_experiments_batch_id",
"prompt_test_experiments",
["batch_id"],
unique=False,
)
def downgrade() -> None:
op.drop_index(
"ix_prompt_test_experiments_batch_id",
table_name="prompt_test_experiments",
)
op.drop_index(
"ix_prompt_test_experiments_unit_id",
table_name="prompt_test_experiments",
)
op.drop_table("prompt_test_experiments")
op.drop_index("ix_prompt_test_units_task_id", table_name="prompt_test_units")
op.drop_table("prompt_test_units")
op.drop_table("prompt_test_tasks")
bind = op.get_bind()
sa.Enum(name="prompt_test_experiment_status").drop(bind, checkfirst=True)
sa.Enum(name="prompt_test_task_status").drop(bind, checkfirst=True)
```
## /alembic/versions/9b546f1b6f1a_add_test_run_batch_id.py
```py path="/alembic/versions/9b546f1b6f1a_add_test_run_batch_id.py"
"""add batch id to test runs
Revision ID: 9b546f1b6f1a
Revises: 07f7c967ab21
Create Date: 2025-10-06 15:30:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "9b546f1b6f1a"
down_revision: Union[str, None] = "07f7c967ab21"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"test_runs",
sa.Column("batch_id", sa.String(length=64), nullable=True),
)
op.create_index("ix_test_runs_batch_id", "test_runs", ["batch_id"])
def downgrade() -> None:
op.drop_index("ix_test_runs_batch_id", table_name="test_runs")
op.drop_column("test_runs", "batch_id")
```
## /alembic/versions/a1b2c3d4e5f6_add_soft_delete_to_prompt_test_tasks.py
```py path="/alembic/versions/a1b2c3d4e5f6_add_soft_delete_to_prompt_test_tasks.py"
"""add soft delete flag to prompt test tasks
Revision ID: a1b2c3d4e5f6
Revises: 2d8d3a4e0c8b
Create Date: 2025-10-25 12:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "a1b2c3d4e5f6"
down_revision: Union[str, Sequence[str], None] = "2d8d3a4e0c8b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"prompt_test_tasks",
sa.Column(
"is_deleted",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
def downgrade() -> None:
op.drop_column("prompt_test_tasks", "is_deleted")
```
## /alembic/versions/ddc6143e8fb8_add_default_prompt_class.py
```py path="/alembic/versions/ddc6143e8fb8_add_default_prompt_class.py"
"""add default prompt class
Revision ID: ddc6143e8fb8
Revises: 0003_prompt_tags
Create Date: 2025-10-03 13:50:27.116107
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "ddc6143e8fb8"
down_revision: Union[str, None] = "0003_prompt_tags"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
DEFAULT_CLASS_NAME = "默认分类"
DEFAULT_CLASS_DESCRIPTION = "系统自动创建的默认分类"
def upgrade() -> None:
"""确保数据库中存在默认 Prompt 分类。"""
connection = op.get_bind()
existing_id = connection.execute(
sa.text("SELECT id FROM prompts_class WHERE name = :name"),
{"name": DEFAULT_CLASS_NAME},
).scalar_one_or_none()
if existing_id is not None:
return
connection.execute(
sa.text(
"INSERT INTO prompts_class (name, description) VALUES (:name, :description)"
),
{"name": DEFAULT_CLASS_NAME, "description": DEFAULT_CLASS_DESCRIPTION},
)
def downgrade() -> None:
"""在安全的前提下移除默认 Prompt 分类。"""
connection = op.get_bind()
default_id = connection.execute(
sa.text("SELECT id FROM prompts_class WHERE name = :name"),
{"name": DEFAULT_CLASS_NAME},
).scalar_one_or_none()
if default_id is None:
return
prompt_count = connection.execute(
sa.text("SELECT COUNT(*) FROM prompts WHERE class_id = :class_id"),
{"class_id": default_id},
).scalar_one()
if prompt_count:
return
connection.execute(
sa.text("DELETE FROM prompts_class WHERE id = :class_id"),
{"class_id": default_id},
)
```
## /alembic/versions/efea0d0224c5_seed_default_prompt_tags.py
```py path="/alembic/versions/efea0d0224c5_seed_default_prompt_tags.py"
"""seed default prompt tags
Revision ID: efea0d0224c5
Revises: ddc6143e8fb8
Create Date: 2025-10-03 13:58:18.007775
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "efea0d0224c5"
down_revision: Union[str, None] = "ddc6143e8fb8"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
DEFAULT_TAGS = [
{"name": "通用运营", "color": "#409EFF"},
{"name": "客户关怀", "color": "#67C23A"},
{"name": "营销活动", "color": "#F56C6C"},
{"name": "产品公告", "color": "#E6A23C"},
{"name": "数据分析", "color": "#909399"},
]
def upgrade() -> None:
"""插入五个常用标签作为默认数据。"""
connection = op.get_bind()
for tag in DEFAULT_TAGS:
existing_id = connection.execute(
sa.text("SELECT id FROM prompt_tags WHERE name = :name"),
{"name": tag["name"]},
).scalar_one_or_none()
if existing_id is not None:
continue
connection.execute(
sa.text("INSERT INTO prompt_tags (name, color) VALUES (:name, :color)"),
tag,
)
def downgrade() -> None:
"""在无引用的情况下移除默认标签。"""
connection = op.get_bind()
for tag in DEFAULT_TAGS:
tag_id = connection.execute(
sa.text("SELECT id FROM prompt_tags WHERE name = :name"),
{"name": tag["name"]},
).scalar_one_or_none()
if tag_id is None:
continue
usage = connection.execute(
sa.text("SELECT COUNT(*) FROM prompt_tag_links WHERE tag_id = :tag_id"),
{"tag_id": tag_id},
).scalar_one()
if usage:
continue
connection.execute(
sa.text("DELETE FROM prompt_tags WHERE id = :tag_id"),
{"tag_id": tag_id},
)
```
## /alembic/versions/f5e1a97c2e3d_seed_sample_prompt.py
```py path="/alembic/versions/f5e1a97c2e3d_seed_sample_prompt.py"
"""seed sample prompt data
Revision ID: f5e1a97c2e3d
Revises: 0005_drop_desc_params
Create Date: 2025-10-05 18:30:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "f5e1a97c2e3d"
down_revision: Union[str, None] = "0005_drop_desc_params"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
DEFAULT_CLASS_NAME = "默认分类"
DEFAULT_CLASS_DESCRIPTION = "系统自动创建的默认分类"
SAMPLE_PROMPT_NAME = "示例:客服欢迎语"
SAMPLE_PROMPT_AUTHOR = "系统预置"
SAMPLE_PROMPT_DESCRIPTION = "面向客服首次接待的欢迎语示例"
SAMPLE_VERSION_NAME = "v1.0.0"
SAMPLE_VERSION_CONTENT = (
"你是一名资深客服代表,请使用亲切、专业的语气欢迎客户。\n"
"1. 简要自我介绍,并表明乐于协助。\n"
"2. 询问客户关注的问题点。\n"
"3. 告知可提供的帮助范围,并给出下一步指引。"
)
def _get_or_create_default_class(connection: sa.engine.Connection) -> int:
class_id = connection.execute(
sa.text("SELECT id FROM prompts_class WHERE name = :name"),
{"name": DEFAULT_CLASS_NAME},
).scalar_one_or_none()
if class_id is not None:
return class_id
connection.execute(
sa.text(
"INSERT INTO prompts_class (name, description) VALUES (:name, :description)"
),
{"name": DEFAULT_CLASS_NAME, "description": DEFAULT_CLASS_DESCRIPTION},
)
return connection.execute(
sa.text("SELECT id FROM prompts_class WHERE name = :name"),
{"name": DEFAULT_CLASS_NAME},
).scalar_one()
def _attach_default_tags(connection: sa.engine.Connection, prompt_id: int) -> None:
tag_ids = (
connection.execute(
sa.text("SELECT id FROM prompt_tags ORDER BY id ASC LIMIT 2")
)
.scalars()
.all()
)
for tag_id in tag_ids:
connection.execute(
sa.text(
"INSERT INTO prompt_tag_links (prompt_id, tag_id) VALUES (:prompt_id, :tag_id)"
),
{"prompt_id": prompt_id, "tag_id": tag_id},
)
def upgrade() -> None:
connection = op.get_bind()
prompt_exists = connection.execute(
sa.text("SELECT 1 FROM prompts WHERE name = :name"),
{"name": SAMPLE_PROMPT_NAME},
).scalar_one_or_none()
if prompt_exists:
return
class_id = _get_or_create_default_class(connection)
connection.execute(
sa.text(
"INSERT INTO prompts (class_id, name, description, author) "
"VALUES (:class_id, :name, :description, :author)"
),
{
"class_id": class_id,
"name": SAMPLE_PROMPT_NAME,
"description": SAMPLE_PROMPT_DESCRIPTION,
"author": SAMPLE_PROMPT_AUTHOR,
},
)
prompt_id = connection.execute(
sa.text("SELECT id FROM prompts WHERE class_id = :class_id AND name = :name"),
{"class_id": class_id, "name": SAMPLE_PROMPT_NAME},
).scalar_one()
connection.execute(
sa.text(
"INSERT INTO prompts_versions (prompt_id, version, content) "
"VALUES (:prompt_id, :version, :content)"
),
{
"prompt_id": prompt_id,
"version": SAMPLE_VERSION_NAME,
"content": SAMPLE_VERSION_CONTENT,
},
)
version_id = connection.execute(
sa.text(
"SELECT id FROM prompts_versions WHERE prompt_id = :prompt_id "
"AND version = :version"
),
{"prompt_id": prompt_id, "version": SAMPLE_VERSION_NAME},
).scalar_one()
connection.execute(
sa.text(
"UPDATE prompts SET current_version_id = :version_id WHERE id = :prompt_id"
),
{"version_id": version_id, "prompt_id": prompt_id},
)
_attach_default_tags(connection, prompt_id)
def downgrade() -> None:
connection = op.get_bind()
prompt_id = connection.execute(
sa.text("SELECT id FROM prompts WHERE name = :name"),
{"name": SAMPLE_PROMPT_NAME},
).scalar_one_or_none()
if prompt_id is None:
return
connection.execute(
sa.text("DELETE FROM prompt_tag_links WHERE prompt_id = :prompt_id"),
{"prompt_id": prompt_id},
)
connection.execute(
sa.text("DELETE FROM prompts_versions WHERE prompt_id = :prompt_id"),
{"prompt_id": prompt_id},
)
connection.execute(
sa.text("DELETE FROM prompts WHERE id = :prompt_id"),
{"prompt_id": prompt_id},
)
remaining = connection.execute(
sa.text(
"SELECT COUNT(*) FROM prompts "
"WHERE class_id = (SELECT id FROM prompts_class WHERE name = :name)"
),
{"name": DEFAULT_CLASS_NAME},
).scalar_one()
if remaining == 0:
connection.execute(
sa.text("DELETE FROM prompts_class WHERE name = :name"),
{"name": DEFAULT_CLASS_NAME},
)
```
## /app/__init__.py
```py path="/app/__init__.py"
```
## /app/api/__init__.py
```py path="/app/api/__init__.py"
```
## /app/api/v1/__init__.py
```py path="/app/api/v1/__init__.py"
```
## /app/api/v1/api.py
```py path="/app/api/v1/api.py"
from fastapi import APIRouter
from app.api.v1.endpoints import (
llms,
prompt_classes,
prompt_tags,
prompts,
test_prompt,
usage,
prompt_test_tasks,
settings,
)
api_router = APIRouter()
api_router.include_router(llms.router, prefix="/llm-providers", tags=["llm_providers"])
api_router.include_router(
prompt_classes.router, prefix="/prompt-classes", tags=["prompt_classes"]
)
api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
api_router.include_router(
prompt_tags.router, prefix="/prompt-tags", tags=["prompt_tags"]
)
api_router.include_router(
test_prompt.router, prefix="/test_prompt", tags=["test_prompt"]
)
api_router.include_router(usage.router, prefix="/usage", tags=["usage"])
api_router.include_router(prompt_test_tasks.router)
api_router.include_router(settings.router)
```
## /app/api/v1/endpoints/__init__.py
```py path="/app/api/v1/endpoints/__init__.py"
```
## /app/api/v1/endpoints/llms.py
```py path="/app/api/v1/endpoints/llms.py"
from __future__ import annotations
import json
import time
from typing import Any, AsyncIterator, Iterator, Mapping, Sequence, cast
import httpx
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, selectinload
from app.core.llm_provider_registry import (
get_provider_defaults,
iter_common_providers,
)
from app.core.logging_config import get_logger
from app.db.session import get_db
from app.models.llm_provider import LLMModel, LLMProvider
from app.models.usage import LLMUsageLog
from app.schemas.llm_provider import (
KnownLLMProvider,
LLMModelCreate,
LLMModelUpdate,
LLMModelRead,
LLMProviderCreate,
LLMProviderRead,
LLMProviderUpdate,
LLMUsageLogRead,
LLMUsageMessage,
)
from app.services.llm_usage import list_quick_test_usage_logs
from app.services.system_settings import (
DEFAULT_QUICK_TEST_TIMEOUT,
get_testing_timeout_config,
)
router = APIRouter()
logger = get_logger("promptworks.api.llms")
DEFAULT_INVOKE_TIMEOUT = DEFAULT_QUICK_TEST_TIMEOUT
class ChatMessage(BaseModel):
role: str = Field(..., description="聊天消息的角色,例如 system、user、assistant")
content: Any = Field(..., description="遵循 OpenAI 聊天格式的消息内容")
class LLMInvocationRequest(BaseModel):
messages: list[ChatMessage]
parameters: dict[str, Any] = Field(
default_factory=dict, description="额外的 OpenAI 兼容参数"
)
model: str | None = Field(default=None, description="覆盖使用的模型名称")
model_id: int | None = Field(default=None, description="指定已配置模型的 ID")
temperature: float | None = Field(
default=None,
ge=0.0,
le=2.0,
description="对话生成温度,范围 0~2",
)
prompt_id: int | None = Field(
default=None, description="可选的 Prompt ID,便于溯源"
)
prompt_version_id: int | None = Field(
default=None, description="可选的 Prompt 版本 ID"
)
persist_usage: bool = Field(
default=False, description="是否将本次调用记录到用量日志"
)
class LLMStreamInvocationRequest(LLMInvocationRequest):
temperature: float = Field(
default=0.7,
ge=0.0,
le=2.0,
description="对话生成温度,范围 0~2",
)
def _normalize_key(value: str | None) -> str | None:
if value is None:
return None
normalized = value.strip().lower()
return normalized or None
def _normalize_base_url(base_url: str | None) -> str | None:
if not base_url:
return None
return base_url.rstrip("/")
def _mask_api_key(api_key: str) -> str:
if not api_key:
return ""
if len(api_key) <= 6:
return "*" * len(api_key)
prefix = api_key[:4]
suffix = api_key[-2:]
return f"{prefix}{'*' * (len(api_key) - 6)}{suffix}"
def _serialize_provider(provider: LLMProvider) -> LLMProviderRead:
models = [
LLMModelRead.model_validate(model, from_attributes=True)
for model in sorted(provider.models, key=lambda item: item.created_at)
]
defaults = get_provider_defaults(provider.provider_key)
resolved_base_url = provider.base_url or (defaults.base_url if defaults else None)
resolved_logo_url = provider.logo_url or (defaults.logo_url if defaults else None)
resolved_logo_emoji = provider.logo_emoji
if resolved_logo_emoji is None and defaults:
resolved_logo_emoji = defaults.logo_emoji
return LLMProviderRead(
id=provider.id,
provider_key=provider.provider_key,
provider_name=provider.provider_name,
base_url=resolved_base_url,
logo_emoji=resolved_logo_emoji,
logo_url=resolved_logo_url,
is_custom=provider.is_custom,
is_archived=provider.is_archived,
default_model_name=provider.default_model_name,
masked_api_key=_mask_api_key(provider.api_key),
models=models,
created_at=provider.created_at,
updated_at=provider.updated_at,
)
def _resolve_base_url_or_400(provider: LLMProvider) -> str:
base_url = provider.base_url or (
defaults.base_url
if (defaults := get_provider_defaults(provider.provider_key))
else None
)
if not base_url:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该提供者未配置基础 URL。",
)
return cast(str, _normalize_base_url(base_url))
def _get_provider_or_404(
db: Session, provider_id: int, *, include_archived: bool = False
) -> LLMProvider:
provider = db.get(LLMProvider, provider_id)
if not provider or (provider.is_archived and not include_archived):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="未找到指定的提供者"
)
return provider
def _determine_model_for_invocation(
db: Session, provider: LLMProvider, payload: LLMInvocationRequest
) -> tuple[str, LLMModel | None]:
model_name: str | None = None
target_model: LLMModel | None = None
if payload.model_id is not None:
stmt = select(LLMModel).where(
LLMModel.id == payload.model_id, LLMModel.provider_id == provider.id
)
target_model = db.scalar(stmt)
if not target_model:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="指定的模型不存在",
)
elif payload.model:
stmt = select(LLMModel).where(
LLMModel.provider_id == provider.id, LLMModel.name == payload.model
)
target_model = db.scalar(stmt)
if not target_model:
model_name = payload.model
if target_model:
model_name = target_model.name
if not model_name:
if provider.default_model_name:
model_name = provider.default_model_name
else:
fallback_model = next(iter(provider.models), None)
if fallback_model:
model_name = fallback_model.name
if not model_name:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="未能确定调用模型,请在请求中指定 model 或先配置默认模型。",
)
return model_name, target_model
def _resolve_provider_defaults_for_create(
data: dict[str, Any],
) -> tuple[dict[str, Any], str | None]:
normalized_key = _normalize_key(data.get("provider_key"))
normalized_name = _normalize_key(data.get("provider_name"))
provider_key = normalized_key or normalized_name
defaults = get_provider_defaults(provider_key)
resolved_base_url = _normalize_base_url(
data.get("base_url") or (defaults.base_url if defaults else None)
)
resolved_is_custom = data.get("is_custom")
if resolved_is_custom is None:
resolved_is_custom = defaults is None
if resolved_is_custom and not resolved_base_url:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="自定义提供者必须配置基础 URL。",
)
if not resolved_is_custom and not resolved_base_url:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该提供者需要配置基础 URL。",
)
provider_name = data.get("provider_name") or (defaults.name if defaults else None)
if not provider_name:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="请提供提供者名称。",
)
data.update(
{
"provider_key": defaults.key if defaults else provider_key,
"provider_name": provider_name,
"base_url": resolved_base_url,
"logo_emoji": data.get("logo_emoji")
or (defaults.logo_emoji if defaults else None),
"logo_url": data.get("logo_url")
or (defaults.logo_url if defaults else None),
"is_custom": resolved_is_custom,
}
)
return data, provider_key
@router.get("/common", response_model=list[KnownLLMProvider])
def list_common_providers() -> list[KnownLLMProvider]:
"""返回预置的常用提供方信息。"""
items = [
KnownLLMProvider(
key=provider.key,
name=provider.name,
description=provider.description,
base_url=provider.base_url,
logo_emoji=provider.logo_emoji,
logo_url=provider.logo_url,
)
for provider in iter_common_providers()
]
return items
@router.get("/quick-test/history", response_model=list[LLMUsageLogRead])
def list_quick_test_history(
*,
db: Session = Depends(get_db),
limit: int = Query(20, ge=1, le=100, description="返回的历史记录数量"),
offset: int = Query(0, ge=0, description="跳过的历史记录数量"),
) -> list[LLMUsageLogRead]:
"""返回快速测试产生的最近调用记录。"""
logs = list_quick_test_usage_logs(db, limit=limit, offset=offset)
history: list[LLMUsageLogRead] = []
for log in logs:
provider = log.provider
message_items: list[LLMUsageMessage] = []
if isinstance(log.messages, list):
for item in log.messages:
if not isinstance(item, dict):
continue
try:
message_items.append(LLMUsageMessage.model_validate(item))
except ValidationError:
role = str(item.get("role", "user"))
message_items.append(
LLMUsageMessage(role=role, content=item.get("content"))
)
history.append(
LLMUsageLogRead(
id=log.id,
provider_id=log.provider_id,
provider_name=provider.provider_name if provider else None,
provider_logo_emoji=provider.logo_emoji if provider else None,
provider_logo_url=provider.logo_url if provider else None,
model_id=log.model_id,
model_name=log.model_name,
response_text=log.response_text,
messages=message_items,
temperature=log.temperature,
latency_ms=log.latency_ms,
prompt_tokens=log.prompt_tokens,
completion_tokens=log.completion_tokens,
total_tokens=log.total_tokens,
prompt_id=log.prompt_id,
prompt_version_id=log.prompt_version_id,
created_at=log.created_at,
)
)
return history
@router.get("/", response_model=list[LLMProviderRead])
def list_llm_providers(
*,
db: Session = Depends(get_db),
provider_name: str | None = Query(default=None, alias="provider"),
limit: int = Query(default=50, ge=1, le=200),
offset: int = Query(default=0, ge=0),
) -> list[LLMProviderRead]:
"""返回 LLM 提供者列表,默认排除已归档记录。"""
logger.info(
"查询 LLM 提供者列表: provider=%s limit=%s offset=%s",
provider_name,
limit,
offset,
)
stmt = (
select(LLMProvider)
.where(LLMProvider.is_archived.is_(False))
.order_by(LLMProvider.updated_at.desc())
.options(selectinload(LLMProvider.models))
.offset(offset)
.limit(limit)
)
if provider_name:
stmt = stmt.where(LLMProvider.provider_name.ilike(f"%{provider_name}%"))
providers = list(db.scalars(stmt))
return [_serialize_provider(provider) for provider in providers]
@router.post("/", response_model=LLMProviderRead, status_code=status.HTTP_201_CREATED)
def create_llm_provider(
*,
db: Session = Depends(get_db),
payload: LLMProviderCreate,
) -> LLMProviderRead:
"""创建新的 LLM 提供者卡片,初始模型列表为空。"""
data = payload.model_dump()
data, provider_key = _resolve_provider_defaults_for_create(data)
duplicate_stmt = select(LLMProvider).where(
LLMProvider.provider_name == data["provider_name"],
LLMProvider.base_url == data["base_url"],
LLMProvider.is_archived.is_(False),
)
if db.scalar(duplicate_stmt):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="已存在相同提供方与基础地址的配置。",
)
provider = LLMProvider(**data)
db.add(provider)
db.commit()
db.refresh(provider)
logger.info(
"创建 LLM 提供者成功: id=%s provider=%s key=%s",
provider.id,
provider.provider_name,
provider_key,
)
return _serialize_provider(provider)
@router.get("/{provider_id}", response_model=LLMProviderRead)
def get_llm_provider(
*, db: Session = Depends(get_db), provider_id: int
) -> LLMProviderRead:
"""获取单个 LLM 提供者详情。"""
provider = _get_provider_or_404(db, provider_id, include_archived=True)
return _serialize_provider(provider)
@router.patch("/{provider_id}", response_model=LLMProviderRead)
def update_llm_provider(
*,
db: Session = Depends(get_db),
provider_id: int,
payload: LLMProviderUpdate,
) -> LLMProviderRead:
"""更新已有的 LLM 提供者配置。"""
provider = _get_provider_or_404(db, provider_id)
update_data = payload.model_dump(exclude_unset=True)
if "base_url" in update_data:
update_data["base_url"] = _normalize_base_url(update_data["base_url"])
if update_data["base_url"] is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该提供者需要配置基础 URL。",
)
if "api_key" in update_data and update_data["api_key"]:
logger.info("更新 LLM 提供者密钥: provider_id=%s", provider.id)
for key, value in update_data.items():
setattr(provider, key, value)
defaults = get_provider_defaults(provider.provider_key)
if not provider.is_custom and not provider.base_url and defaults:
provider.base_url = defaults.base_url
if provider.is_custom and not provider.base_url:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="自定义提供者必须配置基础 URL。",
)
if not provider.is_custom and not provider.base_url:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该提供者需要配置基础 URL。",
)
db.commit()
db.refresh(provider)
logger.info("更新 LLM 提供者成功: id=%s", provider.id)
return _serialize_provider(provider)
@router.post(
"/{provider_id}/models",
response_model=LLMModelRead,
status_code=status.HTTP_201_CREATED,
)
def create_llm_model(
*,
db: Session = Depends(get_db),
provider_id: int,
payload: LLMModelCreate,
) -> LLMModelRead:
"""为指定提供者新增模型。"""
provider = _get_provider_or_404(db, provider_id)
data = payload.model_dump()
model = LLMModel(provider_id=provider.id, **data)
db.add(model)
try:
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该模型名称已存在,请勿重复添加。",
) from None
db.refresh(model)
logger.info("新增模型成功: provider_id=%s model=%s", provider.id, model.name)
return LLMModelRead.model_validate(model, from_attributes=True)
@router.patch(
"/{provider_id}/models/{model_id}",
response_model=LLMModelRead,
)
def update_llm_model(
*,
db: Session = Depends(get_db),
provider_id: int,
model_id: int,
payload: LLMModelUpdate,
) -> LLMModelRead:
"""更新模型属性,如并发配置。"""
_ = _get_provider_or_404(db, provider_id)
stmt = select(LLMModel).where(
LLMModel.id == model_id, LLMModel.provider_id == provider_id
)
model = db.scalar(stmt)
if not model:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="未找到指定的模型",
)
update_data = payload.model_dump(exclude_unset=True)
concurrency = update_data.get("concurrency_limit")
if concurrency is not None:
model.concurrency_limit = concurrency
if "capability" in update_data:
model.capability = update_data["capability"]
if "quota" in update_data:
model.quota = update_data["quota"]
db.commit()
db.refresh(model)
logger.info(
"更新模型成功: provider_id=%s model_id=%s concurrency=%s",
provider_id,
model_id,
model.concurrency_limit,
)
return LLMModelRead.model_validate(model, from_attributes=True)
@router.delete(
"/{provider_id}/models/{model_id}",
status_code=status.HTTP_204_NO_CONTENT,
response_class=Response,
)
def delete_llm_model(
*,
db: Session = Depends(get_db),
provider_id: int,
model_id: int,
) -> Response:
"""删除指定模型,若无剩余模型则自动归档提供者。"""
provider = _get_provider_or_404(db, provider_id, include_archived=True)
model_stmt = select(LLMModel).where(
LLMModel.id == model_id, LLMModel.provider_id == provider.id
)
model = db.scalar(model_stmt)
if not model:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="未找到指定的模型",
)
db.delete(model)
db.flush()
remaining = db.scalar(
select(func.count(LLMModel.id)).where(LLMModel.provider_id == provider.id)
)
if remaining == 0:
provider.is_archived = True
db.commit()
logger.info(
"删除模型成功: provider_id=%s model_id=%s remaining=%s",
provider.id,
model_id,
remaining,
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.delete(
"/{provider_id}",
status_code=status.HTTP_204_NO_CONTENT,
response_class=Response,
)
def delete_llm_provider(
*,
db: Session = Depends(get_db),
provider_id: int,
) -> Response:
"""删除整个提供方配置,级联清理其下所有模型。"""
provider = _get_provider_or_404(db, provider_id, include_archived=True)
db.delete(provider)
db.commit()
logger.info("删除 LLM 提供者成功: id=%s", provider_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.post("/{provider_id}/invoke")
def invoke_llm(
*,
db: Session = Depends(get_db),
provider_id: int,
payload: LLMInvocationRequest,
) -> dict[str, Any]:
"""使用兼容 OpenAI Chat Completion 的方式调用目标 LLM。"""
provider = _get_provider_or_404(db, provider_id)
model_name, target_model = _determine_model_for_invocation(db, provider, payload)
base_url = _resolve_base_url_or_400(provider)
request_payload: dict[str, Any] = dict(payload.parameters)
if payload.temperature is not None:
request_payload.setdefault("temperature", payload.temperature)
request_payload["model"] = model_name
request_messages = [message.model_dump() for message in payload.messages]
request_payload["messages"] = request_messages
headers = {
"Authorization": f"Bearer {provider.api_key}",
"Content-Type": "application/json",
}
url = f"{base_url}/chat/completions"
logger.info("调用外部 LLM 接口: provider_id=%s url=%s", provider.id, url)
logger.debug("LLM 请求参数: %s", request_payload)
timeout_config = get_testing_timeout_config(db)
invoke_timeout = float(timeout_config.quick_test_timeout or DEFAULT_INVOKE_TIMEOUT)
start_time = time.perf_counter()
try:
response = httpx.post(
url,
headers=headers,
json=request_payload,
timeout=invoke_timeout,
)
except httpx.HTTPError as exc:
logger.error(
"调用外部 LLM 接口出现网络异常: provider_id=%s 错误=%s",
provider.id,
exc,
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)
) from exc
if response.status_code >= 400:
try:
error_payload = response.json()
except ValueError:
error_payload = {"message": response.text}
logger.error(
"外部 LLM 接口返回错误: provider_id=%s 状态码=%s 响应=%s",
provider.id,
response.status_code,
error_payload,
)
raise HTTPException(status_code=response.status_code, detail=error_payload)
elapsed = getattr(response, "elapsed", None)
if elapsed is not None:
latency_ms = int(elapsed.total_seconds() * 1000)
else:
latency_ms = int((time.perf_counter() - start_time) * 1000)
if latency_ms >= 0:
logger.info(
"外部 LLM 接口调用成功: provider_id=%s 耗时 %.2fms",
provider.id,
max(latency_ms, 0),
)
else:
logger.info("外部 LLM 接口调用成功: provider_id=%s", provider.id)
try:
response_payload = response.json()
except ValueError as exc:
logger.error(
"LLM 响应解析失败: provider_id=%s model=%s", provider.id, model_name
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail="LLM 响应解析失败。"
) from exc
if payload.persist_usage:
usage_raw = response_payload.get("usage")
usage_obj: Mapping[str, Any] | None = (
usage_raw if isinstance(usage_raw, Mapping) else None
)
prompt_value = usage_obj.get("prompt_tokens") if usage_obj else None
prompt_tokens = (
int(prompt_value) if isinstance(prompt_value, (int, float)) else None
)
completion_value = usage_obj.get("completion_tokens") if usage_obj else None
completion_tokens = (
int(completion_value)
if isinstance(completion_value, (int, float))
else None
)
total_value = usage_obj.get("total_tokens") if usage_obj else None
total_tokens = (
int(total_value) if isinstance(total_value, (int, float)) else None
)
if total_tokens is None and any(
value is not None for value in (prompt_tokens, completion_tokens)
):
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
choices = response_payload.get("choices")
generated_chunks: list[str] = []
if isinstance(choices, Sequence):
for choice in choices:
if not isinstance(choice, Mapping):
continue
message_obj = choice.get("message")
if isinstance(message_obj, Mapping) and isinstance(
message_obj.get("content"), str
):
generated_chunks.append(message_obj["content"])
continue
text_content = choice.get("text")
if isinstance(text_content, str):
generated_chunks.append(text_content)
response_text = "".join(generated_chunks)
original_parameters = dict(payload.parameters)
if payload.temperature is not None:
original_parameters.setdefault("temperature", payload.temperature)
log_entry = LLMUsageLog(
provider_id=provider.id,
model_id=target_model.id if target_model else None,
model_name=model_name,
source="quick_test",
prompt_id=payload.prompt_id,
prompt_version_id=payload.prompt_version_id,
messages=request_messages,
parameters=original_parameters or None,
response_text=response_text or None,
temperature=payload.temperature,
latency_ms=max(latency_ms, 0),
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
try:
db.add(log_entry)
db.commit()
logger.info(
"非流式调用已记录用量: provider_id=%s model=%s tokens=%s",
provider.id,
model_name,
{
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
},
)
except Exception: # pragma: no cover - 防御性回滚
db.rollback()
logger.exception(
"保存 LLM 调用日志失败: provider_id=%s model=%s",
provider.id,
model_name,
)
return response_payload
@router.post(
"/{provider_id}/invoke/stream",
response_class=StreamingResponse,
)
async def stream_invoke_llm(
*,
db: Session = Depends(get_db),
provider_id: int,
payload: LLMStreamInvocationRequest,
) -> StreamingResponse:
"""以流式方式调用目标 LLM,并转发 OpenAI 兼容的事件流。"""
provider = _get_provider_or_404(db, provider_id)
model_name, target_model = _determine_model_for_invocation(db, provider, payload)
base_url = _resolve_base_url_or_400(provider)
request_payload: dict[str, Any] = dict(payload.parameters)
request_payload.pop("stream", None)
request_payload["temperature"] = payload.temperature
request_payload["model"] = model_name
request_payload["messages"] = [message.model_dump() for message in payload.messages]
request_payload["stream"] = True
stream_options = request_payload.get("stream_options")
if isinstance(stream_options, dict):
stream_options.setdefault("include_usage", True)
else:
request_payload["stream_options"] = {"include_usage": True}
headers = {
"Authorization": f"Bearer {provider.api_key}",
"Content-Type": "application/json",
}
url = f"{base_url}/chat/completions"
logger.info(
"启动流式 LLM 调用: provider_id=%s model=%s url=%s",
provider.id,
model_name,
url,
)
logger.debug("LLM 流式请求参数: %s", request_payload)
start_time = time.perf_counter()
usage_summary: dict[str, int | None] | None = None
generated_chunks: list[str] = []
should_persist = True
request_messages = [message.model_dump() for message in payload.messages]
original_parameters = dict(payload.parameters)
def _process_event(lines: list[str]) -> list[str]:
nonlocal usage_summary
if not lines:
return []
data_segments: list[str] = []
for item in lines:
if item.startswith(":"):
continue
if item.startswith("data:"):
data_segments.append(item[5:].lstrip())
if not data_segments:
return []
data_str = "\n".join(data_segments).strip()
if not data_str:
return []
if data_str == "[DONE]":
return ["[DONE]"]
snippet = data_str if len(data_str) <= 200 else f"{data_str[:200]}…"
logger.info(
"接收到流式事件: provider_id=%s model=%s data=%s",
provider.id,
model_name,
snippet,
)
try:
payload_obj = json.loads(data_str)
except json.JSONDecodeError:
logger.debug("忽略无法解析的流式分片: %s", data_str)
return []
usage_payload = payload_obj.get("usage")
usage = payload_obj.get("usage")
if isinstance(usage, dict):
usage_summary = {
"prompt_tokens": usage.get("prompt_tokens"),
"completion_tokens": usage.get("completion_tokens"),
"total_tokens": usage.get("total_tokens"),
}
base_payload = {
key: value
for key, value in payload_obj.items()
if key not in ("choices", "usage")
}
def _split_choice(choice: Mapping[str, Any]) -> list[dict[str, Any]]:
"""将单个 choice 拆分为逐字符的子分片,确保前端逐字渲染。"""
pieces: list[dict[str, Any]] = []
common_fields = {
key: value
for key, value in choice.items()
if key not in ("delta", "message", "text")
}
delta_obj = choice.get("delta")
if isinstance(delta_obj, dict):
extra_delta = {k: v for k, v in delta_obj.items() if k != "content"}
content = delta_obj.get("content")
if isinstance(content, str) and content:
generated_chunks.append(content)
for symbol in content:
new_choice = dict(common_fields)
new_delta = dict(extra_delta)
new_delta["content"] = symbol
new_choice["delta"] = new_delta
pieces.append(new_choice)
return pieces
new_choice = dict(common_fields)
new_choice["delta"] = dict(delta_obj)
pieces.append(new_choice)
return pieces
message_obj = choice.get("message")
if isinstance(message_obj, dict):
extra_message = {k: v for k, v in message_obj.items() if k != "content"}
content = message_obj.get("content")
if isinstance(content, str) and content:
generated_chunks.append(content)
for symbol in content:
new_choice = dict(common_fields)
new_message = dict(extra_message)
new_message["content"] = symbol
new_choice["message"] = new_message
pieces.append(new_choice)
return pieces
new_choice = dict(common_fields)
new_choice["message"] = dict(message_obj)
pieces.append(new_choice)
return pieces
text_value = choice.get("text")
if isinstance(text_value, str) and text_value:
generated_chunks.append(text_value)
for symbol in text_value:
new_choice = dict(common_fields)
new_choice["text"] = symbol
pieces.append(new_choice)
return pieces
pieces.append(dict(choice))
return pieces
event_payloads: list[dict[str, Any]] = []
choices = payload_obj.get("choices")
if isinstance(choices, list):
for choice in choices:
if not isinstance(choice, Mapping):
continue
for piece in _split_choice(choice):
event_payloads.append({**base_payload, "choices": [piece]})
if not event_payloads:
payload_copy = dict(payload_obj)
return [json.dumps(payload_copy, ensure_ascii=False, separators=(",", ":"))]
if isinstance(usage_payload, dict):
event_payloads[-1]["usage"] = usage_payload
return [
json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
for payload in event_payloads
]
def _persist_usage() -> None:
if not should_persist:
return
summary = usage_summary or {}
response_text = "".join(generated_chunks)
if not response_text and not summary:
return
latency_ms = int((time.perf_counter() - start_time) * 1000)
log_entry = LLMUsageLog(
provider_id=provider.id,
model_id=target_model.id if target_model else None,
model_name=model_name,
source="quick_test",
prompt_id=payload.prompt_id,
prompt_version_id=payload.prompt_version_id,
messages=request_messages,
parameters=original_parameters or None,
response_text=response_text or None,
temperature=payload.temperature,
latency_ms=latency_ms,
prompt_tokens=summary.get("prompt_tokens"),
completion_tokens=summary.get("completion_tokens"),
total_tokens=summary.get("total_tokens"),
)
try:
db.add(log_entry)
db.commit()
logger.info(
"流式调用完成: provider_id=%s model=%s tokens=%s",
provider.id,
model_name,
summary,
)
except Exception: # pragma: no cover - 防御性回滚
db.rollback()
logger.exception(
"保存 LLM 调用日志失败: provider_id=%s model=%s",
provider.id,
model_name,
)
timeout_config = get_testing_timeout_config(db)
invoke_timeout = float(timeout_config.quick_test_timeout or DEFAULT_INVOKE_TIMEOUT)
async def _event_stream() -> AsyncIterator[bytes]:
nonlocal should_persist
event_lines: list[str] = []
async with httpx.AsyncClient(timeout=invoke_timeout) as async_client:
try:
async with async_client.stream(
"POST",
url,
headers=headers,
json=request_payload,
) as response:
if response.status_code >= 400:
should_persist = False
error_body = await response.aread()
decoded = error_body.decode("utf-8", errors="ignore")
try:
error_payload = json.loads(decoded)
except ValueError:
error_payload = {"message": decoded}
logger.error(
"流式调用返回错误: provider_id=%s 状态码=%s 响应=%s",
provider.id,
response.status_code,
error_payload,
)
raise HTTPException(
status_code=response.status_code, detail=error_payload
)
async for line in response.aiter_lines():
if line is None:
continue
if line == "":
for payload in _process_event(event_lines):
if payload == "[DONE]":
yield b"data: [DONE]\n\n"
else:
yield f"data: {payload}\n\n".encode("utf-8")
event_lines = []
continue
event_lines.append(line)
if event_lines:
for payload in _process_event(event_lines):
if payload == "[DONE]":
yield b"data: [DONE]\n\n"
else:
yield f"data: {payload}\n\n".encode("utf-8")
except httpx.HTTPError as exc:
should_persist = False
logger.error(
"流式调用外部 LLM 出现异常: provider_id=%s 错误=%s",
provider.id,
exc,
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)
) from exc
finally:
await run_in_threadpool(_persist_usage)
headers_extra = {
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
}
return StreamingResponse(
_event_stream(), media_type="text/event-stream", headers=headers_extra
)
```
## /app/api/v1/endpoints/prompt_classes.py
```py path="/app/api/v1/endpoints/prompt_classes.py"
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.db.session import get_db
from app.models.prompt import Prompt, PromptClass
from app.schemas import (
PromptClassCreate,
PromptClassRead,
PromptClassStats,
PromptClassUpdate,
)
router = APIRouter()
@router.get("/", response_model=list[PromptClassStats])
def list_prompt_classes(
*,
db: Session = Depends(get_db),
q: str | None = Query(default=None, description="按名称模糊搜索分类"),
limit: int = Query(default=50, ge=1, le=200),
offset: int = Query(default=0, ge=0),
) -> list[PromptClassStats]:
"""按名称排序列出 Prompt 分类,并附带使用统计信息"""
prompt_count = func.count(Prompt.id)
latest_updated = func.max(Prompt.updated_at)
stmt = (
select(
PromptClass,
prompt_count.label("prompt_count"),
latest_updated.label("latest_prompt_updated_at"),
)
.outerjoin(Prompt, Prompt.class_id == PromptClass.id)
.group_by(PromptClass.id)
.order_by(PromptClass.name.asc())
)
if q:
term = q.strip()
if term:
stmt = stmt.where(PromptClass.name.ilike(f"%{term}%"))
stmt = stmt.offset(offset).limit(limit)
rows = db.execute(stmt).all()
return [
PromptClassStats(
id=row.PromptClass.id,
name=row.PromptClass.name,
description=row.PromptClass.description,
created_at=row.PromptClass.created_at,
updated_at=row.PromptClass.updated_at,
prompt_count=row.prompt_count or 0,
latest_prompt_updated_at=row.latest_prompt_updated_at,
)
for row in rows
]
@router.post("/", response_model=PromptClassRead, status_code=status.HTTP_201_CREATED)
def create_prompt_class(
*, db: Session = Depends(get_db), payload: PromptClassCreate
) -> PromptClass:
"""创建新的 Prompt 分类"""
prompt_class = PromptClass(name=payload.name, description=payload.description)
db.add(prompt_class)
try:
db.commit()
except IntegrityError as exc:
db.rollback()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="同名分类已存在"
) from exc
db.refresh(prompt_class)
return prompt_class
@router.patch("/{class_id}", response_model=PromptClassRead)
def update_prompt_class(
*, db: Session = Depends(get_db), class_id: int, payload: PromptClassUpdate
) -> PromptClass:
"""更新指定 Prompt 分类"""
prompt_class = db.get(PromptClass, class_id)
if not prompt_class:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="分类不存在")
update_data = payload.model_dump(exclude_unset=True)
if not update_data:
return prompt_class
name = update_data.get("name")
if name is not None:
prompt_class.name = name.strip()
if "description" in update_data:
prompt_class.description = update_data["description"]
try:
db.commit()
except IntegrityError as exc:
db.rollback()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="同名分类已存在"
) from exc
db.refresh(prompt_class)
return prompt_class
@router.delete(
"/{class_id}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response
)
def delete_prompt_class(*, db: Session = Depends(get_db), class_id: int) -> Response:
"""删除指定 Prompt 分类,同时清理其下所有 Prompt。"""
prompt_class = db.get(PromptClass, class_id)
if not prompt_class:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="分类不存在")
db.delete(prompt_class)
db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
```
## /app/api/v1/endpoints/prompt_tags.py
```py path="/app/api/v1/endpoints/prompt_tags.py"
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Response, status
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.db.session import get_db
from app.models.prompt import PromptTag, prompt_tag_association
from app.schemas import (
PromptTagCreate,
PromptTagListResponse,
PromptTagRead,
PromptTagStats,
)
router = APIRouter()
@router.get("/", response_model=PromptTagListResponse)
def list_prompt_tags(*, db: Session = Depends(get_db)) -> PromptTagListResponse:
"""按名称排序返回全部 Prompt 标签及其引用统计。"""
prompt_count = func.count(prompt_tag_association.c.prompt_id)
stmt = (
select(PromptTag, prompt_count.label("prompt_count"))
.select_from(PromptTag)
.outerjoin(
prompt_tag_association,
PromptTag.id == prompt_tag_association.c.tag_id,
)
.group_by(PromptTag.id)
.order_by(PromptTag.name.asc())
)
rows = db.execute(stmt).all()
items = [
PromptTagStats(
id=row.PromptTag.id,
name=row.PromptTag.name,
color=row.PromptTag.color,
created_at=row.PromptTag.created_at,
updated_at=row.PromptTag.updated_at,
prompt_count=row.prompt_count or 0,
)
for row in rows
]
tagged_prompt_total = db.scalar(
select(func.count(func.distinct(prompt_tag_association.c.prompt_id)))
)
return PromptTagListResponse(
items=items,
tagged_prompt_total=tagged_prompt_total or 0,
)
@router.post("/", response_model=PromptTagRead, status_code=status.HTTP_201_CREATED)
def create_prompt_tag(
*, db: Session = Depends(get_db), payload: PromptTagCreate
) -> PromptTag:
"""创建新的 Prompt 标签。"""
prompt_tag = PromptTag(name=payload.name, color=payload.color)
db.add(prompt_tag)
try:
db.commit()
except IntegrityError as exc:
db.rollback()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="同名标签已存在"
) from exc
db.refresh(prompt_tag)
return prompt_tag
@router.delete(
"/{tag_id}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response
)
def delete_prompt_tag(*, db: Session = Depends(get_db), tag_id: int) -> Response:
"""删除指定 Prompt 标签,若仍有关联则阻止删除。"""
prompt_tag = db.get(PromptTag, tag_id)
if not prompt_tag:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="标签不存在")
associated = db.scalar(
select(func.count())
.select_from(prompt_tag_association)
.where(prompt_tag_association.c.tag_id == tag_id)
)
if associated and associated > 0:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="仍有 Prompt 使用该标签,请先迁移或删除相关 Prompt",
)
db.delete(prompt_tag)
db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
```
## /app/api/v1/endpoints/prompt_test_tasks.py
```py path="/app/api/v1/endpoints/prompt_test_tasks.py"
from __future__ import annotations
from datetime import UTC, datetime
from typing import Sequence
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from sqlalchemy import func, select
from sqlalchemy.orm import Session, selectinload
from app.core.prompt_test_task_queue import enqueue_prompt_test_task
from app.db.session import get_db
from app.models.prompt_test import (
PromptTestExperiment,
PromptTestExperimentStatus,
PromptTestTask,
PromptTestTaskStatus,
PromptTestUnit,
)
from app.schemas.prompt_test import (
PromptTestExperimentCreate,
PromptTestExperimentRead,
PromptTestTaskCreate,
PromptTestTaskRead,
PromptTestTaskUpdate,
PromptTestUnitCreate,
PromptTestUnitRead,
PromptTestUnitUpdate,
)
from app.services.prompt_test_engine import (
PromptTestExecutionError,
execute_prompt_test_experiment,
)
router = APIRouter(prefix="/prompt-test", tags=["prompt-test"])
def _get_task_or_404(db: Session, task_id: int) -> PromptTestTask:
task = db.get(PromptTestTask, task_id)
if not task or task.is_deleted:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="测试任务不存在"
)
return task
def _get_unit_or_404(db: Session, unit_id: int) -> PromptTestUnit:
stmt = (
select(PromptTestUnit)
.where(PromptTestUnit.id == unit_id)
.options(selectinload(PromptTestUnit.task))
)
unit = db.execute(stmt).scalar_one_or_none()
if not unit or (unit.task and unit.task.is_deleted):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="测试单元不存在"
)
return unit
@router.get("/tasks", response_model=list[PromptTestTaskRead])
def list_prompt_test_tasks(
*,
db: Session = Depends(get_db),
status_filter: PromptTestTaskStatus | None = Query(default=None, alias="status"),
) -> Sequence[PromptTestTask]:
"""按状态筛选测试任务列表。"""
stmt = (
select(PromptTestTask)
.options(selectinload(PromptTestTask.units))
.where(PromptTestTask.is_deleted.is_(False))
.order_by(PromptTestTask.created_at.desc())
)
if status_filter:
stmt = stmt.where(PromptTestTask.status == status_filter)
return list(db.scalars(stmt))
@router.post(
"/tasks", response_model=PromptTestTaskRead, status_code=status.HTTP_201_CREATED
)
def create_prompt_test_task(
*, db: Session = Depends(get_db), payload: PromptTestTaskCreate
) -> PromptTestTask:
"""创建新的测试任务,可同时定义最小测试单元。"""
task_data = payload.model_dump(exclude={"units", "auto_execute"})
task = PromptTestTask(**task_data)
db.add(task)
db.flush()
units_payload = payload.units or []
for unit_payload in units_payload:
unit_data = unit_payload.model_dump(exclude_none=True)
unit_data["task_id"] = task.id
unit = PromptTestUnit(**unit_data)
db.add(unit)
if payload.auto_execute:
task.status = PromptTestTaskStatus.READY
db.commit()
db.refresh(task)
if payload.auto_execute:
enqueue_prompt_test_task(task.id)
return task
@router.get("/tasks/{task_id}", response_model=PromptTestTaskRead)
def get_prompt_test_task(
*, db: Session = Depends(get_db), task_id: int
) -> PromptTestTask:
"""获取单个测试任务详情。"""
task = _get_task_or_404(db, task_id)
return task
@router.patch("/tasks/{task_id}", response_model=PromptTestTaskRead)
def update_prompt_test_task(
*,
db: Session = Depends(get_db),
task_id: int,
payload: PromptTestTaskUpdate,
) -> PromptTestTask:
"""更新测试任务的基础信息或状态。"""
task = _get_task_or_404(db, task_id)
update_data = payload.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(task, key, value)
db.commit()
db.refresh(task)
return task
@router.delete(
"/tasks/{task_id}",
status_code=status.HTTP_204_NO_CONTENT,
response_class=Response,
)
def delete_prompt_test_task(*, db: Session = Depends(get_db), task_id: int) -> Response:
"""将测试任务标记为删除,但保留历史数据。"""
task = _get_task_or_404(db, task_id)
if not task.is_deleted:
task.is_deleted = True
db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.get("/tasks/{task_id}/units", response_model=list[PromptTestUnitRead])
def list_units_for_task(
*, db: Session = Depends(get_db), task_id: int
) -> Sequence[PromptTestUnit]:
"""列出指定测试任务下的全部最小测试单元。"""
_get_task_or_404(db, task_id)
stmt = (
select(PromptTestUnit)
.where(PromptTestUnit.task_id == task_id)
.order_by(PromptTestUnit.created_at.asc())
)
return list(db.scalars(stmt))
@router.post(
"/tasks/{task_id}/units",
response_model=PromptTestUnitRead,
status_code=status.HTTP_201_CREATED,
)
def create_unit_for_task(
*,
db: Session = Depends(get_db),
task_id: int,
payload: PromptTestUnitCreate,
) -> PromptTestUnit:
"""为指定测试任务新增最小测试单元。"""
_get_task_or_404(db, task_id)
unit_data = payload.model_dump(exclude_none=True)
unit_data["task_id"] = task_id
unit = PromptTestUnit(**unit_data)
db.add(unit)
db.commit()
db.refresh(unit)
return unit
@router.get("/units/{unit_id}", response_model=PromptTestUnitRead)
def get_prompt_test_unit(
*, db: Session = Depends(get_db), unit_id: int
) -> PromptTestUnit:
"""获取最小测试单元详情。"""
unit = _get_unit_or_404(db, unit_id)
return unit
@router.patch("/units/{unit_id}", response_model=PromptTestUnitRead)
def update_prompt_test_unit(
*,
db: Session = Depends(get_db),
unit_id: int,
payload: PromptTestUnitUpdate,
) -> PromptTestUnit:
"""更新最小测试单元配置。"""
unit = _get_unit_or_404(db, unit_id)
update_data = payload.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(unit, key, value)
db.commit()
db.refresh(unit)
return unit
@router.get(
"/units/{unit_id}/experiments", response_model=list[PromptTestExperimentRead]
)
def list_experiments_for_unit(
*, db: Session = Depends(get_db), unit_id: int
) -> Sequence[PromptTestExperiment]:
"""列出指定测试单元下的实验记录。"""
unit = _get_unit_or_404(db, unit_id)
stmt = (
select(PromptTestExperiment)
.where(PromptTestExperiment.unit_id == unit.id)
.order_by(PromptTestExperiment.created_at.desc())
)
return list(db.scalars(stmt))
@router.post(
"/units/{unit_id}/experiments",
response_model=PromptTestExperimentRead,
status_code=status.HTTP_201_CREATED,
)
def create_experiment_for_unit(
*,
db: Session = Depends(get_db),
unit_id: int,
payload: PromptTestExperimentCreate,
) -> PromptTestExperiment:
"""为指定测试单元创建实验,可选择立即执行。"""
unit = _get_unit_or_404(db, unit_id)
sequence = payload.sequence
if sequence is None:
sequence = (
db.scalar(
select(func.max(PromptTestExperiment.sequence)).where(
PromptTestExperiment.unit_id == unit.id
)
)
or 0
) + 1
experiment_data = payload.model_dump(exclude={"auto_execute"}, exclude_none=True)
experiment_data["unit_id"] = unit.id
experiment_data["sequence"] = sequence
experiment = PromptTestExperiment(**experiment_data)
db.add(experiment)
db.flush()
if payload.auto_execute:
try:
execute_prompt_test_experiment(db, experiment)
except PromptTestExecutionError as exc:
experiment.status = PromptTestExperimentStatus.FAILED
experiment.error = str(exc)
experiment.finished_at = datetime.now(UTC)
db.flush()
db.commit()
db.refresh(experiment)
return experiment
@router.get("/experiments/{experiment_id}", response_model=PromptTestExperimentRead)
def get_prompt_test_experiment(
*, db: Session = Depends(get_db), experiment_id: int
) -> PromptTestExperiment:
"""获取实验结果详情。"""
stmt = (
select(PromptTestExperiment)
.where(PromptTestExperiment.id == experiment_id)
.options(
selectinload(PromptTestExperiment.unit).selectinload(PromptTestUnit.task)
)
)
experiment = db.execute(stmt).scalar_one_or_none()
if not experiment or experiment.unit is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="实验记录不存在"
)
if experiment.unit.task and experiment.unit.task.is_deleted:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="实验记录不存在"
)
return experiment
@router.post(
"/experiments/{experiment_id}/execute",
response_model=PromptTestExperimentRead,
)
def execute_existing_experiment(
*, db: Session = Depends(get_db), experiment_id: int
) -> PromptTestExperiment:
"""重新执行已存在的实验记录。"""
stmt = (
select(PromptTestExperiment)
.where(PromptTestExperiment.id == experiment_id)
.options(
selectinload(PromptTestExperiment.unit).selectinload(PromptTestUnit.task)
)
)
experiment = db.execute(stmt).scalar_one_or_none()
if not experiment or experiment.unit is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="实验记录不存在"
)
if experiment.unit.task and experiment.unit.task.is_deleted:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="实验记录不存在"
)
try:
execute_prompt_test_experiment(db, experiment)
except PromptTestExecutionError as exc:
experiment.status = PromptTestExperimentStatus.FAILED
experiment.error = str(exc)
experiment.finished_at = datetime.now(UTC)
db.commit()
db.refresh(experiment)
return experiment
__all__ = ["router"]
```
## /app/api/v1/endpoints/prompts.py
```py path="/app/api/v1/endpoints/prompts.py"
from __future__ import annotations
from typing import Sequence
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload, selectinload
from app.db.session import get_db
from app.models.prompt import Prompt, PromptClass, PromptTag, PromptVersion
from app.schemas.prompt import PromptCreate, PromptRead, PromptUpdate
router = APIRouter()
def _prompt_query():
return select(Prompt).options(
joinedload(Prompt.prompt_class),
joinedload(Prompt.current_version),
selectinload(Prompt.versions),
selectinload(Prompt.tags),
)
def _get_prompt_or_404(db: Session, prompt_id: int) -> Prompt:
stmt = _prompt_query().where(Prompt.id == prompt_id)
prompt = db.execute(stmt).unique().scalar_one_or_none()
if not prompt:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Prompt 不存在"
)
return prompt
def _resolve_prompt_class(
db: Session,
*,
class_id: int | None,
class_name: str | None,
class_description: str | None,
) -> PromptClass:
if class_id is not None:
prompt_class = db.get(PromptClass, class_id)
if not prompt_class:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="指定的 Prompt 分类不存在",
)
return prompt_class
assert class_name is not None
trimmed = class_name.strip()
stmt = select(PromptClass).where(PromptClass.name == trimmed)
prompt_class = db.scalar(stmt)
if prompt_class:
if class_description and not prompt_class.description:
prompt_class.description = class_description
return prompt_class
prompt_class = PromptClass(name=trimmed, description=class_description)
db.add(prompt_class)
db.flush()
return prompt_class
def _resolve_prompt_tags(db: Session, tag_ids: list[int]) -> list[PromptTag]:
if not tag_ids:
return []
unique_ids = list(dict.fromkeys(tag_ids))
stmt = select(PromptTag).where(PromptTag.id.in_(unique_ids))
tags = db.execute(stmt).scalars().all()
found_ids = {tag.id for tag in tags}
missing = [tag_id for tag_id in unique_ids if tag_id not in found_ids]
if missing:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"以下标签不存在: {missing}",
)
id_to_tag = {tag.id: tag for tag in tags}
return [id_to_tag[tag_id] for tag_id in unique_ids]
@router.get("/", response_model=list[PromptRead])
def list_prompts(
*,
db: Session = Depends(get_db),
q: str | None = Query(default=None, description="根据名称、作者或分类模糊搜索"),
limit: int = Query(default=50, ge=1, le=200),
offset: int = Query(default=0, ge=0),
) -> Sequence[Prompt]:
"""按更新时间倒序分页列出 Prompt。"""
stmt = (
_prompt_query().order_by(Prompt.updated_at.desc()).offset(offset).limit(limit)
)
if q:
like_term = f"%{q}%"
stmt = stmt.join(Prompt.prompt_class).where(
(Prompt.name.ilike(like_term))
| (Prompt.author.ilike(like_term))
| (PromptClass.name.ilike(like_term))
)
return list(db.execute(stmt).unique().scalars().all())
@router.post("/", response_model=PromptRead, status_code=status.HTTP_201_CREATED)
def create_prompt(*, db: Session = Depends(get_db), payload: PromptCreate) -> Prompt:
"""创建 Prompt 并写入首个版本,缺少分类时自动创建分类。"""
prompt_class = _resolve_prompt_class(
db,
class_id=payload.class_id,
class_name=payload.class_name,
class_description=payload.class_description,
)
stmt = select(Prompt).where(
Prompt.class_id == prompt_class.id, Prompt.name == payload.name
)
prompt = db.scalar(stmt)
created_new_prompt = False
if not prompt:
prompt = Prompt(
name=payload.name,
description=payload.description,
author=payload.author,
prompt_class=prompt_class,
)
db.add(prompt)
db.flush()
created_new_prompt = True
else:
if payload.description is not None:
prompt.description = payload.description
if payload.author is not None:
prompt.author = payload.author
if payload.tag_ids is not None:
prompt.tags = _resolve_prompt_tags(db, payload.tag_ids)
elif created_new_prompt:
prompt.tags = []
existing_version = db.scalar(
select(PromptVersion).where(
PromptVersion.prompt_id == prompt.id,
PromptVersion.version == payload.version,
)
)
if existing_version:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该 Prompt 已存在同名版本",
)
prompt_version = PromptVersion(
prompt=prompt,
version=payload.version,
content=payload.content,
)
db.add(prompt_version)
db.flush()
prompt.current_version = prompt_version
try:
db.commit()
except IntegrityError as exc: # pragma: no cover 数据库完整性异常回滚
db.rollback()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="创建 Prompt 时发生数据冲突"
) from exc
return _get_prompt_or_404(db, prompt.id)
@router.get("/{prompt_id}", response_model=PromptRead)
def get_prompt(*, db: Session = Depends(get_db), prompt_id: int) -> Prompt:
"""根据 ID 获取 Prompt 详情,包含全部版本信息。"""
return _get_prompt_or_404(db, prompt_id)
@router.put("/{prompt_id}", response_model=PromptRead)
def update_prompt(
*, db: Session = Depends(get_db), prompt_id: int, payload: PromptUpdate
) -> Prompt:
"""更新 Prompt 及其元数据,可选择创建新版本或切换当前版本。"""
prompt = _get_prompt_or_404(db, prompt_id)
if payload.class_id is not None or (
payload.class_name and payload.class_name.strip()
):
prompt_class = _resolve_prompt_class(
db,
class_id=payload.class_id,
class_name=payload.class_name,
class_description=payload.class_description,
)
prompt.prompt_class = prompt_class
if payload.name is not None:
prompt.name = payload.name
if payload.description is not None:
prompt.description = payload.description
if payload.author is not None:
prompt.author = payload.author
if payload.tag_ids is not None:
prompt.tags = _resolve_prompt_tags(db, payload.tag_ids)
if payload.version is not None and payload.content is not None:
exists = db.scalar(
select(PromptVersion).where(
PromptVersion.prompt_id == prompt.id,
PromptVersion.version == payload.version,
)
)
if exists:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="同名版本已存在"
)
new_version = PromptVersion(
prompt=prompt,
version=payload.version,
content=payload.content,
)
db.add(new_version)
db.flush()
prompt.current_version = new_version
if payload.activate_version_id is not None:
if payload.version is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="activate_version_id 与 version/content 不能同时出现",
)
target_version = db.get(PromptVersion, payload.activate_version_id)
if not target_version or target_version.prompt_id != prompt.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="目标版本不存在或不属于该 Prompt",
)
prompt.current_version = target_version
try:
db.commit()
except IntegrityError as exc: # pragma: no cover 数据库完整性异常回滚
db.rollback()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="更新 Prompt 失败"
) from exc
return _get_prompt_or_404(db, prompt_id)
@router.delete(
"/{prompt_id}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response
)
def delete_prompt(*, db: Session = Depends(get_db), prompt_id: int) -> Response:
"""删除 Prompt 及其全部版本。"""
prompt = db.get(Prompt, prompt_id)
if not prompt:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Prompt 不存在"
)
db.delete(prompt)
db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
```
## /app/api/v1/endpoints/settings.py
```py path="/app/api/v1/endpoints/settings.py"
from __future__ import annotations
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.db.session import get_db
from app.schemas.settings import TestingTimeoutsRead, TestingTimeoutsUpdate
from app.services.system_settings import (
get_testing_timeout_config,
update_testing_timeout_config,
)
router = APIRouter(prefix="/settings", tags=["settings"])
@router.get(
"/testing",
response_model=TestingTimeoutsRead,
summary="获取快速测试与测试任务的超时时间配置",
)
def get_testing_timeouts(*, db: Session = Depends(get_db)) -> TestingTimeoutsRead:
config = get_testing_timeout_config(db)
return TestingTimeoutsRead(
quick_test_timeout=int(config.quick_test_timeout),
test_task_timeout=int(config.test_task_timeout),
updated_at=config.updated_at,
)
@router.put(
"/testing",
response_model=TestingTimeoutsRead,
summary="更新快速测试与测试任务的超时时间配置",
)
def update_testing_timeouts(
*,
db: Session = Depends(get_db),
payload: TestingTimeoutsUpdate,
) -> TestingTimeoutsRead:
config = update_testing_timeout_config(
db,
quick_test_timeout=payload.quick_test_timeout,
test_task_timeout=payload.test_task_timeout,
)
return TestingTimeoutsRead(
quick_test_timeout=int(config.quick_test_timeout),
test_task_timeout=int(config.test_task_timeout),
updated_at=config.updated_at,
)
__all__ = ["router"]
```
## /app/api/v1/endpoints/test_prompt.py
```py path="/app/api/v1/endpoints/test_prompt.py"
from __future__ import annotations
from typing import Sequence
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload, selectinload
from app.db.session import get_db
from app.models.prompt import Prompt, PromptVersion
from app.models.result import Result
from app.models.test_run import TestRun, TestRunStatus
from app.schemas.result import ResultRead
from app.schemas.test_run import TestRunCreate, TestRunRead, TestRunUpdate
from app.core.task_queue import enqueue_test_run, task_queue
router = APIRouter()
def _test_run_query():
return select(TestRun).options(
joinedload(TestRun.prompt_version)
.joinedload(PromptVersion.prompt)
.joinedload(Prompt.prompt_class),
selectinload(TestRun.results).selectinload(Result.metrics),
)
@router.get("/", response_model=list[TestRunRead])
def list_test_prompts(
*,
db: Session = Depends(get_db),
status_filter: TestRunStatus | None = Query(default=None, alias="status"),
prompt_version_id: int | None = Query(default=None),
limit: int = Query(default=50, ge=1, le=200),
offset: int = Query(default=0, ge=0),
) -> Sequence[TestRun]:
"""按筛选条件列出 Prompt 测试任务。"""
stmt = (
_test_run_query()
.order_by(TestRun.created_at.desc())
.offset(offset)
.limit(limit)
)
if status_filter:
stmt = stmt.where(TestRun.status == status_filter)
if prompt_version_id:
stmt = stmt.where(TestRun.prompt_version_id == prompt_version_id)
return list(db.execute(stmt).unique().scalars().all())
@router.post("/", response_model=TestRunRead, status_code=status.HTTP_201_CREATED)
def create_test_prompt(
*, db: Session = Depends(get_db), payload: TestRunCreate
) -> TestRun:
"""为指定 Prompt 版本创建新的测试任务,并将其入队异步执行。"""
prompt_version = db.get(PromptVersion, payload.prompt_version_id)
if not prompt_version:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Prompt 版本不存在"
)
data = payload.model_dump(by_alias=True, exclude_none=True)
test_run = TestRun(**data)
test_run.prompt_version = prompt_version
db.add(test_run)
db.flush()
db.commit()
stmt = _test_run_query().where(TestRun.id == test_run.id)
created_run = db.execute(stmt).unique().scalar_one()
try:
enqueue_test_run(test_run.id)
# 为提升测试稳定性,在入队后短暂等待队列消费,确保立即可见最新状态
task_queue.wait_for_idle(timeout=0.05)
except Exception as exc: # pragma: no cover - 防御性兜底
test_run_ref = db.get(TestRun, test_run.id)
if test_run_ref:
test_run_ref.status = TestRunStatus.FAILED
test_run_ref.last_error = "测试任务入队失败"
db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="测试任务入队失败",
) from exc
return created_run
@router.get("/{test_prompt_id}", response_model=TestRunRead)
def get_test_prompt(*, db: Session = Depends(get_db), test_prompt_id: int) -> TestRun:
"""根据 ID 获取单个测试任务及其关联数据。"""
stmt = _test_run_query().where(TestRun.id == test_prompt_id)
test_run = db.execute(stmt).unique().scalar_one_or_none()
if not test_run:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Test run 不存在"
)
return test_run
@router.patch("/{test_prompt_id}", response_model=TestRunRead)
def update_test_prompt(
*,
db: Session = Depends(get_db),
test_prompt_id: int,
payload: TestRunUpdate,
) -> TestRun:
"""根据 ID 更新测试任务属性,可修改状态。"""
test_run = db.get(TestRun, test_prompt_id)
if not test_run:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Test run 不存在"
)
update_data = payload.model_dump(exclude_unset=True)
status_value = update_data.pop("status", None)
for key, value in update_data.items():
setattr(test_run, key, value)
if status_value is not None:
test_run.status = status_value
db.commit()
db.refresh(test_run)
return test_run
@router.delete(
"/{test_prompt_id}",
status_code=status.HTTP_204_NO_CONTENT,
response_class=Response,
)
def delete_test_prompt(
*, db: Session = Depends(get_db), test_prompt_id: int
) -> Response:
"""删除指定的测试任务及其结果记录。"""
test_run = db.get(TestRun, test_prompt_id)
if not test_run:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Test run 不存在"
)
db.delete(test_run)
db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.get("/{test_prompt_id}/results", response_model=list[ResultRead])
def list_results_for_test_prompt(
*, db: Session = Depends(get_db), test_prompt_id: int
) -> Sequence[Result]:
"""列出指定测试任务的所有结果。"""
stmt = (
select(Result)
.where(Result.test_run_id == test_prompt_id)
.options(selectinload(Result.metrics))
.order_by(Result.run_index.asc())
)
return list(db.scalars(stmt))
@router.post("/{test_prompt_id}/retry", response_model=TestRunRead)
def retry_test_prompt(*, db: Session = Depends(get_db), test_prompt_id: int) -> TestRun:
"""重新入队执行失败的测试任务。"""
test_run = db.get(TestRun, test_prompt_id)
if not test_run:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Test run 不存在"
)
if test_run.status != TestRunStatus.FAILED:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail="仅失败状态的测试任务可重试"
)
test_run.status = TestRunStatus.PENDING
test_run.last_error = None
db.flush()
db.commit()
stmt = _test_run_query().where(TestRun.id == test_run.id)
refreshed = db.execute(stmt).unique().scalar_one()
try:
enqueue_test_run(test_run.id)
except Exception as exc: # pragma: no cover - 防御性兜底
failed_run = db.get(TestRun, test_run.id)
if failed_run:
failed_run.status = TestRunStatus.FAILED
failed_run.last_error = "测试任务重新入队失败"
db.commit()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="测试任务重新入队失败",
) from exc
return refreshed
```
## /app/api/v1/endpoints/usage.py
```py path="/app/api/v1/endpoints/usage.py"
from __future__ import annotations
from datetime import date
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from app.db.session import get_db
from app.schemas import UsageModelSummary, UsageOverview, UsageTimeseriesPoint
from app.services.usage_dashboard import (
ModelUsageSummary as ModelUsageSummaryEntity,
UsageTimeseriesPoint as UsageTimeseriesPointEntity,
aggregate_usage_by_model,
calculate_usage_overview,
get_model_usage_timeseries,
)
router = APIRouter()
def _validate_date_range(start_date: date | None, end_date: date | None) -> None:
if start_date and end_date and end_date < start_date:
raise HTTPException(status_code=400, detail="结束日期必须晚于开始日期")
def _compose_model_key(provider_id: int | None, model_name: str) -> str:
prefix = str(provider_id) if provider_id is not None else "none"
return f"{prefix}::{model_name}"
def _parse_model_key(model_key: str) -> tuple[int | None, str]:
parts = model_key.split("::", 1)
if len(parts) != 2 or not parts[1]:
raise HTTPException(status_code=400, detail="无效的模型标识")
provider_part, model_name = parts
if provider_part == "none":
provider_id = None
else:
try:
provider_id = int(provider_part)
except ValueError as exc: # pragma: no cover - 防御性判断
raise HTTPException(status_code=400, detail="无效的模型标识") from exc
return provider_id, model_name
def _map_model_summary(entity: ModelUsageSummaryEntity) -> UsageModelSummary:
provider_name = entity.provider_name or "未命名提供商"
return UsageModelSummary(
model_key=_compose_model_key(entity.provider_id, entity.model_name),
model_name=entity.model_name,
provider=provider_name,
total_tokens=entity.total_tokens,
input_tokens=entity.input_tokens,
output_tokens=entity.output_tokens,
call_count=entity.call_count,
)
def _map_timeseries_point(entity: UsageTimeseriesPointEntity) -> UsageTimeseriesPoint:
return UsageTimeseriesPoint(
date=entity.date,
input_tokens=entity.input_tokens,
output_tokens=entity.output_tokens,
call_count=entity.call_count,
)
@router.get("/overview", response_model=UsageOverview | None)
def read_usage_overview(
*,
db: Session = Depends(get_db),
start_date: date | None = Query(default=None, description="开始日期"),
end_date: date | None = Query(default=None, description="结束日期"),
) -> UsageOverview | None:
"""汇总全局用量指标。"""
_validate_date_range(start_date, end_date)
overview = calculate_usage_overview(db, start_date=start_date, end_date=end_date)
if overview is None:
return None
return UsageOverview(
total_tokens=overview.total_tokens,
input_tokens=overview.input_tokens,
output_tokens=overview.output_tokens,
call_count=overview.call_count,
)
@router.get("/models", response_model=list[UsageModelSummary])
def read_model_usage(
*,
db: Session = Depends(get_db),
start_date: date | None = Query(default=None, description="开始日期"),
end_date: date | None = Query(default=None, description="结束日期"),
) -> list[UsageModelSummary]:
"""按模型聚合用量数据。"""
_validate_date_range(start_date, end_date)
summaries = aggregate_usage_by_model(db, start_date=start_date, end_date=end_date)
return [_map_model_summary(item) for item in summaries]
@router.get("/models/{model_key}/timeseries", response_model=list[UsageTimeseriesPoint])
def read_model_usage_timeseries(
*,
db: Session = Depends(get_db),
model_key: str,
start_date: date | None = Query(default=None, description="开始日期"),
end_date: date | None = Query(default=None, description="结束日期"),
) -> list[UsageTimeseriesPoint]:
"""获取指定模型的按日用量趋势。"""
_validate_date_range(start_date, end_date)
provider_id, model_name = _parse_model_key(model_key)
points = get_model_usage_timeseries(
db,
provider_id=provider_id,
model_name=model_name,
start_date=start_date,
end_date=end_date,
)
return [_map_timeseries_point(point) for point in points]
__all__ = ["router"]
```
## /app/core/__init__.py
```py path="/app/core/__init__.py"
```
## /app/core/config.py
```py path="/app/core/config.py"
from functools import lru_cache
from typing import Any
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application configuration loaded from environment variables."""
APP_ENV: str = "development"
# 是否启用测试模式,用于控制 DEBUG 级别日志的输出
APP_TEST_MODE: bool = False
API_V1_STR: str = "/api/v1"
PROJECT_NAME: str = "PromptWorks"
DATABASE_URL: str = (
"postgresql+psycopg://promptworks:promptworks@localhost:5432/promptworks"
)
REDIS_URL: str = "redis://localhost:6379/0"
OPENAI_API_KEY: str | None = None
ANTHROPIC_API_KEY: str | None = None
BACKEND_CORS_ORIGINS: list[str] | str = ["http://localhost:5173"]
BACKEND_CORS_ALLOW_CREDENTIALS: bool = True
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
@field_validator("DATABASE_URL")
@classmethod
def validate_database_url(cls, value: str) -> str:
if not value:
msg = "DATABASE_URL must be provided"
raise ValueError(msg)
return value
@field_validator("BACKEND_CORS_ORIGINS", mode="before")
@classmethod
def parse_cors_origins(cls, value: Any) -> list[str]:
if value is None:
return []
if isinstance(value, str):
return [origin.strip() for origin in value.split(",") if origin.strip()]
if isinstance(value, (list, tuple)):
return [str(origin).strip() for origin in value if str(origin).strip()]
raise TypeError(
"BACKEND_CORS_ORIGINS must be a list or a comma separated string"
)
@lru_cache
def get_settings() -> Settings:
"""Return a cached application settings instance."""
return Settings()
settings = get_settings()
```
## /app/core/llm_provider_registry.py
```py path="/app/core/llm_provider_registry.py"
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Iterable
@dataclass(frozen=True)
class ProviderDefaults:
key: str
name: str
base_url: str | None
logo_emoji: str | None
description: str | None = None
logo_url: str | None = None
# 预置常见提供方信息,方便前端直接展示品牌内容
_COMMON_PROVIDERS: Dict[str, ProviderDefaults] = {
"openai": ProviderDefaults(
key="openai",
name="OpenAI",
base_url="https://api.openai.com/v1",
logo_emoji=None,
description="通用对话与代码生成能力强,官方模型接入通道。",
logo_url=(
"https://raw.githubusercontent.com/lobehub/lobe-icons/master/"
"packages/static-svg/icons/openai.svg"
),
),
"anthropic": ProviderDefaults(
key="anthropic",
name="Anthropic",
base_url="https://api.anthropic.com",
logo_emoji=None,
description="Claude 系列专注长文本与合规场景。",
logo_url=(
"https://raw.githubusercontent.com/lobehub/lobe-icons/master/"
"packages/static-svg/icons/anthropic.svg"
),
),
"azure-openai": ProviderDefaults(
key="azure-openai",
name="Azure OpenAI",
base_url="https://{resource-name}.openai.azure.com",
logo_emoji=None,
description="基于 Azure 的企业级 OpenAI 服务,需自定义资源域名。",
logo_url=(
"https://raw.githubusercontent.com/lobehub/lobe-icons/master/"
"packages/static-svg/icons/azure.svg"
),
),
"google": ProviderDefaults(
key="google",
name="Google",
base_url="https://generativelanguage.googleapis.com/v1beta",
logo_emoji=None,
description="Gemini 系列涵盖多模态推理与搜索增强。",
logo_url=(
"https://raw.githubusercontent.com/lobehub/lobe-icons/master/"
"packages/static-svg/icons/google.svg"
),
),
"deepseek": ProviderDefaults(
key="deepseek",
name="DeepSeek",
base_url="https://api.deepseek.com/v1",
logo_emoji=None,
description="国内团队自研的开源友好模型,突出推理与代码表现。",
logo_url=(
"https://raw.githubusercontent.com/lobehub/lobe-icons/master/"
"packages/static-svg/icons/deepseek.svg"
),
),
"dashscope": ProviderDefaults(
key="dashscope",
name="阿里云百炼",
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
logo_emoji=None,
description="通义大模型官方兼容接口,覆盖通用与行业场景。",
logo_url=(
"https://raw.githubusercontent.com/lobehub/lobe-icons/master/"
"packages/static-svg/icons/qwen.svg"
),
),
"siliconflow": ProviderDefaults(
key="siliconflow",
name="硅基流动",
base_url="https://api.siliconflow.cn/v1",
logo_emoji=None,
description="专注高性价比推理服务,提供丰富的开源模型托管能力。",
logo_url=(
"https://raw.githubusercontent.com/lobehub/lobe-icons/master/"
"packages/static-svg/icons/siliconcloud.svg"
),
),
"volcengine": ProviderDefaults(
key="volcengine",
name="火山引擎 Ark",
base_url="https://ark.cn-beijing.volces.com/api/v3",
logo_emoji=None,
description="字节跳动企业级模型平台,支持多模态与大规模并发。",
logo_url=(
"https://raw.githubusercontent.com/lobehub/lobe-icons/master/"
"packages/static-svg/icons/volcengine.svg"
),
),
"zhipu": ProviderDefaults(
key="zhipu",
name="智谱开放平台",
base_url="https://open.bigmodel.cn/api/paas/v4",
logo_emoji=None,
description="GLM 系列专注中文理解与工具调用,生态完整。",
logo_url=(
"https://raw.githubusercontent.com/lobehub/lobe-icons/master/"
"packages/static-svg/icons/zhipu.svg"
),
),
"moonshot": ProviderDefaults(
key="moonshot",
name="月之暗面 Moonshot",
base_url="https://api.moonshot.cn/v1",
logo_emoji=None,
description="国内率先开放 128K 以上上下文的高性能大模型。",
logo_url=(
"https://raw.githubusercontent.com/lobehub/lobe-icons/master/"
"packages/static-svg/icons/moonshot.svg"
),
),
"modelscope": ProviderDefaults(
key="modelscope",
name="魔搭 ModelScope",
base_url="https://api-inference.modelscope.cn/v1",
logo_emoji=None,
description="阿里云模型社区统一推理入口,便于快速体验模型。",
logo_url=(
"https://raw.githubusercontent.com/lobehub/lobe-icons/master/"
"packages/static-svg/icons/modelscope.svg"
),
),
"qianfan": ProviderDefaults(
key="qianfan",
name="百度云千帆",
base_url="https://qianfan.baidubce.com/v2",
logo_emoji=None,
description="百度智能云模型服务,提供文心家族与行业模型接入。",
logo_url=(
"https://raw.githubusercontent.com/lobehub/lobe-icons/master/"
"packages/static-svg/icons/baiducloud.svg"
),
),
}
def get_provider_defaults(provider_key: str | None) -> ProviderDefaults | None:
if not provider_key:
return None
return _COMMON_PROVIDERS.get(provider_key.lower())
def iter_common_providers() -> Iterable[ProviderDefaults]:
return _COMMON_PROVIDERS.values()
```
## /app/core/logging_config.py
```py path="/app/core/logging_config.py"
from __future__ import annotations
import logging
from zoneinfo import ZoneInfo
from datetime import datetime
from app.core.config import settings
# 定义北京时间时区对象,确保所有日志时间统一为北京时区
BEIJING_TZ = ZoneInfo("Asia/Shanghai")
# 定义公共日志格式,包含日志级别、时间、模块名称与具体消息内容
LOG_FORMAT = "[%(levelname)s] %(asctime)s %(name)s - %(message)s"
# 默认时间格式,便于读取
DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
class BeijingTimeFormatter(logging.Formatter):
"""自定义日志格式化器,用于强制日志时间显示为北京时间。"""
def __init__(
self, fmt: str, datefmt: str | None = None, timezone: ZoneInfo | None = None
) -> None:
# 记录目标时区,默认使用北京时区
super().__init__(fmt=fmt, datefmt=datefmt)
self._timezone = timezone or BEIJING_TZ
def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str:
# 使用记录生成时间与指定时区构造 datetime,以便格式化输出
current_time = datetime.fromtimestamp(record.created, tz=self._timezone)
return current_time.strftime(datefmt or DEFAULT_DATE_FORMAT)
def _build_console_handler() -> logging.Handler:
"""构建标准输出日志处理器,负责控制日志级别与格式。"""
console_handler = logging.StreamHandler()
# 测试模式下允许打印 DEBUG 日志,其他模式仅打印 INFO 及以上日志
handler_level = logging.DEBUG if settings.APP_TEST_MODE else logging.INFO
console_handler.setLevel(handler_level)
console_handler.setFormatter(BeijingTimeFormatter(LOG_FORMAT, DEFAULT_DATE_FORMAT))
# 标记该处理器,避免重复添加
setattr(console_handler, "_is_promptworks_handler", True)
return console_handler
def configure_logging() -> None:
"""初始化全局日志配置,仅在首次调用时生效。"""
root_logger = logging.getLogger()
# 如果已经存在我们自定义的处理器,则无需重复配置
for handler in root_logger.handlers:
if getattr(handler, "_is_promptworks_handler", False):
return
# 统一提升根日志器的级别为 DEBUG,交由处理器判断是否真正输出
root_logger.setLevel(logging.DEBUG)
root_logger.addHandler(_build_console_handler())
_disable_uvicorn_logs()
def _disable_uvicorn_logs() -> None:
"""关闭 FastAPI/Uvicorn 默认日志,避免重复输出。"""
for logger_name in ("uvicorn", "uvicorn.error", "uvicorn.access", "uvicorn.asgi"):
logger = logging.getLogger(logger_name)
logger.handlers.clear()
logger.propagate = False
logger.disabled = True
def get_logger(name: str) -> logging.Logger:
"""提供模块级日志记录器,确保使用统一配置。"""
configure_logging()
return logging.getLogger(name)
```
## /app/core/middleware.py
```py path="/app/core/middleware.py"
from __future__ import annotations
import logging
import time
from typing import Callable, Awaitable
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from starlette.types import ASGIApp
from app.core.logging_config import get_logger
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""负责在每次 HTTP 请求完成后记录关键访问信息的中间件。"""
def __init__(self, app: ASGIApp) -> None:
# 初始化父类并准备日志记录器
super().__init__(app)
self._logger = get_logger("promptworks.middleware.request")
async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
# 记录请求开始时间,方便统计耗时
start_time = time.perf_counter()
method = request.method
path = request.url.path
client_host = request.client.host if request.client else "unknown"
try:
response = await call_next(request)
except Exception:
# 出现异常时记录完整堆栈,协助定位问题
elapsed = (time.perf_counter() - start_time) * 1000
self._logger.exception(
"请求处理异常: %s %s 来自 %s 耗时 %.2fms",
method,
path,
client_host,
elapsed,
)
raise
elapsed = (time.perf_counter() - start_time) * 1000
# 标准请求日志,包含方法、路径、客户端、状态码与耗时
self._logger.info(
"请求完成: %s %s 来自 %s 状态码 %s 耗时 %.2fms",
method,
path,
client_host,
response.status_code,
elapsed,
)
# DEBUG 模式下附加更多请求上下文,便于调试
if self._logger.isEnabledFor(logging.DEBUG):
self._logger.debug("请求查询参数: %s", dict(request.query_params))
return response
```
## /app/core/prompt_test_task_queue.py
```py path="/app/core/prompt_test_task_queue.py"
from __future__ import annotations
import logging
import math
import threading
import time
from datetime import UTC, datetime
from queue import Empty, Queue
from collections.abc import Mapping, Sequence
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session, selectinload
from app.db import session as db_session
from app.models.prompt_test import (
PromptTestExperiment,
PromptTestExperimentStatus,
PromptTestTask,
PromptTestTaskStatus,
PromptTestUnit,
)
from app.services.prompt_test_engine import (
PromptTestExecutionError,
execute_prompt_test_experiment,
)
logger = logging.getLogger("promptworks.prompt_test_queue")
class PromptTestProgressTracker:
"""在任务执行期间追踪并持久化进度信息。"""
def __init__(
self,
session: Session,
task: PromptTestTask,
total_runs: int,
*,
step_percent: int = 5,
) -> None:
self._session = session
self._task = task
self._configured_total = max(1, int(total_runs)) if total_runs else 1
self._actual_total = max(1, total_runs)
self._step = max(1, step_percent)
self._completed = 0
self._last_percent = 0
self._next_threshold = self._step
self._initialized = False
def initialize(self) -> None:
config = dict(self._task.config) if isinstance(self._task.config, dict) else {}
progress_record = config.get("progress")
if not isinstance(progress_record, dict):
progress_record = {}
progress_record.update(
{
"current": 0,
"total": self._configured_total,
"percentage": 0,
"step": self._step,
}
)
config["progress"] = progress_record
config["progress_current"] = 0
config["progressCurrent"] = 0
config["progress_total"] = self._configured_total
config["progressTotal"] = self._configured_total
config["progress_percentage"] = 0
config["progressPercentage"] = 0
self._task.config = config
self._initialized = True
self._session.flush()
def advance(self, amount: int = 1) -> None:
if not self._initialized or amount <= 0:
return
self._completed = min(self._actual_total, self._completed + amount)
percent = self._calculate_percent(self._completed)
if self._completed < self._actual_total and percent < self._next_threshold:
return
self._write_progress(self._completed, percent)
self._last_percent = percent
if self._completed >= self._actual_total:
self._next_threshold = 100
else:
next_multiple = ((percent // self._step) + 1) * self._step
self._next_threshold = min(100, max(self._step, next_multiple))
self._session.commit()
def finish(self, force: bool = False) -> None:
if not self._initialized:
return
if not force and self._last_percent >= 100:
return
self._completed = self._actual_total
self._write_progress(self._actual_total, 100)
self._last_percent = 100
self._session.commit()
def _calculate_percent(self, completed: int) -> int:
ratio = completed / self._actual_total if self._actual_total else 0
percent = math.ceil(ratio * 100)
return min(100, max(0, percent))
def _write_progress(self, current: int, percent: int) -> None:
config = dict(self._task.config) if isinstance(self._task.config, dict) else {}
progress_record = config.get("progress")
if not isinstance(progress_record, dict):
progress_record = {}
progress_record.update(
{
"current": min(current, self._configured_total),
"total": self._configured_total,
"percentage": percent,
"step": self._step,
}
)
config["progress"] = progress_record
config["progress_current"] = min(current, self._configured_total)
config["progressCurrent"] = min(current, self._configured_total)
config["progress_total"] = self._configured_total
config["progressTotal"] = self._configured_total
config["progress_percentage"] = percent
config["progressPercentage"] = percent
self._task.config = config
self._session.flush()
def _count_variable_cases(value: Any) -> int:
if value is None:
return 1
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
return len(value) or 1
if isinstance(value, Mapping):
cases = value.get("cases")
if isinstance(cases, Sequence) and not isinstance(cases, (str, bytes, bytearray)):
return len(cases) or 1
rows = value.get("rows")
if isinstance(rows, Sequence) and not isinstance(rows, (str, bytes, bytearray)):
return len(rows) or 1
for key in ("data", "values"):
data = value.get(key)
if isinstance(data, Sequence) and not isinstance(
data, (str, bytes, bytearray)
):
return len(data) or 1
length = (
value.get("length")
or value.get("count")
or value.get("size")
or value.get("total")
)
if isinstance(length, (int, float)) and length > 0:
return int(length)
return 1
def _estimate_total_runs(units: Sequence[PromptTestUnit]) -> int:
total = 0
for unit in units:
rounds = unit.rounds or 1
case_count = _count_variable_cases(unit.variables)
total += max(1, int(rounds)) * max(1, int(case_count))
return max(total, 1)
class PromptTestTaskQueue:
"""Prompt 测试任务的串行执行队列。"""
def __init__(self) -> None:
self._queue: Queue[int] = Queue()
self._worker = threading.Thread(
target=self._worker_loop, name="prompt-test-task-queue", daemon=True
)
self._worker.start()
@staticmethod
def _update_task_last_error(task: PromptTestTask, message: str | None) -> None:
if message:
base = dict(task.config) if isinstance(task.config, dict) else {}
base["last_error"] = message
task.config = base
return
if isinstance(task.config, dict) and "last_error" in task.config:
cleaned = dict(task.config)
cleaned.pop("last_error", None)
task.config = cleaned
def enqueue(self, task_id: int) -> None:
"""将任务加入待执行队列。"""
self._queue.put_nowait(task_id)
logger.info("Prompt 测试任务 %s 已加入执行队列", task_id)
def wait_for_idle(self, timeout: float | None = None) -> bool:
"""等待队列清空,便于测试或调试。"""
if timeout is None:
self._queue.join()
return True
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if self._queue.unfinished_tasks == 0:
return True
time.sleep(0.02)
return self._queue.unfinished_tasks == 0
def _worker_loop(self) -> None:
while True:
try:
task_id = self._queue.get()
except Empty: # pragma: no cover - Queue.get 默认阻塞
continue
try:
self._execute_task(task_id)
except Exception: # pragma: no cover - 防御性兜底
logger.exception(
"执行 Prompt 测试任务 %s 过程中发生未捕获异常", task_id
)
finally:
self._queue.task_done()
def _execute_task(self, task_id: int) -> None:
session = db_session.SessionLocal()
try:
task = session.execute(
select(PromptTestTask)
.where(PromptTestTask.id == task_id)
.options(selectinload(PromptTestTask.units))
).scalar_one_or_none()
if not task:
logger.warning("Prompt 测试任务 %s 不存在,跳过执行", task_id)
return
if task.is_deleted:
logger.info("Prompt 测试任务 %s 已被标记删除,跳过执行", task_id)
return
units = [unit for unit in task.units if isinstance(unit, PromptTestUnit)]
total_runs = _estimate_total_runs(units)
progress_tracker = PromptTestProgressTracker(session, task, total_runs)
progress_tracker.initialize()
if not units:
task.status = PromptTestTaskStatus.COMPLETED
self._update_task_last_error(task, None)
progress_tracker.finish(force=True)
logger.info(
"Prompt 测试任务 %s 无最小测试单元,自动标记为完成", task_id
)
return
task.status = PromptTestTaskStatus.RUNNING
self._update_task_last_error(task, None)
session.commit()
for unit in units:
sequence = (
session.scalar(
select(func.max(PromptTestExperiment.sequence)).where(
PromptTestExperiment.unit_id == unit.id
)
)
or 0
) + 1
experiment = PromptTestExperiment(
unit_id=unit.id,
sequence=sequence,
status=PromptTestExperimentStatus.PENDING,
)
session.add(experiment)
session.flush()
try:
execute_prompt_test_experiment(
session, experiment, progress_tracker.advance
)
except PromptTestExecutionError as exc:
session.refresh(experiment)
experiment.status = PromptTestExperimentStatus.FAILED
experiment.error = str(exc)
experiment.finished_at = datetime.now(UTC)
task.status = PromptTestTaskStatus.FAILED
self._update_task_last_error(task, str(exc))
progress_tracker.finish(force=True)
session.commit()
logger.warning(
"Prompt 测试任务 %s 的最小单元 %s 执行失败: %s",
task_id,
unit.id,
exc,
)
return
except Exception as exc: # pragma: no cover - 防御性兜底
session.refresh(experiment)
experiment.status = PromptTestExperimentStatus.FAILED
experiment.error = "执行测试任务失败"
experiment.finished_at = datetime.now(UTC)
task.status = PromptTestTaskStatus.FAILED
self._update_task_last_error(task, "执行测试任务失败")
progress_tracker.finish(force=True)
session.commit()
logger.exception(
"Prompt 测试任务 %s 的最小单元 %s 执行出现未知异常",
task_id,
unit.id,
)
return
session.commit()
task.status = PromptTestTaskStatus.COMPLETED
self._update_task_last_error(task, None)
progress_tracker.finish()
session.commit()
logger.info("Prompt 测试任务 %s 执行完成", task_id)
finally:
session.close()
task_queue = PromptTestTaskQueue()
def enqueue_prompt_test_task(task_id: int) -> None:
"""对外暴露的入队方法。"""
task_queue.enqueue(task_id)
__all__ = ["enqueue_prompt_test_task", "task_queue"]
```
## /app/core/task_queue.py
```py path="/app/core/task_queue.py"
from __future__ import annotations
import logging
import threading
import time
from queue import Empty, Queue
from app.db import session as db_session
from app.models.test_run import TestRun, TestRunStatus
from app.services.test_run import TestRunExecutionError, execute_test_run
logger = logging.getLogger("promptworks.task_queue")
class TestRunTaskQueue:
"""简单的内存消息队列,用于串行执行测试任务。"""
def __init__(self) -> None:
self._queue: Queue[int] = Queue()
self._worker = threading.Thread(
target=self._worker_loop, name="test-run-queue", daemon=True
)
self._worker.start()
def enqueue(self, test_run_id: int) -> None:
"""将测试任务加入待执行队列。"""
self._queue.put_nowait(test_run_id)
logger.info("测试任务 %s 已加入执行队列", test_run_id)
def _worker_loop(self) -> None:
while True:
try:
test_run_id = self._queue.get()
except Empty: # pragma: no cover - Queue.get 默认阻塞,不会出现
continue
try:
self._execute_task(test_run_id)
except Exception: # pragma: no cover - 防御性兜底
logger.exception("执行测试任务 %s 过程中发生未捕获异常", test_run_id)
finally:
self._queue.task_done()
def _execute_task(self, test_run_id: int) -> None:
session = db_session.SessionLocal()
try:
test_run = session.get(TestRun, test_run_id)
if not test_run:
logger.warning("测试任务 %s 不存在,跳过执行", test_run_id)
return
nested_txn = session.begin_nested()
try:
execute_test_run(session, test_run)
except TestRunExecutionError as exc:
nested_txn.rollback()
session.expire(test_run)
failed_run = session.get(TestRun, test_run_id)
if not failed_run:
logger.warning("测试任务 %s 在回滚后不存在", test_run_id)
return
failed_run.last_error = str(exc)
failed_run.status = TestRunStatus.FAILED
session.commit()
logger.warning("测试任务 %s 执行失败: %s", test_run_id, exc)
return
except Exception as exc: # pragma: no cover - 防御性兜底
nested_txn.rollback()
session.expire_all()
failed_run = session.get(TestRun, test_run_id)
if failed_run:
failed_run.last_error = "执行测试任务失败"
failed_run.status = TestRunStatus.FAILED
session.commit()
logger.exception("测试任务 %s 执行出现未知异常", test_run_id)
return
else:
nested_txn.commit()
session.commit()
logger.info("测试任务 %s 执行完成", test_run_id)
finally:
session.close()
def wait_for_idle(self, timeout: float | None = None) -> bool:
"""等待队列清空,供测试或调试使用。"""
if timeout is None:
self._queue.join()
return True
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if self._queue.unfinished_tasks == 0:
return True
time.sleep(0.02)
return self._queue.unfinished_tasks == 0
task_queue = TestRunTaskQueue()
def enqueue_test_run(test_run_id: int) -> None:
"""对外暴露的入队方法。"""
task_queue.enqueue(test_run_id)
__all__ = ["enqueue_test_run", "task_queue"]
```
## /app/db/__init__.py
```py path="/app/db/__init__.py"
```
## /app/db/session.py
```py path="/app/db/session.py"
from collections.abc import Generator
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from app.core.config import settings
engine = create_engine(settings.DATABASE_URL, future=True, pool_pre_ping=True)
SessionLocal = sessionmaker(
bind=engine, autoflush=False, autocommit=False, expire_on_commit=False
)
def get_db() -> Generator[Session, None, None]:
"""FastAPI dependency that yields a database session."""
db = SessionLocal()
try:
yield db
finally:
db.close()
```
## /app/db/types.py
```py path="/app/db/types.py"
"""Custom SQLAlchemy types used across the application."""
from __future__ import annotations
from sqlalchemy import JSON
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.types import TypeDecorator
class JSONBCompat(TypeDecorator):
"""Use PostgreSQL JSONB when available, fallback to generic JSON otherwise."""
impl = JSONB
cache_ok = True
def load_dialect_impl(self, dialect): # type: ignore[override]
if dialect.name == "postgresql":
return dialect.type_descriptor(JSONB())
return dialect.type_descriptor(JSON())
```
## /app/main.py
```py path="/app/main.py"
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.v1.api import api_router
from app.core.config import settings
from app.core.logging_config import configure_logging, get_logger
from app.core.middleware import RequestLoggingMiddleware
from app.core.task_queue import task_queue as _test_run_task_queue # noqa: F401 - 确保队列初始化
def create_application() -> FastAPI:
"""Instantiate the FastAPI application."""
# 初始化日志系统并输出应用启动信息
configure_logging()
app_logger = get_logger("promptworks.app")
app_logger.info("FastAPI 应用初始化开始")
app = FastAPI(
title=settings.PROJECT_NAME,
version="0.1.0",
openapi_url=f"{settings.API_V1_STR}/openapi.json",
)
# 注册自定义请求日志中间件,捕获每一次请求信息
app.add_middleware(RequestLoggingMiddleware)
allowed_origins = settings.BACKEND_CORS_ORIGINS or ["http://localhost:5173"]
allow_credentials = settings.BACKEND_CORS_ALLOW_CREDENTIALS
if "*" in allowed_origins:
allow_credentials = False
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=allow_credentials,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(api_router, prefix=settings.API_V1_STR)
app_logger.info("FastAPI 应用初始化完成")
return app
app = create_application()
```
## /app/models/__init__.py
```py path="/app/models/__init__.py"
from app.models.base import Base
from app.models.llm_provider import LLMModel, LLMProvider
from app.models.metric import Metric
from app.models.prompt import Prompt, PromptClass, PromptTag, PromptVersion
from app.models.result import Result
from app.models.usage import LLMUsageLog
from app.models.test_run import TestRun, TestRunStatus
from app.models.prompt_test import (
PromptTestTask,
PromptTestTaskStatus,
PromptTestUnit,
PromptTestExperiment,
PromptTestExperimentStatus,
)
from app.models.system_setting import SystemSetting
__all__ = [
"Base",
"PromptClass",
"Prompt",
"PromptTag",
"PromptVersion",
"TestRun",
"TestRunStatus",
"Result",
"Metric",
"LLMProvider",
"LLMModel",
"LLMUsageLog",
"PromptTestTask",
"PromptTestTaskStatus",
"PromptTestUnit",
"PromptTestExperiment",
"PromptTestExperimentStatus",
"SystemSetting",
]
```
## /app/models/base.py
```py path="/app/models/base.py"
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
"""Base class for SQLAlchemy models."""
```
## /app/models/llm_provider.py
```py path="/app/models/llm_provider.py"
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import (
Boolean,
DateTime,
ForeignKey,
Integer,
String,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base
if TYPE_CHECKING: # pragma: no cover - 类型检查辅助
from app.models.usage import LLMUsageLog
class LLMProvider(Base):
__tablename__ = "llm_providers"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
provider_key: Mapped[str | None] = mapped_column(
String(100), nullable=True, index=True
)
provider_name: Mapped[str] = mapped_column(String(150), nullable=False)
base_url: Mapped[str | None] = mapped_column(String(255), nullable=True)
api_key: Mapped[str] = mapped_column(Text, nullable=False)
is_custom: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
is_archived: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
logo_url: Mapped[str | None] = mapped_column(String(255), nullable=True)
logo_emoji: Mapped[str | None] = mapped_column(String(16), nullable=True)
default_model_name: Mapped[str | None] = mapped_column(String(150), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
models: Mapped[list["LLMModel"]] = relationship(
"LLMModel",
back_populates="provider",
cascade="all, delete-orphan",
lazy="selectin",
)
usage_logs: Mapped[list["LLMUsageLog"]] = relationship(
"LLMUsageLog",
back_populates="provider",
passive_deletes=True,
)
def __repr__(self) -> str: # pragma: no cover - 调试辅助
return (
"LLMProvider(id={id}, provider_name={provider}, base_url={base}, models={count})"
).format(
id=self.id,
provider=self.provider_name,
base=self.base_url,
count=len(self.models),
)
class LLMModel(Base):
__tablename__ = "llm_models"
__table_args__ = (
UniqueConstraint("provider_id", "name", name="uq_llm_model_provider_name"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
provider_id: Mapped[int] = mapped_column(
ForeignKey("llm_providers.id", ondelete="CASCADE"), nullable=False, index=True
)
name: Mapped[str] = mapped_column(String(150), nullable=False)
capability: Mapped[str | None] = mapped_column(String(120), nullable=True)
quota: Mapped[str | None] = mapped_column(String(120), nullable=True)
concurrency_limit: Mapped[int] = mapped_column(
Integer, nullable=False, default=5, server_default="5"
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
provider: Mapped[LLMProvider] = relationship("LLMProvider", back_populates="models")
def __repr__(self) -> str: # pragma: no cover - 调试辅助
return ("LLMModel(id={id}, provider_id={provider}, name={name})").format(
id=self.id, provider=self.provider_id, name=self.name
)
```
## /app/models/metric.py
```py path="/app/models/metric.py"
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import DateTime, ForeignKey, Integer, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.types import JSONBCompat
from app.models.base import Base
if TYPE_CHECKING:
from app.models.result import Result
class Metric(Base):
__tablename__ = "metrics"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
result_id: Mapped[int] = mapped_column(
ForeignKey("results.id", ondelete="CASCADE"), nullable=False
)
is_valid_json: Mapped[bool | None] = mapped_column(nullable=True)
schema_pass: Mapped[bool | None] = mapped_column(nullable=True)
missing_fields: Mapped[dict | None] = mapped_column(JSONBCompat, nullable=True)
type_mismatches: Mapped[dict | None] = mapped_column(JSONBCompat, nullable=True)
consistency_score: Mapped[float | None] = mapped_column(nullable=True)
numeric_accuracy: Mapped[float | None] = mapped_column(nullable=True)
boolean_accuracy: Mapped[float | None] = mapped_column(nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
result: Mapped["Result"] = relationship("Result", back_populates="metrics")
```
## /app/models/prompt.py
```py path="/app/models/prompt.py"
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import (
Column,
DateTime,
ForeignKey,
Integer,
String,
Table,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
if TYPE_CHECKING:
from app.models.test_run import TestRun
from app.models.base import Base
prompt_tag_association = Table(
"prompt_tag_links",
Base.metadata,
Column(
"prompt_id",
ForeignKey("prompts.id", ondelete="CASCADE"),
primary_key=True,
),
Column(
"tag_id",
ForeignKey("prompt_tags.id", ondelete="CASCADE"),
primary_key=True,
),
)
class PromptClass(Base):
__tablename__ = "prompts_class"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
prompts: Mapped[list["Prompt"]] = relationship(
"Prompt",
back_populates="prompt_class",
cascade="all, delete-orphan",
)
class PromptTag(Base):
__tablename__ = "prompt_tags"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
color: Mapped[str] = mapped_column(String(7), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
prompts: Mapped[list["Prompt"]] = relationship(
"Prompt",
secondary=prompt_tag_association,
back_populates="tags",
)
class Prompt(Base):
__tablename__ = "prompts"
__table_args__ = (
UniqueConstraint("class_id", "name", name="uq_prompt_class_name"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
class_id: Mapped[int] = mapped_column(
ForeignKey(
"prompts_class.id",
ondelete="CASCADE",
name="prompts_class_id_fkey",
),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
author: Mapped[str | None] = mapped_column(String(100), nullable=True)
current_version_id: Mapped[int | None] = mapped_column(
ForeignKey(
"prompts_versions.id",
name="prompts_current_version_id_fkey",
use_alter=True,
),
nullable=True,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
prompt_class: Mapped[PromptClass] = relationship(
"PromptClass", back_populates="prompts"
)
versions: Mapped[list["PromptVersion"]] = relationship(
"PromptVersion",
back_populates="prompt",
cascade="all, delete-orphan",
order_by="PromptVersion.created_at.desc()",
foreign_keys="PromptVersion.prompt_id",
primaryjoin="Prompt.id == PromptVersion.prompt_id",
)
current_version: Mapped["PromptVersion | None"] = relationship(
"PromptVersion",
foreign_keys="Prompt.current_version_id",
primaryjoin="Prompt.current_version_id == PromptVersion.id",
post_update=True,
)
tags: Mapped[list["PromptTag"]] = relationship(
"PromptTag",
secondary=prompt_tag_association,
back_populates="prompts",
)
class PromptVersion(Base):
__tablename__ = "prompts_versions"
__table_args__ = (
UniqueConstraint("prompt_id", "version", name="uq_prompt_version"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
prompt_id: Mapped[int] = mapped_column(
ForeignKey(
"prompts.id",
ondelete="CASCADE",
name="prompts_versions_prompt_id_fkey",
),
nullable=False,
index=True,
)
version: Mapped[str] = mapped_column(String(50), nullable=False)
content: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
prompt: Mapped["Prompt"] = relationship(
"Prompt",
back_populates="versions",
foreign_keys="PromptVersion.prompt_id",
primaryjoin="PromptVersion.prompt_id == Prompt.id",
)
test_runs: Mapped[list["TestRun"]] = relationship(
"TestRun", back_populates="prompt_version", cascade="all, delete-orphan"
)
__all__ = ["PromptClass", "Prompt", "PromptTag", "PromptVersion"]
```
## /app/models/prompt_test.py
```py path="/app/models/prompt_test.py"
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING
from sqlalchemy import (
Boolean,
DateTime,
Enum as PgEnum,
ForeignKey,
Integer,
String,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.types import JSONBCompat
from app.models.base import Base
if TYPE_CHECKING:
from app.models.prompt import PromptVersion
class PromptTestTaskStatus(str, Enum):
"""测试任务的状态枚举。"""
__test__ = False
DRAFT = "draft"
READY = "ready"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class PromptTestTask(Base):
"""测试任务表,描述一次测试活动的整体配置。"""
__tablename__ = "prompt_test_tasks"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
name: Mapped[str] = mapped_column(String(120), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
prompt_version_id: Mapped[int | None] = mapped_column(
ForeignKey("prompts_versions.id", ondelete="SET NULL"), nullable=True
)
owner_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
config: Mapped[dict | None] = mapped_column(JSONBCompat, nullable=True)
status: Mapped[PromptTestTaskStatus] = mapped_column(
PgEnum(
PromptTestTaskStatus,
name="prompt_test_task_status",
values_callable=lambda enum: [member.value for member in enum],
),
nullable=False,
default=PromptTestTaskStatus.DRAFT,
server_default=PromptTestTaskStatus.DRAFT.value,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
is_deleted: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, server_default="false"
)
prompt_version: Mapped["PromptVersion | None"] = relationship("PromptVersion")
units: Mapped[list["PromptTestUnit"]] = relationship(
"PromptTestUnit",
back_populates="task",
cascade="all, delete-orphan",
passive_deletes=True,
)
class PromptTestUnit(Base):
"""最小测试单元,描述执行一次模型调用所需的上下文。"""
__tablename__ = "prompt_test_units"
__table_args__ = (
UniqueConstraint("task_id", "name", name="uq_prompt_test_unit_task_name"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
task_id: Mapped[int] = mapped_column(
ForeignKey("prompt_test_tasks.id", ondelete="CASCADE"), nullable=False
)
prompt_version_id: Mapped[int | None] = mapped_column(
ForeignKey("prompts_versions.id", ondelete="SET NULL"), nullable=True
)
name: Mapped[str] = mapped_column(String(120), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
model_name: Mapped[str] = mapped_column(String(100), nullable=False)
llm_provider_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
temperature: Mapped[float] = mapped_column(nullable=False, default=0.7)
top_p: Mapped[float | None] = mapped_column(nullable=True)
rounds: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
prompt_template: Mapped[str | None] = mapped_column(Text, nullable=True)
variables: Mapped[dict | None] = mapped_column(JSONBCompat, nullable=True)
parameters: Mapped[dict | None] = mapped_column(JSONBCompat, nullable=True)
expectations: Mapped[dict | None] = mapped_column(JSONBCompat, nullable=True)
tags: Mapped[list[str] | None] = mapped_column(JSONBCompat, nullable=True)
extra: Mapped[dict | None] = mapped_column(JSONBCompat, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
task: Mapped["PromptTestTask"] = relationship(
"PromptTestTask", back_populates="units"
)
prompt_version: Mapped["PromptVersion | None"] = relationship("PromptVersion")
experiments: Mapped[list["PromptTestExperiment"]] = relationship(
"PromptTestExperiment",
back_populates="unit",
cascade="all, delete-orphan",
passive_deletes=True,
)
class PromptTestExperimentStatus(str, Enum):
"""记录实验执行状态,支持多轮执行与失败重试。"""
__test__ = False
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class PromptTestExperiment(Base):
"""实验执行结果,与最小测试单元关联。"""
__tablename__ = "prompt_test_experiments"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
unit_id: Mapped[int] = mapped_column(
ForeignKey("prompt_test_units.id", ondelete="CASCADE"), nullable=False
)
batch_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
sequence: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
status: Mapped[PromptTestExperimentStatus] = mapped_column(
PgEnum(
PromptTestExperimentStatus,
name="prompt_test_experiment_status",
values_callable=lambda enum: [member.value for member in enum],
),
nullable=False,
default=PromptTestExperimentStatus.PENDING,
server_default=PromptTestExperimentStatus.PENDING.value,
)
outputs: Mapped[list[dict] | None] = mapped_column(
JSONBCompat, nullable=True, doc="多轮执行的返回结果列表"
)
metrics: Mapped[dict | None] = mapped_column(
JSONBCompat, nullable=True, doc="自动化评估指标与统计信息"
)
error: Mapped[str | None] = mapped_column(Text, nullable=True)
started_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
finished_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
unit: Mapped["PromptTestUnit"] = relationship(
"PromptTestUnit", back_populates="experiments"
)
__all__ = [
"PromptTestTask",
"PromptTestTaskStatus",
"PromptTestUnit",
"PromptTestExperiment",
"PromptTestExperimentStatus",
]
```
## /app/models/result.py
```py path="/app/models/result.py"
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import DateTime, ForeignKey, Integer, Text, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.types import JSONBCompat
from app.models.base import Base
if TYPE_CHECKING:
from app.models.metric import Metric
from app.models.test_run import TestRun
class Result(Base):
__tablename__ = "results"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
test_run_id: Mapped[int] = mapped_column(
ForeignKey("test_runs.id", ondelete="CASCADE"), nullable=False
)
run_index: Mapped[int] = mapped_column(Integer, nullable=False)
output: Mapped[str] = mapped_column(Text, nullable=False)
parsed_output: Mapped[dict | None] = mapped_column(JSONBCompat, nullable=True)
tokens_used: Mapped[int | None] = mapped_column(Integer, nullable=True)
latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
test_run: Mapped["TestRun"] = relationship("TestRun", back_populates="results")
metrics: Mapped[list["Metric"]] = relationship(
"Metric",
back_populates="result",
cascade="all, delete-orphan",
passive_deletes=True,
)
```
## /app/models/system_setting.py
```py path="/app/models/system_setting.py"
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from app.db.types import JSONBCompat
from app.models.base import Base
class SystemSetting(Base):
"""存储全局配置项的键值对。"""
__tablename__ = "system_settings"
key: Mapped[str] = mapped_column(
String(120),
primary_key=True,
doc="配置项唯一标识",
)
value: Mapped[dict | list | str | int | float | bool | None] = mapped_column(
JSONBCompat,
nullable=True,
doc="配置项内容,统一使用 JSON 结构存储",
)
description: Mapped[str | None] = mapped_column(
Text,
nullable=True,
doc="配置项说明",
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
doc="创建时间",
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=func.now(),
doc="最近更新时间",
)
__all__ = ["SystemSetting"]
```
## /app/models/test_run.py
```py path="/app/models/test_run.py"
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING
from sqlalchemy import DateTime, Enum as PgEnum, ForeignKey, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.types import JSONBCompat
from app.models.base import Base
if TYPE_CHECKING:
from app.models.prompt import Prompt, PromptVersion
from app.models.result import Result
class TestRunStatus(str, Enum):
__test__ = False
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class TestRun(Base):
__test__ = False
__tablename__ = "test_runs"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
prompt_version_id: Mapped[int] = mapped_column(
ForeignKey("prompts_versions.id", ondelete="CASCADE"), nullable=False
)
batch_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
model_name: Mapped[str] = mapped_column(String(100), nullable=False)
model_version: Mapped[str | None] = mapped_column(String(50), nullable=True)
temperature: Mapped[float] = mapped_column(nullable=False, default=0.7)
top_p: Mapped[float] = mapped_column(nullable=False, default=1.0)
repetitions: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
schema: Mapped[dict | None] = mapped_column(JSONBCompat, nullable=True)
status: Mapped[TestRunStatus] = mapped_column(
PgEnum(
TestRunStatus,
name="test_run_status",
values_callable=lambda enum: [member.value for member in enum],
),
nullable=False,
default=TestRunStatus.PENDING,
server_default=TestRunStatus.PENDING.value,
)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
prompt_version: Mapped["PromptVersion"] = relationship(
"PromptVersion", back_populates="test_runs"
)
results: Mapped[list["Result"]] = relationship(
"Result",
back_populates="test_run",
cascade="all, delete-orphan",
passive_deletes=True,
)
@property
def last_error(self) -> str | None:
schema_data = self.schema
if isinstance(schema_data, dict):
raw_value = schema_data.get("last_error")
if isinstance(raw_value, str):
trimmed = raw_value.strip()
if trimmed:
return trimmed
return None
@last_error.setter
def last_error(self, message: str | None) -> None:
schema_data = dict(self.schema or {})
if message and message.strip():
schema_data["last_error"] = message.strip()
else:
schema_data.pop("last_error", None)
self.schema = schema_data or None
@property
def failure_reason(self) -> str | None:
return self.last_error
@failure_reason.setter
def failure_reason(self, message: str | None) -> None:
self.last_error = message
@property
def prompt(self) -> Prompt | None:
return self.prompt_version.prompt if self.prompt_version else None
__all__ = ["TestRun", "TestRunStatus"]
```
## /app/models/usage.py
```py path="/app/models/usage.py"
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, Any
from sqlalchemy import (
DateTime,
Float,
ForeignKey,
Integer,
String,
Text,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.types import JSONBCompat
from app.models.base import Base
if TYPE_CHECKING: # pragma: no cover - 类型检查辅助
from app.models.llm_provider import LLMProvider
class LLMUsageLog(Base):
"""记录每次 LLM 调用的用量与上下文,供后续统计分析。"""
__tablename__ = "llm_usage_logs"
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
provider_id: Mapped[int | None] = mapped_column(
ForeignKey("llm_providers.id", ondelete="SET NULL"), nullable=True
)
model_id: Mapped[int | None] = mapped_column(
ForeignKey("llm_models.id", ondelete="SET NULL"), nullable=True
)
model_name: Mapped[str] = mapped_column(String(150), nullable=False)
source: Mapped[str] = mapped_column(
String(50), nullable=False, default="quick_test"
)
prompt_id: Mapped[int | None] = mapped_column(
ForeignKey("prompts.id", ondelete="SET NULL"), nullable=True
)
prompt_version_id: Mapped[int | None] = mapped_column(
ForeignKey("prompts_versions.id", ondelete="SET NULL"), nullable=True
)
messages: Mapped[list[dict[str, Any]] | None] = mapped_column(
JSONBCompat, nullable=True
)
parameters: Mapped[dict[str, Any] | None] = mapped_column(
JSONBCompat, nullable=True
)
response_text: Mapped[str | None] = mapped_column(Text, nullable=True)
temperature: Mapped[float | None] = mapped_column(Float, nullable=True)
latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
prompt_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
completion_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
total_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
provider: Mapped["LLMProvider"] = relationship(
"LLMProvider", back_populates="usage_logs", passive_deletes=True
)
def __repr__(self) -> str: # pragma: no cover - 调试辅助
return (
"LLMUsageLog(id={id}, provider_id={provider_id}, model={model}, tokens={tokens})"
).format(
id=self.id,
provider_id=self.provider_id,
model=self.model_name,
tokens=self.total_tokens,
)
__all__ = ["LLMUsageLog"]
```
## /app/schemas/__init__.py
```py path="/app/schemas/__init__.py"
from app.schemas.llm_provider import (
LLMProviderCreate,
LLMProviderRead,
LLMProviderUpdate,
LLMUsageLogRead,
LLMUsageMessage,
)
from app.schemas.metric import MetricCreate, MetricRead
from app.schemas.prompt import (
PromptClassRead,
PromptClassCreate,
PromptClassUpdate,
PromptClassStats,
PromptCreate,
PromptRead,
PromptTagCreate,
PromptTagListResponse,
PromptTagRead,
PromptTagStats,
PromptTagUpdate,
PromptUpdate,
PromptVersionCreate,
PromptVersionRead,
)
from app.schemas.result import ResultCreate, ResultRead
from app.schemas.test_run import TestRunCreate, TestRunRead, TestRunUpdate
from app.schemas.usage import UsageModelSummary, UsageOverview, UsageTimeseriesPoint
from app.schemas.settings import TestingTimeoutsRead, TestingTimeoutsUpdate
__all__ = [
"PromptClassRead",
"PromptClassCreate",
"PromptClassUpdate",
"PromptClassStats",
"PromptCreate",
"PromptUpdate",
"PromptRead",
"PromptTagCreate",
"PromptTagUpdate",
"PromptTagRead",
"PromptTagStats",
"PromptTagListResponse",
"PromptVersionCreate",
"PromptVersionRead",
"TestRunCreate",
"TestRunUpdate",
"TestRunRead",
"ResultCreate",
"ResultRead",
"MetricCreate",
"MetricRead",
"LLMProviderCreate",
"LLMProviderUpdate",
"LLMProviderRead",
"LLMUsageLogRead",
"LLMUsageMessage",
"UsageOverview",
"UsageModelSummary",
"UsageTimeseriesPoint",
"TestingTimeoutsRead",
"TestingTimeoutsUpdate",
]
```
## /app/schemas/llm_provider.py
```py path="/app/schemas/llm_provider.py"
from __future__ import annotations
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class LLMModelBase(BaseModel):
name: str = Field(..., description="模型唯一名称")
capability: str | None = Field(
default=None, description="可选的能力标签,例如对话、推理"
)
quota: str | None = Field(default=None, description="配额或调用策略说明")
concurrency_limit: int = Field(
default=5,
ge=1,
le=50,
description="执行测试任务时的最大并发请求数",
)
class LLMModelCreate(LLMModelBase):
pass
class LLMModelUpdate(BaseModel):
capability: str | None = Field(default=None, description="可选的能力标签")
quota: str | None = Field(default=None, description="配额或调用策略说明")
concurrency_limit: int | None = Field(
default=None,
ge=1,
le=50,
description="执行测试任务时的最大并发请求数",
)
class LLMModelRead(LLMModelBase):
id: int
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class LLMProviderBase(BaseModel):
provider_name: str = Field(..., description="展示用名称,例如 OpenAI")
provider_key: str | None = Field(
default=None, description="常用提供方标识,用于自动补全默认信息"
)
base_url: str | None = Field(default=None, description="调用使用的基础 URL")
logo_emoji: str | None = Field(
default=None, max_length=16, description="用于展示的表情符号"
)
logo_url: str | None = Field(default=None, description="可选的品牌 Logo URL")
is_custom: bool | None = Field(
default=None, description="是否为自定义提供方,未指定时由后端推断"
)
class LLMProviderCreate(LLMProviderBase):
api_key: str = Field(..., min_length=1, description="访问该提供方所需的密钥")
class LLMProviderUpdate(BaseModel):
provider_name: str | None = None
base_url: str | None = None
api_key: str | None = None
logo_emoji: str | None = Field(default=None, max_length=16)
logo_url: str | None = None
is_custom: bool | None = None
default_model_name: str | None = None
class LLMProviderRead(BaseModel):
id: int
provider_key: str | None
provider_name: str
base_url: str | None
logo_emoji: str | None
logo_url: str | None
is_custom: bool
is_archived: bool
default_model_name: str | None
masked_api_key: str
models: list[LLMModelRead] = Field(default_factory=list)
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class KnownLLMProvider(BaseModel):
key: str
name: str
description: str | None = None
base_url: str | None = None
logo_emoji: str | None = None
logo_url: str | None = None
class LLMUsageMessage(BaseModel):
role: str = Field(..., description="消息角色,例如 user、assistant")
content: Any = Field(..., description="与 OpenAI 兼容的消息内容")
class LLMUsageLogRead(BaseModel):
id: int
provider_id: int | None
provider_name: str | None
provider_logo_emoji: str | None
provider_logo_url: str | None
model_id: int | None
model_name: str
response_text: str | None
messages: list[LLMUsageMessage] = Field(default_factory=list)
temperature: float | None
latency_ms: int | None
prompt_tokens: int | None
completion_tokens: int | None
total_tokens: int | None
prompt_id: int | None
prompt_version_id: int | None
created_at: datetime
model_config = ConfigDict(from_attributes=True)
```
## /app/schemas/metric.py
```py path="/app/schemas/metric.py"
from datetime import datetime
from pydantic import BaseModel, ConfigDict
class MetricBase(BaseModel):
is_valid_json: bool | None = None
schema_pass: bool | None = None
missing_fields: dict | None = None
type_mismatches: dict | None = None
consistency_score: float | None = None
numeric_accuracy: float | None = None
boolean_accuracy: float | None = None
class MetricCreate(MetricBase):
result_id: int
class MetricRead(MetricBase):
id: int
result_id: int
created_at: datetime
model_config = ConfigDict(from_attributes=True)
```
## /app/schemas/prompt.py
```py path="/app/schemas/prompt.py"
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field, model_validator
class PromptTagBase(BaseModel):
name: str = Field(..., max_length=100)
color: str = Field(..., pattern=r"^#[0-9A-Fa-f]{6}{{contextString}}quot;)
class PromptTagRead(PromptTagBase):
id: int
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class PromptTagCreate(PromptTagBase):
"""Prompt 标签创建入参"""
@model_validator(mode="after")
def normalize_payload(self):
trimmed = self.name.strip()
if not trimmed:
raise ValueError("name 不能为空字符")
self.name = trimmed
self.color = self.color.upper()
return self
class PromptTagUpdate(BaseModel):
"""Prompt 标签更新入参"""
name: str | None = Field(default=None, max_length=100)
color: str | None = Field(default=None, pattern=r"^#[0-9A-Fa-f]{6}{{contextString}}quot;)
@model_validator(mode="after")
def normalize_payload(self):
if self.name is not None:
trimmed = self.name.strip()
if not trimmed:
raise ValueError("name 不能为空字符")
self.name = trimmed
if self.color is not None:
self.color = self.color.upper()
return self
class PromptTagStats(PromptTagRead):
prompt_count: int = Field(default=0, ge=0)
class PromptTagListResponse(BaseModel):
items: list[PromptTagStats]
tagged_prompt_total: int = Field(default=0, ge=0)
class PromptClassBase(BaseModel):
name: str = Field(..., max_length=255)
description: str | None = None
class PromptClassRead(PromptClassBase):
id: int
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class PromptClassCreate(PromptClassBase):
"""Prompt 分类创建入参"""
@model_validator(mode="after")
def validate_payload(self):
trimmed = self.name.strip()
if not trimmed:
raise ValueError("name 不能为空字符")
self.name = trimmed
return self
class PromptClassUpdate(BaseModel):
"""Prompt 分类更新入参"""
name: str | None = Field(default=None, max_length=255)
description: str | None = None
@model_validator(mode="after")
def validate_payload(self):
if self.name is not None and not self.name.strip():
raise ValueError("name 不能为空字符")
return self
class PromptClassStats(PromptClassRead):
"""带统计信息的 Prompt 分类出参"""
prompt_count: int = Field(default=0, ge=0)
latest_prompt_updated_at: datetime | None = None
class PromptVersionBase(BaseModel):
version: str = Field(..., max_length=50)
content: str
class PromptVersionCreate(PromptVersionBase):
prompt_id: int
class PromptVersionRead(PromptVersionBase):
id: int
prompt_id: int
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class PromptBase(BaseModel):
name: str = Field(..., max_length=255)
description: str | None = None
author: str | None = Field(default=None, max_length=100)
class PromptCreate(PromptBase):
class_id: int | None = Field(default=None, ge=1)
class_name: str | None = Field(default=None, max_length=255)
class_description: str | None = None
version: str = Field(..., max_length=50)
content: str
tag_ids: list[int] | None = Field(
default=None,
description="为 Prompt 选择的标签 ID 列表,未提供时保持既有设置",
)
@model_validator(mode="after")
def validate_class_reference(self):
if self.class_id is None and not (self.class_name and self.class_name.strip()):
raise ValueError("class_id 或 class_name 至少需要提供一个")
return self
class PromptUpdate(BaseModel):
name: str | None = Field(default=None, max_length=255)
description: str | None = None
author: str | None = Field(default=None, max_length=100)
class_id: int | None = Field(default=None, ge=1)
class_name: str | None = Field(default=None, max_length=255)
class_description: str | None = None
version: str | None = Field(default=None, max_length=50)
content: str | None = None
activate_version_id: int | None = Field(default=None, ge=1)
tag_ids: list[int] | None = Field(
default=None,
description="如果提供则覆盖 Prompt 的标签,传空列表代表清空标签",
)
@model_validator(mode="after")
def validate_version_payload(self):
if (self.version is None) != (self.content is None):
raise ValueError("更新版本时必须同时提供 version 与 content")
return self
@model_validator(mode="after")
def validate_class_reference(self):
if (
self.class_id is None
and self.class_name is not None
and not self.class_name.strip()
):
raise ValueError("class_name 不能为空字符串")
return self
class PromptRead(PromptBase):
id: int
prompt_class: PromptClassRead
current_version: PromptVersionRead | None = None
versions: list[PromptVersionRead] = Field(default_factory=list)
tags: list[PromptTagRead] = Field(default_factory=list)
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
```
## /app/schemas/prompt_test.py
```py path="/app/schemas/prompt_test.py"
from __future__ import annotations
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.models.prompt_test import (
PromptTestExperimentStatus,
PromptTestTaskStatus,
)
class PromptTestTaskBase(BaseModel):
"""测试任务基础字段定义。"""
__test__ = False
name: str = Field(..., max_length=120)
description: str | None = None
prompt_version_id: int | None = None
owner_id: int | None = None
config: dict[str, Any] | None = None
class PromptTestTaskCreate(PromptTestTaskBase):
"""创建测试任务时使用的结构,可附带预置的测试单元。"""
__test__ = False
units: list["PromptTestUnitCreate"] | None = None
auto_execute: bool = False
class PromptTestTaskUpdate(BaseModel):
"""更新测试任务时可修改的字段。"""
__test__ = False
name: str | None = Field(default=None, max_length=120)
description: str | None = None
config: dict[str, Any] | None = None
status: PromptTestTaskStatus | None = None
class PromptTestTaskRead(PromptTestTaskBase):
"""返回给前端的测试任务结构。"""
__test__ = False
id: int
status: PromptTestTaskStatus
created_at: datetime
updated_at: datetime
units: list["PromptTestUnitRead"] | None = None
model_config = ConfigDict(from_attributes=True)
class PromptTestUnitBase(BaseModel):
"""最小测试单元基础字段。"""
__test__ = False
name: str = Field(..., max_length=120)
description: str | None = None
model_name: str = Field(..., max_length=100)
llm_provider_id: int | None = None
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
top_p: float | None = Field(default=None, ge=0.0, le=1.0)
rounds: int = Field(default=1, ge=1, le=100)
prompt_template: str | None = None
variables: dict[str, Any] | list[Any] | None = None
parameters: dict[str, Any] | None = None
expectations: dict[str, Any] | None = None
tags: list[str] | None = None
extra: dict[str, Any] | None = None
class PromptTestUnitCreate(PromptTestUnitBase):
"""创建测试单元需要提供所属测试任务 ID。"""
__test__ = False
task_id: int | None = Field(default=None, ge=1)
prompt_version_id: int | None = None
class PromptTestUnitUpdate(BaseModel):
"""更新测试单元时允许修改的字段。"""
__test__ = False
description: str | None = None
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
top_p: float | None = Field(default=None, ge=0.0, le=1.0)
rounds: int | None = Field(default=None, ge=1, le=100)
prompt_template: str | None = None
variables: dict[str, Any] | list[Any] | None = None
parameters: dict[str, Any] | None = None
expectations: dict[str, Any] | None = None
tags: list[str] | None = None
extra: dict[str, Any] | None = None
class PromptTestUnitRead(PromptTestUnitBase):
"""返回给前端的最小测试单元结构。"""
__test__ = False
id: int
task_id: int
prompt_version_id: int | None = None
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class PromptTestExperimentBase(BaseModel):
"""新建实验需要提供的基础字段。"""
__test__ = False
unit_id: int | None = Field(default=None, ge=1)
batch_id: str | None = Field(default=None, max_length=64)
sequence: int | None = Field(default=None, ge=1, le=1000)
class PromptTestExperimentCreate(PromptTestExperimentBase):
"""创建实验的结构体,可选自动执行。"""
__test__ = False
auto_execute: bool = False
class PromptTestExperimentRead(BaseModel):
"""返回给前端的实验结果结构。"""
__test__ = False
id: int
unit_id: int
batch_id: str | None = None
sequence: int
status: PromptTestExperimentStatus
outputs: list[dict[str, Any]] | None = None
metrics: dict[str, Any] | None = None
error: str | None = None
started_at: datetime | None = None
finished_at: datetime | None = None
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
PromptTestTaskRead.model_rebuild()
PromptTestUnitRead.model_rebuild()
PromptTestExperimentRead.model_rebuild()
__all__ = [
"PromptTestTaskCreate",
"PromptTestTaskUpdate",
"PromptTestTaskRead",
"PromptTestUnitCreate",
"PromptTestUnitUpdate",
"PromptTestUnitRead",
"PromptTestExperimentCreate",
"PromptTestExperimentRead",
]
```
## /app/schemas/result.py
```py path="/app/schemas/result.py"
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field
from app.schemas.metric import MetricRead
class ResultBase(BaseModel):
run_index: int = Field(ge=0)
output: str
parsed_output: dict | None = None
tokens_used: int | None = Field(default=None, ge=0)
latency_ms: int | None = Field(default=None, ge=0)
class ResultCreate(ResultBase):
test_run_id: int
class ResultRead(ResultBase):
id: int
test_run_id: int
created_at: datetime
metrics: list[MetricRead] = []
model_config = ConfigDict(from_attributes=True)
```
## /app/schemas/settings.py
```py path="/app/schemas/settings.py"
from __future__ import annotations
from datetime import datetime
from typing import Annotated
from pydantic import BaseModel, Field
TimeoutSecondsType = Annotated[int, Field(ge=1, le=600)]
class TestingTimeoutsBase(BaseModel):
quick_test_timeout: TimeoutSecondsType = Field(
...,
description="快速测试的超时时间,单位:秒",
)
test_task_timeout: TimeoutSecondsType = Field(
...,
description="测试任务的超时时间,单位:秒",
)
class TestingTimeoutsUpdate(TestingTimeoutsBase):
"""更新快速测试/测试任务超时配置的请求体。"""
class TestingTimeoutsRead(TestingTimeoutsBase):
"""返回快速测试/测试任务超时配置的响应体。"""
updated_at: datetime | None = Field(
default=None,
description="配置最近更新时间,若尚未设置则为空",
)
__all__ = [
"TestingTimeoutsRead",
"TestingTimeoutsUpdate",
]
```
## /app/schemas/test_run.py
```py path="/app/schemas/test_run.py"
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field
from app.models.test_run import TestRunStatus
from app.schemas.prompt import PromptRead, PromptVersionRead
from app.schemas.result import ResultRead
class TestRunBase(BaseModel):
__test__ = False
model_name: str = Field(..., max_length=100)
model_version: str | None = Field(default=None, max_length=50)
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
repetitions: int = Field(default=1, ge=1, le=50)
schema_data: dict | None = Field(
default=None, alias="schema", serialization_alias="schema"
)
notes: str | None = None
batch_id: str | None = Field(default=None, max_length=64)
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
class TestRunCreate(TestRunBase):
__test__ = False
prompt_version_id: int = Field(..., ge=1)
class TestRunUpdate(BaseModel):
__test__ = False
model_name: str | None = Field(default=None, max_length=100)
model_version: str | None = Field(default=None, max_length=50)
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
top_p: float | None = Field(default=None, ge=0.0, le=1.0)
repetitions: int | None = Field(default=None, ge=1, le=50)
schema_data: dict | None = Field(
default=None, alias="schema", serialization_alias="schema"
)
notes: str | None = None
status: TestRunStatus | None = None
batch_id: str | None = Field(default=None, max_length=64)
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
class TestRunRead(TestRunBase):
__test__ = False
id: int
prompt_version_id: int
status: TestRunStatus
failure_reason: str | None = None
created_at: datetime
updated_at: datetime
prompt_version: PromptVersionRead | None = None
prompt: PromptRead | None = None
results: list[ResultRead] = []
model_config = ConfigDict(
from_attributes=True, populate_by_name=True, serialize_by_alias=True
)
```
## /app/schemas/usage.py
```py path="/app/schemas/usage.py"
from datetime import date
from pydantic import BaseModel, ConfigDict, Field
class UsageOverview(BaseModel):
total_tokens: int = Field(default=0, ge=0)
input_tokens: int = Field(default=0, ge=0)
output_tokens: int = Field(default=0, ge=0)
call_count: int = Field(default=0, ge=0)
class UsageModelSummary(BaseModel):
model_key: str
model_name: str
provider: str
total_tokens: int = Field(default=0, ge=0)
input_tokens: int = Field(default=0, ge=0)
output_tokens: int = Field(default=0, ge=0)
call_count: int = Field(default=0, ge=0)
model_config = ConfigDict(from_attributes=True)
class UsageTimeseriesPoint(BaseModel):
date: date
input_tokens: int = Field(default=0, ge=0)
output_tokens: int = Field(default=0, ge=0)
call_count: int = Field(default=0, ge=0)
model_config = ConfigDict(from_attributes=True)
__all__ = ["UsageOverview", "UsageModelSummary", "UsageTimeseriesPoint"]
```
## /app/services/__init__.py
```py path="/app/services/__init__.py"
```
## /app/services/llm_usage.py
```py path="/app/services/llm_usage.py"
from __future__ import annotations
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from app.models.usage import LLMUsageLog
def list_quick_test_usage_logs(
db: Session, *, limit: int = 20, offset: int = 0
) -> list[LLMUsageLog]:
stmt = (
select(LLMUsageLog)
.options(selectinload(LLMUsageLog.provider))
.where(LLMUsageLog.source == "quick_test")
.order_by(LLMUsageLog.created_at.desc())
.offset(offset)
.limit(limit)
)
return list(db.scalars(stmt))
__all__ = ["list_quick_test_usage_logs"]
```
## /app/services/prompt_test_engine.py
```py path="/app/services/prompt_test_engine.py"
from __future__ import annotations
import logging
import random
import statistics
import time
from collections.abc import Callable, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any
import httpx
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.llm_provider_registry import get_provider_defaults
from app.models.llm_provider import LLMModel, LLMProvider
from app.models.prompt_test import (
PromptTestExperiment,
PromptTestExperimentStatus,
PromptTestUnit,
)
from app.models.usage import LLMUsageLog
from app.services.test_run import (
REQUEST_SLEEP_RANGE,
_format_error_detail,
_try_parse_json,
)
from app.services.system_settings import (
DEFAULT_TEST_TASK_TIMEOUT,
get_testing_timeout_config,
)
logger = logging.getLogger("promptworks.prompt_test_engine")
_KNOWN_PARAMETER_KEYS = {
"max_tokens",
"presence_penalty",
"frequency_penalty",
"response_format",
"stop",
"logit_bias",
"top_k",
"seed",
"user",
"n",
"parallel_tool_calls",
"tool_choice",
"tools",
"metadata",
}
_NESTED_PARAMETER_KEYS = {"llm_parameters", "model_parameters", "parameters"}
class PromptTestExecutionError(Exception):
"""执行 Prompt 测试实验时抛出的业务异常。"""
__test__ = False
def __init__(self, message: str, *, status_code: int | None = None) -> None:
super().__init__(message)
self.status_code = status_code
def execute_prompt_test_experiment(
db: Session,
experiment: PromptTestExperiment,
progress_callback: Callable[[int], None] | None = None,
) -> PromptTestExperiment:
"""执行单个最小测试单元的实验,并存储结果。"""
if experiment.status not in {
PromptTestExperimentStatus.PENDING,
PromptTestExperimentStatus.RUNNING,
}:
return experiment
unit = experiment.unit
if unit is None:
raise PromptTestExecutionError("实验缺少关联的测试单元。")
provider, model = _resolve_provider_and_model(db, unit)
prompt_snapshot = _resolve_prompt_snapshot(unit)
parameters = _collect_parameters(unit)
context_template = unit.variables or {}
experiment.status = PromptTestExperimentStatus.RUNNING
experiment.started_at = datetime.now(UTC)
experiment.error = None
db.flush()
timeout_config = get_testing_timeout_config(db)
request_timeout = float(
timeout_config.test_task_timeout or DEFAULT_TEST_TASK_TIMEOUT
)
run_records: list[dict[str, Any]] = []
latencies: list[int] = []
token_totals: list[int] = []
json_success = 0
rounds_per_case = max(1, int(unit.rounds or 1))
case_count = _count_variable_cases(context_template)
total_runs = rounds_per_case * max(case_count, 1)
for run_index in range(1, total_runs + 1):
try:
run_record = _execute_single_round(
provider=provider,
model=model,
unit=unit,
prompt_snapshot=prompt_snapshot,
base_parameters=parameters,
context_template=context_template,
run_index=run_index,
request_timeout=request_timeout,
)
except PromptTestExecutionError as exc:
experiment.status = PromptTestExperimentStatus.FAILED
experiment.error = str(exc)
experiment.finished_at = datetime.now(UTC)
db.flush()
return experiment
run_records.append(run_record)
if progress_callback is not None:
try:
progress_callback(1)
except Exception: # pragma: no cover - 防御性兜底
logger.exception("更新 Prompt 测试进度时出现异常")
usage_log = _build_usage_log(
provider=provider,
model=model,
unit=unit,
run_record=run_record,
)
db.add(usage_log)
latency = run_record.get("latency_ms")
if isinstance(latency, (int, float)):
latencies.append(int(latency))
tokens = run_record.get("total_tokens")
if isinstance(tokens, (int, float)):
token_totals.append(int(tokens))
if run_record.get("parsed_output") is not None:
json_success += 1
experiment.outputs = run_records
experiment.metrics = _aggregate_metrics(
latencies=latencies,
tokens=token_totals,
total_rounds=len(run_records),
json_success=json_success,
)
experiment.status = PromptTestExperimentStatus.COMPLETED
experiment.finished_at = datetime.now(UTC)
db.flush()
return experiment
def _resolve_provider_and_model(
db: Session, unit: PromptTestUnit
) -> tuple[LLMProvider, LLMModel | None]:
provider: LLMProvider | None = None
model: LLMModel | None = None
if isinstance(unit.llm_provider_id, int):
provider = db.get(LLMProvider, unit.llm_provider_id)
extra_data = unit.extra if isinstance(unit.extra, Mapping) else {}
provider_key = extra_data.get("provider_key")
if provider is None and isinstance(provider_key, str):
provider = db.scalar(
select(LLMProvider).where(LLMProvider.provider_key == provider_key)
)
model_id = extra_data.get("llm_model_id")
if provider and isinstance(model_id, int):
model = db.get(LLMModel, model_id)
if provider is None:
stmt = (
select(LLMProvider, LLMModel)
.join(LLMModel, LLMModel.provider_id == LLMProvider.id)
.where(LLMModel.name == unit.model_name)
)
record = db.execute(stmt).first()
if record:
provider, model = record
if provider is None:
provider = db.scalar(
select(LLMProvider).where(LLMProvider.provider_name == unit.model_name)
)
if provider is None:
raise PromptTestExecutionError("未找到合适的模型提供者配置。")
if model is None:
model = db.scalar(
select(LLMModel).where(
LLMModel.provider_id == provider.id,
LLMModel.name == unit.model_name,
)
)
return provider, model
def _resolve_prompt_snapshot(unit: PromptTestUnit) -> str:
if unit.prompt_template:
return str(unit.prompt_template)
if unit.prompt_version and unit.prompt_version.content:
return unit.prompt_version.content
return ""
def _collect_parameters(unit: PromptTestUnit) -> dict[str, Any]:
params: dict[str, Any] = {"temperature": unit.temperature}
if unit.top_p is not None:
params["top_p"] = unit.top_p
raw_parameters = unit.parameters if isinstance(unit.parameters, Mapping) else {}
for key in _NESTED_PARAMETER_KEYS:
nested = raw_parameters.get(key)
if isinstance(nested, Mapping):
params.update(dict(nested))
for key, value in raw_parameters.items():
if key in {"conversation", "messages"}:
continue
if key in _KNOWN_PARAMETER_KEYS and value is not None:
params[key] = value
return params
def _execute_single_round(
*,
provider: LLMProvider,
model: LLMModel | None,
unit: PromptTestUnit,
prompt_snapshot: str,
base_parameters: Mapping[str, Any],
context_template: Mapping[str, Any] | Sequence[Any],
run_index: int,
request_timeout: float,
) -> dict[str, Any]:
context = _resolve_context(context_template, run_index)
messages = _build_messages(unit, prompt_snapshot, context, run_index)
payload = {
"model": model.name if model else unit.model_name,
"messages": messages,
**base_parameters,
}
request_parameters = {
key: value for key, value in payload.items() if key not in {"model", "messages"}
}
base_url = _resolve_base_url(provider)
headers = {
"Authorization": f"Bearer {provider.api_key}",
"Content-Type": "application/json",
}
try:
sleep_lower, sleep_upper = REQUEST_SLEEP_RANGE
if sleep_upper > 0:
jitter = random.uniform(sleep_lower, sleep_upper)
if jitter > 0:
time.sleep(jitter)
except Exception: # pragma: no cover - 容错兜底
pass
start_time = time.perf_counter()
try:
response = httpx.post(
f"{base_url}/chat/completions",
headers=headers,
json=payload,
timeout=request_timeout,
)
except httpx.HTTPError as exc: # pragma: no cover - 网络异常兜底
raise PromptTestExecutionError(f"调用外部 LLM 失败: {exc}") from exc
if response.status_code >= 400:
try:
error_payload = response.json()
except ValueError:
error_payload = {"message": response.text}
detail = _format_error_detail(error_payload)
raise PromptTestExecutionError(
f"LLM 请求失败 (HTTP {response.status_code}): {detail}",
status_code=response.status_code,
)
try:
payload_obj = response.json()
except ValueError as exc: # pragma: no cover - 响应解析异常
raise PromptTestExecutionError("LLM 响应解析失败。") from exc
elapsed = response.elapsed.total_seconds() * 1000 if response.elapsed else None
latency_ms = (
int(elapsed)
if elapsed is not None
else int((time.perf_counter() - start_time) * 1000)
)
latency_ms = max(latency_ms, 0)
output_text = _extract_output(payload_obj)
parsed_output = _try_parse_json(output_text)
usage = (
payload_obj.get("usage")
if isinstance(payload_obj.get("usage"), Mapping)
else {}
)
prompt_tokens = _safe_int(usage.get("prompt_tokens"))
completion_tokens = _safe_int(usage.get("completion_tokens"))
total_tokens = _safe_int(usage.get("total_tokens"))
if (
total_tokens is None
and prompt_tokens is not None
and completion_tokens is not None
):
total_tokens = prompt_tokens + completion_tokens
variables = _extract_variables(context)
return {
"run_index": run_index,
"messages": messages,
"parameters": request_parameters or None,
"variables": variables,
"output_text": output_text,
"parsed_output": parsed_output,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"latency_ms": latency_ms,
"raw_response": payload_obj,
}
def _resolve_context(
template: Mapping[str, Any] | Sequence[Any], run_index: int
) -> dict[str, Any]:
context: dict[str, Any] = {"run_index": run_index}
if isinstance(template, Mapping):
defaults = template.get("defaults")
if isinstance(defaults, Mapping):
context.update(
{k: v for k, v in defaults.items() if not isinstance(k, (int, float))}
)
cases = template.get("cases")
if isinstance(cases, Sequence) and cases:
selected = cases[(run_index - 1) % len(cases)]
if isinstance(selected, Mapping):
context.update(selected)
else:
context["case"] = selected
for key, value in template.items():
if key in {"defaults", "cases"}:
continue
if isinstance(value, Mapping):
continue
if isinstance(value, Sequence) and not isinstance(
value, (str, bytes, bytearray)
):
continue
context.setdefault(key, value)
elif (
isinstance(template, Sequence)
and not isinstance(template, (str, bytes, bytearray))
and template
):
selected = template[(run_index - 1) % len(template)]
if isinstance(selected, Mapping):
context.update(selected)
else:
context["value"] = selected
return context
def _count_variable_cases(template: Mapping[str, Any] | Sequence[Any] | None) -> int:
if isinstance(template, Mapping):
cases = template.get("cases")
if isinstance(cases, Sequence) and not isinstance(
cases, (str, bytes, bytearray)
):
return len(cases)
return 1 if template else 0
if isinstance(template, Sequence) and not isinstance(
template, (str, bytes, bytearray)
):
return len(template)
return 0
def _extract_variables(context: Mapping[str, Any] | None) -> dict[str, Any] | None:
if not isinstance(context, Mapping):
return None
sanitized: dict[str, Any] = {}
for key, value in context.items():
if key == "run_index":
continue
sanitized[key] = value
return sanitized or None
def _build_messages(
unit: PromptTestUnit,
prompt_snapshot: str,
context: Mapping[str, Any],
run_index: int,
) -> list[dict[str, Any]]:
conversation: Any = None
if isinstance(unit.parameters, Mapping):
conversation = unit.parameters.get("conversation") or unit.parameters.get(
"messages"
)
messages: list[dict[str, Any]] = []
if isinstance(conversation, Sequence):
for item in conversation:
if not isinstance(item, Mapping):
continue
role = str(item.get("role", "")).strip() or "user"
content = _format_text(item.get("content"), context, run_index)
if content is None:
continue
messages.append({"role": role, "content": content})
snapshot_message: dict[str, Any] | None = None
prompt_message = _format_text(prompt_snapshot, context, run_index)
if prompt_message:
if not messages:
snapshot_message = {"role": "user", "content": prompt_message}
messages.append(snapshot_message)
elif not any(msg["role"] == "system" for msg in messages):
snapshot_message = {"role": "user", "content": prompt_message}
messages.insert(0, snapshot_message)
user_template = unit.prompt_template or context.get("user_prompt")
user_message = _format_text(user_template, context, run_index)
has_user = any(
msg.get("role") == "user"
and (snapshot_message is None or msg is not snapshot_message)
for msg in messages
)
if user_message and not has_user:
messages.append({"role": "user", "content": user_message})
if not messages:
messages.append(
{
"role": "user",
"content": f"请生成第 {run_index} 次响应。",
}
)
return messages
def _format_text(
template: Any, context: Mapping[str, Any], run_index: int
) -> str | None:
if template is None:
return None
if not isinstance(template, str):
return str(template)
try:
replaced = template.replace("{{run_index}}", str(run_index))
return replaced.format(**context)
except Exception:
return template.replace("{{run_index}}", str(run_index))
def _extract_output(payload_obj: Mapping[str, Any]) -> str:
choices = payload_obj.get("choices")
if isinstance(choices, Sequence) and choices:
first = choices[0]
if isinstance(first, Mapping):
message = first.get("message")
if isinstance(message, Mapping) and isinstance(message.get("content"), str):
return message["content"]
text_value = first.get("text")
if isinstance(text_value, str):
return text_value
return ""
def _safe_int(value: Any) -> int | None:
if isinstance(value, (int, float)):
return int(value)
try:
if isinstance(value, str) and value.strip():
return int(float(value))
except Exception: # pragma: no cover - 容错
return None
return None
def _aggregate_metrics(
*,
latencies: Sequence[int],
tokens: Sequence[int],
total_rounds: int,
json_success: int,
) -> dict[str, Any]:
metrics: dict[str, Any] = {
"rounds": total_rounds,
}
if latencies:
metrics["avg_latency_ms"] = statistics.fmean(latencies)
metrics["max_latency_ms"] = max(latencies)
metrics["min_latency_ms"] = min(latencies)
if tokens:
metrics["avg_total_tokens"] = statistics.fmean(tokens)
metrics["max_total_tokens"] = max(tokens)
metrics["min_total_tokens"] = min(tokens)
if total_rounds:
metrics["json_success_rate"] = round(json_success / total_rounds, 4)
return metrics
def _build_usage_log(
*,
provider: LLMProvider,
model: LLMModel | None,
unit: PromptTestUnit,
run_record: Mapping[str, Any],
) -> LLMUsageLog:
prompt_version = unit.prompt_version
prompt_id: int | None = None
prompt_version_id = unit.prompt_version_id
if (
prompt_version is not None
and getattr(prompt_version, "prompt_id", None) is not None
):
prompt_id = prompt_version.prompt_id
else:
task = getattr(unit, "task", None)
task_prompt_version = getattr(task, "prompt_version", None) if task else None
if prompt_version_id is None and task is not None:
prompt_version_id = getattr(task, "prompt_version_id", None)
if (
task_prompt_version is not None
and getattr(task_prompt_version, "prompt_id", None) is not None
):
prompt_id = task_prompt_version.prompt_id
def _safe_int_value(key: str) -> int | None:
value = run_record.get(key)
if isinstance(value, (int, float)):
return int(value)
return None
latency_value = _safe_int_value("latency_ms")
return LLMUsageLog(
provider_id=provider.id,
model_id=model.id if model else None,
model_name=model.name if model else unit.model_name,
source="prompt_test",
prompt_id=prompt_id,
prompt_version_id=prompt_version_id,
messages=run_record.get("messages"),
parameters=run_record.get("parameters"),
response_text=run_record.get("output_text"),
temperature=unit.temperature,
latency_ms=latency_value,
prompt_tokens=_safe_int_value("prompt_tokens"),
completion_tokens=_safe_int_value("completion_tokens"),
total_tokens=_safe_int_value("total_tokens"),
)
def _resolve_base_url(provider: LLMProvider) -> str:
defaults = get_provider_defaults(provider.provider_key)
base_url = provider.base_url or (defaults.base_url if defaults else None)
if not base_url:
raise PromptTestExecutionError("模型提供者缺少基础 URL 配置。")
return base_url.rstrip("/")
__all__ = ["execute_prompt_test_experiment", "PromptTestExecutionError"]
```
## /app/services/system_settings.py
```py path="/app/services/system_settings.py"
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from sqlalchemy.orm import Session
from app.models.system_setting import SystemSetting
TESTING_TIMEOUT_SETTING_KEY = "testing_timeout"
DEFAULT_QUICK_TEST_TIMEOUT = 30.0
DEFAULT_TEST_TASK_TIMEOUT = 30.0
@dataclass(slots=True)
class TestingTimeoutConfig:
"""封装快速测试与测试任务的超时配置。"""
quick_test_timeout: float
test_task_timeout: float
updated_at: datetime | None = None
def _coerce_timeout(value: Any, default: float) -> float:
"""将任意输入转换为合法的超时秒数。"""
numeric: float | None = None
if isinstance(value, (int, float)):
numeric = float(value)
elif isinstance(value, str):
try:
numeric = float(value.strip())
except ValueError:
numeric = None
if numeric is None or not numeric > 0:
return default
return float(numeric)
def get_testing_timeout_config(db: Session) -> TestingTimeoutConfig:
"""读取快速测试与测试任务的超时配置,若未设置则返回默认值。"""
record = db.get(SystemSetting, TESTING_TIMEOUT_SETTING_KEY)
if record is None:
return TestingTimeoutConfig(
quick_test_timeout=DEFAULT_QUICK_TEST_TIMEOUT,
test_task_timeout=DEFAULT_TEST_TASK_TIMEOUT,
updated_at=None,
)
value = record.value if isinstance(record.value, Mapping) else {}
quick_timeout = _coerce_timeout(
value.get("quick_test_timeout", DEFAULT_QUICK_TEST_TIMEOUT),
DEFAULT_QUICK_TEST_TIMEOUT,
)
task_timeout = _coerce_timeout(
value.get("test_task_timeout", DEFAULT_TEST_TASK_TIMEOUT),
DEFAULT_TEST_TASK_TIMEOUT,
)
return TestingTimeoutConfig(
quick_test_timeout=quick_timeout,
test_task_timeout=task_timeout,
updated_at=record.updated_at,
)
def update_testing_timeout_config(
db: Session,
*,
quick_test_timeout: float,
test_task_timeout: float,
) -> TestingTimeoutConfig:
"""更新快速测试与测试任务的超时配置。"""
sanitized_quick = _coerce_timeout(quick_test_timeout, DEFAULT_QUICK_TEST_TIMEOUT)
sanitized_task = _coerce_timeout(test_task_timeout, DEFAULT_TEST_TASK_TIMEOUT)
payload = {
"quick_test_timeout": sanitized_quick,
"test_task_timeout": sanitized_task,
}
record = db.get(SystemSetting, TESTING_TIMEOUT_SETTING_KEY)
if record is None:
record = SystemSetting(
key=TESTING_TIMEOUT_SETTING_KEY,
value=payload,
description="快速测试与测试任务的超时时间(秒)",
)
db.add(record)
else:
record.value = payload
db.flush()
db.commit()
db.refresh(record)
return TestingTimeoutConfig(
quick_test_timeout=sanitized_quick,
test_task_timeout=sanitized_task,
updated_at=record.updated_at,
)
__all__ = [
"TestingTimeoutConfig",
"DEFAULT_QUICK_TEST_TIMEOUT",
"DEFAULT_TEST_TASK_TIMEOUT",
"get_testing_timeout_config",
"update_testing_timeout_config",
]
```
## /app/services/test_run.py
```py path="/app/services/test_run.py"
from __future__ import annotations
import json
import random
import time
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Any
import httpx
from sqlalchemy import select
from sqlalchemy.orm import Session
from starlette import status
from app.core.llm_provider_registry import get_provider_defaults
from app.models.llm_provider import LLMModel, LLMProvider
from app.models.result import Result
from app.models.test_run import TestRun, TestRunStatus
from app.models.usage import LLMUsageLog
from app.services.system_settings import (
DEFAULT_TEST_TASK_TIMEOUT,
get_testing_timeout_config,
)
DEFAULT_TEST_TIMEOUT = DEFAULT_TEST_TASK_TIMEOUT
DEFAULT_CONCURRENCY_LIMIT = 5
REQUEST_SLEEP_RANGE = (0.05, 0.2)
_KNOWN_PARAMETER_KEYS = {
"max_tokens",
"presence_penalty",
"frequency_penalty",
"response_format",
"stop",
"logit_bias",
"top_k",
"seed",
"user",
"n",
"parallel_tool_calls",
"tool_choice",
"tools",
"metadata",
}
_NESTED_PARAMETER_KEYS = {"llm_parameters", "model_parameters", "parameters"}
@dataclass(frozen=True)
class RunRequestContext:
test_run_id: int
model_name: str
prompt_id: int | None
prompt_version_id: int | None
timeout_seconds: float
class TestRunExecutionError(Exception):
"""执行测试任务过程中出现的业务异常。"""
__test__ = False
def __init__(
self, message: str, *, status_code: int = status.HTTP_400_BAD_REQUEST
) -> None:
super().__init__(message)
self.status_code = status_code
def execute_test_run(db: Session, test_run: TestRun) -> TestRun:
"""调用外部 LLM 完成测试任务,并记录结果与用量。"""
if test_run.status not in {TestRunStatus.PENDING, TestRunStatus.RUNNING}:
return test_run
provider, model = _resolve_provider_and_model(db, test_run)
prompt_version = test_run.prompt_version
if not prompt_version:
raise TestRunExecutionError("测试任务缺少关联的 Prompt 版本。")
prompt_snapshot = prompt_version.content
schema_data = _ensure_mapping(test_run.schema)
schema_data.pop("last_error", None)
schema_data.pop("last_error_status", None)
schema_data.setdefault("prompt_snapshot", prompt_snapshot)
schema_data.setdefault("llm_provider_id", provider.id)
schema_data.setdefault("llm_provider_name", provider.provider_name)
if model:
schema_data.setdefault("llm_model_id", model.id)
test_run.schema = schema_data
parameters_template = _build_parameters(test_run, schema_data)
base_url = _resolve_base_url(provider)
headers = {
"Authorization": f"Bearer {provider.api_key}",
"Content-Type": "application/json",
}
test_run.status = TestRunStatus.RUNNING
db.flush()
timeout_config = get_testing_timeout_config(db)
request_timeout = float(timeout_config.test_task_timeout or DEFAULT_TEST_TIMEOUT)
context = RunRequestContext(
test_run_id=test_run.id,
model_name=test_run.model_name,
prompt_id=prompt_version.prompt_id,
prompt_version_id=test_run.prompt_version_id,
timeout_seconds=request_timeout,
)
concurrency_limit = DEFAULT_CONCURRENCY_LIMIT
if model and isinstance(model.concurrency_limit, int):
concurrency_limit = max(1, model.concurrency_limit)
def _execute_single(run_index: int) -> tuple[int, Result, LLMUsageLog]:
messages = _build_messages(schema_data, prompt_snapshot, run_index)
payload: dict[str, Any] = dict(parameters_template)
payload["model"] = model.name if model else test_run.model_name
payload["messages"] = messages
result, usage_log = _invoke_llm_once(
provider=provider,
model=model,
base_url=base_url,
headers=headers,
payload=payload,
context=context,
)
result.test_run_id = context.test_run_id
result.run_index = run_index
return run_index, result, usage_log
run_indices = range(1, test_run.repetitions + 1)
error_message: str | None = None
error_status_code: int | None = None
worker_count = max(1, min(concurrency_limit, test_run.repetitions))
with ThreadPoolExecutor(max_workers=worker_count) as executor:
future_map = {
executor.submit(_execute_single, index): index for index in run_indices
}
for future in as_completed(future_map):
try:
_, result_obj, usage_obj = future.result()
except TestRunExecutionError as exc:
if error_message is None:
error_message = str(exc)
error_status_code = getattr(exc, "status_code", None)
except Exception as exc: # pragma: no cover - 防御性
if error_message is None:
error_message = f"执行测试任务失败: {exc}"
error_status_code = status.HTTP_502_BAD_GATEWAY
else:
_persist_run_artifacts(db, result_obj, usage_obj)
if error_message:
test_run.status = TestRunStatus.FAILED
test_run.last_error = error_message
if error_status_code is not None:
current_schema = _ensure_mapping(test_run.schema)
current_schema["last_error_status"] = error_status_code
test_run.schema = current_schema
else:
test_run.status = TestRunStatus.COMPLETED
test_run.last_error = None
current_schema = _ensure_mapping(test_run.schema)
if current_schema:
current_schema.pop("last_error_status", None)
test_run.schema = current_schema or None
db.flush()
return test_run
def ensure_completed(db: Session, runs: Sequence[TestRun]) -> None:
for run in runs:
execute_test_run(db, run)
db.flush()
def _resolve_provider_and_model(
db: Session, test_run: TestRun
) -> tuple[LLMProvider, LLMModel | None]:
schema_data = _ensure_mapping(test_run.schema)
provider: LLMProvider | None = None
model: LLMModel | None = None
provider_id = schema_data.get("llm_provider_id") or schema_data.get("provider_id")
model_id = schema_data.get("llm_model_id") or schema_data.get("model_id")
if isinstance(provider_id, int):
provider = db.get(LLMProvider, provider_id)
elif isinstance(provider_id, str) and provider_id.isdigit():
provider = db.get(LLMProvider, int(provider_id))
if provider and isinstance(model_id, int):
model = db.get(LLMModel, model_id)
elif provider and isinstance(model_id, str) and model_id.isdigit():
model = db.get(LLMModel, int(model_id))
if model and provider and model.provider_id != provider.id:
model = None
if provider is None and test_run.model_version:
provider = db.scalar(
select(LLMProvider).where(
LLMProvider.provider_name == test_run.model_version
)
)
if provider is None and isinstance(
schema_key := schema_data.get("provider_key"), str
):
provider = db.scalar(
select(LLMProvider).where(LLMProvider.provider_key == schema_key)
)
if provider is None:
stmt = (
select(LLMProvider, LLMModel)
.join(LLMModel, LLMModel.provider_id == LLMProvider.id)
.where(LLMModel.name == test_run.model_name)
)
record = db.execute(stmt).first()
if record:
provider, model = record
if provider is None:
raise TestRunExecutionError(
"未找到可用的模型提供者配置。", status_code=status.HTTP_404_NOT_FOUND
)
if model is None:
model = db.scalar(
select(LLMModel).where(
LLMModel.provider_id == provider.id,
LLMModel.name == test_run.model_name,
)
)
return provider, model
def _resolve_base_url(provider: LLMProvider) -> str:
defaults = get_provider_defaults(provider.provider_key)
base_url = provider.base_url or (defaults.base_url if defaults else None)
if not base_url:
raise TestRunExecutionError("模型提供者缺少基础 URL 配置。")
return base_url.rstrip("/")
def _ensure_mapping(raw: Any) -> dict[str, Any]:
if isinstance(raw, Mapping):
return dict(raw)
return {}
def _build_parameters(
test_run: TestRun, schema_data: Mapping[str, Any]
) -> dict[str, Any]:
parameters: dict[str, Any] = {
"temperature": test_run.temperature,
}
if test_run.top_p is not None:
parameters["top_p"] = test_run.top_p
for key in _NESTED_PARAMETER_KEYS:
nested = schema_data.get(key)
if isinstance(nested, Mapping):
parameters.update(dict(nested))
for key in _KNOWN_PARAMETER_KEYS:
if key in schema_data and schema_data[key] is not None:
parameters[key] = schema_data[key]
return parameters
def _render_content(content: Any, run_index: int) -> Any:
if isinstance(content, str):
return content.replace("{{run_index}}", str(run_index))
return content
def _build_messages(
schema_data: Mapping[str, Any], prompt_snapshot: str, run_index: int
) -> list[dict[str, Any]]:
raw_conversation = schema_data.get("conversation")
messages: list[dict[str, Any]] = []
if isinstance(raw_conversation, Sequence):
for item in raw_conversation:
if not isinstance(item, Mapping):
continue
role = str(item.get("role", "")).strip()
content = _render_content(item.get("content"), run_index)
if not role or content is None:
continue
messages.append({"role": role, "content": content})
snapshot_message: dict[str, Any] | None = None
if prompt_snapshot:
if not messages:
snapshot_message = {"role": "user", "content": prompt_snapshot}
messages.append(snapshot_message)
elif not any(message.get("role") == "system" for message in messages):
snapshot_message = {"role": "user", "content": prompt_snapshot}
messages.insert(0, snapshot_message)
has_user = any(
message.get("role") == "user"
and (snapshot_message is None or message is not snapshot_message)
for message in messages
)
if not has_user:
user_inputs = schema_data.get("inputs") or schema_data.get("test_inputs")
user_message: str
if isinstance(user_inputs, Sequence) and user_inputs:
index = (run_index - 1) % len(user_inputs)
candidate = user_inputs[index]
user_message = (
_render_content(candidate, run_index)
if isinstance(candidate, str)
else str(candidate)
)
else:
user_message = f"请根据提示生成第 {run_index} 次响应。"
messages.append({"role": "user", "content": user_message})
normalized: list[dict[str, Any]] = []
for message in messages:
role = str(message.get("role", "")).strip() or "user"
content = message.get("content")
normalized.append({"role": role, "content": content})
return normalized
def _invoke_llm_once(
*,
provider: LLMProvider,
model: LLMModel | None,
base_url: str,
headers: Mapping[str, str],
payload: dict[str, Any],
context: RunRequestContext,
) -> tuple[Result, LLMUsageLog]:
url = f"{base_url}/chat/completions"
start_time = time.perf_counter()
try:
sleep_lower, sleep_upper = REQUEST_SLEEP_RANGE
if sleep_upper > 0:
jitter = random.uniform(sleep_lower, sleep_upper)
if jitter > 0:
time.sleep(jitter)
except Exception: # pragma: no cover - 容错
pass
try:
timeout_value = getattr(context, "timeout_seconds", DEFAULT_TEST_TIMEOUT)
response = httpx.post(
url, headers=dict(headers), json=payload, timeout=timeout_value
)
except httpx.HTTPError as exc: # pragma: no cover - 网络异常场景
raise TestRunExecutionError(
f"调用外部 LLM 失败: {exc}", status_code=status.HTTP_502_BAD_GATEWAY
) from exc
if response.status_code >= 400:
try:
error_payload = response.json()
except ValueError:
error_payload = {"message": response.text}
detail_text = _format_error_detail(error_payload)
raise TestRunExecutionError(
f"LLM 请求失败 (HTTP {response.status_code}): {detail_text}",
status_code=response.status_code,
) from None
try:
payload_obj = response.json()
except ValueError as exc: # pragma: no cover - 防御性
raise TestRunExecutionError(
"LLM 响应解析失败。", status_code=status.HTTP_502_BAD_GATEWAY
) from exc
choices = payload_obj.get("choices")
output_text = ""
if isinstance(choices, Sequence) and choices:
first = choices[0]
if isinstance(first, Mapping):
message = first.get("message")
if isinstance(message, Mapping) and isinstance(message.get("content"), str):
output_text = message["content"]
elif isinstance(first.get("text"), str):
output_text = str(first["text"])
output_text = str(output_text)
parsed_output = _try_parse_json(output_text)
usage = (
payload_obj.get("usage")
if isinstance(payload_obj.get("usage"), Mapping)
else {}
)
prompt_tokens = usage.get("prompt_tokens") if isinstance(usage, Mapping) else None
completion_tokens = (
usage.get("completion_tokens") if isinstance(usage, Mapping) else None
)
total_tokens = usage.get("total_tokens") if isinstance(usage, Mapping) else None
if total_tokens is None and any(
isinstance(value, (int, float)) for value in (prompt_tokens, completion_tokens)
):
total_tokens = 0
if isinstance(prompt_tokens, (int, float)):
total_tokens += int(prompt_tokens)
if isinstance(completion_tokens, (int, float)):
total_tokens += int(completion_tokens)
elapsed_delta = getattr(response, "elapsed", None)
if elapsed_delta is not None:
latency_ms = int(elapsed_delta.total_seconds() * 1000)
else:
latency_ms = int((time.perf_counter() - start_time) * 1000)
latency_ms = max(latency_ms, 0)
result = Result(
output=output_text,
parsed_output=parsed_output,
tokens_used=int(total_tokens)
if isinstance(total_tokens, (int, float))
else None,
latency_ms=latency_ms,
)
request_parameters = {
key: value for key, value in payload.items() if key not in {"model", "messages"}
}
usage_log = LLMUsageLog(
provider_id=provider.id,
model_id=model.id if model else None,
model_name=model.name if model else payload.get("model", context.model_name),
source="test_run",
prompt_id=context.prompt_id,
prompt_version_id=context.prompt_version_id,
messages=payload.get("messages"),
parameters=request_parameters or None,
response_text=output_text or None,
temperature=request_parameters.get("temperature"),
latency_ms=latency_ms,
prompt_tokens=int(prompt_tokens)
if isinstance(prompt_tokens, (int, float))
else None,
completion_tokens=int(completion_tokens)
if isinstance(completion_tokens, (int, float))
else None,
total_tokens=int(total_tokens)
if isinstance(total_tokens, (int, float))
else None,
)
return result, usage_log
def _format_error_detail(payload: Any) -> str:
if isinstance(payload, Mapping):
error_obj = payload.get("error")
if isinstance(error_obj, Mapping):
message_parts: list[str] = []
code = error_obj.get("code")
if isinstance(code, str) and code.strip():
message_parts.append(code.strip())
error_type = error_obj.get("type")
if isinstance(error_type, str) and error_type.strip():
message_parts.append(error_type.strip())
message = error_obj.get("message")
if isinstance(message, str) and message.strip():
prefix = " | ".join(message_parts)
return f"{prefix}: {message.strip()}" if prefix else message.strip()
message = payload.get("message")
if isinstance(message, str) and message.strip():
return message.strip()
try:
return json.dumps(payload, ensure_ascii=False)
except Exception: # pragma: no cover - 容错
return str(payload)
return str(payload)
def _try_parse_json(text: str) -> Any:
try:
return json.loads(text)
except (TypeError, json.JSONDecodeError):
return None
def _persist_run_artifacts(db: Session, result: Result, usage_log: LLMUsageLog) -> None:
db.add(result)
db.add(usage_log)
db.flush()
__all__ = ["execute_test_run", "ensure_completed", "TestRunExecutionError"]
```
## /docker-compose.yml
```yml path="/docker-compose.yml"
services:
postgres:
image: postgres:15-alpine
container_name: promptworks-postgres
environment:
POSTGRES_DB: promptworks
POSTGRES_USER: promptworks
POSTGRES_PASSWORD: promptworks
ports:
- "15432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U promptworks -d promptworks"]
interval: 5s
timeout: 5s
retries: 10
redis:
image: redis:7-alpine
container_name: promptworks-redis
command: ["redis-server", "--save", "60", "1", "--loglevel", "warning"]
ports:
- "6379:6379"
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 3s
retries: 10
backend:
build:
context: .
dockerfile: Dockerfile
container_name: promptworks-backend
environment:
APP_ENV: production
APP_TEST_MODE: "false"
API_V1_STR: /api/v1
PROJECT_NAME: PromptWorks
DATABASE_URL: postgresql+psycopg://promptworks:promptworks@postgres:5432/promptworks
REDIS_URL: redis://redis:6379/0
BACKEND_CORS_ORIGINS: '["http://localhost:18080"]'
BACKEND_CORS_ALLOW_CREDENTIALS: "true"
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
ports:
- "8000:8000"
frontend:
build:
context: ./frontend
dockerfile: Dockerfile
args:
VITE_API_BASE_URL: http://localhost:8000/api/v1
container_name: promptworks-frontend
depends_on:
backend:
condition: service_started
ports:
- "18080:80"
volumes:
postgres_data:
redis_data:
```
## /docker/backend/entrypoint.sh
```sh path="/docker/backend/entrypoint.sh"
#!/usr/bin/env bash
set -euo pipefail
APP_HOST=${APP_HOST:-0.0.0.0}
APP_PORT=${APP_PORT:-8000}
echo "等待数据库迁移..."
until alembic upgrade head; do
echo "数据库暂未就绪,3 秒后重试..."
sleep 3
done
echo "启动 FastAPI 服务,监听 ${APP_HOST}:${APP_PORT}"
exec uvicorn app.main:app --host "${APP_HOST}" --port "${APP_PORT}"
```
## /docs/logo.jpg
Binary file available at https://raw.githubusercontent.com/YellowSeaa/PromptWorks/refs/heads/main/docs/logo.jpg
## /frontend/.gitignore
```gitignore path="/frontend/.gitignore"
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?
```
## /frontend/Dockerfile
``` path="/frontend/Dockerfile"
FROM node:20-alpine AS build
WORKDIR /app
COPY package.json ./
COPY index.html ./
COPY tsconfig.json vite.config.ts ./
COPY src ./src
COPY public ./public
ARG VITE_API_BASE_URL="http://localhost:8000/api/v1"
ENV VITE_API_BASE_URL=${VITE_API_BASE_URL}
RUN npm install && npm run build
FROM nginx:1.27-alpine
COPY nginx.conf /etc/nginx/conf.d/default.conf
COPY --from=build /app/dist /usr/share/nginx/html
EXPOSE 80
CMD ["nginx", "-g", "daemon off;"]
```
## /frontend/index.html
```html path="/frontend/index.html"
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/png" href="/logo.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>PromptWorks</title>
</head>
<body>
<div id="app"></div>
<script type="module" src="/src/main.ts"></script>
</body>
</html>
```
## /tests/__init__.py
```py path="/tests/__init__.py"
```
The content has been capped at 50000 tokens. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.