``` ├── .dockerignore ├── .env.example ├── .github/ ├── FUNDING.yml ├── ISSUE_TEMPLATE/ ├── bug_report.md ├── config.yml ├── feature_request.md ├── workflows/ ├── docker-image-amd64.yml ├── docker-image-arm64.yml ├── linux-release.yml ├── macos-release.yml ├── windows-release.yml ├── .gitignore ├── BT.md ├── Dockerfile ├── LICENSE ├── Midjourney.md ├── README.en.md ├── README.md ├── Rerank.md ├── Suno.md ├── VERSION ├── bin/ ├── migration_v0.2-v0.3.sql ├── migration_v0.3-v0.4.sql ├── time_test.sh ├── common/ ├── constants.go ├── crypto.go ├── custom-event.go ├── database.go ├── email-outlook-auth.go ├── email.go ├── embed-file-system.go ├── env.go ├── gin.go ├── go-channel.go ├── gopool.go ├── init.go ├── json.go ├── limiter/ ├── limiter.go ├── lua/ ├── rate_limit.lua ├── logger.go ├── pprof.go ├── rate-limit.go ├── redis.go ├── str.go ├── topup-ratio.go ├── utils.go ├── validate.go ├── verification.go ├── constant/ ├── cache_key.go ├── channel_setting.go ├── context_key.go ├── env.go ├── finish_reason.go ├── midjourney.go ├── setup.go ├── task.go ├── user_setting.go ├── controller/ ├── billing.go ├── channel-billing.go ├── channel-test.go ├── channel.go ├── github.go ├── group.go ├── image.go ├── linuxdo.go ├── log.go ├── midjourney.go ├── misc.go ├── model.go ├── oidc.go ├── option.go ├── playground.go ├── pricing.go ├── redemption.go ├── relay.go ├── setup.go ├── task.go ├── telegram.go ├── token.go ├── topup.go ├── usedata.go ├── user.go ``` ## /.dockerignore ```dockerignore path="/.dockerignore" .github .git *.md .vscode .gitignore Makefile docs ``` ## /.env.example ```example path="/.env.example" # 端口号 # PORT=3000 # 前端基础URL # FRONTEND_BASE_URL=https://your-frontend-url.com # 调试相关配置 # 启用pprof # ENABLE_PPROF=true # 数据库相关配置 # 数据库连接字符串 # SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true # 日志数据库连接字符串 # LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true # SQLite数据库路径 # SQLITE_PATH=/path/to/sqlite.db # 数据库最大空闲连接数 # SQL_MAX_IDLE_CONNS=100 # 数据库最大打开连接数 # SQL_MAX_OPEN_CONNS=1000 # 数据库连接最大生命周期(秒) # SQL_MAX_LIFETIME=60 # 缓存相关配置 # Redis连接字符串 # REDIS_CONN_STRING=redis://user:password@localhost:6379/0 # 同步频率(单位:秒) # SYNC_FREQUENCY=60 # 内存缓存启用 # MEMORY_CACHE_ENABLED=true # 渠道更新频率(单位:秒) # CHANNEL_UPDATE_FREQUENCY=30 # 批量更新启用 # BATCH_UPDATE_ENABLED=true # 批量更新间隔(单位:秒) # BATCH_UPDATE_INTERVAL=5 # 任务和功能配置 # 更新任务启用 # UPDATE_TASK=true # 会话密钥 # SESSION_SECRET=random_string # 其他配置 # 渠道测试频率(单位:秒) # CHANNEL_TEST_FREQUENCY=10 # 生成默认token # GENERATE_DEFAULT_TOKEN=false # Cohere 安全设置 # COHERE_SAFETY_SETTING=NONE # 是否统计图片token # GET_MEDIA_TOKEN=true # 是否在非流(stream=false)情况下统计图片token # GET_MEDIA_TOKEN_NOT_STREAM=true # 设置 Dify 渠道是否输出工作流和节点信息到客户端 # DIFY_DEBUG=true # 设置流式一次回复的超时时间 # STREAMING_TIMEOUT=90 # 节点类型 # 如果是主节点则为master # NODE_TYPE=master ``` ## /.github/FUNDING.yml ```yml path="/.github/FUNDING.yml" # These are supported funding model platforms github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] patreon: # Replace with a single Patreon username open_collective: # Replace with a single Open Collective username ko_fi: # Replace with a single Ko-fi username tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry liberapay: # Replace with a single Liberapay username issuehunt: # Replace with a single IssueHunt username otechie: # Replace with a single Otechie username custom: ['https://afdian.com/a/new-api'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] ``` ## /.github/ISSUE_TEMPLATE/bug_report.md --- name: 报告问题 about: 使用简练详细的语言描述你遇到的问题 title: '' labels: bug assignees: '' --- **例行检查** [//]: # (方框内删除已有的空格,填 x 号) + [ ] 我已确认目前没有类似 issue + [ ] 我已确认我已升级到最新版本 + [ ] 我已完整查看过项目 README,尤其是常见问题部分 + [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈 + [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭** **问题描述** **复现步骤** **预期结果** **相关截图** 如果没有的话,请删除此节。 ## /.github/ISSUE_TEMPLATE/config.yml ```yml path="/.github/ISSUE_TEMPLATE/config.yml" blank_issues_enabled: false contact_links: - name: 项目群聊 url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg about: QQ 群:629454374 ``` ## /.github/ISSUE_TEMPLATE/feature_request.md --- name: 功能请求 about: 使用简练详细的语言描述希望加入的新功能 title: '' labels: enhancement assignees: '' --- **例行检查** [//]: # (方框内删除已有的空格,填 x 号) + [ ] 我已确认目前没有类似 issue + [ ] 我已确认我已升级到最新版本 + [ ] 我已完整查看过项目 README,已确定现有版本无法满足需求 + [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈 + [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭** **功能描述** **应用场景** ## /.github/workflows/docker-image-amd64.yml ```yml path="/.github/workflows/docker-image-amd64.yml" name: Publish Docker image (amd64) on: push: tags: - '*' workflow_dispatch: inputs: name: description: 'reason' required: false jobs: push_to_registries: name: Push Docker image to multiple registries runs-on: ubuntu-latest permissions: packages: write contents: read steps: - name: Check out the repo uses: actions/checkout@v4 - name: Save version info run: | git describe --tags > VERSION - name: Log in to Docker Hub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Log in to the Container registry uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - name: Extract metadata (tags, labels) for Docker id: meta uses: docker/metadata-action@v5 with: images: | calciumion/new-api ghcr.io/${{ github.repository }} - name: Build and push Docker images uses: docker/build-push-action@v5 with: context: . push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} ``` ## /.github/workflows/docker-image-arm64.yml ```yml path="/.github/workflows/docker-image-arm64.yml" name: Publish Docker image (arm64) on: push: tags: - '*' workflow_dispatch: inputs: name: description: 'reason' required: false jobs: push_to_registries: name: Push Docker image to multiple registries runs-on: ubuntu-latest permissions: packages: write contents: read steps: - name: Check out the repo uses: actions/checkout@v4 - name: Save version info run: | git describe --tags > VERSION - name: Set up QEMU uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Log in to Docker Hub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Log in to the Container registry uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - name: Extract metadata (tags, labels) for Docker id: meta uses: docker/metadata-action@v5 with: images: | calciumion/new-api ghcr.io/${{ github.repository }} - name: Build and push Docker images uses: docker/build-push-action@v5 with: context: . platforms: linux/amd64,linux/arm64 push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} ``` ## /.github/workflows/linux-release.yml ```yml path="/.github/workflows/linux-release.yml" name: Linux Release permissions: contents: write on: push: tags: - '*' - '!*-alpha*' jobs: release: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v3 with: fetch-depth: 0 - uses: actions/setup-node@v3 with: node-version: 18 - name: Build Frontend env: CI: "" run: | cd web npm install REACT_APP_VERSION=$(git describe --tags) npm run build cd .. - name: Set up Go uses: actions/setup-go@v3 with: go-version: '>=1.18.0' - name: Build Backend (amd64) run: | go mod download go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api - name: Build Backend (arm64) run: | sudo apt-get update DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64 - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') with: files: | one-api one-api-arm64 draft: true generate_release_notes: true env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ``` ## /.github/workflows/macos-release.yml ```yml path="/.github/workflows/macos-release.yml" name: macOS Release permissions: contents: write on: push: tags: - '*' - '!*-alpha*' jobs: release: runs-on: macos-latest steps: - name: Checkout uses: actions/checkout@v3 with: fetch-depth: 0 - uses: actions/setup-node@v3 with: node-version: 18 - name: Build Frontend env: CI: "" run: | cd web npm install REACT_APP_VERSION=$(git describe --tags) npm run build cd .. - name: Set up Go uses: actions/setup-go@v3 with: go-version: '>=1.18.0' - name: Build Backend run: | go mod download go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') with: files: one-api-macos draft: true generate_release_notes: true env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ``` ## /.github/workflows/windows-release.yml ```yml path="/.github/workflows/windows-release.yml" name: Windows Release permissions: contents: write on: push: tags: - '*' - '!*-alpha*' jobs: release: runs-on: windows-latest defaults: run: shell: bash steps: - name: Checkout uses: actions/checkout@v3 with: fetch-depth: 0 - uses: actions/setup-node@v3 with: node-version: 18 - name: Build Frontend env: CI: "" run: | cd web npm install REACT_APP_VERSION=$(git describe --tags) npm run build cd .. - name: Set up Go uses: actions/setup-go@v3 with: go-version: '>=1.18.0' - name: Build Backend run: | go mod download go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') with: files: one-api.exe draft: true generate_release_notes: true env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ``` ## /.gitignore ```gitignore path="/.gitignore" .idea .vscode upload *.exe *.db build *.db-journal logs web/dist .env one-api .DS_Store tiktoken_cache ``` ## /BT.md 密钥为环境变量SESSION_SECRET ![8285bba413e770fe9620f1bf9b40d44e](https://github.com/user-attachments/assets/7a6fc03e-c457-45e4-b8f9-184508fc26b0) ## /Dockerfile ``` path="/Dockerfile" FROM oven/bun:latest AS builder WORKDIR /build COPY web/package.json . RUN bun install COPY ./web . COPY ./VERSION . RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build FROM golang:alpine AS builder2 ENV GO111MODULE=on \ CGO_ENABLED=0 \ GOOS=linux WORKDIR /build ADD go.mod go.sum ./ RUN go mod download COPY . . COPY --from=builder /build/dist ./web/dist RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-api FROM alpine RUN apk update \ && apk upgrade \ && apk add --no-cache ca-certificates tzdata ffmpeg \ && update-ca-certificates COPY --from=builder2 /build/one-api / EXPOSE 3000 WORKDIR /data ENTRYPOINT ["/one-api"] ``` ## /LICENSE ``` path="/LICENSE" Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ``` ## /Midjourney.md # Midjourney Proxy API文档 **简介**:Midjourney Proxy API文档 ## 接口列表 支持的接口如下: + [x] /mj/submit/imagine + [x] /mj/submit/change + [x] /mj/submit/blend + [x] /mj/submit/describe + [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**) + [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址) + [x] /task/list-by-condition + [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同) + [x] /mj/submit/modal + [x] /mj/submit/shorten + [x] /mj/task/{id}/image-seed + [x] /mj/insight-face/swap (InsightFace) ## 模型列表 ### midjourney-proxy支持 - mj_imagine (绘图) - mj_variation (变换) - mj_reroll (重绘) - mj_blend (混合) - mj_upscale (放大) - mj_describe (图生文) ### 仅midjourney-proxy-plus支持 - mj_zoom (比例变焦) - mj_shorten (提示词缩短) - mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加) - mj_inpaint (局部重绘提交,必须和mj_modal一同添加) - mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加) - mj_high_variation (强变换) - mj_low_variation (弱变换) - mj_pan (平移) - swap_face (换脸) ## 模型价格设置(在设置-运营设置-模型固定价格设置中设置) ```json { "mj_imagine": 0.1, "mj_variation": 0.1, "mj_reroll": 0.1, "mj_blend": 0.1, "mj_modal": 0.1, "mj_zoom": 0.1, "mj_shorten": 0.1, "mj_high_variation": 0.1, "mj_low_variation": 0.1, "mj_pan": 0.1, "mj_inpaint": 0, "mj_custom_zoom": 0, "mj_describe": 0.05, "mj_upscale": 0.05, "swap_face": 0.05 } ``` 其中mj_inpaint和mj_custom_zoom的价格设置为0,是因为这两个模型需要搭配mj_modal使用,所以价格由mj_modal决定。 ## 渠道设置 ### 对接 midjourney-proxy(plus) 1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy) 2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus** ,模型请参考上方模型列表 3. **代理**填写midjourney-proxy部署的地址,例如:http://localhost:8080 4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填 ### 对接上游new api 1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表 2. **代理**填写上游new api的地址,例如:http://localhost:3000 3. 密钥填写上游new api的密钥 ## /README.en.md
![new-api](/web/public/logo.png) # New API 🍥 Next Generation LLM Gateway and AI Asset Management System Calcium-Ion%2Fnew-api | Trendshift

license release docker docker GoReportCard

## 📝 Project Description > [!NOTE] > This is an open-source project developed based on [One API](https://github.com/songquanpeng/one-api) > [!IMPORTANT] > - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and relevant laws and regulations. Not to be used for illegal purposes. > - This project is for personal learning only. Stability is not guaranteed, and no technical support is provided. ## ✨ Key Features 1. 🎨 New UI interface (some interfaces pending update) 2. 🌍 Multi-language support (work in progress) 3. 🎨 Added [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface support, [Integration Guide](Midjourney.md) 4. 💰 Online recharge support, configurable in system settings: - [x] EasyPay 5. 🔍 Query usage quota by key: - Works with [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool) 6. 📑 Configurable items per page in pagination 7. 🔄 Compatible with original One API database (one-api.db) 8. 💵 Support per-request model pricing, configurable in System Settings - Operation Settings 9. ⚖️ Support channel **weighted random** selection 10. 📈 Data dashboard (console) 11. 🔒 Configurable model access per token 12. 🤖 Telegram authorization login support: 1. System Settings - Configure Login Registration - Allow Telegram Login 2. Send /setdomain command to [@Botfather](https://t.me/botfather) 3. Select your bot, then enter http(s)://your-website/login 4. Telegram Bot name is the bot username without @ 13. 🎵 Added [Suno API](https://github.com/Suno-API/Suno-API) interface support, [Integration Guide](Suno.md) 14. 🔄 Support for Rerank models, compatible with Cohere and Jina, can integrate with Dify, [Integration Guide](Rerank.md) 15. ⚡ **[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - Support for OpenAI's Realtime API, including Azure channels 16. 🧠 Support for setting reasoning effort through model name suffix: - Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`) - Add suffix `-medium` to set medium reasoning effort - Add suffix `-low` to set low reasoning effort 17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `` tags and concatenated to the content returned. 18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings` 19. 💰 Cache billing support, when enabled can charge a configurable ratio for cache hits: 1. Set `Prompt Cache Ratio` in `System Settings -> Operation Settings` 2. Set `Prompt Cache Ratio` in channel settings, range 0-1 (e.g., 0.5 means 50% charge on cache hits) 3. Supported channels: - [x] OpenAI - [x] Azure - [x] DeepSeek - [ ] Claude ## Model Support This version additionally supports: 1. Third-party model **gpts** (gpt-4-gizmo-*) 2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [Integration Guide](Midjourney.md) 3. Custom channels with full API URL support 4. [Suno API](https://github.com/Suno-API/Suno-API) interface, [Integration Guide](Suno.md) 5. Rerank models, supporting [Cohere](https://cohere.ai/) and [Jina](https://jina.ai/), [Integration Guide](Rerank.md) 6. Dify You can add custom models gpt-4-gizmo-* in channels. These are third-party models and cannot be called with official OpenAI keys. ## Additional Configurations Beyond One API - `GENERATE_DEFAULT_TOKEN`: Generate initial token for new users, default `false` - `STREAMING_TIMEOUT`: Set streaming response timeout, default 60 seconds - `DIFY_DEBUG`: Output workflow and node info to client for Dify channel, default `true` - `FORCE_STREAM_OPTION`: Override client stream_options parameter, default `true` - `GET_MEDIA_TOKEN`: Calculate image tokens, default `true` - `GET_MEDIA_TOKEN_NOT_STREAM`: Calculate image tokens in non-stream mode, default `true` - `UPDATE_TASK`: Update async tasks (Midjourney, Suno), default `true` - `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated - `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE` - `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable - `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20` - `CRYPTO_SECRET`: Encryption key for encrypting database content - `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10` - `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2` ## Deployment > [!TIP] > Latest Docker image: `calciumion/new-api:latest` > Default account: root, password: 123456 ### Multi-Server Deployment - Must set `SESSION_SECRET` environment variable, otherwise login state will not be consistent across multiple servers. - If using a public Redis, must set `CRYPTO_SECRET` environment variable, otherwise Redis content will not be able to be obtained in multi-server deployment. ### Requirements - Local database (default): SQLite (Docker deployment must mount `/data` directory) - Remote database: MySQL >= 5.7.8, PgSQL >= 9.6 ### Deployment with BT Panel Install BT Panel (**version 9.2.0** or above) from [BT Panel Official Website](https://www.bt.cn/new/download.html), choose the stable version script to download and install. After installation, log in to BT Panel and click Docker in the menu bar. First-time access will prompt to install Docker service. Click Install Now and follow the prompts to complete installation. After installation, find **New-API** in the app store, click install, configure basic options to complete installation. [Pictorial Guide](BT.md) ### Docker Deployment ### Using Docker Compose (Recommended) ```shell # Clone project git clone https://github.com/Calcium-Ion/new-api.git cd new-api # Edit docker-compose.yml as needed # nano docker-compose.yml # vim docker-compose.yml # Start docker-compose up -d ``` #### Update Version ```shell docker-compose pull docker-compose up -d ``` ### Direct Docker Image Usage ```shell # SQLite deployment: docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest # MySQL deployment (add -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"), modify database connection parameters as needed # Example: docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest ``` #### Update Version ```shell # Pull the latest image docker pull calciumion/new-api:latest # Stop and remove the old container docker stop new-api docker rm new-api # Run the new container with the same parameters as before docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest ``` Alternatively, you can use Watchtower for automatic updates (not recommended, may cause database incompatibility): ```shell docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR ``` ## Channel Retry Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**. If retry is enabled, the system will automatically use the next priority channel for the same request after a failed request. ### Cache Configuration 1. `REDIS_CONN_STRING`: Use Redis as cache + Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` 2. `MEMORY_CACHE_ENABLED`: Enable memory cache, default `false` + Example: `MEMORY_CACHE_ENABLED=true` ### Why Some Errors Don't Retry Error codes 400, 504, 524 won't retry ### To Enable Retry for 400 In `Channel->Edit`, set `Status Code Override` to: ```json { "400": "500" } ``` ## Integration Guides - [Midjourney Integration](Midjourney.md) - [Suno Integration](Suno.md) ## Related Projects - [One API](https://github.com/songquanpeng/one-api): Original project - [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy): Midjourney interface support - [chatnio](https://github.com/Deeptrain-Community/chatnio): Next-gen AI B/C solution - [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool): Query usage quota by key ## 🌟 Star History [![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date) ## /README.md

中文 | English

![new-api](/web/public/logo.png) # New API 🍥新一代大模型网关与AI资产管理系统 Calcium-Ion%2Fnew-api | Trendshift

license release docker docker GoReportCard

