当前位置: 首页 > news >正文

基于transformers框架实践Bert系列5-阅读理解(文本摘要)

本系列用于Bert模型实践实际场景,分别包括分类器、命名实体识别、选择题、文本摘要等等。(关于Bert的结构和详细这里就不做讲解,但了解Bert的基本结构是做实践的基础,因此看本系列之前,最好了解一下transformers和Bert等)
本篇主要讲解阅读理解文本摘要)应用场景。本系列代码和数据集都上传到GitHub上:https://github.com/forever1986/bert_task

目录

  • 1 环境说明
  • 2 前期准备
    • 2.1 了解Bert的输入输出
    • 2.2 数据集与模型
    • 2.3 任务说明
    • 2.4 实现关键
  • 3 关键代码
    • 3.1 数据集处理
    • 3.2 模型加载
    • 3.3 评估函数
  • 4 整体代码
  • 5 运行效果

1 环境说明

1)本次实践的框架采用torch-2.1+transformer-4.37
2)另外还采用或依赖其它一些库,如:evaluate、pandas、datasets、accelerate、nltk等

2 前期准备

Bert模型是一个只包含transformer的encoder部分,并采用双向上下文和预测下一句训练而成的预训练模型。可以基于该模型做很多下游任务。

2.1 了解Bert的输入输出

Bert的输入:input_ids(使用tokenizer将句子向量化),attention_mask,token_type_ids(句子序号)、labels(结果)
Bert的输出:
last_hidden_state:最后一层encoder的输出;大小是(batch_size, sequence_length, hidden_size)(注意:这是关键输出,本次任务就需要获取该值,并进行一次线性层处理
pooler_output:这是序列的第一个token(classification token)的最后一层的隐藏状态,输出的大小是(batch_size, hidden_size),它是由线性层和Tanh激活函数进一步处理的。(通常用于句子分类,至于是使用这个表示,还是使用整个输入序列的隐藏状态序列的平均化或池化,视情况而定)。
hidden_states: 这是输出的一个可选项,如果输出,需要指定config.output_hidden_states=True,它也是一个元组,它的第一个元素是embedding,其余元素是各层的输出,每个元素的形状是(batch_size, sequence_length, hidden_size)
attentions:这是输出的一个可选项,如果输出,需要指定config.output_attentions=True,它也是一个元组,它的元素是每一层的注意力权重,用于计算self-attention heads的加权平均值。

2.2 数据集与模型

1)数据集来自:cmrc2018
2)模型权重使用:bert-base-chinese

2.3 任务说明

1)文本摘要就是让模型学会文本内容,然后根据提问给出答案或者总结。阅读理解中的文本摘要包括:抽取式和生成式
2)本次将的主要完成问答任务中抽取式(对于生成式文本摘要,单单使用BERT是无法完成的,因为BERT只有编码器,编码器就是给多少返回多少,并不会自动生成,一般可以结合解码器一起使用,那么就不如直接使用自回归或者seq2seq模型了),这里对抽取式数据集做一下说明,就是给定一个文本context,然后再给一个问题question以及一个答案answer(answer来自contex中的一小段文字)
3)片段抽取式任务就是让模型学会从context+question中获取到answer,answer来自contex中的一小段文字。

2.4 实现关键

1)首先我们先看看数据集结构

{"answers": {"answer_start": [11, 11],"text": ["光荣和ω-force", "光荣和ω-force"]},"context": "\"《战国无双3》()是由光荣和ω-force开发的战国无双系列的正统第三续作。本作以三大故事为主轴,分别是以武田信玄等人为主的《关东三国志》,织田信长等人为主的《战国三杰》,石田三成等人为主的《关原的年轻武者》,丰富游戏内的剧情。此部份专门介绍角色,欲知武...","id": "DEV_0_QUERY_0","question": "《战国无双3》是由哪两个公司合作开发的?"
}
  • 其中answers是答案,一个map,其中answer_start是答案的起始位置,text是答案内容,这里2个字段都是list,是因为可能有多个答案或者多个问题的答案
  • context是文本内容
  • question是问题
  • id是数据id

2)我们需要将任务转化为:问题question和文本内容context作为输入,然后label是answer的开始位置结束位置,记住是token后的开始位置和结束位置,因为有的词token后可能是多个token

