项目简介
本项目是基于GPT2的中文对话生成模型,借助HuggingFace的transformers
库搭建,实现了基于文本的对话系统。具备数据预处理、模型训练、评估以及与用户交互等功能,还提供多种采样策略控制生成的多样性。
项目的主要特性和功能
- 基于GPT2的中文对话生成:采用GPT2模型架构,在中文对话数据集上训练,生成自然流畅的中文对话。
- 多样化生成策略:通过温度参数、Top - k采样和Nucleus Sampling策略,灵活控制生成内容的多样性。
- 数据预处理:提供脚本将原始对话数据转换为模型训练所需格式。
- 模型训练与评估:涵盖模型训练、验证、保存和加载功能,可计算训练损失、验证损失和准确率。
- 用户交互:提供脚本让用户通过命令行输入对话内容,模型生成相应回复。
安装使用步骤
环境准备
- 安装Python 3.6及以上版本。
- 安装依赖库:
pip install torch transformers
。
数据准备
- 准备中文对话数据集,保存为
data/train.txt
。 - 运行数据预处理脚本:
bash python preprocess.py --train_path data/train.txt --save_path data/train.pkl
模型训练
- 运行训练脚本:
bash python train.py --epochs 40 --batch_size 8 --device 0 --train_path data/train.pkl
- 训练过程中,模型会自动保存到
model
目录下。
模型加载与交互
- 使用训练好的模型进行对话:
bash python interact.py --model_path model/epoch40 --device 0
- 用户可通过命令行输入对话内容,模型将生成回复。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】