## 📝 项目说明 > [!NOTE] > 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发 > [!IMPORTANT] > - 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。 > - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 > - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 ## 📚 文档 详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/) ## ✨ 主要特性 New API提供了丰富的功能,详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction): 1. 🎨 全新的UI界面 2. 🌍 多语言支持 3. 💰 支持在线充值功能(易支付) 4. 🔍 支持用key查询使用额度(配合[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)) 5. 🔄 兼容原版One API的数据库 6. 💵 支持模型按次数收费 7. ⚖️ 支持渠道加权随机 8. 📈 数据看板(控制台) 9. 🔒 令牌分组、模型限制 10. 🤖 支持更多授权登陆方式(LinuxDO,Telegram、OIDC) 11. 🔄 支持Rerank模型(Cohere和Jina),[接口文档](https://docs.newapi.pro/api/jinaai-rerank) 12. ⚡ 支持OpenAI Realtime API(包括Azure渠道),[接口文档](https://docs.newapi.pro/api/openai-realtime) 13. ⚡ 支持Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat) 14. 支持使用路由/chat2link进入聊天界面 15. 🧠 支持通过模型名称后缀设置 reasoning effort: 1. OpenAI o系列模型 - 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`) - 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`) - 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`) 2. Claude 思考模型 - 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`) 16. 🔄 思考转内容功能 17. 🔄 针对用户的模型限流功能 18. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费: 1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项 2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费 3. 支持的渠道: - [x] OpenAI - [x] Azure - [x] DeepSeek - [x] Claude ## 模型支持 此版本支持多种模型,详情请参考[接口文档-中继接口](https://docs.newapi.pro/api): 1. 第三方模型 **gpts** (gpt-4-gizmo-*) 2. 第三方渠道[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[接口文档](https://docs.newapi.pro/api/midjourney-proxy-image) 3. 第三方渠道[Suno API](https://github.com/Suno-API/Suno-API)接口,[接口文档](https://docs.newapi.pro/api/suno-music) 4. 自定义渠道,支持填入完整调用地址 5. Rerank模型([Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)),[接口文档](https://docs.newapi.pro/api/jinaai-rerank) 6. Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat) 7. Dify,当前仅支持chatflow ## 环境变量配置 详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables): - `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false` - `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒 - `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true` - `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true` - `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true` - `GET_MEDIA_TOKEN_NOT_STREAM`:非流情况下是否统计图片token,默认 `true` - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认 `true` - `COHERE_SAFETY_SETTING`:Cohere模型安全设置,可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认 `NONE` - `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认 `16` - `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位MB,默认 `20` - `CRYPTO_SECRET`:加密密钥,用于加密数据库内容 - `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2024-12-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟 - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2` ## 部署 详细部署指南请参考[安装指南-部署方式](https://docs.newapi.pro/installation): > [!TIP] > 最新版Docker镜像:`calciumion/new-api:latest` ### 多机部署注意事项 - 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致 - 如果公用Redis,必须设置 `CRYPTO_SECRET`,否则会导致多机部署时Redis内容无法获取 ### 部署要求 - 本地数据库(默认):SQLite(Docker部署必须挂载`/data`目录) - 远程数据库:MySQL版本 >= 5.7.8,PgSQL版本 >= 9.6 ### 部署方式 #### 使用宝塔面板Docker功能部署 安装宝塔面板(**9.2.0版本**及以上),在应用商店中找到**New-API**安装即可。 [图文教程](BT.md) #### 使用Docker Compose部署(推荐) ```shell # 下载项目 git clone https://github.com/Calcium-Ion/new-api.git cd new-api # 按需编辑docker-compose.yml # 启动 docker-compose up -d ``` #### 直接使用Docker镜像 ```shell # 使用SQLite docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest # 使用MySQL docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest ``` ## 渠道重试与缓存 渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。 ### 缓存设置方法 1. `REDIS_CONN_STRING`:设置Redis作为缓存 2. `MEMORY_CACHE_ENABLED`:启用内存缓存(设置了Redis则无需手动设置) ## 接口文档 详细接口文档请参考[接口文档](https://docs.newapi.pro/api): - [聊天接口(Chat)](https://docs.newapi.pro/api/openai-chat) - [图像接口(Image)](https://docs.newapi.pro/api/openai-image) - [重排序接口(Rerank)](https://docs.newapi.pro/api/jinaai-rerank) - [实时对话接口(Realtime)](https://docs.newapi.pro/api/openai-realtime) - [Claude聊天接口(messages)](https://docs.newapi.pro/api/anthropic-chat) ## 相关项目 - [One API](https://github.com/songquanpeng/one-api):原版项目 - [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy):Midjourney接口支持 - [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代AI一站式B/C端解决方案 - [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool):用key查询使用额度 其他基于New API的项目: - [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版 - [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本 ## 帮助支持 如有问题,请参考[帮助支持](https://docs.newapi.pro/support): - [社区交流](https://docs.newapi.pro/support/community-interaction) - [反馈问题](https://docs.newapi.pro/support/feedback-issues) - [常见问题](https://docs.newapi.pro/support/faq) ## 🌟 Star History [![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date) ## /Rerank.md # Rerank API文档 **简介**:Rerank API文档 ## 接入Dify 模型供应商选择Jina,按要求填写模型信息即可接入Dify。 ## 请求方式 Post: /v1/rerank Request: ```json { "model": "jina-reranker-v2-base-multilingual", "query": "What is the capital of the United States?", "top_n": 3, "documents": [ "Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." ] } ``` Response: ```json { "results": [ { "document": { "text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district." }, "index": 2, "relevance_score": 0.9999702 }, { "document": { "text": "Carson City is the capital city of the American state of Nevada." }, "index": 0, "relevance_score": 0.67800725 }, { "document": { "text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages." }, "index": 3, "relevance_score": 0.02800752 } ], "usage": { "prompt_tokens": 158, "completion_tokens": 0, "total_tokens": 158 } } ``` ## /Suno.md # Suno API文档 **简介**:Suno API文档 ## 接口列表 支持的接口如下: + [x] /suno/submit/music + [x] /suno/submit/lyrics + [x] /suno/fetch + [x] /suno/fetch/:id ## 模型列表 ### Suno API支持 - suno_music (自定义模式、灵感模式、续写) - suno_lyrics (生成歌词) ## 模型价格设置(在设置-运营设置-模型固定价格设置中设置) ```json { "suno_music": 0.3, "suno_lyrics": 0.01 } ``` ## 渠道设置 ### 对接 Suno API 1. 部署 Suno API,并配置好suno账号等(强烈建议设置密钥),[项目地址](https://github.com/Suno-API/Suno-API) 2. 在渠道管理中添加渠道,渠道类型选择**Suno API** ,模型请参考上方模型列表 3. **代理**填写 Suno API 部署的地址,例如:http://localhost:8080 4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填 ### 对接上游new api 1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型 2. **代理**填写上游new api的地址,例如:http://localhost:3000 3. 密钥填写上游new api的密钥 ## /VERSION ``` path="/VERSION" ``` ## /bin/migration_v0.2-v0.3.sql ```sql path="/bin/migration_v0.2-v0.3.sql" UPDATE users SET quota = quota + ( SELECT SUM(remain_quota) FROM tokens WHERE tokens.user_id = users.id ) ``` ## /bin/migration_v0.3-v0.4.sql ```sql path="/bin/migration_v0.3-v0.4.sql" INSERT INTO abilities (`group`, model, channel_id, enabled) SELECT c.`group`, m.model, c.id, 1 FROM channels c CROSS JOIN ( SELECT 'gpt-3.5-turbo' AS model UNION ALL SELECT 'gpt-3.5-turbo-0301' AS model UNION ALL SELECT 'gpt-4' AS model UNION ALL SELECT 'gpt-4-0314' AS model ) AS m WHERE c.status = 1 AND NOT EXISTS ( SELECT 1 FROM abilities a WHERE a.`group` = c.`group` AND a.model = m.model AND a.channel_id = c.id ); ``` ## /bin/time_test.sh ```sh path="/bin/time_test.sh" #!/bin/bash if [ $# -lt 3 ]; then echo "Usage: time_test.sh []" exit 1 fi domain=$1 key=$2 count=$3 model=${4:-"gpt-3.5-turbo"} # 设置默认模型为 gpt-3.5-turbo total_time=0 times=() for ((i=1; i<=count; i++)); do result=$(curl -o /dev/null -s -w "%{http_code} %{time_total}\\n" \ https://"$domain"/v1/chat/completions \ -H "Content-Type: application/json" \ -H "Authorization: Bearer $key" \ -d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "'"$model"'", "stream": false, "max_tokens": 1}') http_code=$(echo "$result" | awk '{print $1}') time=$(echo "$result" | awk '{print $2}') echo "HTTP status code: $http_code, Time taken: $time" total_time=$(bc <<< "$total_time + $time") times+=("$time") done average_time=$(echo "scale=4; $total_time / $count" | bc) sum_of_squares=0 for time in "${times[@]}"; do difference=$(echo "scale=4; $time - $average_time" | bc) square=$(echo "scale=4; $difference * $difference" | bc) sum_of_squares=$(echo "scale=4; $sum_of_squares + $square" | bc) done standard_deviation=$(echo "scale=4; sqrt($sum_of_squares / $count)" | bc) echo "Average time: $average_time±$standard_deviation" ``` ## /common/constants.go ```go path="/common/constants.go" package common import ( //"os" //"strconv" "sync" "time" "github.com/google/uuid" ) var StartTime = time.Now().Unix() // unit: second var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var SystemName = "New API" var Footer = "" var Logo = "" var TopUpLink = "" // var ChatLink = "" // var ChatLink2 = "" var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens var DisplayInCurrencyEnabled = true var DisplayTokenStatEnabled = true var DrawingEnabled = true var TaskEnabled = true var DataExportEnabled = true var DataExportInterval = 5 // unit: minute var DataExportDefaultTime = "hour" // unit: minute var DefaultCollapseSidebar = false // default value of collapse sidebar // Any options with "Secret", "Token" in its key won't be return by GetOptions var SessionSecret = uuid.New().String() var CryptoSecret = uuid.New().String() var OptionMap map[string]string var OptionMapRWMutex sync.RWMutex var ItemsPerPage = 10 var MaxRecentItems = 100 var PasswordLoginEnabled = true var PasswordRegisterEnabled = true var EmailVerificationEnabled = false var GitHubOAuthEnabled = false var LinuxDOOAuthEnabled = false var WeChatAuthEnabled = false var TelegramOAuthEnabled = false var TurnstileCheckEnabled = false var RegisterEnabled = true var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制 var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制 var EmailDomainWhitelist = []string{ "gmail.com", "163.com", "126.com", "qq.com", "outlook.com", "hotmail.com", "icloud.com", "yahoo.com", "foxmail.com", } var EmailLoginAuthServerList = []string{ "smtp.sendcloud.net", "smtp.azurecomm.net", } var DebugEnabled bool var MemoryCacheEnabled bool var LogConsumeEnabled = true var SMTPServer = "" var SMTPPort = 587 var SMTPSSLEnabled = false var SMTPAccount = "" var SMTPFrom = "" var SMTPToken = "" var GitHubClientId = "" var GitHubClientSecret = "" var LinuxDOClientId = "" var LinuxDOClientSecret = "" var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" var TurnstileSiteKey = "" var TurnstileSecretKey = "" var TelegramBotToken = "" var TelegramBotName = "" var QuotaForNewUser = 0 var QuotaForInviter = 0 var QuotaForInvitee = 0 var ChannelDisableThreshold = 5.0 var AutomaticDisableChannelEnabled = false var AutomaticEnableChannelEnabled = false var QuotaRemindThreshold = 1000 var PreConsumedQuota = 500 var RetryTimes = 0 //var RootUserEmail = "" var IsMasterNode bool var requestInterval int var RequestInterval time.Duration var SyncFrequency int // unit is second var BatchUpdateEnabled = false var BatchUpdateInterval int var RelayTimeout int // unit is second var GeminiSafetySetting string // https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT var CohereSafetySetting string const ( RequestIdKey = "X-Oneapi-Request-Id" ) const ( RoleGuestUser = 0 RoleCommonUser = 1 RoleAdminUser = 10 RoleRootUser = 100 ) func IsValidateRole(role int) bool { return role == RoleGuestUser || role == RoleCommonUser || role == RoleAdminUser || role == RoleRootUser } var ( FileUploadPermission = RoleGuestUser FileDownloadPermission = RoleGuestUser ImageUploadPermission = RoleGuestUser ImageDownloadPermission = RoleGuestUser ) // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( GlobalApiRateLimitEnable bool GlobalApiRateLimitNum int GlobalApiRateLimitDuration int64 GlobalWebRateLimitEnable bool GlobalWebRateLimitNum int GlobalWebRateLimitDuration int64 UploadRateLimitNum = 10 UploadRateLimitDuration int64 = 60 DownloadRateLimitNum = 10 DownloadRateLimitDuration int64 = 60 CriticalRateLimitNum = 20 CriticalRateLimitDuration int64 = 20 * 60 ) var RateLimitKeyExpirationDuration = 20 * time.Minute const ( UserStatusEnabled = 1 // don't use 0, 0 is the default value! UserStatusDisabled = 2 // also don't use 0 ) const ( TokenStatusEnabled = 1 // don't use 0, 0 is the default value! TokenStatusDisabled = 2 // also don't use 0 TokenStatusExpired = 3 TokenStatusExhausted = 4 ) const ( RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value! RedemptionCodeStatusDisabled = 2 // also don't use 0 RedemptionCodeStatusUsed = 3 // also don't use 0 ) const ( ChannelStatusUnknown = 0 ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! ChannelStatusManuallyDisabled = 2 // also don't use 0 ChannelStatusAutoDisabled = 3 ) const ( ChannelTypeUnknown = 0 ChannelTypeOpenAI = 1 ChannelTypeMidjourney = 2 ChannelTypeAzure = 3 ChannelTypeOllama = 4 ChannelTypeMidjourneyPlus = 5 ChannelTypeOpenAIMax = 6 ChannelTypeOhMyGPT = 7 ChannelTypeCustom = 8 ChannelTypeAILS = 9 ChannelTypeAIProxy = 10 ChannelTypePaLM = 11 ChannelTypeAPI2GPT = 12 ChannelTypeAIGC2D = 13 ChannelTypeAnthropic = 14 ChannelTypeBaidu = 15 ChannelTypeZhipu = 16 ChannelTypeAli = 17 ChannelTypeXunfei = 18 ChannelType360 = 19 ChannelTypeOpenRouter = 20 ChannelTypeAIProxyLibrary = 21 ChannelTypeFastGPT = 22 ChannelTypeTencent = 23 ChannelTypeGemini = 24 ChannelTypeMoonshot = 25 ChannelTypeZhipu_v4 = 26 ChannelTypePerplexity = 27 ChannelTypeLingYiWanWu = 31 ChannelTypeAws = 33 ChannelTypeCohere = 34 ChannelTypeMiniMax = 35 ChannelTypeSunoAPI = 36 ChannelTypeDify = 37 ChannelTypeJina = 38 ChannelCloudflare = 39 ChannelTypeSiliconFlow = 40 ChannelTypeVertexAi = 41 ChannelTypeMistral = 42 ChannelTypeDeepSeek = 43 ChannelTypeMokaAI = 44 ChannelTypeVolcEngine = 45 ChannelTypeBaiduV2 = 46 ChannelTypeXinference = 47 ChannelTypeXai = 48 ChannelTypeDummy // this one is only for count, do not add any channel after this ) var ChannelBaseURLs = []string{ "", // 0 "https://api.openai.com", // 1 "https://oa.api2d.net", // 2 "", // 3 "http://localhost:11434", // 4 "https://api.openai-sb.com", // 5 "https://api.openaimax.com", // 6 "https://api.ohmygpt.com", // 7 "", // 8 "https://api.caipacity.com", // 9 "https://api.aiproxy.io", // 10 "", // 11 "https://api.api2gpt.com", // 12 "https://api.aigc2d.com", // 13 "https://api.anthropic.com", // 14 "https://aip.baidubce.com", // 15 "https://open.bigmodel.cn", // 16 "https://dashscope.aliyuncs.com", // 17 "", // 18 "https://api.360.cn", // 19 "https://openrouter.ai/api", // 20 "https://api.aiproxy.io", // 21 "https://fastgpt.run/api/openapi", // 22 "https://hunyuan.tencentcloudapi.com", //23 "https://generativelanguage.googleapis.com", //24 "https://api.moonshot.cn", //25 "https://open.bigmodel.cn", //26 "https://api.perplexity.ai", //27 "", //28 "", //29 "", //30 "https://api.lingyiwanwu.com", //31 "", //32 "", //33 "https://api.cohere.ai", //34 "https://api.minimax.chat", //35 "", //36 "https://api.dify.ai", //37 "https://api.jina.ai", //38 "https://api.cloudflare.com", //39 "https://api.siliconflow.cn", //40 "", //41 "https://api.mistral.ai", //42 "https://api.deepseek.com", //43 "https://api.moka.ai", //44 "https://ark.cn-beijing.volces.com", //45 "https://qianfan.baidubce.com", //46 "", //47 "https://api.x.ai", //48 } ``` ## /common/crypto.go ```go path="/common/crypto.go" package common import ( "crypto/hmac" "crypto/sha256" "encoding/hex" "golang.org/x/crypto/bcrypt" ) func GenerateHMACWithKey(key []byte, data string) string { h := hmac.New(sha256.New, key) h.Write([]byte(data)) return hex.EncodeToString(h.Sum(nil)) } func GenerateHMAC(data string) string { h := hmac.New(sha256.New, []byte(CryptoSecret)) h.Write([]byte(data)) return hex.EncodeToString(h.Sum(nil)) } func Password2Hash(password string) (string, error) { passwordBytes := []byte(password) hashedPassword, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost) return string(hashedPassword), err } func ValidatePasswordAndHash(password string, hash string) bool { err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) return err == nil } ``` ## /common/custom-event.go ```go path="/common/custom-event.go" // Copyright 2014 Manu Martinez-Almeida. All rights reserved. // Use of this source code is governed by a MIT style // license that can be found in the LICENSE file. package common import ( "fmt" "io" "net/http" "strings" ) type stringWriter interface { io.Writer writeString(string) (int, error) } type stringWrapper struct { io.Writer } func (w stringWrapper) writeString(str string) (int, error) { return w.Writer.Write([]byte(str)) } func checkWriter(writer io.Writer) stringWriter { if w, ok := writer.(stringWriter); ok { return w } else { return stringWrapper{writer} } } // Server-Sent Events // W3C Working Draft 29 October 2009 // http://www.w3.org/TR/2009/WD-eventsource-20091029/ var contentType = []string{"text/event-stream"} var noCache = []string{"no-cache"} var fieldReplacer = strings.NewReplacer( "\n", "\\n", "\r", "\\r") var dataReplacer = strings.NewReplacer( "\n", "\n", "\r", "\\r") type CustomEvent struct { Event string Id string Retry uint Data interface{} } func encode(writer io.Writer, event CustomEvent) error { w := checkWriter(writer) return writeData(w, event.Data) } func writeData(w stringWriter, data interface{}) error { dataReplacer.WriteString(w, fmt.Sprint(data)) if strings.HasPrefix(data.(string), "data") { w.writeString("\n\n") } return nil } func (r CustomEvent) Render(w http.ResponseWriter) error { r.WriteContentType(w) return encode(w, r) } func (r CustomEvent) WriteContentType(w http.ResponseWriter) { header := w.Header() header["Content-Type"] = contentType if _, exist := header["Cache-Control"]; !exist { header["Cache-Control"] = noCache } } ``` ## /common/database.go ```go path="/common/database.go" package common var UsingSQLite = false var UsingPostgreSQL = false var UsingMySQL = false var UsingClickHouse = false var SQLitePath = "one-api.db?_busy_timeout=5000" ``` ## /common/email-outlook-auth.go ```go path="/common/email-outlook-auth.go" package common import ( "errors" "net/smtp" "strings" ) type outlookAuth struct { username, password string } func LoginAuth(username, password string) smtp.Auth { return &outlookAuth{username, password} } func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) { return "LOGIN", []byte{}, nil } func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) { if more { switch string(fromServer) { case "Username:": return []byte(a.username), nil case "Password:": return []byte(a.password), nil default: return nil, errors.New("unknown fromServer") } } return nil, nil } func isOutlookServer(server string) bool { // 兼容多地区的outlook邮箱和ofb邮箱 // 其实应该加一个Option来区分是否用LOGIN的方式登录 // 先临时兼容一下 return strings.Contains(server, "outlook") || strings.Contains(server, "onmicrosoft") } ``` ## /common/email.go ```go path="/common/email.go" package common import ( "crypto/tls" "encoding/base64" "fmt" "net/smtp" "slices" "strings" "time" ) func generateMessageID() (string, error) { split := strings.Split(SMTPFrom, "@") if len(split) < 2 { return "", fmt.Errorf("invalid SMTP account") } domain := strings.Split(SMTPFrom, "@")[1] return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain), nil } func SendEmail(subject string, receiver string, content string) error { if SMTPFrom == "" { // for compatibility SMTPFrom = SMTPAccount } id, err2 := generateMessageID() if err2 != nil { return err2 } if SMTPServer == "" && SMTPAccount == "" { return fmt.Errorf("SMTP 服务器未配置") } encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) mail := []byte(fmt.Sprintf("To: %s\r\n"+ "From: %s<%s>\r\n"+ "Subject: %s\r\n"+ "Date: %s\r\n"+ "Message-ID: %s\r\n"+ // 添加 Message-ID 头 "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), id, content)) auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) to := strings.Split(receiver, ";") var err error if SMTPPort == 465 || SMTPSSLEnabled { tlsConfig := &tls.Config{ InsecureSkipVerify: true, ServerName: SMTPServer, } conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig) if err != nil { return err } client, err := smtp.NewClient(conn, SMTPServer) if err != nil { return err } defer client.Close() if err = client.Auth(auth); err != nil { return err } if err = client.Mail(SMTPFrom); err != nil { return err } receiverEmails := strings.Split(receiver, ";") for _, receiver := range receiverEmails { if err = client.Rcpt(receiver); err != nil { return err } } w, err := client.Data() if err != nil { return err } _, err = w.Write(mail) if err != nil { return err } err = w.Close() if err != nil { return err } } else if isOutlookServer(SMTPAccount) || slices.Contains(EmailLoginAuthServerList, SMTPServer) { auth = LoginAuth(SMTPAccount, SMTPToken) err = smtp.SendMail(addr, auth, SMTPFrom, to, mail) } else { err = smtp.SendMail(addr, auth, SMTPFrom, to, mail) } return err } ``` ## /common/embed-file-system.go ```go path="/common/embed-file-system.go" package common import ( "embed" "github.com/gin-contrib/static" "io/fs" "net/http" ) // Credit: https://github.com/gin-contrib/static/issues/19 type embedFileSystem struct { http.FileSystem } func (e embedFileSystem) Exists(prefix string, path string) bool { _, err := e.Open(path) if err != nil { return false } return true } func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem { efs, err := fs.Sub(fsEmbed, targetPath) if err != nil { panic(err) } return embedFileSystem{ FileSystem: http.FS(efs), } } ``` ## /common/env.go ```go path="/common/env.go" package common import ( "fmt" "os" "strconv" ) func GetEnvOrDefault(env string, defaultValue int) int { if env == "" || os.Getenv(env) == "" { return defaultValue } num, err := strconv.Atoi(os.Getenv(env)) if err != nil { SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) return defaultValue } return num } func GetEnvOrDefaultString(env string, defaultValue string) string { if env == "" || os.Getenv(env) == "" { return defaultValue } return os.Getenv(env) } func GetEnvOrDefaultBool(env string, defaultValue bool) bool { if env == "" || os.Getenv(env) == "" { return defaultValue } b, err := strconv.ParseBool(os.Getenv(env)) if err != nil { SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue)) return defaultValue } return b } ``` ## /common/gin.go ```go path="/common/gin.go" package common import ( "bytes" "encoding/json" "github.com/gin-gonic/gin" "io" "strings" ) const KeyRequestBody = "key_request_body" func GetRequestBody(c *gin.Context) ([]byte, error) { requestBody, _ := c.Get(KeyRequestBody) if requestBody != nil { return requestBody.([]byte), nil } requestBody, err := io.ReadAll(c.Request.Body) if err != nil { return nil, err } _ = c.Request.Body.Close() c.Set(KeyRequestBody, requestBody) return requestBody.([]byte), nil } func UnmarshalBodyReusable(c *gin.Context, v any) error { requestBody, err := GetRequestBody(c) if err != nil { return err } contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { err = json.Unmarshal(requestBody, &v) } else { // skip for now // TODO: someday non json request have variant model, we will need to implementation this } if err != nil { return err } // Reset request body c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return nil } ``` ## /common/go-channel.go ```go path="/common/go-channel.go" package common import ( "time" ) func SafeSendBool(ch chan bool, value bool) (closed bool) { defer func() { // Recover from panic if one occured. A panic would mean the channel was closed. if recover() != nil { closed = true } }() // This will panic if the channel is closed. ch <- value // If the code reaches here, then the channel was not closed. return false } func SafeSendString(ch chan string, value string) (closed bool) { defer func() { // Recover from panic if one occured. A panic would mean the channel was closed. if recover() != nil { closed = true } }() // This will panic if the channel is closed. ch <- value // If the code reaches here, then the channel was not closed. return false } // SafeSendStringTimeout send, return true, else return false func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) { defer func() { // Recover from panic if one occured. A panic would mean the channel was closed. if recover() != nil { closed = false } }() // This will panic if the channel is closed. select { case ch <- value: return true case <-time.After(time.Duration(timeout) * time.Second): return false } } ``` ## /common/gopool.go ```go path="/common/gopool.go" package common import ( "context" "fmt" "github.com/bytedance/gopkg/util/gopool" "math" ) var relayGoPool gopool.Pool func init() { relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig()) relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) { if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok { SafeSendBool(stopChan, true) } SysError(fmt.Sprintf("panic in gopool.RelayPool: %v", i)) }) } func RelayCtxGo(ctx context.Context, f func()) { relayGoPool.CtxGo(ctx, f) } ``` ## /common/init.go ```go path="/common/init.go" package common import ( "flag" "fmt" "log" "os" "path/filepath" "strconv" "time" ) var ( Port = flag.Int("port", 3000, "the listening port") PrintVersion = flag.Bool("version", false, "print version and exit") PrintHelp = flag.Bool("help", false, "print help and exit") LogDir = flag.String("log-dir", "./logs", "specify the log directory") ) func printHelp() { fmt.Println("New API " + Version + " - All in one API service for OpenAI API.") fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.") fmt.Println("GitHub: https://github.com/songquanpeng/one-api") fmt.Println("Usage: one-api [--port ] [--log-dir ] [--version] [--help]") } func LoadEnv() { flag.Parse() if *PrintVersion { fmt.Println(Version) os.Exit(0) } if *PrintHelp { printHelp() os.Exit(0) } if os.Getenv("SESSION_SECRET") != "" { ss := os.Getenv("SESSION_SECRET") if ss == "random_string" { log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.") log.Println("警告:SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。") log.Fatal("Please set SESSION_SECRET to a random string.") } else { SessionSecret = ss } } if os.Getenv("CRYPTO_SECRET") != "" { CryptoSecret = os.Getenv("CRYPTO_SECRET") } else { CryptoSecret = SessionSecret } if os.Getenv("SQLITE_PATH") != "" { SQLitePath = os.Getenv("SQLITE_PATH") } if *LogDir != "" { var err error *LogDir, err = filepath.Abs(*LogDir) if err != nil { log.Fatal(err) } if _, err := os.Stat(*LogDir); os.IsNotExist(err) { err = os.Mkdir(*LogDir, 0777) if err != nil { log.Fatal(err) } } } // Initialize variables from constants.go that were using environment variables DebugEnabled = os.Getenv("DEBUG") == "true" MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" IsMasterNode = os.Getenv("NODE_TYPE") != "slave" // Parse requestInterval and set RequestInterval requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) RequestInterval = time.Duration(requestInterval) * time.Second // Initialize variables with GetEnvOrDefault SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5) RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // Initialize string variables with GetEnvOrDefaultString GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE") // Initialize rate limit variables GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true) GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180)) GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true) GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180)) } ``` ## /common/json.go ```go path="/common/json.go" package common import ( "bytes" "encoding/json" ) func DecodeJson(data []byte, v any) error { return json.NewDecoder(bytes.NewReader(data)).Decode(v) } func DecodeJsonStr(data string, v any) error { return DecodeJson(StringToByteSlice(data), v) } func EncodeJson(v any) ([]byte, error) { return json.Marshal(v) } ``` ## /common/limiter/limiter.go ```go path="/common/limiter/limiter.go" package limiter import ( "context" _ "embed" "fmt" "github.com/go-redis/redis/v8" "one-api/common" "sync" ) //go:embed lua/rate_limit.lua var rateLimitScript string type RedisLimiter struct { client *redis.Client limitScriptSHA string } var ( instance *RedisLimiter once sync.Once ) func New(ctx context.Context, r *redis.Client) *RedisLimiter { once.Do(func() { // 预加载脚本 limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result() if err != nil { common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) } instance = &RedisLimiter{ client: r, limitScriptSHA: limitSHA, } }) return instance } func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) { // 默认配置 config := &Config{ Capacity: 10, Rate: 1, Requested: 1, } // 应用选项模式 for _, opt := range opts { opt(config) } // 执行限流 result, err := rl.client.EvalSha( ctx, rl.limitScriptSHA, []string{key}, config.Requested, config.Rate, config.Capacity, ).Int() if err != nil { return false, fmt.Errorf("rate limit failed: %w", err) } return result == 1, nil } // Config 配置选项模式 type Config struct { Capacity int64 Rate int64 Requested int64 } type Option func(*Config) func WithCapacity(c int64) Option { return func(cfg *Config) { cfg.Capacity = c } } func WithRate(r int64) Option { return func(cfg *Config) { cfg.Rate = r } } func WithRequested(n int64) Option { return func(cfg *Config) { cfg.Requested = n } } ``` ## /common/limiter/lua/rate_limit.lua ```lua path="/common/limiter/lua/rate_limit.lua" -- 令牌桶限流器 -- KEYS[1]: 限流器唯一标识 -- ARGV[1]: 请求令牌数 (通常为1) -- ARGV[2]: 令牌生成速率 (每秒) -- ARGV[3]: 桶容量 local key = KEYS[1] local requested = tonumber(ARGV[1]) local rate = tonumber(ARGV[2]) local capacity = tonumber(ARGV[3]) -- 获取当前时间(Redis服务器时间) local now = redis.call('TIME') local nowInSeconds = tonumber(now[1]) -- 获取桶状态 local bucket = redis.call('HMGET', key, 'tokens', 'last_time') local tokens = tonumber(bucket[1]) local last_time = tonumber(bucket[2]) -- 初始化桶(首次请求或过期) if not tokens or not last_time then tokens = capacity last_time = nowInSeconds else -- 计算新增令牌 local elapsed = nowInSeconds - last_time local add_tokens = elapsed * rate tokens = math.min(capacity, tokens + add_tokens) last_time = nowInSeconds end -- 判断是否允许请求 local allowed = false if tokens >= requested then tokens = tokens - requested allowed = true end ---- 更新桶状态并设置过期时间 redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time) --redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间 return allowed and 1 or 0 ``` ## /common/logger.go ```go path="/common/logger.go" package common import ( "context" "encoding/json" "fmt" "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" "io" "log" "os" "path/filepath" "sync" "time" ) const ( loggerINFO = "INFO" loggerWarn = "WARN" loggerError = "ERR" ) const maxLogCount = 1000000 var logCount int var setupLogLock sync.Mutex var setupLogWorking bool func SetupLogger() { if *LogDir != "" { ok := setupLogLock.TryLock() if !ok { log.Println("setup log is already working") return } defer func() { setupLogLock.Unlock() setupLogWorking = false }() logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") } gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) } } func SysLog(s string) { t := time.Now() _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } func SysError(s string) { t := time.Now() _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } func LogInfo(ctx context.Context, msg string) { logHelper(ctx, loggerINFO, msg) } func LogWarn(ctx context.Context, msg string) { logHelper(ctx, loggerWarn, msg) } func LogError(ctx context.Context, msg string) { logHelper(ctx, loggerError, msg) } func logHelper(ctx context.Context, level string, msg string) { writer := gin.DefaultErrorWriter if level == loggerINFO { writer = gin.DefaultWriter } id := ctx.Value(RequestIdKey) now := time.Now() _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) logCount++ // we don't need accurate count, so no lock here if logCount > maxLogCount && !setupLogWorking { logCount = 0 setupLogWorking = true gopool.Go(func() { SetupLogger() }) } } func FatalLog(v ...any) { t := time.Now() _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) os.Exit(1) } func LogQuota(quota int) string { if DisplayInCurrencyEnabled { return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) } else { return fmt.Sprintf("%d 点额度", quota) } } func FormatQuota(quota int) string { if DisplayInCurrencyEnabled { return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit) } else { return fmt.Sprintf("%d", quota) } } // LogJson 仅供测试使用 only for test func LogJson(ctx context.Context, msg string, obj any) { jsonStr, err := json.Marshal(obj) if err != nil { LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error())) return } LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr))) } ``` ## /common/pprof.go ```go path="/common/pprof.go" package common import ( "fmt" "github.com/shirou/gopsutil/cpu" "os" "runtime/pprof" "time" ) // Monitor 定时监控cpu使用率,超过阈值输出pprof文件 func Monitor() { for { percent, err := cpu.Percent(time.Second, false) if err != nil { panic(err) } if percent[0] > 80 { fmt.Println("cpu usage too high") // write pprof file if _, err := os.Stat("./pprof"); os.IsNotExist(err) { err := os.Mkdir("./pprof", os.ModePerm) if err != nil { SysLog("创建pprof文件夹失败 " + err.Error()) continue } } f, err := os.Create("./pprof/" + fmt.Sprintf("cpu-%s.pprof", time.Now().Format("20060102150405"))) if err != nil { SysLog("创建pprof文件失败 " + err.Error()) continue } err = pprof.StartCPUProfile(f) if err != nil { SysLog("启动pprof失败 " + err.Error()) continue } time.Sleep(10 * time.Second) // profile for 30 seconds pprof.StopCPUProfile() f.Close() } time.Sleep(30 * time.Second) } } ``` ## /common/rate-limit.go ```go path="/common/rate-limit.go" package common import ( "sync" "time" ) type InMemoryRateLimiter struct { store map[string]*[]int64 mutex sync.Mutex expirationDuration time.Duration } func (l *InMemoryRateLimiter) Init(expirationDuration time.Duration) { if l.store == nil { l.mutex.Lock() if l.store == nil { l.store = make(map[string]*[]int64) l.expirationDuration = expirationDuration if expirationDuration > 0 { go l.clearExpiredItems() } } l.mutex.Unlock() } } func (l *InMemoryRateLimiter) clearExpiredItems() { for { time.Sleep(l.expirationDuration) l.mutex.Lock() now := time.Now().Unix() for key := range l.store { queue := l.store[key] size := len(*queue) if size == 0 || now-(*queue)[size-1] > int64(l.expirationDuration.Seconds()) { delete(l.store, key) } } l.mutex.Unlock() } } // Request parameter duration's unit is seconds func (l *InMemoryRateLimiter) Request(key string, maxRequestNum int, duration int64) bool { l.mutex.Lock() defer l.mutex.Unlock() // [old <-- new] queue, ok := l.store[key] now := time.Now().Unix() if ok { if len(*queue) < maxRequestNum { *queue = append(*queue, now) return true } else { if now-(*queue)[0] >= duration { *queue = (*queue)[1:] *queue = append(*queue, now) return true } else { return false } } } else { s := make([]int64, 0, maxRequestNum) l.store[key] = &s *(l.store[key]) = append(*(l.store[key]), now) } return true } ``` ## /common/redis.go ```go path="/common/redis.go" package common import ( "context" "errors" "fmt" "os" "reflect" "strconv" "time" "github.com/go-redis/redis/v8" "gorm.io/gorm" ) var RDB *redis.Client var RedisEnabled = true // InitRedisClient This function is called after init() func InitRedisClient() (err error) { if os.Getenv("REDIS_CONN_STRING") == "" { RedisEnabled = false SysLog("REDIS_CONN_STRING not set, Redis is not enabled") return nil } if os.Getenv("SYNC_FREQUENCY") == "" { SysLog("SYNC_FREQUENCY not set, use default value 60") SyncFrequency = 60 } SysLog("Redis is enabled") opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) if err != nil { FatalLog("failed to parse Redis connection string: " + err.Error()) } opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10) RDB = redis.NewClient(opt) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, err = RDB.Ping(ctx).Result() if err != nil { FatalLog("Redis ping test failed: " + err.Error()) } if DebugEnabled { SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr)) SysLog(fmt.Sprintf("Redis database: %d", opt.DB)) } return err } func ParseRedisOption() *redis.Options { opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) if err != nil { FatalLog("failed to parse Redis connection string: " + err.Error()) } return opt } func RedisSet(key string, value string, expiration time.Duration) error { if DebugEnabled { SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration)) } ctx := context.Background() return RDB.Set(ctx, key, value, expiration).Err() } func RedisGet(key string) (string, error) { if DebugEnabled { SysLog(fmt.Sprintf("Redis GET: key=%s", key)) } ctx := context.Background() val, err := RDB.Get(ctx, key).Result() return val, err } //func RedisExpire(key string, expiration time.Duration) error { // ctx := context.Background() // return RDB.Expire(ctx, key, expiration).Err() //} // //func RedisGetEx(key string, expiration time.Duration) (string, error) { // ctx := context.Background() // return RDB.GetSet(ctx, key, expiration).Result() //} func RedisDel(key string) error { if DebugEnabled { SysLog(fmt.Sprintf("Redis DEL: key=%s", key)) } ctx := context.Background() return RDB.Del(ctx, key).Err() } func RedisHDelObj(key string) error { if DebugEnabled { SysLog(fmt.Sprintf("Redis HDEL: key=%s", key)) } ctx := context.Background() return RDB.HDel(ctx, key).Err() } func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error { if DebugEnabled { SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration)) } ctx := context.Background() data := make(map[string]interface{}) // 使用反射遍历结构体字段 v := reflect.ValueOf(obj).Elem() t := v.Type() for i := 0; i < v.NumField(); i++ { field := t.Field(i) value := v.Field(i) // Skip DeletedAt field if field.Type.String() == "gorm.DeletedAt" { continue } // 处理指针类型 if value.Kind() == reflect.Ptr { if value.IsNil() { data[field.Name] = "" continue } value = value.Elem() } // 处理布尔类型 if value.Kind() == reflect.Bool { data[field.Name] = strconv.FormatBool(value.Bool()) continue } // 其他类型直接转换为字符串 data[field.Name] = fmt.Sprintf("%v", value.Interface()) } txn := RDB.TxPipeline() txn.HSet(ctx, key, data) txn.Expire(ctx, key, expiration) _, err := txn.Exec(ctx) if err != nil { return fmt.Errorf("failed to execute transaction: %w", err) } return nil } func RedisHGetObj(key string, obj interface{}) error { if DebugEnabled { SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key)) } ctx := context.Background() result, err := RDB.HGetAll(ctx, key).Result() if err != nil { return fmt.Errorf("failed to load hash from Redis: %w", err) } if len(result) == 0 { return fmt.Errorf("key %s not found in Redis", key) } // Handle both pointer and non-pointer values val := reflect.ValueOf(obj) if val.Kind() != reflect.Ptr { return fmt.Errorf("obj must be a pointer to a struct, got %T", obj) } v := val.Elem() if v.Kind() != reflect.Struct { return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface()) } t := v.Type() for i := 0; i < v.NumField(); i++ { field := t.Field(i) fieldName := field.Name if value, ok := result[fieldName]; ok { fieldValue := v.Field(i) // Handle pointer types if fieldValue.Kind() == reflect.Ptr { if value == "" { continue } if fieldValue.IsNil() { fieldValue.Set(reflect.New(fieldValue.Type().Elem())) } fieldValue = fieldValue.Elem() } // Enhanced type handling for Token struct switch fieldValue.Kind() { case reflect.String: fieldValue.SetString(value) case reflect.Int, reflect.Int64: intValue, err := strconv.ParseInt(value, 10, 64) if err != nil { return fmt.Errorf("failed to parse int field %s: %w", fieldName, err) } fieldValue.SetInt(intValue) case reflect.Bool: boolValue, err := strconv.ParseBool(value) if err != nil { return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err) } fieldValue.SetBool(boolValue) case reflect.Struct: // Special handling for gorm.DeletedAt if fieldValue.Type().String() == "gorm.DeletedAt" { if value != "" { timeValue, err := time.Parse(time.RFC3339, value) if err != nil { return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err) } fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true})) } } default: return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName) } } } return nil } // RedisIncr Add this function to handle atomic increments func RedisIncr(key string, delta int64) error { if DebugEnabled { SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta)) } // 检查键的剩余生存时间 ttlCmd := RDB.TTL(context.Background(), key) ttl, err := ttlCmd.Result() if err != nil && !errors.Is(err, redis.Nil) { return fmt.Errorf("failed to get TTL: %w", err) } // 只有在 key 存在且有 TTL 时才需要特殊处理 if ttl > 0 { ctx := context.Background() // 开始一个Redis事务 txn := RDB.TxPipeline() // 减少余额 decrCmd := txn.IncrBy(ctx, key, delta) if err := decrCmd.Err(); err != nil { return err // 如果减少失败,则直接返回错误 } // 重新设置过期时间,使用原来的过期时间 txn.Expire(ctx, key, ttl) // 执行事务 _, err = txn.Exec(ctx) return err } return nil } func RedisHIncrBy(key, field string, delta int64) error { if DebugEnabled { SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta)) } ttlCmd := RDB.TTL(context.Background(), key) ttl, err := ttlCmd.Result() if err != nil && !errors.Is(err, redis.Nil) { return fmt.Errorf("failed to get TTL: %w", err) } if ttl > 0 { ctx := context.Background() txn := RDB.TxPipeline() incrCmd := txn.HIncrBy(ctx, key, field, delta) if err := incrCmd.Err(); err != nil { return err } txn.Expire(ctx, key, ttl) _, err = txn.Exec(ctx) return err } return nil } func RedisHSetField(key, field string, value interface{}) error { if DebugEnabled { SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value)) } ttlCmd := RDB.TTL(context.Background(), key) ttl, err := ttlCmd.Result() if err != nil && !errors.Is(err, redis.Nil) { return fmt.Errorf("failed to get TTL: %w", err) } if ttl > 0 { ctx := context.Background() txn := RDB.TxPipeline() hsetCmd := txn.HSet(ctx, key, field, value) if err := hsetCmd.Err(); err != nil { return err } txn.Expire(ctx, key, ttl) _, err = txn.Exec(ctx) return err } return nil } ``` ## /common/str.go ```go path="/common/str.go" package common import ( "encoding/json" "math/rand" "strconv" "unsafe" ) func GetStringIfEmpty(str string, defaultValue string) string { if str == "" { return defaultValue } return str } func GetRandomString(length int) string { //rand.Seed(time.Now().UnixNano()) key := make([]byte, length) for i := 0; i < length; i++ { key[i] = keyChars[rand.Intn(len(keyChars))] } return string(key) } func MapToJsonStr(m map[string]interface{}) string { bytes, err := json.Marshal(m) if err != nil { return "" } return string(bytes) } func StrToMap(str string) map[string]interface{} { m := make(map[string]interface{}) err := json.Unmarshal([]byte(str), &m) if err != nil { return nil } return m } func IsJsonStr(str string) bool { var js map[string]interface{} return json.Unmarshal([]byte(str), &js) == nil } func String2Int(str string) int { num, err := strconv.Atoi(str) if err != nil { return 0 } return num } func StringsContains(strs []string, str string) bool { for _, s := range strs { if s == str { return true } } return false } // StringToByteSlice []byte only read, panic on append func StringToByteSlice(s string) []byte { tmp1 := (*[2]uintptr)(unsafe.Pointer(&s)) tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} return *(*[]byte)(unsafe.Pointer(&tmp2)) } ``` ## /common/topup-ratio.go ```go path="/common/topup-ratio.go" package common import ( "encoding/json" ) var TopupGroupRatio = map[string]float64{ "default": 1, "vip": 1, "svip": 1, } func TopupGroupRatio2JSONString() string { jsonBytes, err := json.Marshal(TopupGroupRatio) if err != nil { SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } func UpdateTopupGroupRatioByJSONString(jsonStr string) error { TopupGroupRatio = make(map[string]float64) return json.Unmarshal([]byte(jsonStr), &TopupGroupRatio) } func GetTopupGroupRatio(name string) float64 { ratio, ok := TopupGroupRatio[name] if !ok { SysError("topup group ratio not found: " + name) return 1 } return ratio } ``` ## /common/utils.go ```go path="/common/utils.go" package common import ( "bytes" "context" crand "crypto/rand" "encoding/base64" "encoding/json" "fmt" "html/template" "io" "log" "math/big" "math/rand" "net" "os" "os/exec" "runtime" "strconv" "strings" "time" "github.com/google/uuid" "github.com/pkg/errors" ) func OpenBrowser(url string) { var err error switch runtime.GOOS { case "linux": err = exec.Command("xdg-open", url).Start() case "windows": err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() case "darwin": err = exec.Command("open", url).Start() } if err != nil { log.Println(err) } } func GetIp() (ip string) { ips, err := net.InterfaceAddrs() if err != nil { log.Println(err) return ip } for _, a := range ips { if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { if ipNet.IP.To4() != nil { ip = ipNet.IP.String() if strings.HasPrefix(ip, "10") { return } if strings.HasPrefix(ip, "172") { return } if strings.HasPrefix(ip, "192.168") { return } ip = "" } } } return } var sizeKB = 1024 var sizeMB = sizeKB * 1024 var sizeGB = sizeMB * 1024 func Bytes2Size(num int64) string { numStr := "" unit := "B" if num/int64(sizeGB) > 1 { numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) unit = "GB" } else if num/int64(sizeMB) > 1 { numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) unit = "MB" } else if num/int64(sizeKB) > 1 { numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) unit = "KB" } else { numStr = fmt.Sprintf("%d", num) } return numStr + " " + unit } func Seconds2Time(num int) (time string) { if num/31104000 > 0 { time += strconv.Itoa(num/31104000) + " 年 " num %= 31104000 } if num/2592000 > 0 { time += strconv.Itoa(num/2592000) + " 个月 " num %= 2592000 } if num/86400 > 0 { time += strconv.Itoa(num/86400) + " 天 " num %= 86400 } if num/3600 > 0 { time += strconv.Itoa(num/3600) + " 小时 " num %= 3600 } if num/60 > 0 { time += strconv.Itoa(num/60) + " 分钟 " num %= 60 } time += strconv.Itoa(num) + " 秒" return } func Interface2String(inter interface{}) string { switch inter.(type) { case string: return inter.(string) case int: return fmt.Sprintf("%d", inter.(int)) case float64: return fmt.Sprintf("%f", inter.(float64)) } return "Not Implemented" } func UnescapeHTML(x string) interface{} { return template.HTML(x) } func IntMax(a int, b int) int { if a >= b { return a } else { return b } } func IsIP(s string) bool { ip := net.ParseIP(s) return ip != nil } func GetUUID() string { code := uuid.New().String() code = strings.Replace(code, "-", "", -1) return code } const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" func init() { rand.New(rand.NewSource(time.Now().UnixNano())) } func GenerateRandomCharsKey(length int) (string, error) { b := make([]byte, length) maxI := big.NewInt(int64(len(keyChars))) for i := range b { n, err := crand.Int(crand.Reader, maxI) if err != nil { return "", err } b[i] = keyChars[n.Int64()] } return string(b), nil } func GenerateRandomKey(length int) (string, error) { bytes := make([]byte, length*3/4) // 对于48位的输出,这里应该是36 if _, err := crand.Read(bytes); err != nil { return "", err } return base64.StdEncoding.EncodeToString(bytes), nil } func GenerateKey() (string, error) { //rand.Seed(time.Now().UnixNano()) return GenerateRandomCharsKey(48) } func GetRandomInt(max int) int { //rand.Seed(time.Now().UnixNano()) return rand.Intn(max) } func GetTimestamp() int64 { return time.Now().Unix() } func GetTimeString() string { now := time.Now() return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) } func Max(a int, b int) int { if a >= b { return a } else { return b } } func MessageWithRequestId(message string, id string) string { return fmt.Sprintf("%s (request id: %s)", message, id) } func RandomSleep() { // Sleep for 0-3000 ms time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) } func GetPointer[T any](v T) *T { return &v } func Any2Type[T any](data any) (T, error) { var zero T bytes, err := json.Marshal(data) if err != nil { return zero, err } var res T err = json.Unmarshal(bytes, &res) if err != nil { return zero, err } return res, nil } // SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string. func SaveTmpFile(filename string, data io.Reader) (string, error) { f, err := os.CreateTemp(os.TempDir(), filename) if err != nil { return "", errors.Wrapf(err, "failed to create temporary file %s", filename) } defer f.Close() _, err = io.Copy(f, data) if err != nil { return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename) } return f.Name(), nil } // GetAudioDuration returns the duration of an audio file in seconds. func GetAudioDuration(ctx context.Context, filename string) (float64, error) { // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}} c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename) output, err := c.Output() if err != nil { return 0, errors.Wrap(err, "failed to get audio duration") } return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64) } ``` ## /common/validate.go ```go path="/common/validate.go" package common import "github.com/go-playground/validator/v10" var Validate *validator.Validate func init() { Validate = validator.New() } ``` ## /common/verification.go ```go path="/common/verification.go" package common import ( "github.com/google/uuid" "strings" "sync" "time" ) type verificationValue struct { code string time time.Time } const ( EmailVerificationPurpose = "v" PasswordResetPurpose = "r" ) var verificationMutex sync.Mutex var verificationMap map[string]verificationValue var verificationMapMaxSize = 10 var VerificationValidMinutes = 10 func GenerateVerificationCode(length int) string { code := uuid.New().String() code = strings.Replace(code, "-", "", -1) if length == 0 { return code } return code[:length] } func RegisterVerificationCodeWithKey(key string, code string, purpose string) { verificationMutex.Lock() defer verificationMutex.Unlock() verificationMap[purpose+key] = verificationValue{ code: code, time: time.Now(), } if len(verificationMap) > verificationMapMaxSize { removeExpiredPairs() } } func VerifyCodeWithKey(key string, code string, purpose string) bool { verificationMutex.Lock() defer verificationMutex.Unlock() value, okay := verificationMap[purpose+key] now := time.Now() if !okay || int(now.Sub(value.time).Seconds()) >= VerificationValidMinutes*60 { return false } return code == value.code } func DeleteKey(key string, purpose string) { verificationMutex.Lock() defer verificationMutex.Unlock() delete(verificationMap, purpose+key) } // no lock inside, so the caller must lock the verificationMap before calling! func removeExpiredPairs() { now := time.Now() for key := range verificationMap { if int(now.Sub(verificationMap[key].time).Seconds()) >= VerificationValidMinutes*60 { delete(verificationMap, key) } } } func init() { verificationMutex.Lock() defer verificationMutex.Unlock() verificationMap = make(map[string]verificationValue) } ``` ## /constant/cache_key.go ```go path="/constant/cache_key.go" package constant import "one-api/common" var ( TokenCacheSeconds = common.SyncFrequency UserId2GroupCacheSeconds = common.SyncFrequency UserId2QuotaCacheSeconds = common.SyncFrequency UserId2StatusCacheSeconds = common.SyncFrequency ) // Cache keys const ( UserGroupKeyFmt = "user_group:%d" UserQuotaKeyFmt = "user_quota:%d" UserEnabledKeyFmt = "user_enabled:%d" UserUsernameKeyFmt = "user_name:%d" ) const ( TokenFiledRemainQuota = "RemainQuota" TokenFieldGroup = "Group" ) ``` ## /constant/channel_setting.go ```go path="/constant/channel_setting.go" package constant var ( ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式 ChanelSettingProxy = "proxy" // Proxy 代理 ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent ) ``` ## /constant/context_key.go ```go path="/constant/context_key.go" package constant const ( ContextKeyRequestStartTime = "request_start_time" ContextKeyUserSetting = "user_setting" ContextKeyUserQuota = "user_quota" ContextKeyUserStatus = "user_status" ContextKeyUserEmail = "user_email" ContextKeyUserGroup = "user_group" ) ``` ## /constant/env.go ```go path="/constant/env.go" package constant import ( "one-api/common" ) var StreamingTimeout int var DifyDebug bool var MaxFileDownloadMB int var ForceStreamOption bool var GetMediaToken bool var GetMediaTokenNotStream bool var UpdateTask bool var AzureDefaultAPIVersion string var GeminiVisionMaxImageNum int var NotifyLimitCount int var NotificationLimitDurationMinute int var GenerateDefaultToken bool //var GeminiModelMap = map[string]string{ // "gemini-1.0-pro": "v1", //} func InitEnv() { StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60) DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true) MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20) // ForceStreamOption 覆盖请求参数,强制返回usage信息 ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview") GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) // GenerateDefaultToken 是否生成初始令牌,默认关闭。 GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false) //modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) //if modelVersionMapStr == "" { // return //} //for _, pair := range strings.Split(modelVersionMapStr, ",") { // parts := strings.Split(pair, ":") // if len(parts) == 2 { // GeminiModelMap[parts[0]] = parts[1] // } else { // common.SysError(fmt.Sprintf("invalid model version map: %s", pair)) // } //} } ``` ## /constant/finish_reason.go ```go path="/constant/finish_reason.go" package constant var ( FinishReasonStop = "stop" FinishReasonToolCalls = "tool_calls" FinishReasonLength = "length" FinishReasonFunctionCall = "function_call" FinishReasonContentFilter = "content_filter" ) ``` ## /constant/midjourney.go ```go path="/constant/midjourney.go" package constant const ( MjErrorUnknown = 5 MjRequestError = 4 ) const ( MjActionImagine = "IMAGINE" MjActionDescribe = "DESCRIBE" MjActionBlend = "BLEND" MjActionUpscale = "UPSCALE" MjActionVariation = "VARIATION" MjActionReRoll = "REROLL" MjActionInPaint = "INPAINT" MjActionModal = "MODAL" MjActionZoom = "ZOOM" MjActionCustomZoom = "CUSTOM_ZOOM" MjActionShorten = "SHORTEN" MjActionHighVariation = "HIGH_VARIATION" MjActionLowVariation = "LOW_VARIATION" MjActionPan = "PAN" MjActionSwapFace = "SWAP_FACE" MjActionUpload = "UPLOAD" ) var MidjourneyModel2Action = map[string]string{ "mj_imagine": MjActionImagine, "mj_describe": MjActionDescribe, "mj_blend": MjActionBlend, "mj_upscale": MjActionUpscale, "mj_variation": MjActionVariation, "mj_reroll": MjActionReRoll, "mj_modal": MjActionModal, "mj_inpaint": MjActionInPaint, "mj_zoom": MjActionZoom, "mj_custom_zoom": MjActionCustomZoom, "mj_shorten": MjActionShorten, "mj_high_variation": MjActionHighVariation, "mj_low_variation": MjActionLowVariation, "mj_pan": MjActionPan, "swap_face": MjActionSwapFace, "mj_upload": MjActionUpload, } ``` ## /constant/setup.go ```go path="/constant/setup.go" package constant var Setup = false ``` ## /constant/task.go ```go path="/constant/task.go" package constant type TaskPlatform string const ( TaskPlatformSuno TaskPlatform = "suno" TaskPlatformMidjourney = "mj" ) const ( SunoActionMusic = "MUSIC" SunoActionLyrics = "LYRICS" ) var SunoModel2Action = map[string]string{ "suno_music": SunoActionMusic, "suno_lyrics": SunoActionLyrics, } ``` ## /constant/user_setting.go ```go path="/constant/user_setting.go" package constant var ( UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型 UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值 UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址 UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥 UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址 UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型 ) var ( NotifyTypeEmail = "email" // Email 邮件 NotifyTypeWebhook = "webhook" // Webhook ) ``` ## /controller/billing.go ```go path="/controller/billing.go" package controller import ( "github.com/gin-gonic/gin" "one-api/common" "one-api/dto" "one-api/model" ) func GetSubscription(c *gin.Context) { var remainQuota int var usedQuota int var err error var token *model.Token var expiredTime int64 if common.DisplayTokenStatEnabled { tokenId := c.GetInt("token_id") token, err = model.GetTokenById(tokenId) expiredTime = token.ExpiredTime remainQuota = token.RemainQuota usedQuota = token.UsedQuota } else { userId := c.GetInt("id") remainQuota, err = model.GetUserQuota(userId, false) usedQuota, err = model.GetUserUsedQuota(userId) } if expiredTime <= 0 { expiredTime = 0 } if err != nil { openAIError := dto.OpenAIError{ Message: err.Error(), Type: "upstream_error", } c.JSON(200, gin.H{ "error": openAIError, }) return } quota := remainQuota + usedQuota amount := float64(quota) if common.DisplayInCurrencyEnabled { amount /= common.QuotaPerUnit } if token != nil && token.UnlimitedQuota { amount = 100000000 } subscription := OpenAISubscriptionResponse{ Object: "billing_subscription", HasPaymentMethod: true, SoftLimitUSD: amount, HardLimitUSD: amount, SystemHardLimitUSD: amount, AccessUntil: expiredTime, } c.JSON(200, subscription) return } func GetUsage(c *gin.Context) { var quota int var err error var token *model.Token if common.DisplayTokenStatEnabled { tokenId := c.GetInt("token_id") token, err = model.GetTokenById(tokenId) quota = token.UsedQuota } else { userId := c.GetInt("id") quota, err = model.GetUserUsedQuota(userId) } if err != nil { openAIError := dto.OpenAIError{ Message: err.Error(), Type: "new_api_error", } c.JSON(200, gin.H{ "error": openAIError, }) return } amount := float64(quota) if common.DisplayInCurrencyEnabled { amount /= common.QuotaPerUnit } usage := OpenAIUsageResponse{ Object: "list", TotalUsage: amount * 100, } c.JSON(200, usage) return } ``` ## /controller/channel-billing.go ```go path="/controller/channel-billing.go" package controller import ( "encoding/json" "errors" "fmt" "io" "net/http" "one-api/common" "one-api/model" "one-api/service" "strconv" "time" "github.com/gin-gonic/gin" ) // https://github.com/songquanpeng/one-api/issues/79 type OpenAISubscriptionResponse struct { Object string `json:"object"` HasPaymentMethod bool `json:"has_payment_method"` SoftLimitUSD float64 `json:"soft_limit_usd"` HardLimitUSD float64 `json:"hard_limit_usd"` SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` AccessUntil int64 `json:"access_until"` } type OpenAIUsageDailyCost struct { Timestamp float64 `json:"timestamp"` LineItems []struct { Name string `json:"name"` Cost float64 `json:"cost"` } } type OpenAICreditGrants struct { Object string `json:"object"` TotalGranted float64 `json:"total_granted"` TotalUsed float64 `json:"total_used"` TotalAvailable float64 `json:"total_available"` } type OpenAIUsageResponse struct { Object string `json:"object"` //DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"` TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar } type OpenAISBUsageResponse struct { Msg string `json:"msg"` Data *struct { Credit string `json:"credit"` } `json:"data"` } type AIProxyUserOverviewResponse struct { Success bool `json:"success"` Message string `json:"message"` ErrorCode int `json:"error_code"` Data struct { TotalPoints float64 `json:"totalPoints"` } `json:"data"` } type API2GPTUsageResponse struct { Object string `json:"object"` TotalGranted float64 `json:"total_granted"` TotalUsed float64 `json:"total_used"` TotalRemaining float64 `json:"total_remaining"` } type APGC2DGPTUsageResponse struct { //Grants interface{} `json:"grants"` Object string `json:"object"` TotalAvailable float64 `json:"total_available"` TotalGranted float64 `json:"total_granted"` TotalUsed float64 `json:"total_used"` } type SiliconFlowUsageResponse struct { Code int `json:"code"` Message string `json:"message"` Status bool `json:"status"` Data struct { ID string `json:"id"` Name string `json:"name"` Image string `json:"image"` Email string `json:"email"` IsAdmin bool `json:"isAdmin"` Balance string `json:"balance"` Status string `json:"status"` Introduction string `json:"introduction"` Role string `json:"role"` ChargeBalance string `json:"chargeBalance"` TotalBalance string `json:"totalBalance"` Category string `json:"category"` } `json:"data"` } type DeepSeekUsageResponse struct { IsAvailable bool `json:"is_available"` BalanceInfos []struct { Currency string `json:"currency"` TotalBalance string `json:"total_balance"` GrantedBalance string `json:"granted_balance"` ToppedUpBalance string `json:"topped_up_balance"` } `json:"balance_infos"` } // GetAuthHeader get auth header func GetAuthHeader(token string) http.Header { h := http.Header{} h.Add("Authorization", fmt.Sprintf("Bearer %s", token)) return h } func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { req, err := http.NewRequest(method, url, nil) if err != nil { return nil, err } for k := range headers { req.Header.Add(k, headers.Get(k)) } res, err := service.GetHttpClient().Do(req) if err != nil { return nil, err } if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("status code: %d", res.StatusCode) } body, err := io.ReadAll(res.Body) if err != nil { return nil, err } err = res.Body.Close() if err != nil { return nil, err } return body, nil } func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := OpenAICreditGrants{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } channel.UpdateBalance(response.TotalAvailable) return response.TotalAvailable, nil } func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := OpenAISBUsageResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } if response.Data == nil { return 0, errors.New(response.Msg) } balance, err := strconv.ParseFloat(response.Data.Credit, 64) if err != nil { return 0, err } channel.UpdateBalance(balance) return balance, nil } func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) { url := "https://aiproxy.io/api/report/getUserOverview" headers := http.Header{} headers.Add("Api-Key", channel.Key) body, err := GetResponseBody("GET", url, channel, headers) if err != nil { return 0, err } response := AIProxyUserOverviewResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } if !response.Success { return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) } channel.UpdateBalance(response.Data.TotalPoints) return response.Data.TotalPoints, nil } func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) { url := "https://api.api2gpt.com/dashboard/billing/credit_grants" body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := API2GPTUsageResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } channel.UpdateBalance(response.TotalRemaining) return response.TotalRemaining, nil } func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) { url := "https://api.siliconflow.cn/v1/user/info" body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := SiliconFlowUsageResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } if response.Code != 20000 { return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message) } balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64) if err != nil { return 0, err } channel.UpdateBalance(balance) return balance, nil } func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) { url := "https://api.deepseek.com/user/balance" body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := DeepSeekUsageResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } index := -1 for i, balanceInfo := range response.BalanceInfos { if balanceInfo.Currency == "CNY" { index = i break } } if index == -1 { return 0, errors.New("currency CNY not found") } balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64) if err != nil { return 0, err } channel.UpdateBalance(balance) return balance, nil } func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { url := "https://api.aigc2d.com/dashboard/billing/credit_grants" body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } response := APGC2DGPTUsageResponse{} err = json.Unmarshal(body, &response) if err != nil { return 0, err } channel.UpdateBalance(response.TotalAvailable) return response.TotalAvailable, nil } func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL := common.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { channel.BaseURL = &baseURL } switch channel.Type { case common.ChannelTypeOpenAI: if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } case common.ChannelTypeAzure: return 0, errors.New("尚未实现") case common.ChannelTypeCustom: baseURL = channel.GetBaseURL() //case common.ChannelTypeOpenAISB: // return updateChannelOpenAISBBalance(channel) case common.ChannelTypeAIProxy: return updateChannelAIProxyBalance(channel) case common.ChannelTypeAPI2GPT: return updateChannelAPI2GPTBalance(channel) case common.ChannelTypeAIGC2D: return updateChannelAIGC2DBalance(channel) case common.ChannelTypeSiliconFlow: return updateChannelSiliconFlowBalance(channel) case common.ChannelTypeDeepSeek: return updateChannelDeepSeekBalance(channel) default: return 0, errors.New("尚未实现") } url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } subscription := OpenAISubscriptionResponse{} err = json.Unmarshal(body, &subscription) if err != nil { return 0, err } now := time.Now() startDate := fmt.Sprintf("%s-01", now.Format("2006-01")) endDate := now.Format("2006-01-02") if !subscription.HasPaymentMethod { startDate = now.AddDate(0, 0, -100).Format("2006-01-02") } url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { return 0, err } usage := OpenAIUsageResponse{} err = json.Unmarshal(body, &usage) if err != nil { return 0, err } balance := subscription.HardLimitUSD - usage.TotalUsage/100 channel.UpdateBalance(balance) return balance, nil } func UpdateChannelBalance(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channel, err := model.GetChannelById(id, true) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } balance, err := updateChannelBalance(channel) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "balance": balance, }) return } func updateAllChannelsBalance() error { channels, err := model.GetAllChannels(0, 0, true, false) if err != nil { return err } for _, channel := range channels { if channel.Status != common.ChannelStatusEnabled { continue } // TODO: support Azure //if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { // continue //} balance, err := updateChannelBalance(channel) if err != nil { continue } else { // err is nil & balance <= 0 means quota is used up if balance <= 0 { service.DisableChannel(channel.Id, channel.Name, "余额不足") } } time.Sleep(common.RequestInterval) } return nil } func UpdateAllChannelsBalance(c *gin.Context) { // TODO: make it async err := updateAllChannelsBalance() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func AutomaticallyUpdateChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) common.SysLog("updating all channels") _ = updateAllChannelsBalance() common.SysLog("channels update done") } } ``` ## /controller/channel-test.go ```go path="/controller/channel-test.go" package controller import ( "bytes" "encoding/json" "errors" "fmt" "io" "math" "net/http" "net/http/httptest" "net/url" "one-api/common" "one-api/dto" "one-api/middleware" "one-api/model" "one-api/relay" relaycommon "one-api/relay/common" "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "strconv" "strings" "sync" "time" "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" ) func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { tik := time.Now() if channel.Type == common.ChannelTypeMidjourney { return errors.New("midjourney channel test is not supported"), nil } if channel.Type == common.ChannelTypeMidjourneyPlus { return errors.New("midjourney plus channel test is not supported!!!"), nil } if channel.Type == common.ChannelTypeSunoAPI { return errors.New("suno channel test is not supported"), nil } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) requestPath := "/v1/chat/completions" // 先判断是否为 Embedding 模型 if strings.Contains(strings.ToLower(testModel), "embedding") || strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 strings.Contains(testModel, "bge-") || // bge 系列模型 strings.Contains(testModel, "embed") || channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型 requestPath = "/v1/embeddings" // 修改请求路径 } c.Request = &http.Request{ Method: "POST", URL: &url.URL{Path: requestPath}, // 使用动态路径 Body: nil, Header: make(http.Header), } if testModel == "" { if channel.TestModel != nil && *channel.TestModel != "" { testModel = *channel.TestModel } else { if len(channel.GetModels()) > 0 { testModel = channel.GetModels()[0] } else { testModel = "gpt-4o-mini" } } } cache, err := model.GetUserCache(1) if err != nil { return err, nil } cache.WriteContext(c) c.Request.Header.Set("Authorization", "Bearer "+channel.Key) c.Request.Header.Set("Content-Type", "application/json") c.Set("channel", channel.Type) c.Set("base_url", channel.GetBaseURL()) group, _ := model.GetUserGroup(1, false) c.Set("group", group) middleware.SetupContextForSelectedChannel(c, channel, testModel) info := relaycommon.GenRelayInfo(c) err = helper.ModelMappedHelper(c, info) if err != nil { return err, nil } testModel = info.UpstreamModelName apiType, _ := constant.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) if adaptor == nil { return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } request := buildTestRequest(testModel) // 创建一个用于日志的 info 副本,移除 ApiKey logInfo := *info logInfo.ApiKey = "" common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens)) if err != nil { return err, nil } adaptor.Init(info) convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request) if err != nil { return err, nil } jsonData, err := json.Marshal(convertedRequest) if err != nil { return err, nil } requestBody := bytes.NewBuffer(jsonData) c.Request.Body = io.NopCloser(requestBody) resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return err, nil } var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { err := service.RelayErrorHandler(httpResp, true) return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err } } usageA, respErr := adaptor.DoResponse(c, httpResp, info) if respErr != nil { return fmt.Errorf("%s", respErr.Error.Message), respErr } if usageA == nil { return errors.New("usage is nil"), nil } usage := usageA.(*dto.Usage) result := w.Result() respBody, err := io.ReadAll(result.Body) if err != nil { return err, nil } info.PromptTokens = usage.PromptTokens quota := 0 if !priceData.UsePrice { quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio)) quota = int(math.Round(float64(quota) * priceData.ModelRatio)) if priceData.ModelRatio != 0 && quota <= 0 { quota = 1 } } else { quota = int(priceData.ModelPrice * common.QuotaPerUnit) } tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() consumedTime := float64(milliseconds) / 1000.0 other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice) model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return nil, nil } func buildTestRequest(model string) *dto.GeneralOpenAIRequest { testRequest := &dto.GeneralOpenAIRequest{ Model: "", // this will be set later Stream: false, } // 先判断是否为 Embedding 模型 if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型 strings.HasPrefix(model, "m3e") || // m3e 系列模型 strings.Contains(model, "bge-") { testRequest.Model = model // Embedding 请求 testRequest.Input = []string{"hello world"} return testRequest } // 并非Embedding 模型 if strings.HasPrefix(model, "o") { testRequest.MaxCompletionTokens = 10 } else if strings.Contains(model, "thinking") { if !strings.Contains(model, "claude") { testRequest.MaxTokens = 50 } } else if strings.Contains(model, "gemini") { testRequest.MaxTokens = 300 } else { testRequest.MaxTokens = 10 } content, _ := json.Marshal("hi") testMessage := dto.Message{ Role: "user", Content: content, } testRequest.Model = model testRequest.Messages = append(testRequest.Messages, testMessage) return testRequest } func TestChannel(c *gin.Context) { channelId, err := strconv.Atoi(c.Param("id")) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channel, err := model.GetChannelById(channelId, true) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } testModel := c.Query("model") tik := time.Now() err, _ = testChannel(channel, testModel) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() go channel.UpdateResponseTime(milliseconds) consumedTime := float64(milliseconds) / 1000.0 if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), "time": consumedTime, }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "time": consumedTime, }) return } var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false func testAllChannels(notify bool) error { testAllChannelsLock.Lock() if testAllChannelsRunning { testAllChannelsLock.Unlock() return errors.New("测试已在运行中") } testAllChannelsRunning = true testAllChannelsLock.Unlock() channels, err := model.GetAllChannels(0, 0, true, false) if err != nil { return err } var disableThreshold = int64(common.ChannelDisableThreshold * 1000) if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value } gopool.Go(func() { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() err, openaiWithStatusErr := testChannel(channel, "") tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() shouldBanChannel := false // request error disables the channel if openaiWithStatusErr != nil { oaiErr := openaiWithStatusErr.Error err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message)) shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) } if milliseconds > disableThreshold { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) shouldBanChannel = true } // disable channel if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { service.DisableChannel(channel.Id, channel.Name, err.Error()) } // enable channel if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) { service.EnableChannel(channel.Id, channel.Name) } channel.UpdateResponseTime(milliseconds) time.Sleep(common.RequestInterval) } testAllChannelsLock.Lock() testAllChannelsRunning = false testAllChannelsLock.Unlock() if notify { service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成") } }) return nil } func TestAllChannels(c *gin.Context) { err := testAllChannels(true) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func AutomaticallyTestChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) common.SysLog("testing all channels") _ = testAllChannels(false) common.SysLog("channel test finished") } } ``` ## /controller/channel.go ```go path="/controller/channel.go" package controller import ( "encoding/json" "fmt" "net/http" "one-api/common" "one-api/model" "strconv" "strings" "github.com/gin-gonic/gin" ) type OpenAIModel struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` OwnedBy string `json:"owned_by"` Permission []struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` AllowCreateEngine bool `json:"allow_create_engine"` AllowSampling bool `json:"allow_sampling"` AllowLogprobs bool `json:"allow_logprobs"` AllowSearchIndices bool `json:"allow_search_indices"` AllowView bool `json:"allow_view"` AllowFineTuning bool `json:"allow_fine_tuning"` Organization string `json:"organization"` Group string `json:"group"` IsBlocking bool `json:"is_blocking"` } `json:"permission"` Root string `json:"root"` Parent string `json:"parent"` } type OpenAIModelsResponse struct { Data []OpenAIModel `json:"data"` Success bool `json:"success"` } func GetAllChannels(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) pageSize, _ := strconv.Atoi(c.Query("page_size")) if p < 0 { p = 0 } if pageSize < 0 { pageSize = common.ItemsPerPage } channelData := make([]*model.Channel, 0) idSort, _ := strconv.ParseBool(c.Query("id_sort")) enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) if enableTagMode { tags, err := model.GetPaginatedTags(p*pageSize, pageSize) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } for _, tag := range tags { if tag != nil && *tag != "" { tagChannel, err := model.GetChannelsByTag(*tag, idSort) if err == nil { channelData = append(channelData, tagChannel...) } } } } else { channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channelData = channels } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": channelData, }) return } func FetchUpstreamModels(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channel, err := model.GetChannelById(id, true) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } //if channel.Type != common.ChannelTypeOpenAI { // c.JSON(http.StatusOK, gin.H{ // "success": false, // "message": "仅支持 OpenAI 类型渠道", // }) // return //} baseURL := common.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } url := fmt.Sprintf("%s/v1/models", baseURL) if channel.Type == common.ChannelTypeGemini { url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) } body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } var result OpenAIModelsResponse if err = json.Unmarshal(body, &result); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": fmt.Sprintf("解析响应失败: %s", err.Error()), }) return } var ids []string for _, model := range result.Data { id := model.ID if channel.Type == common.ChannelTypeGemini { id = strings.TrimPrefix(id, "models/") } ids = append(ids, id) } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": ids, }) } func FixChannelsAbilities(c *gin.Context) { count, err := model.FixAbility() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": count, }) } func SearchChannels(c *gin.Context) { keyword := c.Query("keyword") group := c.Query("group") modelKeyword := c.Query("model") idSort, _ := strconv.ParseBool(c.Query("id_sort")) enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) channelData := make([]*model.Channel, 0) if enableTagMode { tags, err := model.SearchTags(keyword, group, modelKeyword, idSort) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } for _, tag := range tags { if tag != nil && *tag != "" { tagChannel, err := model.GetChannelsByTag(*tag, idSort) if err == nil { channelData = append(channelData, tagChannel...) } } } } else { channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channelData = channels } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": channelData, }) return } func GetChannel(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channel, err := model.GetChannelById(id, false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": channel, }) return } func AddChannel(c *gin.Context) { channel := model.Channel{} err := c.ShouldBindJSON(&channel) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channel.CreatedTime = common.GetTimestamp() keys := strings.Split(channel.Key, "\n") if channel.Type == common.ChannelTypeVertexAi { if channel.Other == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "部署地区不能为空", }) return } else { if common.IsJsonStr(channel.Other) { // must have default regionMap := common.StrToMap(channel.Other) if regionMap["default"] == nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "部署地区必须包含default字段", }) return } } } keys = []string{channel.Key} } channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { if key == "" { continue } localChannel := channel localChannel.Key = key // Validate the length of the model name models := strings.Split(localChannel.Models, ",") for _, model := range models { if len(model) > 255 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": fmt.Sprintf("模型名称过长: %s", model), }) return } } channels = append(channels, localChannel) } err = model.BatchInsertChannels(channels) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func DeleteChannel(c *gin.Context) { id, _ := strconv.Atoi(c.Param("id")) channel := model.Channel{Id: id} err := channel.Delete() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func DeleteDisabledChannel(c *gin.Context) { rows, err := model.DeleteDisabledChannel() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": rows, }) return } type ChannelTag struct { Tag string `json:"tag"` NewTag *string `json:"new_tag"` Priority *int64 `json:"priority"` Weight *uint `json:"weight"` ModelMapping *string `json:"model_mapping"` Models *string `json:"models"` Groups *string `json:"groups"` } func DisableTagChannels(c *gin.Context) { channelTag := ChannelTag{} err := c.ShouldBindJSON(&channelTag) if err != nil || channelTag.Tag == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } err = model.DisableChannelByTag(channelTag.Tag) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func EnableTagChannels(c *gin.Context) { channelTag := ChannelTag{} err := c.ShouldBindJSON(&channelTag) if err != nil || channelTag.Tag == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } err = model.EnableChannelByTag(channelTag.Tag) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func EditTagChannels(c *gin.Context) { channelTag := ChannelTag{} err := c.ShouldBindJSON(&channelTag) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } if channelTag.Tag == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "tag不能为空", }) return } err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } type ChannelBatch struct { Ids []int `json:"ids"` Tag *string `json:"tag"` } func DeleteChannelBatch(c *gin.Context) { channelBatch := ChannelBatch{} err := c.ShouldBindJSON(&channelBatch) if err != nil || len(channelBatch.Ids) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } err = model.BatchDeleteChannels(channelBatch.Ids) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": len(channelBatch.Ids), }) return } func UpdateChannel(c *gin.Context) { channel := model.Channel{} err := c.ShouldBindJSON(&channel) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } if channel.Type == common.ChannelTypeVertexAi { if channel.Other == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "部署地区不能为空", }) return } else { if common.IsJsonStr(channel.Other) { // must have default regionMap := common.StrToMap(channel.Other) if regionMap["default"] == nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "部署地区必须包含default字段", }) return } } } } err = channel.Update() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": channel, }) return } func FetchModels(c *gin.Context) { var req struct { BaseURL string `json:"base_url"` Type int `json:"type"` Key string `json:"key"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "Invalid request", }) return } baseURL := req.BaseURL if baseURL == "" { baseURL = common.ChannelBaseURLs[req.Type] } client := &http.Client{} url := fmt.Sprintf("%s/v1/models", baseURL) request, err := http.NewRequest("GET", url, nil) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": err.Error(), }) return } // remove line breaks and extra spaces. key := strings.TrimSpace(req.Key) // If the key contains a line break, only take the first part. key = strings.Split(key, "\n")[0] request.Header.Set("Authorization", "Bearer "+key) response, err := client.Do(request) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": err.Error(), }) return } //check status code if response.StatusCode != http.StatusOK { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": "Failed to fetch models", }) return } defer response.Body.Close() var result struct { Data []struct { ID string `json:"id"` } `json:"data"` } if err := json.NewDecoder(response.Body).Decode(&result); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": err.Error(), }) return } var models []string for _, model := range result.Data { models = append(models, model.ID) } c.JSON(http.StatusOK, gin.H{ "success": true, "data": models, }) } func BatchSetChannelTag(c *gin.Context) { channelBatch := ChannelBatch{} err := c.ShouldBindJSON(&channelBatch) if err != nil || len(channelBatch.Ids) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": len(channelBatch.Ids), }) return } ``` ## /controller/github.go ```go path="/controller/github.go" package controller import ( "bytes" "encoding/json" "errors" "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "time" ) type GitHubOAuthResponse struct { AccessToken string `json:"access_token"` Scope string `json:"scope"` TokenType string `json:"token_type"` } type GitHubUser struct { Login string `json:"login"` Name string `json:"name"` Email string `json:"email"` } func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { if code == "" { return nil, errors.New("无效的参数") } values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} jsonData, err := json.Marshal(values) if err != nil { return nil, err } req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") client := http.Client{ Timeout: 5 * time.Second, } res, err := client.Do(req) if err != nil { common.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res.Body.Close() var oAuthResponse GitHubOAuthResponse err = json.NewDecoder(res.Body).Decode(&oAuthResponse) if err != nil { return nil, err } req, err = http.NewRequest("GET", "https://api.github.com/user", nil) if err != nil { return nil, err } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) res2, err := client.Do(req) if err != nil { common.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res2.Body.Close() var githubUser GitHubUser err = json.NewDecoder(res2.Body).Decode(&githubUser) if err != nil { return nil, err } if githubUser.Login == "" { return nil, errors.New("返回值非法,用户字段为空,请稍后重试!") } return &githubUser, nil } func GitHubOAuth(c *gin.Context) { session := sessions.Default(c) state := c.Query("state") if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { c.JSON(http.StatusForbidden, gin.H{ "success": false, "message": "state is empty or not same", }) return } username := session.Get("username") if username != nil { GitHubBind(c) return } if !common.GitHubOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 GitHub 登录以及注册", }) return } code := c.Query("code") githubUser, err := getGitHubUserInfoByCode(code) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } user := model.User{ GitHubId: githubUser.Login, } // IsGitHubIdAlreadyTaken is unscoped if model.IsGitHubIdAlreadyTaken(user.GitHubId) { // FillUserByGitHubId is scoped err := user.FillUserByGitHubId() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } // if user.Id == 0 , user has been deleted if user.Id == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "用户已注销", }) return } } else { if common.RegisterEnabled { user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) if githubUser.Name != "" { user.DisplayName = githubUser.Name } else { user.DisplayName = "GitHub User" } user.Email = githubUser.Email user.Role = common.RoleCommonUser user.Status = common.UserStatusEnabled affCode := session.Get("aff") inviterId := 0 if affCode != nil { inviterId, _ = model.GetUserIdByAffCode(affCode.(string)) } if err := user.Insert(inviterId); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } } else { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员关闭了新用户注册", }) return } } if user.Status != common.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, }) return } setupLogin(&user, c) } func GitHubBind(c *gin.Context) { if !common.GitHubOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 GitHub 登录以及注册", }) return } code := c.Query("code") githubUser, err := getGitHubUserInfoByCode(code) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } user := model.User{ GitHubId: githubUser.Login, } if model.IsGitHubIdAlreadyTaken(user.GitHubId) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该 GitHub 账户已被绑定", }) return } session := sessions.Default(c) id := session.Get("id") // id := c.GetInt("id") // critical bug! user.Id = id.(int) err = user.FillUserById() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } user.GitHubId = githubUser.Login err = user.Update(false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "bind", }) return } func GenerateOAuthCode(c *gin.Context) { session := sessions.Default(c) state := common.GetRandomString(12) affCode := c.Query("aff") if affCode != "" { session.Set("aff", affCode) } session.Set("oauth_state", state) err := session.Save() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": state, }) } ``` ## /controller/group.go ```go path="/controller/group.go" package controller import ( "github.com/gin-gonic/gin" "net/http" "one-api/model" "one-api/setting" ) func GetGroups(c *gin.Context) { groupNames := make([]string, 0) for groupName, _ := range setting.GetGroupRatioCopy() { groupNames = append(groupNames, groupName) } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": groupNames, }) } func GetUserGroups(c *gin.Context) { usableGroups := make(map[string]map[string]interface{}) userGroup := "" userId := c.GetInt("id") userGroup, _ = model.GetUserGroup(userId, false) for groupName, ratio := range setting.GetGroupRatioCopy() { // UserUsableGroups contains the groups that the user can use userUsableGroups := setting.GetUserUsableGroups(userGroup) if desc, ok := userUsableGroups[groupName]; ok { usableGroups[groupName] = map[string]interface{}{ "ratio": ratio, "desc": desc, } } } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": usableGroups, }) } ``` ## /controller/image.go ```go path="/controller/image.go" package controller import ( "github.com/gin-gonic/gin" ) func GetImage(c *gin.Context) { } ``` ## /controller/linuxdo.go ```go path="/controller/linuxdo.go" package controller import ( "encoding/base64" "encoding/json" "errors" "fmt" "net/http" "net/url" "one-api/common" "one-api/model" "strconv" "strings" "time" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" ) type LinuxdoUser struct { Id int `json:"id"` Username string `json:"username"` Name string `json:"name"` Active bool `json:"active"` TrustLevel int `json:"trust_level"` Silenced bool `json:"silenced"` } func LinuxDoBind(c *gin.Context) { if !common.LinuxDOOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 Linux DO 登录以及注册", }) return } code := c.Query("code") linuxdoUser, err := getLinuxdoUserInfoByCode(code, c) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } user := model.User{ LinuxDOId: strconv.Itoa(linuxdoUser.Id), } if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该 Linux DO 账户已被绑定", }) return } session := sessions.Default(c) id := session.Get("id") user.Id = id.(int) err = user.FillUserById() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } user.LinuxDOId = strconv.Itoa(linuxdoUser.Id) err = user.Update(false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "bind", }) } func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) { if code == "" { return nil, errors.New("invalid code") } // Get access token using Basic auth tokenEndpoint := "https://connect.linux.do/oauth2/token" credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials)) // Get redirect URI from request scheme := "http" if c.Request.TLS != nil { scheme = "https" } redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host) data := url.Values{} data.Set("grant_type", "authorization_code") data.Set("code", code) data.Set("redirect_uri", redirectURI) req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode())) if err != nil { return nil, err } req.Header.Set("Authorization", basicAuth) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") client := http.Client{Timeout: 5 * time.Second} res, err := client.Do(req) if err != nil { return nil, errors.New("failed to connect to Linux DO server") } defer res.Body.Close() var tokenRes struct { AccessToken string `json:"access_token"` Message string `json:"message"` } if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil { return nil, err } if tokenRes.AccessToken == "" { return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message) } // Get user info userEndpoint := "https://connect.linux.do/api/user" req, err = http.NewRequest("GET", userEndpoint, nil) if err != nil { return nil, err } req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken) req.Header.Set("Accept", "application/json") res2, err := client.Do(req) if err != nil { return nil, errors.New("failed to get user info from Linux DO") } defer res2.Body.Close() var linuxdoUser LinuxdoUser if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil { return nil, err } if linuxdoUser.Id == 0 { return nil, errors.New("invalid user info returned") } return &linuxdoUser, nil } func LinuxdoOAuth(c *gin.Context) { session := sessions.Default(c) errorCode := c.Query("error") if errorCode != "" { errorDescription := c.Query("error_description") c.JSON(http.StatusOK, gin.H{ "success": false, "message": errorDescription, }) return } state := c.Query("state") if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { c.JSON(http.StatusForbidden, gin.H{ "success": false, "message": "state is empty or not same", }) return } username := session.Get("username") if username != nil { LinuxDoBind(c) return } if !common.LinuxDOOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 Linux DO 登录以及注册", }) return } code := c.Query("code") linuxdoUser, err := getLinuxdoUserInfoByCode(code, c) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } user := model.User{ LinuxDOId: strconv.Itoa(linuxdoUser.Id), } // Check if user exists if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) { err := user.FillUserByLinuxDOId() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } if user.Id == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "用户已注销", }) return } } else { if common.RegisterEnabled { user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1) user.DisplayName = linuxdoUser.Name user.Role = common.RoleCommonUser user.Status = common.UserStatusEnabled affCode := session.Get("aff") inviterId := 0 if affCode != nil { inviterId, _ = model.GetUserIdByAffCode(affCode.(string)) } if err := user.Insert(inviterId); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } } else { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员关闭了新用户注册", }) return } } if user.Status != common.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, }) return } setupLogin(&user, c) } ``` ## /controller/log.go ```go path="/controller/log.go" package controller import ( "net/http" "one-api/common" "one-api/model" "strconv" "github.com/gin-gonic/gin" ) func GetAllLogs(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) pageSize, _ := strconv.Atoi(c.Query("page_size")) if p < 1 { p = 1 } if pageSize < 0 { pageSize = common.ItemsPerPage } logType, _ := strconv.Atoi(c.Query("type")) startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) username := c.Query("username") tokenName := c.Query("token_name") modelName := c.Query("model_name") channel, _ := strconv.Atoi(c.Query("channel")) group := c.Query("group") logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel, group) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": map[string]any{ "items": logs, "total": total, "page": p, "page_size": pageSize, }, }) } func GetUserLogs(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) pageSize, _ := strconv.Atoi(c.Query("page_size")) if p < 1 { p = 1 } if pageSize < 0 { pageSize = common.ItemsPerPage } if pageSize > 100 { pageSize = 100 } userId := c.GetInt("id") logType, _ := strconv.Atoi(c.Query("type")) startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") modelName := c.Query("model_name") group := c.Query("group") logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize, group) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": map[string]any{ "items": logs, "total": total, "page": p, "page_size": pageSize, }, }) return } func SearchAllLogs(c *gin.Context) { keyword := c.Query("keyword") logs, err := model.SearchAllLogs(keyword) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) return } func SearchUserLogs(c *gin.Context) { keyword := c.Query("keyword") userId := c.GetInt("id") logs, err := model.SearchUserLogs(userId, keyword) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) return } func GetLogByKey(c *gin.Context) { key := c.Query("key") logs, err := model.GetLogByKey(key) if err != nil { c.JSON(200, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(200, gin.H{ "success": true, "message": "", "data": logs, }) } func GetLogsStat(c *gin.Context) { logType, _ := strconv.Atoi(c.Query("type")) startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") username := c.Query("username") modelName := c.Query("model_name") channel, _ := strconv.Atoi(c.Query("channel")) group := c.Query("group") stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ "quota": stat.Quota, "rpm": stat.Rpm, "tpm": stat.Tpm, }, }) return } func GetLogsSelfStat(c *gin.Context) { username := c.GetString("username") logType, _ := strconv.Atoi(c.Query("type")) startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") modelName := c.Query("model_name") channel, _ := strconv.Atoi(c.Query("channel")) group := c.Query("group") quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) c.JSON(200, gin.H{ "success": true, "message": "", "data": gin.H{ "quota": quotaNum.Quota, "rpm": quotaNum.Rpm, "tpm": quotaNum.Tpm, //"token": tokenNum, }, }) return } func DeleteHistoryLogs(c *gin.Context) { targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64) if targetTimestamp == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "target timestamp is required", }) return } count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": count, }) return } ``` ## /controller/midjourney.go ```go path="/controller/midjourney.go" package controller import ( "bytes" "context" "encoding/json" "fmt" "github.com/gin-gonic/gin" "io" "log" "net/http" "one-api/common" "one-api/dto" "one-api/model" "one-api/service" "one-api/setting" "strconv" "time" ) func UpdateMidjourneyTaskBulk() { //imageModel := "midjourney" ctx := context.TODO() for { time.Sleep(time.Duration(15) * time.Second) tasks := model.GetAllUnFinishTasks() if len(tasks) == 0 { continue } common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) taskChannelM := make(map[int][]string) taskM := make(map[string]*model.Midjourney) nullTaskIds := make([]int, 0) for _, task := range tasks { if task.MjId == "" { // 统计失败的未完成任务 nullTaskIds = append(nullTaskIds, task.Id) continue } taskM[task.MjId] = task taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId) } if len(nullTaskIds) > 0 { err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{ "status": "FAILURE", "progress": "100%", }) if err != nil { common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) } else { common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) } } if len(taskChannelM) == 0 { continue } for channelId, taskIds := range taskChannelM { common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) if len(taskIds) == 0 { continue } midjourneyChannel, err := model.CacheGetChannel(channelId) if err != nil { common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err)) err := model.MjBulkUpdate(taskIds, map[string]any{ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), "status": "FAILURE", "progress": "100%", }) if err != nil { common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) } continue } requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL) body, _ := json.Marshal(map[string]any{ "ids": taskIds, }) req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body)) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) continue } // 设置超时时间 timeout := time.Second * 15 ctx, cancel := context.WithTimeout(context.Background(), timeout) // 使用带有超时的 context 创建新的请求 req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/json") req.Header.Set("mj-api-secret", midjourneyChannel.Key) resp, err := service.GetHttpClient().Do(req) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) continue } if resp.StatusCode != http.StatusOK { common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) continue } responseBody, err := io.ReadAll(resp.Body) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) continue } var responseItems []dto.MidjourneyDto err = json.Unmarshal(responseBody, &responseItems) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) continue } resp.Body.Close() req.Body.Close() cancel() for _, responseItem := range responseItems { task := taskM[responseItem.MjId] useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime // 如果时间超过一小时,且进度不是100%,则认为任务失败 if useTime > 3600000 && task.Progress != "100%" { responseItem.FailReason = "上游任务超时(超过1小时)" responseItem.Status = "FAILURE" } if !checkMjTaskNeedUpdate(task, responseItem) { continue } task.Code = 1 task.Progress = responseItem.Progress task.PromptEn = responseItem.PromptEn task.State = responseItem.State task.SubmitTime = responseItem.SubmitTime task.StartTime = responseItem.StartTime task.FinishTime = responseItem.FinishTime task.ImageUrl = responseItem.ImageUrl task.Status = responseItem.Status task.FailReason = responseItem.FailReason if responseItem.Properties != nil { propertiesStr, _ := json.Marshal(responseItem.Properties) task.Properties = string(propertiesStr) } if responseItem.Buttons != nil { buttonStr, _ := json.Marshal(responseItem.Buttons) task.Buttons = string(buttonStr) } shouldReturnQuota := false if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" if task.Quota != 0 { shouldReturnQuota = true } } err = task.Update() if err != nil { common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) } else { if shouldReturnQuota { err = model.IncreaseUserQuota(task.UserId, task.Quota, false) if err != nil { common.LogError(ctx, "fail to increase user quota: "+err.Error()) } logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } } } } } func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool { if oldTask.Code != 1 { return true } if oldTask.Progress != newTask.Progress { return true } if oldTask.PromptEn != newTask.PromptEn { return true } if oldTask.State != newTask.State { return true } if oldTask.SubmitTime != newTask.SubmitTime { return true } if oldTask.StartTime != newTask.StartTime { return true } if oldTask.FinishTime != newTask.FinishTime { return true } if oldTask.ImageUrl != newTask.ImageUrl { return true } if oldTask.Status != newTask.Status { return true } if oldTask.FailReason != newTask.FailReason { return true } if oldTask.FinishTime != newTask.FinishTime { return true } if oldTask.Progress != "100%" && newTask.FailReason != "" { return true } return false } func GetAllMidjourney(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) if p < 0 { p = 0 } // 解析其他查询参数 queryParams := model.TaskQueryParams{ ChannelID: c.Query("channel_id"), MjID: c.Query("mj_id"), StartTimestamp: c.Query("start_timestamp"), EndTimestamp: c.Query("end_timestamp"), } logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams) if logs == nil { logs = make([]*model.Midjourney, 0) } if setting.MjForwardUrlEnabled { for i, midjourney := range logs { midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId logs[i] = midjourney } } c.JSON(200, gin.H{ "success": true, "message": "", "data": logs, }) } func GetUserMidjourney(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) if p < 0 { p = 0 } userId := c.GetInt("id") log.Printf("userId = %d \n", userId) queryParams := model.TaskQueryParams{ MjID: c.Query("mj_id"), StartTimestamp: c.Query("start_timestamp"), EndTimestamp: c.Query("end_timestamp"), } logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams) if logs == nil { logs = make([]*model.Midjourney, 0) } if setting.MjForwardUrlEnabled { for i, midjourney := range logs { midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId logs[i] = midjourney } } c.JSON(200, gin.H{ "success": true, "message": "", "data": logs, }) } ``` ## /controller/misc.go ```go path="/controller/misc.go" package controller import ( "encoding/json" "fmt" "net/http" "one-api/common" "one-api/constant" "one-api/model" "one-api/setting" "one-api/setting/operation_setting" "one-api/setting/system_setting" "strings" "github.com/gin-gonic/gin" ) func TestStatus(c *gin.Context) { err := model.PingDB() if err != nil { c.JSON(http.StatusServiceUnavailable, gin.H{ "success": false, "message": "数据库连接失败", }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "Server is running", }) return } func GetStatus(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ "version": common.Version, "start_time": common.StartTime, "email_verification": common.EmailVerificationEnabled, "github_oauth": common.GitHubOAuthEnabled, "github_client_id": common.GitHubClientId, "linuxdo_oauth": common.LinuxDOOAuthEnabled, "linuxdo_client_id": common.LinuxDOClientId, "telegram_oauth": common.TelegramOAuthEnabled, "telegram_bot_name": common.TelegramBotName, "system_name": common.SystemName, "logo": common.Logo, "footer_html": common.Footer, "wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_login": common.WeChatAuthEnabled, "server_address": setting.ServerAddress, "price": setting.Price, "min_topup": setting.MinTopUp, "turnstile_check": common.TurnstileCheckEnabled, "turnstile_site_key": common.TurnstileSiteKey, "top_up_link": common.TopUpLink, "docs_link": operation_setting.GetGeneralSetting().DocsLink, "quota_per_unit": common.QuotaPerUnit, "display_in_currency": common.DisplayInCurrencyEnabled, "enable_batch_update": common.BatchUpdateEnabled, "enable_drawing": common.DrawingEnabled, "enable_task": common.TaskEnabled, "enable_data_export": common.DataExportEnabled, "data_export_default_time": common.DataExportDefaultTime, "default_collapse_sidebar": common.DefaultCollapseSidebar, "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "", "mj_notify_enabled": setting.MjNotifyEnabled, "chats": setting.Chats, "demo_site_enabled": operation_setting.DemoSiteEnabled, "self_use_mode_enabled": operation_setting.SelfUseModeEnabled, "oidc_enabled": system_setting.GetOIDCSettings().Enabled, "oidc_client_id": system_setting.GetOIDCSettings().ClientId, "oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint, "setup": constant.Setup, }, }) return } func GetNotice(c *gin.Context) { common.OptionMapRWMutex.RLock() defer common.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": common.OptionMap["Notice"], }) return } func GetAbout(c *gin.Context) { common.OptionMapRWMutex.RLock() defer common.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": common.OptionMap["About"], }) return } func GetMidjourney(c *gin.Context) { common.OptionMapRWMutex.RLock() defer common.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": common.OptionMap["Midjourney"], }) return } func GetHomePageContent(c *gin.Context) { common.OptionMapRWMutex.RLock() defer common.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": common.OptionMap["HomePageContent"], }) return } func SendEmailVerification(c *gin.Context) { email := c.Query("email") if err := common.Validate.Var(email, "required,email"); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无效的参数", }) return } parts := strings.Split(email, "@") if len(parts) != 2 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无效的邮箱地址", }) return } localPart := parts[0] domainPart := parts[1] if common.EmailDomainRestrictionEnabled { allowed := false for _, domain := range common.EmailDomainWhitelist { if domainPart == domain { allowed = true break } } if !allowed { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.", }) return } } if common.EmailAliasRestrictionEnabled { containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Contains(localPart, ".") if containsSpecialSymbols { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员已启用邮箱地址别名限制,您的邮箱地址由于包含特殊符号而被拒绝。", }) return } } if model.IsEmailAlreadyTaken(email) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "邮箱地址已被占用", }) return } code := common.GenerateVerificationCode(6) common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) content := fmt.Sprintf("