3)另外注意的问题是拼接问题question和文本内容context可能超过我们的max_length,如果采用暴力截断,会出现有答案部分内容被截取掉了,因此如果我们不希望被截断,可以将文本分块输入到模型中,所以在做tokenizer时,使用参数return_overflowing_tokens和stride对文本进行分块。这里打个比方,比如你有个句子如下:

《战国无双3》是由光荣和ω-force开发的战国无双系列的正统第三续作。本作以三大故事为主轴,分别是以武田信玄等人为主的《关东三国志》,织田信长等人为主的《战国三杰》,石田三成等人为主的《关原的年轻武者》,丰富游戏内的剧情。

但你的max_length=40,stride=1,那么这个句子可以被切分为如下:

句子1:《战国无双3》是由光荣和ω-force开发的战国无双系列的正统第三续作。本作以三大故事为主
句子2:轴,分别是以武田信玄等人为主的《关东三国志》,织田信长等人为主的《战国三杰》,石
句子3:田三成等人为主的《关原的年轻武者》,丰富游戏内的剧情。

这时候虽然切分了,但是如果你的答案刚好在句子1和句子2之间,那么你还是得不到答案,因此使用stride参数,可以将部分内容重叠,比如我们设置stride=10,那么得到新的句子如下

句子1:《战国无双3》是由光荣和ω-force开发的战国无双系列的正统第三续作。本作以三大故事为主
句子2:。本作以三大故事为主轴,分别是以武田信玄等人为主的《关东三国志》,织田信长等人为
句子3:志》,织田信长等人为主的《战国三杰》,石田三成等人为主的《关原的年轻武者》,丰富
句子4:原的年轻武者》,丰富游戏内的剧情。

可以看到新切分的句子中会出现重叠的内容,这样可以一定程度上保证答案被包括,缺点就是stride越长,答案没有被切分概率越小,但是切分的句子会越多,变相增加了训练工作量。

3 关键代码

3.1 数据集处理

1)首先你需要做的就是拼接question和context,并且还有实现分块,好在tokenizer已经具备了

tokenized_datas = tokenizer(text=datas["question"],text_pair=datas["context"],return_offsets_mapping=True,  # 返回token与input_ids的位置映射return_overflowing_tokens=True,  # 设置将句子切分为多个句子(如果句子超过max_length的话)stride=128,  # 切分的重叠tokenmax_length=512, truncation="only_second", padding="max_length")

2)找到答案在contex被token后的开始位置和结束位置,我们知道分词的结果可能是一个词被分为多个token,因此要找出答案被token后的位置,需要借助offset_mapping。offset_mapping的作用,就是描述每个词被token后的token位置,如下:

