项目简介
本项目是基于深度学习的文本分类模型优化与部署项目。借助BERT模型,结合蒸馏、剪枝、优化和量化等技术优化模型,提高性能并减小模型大小,最终通过FastAPI框架完成模型的部署与推理,为实际应用场景提供高效文本分类解决方案。
项目的主要特性和功能
模型优化
- 蒸馏:运用TextBrewer工具开展模型蒸馏,提升模型性能。
- 剪枝:借助TextPruner工具对模型进行剪枝,减小模型体积。
- 量化:使用ONNX Runtime对模型进行量化,提升推理速度。
- ONNX转换:将PyTorch模型转换为ONNX格式,便于后续优化和部署。
模型部署
- FastAPI服务:通过FastAPI框架构建Web服务,支持单条和批量文本数据的预测。
- 接口支持:提供RESTful API接口,返回预测结果和概率分布。
- 命令行工具:提供命令行脚本,方便用户调用模型进行预测。
安装使用步骤
环境准备
- 确保已安装Python 3.7及以上版本。
- 安装项目依赖库:
bash pip install -r requirements.txt
数据准备
确保数据文件格式符合要求,表头为 index sentence label
,数据与表头之间用 \t
分隔。
模型训练
进入 train
文件夹,运行 run.sh
脚本进行模型训练:
bash
./run.sh
模型优化
- 蒸馏:进入
optimize/distill
文件夹,修改distill.sh
脚本中的参数,并运行:bash ./distill.sh
- 剪枝:使用TextPruner工具对模型进行剪枝,具体参数参考项目文档。
- 量化:进入
optimize/quantify
文件夹,运行quantize.py
脚本进行模型量化:bash python quantize.py
- ONNX转换:进入
optimize/acceleration
文件夹,运行export_pytorch2onnx.py
脚本将模型转换为ONNX格式:bash python export_pytorch2onnx.py
模型部署
- 进入
deploy
文件夹,运行run_app.sh
脚本启动FastAPI服务:bash ./run_app.sh
- 通过API接口进行文本分类预测,支持单条和批量预测。
注意事项
- 在进行模型优化时,建议按照蒸馏、剪枝、量化的顺序进行操作,并在每一步骤后进行模型评估,确保模型性能符合要求。
- 部署时需根据实际需求调整模型类型和配置参数。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】