当前位置: 首页 > news >正文

Cross-Encoder实现文本匹配(重排序模型)

引言

前面几篇文章都是基于表示型的方法训练BERT进行文本匹配,而本文是以交互型的方法。具体来说,将待匹配的两个句子拼接成一个输入喂给BERT模型,最后让其输出一个相似性得分。

文本匹配系列文章先更新到此,目前为止都是基于监督学习Sentence Pair的方式,后续有时间继续更新对比学习三元组(anchor, positive, negative)的方式和无监督学习的方式。

架构

image-20231018135433256

Cross-Encoder会利用自注意力机制不断计算这两个句子之间的交互(注意力),最后接一个分类器输出一个分数(logits)代表相似度(可以经过sigmoid变成一个概率)。

实现

实现采用类似Huggingface的形式,每个文件夹下面有一种模型。分为modelingargumentstrainer等不同的文件。不同的架构放置在不同的文件夹内。

modeling.py:

import torch
from torch import nn
import numpy as npfrom tqdm import tqdmfrom transformers import (AutoTokenizer,AutoConfig,AutoModelForSequenceClassification,
)
from torch.utils.data import DataLoader
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.tokenization_utils_base import BatchEncodingimport logginglogger = logging.getLogger(__name__)class SentenceBert(nn.Module):def __init__(self,model_name: str,max_length: int = None,trust_remote_code: bool = True,) -> None:super().__init__()self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)self.config.num_labels = 1# rerankerself.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config, trust_remote_code=trust_remote_code)self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)self.max_length = max_lengthself.loss_fct = nn.BCEWithLogitsLoss()def batching_collate(self, batch: list[tuple[str, str]]) -> BatchEncoding:texts = [[] for _ in range(len(batch[0]))]for example in batch:for idx, text in enumerate(example):texts[idx].append(text.strip())tokenized = self.tokenizer(*texts,padding=True,truncation="longest_first",return_tensors="pt",max_length=self.max_length).to(self.model.device)return tokenizeddef predict(self,sentences: list[tuple[str, str]],batch_size: int = 64,convert_to_tensor: bool = True,show_progress_bar: bool = False,):dataloader = DataLoader(sentences,batch_size=batch_size,collate_fn=self.batching_collate,shuffle=False,)preds = []for batch in tqdm(dataloader, disable=not show_progress_bar, desc="Running Inference"):with torch.no_grad():logits = self.model(**batch).logitslogits = torch.sigmoid(logits)preds.extend(logits)if convert_to_tensor:preds = torch.stack(preds)else:preds = np.asarray([pred.cpu().detach().float().numpy() for pred in preds])return predsdef forward(self, inputs, labels=None):outputs = self.model(**inputs, return_dict=True)if labels is not None:labels = labels.float()logits = outputs.logitslogits = logits.view(-1)loss = self.loss_fct(logits, labels)return SequenceClassifierOutput(loss, **outputs)return outputsdef save_pretrained(self, output_dir: str) -> None:state_dict = self.model.state_dict()state_dict = type(state_dict)({k: v.clone().cpu().contiguous() for k, v in state_dict.items()})self.model.save_pretrained(output_dir, state_dict=state_dict)

整个模型的实现放到modeling.py文件中。

这里首先设置类别数为1num_labels = 1;然后通过AutoModelForSequenceClassification增加一个序列分类器头,该分类器头核心代码为:

class BertForSequenceClassification(BertPreTrainedModel):def __init__(self, config):super().__init__(config)self.num_labels = config.num_labelsself.config = config# 实例化BERT模型self.bert = BertModel(config)# 增加一个线性层,从hidden_size映射为num_labels维度,这里是1self.classifier = nn.Linear(config.hidden_size, config.num_labels)def forward(self,input_ids: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,token_type_ids: Optional[torch.Tensor] = None,position_ids: Optional[torch.Tensor] = None,head_mask: Optional[torch.Tensor] = None,inputs_embeds: Optional[torch.Tensor] = None,labels: Optional[torch.Tensor] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:# 先得到bert模型的输出outputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)# 实际上是cls 标记对应的表示pooled_output = outputs[1]# 得到一个一维的logitslogits = self.classifier(pooled_output)

BERT模型中所谓的pooled_output实际上是:

class BertPooler(nn.Module):def __init__(self, config):super().__init__()# 从hidden_size空间映射到另一个hidden_size空间self.dense = nn.Linear(config.hidden_size, config.hidden_size)# 经过tanh激活函数self.activation = nn.Tanh()def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:# 取最后一层隐藏状态第一个token: [cls]first_token_tensor = hidden_states[:, 0]pooled_output = self.dense(first_token_tensor)pooled_output = self.activation(pooled_output)return pooled_output

