``` ├── .cursor/ ├── rules/ ├── weclone-rules.mdc ├── .gitignore ├── README.md ├── dataset/ ├── blocked_words.json ├── res_csv/ ├── pt/ ├── dataset_info.json ├── sft/ ├── dataset_info.json ├── test_data.json ├── ds_config.json ├── pyproject.toml ├── settings.json ├── tests/ ├── README.md ├── test_full_pipeline.py ├── test_get_sample_audio.py ├── test_old_csv_to_json.py ├── test_qa_generator.py ├── test_weclone_pipeline_mock.py ├── weclone-audio/ ├── README.md ├── src/ ├── Llasa/ ├── infer.py ├── text_to_speech.py ├── SparkTTS.py ├── __init__.py ├── get_sample_audio.py ├── infer.py ├── sample.wav ├── server未完工/ ├── .env.example ├── handle_text.py ├── requirements.txt ├── server.py ├── tts_handler.py ├── utils.py ├── weclone/ ├── __init__.py ├── cli.py ├── data/ ├── __init__.py ├── models.py ├── qa_generator.py ├── strategies.py ├── eval/ ├── __init__.py ├── cli_demo.py ├── evaluate.py ├── test_model.py ├── web_demo.py ├── server/ ├── __init__.py ├── api_service.py ├── train/ ├── __init__.py ├── export_model.py ├── train_pt.py ├── train_sft.py ├── utils/ ├── __init__.py ├── config.py ├── length_cdf.py ├── log.py ├── tools.py ``` ## /.cursor/rules/weclone-rules.mdc ```mdc path="/.cursor/rules/weclone-rules.mdc" --- description: globs: alwaysApply: true --- # Your rule content - You can @ files here - The project uses uv as the package manager and pyproject.toml as the project configuration file. ``` ## /.gitignore ```gitignore path="/.gitignore" wandb/ weclone_archive-my/ **/pycache/ events.out.tfevents.* 归档/ *.pt *.npz *nohup.out *log.txt *cookie.bin *.gradio/ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ *.zip LLaMA-Factory chatglm3-6b cache archive model_output* data/test .vscode *-my.* *.csv *test.* *users.json Spark-TTS-0.5B/ uv.lock output* Qwen*/ ``` ## /README.md ![download](https://github.com/user-attachments/assets/5842e84e-004f-4afd-9373-af64e9575b78)

🚀从聊天记录创造数字分身的一站式解决方案💡

