项目简介
本项目聚焦于2024DCIC少样本条件下的社交平台话题识别任务,最终取得了rank9的成绩。项目采用m3e - large预训练模型,通过多分类微调并结合对抗训练提升模型泛化能力,同时运用两种不同粒度的相似度计算方法和动态阈值计算,实现对社交平台文本的高效话题识别。
项目的主要特性和功能
- 预训练模型应用:使用m3e - large预训练模型,为话题识别任务提供有力基础。
- 数据预处理:自动过滤文本中的特定标签和超链接,简化数据处理流程。
- 模型训练优化:在训练阶段进行多分类微调,并加入对抗训练,增强模型泛化能力。
- 相似度计算:采用全局相似度和局部相似度两种计算方法,综合评估文本相似度。
- 动态阈值计算:根据每个任务中支撑文本之间的相似度计算阈值,并人工调整,提高识别精度。
- API服务:提供API接口,方便用户以请求的方式进行推理预测。
安装使用步骤
环境搭建
下载项目源码文件后,打开终端,执行以下命令安装项目所需依赖:
sh
cd code
pip install -r requirements.txt
预训练模型准备
从 https://modelscope.cn/models/Jerry0/M3E-large/files 下载m3e - large权重。
数据预处理
项目已在 retrieval_dataset.py 中完成数据预处理,复现训练、推理以及B榜无需执行额外的预处理命令。
训练阶段
以m3e - large为预训练权重,在训练集上进行多分类微调。执行以下命令开始训练:
sh
sh train/run_pretrain.sh
run_pretrain.sh
主要参数含义如下:
sh
python main_pretrain.py \
--data_path 训练数据路径 \
--batch_size 8 \
--task_name 'models' \
--encoder_dir 预训练权重路径 \
--max_len 512 \
--random_lr 1e-4 \
--pretrained_lr 1e-5 \
--schedule_type 'none' \
--use_fp16 \
--valid_ratio 按类别划分的验证集比例 \
--epochs 6 \
--preprocess \
--use_fgm \
--output_dir 模型保存路径 \
推理阶段
执行以下命令启动服务:
sh
sh main.sh
通过request的POST方法请求服务,url为 http://0.0.0.0:8090/predict 进行推理。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】