回到我们的modeling.py,训练时利用forward方法;推理时利用predict方法,支持批处理。输入是表示语句对的元组。

arguments.py:

from dataclasses import dataclass, field
from typing import Optionalimport os@dataclass
class ModelArguments:model_name_or_path: str = field(metadata={"help": "Path to pretrained model or model identifier from huggingface"})config_name: Optional[str] = field(default=None,metadata={"help": "Pretrained config name or path if not the same as model_name"},)tokenizer_name: Optional[str] = field(default=None,metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},)@dataclass
class DataArguments:train_data_path: str = field(default=None, metadata={"help": "Path to train corpus"})eval_data_path: str = field(default=None, metadata={"help": "Path to eval corpus"})max_length: int = field(default=512,metadata={"help": "The maximum total input sequence length after tokenization for input text."},)def __post_init__(self):if not os.path.exists(self.train_data_path):raise FileNotFoundError(f"cannot find file: {self.train_data_path}, please set a true path")if not os.path.exists(self.eval_data_path):raise FileNotFoundError(f"cannot find file: {self.eval_data_path}, please set a true path")

定义了模型和数据相关参数。

dataset.py:

from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, DataCollatorWithPaddingfrom datasets import Dataset as dtfrom typing import Anyfrom utils import build_dataframe_from_csvclass PairDataset(Dataset):def __init__(self, data_path: str, tokenizer: PreTrainedTokenizer, max_len: int) -> None:df = build_dataframe_from_csv(data_path)self.dataset = dt.from_pandas(df, split="train")self.total_len = len(self.dataset)self.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return self.total_lendef __getitem__(self, index) -> dict[str, Any]:query1 = self.dataset[index]["query1"]query2 = self.dataset[index]["query2"]label = self.dataset[index]["label"]encoding = self.tokenizer.encode_plus(query1,query2,truncation="only_second",max_length=self.max_len,padding=False,)encoding["label"] = labelreturn encoding

数据集类考虑了LCQMC数据集的格式,即成对的语句和一个数值标签。类似:

Hello.	Hi.	1
Nice to see you.	Nice	0

这里数据集的处理和之前的有所不同,主要是调用encode_plus将文本对拼接在一起,并且仅阶段第二个文本。

这里没有进行padding,交给DataCollatorWithPadding来做。

trainer.py:

import torch
from transformers.trainer import Trainerfrom typing import Optional
import os
import loggingTRAINING_ARGS_NAME = "training_args.bin"from modeling import SentenceBertlogger = logging.getLogger(__name__)class CrossTrainer(Trainer):def compute_loss(self, model: SentenceBert, inputs, return_outputs=False):labels = inputs.pop("labels")return model(inputs, labels)["loss"]def _save(self, output_dir: Optional[str] = None, state_dict=None):# If we are executing this function, we are the process zero, so we don't check for that.output_dir = output_dir if output_dir is not None else self.args.output_diros.makedirs(output_dir, exist_ok=True)logger.info(f"Saving model checkpoint to {output_dir}")self.model.save_pretrained(output_dir)if self.tokenizer is not None:self.tokenizer.save_pretrained(output_dir)# Good practice: save your training arguments together with the trained modeltorch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

继承🤗 Transformers的Trainer类,重写了compute_loss_save方法。

这样我们就可以利用🤗 Transformers来训练我们的模型了。

utils.py:

import torch
import pandas as pd
from scipy.stats import pearsonr, spearmanr
from typing import Tupledef build_dataframe_from_csv(dataset_csv: str) -> pd.DataFrame:df = pd.read_csv(dataset_csv,sep="\t",header=None,names=["query1", "query2", "label"],)return dfdef compute_spearmanr(x, y):return spearmanr(x, y).correlationdef compute_pearsonr(x, y):return pearsonr(x, y)[0]def find_best_acc_and_threshold(scores, labels, high_score_more_similar: bool):"""Copied from https://github.com/UKPLab/sentence-transformers/tree/master"""assert len(scores) == len(labels)rows = list(zip(scores, labels))rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar)max_acc = 0best_threshold = -1# positive examples number so farpositive_so_far = 0# remain negative examplesremaining_negatives = sum(labels == 0)for i in range(len(rows) - 1):score, label = rows[i]if label == 1:positive_so_far += 1else:remaining_negatives -= 1acc = (positive_so_far + remaining_negatives) / len(labels)if acc > max_acc:max_acc = accbest_threshold = (rows[i][0] + rows[i + 1][0]) / 2return max_acc, best_thresholddef metrics(y: torch.Tensor, y_pred: torch.Tensor) -> Tuple[float, float, float, float]:TP = ((y_pred == 1) & (y == 1)).sum().float()  # True PositiveTN = ((y_pred == 0) & (y == 0)).sum().float()  # True NegativeFN = ((y_pred == 0) & (y == 1)).sum().float()  # False NegatvieFP = ((y_pred == 1) & (y == 0)).sum().float()  # False Positivep = TP / (TP + FP).clamp(min=1e-8)  # Precisionr = TP / (TP + FN).clamp(min=1e-8)  # RecallF1 = 2 * r * p / (r + p).clamp(min=1e-8)  # F1 scoreacc = (TP + TN) / (TP + TN + FP + FN).clamp(min=1e-8)  # Accuraryreturn acc, p, r, F1def compute_metrics(predicts, labels):return metrics(labels, predicts)