[![GitHub stars](https://img.shields.io/github/stars/xming521/WeClone?style=for-the-badge&logo=github&label=Stars&logoColor=white&color=ffda65)](https://github.com/xming521/WeClone/stargazers) [![GitHub release](https://img.shields.io/github/v/release/xming521/WeClone?style=for-the-badge&logo=github&label=Release&logoColor=white&color=06d094)](https://github.com/xming521/WeClone/releases) WeClone①
## 核心功能✨ - 💫 涵盖打造数字分身的全链路方案,包括聊天数据导出、预处理、模型训练、部署 - 💬 使用微信聊天记录微调LLM,让大模型有“那味儿” - 🎙️ 使用微信语音消息➕0.5B大模型实现高质量声音克隆 👉[WeClone-audio](https://github.com/xming521/WeClone/tree/master/weclone-audio) - 🔗 绑定到微信、QQ、Telegram、企微、飞书机器人,实现自己的数字分身 ## 特性与说明📋 > [!TIP] > 新特性:[WeClone-audio](https://github.com/xming521/WeClone/tree/master/weclone-audio) 模块,支持对微信语音进行克隆。 > [!IMPORTANT] >

0.2.0版本进行了全面重构,数据集目录和脚本路径全部进行了修改,拉取新代码后,`csv`文件夹放在`dataset`下,并且需要重新安装依赖。

> [!IMPORTANT] > - WeClone仍在快速迭代期,当前效果不代表最终效果。 > - 微调LLM效果很大程度取决于模型大小、聊天数据的数量和质量,理论上模型越大,数据越多,效果越好。 > - Windows环境未进行严格测试,可以使用WSL作为运行环境。 ### 硬件要求 项目默认使用Qwen2.5-7B-Instruct模型,LoRA方法对sft阶段微调,大约需要16GB显存。也可以使用[LLaMA Factory](https://github.com/hiyouga/LLaMA-Factory/blob/main/README_zh.md#%E6%A8%A1%E5%9E%8B)支持的其他模型和方法。 需要显存的估算值: | 方法 | 精度 | 7B | 14B | 30B | 70B | `x`B | | ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- | | Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB | | Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB | | Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB | ### 环境搭建 cuda安装(已安装可跳过):[LLaMA Factory](https://llamafactory.readthedocs.io/zh-cn/latest/getting_started/installation.html#cuda) 建议使用 [uv](https://docs.astral.sh/uv/),这是一个非常快速的 Python 环境管理器。安装uv后,您可以使用以下命令创建一个新的Python环境并安装依赖项,注意这不包含音频克隆功能的依赖: ```bash git clone https://github.com/xming521/WeClone.git cd WeClone uv venv .venv --python=3.10 source .venv/bin/activate uv pip install --group main -e . ``` 使用以下命令测试CUDA环境是否正确配置并可被PyTorch识别,Mac不需要: ```bash python -c "import torch; print('CUDA是否可用:', torch.cuda.is_available());" ``` (可选)安装FlashAttention,加速训练和推理:`uv pip install flash-attn --no-build-isolation` > [!NOTE] > 训练以及推理相关配置统一在文件[settings.json](settings.json) ### 数据准备 请使用[PyWxDump](https://github.com/xaoyaoo/PyWxDump)提取微信聊天记录。可以先将手机的聊天记录迁移(备份)到电脑,数据量更多一些。下载软件并解密数据库后,点击聊天备份,导出类型为CSV,可以导出多个联系人或群聊,然后将导出的位于`wxdump_tmp/export` 的 `csv` 文件夹放在`./dataset`目录即可,也就是不同人聊天记录的文件夹一起放在 `./dataset/csv`。 ### 数据预处理 - 项目默认去除了数据中的手机号、身份证号、邮箱、网址。还提供了一个禁用词词库[blocked_words](dataset/blocked_words.json),可以自行添加需要过滤的词句(会默认去掉包括禁用词的整句)。 - 执行以下命令对数据进行处理,可以根据自己的聊天风格修改settings.json的`make_dataset_args`。 ```bash python weclone/data/qa_generator.py ``` - 目前仅支持时间窗口策略,根据`single_combine_time_window`将单人连续消息通过逗号连接合并为一句,根据`qa_match_time_window`匹配问答对。后续将增加大模型清洗数据的功能。 ### 模型下载 ```bash git lfs install git clone https://www.modelscope.cn/Qwen/Qwen2.5-7B-Instruct.git ``` ### 配置参数并微调模型 - (可选)修改[settings.json](settings.json)的`model_name_or_path`和`template`选择本地下载好的其他模型。 - 修改`per_device_train_batch_size`以及`gradient_accumulation_steps`来调整显存占用。 - 可以根据自己数据集的数量和质量修改`lora_rank`、`lora_dropout`等参数。 #### 单卡训练 ```bash python weclone/train/train_sft.py ``` #### 多卡训练 取消`settings.json`中`deepspeed`行代码注释,使用以下命令多卡训练: ```bash uv pip install deepspeed deepspeed --num_gpus=使用显卡数量 weclone/train/train_sft.py ``` ### 使用浏览器demo简单推理 可以在这一步测试出合适的temperature、top_p值,修改settings.json的`infer_args`后,供后续推理时使用。 ```bash python weclone/eval/web_demo.py ``` ### 使用接口进行推理 ```bash python weclone/server/api_service.py ``` ### 使用常见聊天问题测试 有些答案比较抽象,主要原因是训练数据没有覆盖,后续通过RAG来解决。测试结果在test_result-my.txt。 ```bash python weclone/server/api_service.py python weclone/eval/test_model.py ``` ### 微调效果 使用Qwen2.5-14B-Instruct模型,大概3万条处理后的有效数据,loss降到了3.5左右的效果。
截图
alt text alt text alt text alt text
### 部署到聊天机器人 [AstrBot](https://github.com/AstrBotDevs/AstrBot) 是易上手的多平台 LLM 聊天机器人及开发框架 ✨ 平台支持 QQ、QQ频道、Telegram、微信、企微、飞书。 使用步骤: 1. 部署 AstrBot 2. 在 AstrBot 中部署消息平台 3. 执行 `python weclone/server/api_service.py ` 启动api服务 4. 在 AstrBot 中新增服务提供商,类型选择OpenAI,API Base URL 根据AstrBot部署方式填写(例如docker部署可能为http://172.17.0.1:8005/v1) ,模型填写gpt-3.5-turbo,API Key随意填写一个 5. 微调后不支持工具调用,请先关掉默认的工具,消息平台发送指令: `/tool off all`,否则会没有微调后的效果。 6. 根据微调时使用的default_system,在 AstrBot 中设置系统提示词。 ![5](https://github.com/user-attachments/assets/19de7072-076a-4cdf-8ae6-46b9b89f536a) > [!IMPORTANT] > 检查api_service的日志,尽量保证大模型服务请求的参数和微调时一致,tool插件能力都关掉。 7. 调整采样参数,例如temperature、top_p、top_k等 [配置自定义的模型参数](https://astrbot.app/config/model-config.html#%E9%85%8D%E7%BD%AE%E8%87%AA%E5%AE%9A%E4%B9%89%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%8F%82%E6%95%B0) ### 问题解决 - 微调问题:[LLaMA-Factory| FAQs | 常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614) ### ❤️ 贡献代码 欢迎任何 Issues/Pull Requests! 你可以通过查看Issues或帮助审核 PR(拉取请求)来贡献。对于新功能的添加,请先通过 Issue 讨论。 运行`uv pip install --group dev -e .`安装开发依赖。 项目使用`pytest`测试,`pyright`检查类型,`ruff`检查代码格式。 ### 免责声明 > [!CAUTION] > 请勿用于非法用途,否则后果自负。
1. 使用目的 * 本项目仅供学习交流使用,**请勿用于非法用途**,**请勿用于非法用途**,**请勿用于非法用途**,否则后果自负。 * 用户理解并同意,任何违反法律法规、侵犯他人合法权益的行为,均与本项目及其开发者无关,后果由用户自行承担。 2. 使用期限 * 您应该在下载保存使用本项目的24小时内,删除本项目的源代码和程序;超出此期限的任何使用行为,一概与本项目及其开发者无关。 3. 操作规范 * 本项目仅允许在授权情况下使用数据训练,严禁用于非法目的,否则自行承担所有相关责任;用户如因违反此规定而引发的任何法律责任,将由用户自行承担,与本项目及其开发者无关。 * 严禁用于窃取他人隐私,严禁用于窃取他人隐私,严禁用于窃取他人隐私,否则自行承担所有相关责任。 4. 免责声明接受 * 下载、保存、进一步浏览源代码或者下载安装、编译使用本程序,表示你同意本警告,并承诺遵守它; 5. 禁止用于非法测试或渗透 * 禁止利用本项目的相关技术从事非法测试或渗透,禁止利用本项目的相关代码或相关技术从事任何非法工作,如因此产生的一切不良后果与本项目及其开发者无关。 * 任何因此产生的不良后果,包括但不限于数据泄露、系统瘫痪、侵犯隐私等,均与本项目及其开发者无关,责任由用户自行承担。 6. 免责声明修改 * 本免责声明可能根据项目运行情况和法律法规的变化进行修改和调整。用户应定期查阅本页面以获取最新版本的免责声明,使用本项目时应遵守最新版本的免责声明。 7. 其他 * 除本免责声明规定外,用户在使用本项目过程中应遵守相关的法律法规和道德规范。对于因用户违反相关规定而引发的任何纠纷或损失,本项目及其开发者不承担任何责任。 * 请用户慎重阅读并理解本免责声明的所有内容,确保在使用本项目时严格遵守相关规定。
请用户慎重阅读并理解本免责声明的所有内容,确保在使用本项目时严格遵守相关规定。


### ⭐ Star History > [!TIP] > 如果本项目对您有帮助,或者您关注本项目的未来发展,请给项目 Star,谢谢
[![Star History Chart](https://api.star-history.com/svg?repos=xming521/WeClone&type=Date)](https://www.star-history.com/#xming521/WeClone&Date)
克隆我们,保留那灵魂的芬芳
## /dataset/blocked_words.json ```json path="/dataset/blocked_words.json" { "blocked_words": [ "例如 姓名", "例如 地址", "//....." ] } ``` ## /dataset/res_csv/pt/dataset_info.json ```json path="/dataset/res_csv/pt/dataset_info.json" {"wechat-pt":{ "file_name": "./pt-my.json", "columns": { "prompt": "c" } }} ``` ## /dataset/res_csv/sft/dataset_info.json ```json path="/dataset/res_csv/sft/dataset_info.json" { "wechat-sft": { "file_name": "./sft-my.json", "columns": { "prompt": "instruction", "response": "output", "system": "system" } }, "wechat-sft-with-history": { "file_name": "./sft-my.json", "columns": { "prompt": "instruction", "response": "output", "system": "system", "history": "history" } } } ``` ## /dataset/test_data.json ```json path="/dataset/test_data.json" { "questions": [ [ "吃了吗?", "吃的什么啊", "好吃吗", "多少钱啊", "可以请我吃吗" ], [ "你多大了?" ], [ "你有什么爱好吗?" ], [ "你的理想是什么?", "你觉得你离你的理想还有多远?" ], [ "你最近在忙什么?", "工作/学习顺利吗?", "有什么有趣的事情发生吗?" ], [ "你喜欢看什么类型的电影?", "最近看过什么好看的电影吗?", "你最喜欢的电影是什么?" ], [ "你平时喜欢听什么音乐?", "有推荐的歌手或乐队吗?", "最近有喜欢的歌曲吗?" ], [ "你喜欢旅游吗?", "去过哪些地方?", "最喜欢的旅游地是哪里?" ], [ "你喜欢读书吗?", "最近在读什么书?", "最喜欢的书是哪本?" ], [ "你平时喜欢运动吗?", "喜欢做哪些运动?", "有固定去锻炼吗?" ], [ "周末一般都做些什么?", "有没有什么特别的计划?", "周末喜欢宅在家还是出去玩?" ], [ "你喜欢宠物吗?", "有养宠物吗?", "最喜欢什么动物?" ], [ "你喜欢吃什么类型的食物?", "有推荐的餐厅吗?", "最喜欢的菜是什么?" ], [ "你喜欢什么样的天气?", "最喜欢的季节是哪一个?", "你觉得今天的天气怎么样?" ], [ "你有看电视剧的习惯吗?", "最近在追哪部剧?", "最喜欢的电视剧是哪部?" ], [ "你喜欢玩游戏吗?", "最近在玩什么游戏?", "有推荐的好玩的游戏吗?" ], [ "你会做饭吗?", "平时喜欢做哪些菜?", "有没有特别拿手的菜?" ], [ "你喜欢购物吗?", "最近买了什么新东西?", "有推荐的购物网站或店铺吗?" ], [ "你平时怎么放松自己?", "有特别的解压方式吗?", "最喜欢的放松活动是什么?" ], [ "你喜欢和朋友出去玩吗?", "平时会和朋友去哪玩?", "最近有没有和朋友聚会的计划?" ], [ "你喜欢喝咖啡还是茶?", "有没有特别喜欢的咖啡馆或茶馆?", "最喜欢的饮品是什么?" ], [ "你有兄弟姐妹吗?", "和他们关系怎么样?", "经常联系吗?" ], [ "你喜欢读什么类型的杂志?", "最近有看什么有趣的文章吗?", "有订阅的杂志吗?" ], [ "你喜欢看体育比赛吗?", "最喜欢的运动项目是什么?", "有没有特别支持的球队或运动员?" ], [ "你会说其他语言吗?", "最想学的语言是什么?", "学习语言有什么技巧吗?" ], [ "你对科技产品感兴趣吗?", "最近有没有关注什么新科技?", "最喜欢的电子产品是什么?" ], [ "你喜欢喝什么样的饮料?", "有没有自己调饮料的习惯?", "最喜欢的饮品品牌是什么?" ], [ "你平时用社交媒体吗?", "常用哪些平台?", "在社交媒体上做什么?" ], [ "你对艺术感兴趣吗?", "最喜欢的艺术家是谁?", "有去过哪些艺术展览?" ], [ "你喜欢DIY吗?", "平时做些什么手工?", "有没有完成的作品可以分享?" ], [ "你喜欢种植植物吗?", "有养什么植物?", "最喜欢的植物是什么?" ], [ "你喜欢拍照吗?", "喜欢拍什么样的照片?", "有没有用什么特别的摄影设备?" ], [ "你喜欢听播客吗?", "常听哪些主题的播客?", "有没有推荐的播客?" ], [ "你对历史感兴趣吗?", "最喜欢哪个历史时期?", "有没有特别喜欢的历史人物?" ], [ "你喜欢画画吗?", "平时画什么类型的画?", "有参加过画展吗?" ], [ "你喜欢写作吗?", "平时写什么类型的文章?", "有没有发表过作品?" ], [ "你喜欢钓鱼吗?", "平时去哪里钓鱼?", "有没有钓到过什么大鱼?" ], [ "你喜欢露营吗?", "平时会去哪里露营?", "有没有什么难忘的露营经历?" ], [ "你喜欢摄影吗?", "最喜欢拍什么题材?", "有没有特别喜欢的摄影师?" ], [ "你喜欢喝酒吗?", "喜欢什么类型的酒?", "有没有推荐的酒吧或品牌?" ], [ "你喜欢滑雪吗?", "平时去哪里滑雪?", "有没有什么滑雪技巧分享?" ], [ "你喜欢海边还是山里?", "最喜欢去哪个地方度假?", "有没有什么特别推荐的景点?" ], [ "你喜欢参加音乐节吗?", "参加过哪些音乐节?", "最喜欢的音乐节是哪一个?" ], [ "你喜欢跑步吗?", "平时跑多长距离?", "有没有参加过马拉松?" ], [ "你喜欢参加聚会吗?", "平时和朋友聚会做什么?", "有没有什么有趣的聚会游戏?" ], [ "你喜欢收集东西吗?", "收集什么类型的物品?", "有没有什么特别的收藏?" ] ] } ``` ## /ds_config.json ```json path="/ds_config.json" { "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "zero_optimization": { "stage": 2, "allgather_partitions": true, "allgather_bucket_size": 5e8, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": 5e8, "contiguous_gradients": true }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "steps_per_print": 2000, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false } ``` ## /pyproject.toml ```toml path="/pyproject.toml" [project] name = "WeClone" version = "0.2.0" description = "从聊天记录创造数字分身的一站式解决方案" authors = [{ name = "xming521" }] readme = "README.md" requires-python = ">=3.10,<3.11" dependencies = [ "pandas", "commentjson", "pydantic==2.10.6", "setuptools>=78.1.0", "loguru>=0.7.3", "torch>=2.5.1", "transformers==4.49.0", ] [dependency-groups] # xcodec = ["xcodec2==0.1.3"] sparktts = [ "einops>=0.8.1", "einx>=0.3.0", "numpy==1.26.4", "omegaconf>=2.3.0", "packaging>=24.2", "safetensors>=0.5.2", "soundfile>=0.12.1", "soxr>=0.5.0.post1", "torchaudio>=2.5.1", "tqdm>=4.66.5", ] main = ["llamafactory>=0.9.2", "openai==0.28.0"] dev = ["pytest", "pyright", "ruff"] [tool.uv] conflicts = [ # [{ group = "wx" }, { group = "xcodec" }], ] [tool.uv.sources] torch = [ { index = "pytorch-cu121", marker = "platform_system == 'Windows'" }, { index = "pytorch-cu121", marker = "platform_system == 'Linux'" }, ] torchaudio = [ { index = "pytorch-cu121", marker = "platform_system == 'Windows'" }, { index = "pytorch-cu121", marker = "platform_system == 'Linux'" }, ] torchvision = [ { index = "pytorch-cu121", marker = "platform_system == 'Windows'" }, { index = "pytorch-cu121", marker = "platform_system == 'Linux'" }, ] [[tool.uv.index]] url = "https://pypi.tuna.tsinghua.edu.cn/simple/" default = true [[tool.uv.index]] name = "pytorch-cu121" url = "https://download.pytorch.org/whl/cu121" explicit = true [tool.setuptools.packages.find] where = ["."] # 表示在项目根目录开始查找 include = ["weclone*"] # 只包含名为 weclone 的目录及其子包 exclude = ["*tests*", "*archive*"] # 可以选择性排除其他模式,比如测试目录 [tool.pyright] typeCheckingMode = "basic" include = ["weclone/data"] exclude = ["**/archive", "**/tests"] ignore = ["**/archive"] reportMissingImports = "error" reportMissingTypeStubs = false pythonVersion = "3.10" pythonPlatform = "Linux" [tool.ruff] exclude = [ "**/archive", "**/tests", "weclone-audio/src/server未完工", "weclone-audio/src/Spark-TTS", ] line-length = 120 lint.ignore = ["F403", "F405", "E501", "E402"] lint.select = [ "F", # Pyflakes "W", # pycodestyle warnings "E", # pycodestyle errors "ASYNC", # flake8-async "C4", # flake8-comprehensions "Q", # flake8-quotes ] target-version = "py310" ``` ## /settings.json ```json path="/settings.json" { "train_pt_args": { "stage": "pt", "dataset": "wechat-pt", "dataset_dir": "./dataset/res_csv/pt", "lora_target": "q_proj,v_proj", "lora_rank": 2, "lora_dropout": 0.1, "output_dir": "model_output", "overwrite_cache": true, "per_device_train_batch_size": 1, "gradient_accumulation_steps": 1, "lr_scheduler_type": "cosine", "logging_steps": 10, "save_steps": 1000, "learning_rate": 0.001, "num_train_epochs": 30, "plot_loss": true, "fp16": true }, "train_sft_args": { "stage": "sft", "dataset": "wechat-sft", "dataset_dir": "./dataset/res_csv/sft", "use_fast_tokenizer": true, "lora_target": "q_proj,v_proj", "lora_rank": 4, "lora_dropout": 0.4, "weight_decay": 0.1, "overwrite_cache": true, "per_device_train_batch_size": 8, "gradient_accumulation_steps": 4, "lr_scheduler_type": "cosine", "cutoff_len": 256, "logging_steps": 10, "save_steps": 100, "learning_rate": 1e-4, "warmup_ratio": 0.1, "num_train_epochs": 3, "plot_loss": true, "fp16": true, "flash_attn": "fa2", // "deepspeed": "ds_config.json" //多卡训练 }, "infer_args": { "repetition_penalty": 1.2, "temperature": 0.5, "max_length": 50, "top_p": 0.65 }, "make_dataset_args": { // "enable_vision_model": false,//后续实现 // "include_type": [ // "文本" // ], "single_combine_strategy": "time_window", // 单人组成单句策略 "qa_match_strategy": "time_window", // 组成qa策略 "single_combine_time_window": 2, // 单人组成单句时间窗口(分钟), "qa_match_time_window": 5, // 组成qa时间窗口(分钟), "combine_msg_max_length": 256, // 组合后消息最大长度 配合cutoff_len 使用 "prompt_with_history": false // 是否在prompt中包含历史对话 }, "common_args": { "model_name_or_path": "./Qwen2.5-7B-Instruct", "adapter_name_or_path": "./model_output", //同时做为train_sft_args的output_dir "template": "qwen", "default_system": "请你扮演一名人类,不要说自己是人工智能", "finetuning_type": "lora", "trust_remote_code": true } } ``` ## /tests/README.md # WEClone 测试指南 本目录包含WEClone项目的测试文件,用于确保项目各个组件正常工作。 ## 测试文件说明 - `test_weclone_pipeline.py`: 全流程测试,按顺序测试数据生成、训练、API服务和模型评估 - `test_qa_generator.py`: 测试QA生成器功能 ## 运行全流程测试 要运行完整的测试流程,请执行以下命令: ```bash # 在项目根目录下执行 python -m tests.test_weclone_pipeline ``` ## 测试流程说明 全流程测试按照以下顺序测试项目的主要组件: 1. **数据生成**:测试 `weclone/data/qa_generator.py` 模块,模拟微信聊天记录的处理和QA对的生成 2. **模型训练**:测试 `weclone/train/train_sft.py` 模块,模拟使用生成的数据进行模型的SFT训练 3. **API服务**:测试 `weclone/server/api_service.py` 模块,模拟启动API服务 4. **模型评估**:测试 `weclone/eval/test_model.py` 模块,模拟对训练后的模型进行评估 ## 注意事项 - 测试使用Python的unittest框架和mock库,模拟各个组件的运行环境和依赖 - 测试不会修改实际的数据文件或模型文件,所有操作都在临时目录中进行 - 要运行单独的测试方法,可以使用以下命令: ```bash # 例如,只运行QA生成器测试 python -m unittest tests.test_weclone_pipeline.TestWeclonePipeline.test_qa_generator ``` ## /tests/test_full_pipeline.py ```py path="/tests/test_full_pipeline.py" import subprocess import sys import os import time import shutil import threading # 导入 threading from typing import Optional, Union, IO # 导入 IO import torch from loguru import logger from subprocess import Popen #TODO 放弃了改成测cli吧 # 配置 Loguru logger.remove() # 移除默认处理器 current_time = time.strftime('%Y%m%d_%H%M%S') log_file_path = os.path.join(os.path.dirname(__file__), f"pipeline_test_{current_time}.log") # 日志文件名包含执行时间 logger.add(log_file_path, rotation="10 MB", encoding='utf-8', level="DEBUG", enqueue=True) # 文件记录 DEBUG 级别 logger.add(sys.stdout, colorize=True, format="[test] {time:YYYY-MM-DD HH:mm:ss} | {level.name[0]} | {message}", level="INFO", enqueue=True) # 控制台保持 INFO 级别 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) logger.info(f"项目根目录: {project_root}") qa_script = "weclone/data/qa_generator.py" train_script = "weclone/train/train_sft.py" api_service_script = "weclone/server/api_service.py" eval_script = "weclone/eval/test_model.py" web_demo_script = "weclone/eval/web_demo.py" DEFAULT_TIMEOUT: Optional[Union[int, float]] = 45 API_STARTUP_WAIT = 20 API_TERMINATE_WAIT = 15 WEB_DEMO_STARTUP_WAIT = 20 WEB_DEMO_TERMINATE_WAIT = 15 STEP_QA = "QA 数据生成" STEP_TRAIN = "SFT 训练" STEP_COPY_CKPT = "Checkpoint 复制" STEP_API_START = "API 服务启动" STEP_EVAL = "模型评估" STEP_WEB_DEMO = "Web Demo 启动" # Mapping from identifiers (script paths or custom keys) to step names step_identifiers = { qa_script: STEP_QA, train_script: STEP_TRAIN, "copy_checkpoint": STEP_COPY_CKPT, # Custom key for non-script step api_service_script: STEP_API_START, # Script associated with starting API eval_script: STEP_EVAL, web_demo_script: STEP_WEB_DEMO, # Script associated with starting Web Demo } # Order for fallback logic step_order = [STEP_QA, STEP_TRAIN, STEP_COPY_CKPT, STEP_API_START, STEP_EVAL, STEP_WEB_DEMO] #todo 需要测试前替换成测试的settings.json 测试完再替换回来 class PipelineStepError(Exception): """自定义异常类,用于表示 Pipeline 步骤执行失败。""" pass # --- 辅助函数:用于在线程中读取和记录流 --- def log_stream(stream: Optional[IO[str]], log_func): """读取流并使用指定的 log 函数记录每一行。""" if stream is None: return try: for line in iter(stream.readline, ''): if line: log_func(line.strip()) # 去除末尾换行符 except ValueError: # 当 Popen 的 stream 在另一线程中被关闭时,readline 可能会抛出 ValueError logger.warning("日志流在读取时似乎已被关闭。") except Exception as e: # 捕获其他潜在的读取错误 logger.warning(f"日志流读取时发生未预料的错误: {e}") finally: if stream: try: stream.close() # 确保流被关闭 except Exception as close_e: logger.warning(f"关闭日志流时发生错误: {close_e}") # --- 新增:启动日志流线程的辅助函数 --- def _start_stream_logging_threads(process: Popen, stdout_log_func=logger.info, stderr_log_func=logger.error) -> tuple[threading.Thread, threading.Thread]: """为给定的进程启动 stdout 和 stderr 的日志记录线程。""" stdout_thread = threading.Thread( target=log_stream, args=(process.stdout, stdout_log_func), daemon=True ) stderr_thread = threading.Thread( target=log_stream, args=(process.stderr, stderr_log_func), daemon=True ) stdout_thread.start() stderr_thread.start() return stdout_thread, stderr_thread def run_script(script_relative_path: str, timeout: Optional[Union[int, float]] = DEFAULT_TIMEOUT, ignore_timeout_error: bool = False, env: Optional[dict] = None): """使用 Popen 执行脚本,通过线程实时记录 stdout/stderr 到 loguru。""" script_full_path = os.path.join(project_root, script_relative_path) timeout_str = '无限制' if timeout is None else f'{timeout}s' env_str = f" (环境变量: {env})" if env else "" logger.info(f"--- 开始执行 (流式): {script_relative_path} (超时: {timeout_str}){env_str} ---") if not os.path.exists(script_full_path): error_msg = f"脚本文件不存在 {script_full_path}" logger.error(error_msg) raise PipelineStepError(error_msg) process: Optional[Popen] = None stdout_thread: Optional[threading.Thread] = None stderr_thread: Optional[threading.Thread] = None # 准备环境变量 run_env = os.environ.copy() if env: run_env.update(env) try: process = Popen( [sys.executable, script_full_path], cwd=project_root, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, encoding='utf-8', bufsize=1, # 行缓冲 env=run_env # 传递环境变量 ) # 使用辅助函数启动日志线程 stdout_thread, stderr_thread = _start_stream_logging_threads(process, logger.debug, logger.debug) # stdout/stderr 都用 debug # 等待子进程完成或超时 try: return_code = process.wait(timeout=timeout) except subprocess.TimeoutExpired: warn_msg = f"{script_relative_path} 执行超时 ({timeout}s)。" logger.warning(warn_msg) # 尝试优雅地关闭流(可能已被 log_stream 关闭) if process.stdout: process.stdout.close() if process.stderr: process.stderr.close() process.kill() # 强制终止超时进程 logger.warning(f"已强制终止进程 {process.pid}") # 等待 I/O 线程完成(即使进程被 kill,也要尝试读取剩余输出) if stdout_thread: stdout_thread.join(timeout=5) if stderr_thread: stderr_thread.join(timeout=5) if not ignore_timeout_error: error_msg = f"{script_relative_path} 执行超时 ({timeout}s) 且未忽略。" logger.error(error_msg) raise PipelineStepError(error_msg) else: logger.info("--- 根据设置,超时不视为错误,继续执行后续步骤。 ---") return # 忽略超时,函数正常返回 # 等待日志线程完成(确保所有输出都被记录) if stdout_thread: stdout_thread.join() if stderr_thread: stderr_thread.join() # 检查返回码 if return_code != 0: error_msg = f"{script_relative_path} 执行失败,返回码 {return_code}" logger.error(error_msg) raise PipelineStepError(error_msg) else: logger.success(f"--- {script_relative_path} 执行成功 ---") except FileNotFoundError: error_msg = f"Python 解释器 '{sys.executable}' 或脚本 '{script_full_path}' 未找到。" logger.error(error_msg) raise PipelineStepError(error_msg) except Exception as e: # 捕获其他潜在错误 (例如 Popen 本身失败) error_msg = f"执行 {script_relative_path} 时发生意外错误: {e}" logger.error(error_msg) # 尝试确保进程和线程被清理 if process and process.poll() is None: try: if process.stdout: process.stdout.close() if process.stderr: process.stderr.close() process.kill() logger.warning(f"因异常 {e},强制终止进程 {process.pid}") except Exception as kill_e: logger.error(f"清理过程中强制终止进程失败: {kill_e}") if stdout_thread and stdout_thread.is_alive(): stdout_thread.join(timeout=1) if stderr_thread and stderr_thread.is_alive(): stderr_thread.join(timeout=1) raise PipelineStepError(error_msg) def start_api_service_background() -> Popen: """在后台启动 API 服务脚本,实时记录启动日志,失败时抛出 PipelineStepError。""" script_full_path = os.path.join(project_root, api_service_script) logger.info(f"--- 尝试在后台启动: {api_service_script} ---") if not os.path.exists(script_full_path): error_msg = f"脚本文件不存在 {script_full_path}" logger.error(error_msg) raise PipelineStepError(error_msg) process: Optional[Popen] = None stdout_thread: Optional[threading.Thread] = None stderr_thread: Optional[threading.Thread] = None try: logger.info(f"启动命令: {[sys.executable, script_full_path]}") process = Popen( [sys.executable, script_full_path], cwd=project_root, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, encoding='utf-8', bufsize=1 # 行缓冲 ) # 使用辅助函数启动日志线程 stdout_thread, stderr_thread = _start_stream_logging_threads(process, logger.debug, logger.debug) # stdout/stderr 都用 debug logger.info(f"等待 {API_STARTUP_WAIT} 秒让服务初步启动 (日志将实时显示)...") time.sleep(API_STARTUP_WAIT) # 检查进程是否仍在运行 if process.poll() is None: logger.success(f"--- {api_service_script} 似乎已在后台启动 (进程 PID: {process.pid}) ---") # 注意:不 join 日志线程,让它们继续运行 return process else: # 进程过早退出 logger.error(f"{api_service_script} 启动后在 {API_STARTUP_WAIT} 秒内过早退出,返回码 {process.returncode}") # 尝试等待日志线程结束以捕获最后输出 if stdout_thread: stdout_thread.join(timeout=2) if stderr_thread: stderr_thread.join(timeout=2) # 读取 communicate 获取可能遗漏的最终输出 (虽然理论上线程应该读完了) try: # 设置短超时,因为进程已退出,communicate 应该立即返回 stdout, stderr = process.communicate(timeout=1) except subprocess.TimeoutExpired: logger.warning("等待 communicate 超时,可能没有更多输出了。") stdout, stderr = "", "" # 假设没有更多输出 except Exception as comm_e: logger.warning(f"调用 communicate 获取最后输出时出错: {comm_e}") stdout, stderr = "", "" error_message = f'''--- EARLY EXIT STDOUT --- {stdout} --- EARLY EXIT STDERR --- {stderr}''' logger.error(error_message) raise PipelineStepError(f"{api_service_script} 启动失败并过早退出。") except FileNotFoundError: error_msg = f"Python 解释器 '{sys.executable}' 或脚本 '{script_full_path}' 未找到。" logger.error(error_msg) raise PipelineStepError(error_msg) except Exception as e: # 捕获其他启动错误 error_msg = f"启动 {api_service_script} 时发生意外错误: {e}" logger.error(error_msg) if process and process.poll() is None: logger.warning("捕获到异常,尝试强制终止进程...") try: if process.stdout: process.stdout.close() if process.stderr: process.stderr.close() process.kill() except Exception as kill_e: logger.error(f"强制终止进程时出错: {kill_e}") # 尝试join线程 if stdout_thread and stdout_thread.is_alive(): stdout_thread.join(timeout=1) if stderr_thread and stderr_thread.is_alive(): stderr_thread.join(timeout=1) raise PipelineStepError(error_msg) def stop_api_service(process: Optional[Popen]): """停止指定的 API 服务进程,采用更健壮的终止和清理逻辑。""" if process and process.poll() is None: pid = process.pid # Get PID for logging logger.info(f"--- 尝试停止 API 服务 (PID: {pid}) ---") try: logger.info(f"发送 SIGTERM 信号到进程 {pid}...") process.terminate() try: logger.info(f"等待最多 {API_TERMINATE_WAIT} 秒让进程 {pid} 优雅终止...") process.wait(timeout=API_TERMINATE_WAIT) logger.info(f"API 服务进程 {pid} 已优雅终止,返回码: {process.returncode}") # 进程已终止,尝试获取最终输出 try: stdout, stderr = process.communicate(timeout=2) if stdout: logger.debug(f"进程 {pid} 最终 STDOUT:\n{stdout.strip()}") if stderr: logger.debug(f"进程 {pid} 最终 STDERR:\n{stderr.strip()}") except subprocess.TimeoutExpired: logger.warning(f"获取进程 {pid} 最终输出时超时。") except Exception as comm_e: logger.warning(f"获取进程 {pid} 最终输出时出错: {comm_e}") except subprocess.TimeoutExpired: logger.warning(f"进程 {pid} 优雅终止超时 ({API_TERMINATE_WAIT}s),发送 SIGKILL 信号...") process.kill() logger.info(f"等待进程 {pid} 被强制终止...") # 在 kill 后等待,应该很快返回。增加安全超时。 try: process.wait(timeout=5) logger.info(f"API 服务进程 {pid} 已被强制终止。") except subprocess.TimeoutExpired: logger.error(f"进程 {pid} 在发送 SIGKILL 后仍然没有终止!") except Exception as wait_kill_e: logger.error(f"等待强制终止进程 {pid} 时发生错误: {wait_kill_e}") # 尝试在 kill 后获取输出 try: # 在 kill 后也使用 communicate,它隐式处理等待 stdout, stderr = process.communicate(timeout=2) if stdout: logger.warning(f"来自进程 {pid} 的 Kill 后输出 (STDOUT):\n{stdout.strip()}") if stderr: logger.warning(f"来自进程 {pid} 的 Kill 后输出 (STDERR):\n{stderr.strip()}") except Exception as comm_e: logger.warning(f"获取进程 {pid} (强制终止后) 输出时出错: {comm_e}") except Exception as e: logger.error(f"停止 API 服务 (PID: {pid if process else '未知'}) 时发生意外错误: {e}") # 如果进程仍然存活,尝试最后一次强制 kill if process and process.poll() is None: logger.warning(f"最终尝试强制终止进程 {pid}...") try: process.kill() process.wait(timeout=5) except Exception as final_kill_e: logger.error(f"最终强制终止进程 {pid} 时出错: {final_kill_e}") elif process: logger.info(f"--- API 服务进程 (PID: {process.pid}) 在尝试停止前已经退出。 ---") else: logger.debug("--- 无需停止 API 服务 (进程不存在或已为 None) ---") def start_web_demo_background() -> Popen: """在后台启动 Web Demo 脚本,实时记录启动日志,失败时抛出 PipelineStepError。""" script_full_path = os.path.join(project_root, web_demo_script) logger.info(f"--- 尝试在后台启动: {web_demo_script} ---") if not os.path.exists(script_full_path): error_msg = f"脚本文件不存在 {script_full_path}" logger.error(error_msg) raise PipelineStepError(error_msg) process: Optional[Popen] = None stdout_thread: Optional[threading.Thread] = None stderr_thread: Optional[threading.Thread] = None try: logger.info(f"启动命令: {[sys.executable, script_full_path]}") process = Popen( [sys.executable, script_full_path], cwd=project_root, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, encoding='utf-8', bufsize=1 # 行缓冲 ) # 使用辅助函数启动日志线程 (stdout/stderr 都用 info) stdout_thread, stderr_thread = _start_stream_logging_threads(process, logger.debug, logger.debug) # stdout/stderr 都用 debug logger.info(f"等待 {WEB_DEMO_STARTUP_WAIT} 秒让 Web Demo 初步启动 (日志将实时显示)...") time.sleep(WEB_DEMO_STARTUP_WAIT) # 检查进程是否仍在运行 if process.poll() is None: logger.success(f"--- {web_demo_script} 似乎已在后台启动 (进程 PID: {process.pid}) ---") # 注意:不 join 日志线程 return process else: # 进程过早退出 logger.error(f"{web_demo_script} 启动后在 {WEB_DEMO_STARTUP_WAIT} 秒内过早退出,返回码 {process.returncode}") if stdout_thread: stdout_thread.join(timeout=2) if stderr_thread: stderr_thread.join(timeout=2) try: stdout, stderr = process.communicate(timeout=1) except subprocess.TimeoutExpired: logger.warning("等待 communicate 超时,可能没有更多输出了。") stdout, stderr = "", "" except Exception as comm_e: logger.warning(f"调用 communicate 获取最后输出时出错: {comm_e}") stdout, stderr = "", "" error_message = f'''--- EARLY EXIT STDOUT --- {stdout} --- EARLY EXIT STDERR --- {stderr}''' logger.error(error_message) raise PipelineStepError(f"{web_demo_script} 启动失败并过早退出。") except FileNotFoundError: error_msg = f"Python 解释器 '{sys.executable}' 或脚本 '{script_full_path}' 未找到。" logger.error(error_msg) raise PipelineStepError(error_msg) except Exception as e: error_msg = f"启动 {web_demo_script} 时发生意外错误: {e}" logger.error(error_msg) if process and process.poll() is None: logger.warning("捕获到异常,尝试强制终止进程...") try: if process.stdout: process.stdout.close() if process.stderr: process.stderr.close() process.kill() except Exception as kill_e: logger.error(f"强制终止进程时出错: {kill_e}") if stdout_thread and stdout_thread.is_alive(): stdout_thread.join(timeout=1) if stderr_thread and stderr_thread.is_alive(): stderr_thread.join(timeout=1) raise PipelineStepError(error_msg) def stop_web_demo(process: Optional[Popen]): """停止指定的 Web Demo 进程,采用更健壮的终止和清理逻辑。""" if process and process.poll() is None: pid = process.pid # Get PID for logging logger.info(f"--- 尝试停止 Web Demo 服务 (PID: {pid}) ---") try: logger.info(f"发送 SIGTERM 信号到进程 {pid}...") process.terminate() try: logger.info(f"等待最多 {WEB_DEMO_TERMINATE_WAIT} 秒让进程 {pid} 优雅终止...") process.wait(timeout=WEB_DEMO_TERMINATE_WAIT) logger.info(f"Web Demo 服务进程 {pid} 已优雅终止,返回码: {process.returncode}") # 进程已终止,尝试获取最终输出 try: stdout, stderr = process.communicate(timeout=2) if stdout: logger.debug(f"进程 {pid} 最终 STDOUT:\n{stdout.strip()}") if stderr: logger.debug(f"进程 {pid} 最终 STDERR:\n{stderr.strip()}") except subprocess.TimeoutExpired: logger.warning(f"获取进程 {pid} 最终输出时超时。") except Exception as comm_e: logger.warning(f"获取进程 {pid} 最终输出时出错: {comm_e}") except subprocess.TimeoutExpired: logger.warning(f"进程 {pid} 优雅终止超时 ({WEB_DEMO_TERMINATE_WAIT}s),发送 SIGKILL 信号...") process.kill() logger.info(f"等待进程 {pid} 被强制终止...") try: process.wait(timeout=5) logger.info(f"Web Demo 服务进程 {pid} 已被强制终止。") except subprocess.TimeoutExpired: logger.error(f"进程 {pid} 在发送 SIGKILL 后仍然没有终止!") except Exception as wait_kill_e: logger.error(f"等待强制终止进程 {pid} 时发生错误: {wait_kill_e}") # 尝试在 kill 后获取输出 try: stdout, stderr = process.communicate(timeout=2) if stdout: logger.warning(f"来自进程 {pid} 的 Kill 后输出 (STDOUT):\n{stdout.strip()}") if stderr: logger.warning(f"来自进程 {pid} 的 Kill 后输出 (STDERR):\n{stderr.strip()}") except Exception as comm_e: logger.warning(f"获取进程 {pid} (强制终止后) 输出时出错: {comm_e}") except Exception as e: logger.error(f"停止 Web Demo 服务 (PID: {pid if process else '未知'}) 时发生意外错误: {e}") # 如果进程仍然存活,尝试最后一次强制 kill if process and process.poll() is None: logger.warning(f"最终尝试强制终止进程 {pid}...") try: process.kill() process.wait(timeout=5) except Exception as final_kill_e: logger.error(f"最终强制终止进程 {pid} 时出错: {final_kill_e}") elif process: logger.info(f"--- Web Demo 服务进程 (PID: {process.pid}) 在尝试停止前已经退出。 ---") else: logger.debug("--- 无需停止 Web Demo 服务 (进程不存在或已为 None) ---") # --- 新增:监控 Checkpoint 目录的函数 --- def monitor_checkpoints(process: Popen, model_output_dir: str, stop_event: threading.Event, check_interval: float = 5.0): """ 在后台线程中监控指定目录,如果发现 checkpoint* 目录,则尝试终止目标进程。 """ logger.info(f"[Monitor] 开始监控目录 {model_output_dir} 的 checkpoint...") while not stop_event.is_set(): if not os.path.isdir(model_output_dir): # 目录可能尚未创建,等待下一个间隔 time.sleep(check_interval) continue try: found_checkpoint = False for item in os.listdir(model_output_dir): item_path = os.path.join(model_output_dir, item) if os.path.isdir(item_path) and item.startswith("checkpoint"): logger.warning(f"[Monitor] 检测到 Checkpoint 目录: {item_path}。尝试停止训练进程 (PID: {process.pid})...") found_checkpoint = True break # 找到一个就足够了 if found_checkpoint: # 发送终止信号 try: logger.info(f"[Monitor] 发送 SIGTERM 到进程 {process.pid}...") process.terminate() # 给进程一点时间响应 SIGTERM try: process.wait(timeout=5) logger.info(f"[Monitor] 进程 {process.pid} 已通过 SIGTERM 终止。") except subprocess.TimeoutExpired: logger.warning(f"[Monitor] 进程 {process.pid} 未在 5 秒内响应 SIGTERM,发送 SIGKILL...") process.kill() process.wait(timeout=5) # 等待 SIGKILL 生效 logger.info(f"[Monitor] 进程 {process.pid} 已通过 SIGKILL 终止。") except Exception as term_err: logger.error(f"[Monitor] 尝试终止进程 {process.pid} 时出错: {term_err}") finally: stop_event.set() # 通知主线程停止等待 logger.info("[Monitor] 已设置停止事件,监控结束。") return # 找到 checkpoint 并处理后,监控任务完成 except FileNotFoundError: # 目录可能在检查时被删除,忽略 pass except Exception as e: logger.error(f"[Monitor] 监控时发生错误: {e}") # 出现错误也设置停止信号,防止无限循环或未处理的异常 stop_event.set() return # 如果没有找到,且进程仍在运行,则等待下一个检查周期 if process.poll() is None: time.sleep(check_interval) else: # 如果进程已经结束(无论何种原因),监控也应结束 logger.info(f"[Monitor] 训练进程 {process.pid} 似乎已结束,停止监控。") break # 进程已结束,退出循环 logger.info("[Monitor] 监控循环正常结束。") # --- 新增:运行训练并进行监控的函数 --- def run_train_with_checkpoint_monitoring( script_relative_path: str, model_output_dir: str, timeout: Optional[Union[int, float]] = DEFAULT_TIMEOUT, ignore_timeout_error: bool = False, env: Optional[dict] = None ) -> str: """ 执行训练脚本,同时启动一个后台线程监控 checkpoint 目录。 如果检测到 checkpoint,会尝试停止训练进程。 返回执行状态: "success", "stopped_by_monitor", "timeout", "failed" """ script_full_path = os.path.join(project_root, script_relative_path) timeout_str = '无限制' if timeout is None else f'{timeout}s' env_str = f" (环境变量: {env})" if env else "" logger.info(f"--- 开始执行 (带监控): {script_relative_path} (超时: {timeout_str}){env_str} ---") if not os.path.exists(script_full_path): error_msg = f"脚本文件不存在 {script_full_path}" logger.error(error_msg) raise PipelineStepError(error_msg) process: Optional[Popen] = None stdout_thread: Optional[threading.Thread] = None stderr_thread: Optional[threading.Thread] = None monitor_thread: Optional[threading.Thread] = None stop_event = threading.Event() status = "failed" # 默认状态 # 准备环境变量 run_env = os.environ.copy() if env: run_env.update(env) try: process = Popen( [sys.executable, script_full_path], cwd=project_root, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, encoding='utf-8', bufsize=1, # 行缓冲 env=run_env ) # 启动日志线程 stdout_thread, stderr_thread = _start_stream_logging_threads(process, logger.debug, logger.debug) # 启动监控线程 monitor_thread = threading.Thread( target=monitor_checkpoints, args=(process, model_output_dir, stop_event), daemon=True ) monitor_thread.start() # 等待进程完成、被监控停止或超时 start_time = time.time() wait_interval = 1 # seconds to wait between checks while True: # 检查进程是否结束 return_code = process.poll() if return_code is not None: # 进程已结束 if stop_event.is_set(): # 如果是监控线程停止的 logger.warning(f"{script_relative_path} 被监控线程停止。返回码可能为 {return_code}。") status = "stopped_by_monitor" elif return_code == 0: logger.success(f"{script_relative_path} 成功完成。") status = "success" else: logger.error(f"{script_relative_path} 执行失败,返回码 {return_code}") status = "failed" stop_event.set() # 确保监控线程也会退出 break # 检查是否被监控线程要求停止 if stop_event.is_set(): logger.warning(f"{script_relative_path} 被监控线程标记为停止。") # 进程可能仍在运行,等待 monitor_checkpoints 中的终止逻辑生效 # 但我们这里也应该退出等待循环 status = "stopped_by_monitor" # 不需要再次 kill,monitor 线程会处理 break # 检查是否超时 if timeout is not None and (time.time() - start_time) > timeout: warn_msg = f"{script_relative_path} 执行超时 ({timeout}s)。" logger.warning(warn_msg) stop_event.set() # 通知监控线程停止 # 尝试优雅地关闭流 if process.stdout: process.stdout.close() if process.stderr: process.stderr.close() process.kill() # 强制终止超时进程 logger.warning(f"已强制终止进程 {process.pid}") status = "timeout" break # 短暂休眠后继续检查 time.sleep(wait_interval) # --- 等待所有线程完成 --- logger.info("等待日志和监控线程完成...") if stdout_thread: stdout_thread.join(timeout=5) if stderr_thread: stderr_thread.join(timeout=5) if monitor_thread: monitor_thread.join(timeout=5) # 监控线程也需要 join # 处理最终状态 if status == "failed": raise PipelineStepError(f"{script_relative_path} 执行失败。") elif status == "timeout": if not ignore_timeout_error: error_msg = f"{script_relative_path} 执行超时 ({timeout}s) 且未忽略。" logger.error(error_msg) raise PipelineStepError(error_msg) else: logger.info("--- 根据设置,超时不视为错误,继续执行后续步骤。 ---") # 即使忽略超时错误,状态仍然是 "timeout" return status # 返回 "timeout" 状态 # 对于 success 和 stopped_by_monitor,直接返回状态 logger.info(f"--- {script_relative_path} 执行结束,状态: {status} ---") return status except FileNotFoundError: error_msg = f"Python 解释器 '{sys.executable}' 或脚本 '{script_full_path}' 未找到。" logger.error(error_msg) raise PipelineStepError(error_msg) except Exception as e: # 捕获其他潜在错误 (例如 Popen 本身失败) error_msg = f"执行 {script_relative_path} (带监控) 时发生意外错误: {e}" logger.error(error_msg) stop_event.set() # 确保监控线程停止 # 尝试确保进程和线程被清理 if process and process.poll() is None: try: if process.stdout: process.stdout.close() if process.stderr: process.stderr.close() process.kill() logger.warning(f"因异常 {e},强制终止进程 {process.pid}") except Exception as kill_e: logger.error(f"清理过程中强制终止进程失败: {kill_e}") if stdout_thread and stdout_thread.is_alive(): stdout_thread.join(timeout=1) if stderr_thread and stderr_thread.is_alive(): stderr_thread.join(timeout=1) if monitor_thread and monitor_thread.is_alive(): monitor_thread.join(timeout=1) raise PipelineStepError(error_msg) if __name__ == "__main__": logger.info("="*20 + " 开始执行 WeClone Pipeline 脚本 " + "="*20) is_cuda_available = torch.cuda.is_available() logger.info("--- CUDA 可用性检查 ---") if is_cuda_available: gpu_count = torch.cuda.device_count() logger.success(f"CUDA 可用 (找到 {gpu_count} 个 GPU)") for i in range(gpu_count): logger.info(f" - GPU {i}: {torch.cuda.get_device_name(i)}") else: logger.warning("CUDA 不可用,将使用 CPU (如果适用)。") logger.info("-" * 25) steps_completed = [] api_process: Optional[Popen] = None web_demo_process: Optional[Popen] = None # 设置哪些步骤需要运行 run_qa = True run_train = True run_copy_checkpoint = True # 依赖于 run_train run_api = True run_eval = True # 依赖于 run_api run_web_demo = True # 不依赖于 run_api try: # 步骤 1: QA Generator if run_qa: logger.info("-" * 10 + " 步骤 1: QA 数据生成 " + "-" * 10) run_script(qa_script) steps_completed.append(f"{STEP_QA}: 成功") else: logger.info(f"{STEP_QA}: 跳过 (配置)") steps_completed.append(f"{STEP_QA}: 跳过") # 步骤 2: Train SFT (with monitoring) if run_train: logger.info("-" * 10 + " 步骤 2: SFT 训练 (带 Checkpoint 监控) " + "-" * 10) model_output_dir = os.path.join(project_root, "model_output") # --- 删除 model_output 目录 --- if os.path.exists(model_output_dir): logger.info(f"删除现有的 model_output 目录: {model_output_dir}") try: shutil.rmtree(model_output_dir) logger.success("成功删除 model_output 目录") except Exception as e: logger.error(f"删除 model_output 目录时出错: {e}") # Treat failure to delete as a critical error before training raise PipelineStepError(f"删除 model_output 目录失败: {e} ###step_id:{train_script}###") # --- 执行训练脚本并进行监控 --- train_status = run_train_with_checkpoint_monitoring( train_script, model_output_dir, # Pass the directory to monitor timeout=2000, ignore_timeout_error=True, env={'TQDM_DISABLE': '1'} ) # --- 根据训练状态更新完成列表 --- if train_status == "success": steps_completed.append(f"{STEP_TRAIN}: 成功") elif train_status == "stopped_by_monitor": steps_completed.append(f"{STEP_TRAIN}: 已停止 (检测到 Checkpoint)") elif train_status == "timeout": steps_completed.append(f"{STEP_TRAIN}: 超时 (已忽略)") else: # "failed" or other unexpected status handled by exception steps_completed.append(f"{STEP_TRAIN}: 失败") # Should be caught by exception, but added for completeness # 步骤 2.1: 复制 Checkpoint (只有在训练 *成功* 完成后才执行) if run_copy_checkpoint: if train_status == "success": logger.info("-" * 10 + " 步骤 2.1: 复制 Checkpoint 到 model_output " + "-" * 10) source_dir = os.path.join(project_root, "model_output", "checkpoint-2") # Note: Still assumes checkpoint-2 specifically. dest_dir = os.path.join(project_root, "model_output") if os.path.isdir(source_dir): try: logger.info(f"开始将 {source_dir} 的内容复制到 {dest_dir}...") shutil.copytree(source_dir, dest_dir, dirs_exist_ok=True) logger.success(f"--- {STEP_COPY_CKPT} 成功 ---") steps_completed.append(f"{STEP_COPY_CKPT}: 成功") except Exception as e: error_msg = f"{STEP_COPY_CKPT} 时发生错误: {e}" logger.error(error_msg) raise PipelineStepError(f"{error_msg} ###step_id:copy_checkpoint###") else: logger.warning(f"训练成功后,源 Checkpoint 目录 {source_dir} 不存在或不是目录,跳过复制。") steps_completed.append(f"{STEP_COPY_CKPT}: 跳过 (源不存在)") # Consider if missing checkpoint-2 after successful training is an error # raise PipelineStepError(f"训练成功但必需的源 Checkpoint 目录 {source_dir} 不存在") else: # If training didn't succeed (stopped, timeout, failed), skip copy logger.info(f"{STEP_COPY_CKPT}: 跳过 (训练未成功完成,状态: {train_status})") steps_completed.append(f"{STEP_COPY_CKPT}: 跳过 (训练未成功)") else: logger.info(f"{STEP_COPY_CKPT}: 跳过 (配置)") steps_completed.append(f"{STEP_COPY_CKPT}: 跳过 (配置)") else: # If run_train is false logger.info(f"{STEP_TRAIN}: 跳过 (配置)") steps_completed.append(f"{STEP_TRAIN}: 跳过 (配置)") logger.info(f"{STEP_COPY_CKPT}: 跳过 (训练未运行)") steps_completed.append(f"{STEP_COPY_CKPT}: 跳过 (训练未运行)") # 步骤 3: Start API Service if run_api: logger.info("-" * 10 + " 步骤 3: 启动 API 服务 " + "-" * 10) api_process = start_api_service_background() steps_completed.append(f"{STEP_API_START}: 成功") else: logger.info(f"{STEP_API_START}: 跳过 (配置)") steps_completed.append(f"{STEP_API_START}: 跳过 (配置)") # 步骤 4: Eval Model (依赖 API 服务) if run_eval: if not run_api: logger.info("-" * 10 + " 步骤 4: 模型评估 " + "-" * 10) logger.warning("--- 因 API 服务配置为不运行,跳过执行: weclone/eval/test_model.py ---") steps_completed.append(f"{STEP_EVAL}: 跳过 (API未配置运行)") elif api_process is None or api_process.poll() is not None: # 检查进程是否已退出 error_msg = "尝试运行评估,但 API 服务进程不存在或已退出。" logger.error(error_msg) raise PipelineStepError(error_msg) else: logger.info("-" * 10 + " 步骤 4: 模型评估 " + "-" * 10) # 在调用评估脚本时禁用 tqdm run_script(eval_script, timeout=9999, env={'TQDM_DISABLE': '1'}) steps_completed.append(f"{STEP_EVAL}: 成功") stop_api_service(api_process) # 评估完成后停止API api_process = None # 标记为已停止 else: logger.info(f"{STEP_EVAL}: 跳过 (配置)") steps_completed.append(f"{STEP_EVAL}: 跳过 (配置)") if api_process: # 如果API在运行但评估被跳过,也停止API logger.info("评估被跳过,停止 API 服务...") stop_api_service(api_process) api_process = None # 步骤 5: Start Web Demo (不依赖 API 服务) if run_web_demo: logger.info("-" * 10 + " 步骤 5: 启动 Web Demo " + "-" * 10) web_demo_process = start_web_demo_background() steps_completed.append(f"{STEP_WEB_DEMO}: 成功") logger.info("--- Web Demo 已启动,测试流程继续... ---") else: logger.info(f"{STEP_WEB_DEMO}: 跳过 (配置)") steps_completed.append(f"{STEP_WEB_DEMO}: 跳过 (配置)") # Pipeline 成功完成所有请求的步骤 logger.info("="*20 + " Pipeline 执行摘要 " + "="*20) for step in steps_completed: logger.info(f"- {step}") logger.success("✅ 所有请求执行的 Pipeline 步骤均成功完成!") skipped_steps = [s for s in steps_completed if "跳过" in s] if skipped_steps: logger.warning("注意: 以下步骤被设置为跳过,如需执行请修改脚本顶部的 run_xxx 变量:") for skipped in skipped_steps: logger.warning(f" - {skipped.split(':')[0]}") except PipelineStepError as e: logger.error("="*20 + " Pipeline 执行失败 " + "="*20) failing_step = "未知步骤" error_details = str(e) cleaned_error_details = error_details # Store original/cleaned details for logging # Attempt 1: Check for explicit marker (e.g., from copy_checkpoint) marker_prefix = "###step_id:" marker_suffix = "###" marker_start = error_details.find(marker_prefix) if marker_start != -1: marker_end = error_details.find(marker_suffix, marker_start + len(marker_prefix)) if marker_end != -1: step_id = error_details[marker_start + len(marker_prefix):marker_end] failing_step = step_identifiers.get(step_id, f"未知标记 ({step_id})") # Clean the marker from the displayed error details cleaned_error_details = error_details[:marker_start].strip() # Attempt 2: Check for known script paths in the error message if marker not found if failing_step == "未知步骤": found_script = False # Iterate through potential script paths stored as keys in step_identifiers for identifier, step_name in step_identifiers.items(): # Check if the identifier looks like a path and is in the error message if isinstance(identifier, str) and ('/' in identifier or '\\\\' in identifier) and identifier in error_details: failing_step = step_name found_script = True break # Found the most likely script # Attempt 3: Fallback based on last completed step (if still unknown) if failing_step == "未知步骤": if steps_completed: last_completed = steps_completed[-1].split(':')[0] try: last_completed_index = step_order.index(last_completed) if last_completed_index + 1 < len(step_order): # Assume the next step in the defined order failed failing_step = step_order[last_completed_index + 1] + " (推断)" else: failing_step = "Pipeline末尾或未知 (推断)" # Error after the last known step except ValueError: # Last completed step name wasn't found in our defined order failing_step = f"未知 (最后完成: {last_completed})" else: failing_step = "初始化期间" # No steps completed logger.error(f"错误发生在步骤: {failing_step}") logger.error(f"错误详情: {cleaned_error_details}") # Log the cleaned error details logger.info("--- 已完成步骤 ---") for step in steps_completed: logger.info(f"- {step}") logger.error("="*50) sys.exit(1) # 测试失败时退出码为 1 finally: logger.info("--- Pipeline 结束,开始清理后台服务 ---") # 确保在 finally 块中总是尝试停止服务 stop_web_demo(web_demo_process) stop_api_service(api_process) # 即使评估步骤停止了它,这里也尝试停止,无害 logger.info("--- 后台服务清理完成 ---") # 如果 Pipeline 成功,确保退出码为 0 sys.exit(0) ``` ## /tests/test_get_sample_audio.py ```py path="/tests/test_get_sample_audio.py" import os import subprocess import sys import pytest # 获取 weclone-audio/src 目录的绝对路径 # 这假设 tests 目录和 weclone-audio 在同一个父目录下 SRC_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'weclone-audio', 'src')) SCRIPT_PATH = os.path.join(SRC_DIR, 'get_sample_audio.py') # --- 测试配置 --- # 请将下面的路径替换为你的测试数据库文件的实际路径 # 最好放在 tests/data 目录下,并使用相对路径 TEST_DB_PATH = r"D:\projects\python projects\WeClone-data\wxdump_work\wxid_d6wwiru2zsmo22\merge_all.db"# <--- 修改这里 # 请将下面的 ID 替换为测试数据库中一个有效的音频消息的 MsgSvrID TEST_MSG_SVR_ID = "3269716813078873653" # <--- 修改这里 # ---------------- @pytest.fixture(scope="module") def setup_test_environment(): """确保测试所需的文件和目录存在""" if not os.path.exists(TEST_DB_PATH): pytest.fail(f"测试数据库文件未找到: {TEST_DB_PATH}。请提供一个有效的测试数据库。") if not os.path.exists(SCRIPT_PATH): pytest.fail(f"待测试的脚本未找到: {SCRIPT_PATH}") # 可以添加其他设置,例如创建测试数据目录 def test_audio_extraction(tmp_path, setup_test_environment): """ 测试 get_sample_audio.py 是否能成功提取音频并保存为 wav 文件。 """ output_filename = "test_output.wav" output_path = tmp_path / output_filename # 使用 pytest 的 tmp_path fixture 创建临时输出路径 # 构建命令行参数 cmd = [ sys.executable, # 使用当前的 Python 解释器 SCRIPT_PATH, "--db-path", TEST_DB_PATH, "--MsgSvrID", TEST_MSG_SVR_ID, "--save-path", str(output_path), "--rate", "24000" # 可以根据需要调整 ] # 运行脚本 # 注意:脚本中的 'key' 可能需要根据实际情况调整,或者修改脚本以允许通过参数传递 key # 目前脚本中硬编码了 key="test1" result = subprocess.run(cmd, capture_output=True, text=True, check=False) # check=False 允许我们检查返回码 # 打印输出以便调试 (如果测试失败) print("STDOUT:", result.stdout) print("STDERR:", result.stderr) # 断言脚本成功运行 assert result.returncode == 0, f"脚本执行失败,错误信息: {result.stderr}" # 断言输出文件已创建 assert output_path.exists(), f"输出文件 {output_path} 未被创建" # (可选) 断言文件大小大于 0 assert output_path.stat().st_size > 0, f"输出文件 {output_path} 为空" # (可选) 更复杂的检查,例如使用 wave 库检查文件头或内容 # import wave # try: # with wave.open(str(output_path), 'rb') as wf: # assert wf.getnchannels() == 1 # 假设是单声道 # assert wf.getframerate() == 24000 # 检查采样率 # except wave.Error as e: # pytest.fail(f"无法读取输出的 WAV 文件: {e}") def main_debug(): """用于直接运行和调试的主要函数""" print("--- 开始调试运行 ---") # 检查基本环境 if not os.path.exists(TEST_DB_PATH): print(f"错误: 测试数据库文件未找到: {TEST_DB_PATH}") return if not os.path.exists(SCRIPT_PATH): print(f"错误: 待测试的脚本未找到: {SCRIPT_PATH}") return if TEST_MSG_SVR_ID == "YOUR_TEST_MSG_SVR_ID": print(f"警告: TEST_MSG_SVR_ID 似乎未配置 ({TEST_MSG_SVR_ID})") # 可以选择在这里 return 或继续执行 # 定义调试输出路径 debug_output_dir = os.path.join(os.path.dirname(__file__), "debug_output") os.makedirs(debug_output_dir, exist_ok=True) # 创建输出目录(如果不存在) debug_output_path = os.path.join(debug_output_dir, "debug_sample.wav") print(f"脚本路径: {SCRIPT_PATH}") print(f"数据库路径: {TEST_DB_PATH}") print(f"消息 ID: {TEST_MSG_SVR_ID}") print(f"输出路径: {debug_output_path}") # 构建命令行参数 cmd = [ sys.executable, SCRIPT_PATH, "--db-path", TEST_DB_PATH, "--MsgSvrID", TEST_MSG_SVR_ID, "--save-path", debug_output_path, "--rate", "24000" ] print(f"执行命令: {' '.join(cmd)}") # 运行脚本 try: result = subprocess.run(cmd, capture_output=True, text=True, check=False, timeout=30) # 添加超时 print("\\n--- 脚本执行结果 ---") print("返回码:", result.returncode) print("STDOUT:") print(result.stdout) print("STDERR:") print(result.stderr) # 检查结果 if result.returncode == 0: print("\\n--- 结果检查 ---") if os.path.exists(debug_output_path): print(f"[成功] 输出文件已创建: {debug_output_path}") if os.path.getsize(debug_output_path) > 0: print(f"[成功] 输出文件大小 > 0 ({os.path.getsize(debug_output_path)} bytes)") else: print(f"[失败] 输出文件为空: {debug_output_path}") else: print(f"[失败] 输出文件未找到: {debug_output_path}") else: print("\\n[失败] 脚本执行失败。") except subprocess.TimeoutExpired: print("\\n[失败] 脚本执行超时。") except Exception as e: print(f"\\n[失败] 执行命令时发生异常: {e}") print("\\n--- 调试运行结束 ---") if __name__ == "__main__": # 确保在直接运行时正确设置了测试数据路径 # 注意:这里仍然使用文件顶部的 TEST_DB_PATH 和 TEST_MSG_SVR_ID # 请确保它们已经被修改为有效值! if TEST_DB_PATH == "tests/data/your_test_db.sqlite" or TEST_MSG_SVR_ID == "YOUR_TEST_MSG_SVR_ID": print("*"*40) print("警告:请先在脚本顶部修改 TEST_DB_PATH 和 TEST_MSG_SVR_ID 为有效的测试值!") print("*"*40) # sys.exit(1) # 可以取消注释以强制退出,如果未配置 main_debug() ``` ## /tests/test_old_csv_to_json.py ```py path="/tests/test_old_csv_to_json.py" import csv import json import os import re import sys import pandas as pd from collections import deque current_dir = os.path.dirname(os.path.abspath(__file__)) root_dir = os.path.dirname(current_dir) sys.path.append(root_dir) from make_dataset.qa_generator import DataProcessor csv_folder = "./data/csv" # csv_folder = './data/test' os.chdir(root_dir) print(f"当前处理目录{csv_folder}") def handle_pt_csv(csvfile): chat_df = pd.read_csv(csvfile) # 选择type_name为文本的行、is_sender为1的行 chat_df = chat_df[chat_df["type_name"] == "文本"] chat_df = chat_df[chat_df["is_sender"] == 1] # 对每一行的content进行处理 转为dict 再取'msg'字段 chat_df["content"] = chat_df["content"].apply(lambda x: json.loads(x)["msg"]) # 如果content 包含 手机号、身份证号、邮箱、网址则删除这行 chat_df = chat_df[~chat_df["content"].str.contains("1\d{10}")] chat_df = chat_df[~chat_df["content"].str.contains("\d{18}")] chat_df = chat_df[~chat_df["content"].str.contains("\w+@\w+")] chat_df = chat_df[~chat_df["content"].str.contains("http")] chat_df = chat_df[~chat_df["content"].str.contains(r"\\xa0")] chat_df = chat_df[~chat_df["content"].str.contains(r"\\u")] # 纯content chat_df = chat_df["content"] chat_df = chat_df.dropna() return chat_df def make_pt_dataset(): csv_res = [] # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件 for chat_obj_folder in os.listdir(csv_folder): chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder) for csvfile in os.listdir(chat_obj_folder_path): if not csvfile.endswith(".csv"): continue csvfile_path = os.path.join(chat_obj_folder_path, csvfile) chat_df = handle_pt_csv(csvfile_path) csv_res.append(chat_df) csv_res = pd.concat(csv_res) csv_res = csv_res.apply(lambda x: {"c": x}) # 设置数据集prompt键为c csv_res.to_json("./data/res_csv/pt-my.json", orient="records", force_ascii=False) def handle_sft_csv(csvfile): chat_df = pd.read_csv(csvfile) blocked_words = json.load( open("./make_dataset/blocked_words.json", encoding="utf-8") )["blocked_words"] # 选择type_name为文本的行、is_sender为1的行 # 需要保留的type_name字段名 type_list = [ "文本", "图片", "视频", "合并转发的聊天记录", "语音", "(分享)音乐", "(分享)卡片式链接", "(分享)笔记", "(分享)小程序", "(分享)收藏夹", "(分享)小说(猜)", "(分享)视频号名片", "(分享)视频号视频", "粘贴的文本", # 无法解析的分享链接 ] chat_df = chat_df[chat_df["type_name"].isin(values=type_list)] # chat_df['content'] = chat_df['content'].apply(func=lambda x: json.loads(x)['msg']) chat_df["content"] = chat_df["msg"] # 如果type_name为文本 并且content 包含 手机号、身份证号、邮箱、网址则删除这行 for i in chat_df.index: if chat_df.loc[i, "type_name"] == "文本": if ( re.search(r"1\d{10}", chat_df.loc[i, "content"]) or re.search(r"\d{18}", chat_df.loc[i, "content"]) or re.search(r"\w+@\w+", chat_df.loc[i, "content"]) or "http" in chat_df.loc[i, "content"] or r"\\xa0" in chat_df.loc[i, "content"] or r"\\u" in chat_df.loc[i, "content"] ): chat_df = chat_df.drop(index=i) continue for blocked_word in blocked_words: if blocked_word in chat_df.loc[i, "content"]: chat_df = chat_df.drop(index=i) break else: chat_df.loc[i, "content"] = "" chat_df = chat_df[["is_sender", "type_name", "content", "CreateTime"]] chat_df = chat_df.dropna() # 时间格式 2021-07-07 10:27:23 # 遍历行 相同is_sender的行合并content()遇到不同is_sender就重新开始 # CreateTime字段保留最后的CreateTime chat_df["CreateTime"] = pd.to_datetime(chat_df["CreateTime"]) # 改到这了 type_list.remove("文本") skip_list = type_list res_df = [] last_is_sender = chat_df.iloc[0]["is_sender"] last_content: str = chat_df.iloc[0]["content"] last_CreateTime = chat_df.iloc[0]["CreateTime"] # 超时处理 半天没说话就重新开始 # 注意这里只是处理了组装成一个句子 最后封装对话、配对在make_sft_dataset # 遇到图片 连接 直接封装成一个句子 for i, row in chat_df.iterrows(): if row["type_name"] in skip_list: if last_content != "": if last_content[-1] == ",": last_content = last_content[:-1] elif last_content[-1] not in ["。", "!", "?", "…", "."]: last_content += "" res_df.append( { "is_sender": last_is_sender, "content": last_content, "CreateTime": last_CreateTime, } ) last_CreateTime = row["CreateTime"] last_content = "" # cut表示被skip字段截断 res_df.append( { "is_sender": row["is_sender"], "content": "cut", "CreateTime": row["CreateTime"], } ) continue if last_content == "": # 重新开始 last_content = row["content"] last_is_sender = row["is_sender"] last_CreateTime = row["CreateTime"] continue if row["is_sender"] == last_is_sender: if row["CreateTime"] - last_CreateTime > pd.Timedelta(value="2m"): # 如果超时 前面的添加到res_df 并重新开始 if last_content[-1] == ",": last_content = last_content[:-1] elif last_content[-1] not in ["。", "!", "?", "…", "."]: last_content += "" res_df.append( { "is_sender": last_is_sender, "content": last_content, "CreateTime": last_CreateTime, } ) last_content = row["content"] last_CreateTime = row["CreateTime"] continue # 如果content的结尾没有标点符号则添加逗号,最后结尾是句号 if last_content[-1] not in ["。", "!", "?", "…", ","]: last_content += "," last_content = last_content + row["content"] last_CreateTime = row["CreateTime"] else: if last_content[-1] == ",": last_content = last_content[:-1] elif last_content[-1] not in ["。", "!", "?", "…", "."]: last_content += "" res_df.append( { "is_sender": last_is_sender, "content": last_content, "CreateTime": last_CreateTime, } ) last_is_sender = row["is_sender"] last_content = row["content"] last_CreateTime = row["CreateTime"] res_df = pd.DataFrame(res_df) return res_df def make_sft_dataset(): processor = DataProcessor() csv_files = processor.get_csv_files() csv_concat = [] csv_res = [] for csvfile_path in csv_files: chat_df = handle_sft_csv(csvfile_path) csv_concat.append(chat_df) # 后续代码保持不变 csv_concat = pd.concat(csv_concat) # 更全面地处理cut标记 # 1. 将连续的cut标记合并为一个 # 2. 标记数据区块的开始和结束 processed_rows = [] skip_row = False last_row_was_cut = False for i in range(len(csv_concat)): if skip_row: skip_row = False continue current_row = csv_concat.iloc[i].copy() # 处理当前行是cut的情况 if current_row["content"] == "cut": # 如果上一行已经是cut,则跳过当前行 if last_row_was_cut: continue # 查找连续的cut j = i + 1 while j < len(csv_concat) and csv_concat.iloc[j]["content"] == "cut": j += 1 # 如果有连续的cut,只保留最后一个 if j > i + 1: current_row = csv_concat.iloc[j - 1].copy() skip_row = True last_row_was_cut = True else: last_row_was_cut = False processed_rows.append(current_row) # 创建新的DataFrame csv_concat = pd.DataFrame(processed_rows) # csv_res里is_sender必须是01 01 01 的顺序 csv_concat里不一定是01 01 # 相差超过1小时的时间戳分为不同的对话 # temp_res为一个长度为2的队列 # 将合并后的数据保存到CSV文件中 output_dir = "./test_output" # 生成带时间戳的文件名 import datetime now = datetime.datetime.now() output_file = os.path.join(output_dir, f"csv_old_.csv") # 保存合并后的数据 # csv_concat.to_csv(output_file, index=False, encoding="utf-8-sig") # print(f"已将合并后的数据保存到: {output_file}") # print(f"合并后数据总量: {len(csv_concat)} 条记录") temp_res = deque(maxlen=2) # 6种情况 # temp_res 为空 遇到 0入队 遇到1不处理 遇到cut不处理 # temp_res 有0 遇到0清空队列再入队 遇到1相差超过1小时清空队列 没有相差一小时入队再全部出队 遇到cut清空队列 for i, row in csv_concat.iterrows(): if len(temp_res) == 0: if row["content"] == "cut": continue if row["is_sender"] == 0: temp_res.append(row["content"]) last_CreateTime = row["CreateTime"] else: continue elif len(temp_res) == 1: if row["content"] == "cut": temp_res.clear() last_CreateTime = row["CreateTime"] elif row["is_sender"] == 0: # 遇到0 清空队列再入队 temp_res.clear() temp_res.append(row["content"]) last_CreateTime = row["CreateTime"] else: if row["CreateTime"] - last_CreateTime > pd.Timedelta("5m"): # 相差超过1小时清空队列 temp_res.clear() last_CreateTime = row["CreateTime"] else: # 没有相差一小时入队再全部出队 temp_res.append(row["content"]) csv_res.append({"instruction": temp_res[0], "output": temp_res[1]}) temp_res.clear() last_CreateTime = row["CreateTime"] csv_res_df = pd.DataFrame(csv_res) print(f"处理后数据量:{csv_res_df.shape[0]}") csv_res_df.to_json('./data/res_csv/sft/sft-old-my.json', orient='records', force_ascii=False) if __name__ == "__main__": # make_pt_dataset() make_sft_dataset() ``` ## /tests/test_qa_generator.py ```py path="/tests/test_qa_generator.py" import sys import os import pytest from datetime import datetime, timedelta import pandas as pd # 添加项目根目录到sys.path current_dir = os.path.dirname(os.path.abspath(__file__)) root_dir = os.path.dirname(current_dir) sys.path.append(root_dir) from make_dataset.models import ChatMessage, CutMessage from make_dataset.qa_generator import DataProcessor # 将当前工作目录更改为项目根目录 os.chdir(root_dir) # # 测试数据处理器类的初始化和配置加载 # def test_data_processor_init(): # """测试DataProcessor初始化""" # processor = DataProcessor() # assert processor.csv_folder == "./data/csv" # assert "文本" not in processor.skip_type_list # assert len(processor.type_list) == 8 class MockDataProcessor(DataProcessor): def __init__(self): super().__init__() @pytest.fixture def processor(): """创建一个测试用的处理器实例""" return MockDataProcessor() def test_empty_messages(processor): """测试空消息列表的情况""" messages = [] result = processor.group_consecutive_messages(messages) assert result == [] def test_single_message(processor): """测试单条消息的情况""" now = datetime.now() message = ChatMessage( id=1, MsgSvrID=1001, type_name="文本", is_sender=0, talker="user1", room_name="testroom", msg="你好", src="", CreateTime=now, ) result = processor.group_consecutive_messages([message]) assert len([msg for msg in result if isinstance(msg, ChatMessage)]) == 1 assert len([msg for msg in result if isinstance(msg, CutMessage)]) == 0 assert result[0].msg == "你好" def test_consecutive_messages_same_sender(processor): """测试同一发送者的连续消息""" now = datetime.now() messages = [ ChatMessage( id=1, MsgSvrID=1001, type_name="文本", is_sender=0, talker="user1", room_name="testroom", msg="你好", src="", CreateTime=now, ), ChatMessage( id=2, MsgSvrID=1002, type_name="文本", is_sender=0, talker="user1", room_name="testroom", msg="最近怎么样", src="", CreateTime=now + timedelta(minutes=5), ), ChatMessage( id=3, MsgSvrID=1003, type_name="文本", is_sender=0, talker="user1", room_name="testroom", msg="我想问个问题", src="", CreateTime=now + timedelta(minutes=10), ), ] result = processor.group_consecutive_messages(messages) assert len([msg for msg in result if isinstance(msg, ChatMessage)]) == 1 assert len([msg for msg in result if isinstance(msg, CutMessage)]) == 0 assert result[0].msg == "你好,最近怎么样,我想问个问题" def test_messages_different_senders(processor): """测试不同发送者的消息""" now = datetime.now() messages = [ ChatMessage( id=1, MsgSvrID=1001, type_name="文本", is_sender=0, talker="user1", room_name="testroom", msg="你好", src="", CreateTime=now, ), ChatMessage( id=2, MsgSvrID=1002, type_name="文本", is_sender=1, talker="user2", room_name="testroom", msg="你好,有什么可以帮你的", src="", CreateTime=now + timedelta(minutes=5), ), ChatMessage( id=3, MsgSvrID=1003, type_name="文本", is_sender=0, talker="user1", room_name="testroom", msg="我想问个问题", src="", CreateTime=now + timedelta(minutes=10), ), ] result = processor.group_consecutive_messages(messages) assert len([msg for msg in result if isinstance(msg, ChatMessage)]) == 3 assert len([msg for msg in result if isinstance(msg, CutMessage)]) == 0 assert result[0].msg == "你好" assert result[1].msg == "你好,有什么可以帮你的" assert result[2].msg == "我想问个问题" def test_skip_non_text_messages(processor): """测试跳过非文本消息""" now = datetime.now() messages = [ ChatMessage( id=1, MsgSvrID=1001, type_name="文本", is_sender=0, talker="user1", room_name="testroom", msg="你好", src="", CreateTime=now, ), ChatMessage( id=2, MsgSvrID=1002, type_name="文本", is_sender=0, talker="user1", room_name="testroom", msg="先生", src="image.jpg", CreateTime=now + timedelta(minutes=9.9), ), ChatMessage( id=2, MsgSvrID=1002, type_name="图片", is_sender=0, talker="user1", room_name="testroom", msg="", src="image.jpg", CreateTime=now + timedelta(minutes=1+9.9), ), ChatMessage( id=3, MsgSvrID=1003, type_name="文本", is_sender=0, talker="user1", room_name="testroom", msg="看到图片了吗", src="", CreateTime=now + timedelta(minutes=1+9.9), ), ] result = processor.group_consecutive_messages(messages) chat_messages = [msg for msg in result if isinstance(msg, ChatMessage)] cut_messages = [msg for msg in result if isinstance(msg, CutMessage)] assert len(chat_messages) == 2 assert len(cut_messages) == 1 assert chat_messages[0].msg == "你好,先生" def test_time_window_limit(processor): """测试时间窗口限制(超过1小时的消息不会合并)""" now = datetime.now() messages = [ ChatMessage( id=1, MsgSvrID=1001, type_name="文本", is_sender=0, talker="user1", room_name="testroom", msg="你好", src="", CreateTime=now, ), ChatMessage( id=2, MsgSvrID=1002, type_name="文本", is_sender=0, talker="user1", room_name="testroom", msg="晚上好", src="", CreateTime=now + timedelta(hours=2), # 超过1小时 ), ] result = processor.group_consecutive_messages(messages) assert len([msg for msg in result if isinstance(msg, ChatMessage)]) == 2 assert len([msg for msg in result if isinstance(msg, CutMessage)]) == 0 assert result[0].msg == "你好" assert result[1].msg == "晚上好" def test_consecutive_messages_to_csv(): """ 测试使用DataProcessor的main函数从CSV文件中读取数据, 应用group_consecutive_messages函数,并将结果保存为CSV """ processor = MockDataProcessor() # 获取CSV文件列表 csv_files = processor.get_csv_files() # 如果没有找到CSV文件,创建一个模拟的CSV文件供测试使用 if not csv_files: print("警告:未找到CSV文件,请确保数据目录中有CSV文件") return "无法找到CSV文件" # 存储所有处理后的消息 all_grouped_messages = [] # 处理每个CSV文件 for csv_file in csv_files: print(f"处理文件: {csv_file}") # 加载CSV文件中的消息 chat_messages = processor.load_csv(csv_file) print(f"加载了 {len(chat_messages)} 条消息") # 应用group_consecutive_messages函数 grouped_messages = processor.group_consecutive_messages(messages=chat_messages) print(f"分组后得到 {len(grouped_messages)} 条消息") # 添加到结果列表 all_grouped_messages.extend(grouped_messages) # 如果没有处理到任何消息,提前返回 if not all_grouped_messages: print("警告:未处理到任何消息") return "未处理到任何消息" # 将结果转换为DataFrame messages_dict = [] for msg in all_grouped_messages: if isinstance(msg, ChatMessage): messages_dict.append( { "id": msg.id, "MsgSvrID": msg.MsgSvrID, "type_name": msg.type_name, "is_sender": msg.is_sender, "talker": msg.talker, "room_name": msg.room_name, "msg": msg.msg, "src": msg.src, "CreateTime": msg.CreateTime, } ) elif hasattr(msg, "cut_type"): # 处理CutMessage对象 messages_dict.append( { "id": None, "MsgSvrID": None, "type_name": msg.cut_type, "is_sender": msg.is_sender, "talker": None, "room_name": None, "msg": f"cut", "src": None, "CreateTime": msg.CreateTime, } ) # 创建DataFrame df = pd.DataFrame(messages_dict) # 确保输出目录存在 output_dir = "./test_output" os.makedirs(output_dir, exist_ok=True) # 保存为CSV文件 import datetime now = datetime.datetime.now() output_file = os.path.join(output_dir, f"grouped_messages_.csv") # 使用utf-8-sig编码保存,添加BOM标记以解决中文乱码问题 df.to_csv(output_file, index=False, encoding="utf-8-sig") # 验证结果 assert os.path.exists(output_file) print(f"已成功保存分组消息到: {output_file}") print(f"共保存了 {len(messages_dict)} 条消息") # 显示前5条消息示例 if len(messages_dict) > 0: print("\n消息示例:") for i, msg in enumerate(messages_dict[:5]): print( f"{i+1}. {'用户' if msg['is_sender'] == 0 else '对方'}: {msg['msg'][:50]}..." ) return output_file if __name__ == "__main__": output_file = test_consecutive_messages_to_csv() print(f"测试完成,消息已保存到 {output_file}") ``` ## /tests/test_weclone_pipeline_mock.py ```py path="/tests/test_weclone_pipeline_mock.py" import os import sys import json import shutil import unittest import tempfile from unittest.mock import patch, MagicMock import pandas as pd # 添加项目根目录到系统路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 导入需要测试的模块 from weclone.data.qa_generator import DataProcessor from weclone.utils.config import load_config class TestWeclonePipeline(unittest.TestCase): @classmethod def setUpClass(cls): """设置测试环境""" # 创建临时目录用于测试 cls.test_dir = tempfile.mkdtemp() cls.test_data_dir = os.path.join(cls.test_dir, "data") cls.test_model_dir = os.path.join(cls.test_dir, "model_output") cls.test_eval_dir = os.path.join(cls.test_dir, "eval_output") # 创建必要的目录 os.makedirs(cls.test_data_dir, exist_ok=True) os.makedirs(cls.test_model_dir, exist_ok=True) os.makedirs(cls.test_eval_dir, exist_ok=True) # 创建测试数据集结构 cls.csv_folder = os.path.join(cls.test_data_dir, "csv") os.makedirs(cls.csv_folder, exist_ok=True) # 创建示例聊天文件夹和CSV文件 chat_folder = os.path.join(cls.csv_folder, "test_chat") os.makedirs(chat_folder, exist_ok=True) # 创建简单的测试CSV数据 cls._create_test_csv(os.path.join(chat_folder, "test_chat.csv")) # 创建测试用的settings.json cls._create_test_settings() # 创建测试用的test_data.json用于模型评估 cls._create_test_eval_data() @classmethod def tearDownClass(cls): """清理测试环境""" # 删除临时目录 shutil.rmtree(cls.test_dir, ignore_errors=True) @classmethod def _create_test_csv(cls, file_path): """创建测试用CSV文件""" import pandas as pd # 创建简单的聊天记录数据 data = { "id": list(range(1, 5)), "MsgSvrID": list(range(1001, 1005)), "type": ["1", "1", "1", "1"], # 文本类型 "is_sender": [0, 1, 0, 1], # 0=对方发送,1=自己发送 "talker": ["test_user", "me", "test_user", "me"], "room_name": ["", "", "", ""], "content": ["你好,请问你是谁?", "我是你的微信助手", "你能帮我做什么?", "我可以回答问题,提供信息和帮助你完成各种任务"], "src": ["", "", "", ""], "CreateTime": [1609459200, 1609459220, 1609459240, 1609459260] # 时间戳 } # 创建DataFrame并保存为CSV df = pd.DataFrame(data) df.to_csv(file_path, index=False) @classmethod def _create_test_settings(cls): """创建测试用的settings.json""" # 简化版的设置文件,只包含测试所需的最小配置 settings = { "train_sft_args": { "stage": "sft", "dataset": "wechat-sft", "dataset_dir": cls.test_data_dir + "/res_csv/sft", "lora_target": "query_key_value", "lora_rank": 4, "lora_dropout": 0.5, "overwrite_cache": True, "per_device_train_batch_size": 1, "gradient_accumulation_steps": 1, "lr_scheduler_type": "cosine", "logging_steps": 1, "save_steps": 1, "learning_rate": 0.0001, "num_train_epochs": 1, "plot_loss": False, "fp16": False }, "infer_args": { "repetition_penalty": 1.2, "temperature": 0.5, "max_length": 50, "top_p": 0.65 }, "make_dataset_args": { "single_combine_strategy": "time_window", "qa_match_strategy": "time_window", "single_combine_time_window": 2, "qa_match_time_window": 5, "prompt_with_history": False }, "common_args": { "model_name_or_path": "./chatglm3-6b", # 假设已有模型 "adapter_name_or_path": cls.test_model_dir, "template": "chatglm3-weclone", "finetuning_type": "lora", "trust_remote_code": True } } # 保存到临时目录 with open(os.path.join(cls.test_dir, "settings.json"), "w", encoding="utf-8") as f: json.dump(settings, f, indent=4) @classmethod def _create_test_eval_data(cls): """创建测试用的评估数据""" test_data = { "questions": [ ["你好", "你是谁"], ["你能做什么"] ] } # 确保目录存在 data_dir = os.path.join(cls.test_dir, "data") os.makedirs(data_dir, exist_ok=True) # 保存测试数据 with open(os.path.join(data_dir, "test_data.json"), "w", encoding="utf-8") as f: json.dump(test_data, f, ensure_ascii=False, indent=4) @patch('weclone.data.qa_generator.DataProcessor.get_csv_files') @patch('weclone.data.qa_generator.DataProcessor.load_csv') @patch('weclone.data.qa_generator.DataProcessor.save_result') def test_qa_generator(self, mock_save_result, mock_load_csv, mock_get_csv_files): """测试QA生成器""" print("\n测试QA生成器...") # 准备模拟数据 from weclone.data.models import ChatMessage mock_get_csv_files.return_value = ["test_csv_file.csv"] # 模拟从CSV加载的消息 mock_messages = [ ChatMessage(id=1, MsgSvrID=1001, type_name="文本", is_sender=0, talker="test_user", room_name="", msg="你好,请问你是谁?", src="", CreateTime=pd.Timestamp(1609459200, unit='s')), ChatMessage(id=2, MsgSvrID=1002, type_name="文本", is_sender=1, talker="me", room_name="", msg="我是你的微信助手", src="", CreateTime=pd.Timestamp(1609459220, unit='s')), ChatMessage(id=3, MsgSvrID=1003, type_name="文本", is_sender=0, talker="test_user", room_name="", msg="你能帮我做什么?", src="", CreateTime=pd.Timestamp(1609459240, unit='s')), ChatMessage(id=4, MsgSvrID=1004, type_name="文本", is_sender=1, talker="me", room_name="", msg="我可以回答问题,提供信息和帮助你完成各种任务", src="", CreateTime=pd.Timestamp(1609459260, unit='s')) ] mock_load_csv.return_value = mock_messages # 创建DataProcessor实例 with patch('weclone.utils.config.load_config') as mock_load_config: # 模拟配置 mock_config = { "single_combine_strategy": "time_window", "qa_match_strategy": "time_window", "single_combine_time_window": 2, "qa_match_time_window": 5, "prompt_with_history": False } mock_load_config.return_value = mock_config # 执行QA生成 processor = DataProcessor() processor.csv_folder = self.csv_folder # 设置为测试目录 processor.main() # 验证是否调用了预期的方法 mock_get_csv_files.assert_called_once() mock_load_csv.assert_called_once() mock_save_result.assert_called_once() # 验证结果格式 # 获取保存的结果 call_args = mock_save_result.call_args[0][0] self.assertTrue(isinstance(call_args, list)) self.assertEqual(len(call_args), 2) # 应该有两个QA对 # 验证QA对的结构 for qa in call_args: self.assertTrue("instruction" in qa) self.assertTrue("output" in qa) print("QA生成器测试成功") def test_train_sft(self): """测试SFT训练过程""" print("\n测试SFT训练过程...") # 由于训练需要实际的模型和数据,这里我们只模拟调用 with patch('llamafactory.train.tuner.run_exp') as mock_run_exp: # 导入训练模块并运行 from weclone.train.train_sft import run_exp # 验证是否正确调用了训练函数 self.assertTrue(mock_run_exp.called) print("SFT训练过程测试成功") def test_api_service(self): """测试API服务""" print("\n测试API服务...") # 模拟服务器进程 with patch('uvicorn.run') as mock_run: # 导入API服务模块 from weclone.server.api_service import main, create_app, ChatModel # 模拟配置和模型 with patch('weclone.utils.config.load_config') as mock_load_config: mock_config = {"model_path": "test_model_path"} mock_load_config.return_value = mock_config # 模拟ChatModel with patch('llamafactory.chat.ChatModel') as MockChatModel: mock_chat_model = MagicMock() MockChatModel.return_value = mock_chat_model # 运行API服务 main() # 验证服务是否正确启动 mock_run.assert_called_once() call_args = mock_run.call_args[1] self.assertEqual(call_args["host"], "0.0.0.0") self.assertEqual(call_args["port"], 8005) # 默认端口 self.assertEqual(call_args["workers"], 1) print("API服务测试成功") def test_model_evaluation(self): """测试模型评估""" print("\n测试模型评估...") # 模拟OpenAI API调用 with patch('openai.ChatCompletion.create') as mock_create: # 设置模拟返回值 mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = "这是模型的测试回复" mock_create.return_value = mock_response # 运行评估脚本 with patch('builtins.open', create=True) as mock_open: # 模拟打开测试数据文件 test_data_content = '{"questions": [["你好", "你是谁"], ["你能做什么"]]}' mock_file = MagicMock() mock_file.read.return_value = test_data_content mock_open.return_value.__enter__.return_value = mock_file # 导入并运行评估模块 from weclone.eval.test_model import main # 执行评估 main() # 验证API调用次数(应该是测试问题的数量) self.assertEqual(mock_create.call_count, 3) # 3个测试问题 print("模型评估测试成功") def test_full_pipeline(self): """测试完整流程""" print("\n测试完整流程...") # 这个测试方法会依次调用上面的各个测试方法,模拟完整的流程 # 1. 测试QA生成器 self.test_qa_generator() # 2. 测试SFT训练 self.test_train_sft() # 3. 测试API服务 self.test_api_service() # 4. 测试模型评估 self.test_model_evaluation() print("完整流程测试完成") if __name__ == "__main__": unittest.main() ``` ## /weclone-audio/README.md # WeClone-audio 模块 WeClone-audio 是一个使用微信语音消息克隆声音的模块,使用模型实现高质量语音合成。 ### 显存需求 **Spark-TTS** 推荐 - **0.5B 模型**: 约 4GB 显存 **Llasa** (已弃用) - **3B 模型**: 约 16GB 显存 - **1B 模型**: 约 9GB 显存 ## 1. 导出微信语音数据 ### 1.1 准备工作 - 使用 [PyWxDump](https://github.com/xaoyaoo/PyWxDump) 提取微信聊天记录 - 下载软件并解密数据库 - 点击聊天备份,导出类型选择"解密文件" ### 1.2 环境配置 语音导出仅支持Windows环境 WeClone Audio使用uv作为包管理器。 ```bash # 为 PyWxDump 创建 Python 环境和安装依赖 # uv venv .venv-wx --python=3.10 .venv-wx\Scripts\activate uv pip install pywxdump ``` ### 1.3 导出语音文件 ```bash python weclone-audio/src/get_sample_audio.py --db-path "导出数据库路径" --MsgSvrID "导出聊天记录的MsgSvrID字段" ``` ## 2. 语音合成推理 ### Spark-TTS模型 **环境安装** 可不创建新环境,直接安装`sparktts`依赖组到WeClone共主环境 ```bash uv venv .venv-sparktts --python=3.10 source .venv-sparktts/bin/activate uv pip install --group sparktts -e . git clone https://github.com/SparkAudio/Spark-TTS.git weclone-audio/src/Spark-TTS ``` **模型下载** 通过python下载: ```python from huggingface_hub import snapshot_download # 假设此 Python 代码在 weclone-audio 目录下运行 模型将下载到 weclone-audio/pretrained_models/Spark-TTS-0.5B snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B") ``` 或通过git下载: ```bash # 假设当前在 weclone-audio 目录 mkdir -p pretrained_models # Make sure you have git-lfs installed (https://git-lfs.com) git lfs install git clone https://huggingface.co/SparkAudio/Spark-TTS-0.5B pretrained_models/Spark-TTS-0.5B ``` 使用代码推理 ```python import os import SparkTTS import soundfile as sf import torch from SparkTTS import SparkTTS # 假设此 Python 代码在 weclone-audio 目录下运行 # 模型路径相对于当前目录 model_path = "pretrained_models/Spark-TTS-0.5B" sample_audio = "sample.wav" output_audio = "output.wav" model = SparkTTS(model_path, "cuda") with torch.no_grad(): wav = model.inference( text="晚上好啊,小可爱们,该睡觉了哦", prompt_speech_path=sample_audio, # 使用相对路径 prompt_text="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。", ) sf.write(output_audio, wav, samplerate=16000) # 使用相对路径 ``` ### Llasa模型 (已弃用) ### 2.1 环境配置 ```bash # 创建并配置推理环境 ## 可不创建新环境,与LLaMA-Factory环境共用 uv venv .venv-xcodec --python=3.9 source .venv-xcodec/bin/activate uv pip install --group xcodec -e . # 退出环境 deactivate # 系统依赖安装(如果需要) sudo apt install python3-dev sudo apt install build-essential ``` ### 2.2 使用代码推理 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。 ```python import os import soundfile as sf # 假设 text_to_speech.py 位于 src/ 或其他可导入的位置 from text_to_speech import TextToSpeech sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # 示例音频文本 # 假设此 Python 代码在 weclone-audio 目录下运行 # 示例音频路径相对于当前目录 sample_audio_path = "sample.wav" output_audio = "output.wav" tts = TextToSpeech(sample_audio_path, sample_audio_text) target_text = "晚上好啊" # 生成目标文本 result = tts.infer(target_text) sf.write(output_audio, result[1], result[0]) # 使用相对路径 ``` ## /weclone-audio/src/Llasa/infer.py ```py path="/weclone-audio/src/Llasa/infer.py" import os import soundfile as sf from text_to_speech import TextToSpeech sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # 示例音频文本 sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav") # 示例音频路径 tts = TextToSpeech(sample_audio_path, sample_audio_text) target_text = "晚上好啊" # 生成目标文本 result = tts.infer(target_text) sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0]) # 保存生成音频 ``` ## /weclone-audio/src/Llasa/text_to_speech.py ```py path="/weclone-audio/src/Llasa/text_to_speech.py" import os from transformers import AutoTokenizer, AutoModelForCausalLM import torch import soundfile as sf from xcodec2.modeling_xcodec2 import XCodec2Model import torchaudio class TextToSpeech: def __init__(self, sample_audio_path, sample_audio_text): self.sample_audio_text = sample_audio_text # 初始化模型 llasa_3b = "HKUSTAudio/Llasa-3B" xcodec2 = "HKUSTAudio/xcodec2" self.tokenizer = AutoTokenizer.from_pretrained(llasa_3b) self.llasa_3b_model = AutoModelForCausalLM.from_pretrained( llasa_3b, trust_remote_code=True, device_map="auto", ) self.llasa_3b_model.eval() self.xcodec_model = XCodec2Model.from_pretrained(xcodec2) self.xcodec_model.eval().cuda() # 处理音频 waveform, sample_rate = torchaudio.load(sample_audio_path) if len(waveform[0]) / sample_rate > 15: print("已将音频裁剪至前15秒。") waveform = waveform[:, : sample_rate * 15] # 检查音频是否为立体声 if waveform.size(0) > 1: waveform_mono = torch.mean(waveform, dim=0, keepdim=True) else: waveform_mono = waveform self.prompt_wav = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=16000 )(waveform_mono) # Encode the prompt wav vq_code_prompt = self.xcodec_model.encode_code(input_waveform=self.prompt_wav) vq_code_prompt = vq_code_prompt[0, 0, :] self.speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt) self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") def ids_to_speech_tokens(self, speech_ids): speech_tokens_str = [] for speech_id in speech_ids: speech_tokens_str.append(f"<|s_{speech_id}|>") return speech_tokens_str def extract_speech_ids(self, speech_tokens_str): speech_ids = [] for token_str in speech_tokens_str: if token_str.startswith("<|s_") and token_str.endswith("|>"): num_str = token_str[4:-2] num = int(num_str) speech_ids.append(num) else: print(f"Unexpected token: {token_str}") return speech_ids @torch.inference_mode() def infer(self, target_text): if len(target_text) == 0: return None elif len(target_text) > 300: print("文本过长,请保持在300字符以内。") target_text = target_text[:300] input_text = self.sample_audio_text + " " + target_text formatted_text = ( f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" ) chat = [ { "role": "user", "content": "Convert the text to speech:" + formatted_text, }, { "role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + "".join(self.speech_ids_prefix), }, ] input_ids = self.tokenizer.apply_chat_template( chat, tokenize=True, return_tensors="pt", continue_final_message=True ) input_ids = input_ids.to("cuda") outputs = self.llasa_3b_model.generate( input_ids, max_length=2048, eos_token_id=self.speech_end_id, do_sample=True, top_p=1, temperature=0.8, ) generated_ids = outputs[0][input_ids.shape[1] - len(self.speech_ids_prefix): -1] speech_tokens = self.tokenizer.batch_decode( generated_ids, skip_special_tokens=True ) speech_tokens = self.extract_speech_ids(speech_tokens) speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0) gen_wav = self.xcodec_model.decode_code(speech_tokens) gen_wav = gen_wav[:, :, self.prompt_wav.shape[1]:] return (16000, gen_wav[0, 0, :].cpu().numpy()) if __name__ == "__main__": # 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。 sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav") tts = TextToSpeech(sample_audio_path, sample_audio_text) target_text = "晚上好啊,吃了吗您" result = tts.infer(target_text) sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0]) target_text = "我是老北京正黄旗!" result = tts.infer(target_text) sf.write(os.path.join(os.path.dirname(__file__), "output1.wav"), result[1], result[0]) ``` ## /weclone-audio/src/SparkTTS.py ```py path="/weclone-audio/src/SparkTTS.py" import re import torch from typing import Tuple from pathlib import Path from transformers import AutoTokenizer, AutoModelForCausalLM import os import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "./Spark-TTS"))) from sparktts.utils.file import load_config from sparktts.models.audio_tokenizer import BiCodecTokenizer from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP class SparkTTS: """ Spark-TTS for text-to-speech generation. """ def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")): """ Initializes the SparkTTS model with the provided configurations and device. Args: model_dir (Path): Directory containing the model and config files. device (torch.device): The device (CPU/GPU) to run the model on. """ self.device = device self.model_dir = model_dir self.configs = load_config(f"{model_dir}/config.yaml") self.sample_rate = self.configs["sample_rate"] self._initialize_inference() def _initialize_inference(self): """Initializes the tokenizer, model, and audio tokenizer for inference.""" self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM") self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM") self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device) self.model.to(self.device) def process_prompt( self, text: str, prompt_speech_path: Path, prompt_text: str = None, ) -> Tuple[str, torch.Tensor]: """ Process input for voice cloning. Args: text (str): The text input to be converted to speech. prompt_speech_path (Path): Path to the audio file used as a prompt. prompt_text (str, optional): Transcript of the prompt audio. Return: Tuple[str, torch.Tensor]: Input prompt; global tokens """ global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize( prompt_speech_path ) global_tokens = "".join( [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()] ) # Prepare the input tokens for the model if prompt_text is not None: semantic_tokens = "".join( [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()] ) inputs = [ TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>", "<|start_global_token|>", global_tokens, "<|end_global_token|>", "<|start_semantic_token|>", semantic_tokens, ] else: inputs = [ TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>", "<|start_global_token|>", global_tokens, "<|end_global_token|>", ] inputs = "".join(inputs) return inputs, global_token_ids def process_prompt_control( self, gender: str, pitch: str, speed: str, text: str, ): """ Process input for voice creation. Args: gender (str): female | male. pitch (str): very_low | low | moderate | high | very_high speed (str): very_low | low | moderate | high | very_high text (str): The text input to be converted to speech. Return: str: Input prompt """ assert gender in GENDER_MAP.keys() assert pitch in LEVELS_MAP.keys() assert speed in LEVELS_MAP.keys() gender_id = GENDER_MAP[gender] pitch_level_id = LEVELS_MAP[pitch] speed_level_id = LEVELS_MAP[speed] pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>" speed_label_tokens = f"<|speed_label_{speed_level_id}|>" gender_tokens = f"<|gender_{gender_id}|>" attribte_tokens = "".join( [gender_tokens, pitch_label_tokens, speed_label_tokens] ) control_tts_inputs = [ TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>", "<|start_style_label|>", attribte_tokens, "<|end_style_label|>", ] return "".join(control_tts_inputs) @torch.no_grad() def inference( self, text: str, prompt_speech_path: Path = None, prompt_text: str = None, gender: str = None, pitch: str = None, speed: str = None, temperature: float = 0.8, top_k: float = 50, top_p: float = 0.95, ) -> torch.Tensor: """ Performs inference to generate speech from text, incorporating prompt audio and/or text. Args: text (str): The text input to be converted to speech. prompt_speech_path (Path): Path to the audio file used as a prompt. prompt_text (str, optional): Transcript of the prompt audio. gender (str): female | male. pitch (str): very_low | low | moderate | high | very_high speed (str): very_low | low | moderate | high | very_high temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8. top_k (float, optional): Top-k sampling parameter. Default is 50. top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95. Returns: torch.Tensor: Generated waveform as a tensor. """ if gender is not None: prompt = self.process_prompt_control(gender, pitch, speed, text) else: prompt, global_token_ids = self.process_prompt( text, prompt_speech_path, prompt_text ) model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) # Generate speech using the model generated_ids = self.model.generate( **model_inputs, max_new_tokens=3000, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, ) # Trim the output tokens to remove the input tokens generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] # Decode the generated tokens into text predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # Extract semantic token IDs from the generated text pred_semantic_ids = ( torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)]) .long() .unsqueeze(0) ) if gender is not None: global_token_ids = ( torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)]) .long() .unsqueeze(0) .unsqueeze(0) ) # Convert semantic tokens back to waveform wav = self.audio_tokenizer.detokenize( global_token_ids.to(self.device).squeeze(0), pred_semantic_ids.to(self.device), ) return wav ``` ## /weclone-audio/src/__init__.py ```py path="/weclone-audio/src/__init__.py" ``` ## /weclone-audio/src/get_sample_audio.py ```py path="/weclone-audio/src/get_sample_audio.py" import os import argparse from pywxdump.db import MediaHandler def main(): parser = argparse.ArgumentParser(description="Extract audio from WeChat database") parser.add_argument("--db-path", type=str, required=True, help="Path to WeChat database file") parser.add_argument("--MsgSvrID", type=str, required=True, help="Message server ID of the audio") parser.add_argument("--save-path", type=str, default=os.path.join(os.path.dirname(__file__), "sample.wav"), help="Path to save the audio file (default: sample.wav in script directory)") parser.add_argument("--rate", type=int, default=24000, help="Sample rate for audio conversion (default: 24000)") args = parser.parse_args() config = { "key": "test1", "type": "sqlite", "path": args.db_path, } t1 = MediaHandler(config) t1.get_audio( MsgSvrID=args.MsgSvrID, is_play=True, is_wave=True, save_path=args.save_path, rate=args.rate, ) if __name__ == "__main__": main() ``` ## /weclone-audio/src/infer.py ```py path="/weclone-audio/src/infer.py" import os import soundfile as sf import torch from SparkTTS import SparkTTS model = SparkTTS("weclone-audio/pretrained_models/Spark-TTS-0.5B", "cuda") with torch.no_grad(): wav = model.inference( text="晚上好啊,小可爱们,该睡觉了哦", prompt_speech_path=os.path.join(os.path.dirname(__file__), "sample.wav"), prompt_text="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。", ) sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), wav, samplerate=16000) print("生成成功!") ``` ## /weclone-audio/src/sample.wav Binary file available at https://raw.githubusercontent.com/xming521/WeClone/refs/heads/main/weclone-audio/src/sample.wav ## /weclone-audio/src/server未完工/.env.example ```example path="/weclone-audio/src/server未完工/.env.example" API_KEY=your_api_key_here PORT=5050 DEFAULT_VOICE=en-US-AvaNeural DEFAULT_RESPONSE_FORMAT=mp3 DEFAULT_SPEED=1.0 DEFAULT_LANGUAGE=en-US REQUIRE_API_KEY=True REMOVE_FILTER=False EXPAND_API=True ``` ## /weclone-audio/src/server未完工/handle_text.py ```py path="/weclone-audio/src/server未完工/handle_text.py" import re import emoji def prepare_tts_input_with_context(text: str) -> str: """ Prepares text for a TTS API by cleaning Markdown and adding minimal contextual hints for certain Markdown elements like headers. Preserves paragraph separation. Args: text (str): The raw text containing Markdown or other formatting. Returns: str: Cleaned text with contextual hints suitable for TTS input. """ # Remove emojis text = emoji.replace_emoji(text, replace='') # Add context for headers def header_replacer(match): level = len(match.group(1)) # Number of '#' symbols header_text = match.group(2).strip() if level == 1: return f"Title — {header_text}\n" elif level == 2: return f"Section — {header_text}\n" else: return f"Subsection — {header_text}\n" text = re.sub(r"^(#{1,6})\s+(.*)", header_replacer, text, flags=re.MULTILINE) # Announce links (currently commented out for potential future use) # text = re.sub(r"\[([^\]]+)\]\((https?:\/\/[^\)]+)\)", r"\1 (link: \2)", text) # Remove links while keeping the link text text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) # Describe inline code text = re.sub(r"`([^`]+)`", r"code snippet: \1", text) # Remove bold/italic symbols but keep the content text = re.sub(r"(\*\*|__|\*|_)", '', text) # Remove code blocks (multi-line) with a description text = re.sub(r"\`\`\`([\s\S]+?)\`\`\`", r"(code block omitted)", text) # Remove image syntax but add alt text if available text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", r"Image: \1", text) # Remove HTML tags text = re.sub(r"]+(>|$)", '', text) # Normalize line breaks text = re.sub(r"\n{2,}", '\n\n', text) # Ensure consistent paragraph separation # Replace multiple spaces within lines text = re.sub(r" {2,}", ' ', text) # Trim leading and trailing whitespace from the whole text text = text.strip() return text ``` ## /weclone-audio/src/server未完工/requirements.txt flask gevent python-dotenv edge-tts emoji ## /weclone-audio/src/server未完工/server.py ```py path="/weclone-audio/src/server未完工/server.py" # server.py from flask import Flask, request, send_file, jsonify from gevent.pywsgi import WSGIServer from dotenv import load_dotenv import os from handle_text import prepare_tts_input_with_context from tts_handler import generate_speech, get_models, get_voices from utils import getenv_bool, require_api_key, AUDIO_FORMAT_MIME_TYPES app = Flask(__name__) load_dotenv() API_KEY = os.getenv('API_KEY', 'your_api_key_here') PORT = int(os.getenv('PORT', 5050)) DEFAULT_VOICE = os.getenv('DEFAULT_VOICE', 'en-US-AvaNeural') DEFAULT_RESPONSE_FORMAT = os.getenv('DEFAULT_RESPONSE_FORMAT', 'mp3') DEFAULT_SPEED = float(os.getenv('DEFAULT_SPEED', 1.0)) REMOVE_FILTER = getenv_bool('REMOVE_FILTER', False) EXPAND_API = getenv_bool('EXPAND_API', True) # DEFAULT_MODEL = os.getenv('DEFAULT_MODEL', 'tts-1') @app.route('/v1/audio/speech', methods=['POST']) @app.route('/audio/speech', methods=['POST']) # Add this line for the alias @require_api_key def text_to_speech(): data = request.json if not data or 'input' not in data: return jsonify({"error": "Missing 'input' in request body"}), 400 text = data.get('input') if not REMOVE_FILTER: text = prepare_tts_input_with_context(text) # model = data.get('model', DEFAULT_MODEL) voice = data.get('voice', DEFAULT_VOICE) response_format = data.get('response_format', DEFAULT_RESPONSE_FORMAT) speed = float(data.get('speed', DEFAULT_SPEED)) mime_type = AUDIO_FORMAT_MIME_TYPES.get(response_format, "audio/mpeg") # Generate the audio file in the specified format with speed adjustment output_file_path = generate_speech(text, voice, response_format, speed) # Return the file with the correct MIME type return send_file(output_file_path, mimetype=mime_type, as_attachment=True, download_name=f"speech.{response_format}") @app.route('/v1/models', methods=['GET', 'POST']) @app.route('/models', methods=['GET', 'POST']) @require_api_key def list_models(): return jsonify({"data": get_models()}) @app.route('/v1/voices', methods=['GET', 'POST']) @app.route('/voices', methods=['GET', 'POST']) @require_api_key def list_voices(): specific_language = None data = request.args if request.method == 'GET' else request.json if data and ('language' in data or 'locale' in data): specific_language = data.get('language') if 'language' in data else data.get('locale') return jsonify({"voices": get_voices(specific_language)}) @app.route('/v1/voices/all', methods=['GET', 'POST']) @app.route('/voices/all', methods=['GET', 'POST']) @require_api_key def list_all_voices(): return jsonify({"voices": get_voices('all')}) """ Support for ElevenLabs and Azure AI Speech (currently in beta) """ # http://localhost:5050/elevenlabs/v1/text-to-speech # http://localhost:5050/elevenlabs/v1/text-to-speech/en-US-AndrewNeural @app.route('/elevenlabs/v1/text-to-speech/', methods=['POST']) @require_api_key def elevenlabs_tts(voice_id): if not EXPAND_API: return jsonify({"error": f"Endpoint not allowed"}), 500 # Parse the incoming JSON payload try: payload = request.json if not payload or 'text' not in payload: return jsonify({"error": "Missing 'text' in request body"}), 400 except Exception as e: return jsonify({"error": f"Invalid JSON payload: {str(e)}"}), 400 text = payload['text'] if not REMOVE_FILTER: text = prepare_tts_input_with_context(text) voice = voice_id # ElevenLabs uses the voice_id in the URL # Use default settings for edge-tts response_format = 'mp3' speed = DEFAULT_SPEED # Optional customization via payload.get('speed', DEFAULT_SPEED) # Generate speech using edge-tts try: output_file_path = generate_speech(text, voice, response_format, speed) except Exception as e: return jsonify({"error": f"TTS generation failed: {str(e)}"}), 500 # Return the generated audio file return send_file(output_file_path, mimetype="audio/mpeg", as_attachment=True, download_name="speech.mp3") # tts.speech.microsoft.com/cognitiveservices/v1 # https://{region}.tts.speech.microsoft.com/cognitiveservices/v1 # http://localhost:5050/azure/cognitiveservices/v1 @app.route('/azure/cognitiveservices/v1', methods=['POST']) @require_api_key def azure_tts(): if not EXPAND_API: return jsonify({"error": f"Endpoint not allowed"}), 500 # Parse the SSML payload try: ssml_data = request.data.decode('utf-8') if not ssml_data: return jsonify({"error": "Missing SSML payload"}), 400 # Extract the text and voice from SSML from xml.etree import ElementTree as ET root = ET.fromstring(ssml_data) text = root.find('.//{http://www.w3.org/2001/10/synthesis}voice').text voice = root.find('.//{http://www.w3.org/2001/10/synthesis}voice').get('name') except Exception as e: return jsonify({"error": f"Invalid SSML payload: {str(e)}"}), 400 # Use default settings for edge-tts response_format = 'mp3' speed = DEFAULT_SPEED if not REMOVE_FILTER: text = prepare_tts_input_with_context(text) # Generate speech using edge-tts try: output_file_path = generate_speech(text, voice, response_format, speed) except Exception as e: return jsonify({"error": f"TTS generation failed: {str(e)}"}), 500 # Return the generated audio file return send_file(output_file_path, mimetype="audio/mpeg", as_attachment=True, download_name="speech.mp3") print(f" Edge TTS (Free Azure TTS) Replacement for OpenAI's TTS API") print(f" ") print(f" * Serving OpenAI Edge TTS") print(f" * Server running on http://localhost:{PORT}") print(f" * TTS Endpoint: http://localhost:{PORT}/v1/audio/speech") print(f" ") if __name__ == '__main__': http_server = WSGIServer(('0.0.0.0', PORT), app) http_server.serve_forever() ``` ## /weclone-audio/src/server未完工/tts_handler.py ```py path="/weclone-audio/src/server未完工/tts_handler.py" import edge_tts import asyncio import tempfile import subprocess import os from pathlib import Path # Language default (environment variable) DEFAULT_LANGUAGE = os.getenv('DEFAULT_LANGUAGE', 'en-US') # OpenAI voice names mapped to edge-tts equivalents voice_mapping = { 'alloy': 'en-US-AvaNeural', 'echo': 'en-US-AndrewNeural', 'fable': 'en-GB-SoniaNeural', 'onyx': 'en-US-EricNeural', 'nova': 'en-US-SteffanNeural', 'shimmer': 'en-US-EmmaNeural' } def is_ffmpeg_installed(): """Check if FFmpeg is installed and accessible.""" try: subprocess.run(['ffmpeg', '-version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return True except (subprocess.CalledProcessError, FileNotFoundError): return False async def _generate_audio(text, voice, response_format, speed): """Generate TTS audio and optionally convert to a different format.""" # Determine if the voice is an OpenAI-compatible voice or a direct edge-tts voice edge_tts_voice = voice_mapping.get(voice, voice) # Use mapping if in OpenAI names, otherwise use as-is # Generate the TTS output in mp3 format first temp_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") # Convert speed to SSML rate format try: speed_rate = speed_to_rate(speed) # Convert speed value to "+X%" or "-X%" except Exception as e: print(f"Error converting speed: {e}. Defaulting to +0%.") speed_rate = "+0%" # Generate the MP3 file communicator = edge_tts.Communicate(text=text, voice=edge_tts_voice, rate=speed_rate) await communicator.save(temp_output_file.name) # If the requested format is mp3, return the generated file directly if response_format == "mp3": return temp_output_file.name # Check if FFmpeg is installed if not is_ffmpeg_installed(): print("FFmpeg is not available. Returning unmodified mp3 file.") return temp_output_file.name # Create a new temporary file for the converted output converted_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=f".{response_format}") # Build the FFmpeg command ffmpeg_command = [ "ffmpeg", "-i", temp_output_file.name, # Input file "-c:a", { "aac": "aac", "mp3": "libmp3lame", "wav": "pcm_s16le", "opus": "libopus", "flac": "flac" }.get(response_format, "aac"), # Default to AAC if unknown "-b:a", "192k" if response_format != "wav" else None, # Bitrate not needed for WAV "-f", { "aac": "mp4", # AAC in MP4 container "mp3": "mp3", "wav": "wav", "opus": "ogg", "flac": "flac" }.get(response_format, response_format), # Default to matching format "-y", # Overwrite without prompt converted_output_file.name # Output file ] try: # Run FFmpeg command and ensure no errors occur subprocess.run(ffmpeg_command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) except subprocess.CalledProcessError as e: raise RuntimeError(f"FFmpeg error during audio conversion: {e}") # Clean up the original temporary file Path(temp_output_file.name).unlink(missing_ok=True) return converted_output_file.name def generate_speech(text, voice, response_format, speed=1.0): return asyncio.run(_generate_audio(text, voice, response_format, speed)) def get_models(): return [ {"id": "tts-1", "name": "Text-to-speech v1"}, {"id": "tts-1-hd", "name": "Text-to-speech v1 HD"} ] async def _get_voices(language=None): # List all voices, filter by language if specified all_voices = await edge_tts.list_voices() language = language or DEFAULT_LANGUAGE # Use default if no language specified filtered_voices = [ {"name": v['ShortName'], "gender": v['Gender'], "language": v['Locale']} for v in all_voices if language == 'all' or language is None or v['Locale'] == language ] return filtered_voices def get_voices(language=None): return asyncio.run(_get_voices(language)) def speed_to_rate(speed: float) -> str: """ Converts a multiplicative speed value to the edge-tts "rate" format. Args: speed (float): The multiplicative speed value (e.g., 1.5 for +50%, 0.5 for -50%). Returns: str: The formatted "rate" string (e.g., "+50%" or "-50%"). """ if speed < 0 or speed > 2: raise ValueError("Speed must be between 0 and 2 (inclusive).") # Convert speed to percentage change percentage_change = (speed - 1) * 100 # Format with a leading "+" or "-" as required return f"{percentage_change:+.0f}%" ``` ## /weclone-audio/src/server未完工/utils.py ```py path="/weclone-audio/src/server未完工/utils.py" # utils.py from flask import request, jsonify from functools import wraps import os from dotenv import load_dotenv load_dotenv() def getenv_bool(name: str, default: bool = False) -> bool: return os.getenv(name, str(default)).lower() in ("yes", "y", "true", "1", "t") API_KEY = os.getenv('API_KEY', 'your_api_key_here') REQUIRE_API_KEY = getenv_bool('REQUIRE_API_KEY', True) def require_api_key(f): @wraps(f) def decorated_function(*args, **kwargs): if not REQUIRE_API_KEY: return f(*args, **kwargs) auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Missing or invalid API key"}), 401 token = auth_header.split('Bearer ')[1] if token != API_KEY: return jsonify({"error": "Invalid API key"}), 401 return f(*args, **kwargs) return decorated_function # Mapping of audio format to MIME type AUDIO_FORMAT_MIME_TYPES = { "mp3": "audio/mpeg", "opus": "audio/ogg", "aac": "audio/aac", "flac": "audio/flac", "wav": "audio/wav", "pcm": "audio/L16" } ``` ## /weclone/__init__.py ```py path="/weclone/__init__.py" ``` ## /weclone/cli.py ```py path="/weclone/cli.py" from weclone.data.qa_generator import DataProcessor ``` ## /weclone/data/__init__.py ```py path="/weclone/data/__init__.py" ``` ## /weclone/data/models.py ```py path="/weclone/data/models.py" from dataclasses import dataclass from pandas import Timestamp @dataclass class ChatMessage: id: int MsgSvrID: int type_name: str is_sender: int talker: str room_name: str msg: str src: str CreateTime: Timestamp @dataclass class CutMessage: is_sender: int cut_type: str CreateTime: Timestamp skip_type_list = [ "添加好友", "推荐公众号", "动画表情", "位置", "文件", "位置共享", "接龙", "引用回复", "视频号直播或直播回放", "用户上传的GIF表情", "文件(猜)", "群公告", "视频号直播或直播回放等", "游戏相关", "转账", "赠送红包封面", "语音通话", "企业微信打招呼(猜)", "企业微信添加好友(猜)", "系统通知", "消息撤回1", "拍一拍", "消息撤回5", "消息撤回6", "消息撤回33", "消息撤回36", "消息撤回57", "邀请加群", "未知-11000,0", ] # 没处理的类型 unprocessed_type_list = [] ``` ## /weclone/data/qa_generator.py ```py path="/weclone/data/qa_generator.py" import os from typing import Dict, List import re import pandas as pd import json from weclone.utils.config import load_config from weclone.utils.log import logger from weclone.data.models import ChatMessage, CutMessage, skip_type_list from weclone.data.strategies import TimeWindowStrategy, LLMStrategy from weclone.utils.length_cdf import length_cdf class DataProcessor: def __init__(self): self.config = load_config(arg_type="make_dataset") self.csv_folder = "./dataset/csv" self.system_prompt = self.config["default_system"] self.cut_type_list = [ "图片", "视频", "合并转发的聊天记录", "语音", "(分享)音乐", "(分享)卡片式链接", "(分享)笔记", "(分享)小程序", "(分享)收藏夹", "(分享)小说(猜)", "(分享)视频号名片", "(分享)视频号视频", "粘贴的文本", # 无法解析的分享链接 ] if self.config["single_combine_strategy"] == "time_window": self.single_combine_strategy = TimeWindowStrategy( time_window=self.config["single_combine_time_window"] * 60, is_single_chat=True, ) elif self.config["single_combine_strategy"] == "llm": self.single_combine_strategy = LLMStrategy( is_single_chat=True, ) if self.config["qa_match_strategy"] == "time_window": self.qa_match_strategy = TimeWindowStrategy( time_window=self.config["qa_match_time_window"] * 60, is_single_chat=False, ) elif self.config["qa_match_strategy"] == "llm": self.qa_match_strategy = LLMStrategy(is_single_chat=False) self.c = self.config def main(self): if not os.path.exists(self.csv_folder) or not os.listdir(self.csv_folder): logger.error(f"错误:目录 '{self.csv_folder}' 不存在或为空,请检查路径并确保其中包含 CSV 聊天数据文件。") return csv_files = self.get_csv_files() message_list: List[ChatMessage] = [] for csv_file in csv_files: chat_messages = self.load_csv(csv_file) message_list.extend(self.group_consecutive_messages(messages=chat_messages)) # self.process_by_msgtype(chat_message) qa_res = self.match_qa(message_list) if self.c["prompt_with_history"]: qa_res = self.add_history_to_qa(qa_res) self.save_result(qa_res) length_cdf( model_name_or_path=self.c["model_name_or_path"], dataset=self.c["dataset"], dataset_dir=self.c["dataset_dir"], template=self.c["template"], interval=self.c["cutoff_len"], ) def get_csv_files(self): """遍历文件夹获取所有CSV文件路径""" csv_files = [] for chat_obj_folder in os.listdir(self.csv_folder): chat_obj_folder_path = os.path.join(self.csv_folder, chat_obj_folder) for csvfile in os.listdir(chat_obj_folder_path): if not csvfile.endswith(".csv"): continue csvfile_path = os.path.join(chat_obj_folder_path, csvfile) csv_files.append(csvfile_path) return csv_files def match_qa(self, messages: List[ChatMessage]) -> List[Dict]: """ 匹配问答对 Args: messages: 消息列表 Returns: List[Dict]: 包含指令和输出的问答对列表 """ # 状态定义 WAITING_INSTRUCTION = "waiting_instruction" # 等待指令 WAITING_RESPONSE = "waiting_response" # 等待回复 current_state = WAITING_INSTRUCTION qa_res = [] last_message = None current_instruction = None for msg in messages: # 检查是否为CutMessage if isinstance(msg, CutMessage): current_state = WAITING_INSTRUCTION current_instruction = None last_message = None if self.c["prompt_with_history"]: qa_res.append(msg) continue if current_state == WAITING_INSTRUCTION: if msg.is_sender == 0: # 收到对方消息 current_instruction = msg.msg last_message = msg current_state = WAITING_RESPONSE elif current_state == WAITING_RESPONSE: if msg.is_sender == 0: # 收到对方消息 current_instruction = msg.msg last_message = msg # 状态保持不变 else: # 自己的回复 使用策略判断是否属于同一对话 if last_message and self.qa_match_strategy.is_same_conversation([last_message], msg): qa_res.append( {"instruction": current_instruction, "output": msg.msg, "system": self.system_prompt} ) else: if self.c["prompt_with_history"]: qa_res.append( CutMessage( is_sender=msg.is_sender, cut_type=msg.type_name, CreateTime=msg.CreateTime, ) ) # 无论是否匹配,都重置状态 current_state = WAITING_INSTRUCTION current_instruction = None last_message = None return qa_res def add_history_to_qa(self, qa_res: List[Dict]) -> List[Dict]: qa_res_with_history = [] last_res = {"instruction": "", "output": "", "history": [], "system": self.system_prompt} for _, qa in enumerate(qa_res): if isinstance(qa, CutMessage): if len(last_res["history"]) == 0: continue else: if len(last_res["history"]) == 1: last_res = { "system": self.system_prompt, "instruction": last_res["history"][0][0], "output": last_res["history"][0][1], "history": [], } else: last_res = { "system": self.system_prompt, "instruction": last_res["history"][-1][0], "output": last_res["history"][-1][1], "history": last_res["history"][:-1], } qa_res_with_history.append(last_res) last_res = {"instruction": "", "output": "", "history": [], "system": self.system_prompt} else: last_res["history"].append([qa["instruction"], qa["output"]]) return qa_res_with_history def group_consecutive_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]: """ 将同一个人连续发送的多条消息组合成一条消息,遇到cut_type添加cut Args: messages: 消息列表 Returns: List[ChatMessage]: 组合后的消息列表 """ if not messages: return [] def _combine_text(messages: List[ChatMessage]) -> ChatMessage: """ 合并多条消息为一条 Args: messages: 要合并的消息列表 Returns: ChatMessage: 合并后的消息 """ base_msg = messages[0] combined_content = messages[0].msg for i in messages[1:]: content = i.msg if not content: continue if combined_content and combined_content[-1] not in ["。", "!", "?", "…", ",", "."]: combined_content += "," combined_content += content if len(combined_content) > self.c["combine_msg_max_length"]: logger.warning(f"组合后消息长度超过{self.c['combine_msg_max_length']}将截断:\n {combined_content}") combined_content = combined_content[: self.c["combine_msg_max_length"]] combined_message = ChatMessage( id=base_msg.id, MsgSvrID=base_msg.MsgSvrID, type_name=base_msg.type_name, is_sender=base_msg.is_sender, talker=base_msg.talker, room_name=base_msg.room_name, msg=combined_content, src=base_msg.src, CreateTime=messages[-1].CreateTime, # 使用最后一条消息的时间 ) return combined_message def _create_cut_message(message: ChatMessage) -> CutMessage: return CutMessage( is_sender=message.is_sender, cut_type=message.type_name, CreateTime=message.CreateTime, ) def _combine_current_group(group): """ 处理当前消息组并添加到grouped_messages Args: group: 当前消息组 """ if len(group) > 1: combined_msg = _combine_text(group) grouped_messages.append(combined_msg) else: grouped_messages.append(group[0]) grouped_messages = [] current_group = [] for _, current_msg in enumerate(messages): if current_msg.type_name in self.cut_type_list: if current_group: # 当前组有消息,合并当前组,并添加一条cut _combine_current_group(current_group) current_group = [] cut_msg = _create_cut_message(current_msg) grouped_messages.append(cut_msg) else: # 当前组没消息,检查上一个组 if grouped_messages: if not isinstance(grouped_messages[-1], CutMessage): cut_msg = _create_cut_message(current_msg) grouped_messages.append(cut_msg) # 如果上一个组没消息或最后一条是CutMessage,直接continue continue if not current_group: current_group = [current_msg] continue last_msg = current_group[-1] # 判断是否是同一个人的连续消息 if ( current_msg.is_sender == last_msg.is_sender and current_msg.talker == last_msg.talker and self.single_combine_strategy.is_same_conversation([last_msg], current_msg) ): current_group.append(current_msg) else: # 不是同一个人的消息,处理当前组并开始新组 _combine_current_group(current_group) # 开始新组 current_group = [current_msg] # 处理最后一组消息 if current_group: _combine_current_group(current_group) return grouped_messages def process_by_msgtype(self, chat_message: ChatMessage): if chat_message.type_name == "文本": self.process_text(chat_message) # elif chat_message.type_name == "图片": # self.process_image(chat_message) def load_csv(self, file_path) -> List[ChatMessage]: """ 做整体第一次预处理,过滤不符合条件的行 """ df = pd.read_csv(file_path, encoding="utf-8", dtype={"msg": str}) blocked_words = json.load(open("./dataset/blocked_words.json", encoding="utf-8"))["blocked_words"] df = df[~df["type_name"].isin(values=skip_type_list)] # 如果type_name为文本 并且msg 包含 手机号、身份证号、邮箱、网址则删除这行 for i in df.index: if df.loc[i, "type_name"] == "文本": msg_str = str(df.loc[i, "msg"]) if ( re.search(r"1\d{10}", msg_str) or re.search(r"\d{18}", msg_str) or re.search(r"\w+@\w+", msg_str) or "http" in msg_str or r"\\xa0" in msg_str or r"\\u" in msg_str ): df = df.drop(index=i) continue for blocked_word in blocked_words: if blocked_word in msg_str: df = df.drop(index=i) break else: df.loc[i, "msg"] = "" df = df.dropna(how="all") # 时间格式 2021-07-07 10:27:23 # 遍历行 相同is_sender的行合并msg()遇到不同is_sender就重新开始 df["CreateTime"] = pd.to_datetime(df["CreateTime"]) return [ChatMessage(*row) for row in df.values] def process_text(self, chat_message: ChatMessage): pass def save_result(self, qa_res: List[Dict]): # 保存结果 with open( "./dataset/res_csv/sft/sft-my.json", "w", encoding="utf-8", ) as f: json.dump(qa_res, f, ensure_ascii=False) logger.success(f"聊天记录处理成功,共{len(qa_res)}条,保存到 {f.name}") if __name__ == "__main__": processor = DataProcessor() processor.main() ``` ## /weclone/data/strategies.py ```py path="/weclone/data/strategies.py" from dataclasses import dataclass from typing import List from .models import ChatMessage from abc import ABC, abstractmethod @dataclass class ConversationStrategy(ABC): """对话策略的抽象基类""" is_single_chat: bool @abstractmethod def is_same_conversation( self, history_msg: List[ChatMessage], current_msg: ChatMessage ) -> bool: """判断两条消息是否属于同一个对话""" pass @dataclass class TimeWindowStrategy(ConversationStrategy): """基于时间窗口的判断策略""" time_window: int # 时间窗口(分钟) def is_same_conversation( self, history_msg: List[ChatMessage], current_msg: ChatMessage ) -> bool: time_diff = abs( (current_msg.CreateTime - history_msg[-1].CreateTime) ).total_seconds() return time_diff <= self.time_window @dataclass class LLMStrategy(ConversationStrategy): """基于大模型判断策略""" def is_same_conversation( self, history_msg: List[ChatMessage], current_msg: ChatMessage ) -> bool: # 修复user_id错误,使用talker字段代替user_id return current_msg.talker == history_msg[-1].talker if history_msg else False @dataclass class CompositeStrategy(ConversationStrategy): """组合多个策略的复合策略""" strategies: List[ConversationStrategy] require_all: bool = True # True表示所有策略都满足,False表示任一策略满足即可 def is_same_conversation( self, history_msg: List[ChatMessage], current_msg: ChatMessage ) -> bool: results = [ s.is_same_conversation(history_msg, current_msg) for s in self.strategies ] return all(results) if self.require_all else any(results) ``` ## /weclone/eval/__init__.py ```py path="/weclone/eval/__init__.py" ``` ## /weclone/eval/cli_demo.py ```py path="/weclone/eval/cli_demo.py" from llamafactory.chat import ChatModel from llamafactory.extras.misc import torch_gc try: import platform if platform.system() != "Windows": import readline # noqa: F401 except ImportError: print("Install `readline` for a better experience.") def main(): chat_model = ChatModel() messages = [] print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") while True: try: query = input("\nUser: ") except UnicodeDecodeError: print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") continue except Exception: raise if query.strip() == "exit": break if query.strip() == "clear": messages = [] torch_gc() print("History has been removed.") continue messages.append({"role": "user", "content": query}) print("Assistant: ", end="", flush=True) response = "" for new_text in chat_model.stream_chat(messages): print(new_text, end="", flush=True) response += new_text print() messages.append({"role": "assistant", "content": response}) if __name__ == "__main__": main() ``` ## /weclone/eval/evaluate.py ```py path="/weclone/eval/evaluate.py" from llamafactory.eval.evaluator import Evaluator def main(): evaluator = Evaluator() evaluator.eval() if __name__ == "__main__": main() ``` ## /weclone/eval/test_model.py ```py path="/weclone/eval/test_model.py" import json import openai from tqdm import tqdm from typing import List, Dict from weclone.utils.config import load_config config = load_config("web_demo") config = { "default_prompt": config["default_system"], "model": "gpt-3.5-turbo", "history_len": 15, } config = type("Config", (object,), config)() openai.api_key = """sk-test""" openai.api_base = "http://127.0.0.1:8005/v1" def handler_text(content: str, history: List[Dict[str, str]], config): messages = [{"role": "system", "content": f"{config.default_prompt}"}] for item in history: messages.append(item) messages.append({"role": "user", "content": content}) history.append({"role": "user", "content": content}) try: response = openai.ChatCompletion.create(model=config.model, messages=messages, max_tokens=50) except openai.APIError as e: history.pop() return "AI接口出错,请重试\n" + str(e) resp = str(response.choices[0].message.content) # type: ignore resp = resp.replace("\n ", "") history.append({"role": "assistant", "content": resp}) return resp def main(): test_list = json.loads(open("dataset/test_data.json", "r", encoding="utf-8").read())["questions"] res = [] for questions in tqdm(test_list, desc=" Testing..."): history = [] for q in questions: handler_text(q, history=history, config=config) res.append(history) res_file = open("test_result-my.txt", "w") for r in res: for i in r: res_file.write(i["content"] + "\n") res_file.write("\n") if __name__ == "__main__": main() ``` ## /weclone/eval/web_demo.py ```py path="/weclone/eval/web_demo.py" from llamafactory.webui.interface import create_web_demo from weclone.utils.config import load_config config = load_config("web_demo") def main(): demo = create_web_demo() demo.queue() demo.launch(server_name="0.0.0.0", share=True, inbrowser=True) if __name__ == "__main__": main() ``` ## /weclone/server/__init__.py ```py path="/weclone/server/__init__.py" ``` ## /weclone/server/api_service.py ```py path="/weclone/server/api_service.py" import os import uvicorn from llamafactory.chat import ChatModel from llamafactory.api.app import create_app from weclone.utils.config import load_config config = load_config("api_service") def main(): chat_model = ChatModel(config) app = create_app(chat_model) print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8005))) uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8005)), workers=1) if __name__ == "__main__": main() ``` ## /weclone/train/__init__.py ```py path="/weclone/train/__init__.py" ``` ## /weclone/train/export_model.py ```py path="/weclone/train/export_model.py" from llamafactory.train.tuner import export_model def main(): export_model() if __name__ == "__main__": main() ``` ## /weclone/train/train_pt.py ```py path="/weclone/train/train_pt.py" from llamafactory.train.tuner import run_exp from weclone.utils.config import load_config config = load_config("train_pt") run_exp(config) ``` ## /weclone/train/train_sft.py ```py path="/weclone/train/train_sft.py" import os import sys from llamafactory.train.tuner import run_exp from llamafactory.extras.misc import get_current_device from weclone.utils.config import load_config from weclone.utils.log import logger def main(): config = load_config(arg_type="train_sft") device = get_current_device() if device == "cpu": logger.warning("请注意你正在使用CPU训练,非Mac设备可能会出现问题") sft_json_path = os.path.join(config["dataset_dir"], "sft-my.json") if not os.path.exists(sft_json_path): logger.error(f"错误:文件 '{sft_json_path}' 不存在,请确保数据处理步骤已正确生成该文件。") sys.exit(1) run_exp(config) if __name__ == "__main__": main() ``` ## /weclone/utils/__init__.py ```py path="/weclone/utils/__init__.py" ``` ## /weclone/utils/config.py ```py path="/weclone/utils/config.py" import os import commentjson import sys from .log import logger from .tools import dict_to_argv def load_config(arg_type: str): config_path = os.environ.get("WECLONE_CONFIG_PATH", "./settings.json") logger.info(f"Loading configuration from: {config_path}") # Add logging to see which file is loaded try: with open(config_path, "r", encoding="utf-8") as f: s_config: dict = commentjson.load(f) except FileNotFoundError: logger.error(f"Configuration file not found: {config_path}") sys.exit(1) # Exit if config file is not found except Exception as e: logger.error(f"Error loading configuration file {config_path}: {e}") sys.exit(1) if arg_type == "web_demo" or arg_type == "api_service": # infer_args和common_args求并集 config = {**s_config["infer_args"], **s_config["common_args"]} elif arg_type == "train_pt": config = {**s_config["train_pt_args"], **s_config["common_args"]} elif arg_type == "train_sft": config = {**s_config["train_sft_args"], **s_config["common_args"]} if s_config["make_dataset_args"]["prompt_with_history"]: dataset_info_path = os.path.join(config["dataset_dir"], "dataset_info.json") dataset_info = commentjson.load(open(dataset_info_path, "r", encoding="utf-8"))[config["dataset"]] if dataset_info["columns"].get("history") is None: logger.warning(f"{config['dataset']}数据集不包history字段,尝试使用wechat-sft-with-history数据集") config["dataset"] = "wechat-sft-with-history" elif arg_type == "make_dataset": config = {**s_config["make_dataset_args"], **s_config["common_args"]} config["dataset"] = s_config["train_sft_args"]["dataset"] config["dataset_dir"] = s_config["train_sft_args"]["dataset_dir"] config["cutoff_len"] = s_config["train_sft_args"]["cutoff_len"] else: raise ValueError("暂不支持的参数类型") if "train" in arg_type: config["output_dir"] = config["adapter_name_or_path"] config.pop("adapter_name_or_path") config["do_train"] = True sys.argv += dict_to_argv(config) return config ``` ## /weclone/utils/length_cdf.py ```py path="/weclone/utils/length_cdf.py" from collections import defaultdict from tqdm import tqdm from weclone.utils.log import logger from llamafactory.data import get_dataset, get_template_and_fix_tokenizer from llamafactory.hparams import get_train_args from llamafactory.model import load_tokenizer def length_cdf( model_name_or_path: str, dataset: str = "alpaca_en_demo", dataset_dir: str = "data", template: str = "default", interval: int = 1000, ): r"""Calculate the distribution of the input lengths in the dataset. Usage: export CUDA_VISIBLE_DEVICES=0 python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default """ model_args, data_args, training_args, _, _ = get_train_args( dict( stage="sft", model_name_or_path=model_name_or_path, dataset=dataset, dataset_dir=dataset_dir, template=template, cutoff_len=1_000_000, preprocessing_num_workers=16, output_dir="dummy_dir", overwrite_cache=True, do_train=True, ) ) tokenizer_module = load_tokenizer(model_args) template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) # type: ignore trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"] # type: ignore total_num = len(trainset) # type: ignore length_dict = defaultdict(int) for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"): # type: ignore length_dict[len(sample) // interval * interval] += 1 length_tuples = list(length_dict.items()) length_tuples.sort() count_accu, prob_accu = 0, 0 logger.info(" cutoff_len设置建议:") for length, count in length_tuples: count_accu += count prob_accu += count / total_num * 100 logger.info(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.") ``` ## /weclone/utils/log.py ```py path="/weclone/utils/log.py" from loguru import logger import sys logger.remove() logger.add( sys.stderr, format="[WeClone] {level.name[0]} | {time:HH:mm:ss} | {message}", colorize=True, ) ``` ## /weclone/utils/tools.py ```py path="/weclone/utils/tools.py" def dict_to_argv(d): argv = [] for k, v in d.items(): argv.append("--" + k) if v is not None: argv.append(str(v)) return argv ``` The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.