项目简介
本项目聚焦医疗问答领域,是基于深度学习框架(可能是TensorFlow或PyTorch)构建的文本相似度分类系统。项目以平安医疗科技疾病问答迁移学习比赛的数据为处理对象,利用BERT预训练模型进行训练,并引入对抗训练提升模型鲁棒性。通过融合多个模型的预测结果,有效提高分类准确性,涵盖数据预处理、模型构建、训练和预测等完整流程。
项目的主要特性和功能
- 数据处理:拼接训练集和验证集数据,引入额外训练数据并去除缺失值,丰富训练数据。
- 数据生成器:自定义数据生成器,可设置数据是否shuffle,对数据编码后补长并迭代生成,减少内存占用。
- 模型结构改进:提供三种模型输出方式,使用dropout和全连接层,编译模型时采用二分类交叉熵损失和Adam优化器。
- 对抗训练:在模型compile之后添加对抗训练,增强模型鲁棒性。
- 模型训练:采用5折交叉验证,训练过程加入对抗训练和早停机制。
- 模型融合:使用BERT - wwm - ext、Ernie - 1.0、RoBERTa - large - pair三种模型,在训练阶段保存最优权重,预测阶段进行多次反sigmoid函数化及加权融合。
安装使用步骤
1. 安装依赖库
依据项目需求,安装必要的Python库,如TensorFlow或PyTorch,以及相关的深度学习模型和工具库。
2. 数据准备
将训练集和验证集数据置于正确文件夹,并按项目要求的格式组织。额外训练数据可从https://www.biendata.xyz/competition/chip2019/data/ 下载。
3. 下载模型文件
从指定地址下载BERT - wwm - ext、Ernie - 1.0、RoBERTa - large - pair的模型文件: - BERT - wwm - ext:https://drive.google.com/u/0/uc?id=1buMLEjdtrXE2c4G1rpsNGWEx7lUQ0RHi&export=download - RoBERTa - large - pair:https://drive.google.com/u/0/uc?id=1ykENKV7dIFAqRRQbZIh0mSb7Vjc2MeFA&export=download
4. 运行代码
运行主要的训练脚本(如train.py
或train_v2.py
),可根据需要调整参数。
5. 模型训练
按照设定参数进行模型训练,每个模型每一折跑5个epoch,以 ‘val_loss’ 为评估准则,保存最优权重。
6. 模型预测
使用训练阶段保存的模型文件对测试集进行预测,经多次反sigmoid函数化及加权融合后得到最终结果。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】