定义了一些帮助函数,从sentence-transformers库中拷贝了寻找最佳准确率阈值的实现find_best_acc_and_threshold

除了准确率,还计算了句嵌入的余弦相似度与真实标签之间的斯皮尔曼等级相关系数指标。

最后定义训练和测试脚本。

train.py:

from transformers import (set_seed,HfArgumentParser,TrainingArguments,DataCollatorWithPadding,
)import logging
import os
from pathlib import Pathfrom datetime import datetimefrom modeling import SentenceBert
from trainer import CrossTrainer
from arguments import DataArguments, ModelArguments
from dataset import PairDatasetlogger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%m/%d/%Y %H:%M:%S",level=logging.INFO,
)def main():parser = HfArgumentParser((TrainingArguments, DataArguments, ModelArguments))training_args, data_args, model_args = parser.parse_args_into_dataclasses()output_dir = f"{training_args.output_dir}/{model_args.model_name_or_path.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"training_args.output_dir = output_dirlogger.info(f"Training parameters {training_args}")logger.info(f"Data parameters {data_args}")logger.info(f"Model parameters {model_args}")set_seed(training_args.seed)model = SentenceBert(model_args.model_name_or_path,max_length=data_args.max_length,trust_remote_code=True,)tokenizer = model.tokenizertrain_dataset = PairDataset(data_args.train_data_path,tokenizer,data_args.max_length,)eval_dataset = PairDataset(data_args.eval_data_path,tokenizer,data_args.max_length,)trainer = CrossTrainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,data_collator=DataCollatorWithPadding(tokenizer),tokenizer=tokenizer,)Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)trainer.train()trainer.save_model()if __name__ == "__main__":main()

训练

基于train.py定义了train.sh传入相关参数:

timestamp=$(date +%Y%m%d%H%M)
logfile="train_${timestamp}.log"# change CUDA_VISIBLE_DEVICES
CUDA_VISIBLE_DEVICES=1 nohup python train.py \--model_name_or_path=hfl/chinese-macbert-large \--output_dir=output \--train_data_path=data/train.txt \--eval_data_path=data/dev.txt \--num_train_epochs=3 \--save_total_limit=5 \--learning_rate=2e-5 \--weight_decay=0.01 \--warmup_ratio=0.01 \--bf16=True \--save_strategy=epoch \--per_device_train_batch_size=64 \--report_to="none" \--remove_unused_columns=False \--max_length=128 \> "$logfile" 2>&1 &

以上参数根据个人环境修改,这里使用的是哈工大的chinese-macbert-large预训练模型。

注意:

  • 通过bf16=True可以加速训练同时不影响效果,不支持可以尝试fp16
  • 其他参数可以自己调整。
100%|██████████| 18655/18655 [1:15:47<00:00,  5.06it/s]
100%|██████████| 18655/18655 [1:15:47<00:00,  4.10it/s]
{'loss': 0.0464, 'grad_norm': 4.171152591705322, 'learning_rate': 1.6785791639592811e-07, 'epoch': 4.96}
{'train_runtime': 4547.2543, 'train_samples_per_second': 262.539, 'train_steps_per_second': 4.102, 'train_loss': 0.11396670312096753, 'epoch': 5.0}

这里训练了5轮,为了测试效果,但发现实际上3轮的结果还好一些,因此最终拿它来测试。

测试

test.py: 测试脚本见后文的完整代码。

test.sh:

# change CUDA_VISIBLE_DEVICES
CUDA_VISIBLE_DEVICES=0 python test.py \--model_name_or_path=output/checkpoint-11193 \--test_data_path=data/test.txt

输出:

TestArguments(model_name_or_path='output/checkpoint-11193', test_data_path='data/test.txt', max_length=64, batch_size=128)
Running Inference: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:19<00:00,  5.05it/s]
max_acc: 0.8954, best_threshold: 0.924443
spearman corr: 0.7996 |  pearson_corr corr: 0.8021 | compute time: 19.44s
accuracy=0.895 precision=0.911 recal=0.876 f1 score=0.8934

测试集上的准确率达到89.5%,spearman系数达到79.96,这两个指标都是本系列文章的SOTA结果,但是没有期望的那么高。可能一般用cross-encoder 模型做精排,选出top-k啥的。

下面是近期几种训练方法的一个对比:

模型(目标函数)准确率(%)spearman(*100)pearson(*100)
Bi-Encoder(Classifier)89.1879.8275.14
Bi-Encoder(Regression)88.3277.9576.68
Bi-Encoder(Contrastive)88.8177.9557.01
Bi-Encoder(CoSENT)89.4079.8977.03
Cross-Encoder89.5479.9680.21

完整代码

完整代码: →点此←

参考

  1. [论文笔记]Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 数据库DDL语句
  • linux驱动开发-ioctl
  • 基于 onsemi NCV78343 NCV78964的汽车矩阵式大灯方案
  • scroll-behavior属性与页面平滑滚动
  • 论文翻译:arxiv-2022 Ignore Previous Prompt: Attack Techniques For Language Models
  • Redis——通用命令
  • 探索广东省自闭症寄宿学校的独特教育模式
  • Python编程 - 三器一包
  • 使用Python实现多个PDF文件的合并
  • free 命令:显示内存使用情况
  • 电脑怎么录屏?四款录屏工具分享
  • yum下载软件失败:‘Could not resolve host: mirrorlist .centos .org; Unknowm error
  • 在服务器上开Juypter Lab教程(远程访问)
  • VSCode C++(Code Runner)+ OpenSSL开发环境搭建
  • Windows 11上pip报‘TLS/SSL connection has been closed (EOF) (_ssl.c:1135)‘的解决方法
  • Apache Zeppelin在Apache Trafodion上的可视化
  • DOM的那些事
  • IDEA 插件开发入门教程
  • Js实现点击查看全文(类似今日头条、知乎日报效果)
  • OpenStack安装流程(juno版)- 添加网络服务(neutron)- controller节点
  • PHP变量
  • Redis在Web项目中的应用与实践
  • 多线程 start 和 run 方法到底有什么区别?
  • 高度不固定时垂直居中
  • 基于Volley网络库实现加载多种网络图片(包括GIF动态图片、圆形图片、普通图片)...
  • 免费小说阅读小程序
  • 浅谈Golang中select的用法
  • 通过来模仿稀土掘金个人页面的布局来学习使用CoordinatorLayout
  • 微信端页面使用-webkit-box和绝对定位时,元素上移的问题
  • 验证码识别技术——15分钟带你突破各种复杂不定长验证码
  • 交换综合实验一
  • 智能情侣枕Pillow Talk,倾听彼此的心跳
  • #stm32驱动外设模块总结w5500模块
  • #我与Java虚拟机的故事#连载12:一本书带我深入Java领域
  • (13)Latex:基于ΤΕΧ的自动排版系统——写论文必备
  • (Qt) 默认QtWidget应用包含什么?
  • (附源码)python旅游推荐系统 毕业设计 250623
  • (附源码)ssm高校志愿者服务系统 毕业设计 011648
  • (黑马点评)二、短信登录功能实现
  • (汇总)os模块以及shutil模块对文件的操作
  • (四)图像的%2线性拉伸
  • (转) Face-Resources
  • (总结)Linux下的暴力密码在线破解工具Hydra详解
  • ..thread“main“ com.fasterxml.jackson.databind.JsonMappingException: Jackson version is too old 2.3.1
  • .NET Core中的时区转换问题
  • .net 提取注释生成API文档 帮助文档
  • .NET 中创建支持集合初始化器的类型
  • .NET/C# 使用 #if 和 Conditional 特性来按条件编译代码的不同原理和适用场景
  • .net6 core Worker Service项目,使用Exchange Web Services (EWS) 分页获取电子邮件收件箱列表,邮件信息字段
  • .NetCore Flurl.Http 升级到4.0后 https 无法建立SSL连接
  • .NET下的多线程编程—1-线程机制概述
  • .net用HTML开发怎么调试,如何使用ASP.NET MVC在调试中查看控制器生成的html?
  • .Net组件程序设计之线程、并发管理(一)
  • //usr/lib/libgdal.so.20:对‘sqlite3_column_table_name’未定义的引用
  • [ web基础篇 ] Burp Suite 爆破 Basic 认证密码