当前位置: 首页 > news >正文

如何使用uer做多分类任务

如何使用uer做多分类任务

语料集下载
在这里插入图片描述
找到这里点击即可
里面是这有json文件的
在这里插入图片描述
因此我们对此要做一些处理,将其转为tsv格式

# -*- coding: utf-8 -*-
import json
import csv
import chardet# 检测文件编码
def detect_encoding(file_path):with open(file_path, 'rb') as f:raw_data = f.read()return chardet.detect(raw_data)['encoding']# 输入文件名
input_file = './datasets/iflytek/train.json'
# 输出文件名
output_file = './datasets/iflytek/train.tsv'# 检测输入文件的编码格式
file_encoding = detect_encoding(input_file)# 打开输入的 JSON 文件和输出的 TSV 文件
with open(input_file, 'r', encoding=file_encoding) as json_file, open(output_file, 'w', newline='', encoding='utf-8') as tsv_file:# 准备 TSV 写入器tsv_writer = csv.writer(tsv_file, delimiter='\t')# 写入表头(列表['label', 'label_des', 'sentence']中要注意根据json文件中的键值做更换)tsv_writer.writerow(['label', 'label_des', 'sentence'])# 逐行读取 JSON 文件for line in json_file:try:# 解析每一行的 JSON 数据json_data = json.loads(line.strip())# 写入到 TSV 文件中,(列表['label', 'label_des', 'sentence']中要注意根据json文件中的键值做更换)tsv_writer.writerow([json_data['label'], json_data['label_des'], json_data['sentence']])except json.JSONDecodeError as e:print(f"无法解析的行: {line.strip()}")print(f"错误信息: {e}")print(f"JSON 文件已成功转换为 TSV 文件,输入文件编码: {file_encoding}")

接着呢要把所有tsv文件的sentence表头名改成text_a,不然运行uer框架会报错,原因请看源代码逻辑

def read_dataset(args, path):dataset, columns = [], {}with open(path, mode="r", encoding="utf-8") as f:for line_id, line in enumerate(f):if line_id == 0:for i, column_name in enumerate(line.rstrip("\r\n").split("\t")):columns[column_name] = icontinueline = line.rstrip("\r\n").split("\t")tgt = int(line[columns["label"]])if args.soft_targets and "logits" in columns.keys():soft_tgt = [float(value) for value in line[columns["logits"]].split(" ")]if "text_b" not in columns:  # Sentence classification.text_a = line[columns["text_a"]]src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN])seg = [1] * len(src)else:  # Sentence-pair classification.text_a, text_b = line[columns["text_a"]], line[columns["text_b"]]src_a = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN])src_b = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(text_b) + [SEP_TOKEN])src = src_a + src_bseg = [1] * len(src_a) + [2] * len(src_b)if len(src) > args.seq_length:src = src[: args.seq_length]seg = seg[: args.seq_length]if len(src) < args.seq_length:PAD_ID = args.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0]src += [PAD_ID] * (args.seq_length - len(src))seg += [0] * (args.seq_length - len(seg))if args.soft_targets and "logits" in columns.keys():dataset.append((src, tgt, seg, soft_tgt))else:dataset.append((src, tgt, seg))return dataset

这里规定好了表头名只有label,text_a,text_b
搞完之后进入训练代码,我的显存只有16G,因此

python finetune/run_classifier.py --pretrained_model_path models/cluecorpussmall_roberta_wwm_large_seq512_model.bin --vocab_path models/google_zh_vocab.txt --config_path models/bert/large_config.json --train_path datasets/iflytek/train.tsv --dev_path datasets/iflytek/dev.tsv --output_model_path models/iflytek_classifier_model.bin --epochs_num 3 --batch_size 16 --seq_length 128

在这里插入图片描述
在这里插入图片描述
这里可以看到只有61.49的正确率,其实是因为显存还不够,训练不了那么大的,标准的参数应该设置为batch_size=32 seq_length=256
有能力的可以更改参数进行训练
接着来预测

python inference/run_classifier_infer.py --load_model_path models/iflytek_classifier_model.bin --vocab_path models/google_zh_vocab.txt --config_path models/bert/large_config.json --test_path datasets/iflytek/test.tsv --prediction_path datasets/iflytek/prediction.tsv --seq_length 256 --labels_num 119

在这里插入图片描述
最后自行查看预测效果

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 刷题之多数元素(leetcode)
  • 11.索引_创建不同种类索引(primary+unique+复合....)
  • Spring MVC深入理解之源码实现
  • .net core Redis 使用有序集合实现延迟队列
  • 【环境准备】 Vue环境搭建
  • AngularJS API 深入解析
  • CTF php RCE (一)
  • 激光干涉仪可以完成哪些测量:全面应用解析
  • 北京大学长安汽车发布毫米波与相机融合模型RCBEVDet:最快能达到每秒28帧
  • 招投标信息采集系统:让您的企业始终站在行业前沿
  • 短链接day3
  • Socket网络通信流程
  • 昇思25天学习打卡营第6天|函数式自动微分
  • Docker安装遇到问题:curl: (7) Failed to connect to download.docker.com port 443: 拒绝连接
  • Nacos2.X 配置中心源码分析:客户端如何拉取配置、服务端配置发布客户端监听机制
  • 【React系列】如何构建React应用程序
  • 【译】React性能工程(下) -- 深入研究React性能调试
  • electron原来这么简单----打包你的react、VUE桌面应用程序
  • ES6系统学习----从Apollo Client看解构赋值
  • Java知识点总结(JavaIO-打印流)
  • Linux后台研发超实用命令总结
  • nodejs实现webservice问题总结
  • Protobuf3语言指南
  • Vue小说阅读器(仿追书神器)
  • 闭包,sync使用细节
  • 机器学习 vs. 深度学习
  • 深入 Nginx 之配置篇
  • 一文看透浏览器架构
  • 最简单的无缝轮播
  • Nginx实现动静分离
  • 测评:对于写作的人来说,Markdown是你最好的朋友 ...
  • #define用法
  • (Arcgis)Python编程批量将HDF5文件转换为TIFF格式并应用地理转换和投影信息
  • (arch)linux 转换文件编码格式
  • (Redis使用系列) Springboot 整合Redisson 实现分布式锁 七
  • (二)学习JVM —— 垃圾回收机制
  • (二十一)devops持续集成开发——使用jenkins的Docker Pipeline插件完成docker项目的pipeline流水线发布
  • (附源码)ssm码农论坛 毕业设计 231126
  • (七)c52学习之旅-中断
  • (十)T检验-第一部分
  • (十八)三元表达式和列表解析
  • (新)网络工程师考点串讲与真题详解
  • (一)springboot2.7.6集成activit5.23.0之集成引擎
  • (译)计算距离、方位和更多经纬度之间的点
  • (转)linux自定义开机启动服务和chkconfig使用方法
  • (转)shell调试方法
  • (最全解法)输入一个整数,输出该数二进制表示中1的个数。
  • ***php进行支付宝开发中return_url和notify_url的区别分析
  • ***汇编语言 实验16 编写包含多个功能子程序的中断例程
  • .360、.halo勒索病毒的最新威胁:如何恢复您的数据?
  • .dwp和.webpart的区别
  • .NET : 在VS2008中计算代码度量值
  • .NET C# 使用GDAL读取FileGDB要素类
  • .net core 3.0 linux,.NET Core 3.0 的新增功能
  • .NET 设计模式初探