您好,你正在进行%s邮箱验证。

"+ "

您的验证码为: %s

"+ "

验证码 %d 分钟内有效,如果不是本人操作,请忽略。

", common.SystemName, code, common.VerificationValidMinutes) err := common.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func SendPasswordResetEmail(c *gin.Context) { email := c.Query("email") if err := common.Validate.Var(email, "required,email"); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无效的参数", }) return } if !model.IsEmailAlreadyTaken(email) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该邮箱地址未注册", }) return } code := common.GenerateVerificationCode(0) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code) subject := fmt.Sprintf("%s密码重置", common.SystemName) content := fmt.Sprintf("

您好,你正在进行%s密码重置。

"+ "

点击 此处 进行密码重置。

"+ "

如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s

"+ "

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

", common.SystemName, link, link, common.VerificationValidMinutes) err := common.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } type PasswordResetRequest struct { Email string `json:"email"` Token string `json:"token"` } func ResetPassword(c *gin.Context) { var req PasswordResetRequest err := json.NewDecoder(c.Request.Body).Decode(&req) if req.Email == "" || req.Token == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无效的参数", }) return } if !common.VerifyCodeWithKey(req.Email, req.Token, common.PasswordResetPurpose) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "重置链接非法或已过期", }) return } password := common.GenerateVerificationCode(12) err = model.ResetUserPasswordByEmail(req.Email, password) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } common.DeleteKey(req.Email, common.PasswordResetPurpose) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": password, }) return } ``` ## /controller/model.go ```go path="/controller/model.go" package controller import ( "fmt" "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/constant" "one-api/dto" "one-api/model" "one-api/relay" "one-api/relay/channel/ai360" "one-api/relay/channel/lingyiwanwu" "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" ) // https://platform.openai.com/docs/api-reference/models/list var openAIModels []dto.OpenAIModels var openAIModelsMap map[string]dto.OpenAIModels var channelId2Models map[int][]string func getPermission() []dto.OpenAIModelPermission { var permission []dto.OpenAIModelPermission permission = append(permission, dto.OpenAIModelPermission{ Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ", Object: "model_permission", Created: 1626777600, AllowCreateEngine: true, AllowSampling: true, AllowLogprobs: true, AllowSearchIndices: false, AllowView: true, AllowFineTuning: false, Organization: "*", Group: nil, IsBlocking: false, }) return permission } func init() { // https://platform.openai.com/docs/models/model-endpoint-compatibility permission := getPermission() for i := 0; i < relayconstant.APITypeDummy; i++ { if i == relayconstant.APITypeAIProxyLibrary { continue } adaptor := relay.GetAdaptor(i) channelName := adaptor.GetChannelName() modelNames := adaptor.GetModelList() for _, modelName := range modelNames { openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: channelName, Permission: permission, Root: modelName, Parent: nil, }) } } for _, modelName := range ai360.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: ai360.ChannelName, Permission: permission, Root: modelName, Parent: nil, }) } for _, modelName := range moonshot.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: moonshot.ChannelName, Permission: permission, Root: modelName, Parent: nil, }) } for _, modelName := range lingyiwanwu.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: lingyiwanwu.ChannelName, Permission: permission, Root: modelName, Parent: nil, }) } for _, modelName := range minimax.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: minimax.ChannelName, Permission: permission, Root: modelName, Parent: nil, }) } for modelName, _ := range constant.MidjourneyModel2Action { openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, OwnedBy: "midjourney", Permission: permission, Root: modelName, Parent: nil, }) } openAIModelsMap = make(map[string]dto.OpenAIModels) for _, aiModel := range openAIModels { openAIModelsMap[aiModel.Id] = aiModel } channelId2Models = make(map[int][]string) for i := 1; i <= common.ChannelTypeDummy; i++ { apiType, success := relayconstant.ChannelType2APIType(i) if !success || apiType == relayconstant.APITypeAIProxyLibrary { continue } meta := &relaycommon.RelayInfo{ChannelType: i} adaptor := relay.GetAdaptor(apiType) adaptor.Init(meta) channelId2Models[i] = adaptor.GetModelList() } } func ListModels(c *gin.Context) { userOpenAiModels := make([]dto.OpenAIModels, 0) permission := getPermission() modelLimitEnable := c.GetBool("token_model_limit_enabled") if modelLimitEnable { s, ok := c.Get("token_model_limit") var tokenModelLimit map[string]bool if ok { tokenModelLimit = s.(map[string]bool) } else { tokenModelLimit = map[string]bool{} } for allowModel, _ := range tokenModelLimit { if _, ok := openAIModelsMap[allowModel]; ok { userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel]) } else { userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ Id: allowModel, Object: "model", Created: 1626777600, OwnedBy: "custom", Permission: permission, Root: allowModel, Parent: nil, }) } } } else { userId := c.GetInt("id") userGroup, err := model.GetUserGroup(userId, true) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "get user group failed", }) return } group := userGroup tokenGroup := c.GetString("token_group") if tokenGroup != "" { group = tokenGroup } models := model.GetGroupModels(group) for _, s := range models { if _, ok := openAIModelsMap[s]; ok { userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s]) } else { userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ Id: s, Object: "model", Created: 1626777600, OwnedBy: "custom", Permission: permission, Root: s, Parent: nil, }) } } } c.JSON(200, gin.H{ "success": true, "data": userOpenAiModels, }) } func ChannelListModels(c *gin.Context) { c.JSON(200, gin.H{ "success": true, "data": openAIModels, }) } func DashboardListModels(c *gin.Context) { c.JSON(200, gin.H{ "success": true, "data": channelId2Models, }) } func EnabledListModels(c *gin.Context) { c.JSON(200, gin.H{ "success": true, "data": model.GetEnabledModels(), }) } func RetrieveModel(c *gin.Context) { modelId := c.Param("model") if aiModel, ok := openAIModelsMap[modelId]; ok { c.JSON(200, aiModel) } else { openAIError := dto.OpenAIError{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), Type: "invalid_request_error", Param: "model", Code: "model_not_found", } c.JSON(200, gin.H{ "error": openAIError, }) } } ``` ## /controller/oidc.go ```go path="/controller/oidc.go" package controller import ( "encoding/json" "errors" "fmt" "net/http" "net/url" "one-api/common" "one-api/model" "one-api/setting" "one-api/setting/system_setting" "strconv" "strings" "time" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" ) type OidcResponse struct { AccessToken string `json:"access_token"` IDToken string `json:"id_token"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` Scope string `json:"scope"` } type OidcUser struct { OpenID string `json:"sub"` Email string `json:"email"` Name string `json:"name"` PreferredUsername string `json:"preferred_username"` Picture string `json:"picture"` } func getOidcUserInfoByCode(code string) (*OidcUser, error) { if code == "" { return nil, errors.New("无效的参数") } values := url.Values{} values.Set("client_id", system_setting.GetOIDCSettings().ClientId) values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret) values.Set("code", code) values.Set("grant_type", "authorization_code") values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress)) formData := values.Encode() req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") client := http.Client{ Timeout: 5 * time.Second, } res, err := client.Do(req) if err != nil { common.SysLog(err.Error()) return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") } defer res.Body.Close() var oidcResponse OidcResponse err = json.NewDecoder(res.Body).Decode(&oidcResponse) if err != nil { return nil, err } if oidcResponse.AccessToken == "" { common.SysError("OIDC 获取 Token 失败,请检查设置!") return nil, errors.New("OIDC 获取 Token 失败,请检查设置!") } req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil) if err != nil { return nil, err } req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) res2, err := client.Do(req) if err != nil { common.SysLog(err.Error()) return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") } defer res2.Body.Close() if res2.StatusCode != http.StatusOK { common.SysError("OIDC 获取用户信息失败!请检查设置!") return nil, errors.New("OIDC 获取用户信息失败!请检查设置!") } var oidcUser OidcUser err = json.NewDecoder(res2.Body).Decode(&oidcUser) if err != nil { return nil, err } if oidcUser.OpenID == "" || oidcUser.Email == "" { common.SysError("OIDC 获取用户信息为空!请检查设置!") return nil, errors.New("OIDC 获取用户信息为空!请检查设置!") } return &oidcUser, nil } func OidcAuth(c *gin.Context) { session := sessions.Default(c) state := c.Query("state") if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { c.JSON(http.StatusForbidden, gin.H{ "success": false, "message": "state is empty or not same", }) return } username := session.Get("username") if username != nil { OidcBind(c) return } if !system_setting.GetOIDCSettings().Enabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 OIDC 登录以及注册", }) return } code := c.Query("code") oidcUser, err := getOidcUserInfoByCode(code) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } user := model.User{ OidcId: oidcUser.OpenID, } if model.IsOidcIdAlreadyTaken(user.OidcId) { err := user.FillUserByOidcId() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } } else { if common.RegisterEnabled { user.Email = oidcUser.Email if oidcUser.PreferredUsername != "" { user.Username = oidcUser.PreferredUsername } else { user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) } if oidcUser.Name != "" { user.DisplayName = oidcUser.Name } else { user.DisplayName = "OIDC User" } err := user.Insert(0) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } } else { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员关闭了新用户注册", }) return } } if user.Status != common.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, }) return } setupLogin(&user, c) } func OidcBind(c *gin.Context) { if !system_setting.GetOIDCSettings().Enabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 OIDC 登录以及注册", }) return } code := c.Query("code") oidcUser, err := getOidcUserInfoByCode(code) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } user := model.User{ OidcId: oidcUser.OpenID, } if model.IsOidcIdAlreadyTaken(user.OidcId) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该 OIDC 账户已被绑定", }) return } session := sessions.Default(c) id := session.Get("id") // id := c.GetInt("id") // critical bug! user.Id = id.(int) err = user.FillUserById() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } user.OidcId = oidcUser.OpenID err = user.Update(false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "bind", }) return } ``` ## /controller/option.go ```go path="/controller/option.go" package controller import ( "encoding/json" "net/http" "one-api/common" "one-api/model" "one-api/setting" "one-api/setting/system_setting" "strings" "github.com/gin-gonic/gin" ) func GetOptions(c *gin.Context) { var options []*model.Option common.OptionMapRWMutex.Lock() for k, v := range common.OptionMap { if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") { continue } options = append(options, &model.Option{ Key: k, Value: common.Interface2String(v), }) } common.OptionMapRWMutex.Unlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": options, }) return } func UpdateOption(c *gin.Context) { var option model.Option err := json.NewDecoder(c.Request.Body).Decode(&option) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "无效的参数", }) return } switch option.Key { case "GitHubOAuthEnabled": if option.Value == "true" && common.GitHubClientId == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", }) return } case "oidc.enabled": if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 OIDC 登录,请先填入 OIDC Client Id 以及 OIDC Client Secret!", }) return } case "LinuxDOOAuthEnabled": if option.Value == "true" && common.LinuxDOClientId == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 LinuxDO OAuth,请先填入 LinuxDO Client Id 以及 LinuxDO Client Secret!", }) return } case "EmailDomainRestrictionEnabled": if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", }) return } case "WeChatAuthEnabled": if option.Value == "true" && common.WeChatServerAddress == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用微信登录,请先填入微信登录相关配置信息!", }) return } case "TurnstileCheckEnabled": if option.Value == "true" && common.TurnstileSiteKey == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", }) return } case "TelegramOAuthEnabled": if option.Value == "true" && common.TelegramBotToken == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 Telegram OAuth,请先填入 Telegram Bot Token!", }) return } case "GroupRatio": err = setting.CheckGroupRatio(option.Value) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } } err = model.UpdateOption(option.Key, option.Value) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } ``` ## /controller/playground.go ```go path="/controller/playground.go" package controller import ( "errors" "fmt" "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/constant" "one-api/dto" "one-api/middleware" "one-api/model" "one-api/service" "one-api/setting" "time" ) func Playground(c *gin.Context) { var openaiErr *dto.OpenAIErrorWithStatusCode defer func() { if openaiErr != nil { c.JSON(openaiErr.StatusCode, gin.H{ "error": openaiErr.Error, }) } }() useAccessToken := c.GetBool("use_access_token") if useAccessToken { openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest) return } playgroundRequest := &dto.PlayGroundRequest{} err := common.UnmarshalBodyReusable(c, playgroundRequest) if err != nil { openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest) return } if playgroundRequest.Model == "" { openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest) return } c.Set("original_model", playgroundRequest.Model) group := playgroundRequest.Group userGroup := c.GetString("group") if group == "" { group = userGroup } else { if !setting.GroupInUserUsableGroups(group) && group != userGroup { openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden) return } c.Set("group", group) } c.Set("token_name", "playground-"+group) channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model) openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError) return } middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) c.Set(constant.ContextKeyRequestStartTime, time.Now()) Relay(c) } ``` ## /controller/pricing.go ```go path="/controller/pricing.go" package controller import ( "github.com/gin-gonic/gin" "one-api/model" "one-api/setting" "one-api/setting/operation_setting" ) func GetPricing(c *gin.Context) { pricing := model.GetPricing() userId, exists := c.Get("id") usableGroup := map[string]string{} groupRatio := map[string]float64{} for s, f := range setting.GetGroupRatioCopy() { groupRatio[s] = f } var group string if exists { user, err := model.GetUserCache(userId.(int)) if err == nil { group = user.Group } } usableGroup = setting.GetUserUsableGroups(group) // check groupRatio contains usableGroup for group := range setting.GetGroupRatioCopy() { if _, ok := usableGroup[group]; !ok { delete(groupRatio, group) } } c.JSON(200, gin.H{ "success": true, "data": pricing, "group_ratio": groupRatio, "usable_group": usableGroup, }) } func ResetModelRatio(c *gin.Context) { defaultStr := operation_setting.DefaultModelRatio2JSONString() err := model.UpdateOption("ModelRatio", defaultStr) if err != nil { c.JSON(200, gin.H{ "success": false, "message": err.Error(), }) return } err = operation_setting.UpdateModelRatioByJSONString(defaultStr) if err != nil { c.JSON(200, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(200, gin.H{ "success": true, "message": "重置模型倍率成功", }) } ``` ## /controller/redemption.go ```go path="/controller/redemption.go" package controller import ( "net/http" "one-api/common" "one-api/model" "strconv" "github.com/gin-gonic/gin" ) func GetAllRedemptions(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) pageSize, _ := strconv.Atoi(c.Query("page_size")) if p < 0 { p = 0 } if pageSize < 1 { pageSize = common.ItemsPerPage } redemptions, total, err := model.GetAllRedemptions((p-1)*pageSize, pageSize) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ "items": redemptions, "total": total, "page": p, "page_size": pageSize, }, }) return } func SearchRedemptions(c *gin.Context) { keyword := c.Query("keyword") p, _ := strconv.Atoi(c.Query("p")) pageSize, _ := strconv.Atoi(c.Query("page_size")) if p < 0 { p = 0 } if pageSize < 1 { pageSize = common.ItemsPerPage } redemptions, total, err := model.SearchRedemptions(keyword, (p-1)*pageSize, pageSize) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ "items": redemptions, "total": total, "page": p, "page_size": pageSize, }, }) return } func GetRedemption(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } redemption, err := model.GetRedemptionById(id) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": redemption, }) return } func AddRedemption(c *gin.Context) { redemption := model.Redemption{} err := c.ShouldBindJSON(&redemption) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } if len(redemption.Name) == 0 || len(redemption.Name) > 20 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "兑换码名称长度必须在1-20之间", }) return } if redemption.Count <= 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "兑换码个数必须大于0", }) return } if redemption.Count > 100 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "一次兑换码批量生成的个数不能大于 100", }) return } var keys []string for i := 0; i < redemption.Count; i++ { key := common.GetUUID() cleanRedemption := model.Redemption{ UserId: c.GetInt("id"), Name: redemption.Name, Key: key, CreatedTime: common.GetTimestamp(), Quota: redemption.Quota, } err = cleanRedemption.Insert() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), "data": keys, }) return } keys = append(keys, key) } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": keys, }) return } func DeleteRedemption(c *gin.Context) { id, _ := strconv.Atoi(c.Param("id")) err := model.DeleteRedemptionById(id) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func UpdateRedemption(c *gin.Context) { statusOnly := c.Query("status_only") redemption := model.Redemption{} err := c.ShouldBindJSON(&redemption) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } cleanRedemption, err := model.GetRedemptionById(redemption.Id) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } if statusOnly != "" { cleanRedemption.Status = redemption.Status } else { // If you add more fields, please also update redemption.Update() cleanRedemption.Name = redemption.Name cleanRedemption.Quota = redemption.Quota } err = cleanRedemption.Update() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": cleanRedemption, }) return } ``` ## /controller/relay.go ```go path="/controller/relay.go" package controller import ( "bytes" "errors" "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "io" "log" "net/http" "one-api/common" "one-api/dto" "one-api/middleware" "one-api/model" "one-api/relay" "one-api/relay/constant" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "strings" ) func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { var err *dto.OpenAIErrorWithStatusCode switch relayMode { case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: err = relay.ImageHelper(c) case relayconstant.RelayModeAudioSpeech: fallthrough case relayconstant.RelayModeAudioTranslation: fallthrough case relayconstant.RelayModeAudioTranscription: err = relay.AudioHelper(c) case relayconstant.RelayModeRerank: err = relay.RerankHelper(c, relayMode) case relayconstant.RelayModeEmbeddings: err = relay.EmbeddingHelper(c) default: err = relay.TextHelper(c) } if err != nil { // 保存错误日志到mysql中 userId := c.GetInt("id") tokenName := c.GetString("token_name") modelName := c.GetString("original_model") tokenId := c.GetInt("token_id") userGroup := c.GetString("group") channelId := c.GetInt("channel_id") other := make(map[string]interface{}) other["error_type"] = err.Error.Type other["error_code"] = err.Error.Code other["status_code"] = err.StatusCode other["channel_id"] = channelId other["channel_name"] = c.GetString("channel_name") other["channel_type"] = c.GetInt("channel_type") model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error.Message, tokenId, 0, false, userGroup, other) } return err } func Relay(c *gin.Context) { relayMode := constant.Path2RelayMode(c.Request.URL.Path) requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") originalModel := c.GetString("original_model") var openaiErr *dto.OpenAIErrorWithStatusCode for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { common.LogError(c, err.Error()) openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) break } openaiErr = relayRequest(c, relayMode, channel) if openaiErr == nil { return // 成功处理请求,直接返回 } go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) if !shouldRetry(c, openaiErr, common.RetryTimes-i) { break } } useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) common.LogInfo(c, retryLogStr) } if openaiErr != nil { if openaiErr.StatusCode == http.StatusTooManyRequests { common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message)) openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" } openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) c.JSON(openaiErr.StatusCode, gin.H{ "error": openaiErr.Error, }) } } var upgrader = websocket.Upgrader{ Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol CheckOrigin: func(r *http.Request) bool { return true // 允许跨域 }, } func WssRelay(c *gin.Context) { // 将 HTTP 连接升级为 WebSocket 连接 ws, err := upgrader.Upgrade(c.Writer, c.Request, nil) defer ws.Close() if err != nil { openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError) helper.WssError(c, ws, openaiErr.Error) return } relayMode := constant.Path2RelayMode(c.Request.URL.Path) requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 originalModel := c.GetString("original_model") var openaiErr *dto.OpenAIErrorWithStatusCode for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { common.LogError(c, err.Error()) openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) break } openaiErr = wssRequest(c, ws, relayMode, channel) if openaiErr == nil { return // 成功处理请求,直接返回 } go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) if !shouldRetry(c, openaiErr, common.RetryTimes-i) { break } } useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) common.LogInfo(c, retryLogStr) } if openaiErr != nil { if openaiErr.StatusCode == http.StatusTooManyRequests { openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" } openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) helper.WssError(c, ws, openaiErr.Error) } } func RelayClaude(c *gin.Context) { //relayMode := constant.Path2RelayMode(c.Request.URL.Path) requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") originalModel := c.GetString("original_model") var claudeErr *dto.ClaudeErrorWithStatusCode for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { common.LogError(c, err.Error()) claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) break } claudeErr = claudeRequest(c, channel) if claudeErr == nil { return // 成功处理请求,直接返回 } openaiErr := service.ClaudeErrorToOpenAIError(claudeErr) go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) if !shouldRetry(c, openaiErr, common.RetryTimes-i) { break } } useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) common.LogInfo(c, retryLogStr) } if claudeErr != nil { claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId) c.JSON(claudeErr.StatusCode, gin.H{ "type": "error", "error": claudeErr.Error, }) } } func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return relayHandler(c, relayMode) } func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return relay.WssHelper(c, ws) } func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode { addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return relay.ClaudeHelper(c) } func addUsedChannel(c *gin.Context, channelId int) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) } func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) { if retryCount == 0 { autoBan := c.GetBool("auto_ban") autoBanInt := 1 if !autoBan { autoBanInt = 0 } return &model.Channel{ Id: c.GetInt("channel_id"), Type: c.GetInt("channel_type"), Name: c.GetString("channel_name"), AutoBan: &autoBanInt, }, nil } channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount) if err != nil { return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error())) } middleware.SetupContextForSelectedChannel(c, channel, originalModel) return channel, nil } func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { if openaiErr == nil { return false } if openaiErr.LocalError { return false } if retryTimes <= 0 { return false } if _, ok := c.Get("specific_channel_id"); ok { return false } if openaiErr.StatusCode == http.StatusTooManyRequests { return true } if openaiErr.StatusCode == 307 { return true } if openaiErr.StatusCode/100 == 5 { // 超时不重试 if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 { return false } return true } if openaiErr.StatusCode == http.StatusBadRequest { channelType := c.GetInt("channel_type") if channelType == common.ChannelTypeAnthropic { return true } return false } if openaiErr.StatusCode == 408 { // azure处理超时不重试 return false } if openaiErr.StatusCode/100 == 2 { return false } return true } func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) { // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) if service.ShouldDisableChannel(channelType, err) && autoBan { service.DisableChannel(channelId, channelName, err.Error.Message) } } func RelayMidjourney(c *gin.Context) { relayMode := c.GetInt("relay_mode") var err *dto.MidjourneyResponse switch relayMode { case relayconstant.RelayModeMidjourneyNotify: err = relay.RelayMidjourneyNotify(c) case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: err = relay.RelayMidjourneyTask(c, relayMode) case relayconstant.RelayModeMidjourneyTaskImageSeed: err = relay.RelayMidjourneyTaskImageSeed(c) case relayconstant.RelayModeSwapFace: err = relay.RelaySwapFace(c) default: err = relay.RelayMidjourneySubmit(c, relayMode) } //err = relayMidjourneySubmit(c, relayMode) log.Println(err) if err != nil { statusCode := http.StatusBadRequest if err.Code == 30 { err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" statusCode = http.StatusTooManyRequests } c.JSON(statusCode, gin.H{ "description": fmt.Sprintf("%s %s", err.Description, err.Result), "type": "upstream_error", "code": err.Code, }) channelId := c.GetInt("channel_id") common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result))) } } func RelayNotImplemented(c *gin.Context) { err := dto.OpenAIError{ Message: "API not implemented", Type: "new_api_error", Param: "", Code: "api_not_implemented", } c.JSON(http.StatusNotImplemented, gin.H{ "error": err, }) } func RelayNotFound(c *gin.Context) { err := dto.OpenAIError{ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Type: "invalid_request_error", Param: "", Code: "", } c.JSON(http.StatusNotFound, gin.H{ "error": err, }) } func RelayTask(c *gin.Context) { retryTimes := common.RetryTimes channelId := c.GetInt("channel_id") relayMode := c.GetInt("relay_mode") group := c.GetString("group") originalModel := c.GetString("original_model") c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) taskErr := taskRelayHandler(c, relayMode) if taskErr == nil { retryTimes = 0 } for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) if err != nil { common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) break } channelId = channel.Id useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, err := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) taskErr = taskRelayHandler(c, relayMode) } useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) common.LogInfo(c, retryLogStr) } if taskErr != nil { if taskErr.StatusCode == http.StatusTooManyRequests { taskErr.Message = "当前分组上游负载已饱和,请稍后再试" } c.JSON(taskErr.StatusCode, taskErr) } } func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError { var err *dto.TaskError switch relayMode { case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID: err = relay.RelayTaskFetch(c, relayMode) default: err = relay.RelayTaskSubmit(c, relayMode) } return err } func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool { if taskErr == nil { return false } if retryTimes <= 0 { return false } if _, ok := c.Get("specific_channel_id"); ok { return false } if taskErr.StatusCode == http.StatusTooManyRequests { return true } if taskErr.StatusCode == 307 { return true } if taskErr.StatusCode/100 == 5 { // 超时不重试 if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 { return false } return true } if taskErr.StatusCode == http.StatusBadRequest { return false } if taskErr.StatusCode == 408 { // azure处理超时不重试 return false } if taskErr.LocalError { return false } if taskErr.StatusCode/100 == 2 { return false } return true } ``` ## /controller/setup.go ```go path="/controller/setup.go" package controller import ( "github.com/gin-gonic/gin" "one-api/common" "one-api/constant" "one-api/model" "one-api/setting/operation_setting" "time" ) type Setup struct { Status bool `json:"status"` RootInit bool `json:"root_init"` DatabaseType string `json:"database_type"` } type SetupRequest struct { Username string `json:"username"` Password string `json:"password"` ConfirmPassword string `json:"confirmPassword"` SelfUseModeEnabled bool `json:"SelfUseModeEnabled"` DemoSiteEnabled bool `json:"DemoSiteEnabled"` } func GetSetup(c *gin.Context) { setup := Setup{ Status: constant.Setup, } if constant.Setup { c.JSON(200, gin.H{ "success": true, "data": setup, }) return } setup.RootInit = model.RootUserExists() if common.UsingMySQL { setup.DatabaseType = "mysql" } if common.UsingPostgreSQL { setup.DatabaseType = "postgres" } if common.UsingSQLite { setup.DatabaseType = "sqlite" } c.JSON(200, gin.H{ "success": true, "data": setup, }) } func PostSetup(c *gin.Context) { // Check if setup is already completed if constant.Setup { c.JSON(400, gin.H{ "success": false, "message": "系统已经初始化完成", }) return } // Check if root user already exists rootExists := model.RootUserExists() var req SetupRequest err := c.ShouldBindJSON(&req) if err != nil { c.JSON(400, gin.H{ "success": false, "message": "请求参数有误", }) return } // If root doesn't exist, validate and create admin account if !rootExists { // Validate password if req.Password != req.ConfirmPassword { c.JSON(400, gin.H{ "success": false, "message": "两次输入的密码不一致", }) return } if len(req.Password) < 8 { c.JSON(400, gin.H{ "success": false, "message": "密码长度至少为8个字符", }) return } // Create root user hashedPassword, err := common.Password2Hash(req.Password) if err != nil { c.JSON(500, gin.H{ "success": false, "message": "系统错误: " + err.Error(), }) return } rootUser := model.User{ Username: req.Username, Password: hashedPassword, Role: common.RoleRootUser, Status: common.UserStatusEnabled, DisplayName: "Root User", AccessToken: nil, Quota: 100000000, } err = model.DB.Create(&rootUser).Error if err != nil { c.JSON(500, gin.H{ "success": false, "message": "创建管理员账号失败: " + err.Error(), }) return } } // Set operation modes operation_setting.SelfUseModeEnabled = req.SelfUseModeEnabled operation_setting.DemoSiteEnabled = req.DemoSiteEnabled // Save operation modes to database for persistence err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled)) if err != nil { c.JSON(500, gin.H{ "success": false, "message": "保存自用模式设置失败: " + err.Error(), }) return } err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled)) if err != nil { c.JSON(500, gin.H{ "success": false, "message": "保存演示站点模式设置失败: " + err.Error(), }) return } // Update setup status constant.Setup = true setup := model.Setup{ Version: common.Version, InitializedAt: time.Now().Unix(), } err = model.DB.Create(&setup).Error if err != nil { c.JSON(500, gin.H{ "success": false, "message": "系统初始化失败: " + err.Error(), }) return } c.JSON(200, gin.H{ "success": true, "message": "系统初始化成功", }) } func boolToString(b bool) string { if b { return "true" } return "false" } ``` ## /controller/task.go ```go path="/controller/task.go" package controller import ( "context" "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" "github.com/samber/lo" "io" "net/http" "one-api/common" "one-api/constant" "one-api/dto" "one-api/model" "one-api/relay" "sort" "strconv" "time" ) func UpdateTaskBulk() { //revocer //imageModel := "midjourney" for { time.Sleep(time.Duration(15) * time.Second) common.SysLog("任务进度轮询开始") ctx := context.TODO() allTasks := model.GetAllUnFinishSyncTasks(500) platformTask := make(map[constant.TaskPlatform][]*model.Task) for _, t := range allTasks { platformTask[t.Platform] = append(platformTask[t.Platform], t) } for platform, tasks := range platformTask { if len(tasks) == 0 { continue } taskChannelM := make(map[int][]string) taskM := make(map[string]*model.Task) nullTaskIds := make([]int64, 0) for _, task := range tasks { if task.TaskID == "" { // 统计失败的未完成任务 nullTaskIds = append(nullTaskIds, task.ID) continue } taskM[task.TaskID] = task taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID) } if len(nullTaskIds) > 0 { err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{ "status": "FAILURE", "progress": "100%", }) if err != nil { common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) } else { common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) } } if len(taskChannelM) == 0 { continue } UpdateTaskByPlatform(platform, taskChannelM, taskM) } common.SysLog("任务进度轮询完成") } } func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) { switch platform { case constant.TaskPlatformMidjourney: //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) case constant.TaskPlatformSuno: _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) default: common.SysLog("未知平台") } } func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { for channelId, taskIds := range taskChannelM { err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) if err != nil { common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) } } return nil } func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) if len(taskIds) == 0 { return nil } channel, err := model.CacheGetChannel(channelId) if err != nil { common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) err = model.TaskBulkUpdate(taskIds, map[string]any{ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), "status": "FAILURE", "progress": "100%", }) if err != nil { common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) } return err } adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno) if adaptor == nil { return errors.New("adaptor not found") } resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{ "ids": taskIds, }) if err != nil { common.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) return err } if resp.StatusCode != http.StatusOK { common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) } defer resp.Body.Close() responseBody, err := io.ReadAll(resp.Body) if err != nil { common.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) return err } var responseItems dto.TaskResponse[[]dto.SunoDataResponse] err = json.Unmarshal(responseBody, &responseItems) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) return err } if !responseItems.IsSuccess() { common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody))) return err } for _, responseItem := range responseItems.Data { task := taskM[responseItem.TaskID] if !checkTaskNeedUpdate(task, responseItem) { continue } task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status) task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason) task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime) task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) task.Progress = "100%" //err = model.CacheUpdateUserQuota(task.UserId) ? if err != nil { common.LogError(ctx, "error update user quota cache: "+err.Error()) } else { quota := task.Quota if quota != 0 { err = model.IncreaseUserQuota(task.UserId, quota, false) if err != nil { common.LogError(ctx, "fail to increase user quota: "+err.Error()) } logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } } if responseItem.Status == model.TaskStatusSuccess { task.Progress = "100%" } task.Data = responseItem.Data err = task.Update() if err != nil { common.SysError("UpdateMidjourneyTask task error: " + err.Error()) } } return nil } func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool { if oldTask.SubmitTime != newTask.SubmitTime { return true } if oldTask.StartTime != newTask.StartTime { return true } if oldTask.FinishTime != newTask.FinishTime { return true } if string(oldTask.Status) != newTask.Status { return true } if oldTask.FailReason != newTask.FailReason { return true } if oldTask.FinishTime != newTask.FinishTime { return true } if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" { return true } oldData, _ := json.Marshal(oldTask.Data) newData, _ := json.Marshal(newTask.Data) sort.Slice(oldData, func(i, j int) bool { return oldData[i] < oldData[j] }) sort.Slice(newData, func(i, j int) bool { return newData[i] < newData[j] }) if string(oldData) != string(newData) { return true } return false } func GetAllTask(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) if p < 0 { p = 0 } startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) // 解析其他查询参数 queryParams := model.SyncTaskQueryParams{ Platform: constant.TaskPlatform(c.Query("platform")), TaskID: c.Query("task_id"), Status: c.Query("status"), Action: c.Query("action"), StartTimestamp: startTimestamp, EndTimestamp: endTimestamp, } logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams) if logs == nil { logs = make([]*model.Task, 0) } c.JSON(200, gin.H{ "success": true, "message": "", "data": logs, }) } func GetUserTask(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) if p < 0 { p = 0 } userId := c.GetInt("id") startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) queryParams := model.SyncTaskQueryParams{ Platform: constant.TaskPlatform(c.Query("platform")), TaskID: c.Query("task_id"), Status: c.Query("status"), Action: c.Query("action"), StartTimestamp: startTimestamp, EndTimestamp: endTimestamp, } logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams) if logs == nil { logs = make([]*model.Task, 0) } c.JSON(200, gin.H{ "success": true, "message": "", "data": logs, }) } ``` ## /controller/telegram.go ```go path="/controller/telegram.go" package controller import ( "crypto/hmac" "crypto/sha256" "encoding/hex" "io" "net/http" "one-api/common" "one-api/model" "sort" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" ) func TelegramBind(c *gin.Context) { if !common.TelegramOAuthEnabled { c.JSON(200, gin.H{ "message": "管理员未开启通过 Telegram 登录以及注册", "success": false, }) return } params := c.Request.URL.Query() if !checkTelegramAuthorization(params, common.TelegramBotToken) { c.JSON(200, gin.H{ "message": "无效的请求", "success": false, }) return } telegramId := params["id"][0] if model.IsTelegramIdAlreadyTaken(telegramId) { c.JSON(200, gin.H{ "message": "该 Telegram 账户已被绑定", "success": false, }) return } session := sessions.Default(c) id := session.Get("id") user := model.User{Id: id.(int)} if err := user.FillUserById(); err != nil { c.JSON(200, gin.H{ "message": err.Error(), "success": false, }) return } if user.Id == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "用户已注销", }) return } user.TelegramId = telegramId if err := user.Update(false); err != nil { c.JSON(200, gin.H{ "message": err.Error(), "success": false, }) return } c.Redirect(302, "/setting") } func TelegramLogin(c *gin.Context) { if !common.TelegramOAuthEnabled { c.JSON(200, gin.H{ "message": "管理员未开启通过 Telegram 登录以及注册", "success": false, }) return } params := c.Request.URL.Query() if !checkTelegramAuthorization(params, common.TelegramBotToken) { c.JSON(200, gin.H{ "message": "无效的请求", "success": false, }) return } telegramId := params["id"][0] user := model.User{TelegramId: telegramId} if err := user.FillUserByTelegramId(); err != nil { c.JSON(200, gin.H{ "message": err.Error(), "success": false, }) return } setupLogin(&user, c) } func checkTelegramAuthorization(params map[string][]string, token string) bool { strs := []string{} var hash = "" for k, v := range params { if k == "hash" { hash = v[0] continue } strs = append(strs, k+"="+v[0]) } sort.Strings(strs) var imploded = "" for _, s := range strs { if imploded != "" { imploded += "\n" } imploded += s } sha256hash := sha256.New() io.WriteString(sha256hash, token) hmachash := hmac.New(sha256.New, sha256hash.Sum(nil)) io.WriteString(hmachash, imploded) ss := hex.EncodeToString(hmachash.Sum(nil)) return hash == ss } ``` ## /controller/token.go ```go path="/controller/token.go" package controller import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" ) func GetAllTokens(c *gin.Context) { userId := c.GetInt("id") p, _ := strconv.Atoi(c.Query("p")) size, _ := strconv.Atoi(c.Query("size")) if p < 0 { p = 0 } if size <= 0 { size = common.ItemsPerPage } else if size > 100 { size = 100 } tokens, err := model.GetAllUserTokens(userId, p*size, size) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": tokens, }) return } func SearchTokens(c *gin.Context) { userId := c.GetInt("id") keyword := c.Query("keyword") token := c.Query("token") tokens, err := model.SearchUserTokens(userId, keyword, token) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": tokens, }) return } func GetToken(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) userId := c.GetInt("id") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } token, err := model.GetTokenByIds(id, userId) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": token, }) return } func GetTokenStatus(c *gin.Context) { tokenId := c.GetInt("token_id") userId := c.GetInt("id") token, err := model.GetTokenByIds(tokenId, userId) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } expiredAt := token.ExpiredTime if expiredAt == -1 { expiredAt = 0 } c.JSON(http.StatusOK, gin.H{ "object": "credit_summary", "total_granted": token.RemainQuota, "total_used": 0, // not supported currently "total_available": token.RemainQuota, "expires_at": expiredAt * 1000, }) } func AddToken(c *gin.Context) { token := model.Token{} err := c.ShouldBindJSON(&token) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } if len(token.Name) > 30 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌名称过长", }) return } key, err := common.GenerateKey() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "生成令牌失败", }) common.SysError("failed to generate token key: " + err.Error()) return } cleanToken := model.Token{ UserId: c.GetInt("id"), Name: token.Name, Key: key, CreatedTime: common.GetTimestamp(), AccessedTime: common.GetTimestamp(), ExpiredTime: token.ExpiredTime, RemainQuota: token.RemainQuota, UnlimitedQuota: token.UnlimitedQuota, ModelLimitsEnabled: token.ModelLimitsEnabled, ModelLimits: token.ModelLimits, AllowIps: token.AllowIps, Group: token.Group, } err = cleanToken.Insert() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func DeleteToken(c *gin.Context) { id, _ := strconv.Atoi(c.Param("id")) userId := c.GetInt("id") err := model.DeleteTokenById(id, userId) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func UpdateToken(c *gin.Context) { userId := c.GetInt("id") statusOnly := c.Query("status_only") token := model.Token{} err := c.ShouldBindJSON(&token) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } if len(token.Name) > 30 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌名称过长", }) return } cleanToken, err := model.GetTokenByIds(token.Id, userId) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } if token.Status == common.TokenStatusEnabled { if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", }) return } if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", }) return } } if statusOnly != "" { cleanToken.Status = token.Status } else { // If you add more fields, please also update token.Update() cleanToken.Name = token.Name cleanToken.ExpiredTime = token.ExpiredTime cleanToken.RemainQuota = token.RemainQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled cleanToken.ModelLimits = token.ModelLimits cleanToken.AllowIps = token.AllowIps cleanToken.Group = token.Group } err = cleanToken.Update() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": cleanToken, }) return } ``` ## /controller/topup.go ```go path="/controller/topup.go" package controller import ( "fmt" "log" "net/url" "one-api/common" "one-api/model" "one-api/service" "one-api/setting" "strconv" "sync" "time" "github.com/Calcium-Ion/go-epay/epay" "github.com/gin-gonic/gin" "github.com/samber/lo" "github.com/shopspring/decimal" ) type EpayRequest struct { Amount int64 `json:"amount"` PaymentMethod string `json:"payment_method"` TopUpCode string `json:"top_up_code"` } type AmountRequest struct { Amount int64 `json:"amount"` TopUpCode string `json:"top_up_code"` } func GetEpayClient() *epay.Client { if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" { return nil } withUrl, err := epay.NewClient(&epay.Config{ PartnerID: setting.EpayId, Key: setting.EpayKey, }, setting.PayAddress) if err != nil { return nil } return withUrl } func getPayMoney(amount int64, group string) float64 { dAmount := decimal.NewFromInt(amount) if !common.DisplayInCurrencyEnabled { dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) dAmount = dAmount.Div(dQuotaPerUnit) } topupGroupRatio := common.GetTopupGroupRatio(group) if topupGroupRatio == 0 { topupGroupRatio = 1 } dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio) dPrice := decimal.NewFromFloat(setting.Price) payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio) return payMoney.InexactFloat64() } func getMinTopup() int64 { minTopup := setting.MinTopUp if !common.DisplayInCurrencyEnabled { dMinTopup := decimal.NewFromInt(int64(minTopup)) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) minTopup = int(dMinTopup.Mul(dQuotaPerUnit).IntPart()) } return int64(minTopup) } func RequestEpay(c *gin.Context) { var req EpayRequest err := c.ShouldBindJSON(&req) if err != nil { c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) return } if req.Amount < getMinTopup() { c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())}) return } id := c.GetInt("id") group, err := model.GetUserGroup(id, true) if err != nil { c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) return } payMoney := getPayMoney(req.Amount, group) if payMoney < 0.01 { c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) return } payType := "wxpay" if req.PaymentMethod == "zfb" { payType = "alipay" } if req.PaymentMethod == "wx" { req.PaymentMethod = "wxpay" payType = "wxpay" } callBackAddress := service.GetCallbackAddress() returnUrl, _ := url.Parse(setting.ServerAddress + "/log") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix()) tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo) client := GetEpayClient() if client == nil { c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"}) return } uri, params, err := client.Purchase(&epay.PurchaseArgs{ Type: payType, ServiceTradeNo: tradeNo, Name: fmt.Sprintf("TUC%d", req.Amount), Money: strconv.FormatFloat(payMoney, 'f', 2, 64), Device: epay.PC, NotifyUrl: notifyUrl, ReturnUrl: returnUrl, }) if err != nil { c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) return } amount := req.Amount if !common.DisplayInCurrencyEnabled { dAmount := decimal.NewFromInt(int64(amount)) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) amount = dAmount.Div(dQuotaPerUnit).IntPart() } topUp := &model.TopUp{ UserId: id, Amount: amount, Money: payMoney, TradeNo: tradeNo, CreateTime: time.Now().Unix(), Status: "pending", } err = topUp.Insert() if err != nil { c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"}) return } c.JSON(200, gin.H{"message": "success", "data": params, "url": uri}) } // tradeNo lock var orderLocks sync.Map var createLock sync.Mutex // LockOrder 尝试对给定订单号加锁 func LockOrder(tradeNo string) { lock, ok := orderLocks.Load(tradeNo) if !ok { createLock.Lock() defer createLock.Unlock() lock, ok = orderLocks.Load(tradeNo) if !ok { lock = new(sync.Mutex) orderLocks.Store(tradeNo, lock) } } lock.(*sync.Mutex).Lock() } // UnlockOrder 释放给定订单号的锁 func UnlockOrder(tradeNo string) { lock, ok := orderLocks.Load(tradeNo) if ok { lock.(*sync.Mutex).Unlock() } } func EpayNotify(c *gin.Context) { params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string { r[t] = c.Request.URL.Query().Get(t) return r }, map[string]string{}) client := GetEpayClient() if client == nil { log.Println("易支付回调失败 未找到配置信息") _, err := c.Writer.Write([]byte("fail")) if err != nil { log.Println("易支付回调写入失败") return } } verifyInfo, err := client.Verify(params) if err == nil && verifyInfo.VerifyStatus { _, err := c.Writer.Write([]byte("success")) if err != nil { log.Println("易支付回调写入失败") } } else { _, err := c.Writer.Write([]byte("fail")) if err != nil { log.Println("易支付回调写入失败") } log.Println("易支付回调签名验证失败") return } if verifyInfo.TradeStatus == epay.StatusTradeSuccess { log.Println(verifyInfo) LockOrder(verifyInfo.ServiceTradeNo) defer UnlockOrder(verifyInfo.ServiceTradeNo) topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo) if topUp == nil { log.Printf("易支付回调未找到订单: %v", verifyInfo) return } if topUp.Status == "pending" { topUp.Status = "success" err := topUp.Update() if err != nil { log.Printf("易支付回调更新订单失败: %v", topUp) return } //user, _ := model.GetUserById(topUp.UserId, false) //user.Quota += topUp.Amount * 500000 dAmount := decimal.NewFromInt(int64(topUp.Amount)) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart()) err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true) if err != nil { log.Printf("易支付回调更新用户失败: %v", topUp) return } log.Printf("易支付回调更新用户成功 %v", topUp) model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money)) } } else { log.Printf("易支付异常回调: %v", verifyInfo) } } func RequestAmount(c *gin.Context) { var req AmountRequest err := c.ShouldBindJSON(&req) if err != nil { c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) return } if req.Amount < getMinTopup() { c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())}) return } id := c.GetInt("id") group, err := model.GetUserGroup(id, true) if err != nil { c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) return } payMoney := getPayMoney(req.Amount, group) if payMoney <= 0.01 { c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) return } c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)}) } ``` ## /controller/usedata.go ```go path="/controller/usedata.go" package controller import ( "github.com/gin-gonic/gin" "net/http" "one-api/model" "strconv" ) func GetAllQuotaDates(c *gin.Context) { startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) username := c.Query("username") dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": dates, }) return } func GetUserQuotaDates(c *gin.Context) { userId := c.GetInt("id") startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) // 判断时间跨度是否超过 1 个月 if endTimestamp-startTimestamp > 2592000 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "时间跨度不能超过 1 个月", }) return } dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": dates, }) return } ``` The content has been capped at 50000 tokens, and files over NaN bytes have been omitted. The user could consider applying other filters to refine the result. The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.