项目简介
本项目基于Python编程语言和PyTorch深度学习框架,聚焦于对ChatGLM-6B模型进行微调与验证。借助数据增强技术生成多样化训练数据,对预训练的ChatGLM-6B模型进行微调,使其适配特定任务需求,可用于问答系统、智能助手等场景。
项目的主要特性和功能
- 数据增广:运用同义词替换、近义词替换、随机删除、随机置换、回译等手段生成更多训练数据,增强模型泛化能力。
- 模型训练:对ChatGLM-6B模型进行单卡Lora微调,使其适配特定任务。
- 模型保存:可保存训练过程中的checkpoint,便于后续使用和继续训练。
- 性能验证:在部分验证集数据上对模型进行评估,确保准确性和可靠性。
- 结果输出:输出预测标签并保存为 "submission.json" 文件。
安装使用步骤
安装依赖
用户下载项目源码文件后,在项目根目录下执行以下命令安装相应的机器学习库:
pip install -r requirements.txt
数据增广
运行 DataAugmentation.py
文件生成训练集和验证集,可配置参数如下:
python
data_aug(label_trainset, syn_num=1, sim_num=1, del_num=1, exc_num=1, bkt_num=True)
其中:
- label_trainset
:带标签的训练集
- syn_num
:同义词替换,生成新数据个数
- sim_num
:nlpcda近义词替换,生成新数据个数
- del_num
:nlpcda随机删除,生成新数据个数
- exc_num
:nlpcda随机置换,生成新数据个数
- bkt_num
:是否进行回译,生成新数据
模型训练
运行 Train.py
文件进行单卡Lora微调ChatGLM-6B。
模型保存
运行 Save.py
文件保存训练过程中的checkpoint。
验证性能
运行 Validation.py
文件在部分验证集数据上进行评估。
输出结果
运行 Predict.py
文件,输出预测的标签并保存为 "submission.json"。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】