[(0, 0), (0, 1), (1, 2), (2, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (0, 0)] 

从上面可以看出,里面某个词占据了(2,4)的范围。通过offset_mapping,我们就能将真正的答案在token后的位置找出来

def process_function(datas):tokenized_datas = tokenizer(text=datas["question"],text_pair=datas["context"],return_offsets_mapping=True,  # 返回token与input_ids的位置映射return_overflowing_tokens=True,  # 设置将句子切分为多个句子(如果句子超过max_length的话)stride=128,  # 切分的重叠tokenmax_length=512, truncation="only_second",  # 截断直接断第二个内容,也就是context,不能截断questionpadding="max_length")sample_mappings = tokenized_datas.pop("overflow_to_sample_mapping")start_positions = []  # 答案在token后的context的起始位置end_positions = []  # 答案在token后的context的结束位置data_ids = []for idx, _ in enumerate(sample_mappings):answer = datas["answers"][sample_mappings[idx]]answer_start = answer["answer_start"][0]answer_end = answer_start + len(answer["text"][0])context_start = tokenized_datas.sequence_ids(idx).index(1)context_end = tokenized_datas.sequence_ids(idx).index(None, context_start) - 1offset = tokenized_datas.get("offset_mapping")[idx]# 如果答案没有在context中if offset[context_end][1] < answer_start or offset[context_start][0] > answer_end:start_pos = 0end_pos = 0else:# 如果答案在context中token_index = context_startwhile token_index <= context_end and offset[token_index][0] < answer_start:token_index += 1start_pos = token_indextoken_index = context_endwhile token_index >= context_start and offset[token_index][1] > answer_end:token_index -= 1end_pos = token_indexstart_positions.append(start_pos)end_positions.append(end_pos)data_ids.append(datas["id"][sample_mappings[idx]])# 将question部分的offset_mapping设置为None,为了方便在评估时查找context时,快速过滤掉question部分tokenized_datas["offset_mapping"][idx] = [(o if tokenized_datas.sequence_ids(idx)[k] == 1 else None)for k, o in enumerate(tokenized_datas["offset_mapping"][idx])]tokenized_datas["data_ids"] = data_idstokenized_datas["start_positions"] = start_positionstokenized_datas["end_positions"] = end_positionsreturn tokenized_datas

3.2 模型加载

model = BertForQuestionAnswering.from_pretrained(model_path)

注意:这里使用的是transformers中的BertForQuestionAnswering,该类对bert模型进行封装。如果我们不使用该类,需要自己定义一个model,继承bert,增加分类线性层。另外使用AutoModelForQuestionAnswering也可以,其实AutoModel最终返回的也是BertForQuestionAnswering,它是根据你config中的model_type去匹配的。
这里列一下BertForQuestionAnswering的关键源代码说明一下transformers帮我们做了哪些关键事情

# 在__init__方法中增加增加了线性层
self.bert = BertModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
# 将输出结果outputs取第一个返回值,也就是last_hidden_state
sequence_output = outputs[0]  #  shape为(batch_size, sequence_length, hidden_size)
# 将last_hidden_state输入到qa_outputs线性层中,获得logits
logits = self.qa_outputs(sequence_output)  # shape为(batch_size, sequence_length, 2)
# 将logits分为2个
start_logits, end_logits = logits.split(1, dim=-1)  # shape为(batch_size, sequence_length, 1)
# 分别得到start和end的位置
start_logits = start_logits.squeeze(-1).contiguous()  # shape为(batch_size, sequence_length)
end_logits = end_logits.squeeze(-1).contiguous()  # shape为(batch_size, sequence_length)

3.3 评估函数

评估函数来自大神:https://github.com/zyds/transformers-code,其中需要使用到nltk库

def evaluate_function(prepredictions):start_logits, end_logits = prepredictions[0]if start_logits.shape[0] == len(new_datasets["validation"]):p, r = get_result(start_logits, end_logits, datasets["validation"], new_datasets["validation"])else:p, r = get_result(start_logits, end_logits, datasets["test"], new_datasets["test"])return evaluate_cmrc(p, r)def get_result(start_logits, end_logits, datas, features):predictions = {}references = {}# datas 和 feature的映射example_to_feature = collections.defaultdict(list)for idx, example_id in enumerate(features["data_ids"]):example_to_feature[example_id].append(idx)# 最优答案候选n_best = 20# 最大答案长度max_answer_length = 30for example in datas:example_id = example["id"]context = example["context"]answers = []for feature_idx in example_to_feature[example_id]:start_logit = start_logits[feature_idx]  # 预测结果开始位置end_logit = end_logits[feature_idx]     # 预测结果结束位置offset = features[feature_idx]["offset_mapping"]  # 每个词与token的映射start_indexes = numpy.argsort(start_logit)[::-1][:n_best].tolist()end_indexes = numpy.argsort(end_logit)[::-1][:n_best].tolist()for start_index in start_indexes:for end_index in end_indexes:if offset[start_index] is None or offset[end_index] is None:continueif end_index < start_index or end_index - start_index + 1 > max_answer_length:continueanswers.append({"text": context[offset[start_index][0]: offset[end_index][1]],"score": start_logit[start_index] + end_logit[end_index]})if len(answers) > 0:best_answer = max(answers, key=lambda x: x["score"])predictions[example_id] = best_answer["text"]else:predictions[example_id] = ""references[example_id] = example["answers"]["text"]return predictions, references

4 整体代码

"""
基于BERT做阅读理解(问答任务)
1)数据集来自:cmrc2018
2)模型权重使用:bert-base-chinese
"""
# step 1 引入数据库
import nltk
import numpy
import collections
from datasets import DatasetDict
from evaluate.cmrc_eval import evaluate_cmrc
from transformers import BertForQuestionAnswering, TrainingArguments, Trainer, DefaultDataCollator, \pipeline, BertTokenizerFastnltk.download("punkt")  # 评估函数中使用的库
model_path = "./model/tiansz/bert-base-chinese"
data_path = "data/cmrc2018"# step 2 数据集处理
datasets = DatasetDict.load_from_disk(data_path)
tokenizer = BertTokenizerFast.from_pretrained(model_path)def process_function(datas):tokenized_datas = tokenizer(text=datas["question"],text_pair=datas["context"],return_offsets_mapping=True,  # 返回token与input_ids的位置映射return_overflowing_tokens=True,  # 设置将句子切分为多个句子(如果句子超过max_length的话)stride=128,  # 切分的重叠tokenmax_length=512, truncation="only_second", padding="max_length")sample_mappings = tokenized_datas.pop("overflow_to_sample_mapping")start_positions = []  # 答案在输入的context的起始位置end_positions = []  # 答案在输入的context的结束位置data_ids = []for idx, _ in enumerate(sample_mappings):answer = datas["answers"][sample_mappings[idx]]answer_start = answer["answer_start"][0]answer_end = answer_start + len(answer["text"][0])context_start = tokenized_datas.sequence_ids(idx).index(1)context_end = tokenized_datas.sequence_ids(idx).index(None, context_start) - 1offset = tokenized_datas.get("offset_mapping")[idx]# 如果答案没有在context中if offset[context_end][1] < answer_start or offset[context_start][0] > answer_end:start_pos = 0end_pos = 0else:# 如果答案在context中token_index = context_startwhile token_index <= context_end and offset[token_index][0] < answer_start:token_index += 1start_pos = token_indextoken_index = context_endwhile token_index >= context_start and offset[token_index][1] > answer_end:token_index -= 1end_pos = token_indexstart_positions.append(start_pos)end_positions.append(end_pos)data_ids.append(datas["id"][sample_mappings[idx]])# 将question部分的offset_mapping设置为None,为了方便在评估时查找context时,快速过滤掉question部分tokenized_datas["offset_mapping"][idx] = [(o if tokenized_datas.sequence_ids(idx)[k] == 1 else None)for k, o in enumerate(tokenized_datas["offset_mapping"][idx])]tokenized_datas["data_ids"] = data_idstokenized_datas["start_positions"] = start_positionstokenized_datas["end_positions"] = end_positionsreturn tokenized_datasnew_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)# step 3 加载模型
model = BertForQuestionAnswering.from_pretrained(model_path)# step 4 评估函数
def evaluate_function(prepredictions):start_logits, end_logits = prepredictions[0]if start_logits.shape[0] == len(new_datasets["validation"]):p, r = get_result(start_logits, end_logits, datasets["validation"], new_datasets["validation"])else:p, r = get_result(start_logits, end_logits, datasets["test"], new_datasets["test"])return evaluate_cmrc(p, r)def get_result(start_logits, end_logits, datas, features):predictions = {}references = {}# datas 和 feature的映射example_to_feature = collections.defaultdict(list)for idx, example_id in enumerate(features["data_ids"]):example_to_feature[example_id].append(idx)# 最优答案候选n_best = 20# 最大答案长度max_answer_length = 30for example in datas:example_id = example["id"]context = example["context"]answers = []for feature_idx in example_to_feature[example_id]:start_logit = start_logits[feature_idx]  # 预测结果开始位置end_logit = end_logits[feature_idx]     # 预测结果结束位置offset = features[feature_idx]["offset_mapping"]  # 每个词与token的映射start_indexes = numpy.argsort(start_logit)[::-1][:n_best].tolist()end_indexes = numpy.argsort(end_logit)[::-1][:n_best].tolist()for start_index in start_indexes:for end_index in end_indexes:if offset[start_index] is None or offset[end_index] is None:continueif end_index < start_index or end_index - start_index + 1 > max_answer_length:continueanswers.append({"text": context[offset[start_index][0]: offset[end_index][1]],"score": start_logit[start_index] + end_logit[end_index]})if len(answers) > 0:best_answer = max(answers, key=lambda x: x["score"])predictions[example_id] = best_answer["text"]else:predictions[example_id] = ""references[example_id] = example["answers"]["text"]return predictions, references# step 5 创建TrainingArguments
# 原先train是1002条数据,但是拆分后的数据量是1439,batch_size=32,因此每个epoch的step=45,总step=135
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹per_device_train_batch_size=32,  # 训练时的batch_sizeper_device_eval_batch_size=32,    # 验证时的batch_sizenum_train_epochs=3,              # 训练轮数logging_steps=20,                # log 打印的频率evaluation_strategy="epoch",     # 评估策略save_strategy="epoch",           # 保存策略save_total_limit=3,              # 最大保存数load_best_model_at_end=True      # 训练完成后加载最优模型)# step 6 创建Trainer
trainer = Trainer(model=model,args=train_args,train_dataset=new_datasets["train"],eval_dataset=new_datasets["validation"],data_collator=DefaultDataCollator(),compute_metrics=evaluate_function,)# Step 7 模型训练
trainer.train()# step 8 模型评估
evaluate_result = trainer.evaluate(new_datasets["test"])
print(evaluate_result)# Step 9 模型预测
pipe = pipeline("question-answering", model=model, tokenizer=tokenizer)
result = pipe(question="乍都节公园位于什么地方?", context="乍都节公园位于泰国的首都曼谷的乍都节县,是拍凤裕庭路、威拍哇丽兰室路、甘烹碧路之间的一处公众游园地。")
print(result)

5 运行效果

在这里插入图片描述

注:本文参考来自大神:https://github.com/zyds/transformers-code

相关文章:

  • Redis常见数据类型(3)-String, Hash
  • 学习平台|基于Springboot+vue的学习平台系统的设计与实现(源码+数据库+文档)
  • c语言之运算符练习题
  • Spring Boot集成testcontainers快速入门Demo
  • 基于地理坐标的高阶几何编辑工具算法(5)——合并相交面
  • Python操作MySQL实战
  • 椋鸟C++笔记#3:类的默认成员函数
  • 【html】网页布局模板01---简谱风
  • Java_IO流学习
  • GESP 四级冲刺训练营(1):字符串
  • linux内核符号表
  • 踩坑——纪实
  • VUE 页面生命周期基本知识点
  • 瑞芯微RV1126——ffmpeg环境搭建
  • 国产linux系统(银河麒麟,统信uos)使用 PageOffice 国产版在线编辑word文件,同时保存数据和文件
  • [分享]iOS开发 - 实现UITableView Plain SectionView和table不停留一起滑动
  • CSS进阶篇--用CSS开启硬件加速来提高网站性能
  • hadoop入门学习教程--DKHadoop完整安装步骤
  • HTML-表单
  • js面向对象
  • Python学习笔记 字符串拼接
  • weex踩坑之旅第一弹 ~ 搭建具有入口文件的weex脚手架
  • Yeoman_Bower_Grunt
  • 给自己的博客网站加上酷炫的初音未来音乐游戏?
  • 基于阿里云移动推送的移动应用推送模式最佳实践
  • 极限编程 (Extreme Programming) - 发布计划 (Release Planning)
  • 使用common-codec进行md5加密
  • 我感觉这是史上最牛的防sql注入方法类
  • 大数据全解:定义、价值及挑战
  • 树莓派用上kodexplorer也能玩成私有网盘
  • # 详解 JS 中的事件循环、宏/微任务、Primise对象、定时器函数,以及其在工作中的应用和注意事项
  • #laravel 通过手动安装依赖PHPExcel#
  • (AtCoder Beginner Contest 340) -- F - S = 1 -- 题解
  • (阿里云在线播放)基于SpringBoot+Vue前后端分离的在线教育平台项目
  • (笔记)M1使用hombrew安装qemu
  • (附源码)spring boot球鞋文化交流论坛 毕业设计 141436
  • (译) 理解 Elixir 中的宏 Macro, 第四部分:深入化
  • (原創) 人會胖會瘦,都是自我要求的結果 (日記)
  • (转) ns2/nam与nam实现相关的文件
  • *++p:p先自+,然后*p,最终为3 ++*p:先*p,即arr[0]=1,然后再++,最终为2 *p++:值为arr[0],即1,该语句执行完毕后,p指向arr[1]
  • .a文件和.so文件
  • .Net Memory Profiler的使用举例
  • .Net MVC + EF搭建学生管理系统
  • .net Stream篇(六)
  • .NET WPF 抖动动画
  • .NET 给NuGet包添加Readme
  • .net 验证控件和javaScript的冲突问题
  • .Net6支持的操作系统版本(.net8已来,你还在用.netframework4.5吗)
  • .sh 的运行
  • ?php echo $logosrc[0];?,如何在一行中显示logo和标题?
  • @31省区市高考时间表来了,祝考试成功
  • @NoArgsConstructor和@AllArgsConstructor,@Builder
  • @Repository 注解
  • @vue/cli脚手架
  • [20171113]修改表结构删除列相关问题4.txt