当前位置: 首页 > news >正文

句向量模型之SimCSE——Pytorch

文章目录

  • 模型简介
    • Unsupervised SimCSE
      • 数据
      • 模型
      • 效果
    • Supervised SimCSE
      • 数据
      • 模型
      • 效果
  • 总结

模型简介

SimCSE模型主要分为两大块,一个是无监督的部分,一个是有监督的部分。整体结构如下图所示:

请添加图片描述

论文地址: https://arxiv.org/pdf/2104.08821.pdf

Unsupervised SimCSE

数据

对于无监督的部分, 最巧妙的是采用Dropout做数据增强, 来构建正例, 从而构建一个正样本对, 而负样本对则是在同一个batch中的其他句子.

那么有人会问了, 为何一个句子, 输入到模型两次, 会得到两个不同的向量呢?

这是因为: 模型中存在dropout层, 神经元随机失活会导致同一个句子在训练阶段输入到模型中得到的输出会不一样.

通过代码来看, 更直观一点:

class TrainDataset(Dataset):
    def __init__(self, data, tokenizer, model_type="unsup"):
        self.data = data
        self.tokenizer = tokenizer
        self.model_type = model_type

    def text2id(self, text):
        if self.model_type == "unsup":
            text_ids = self.tokenizer([text, text], max_length=MAXLEN, truncation=True, padding='max_length', return_tensors='pt')
        elif self.model_type == "sup":
            text_ids = self.tokenizer([text[0], text[1], text[2]], max_length=MAXLEN, truncation=True, padding='max_length', return_tensors='pt')

        return text_ids

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.text2id(self.data[index])

可见, 同一句话, 走两次Bert的Encoder, 生成两个相似的句向量, 当作正例.

模型

class SimcseUnsupModel(nn.Module):
    def __init__(self, pretrained_bert_path, drop_out) -> None:
        super(SimcseUnsupModel, self).__init__()

        self.pretrained_bert_path = pretrained_bert_path
        config = BertConfig.from_pretrained(self.pretrained_bert_path)
        config.attention_probs_dropout_prob = drop_out
        config.hidden_dropout_prob = drop_out
        self.bert = BertModel.from_pretrained(self.pretrained_bert_path, config=config)
    
    def forward(self, input_ids, attention_mask, token_type_ids, pooling="cls"):
        out = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True)

        if pooling == "cls":
            return out.last_hidden_state[:, 0]
        if pooling == "pooler":
            return out.pooler_output
        if pooling == 'last-avg':
            last = out.last_hidden_state.transpose(1, 2)
            return torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)
        if self.pooling == 'first-last-avg':
            first = out.hidden_states[1].transpose(1, 2)
            last = out.hidden_states[-1].transpose(1, 2)
            first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1)
            last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)
            avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1)
            return torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1)
        
        # 有实验表明cls的pooling方式效果最好

细心的同学已经发现了, 什么simcse, 明明就是bert嘛.
没错, 与Bert相比, Simcse只改变了drop_out, 利用Bert做数据增强, 但在计算Loss时, Simcse引入了对比Loss

    def train(self, train_dataloader, dev_dataloader):
        self.model.train()
        for batch_idx, source in enumerate(tqdm(train_dataloader), start=1):
            real_batch_num = source.get('input_ids').shape[0] # source.get('input_ids').shape [64, 2, 64]
            input_ids = source.get('input_ids').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]
            attention_mask = source.get('attention_mask').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]
            token_type_ids = source.get('token_type_ids').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]

            out = self.model(input_ids, attention_mask, token_type_ids) # out.shape [128, 768]  
            loss = self.simcse_unsup_loss(out)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if batch_idx % 10 == 0:     
                logger.info(f'loss: {loss.item():.4f}')
                corrcoef = self.eval(dev_dataloader)
                self.model.train()
                if self.best_loss > corrcoef:
                    self.best_loss = corrcoef
                    torch.save(self.model.state_dict(), self.model_save_path)
                    logger.info(f"higher corrcoef: {self.best_loss:.4f} in batch: {batch_idx}, save model")


    def simcse_unsup_loss(self, y_pred):
        y_true = torch.arange(y_pred.shape[0], device=self.device)
        y_true = (y_true - y_true % 2 * 2) + 1
        sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
        sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
        sim = sim / 0.05
        loss = F.cross_entropy(sim, y_true)
        return loss

train函数中的source为Bert的tokenizer输出, 包含三个字短段: input_ids, token_type_ids, attention_mask
如下图:
请添加图片描述
input_ids的第一维为batch_size, 第二维是输入的句子数量, 输入了两个句子(同一个句子输入bert两次), 所以第二维是2, 第三维是句子的max_length

接下来我们看loss的计算过程, 将每一步拆解开:

1、给128个句子, 生成0-127的索引

y_true = torch.arange(y_pred.shape[0], device=self.device)

请添加图片描述

2、生成每个句子对应的真实的标签

y_true = (y_true - y_true % 2 * 2) + 1

请添加图片描述
注意看这一步的y_true与第一步y_true的区别.

这里的y_true, 实际上是每个句子对应的正例在这一个batch中的索引, 比如:

与第0个句子相似的句子索引为1
与第1个句子相似的句子索引为0

与第2个句子相似的句子索引为3
与第2个句子相似的句子索引为2

注意我是从第0个句子开始算的

3、两两计算相似度

sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)

y_pred维度为[128, 768]

sim的维度为[128, 128]

每一行表示当前句子与其他句子的相似度, 此时, 对角线上的数值应该为1请添加图片描述

4、将对角线上的数值放大为一个较大的数, 消除对角线上自身loss的影响(负无穷计算交叉熵时几乎为0)

sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12

5、乘以超参数温度系数, 至于为什么是0.05, 只能说实验表明, 0.05效果好

sim = sim / 0.05

6、用交叉熵损失表示对比损失, 将相似句子看作分类, 拉近与正例的距离, 拉远与负例的距离, 同一个batch中, 除了输入bert两次的那个句子互为正例, 其他句子都是负例

loss = F.cross_entropy(sim, y_true)

效果

请添加图片描述

Supervised SimCSE

数据

与无监督不同, 无监督的输入为单个text句子, 而有监督的数据集为 [text, text+, text-]的三元组
请添加图片描述

模型

模型部分与有监督一样, 也是利用bert的encode做编码, 取cls句向量

我们重点看一下不一样的部分, loss计算:

    def simcse_sup_loss(self, y_pred):
        y_true = torch.arange(y_pred.shape[0], device=self.device)
        use_row = torch.where((y_true + 1) % 3 != 0)[0]
        y_true = (use_row - use_row % 3 * 2) + 1
        sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
        sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
        sim = torch.index_select(sim, 0, use_row)
        sim = sim / 0.05
        loss = F.cross_entropy(sim, y_true)
        return loss

1、生成0-191的索引

y_true = torch.arange(y_pred.shape[0], device=self.device)

2、选择使用的索引, 每第三句没有label, 第三句为负例, 不使用第三句, 把同一个batch内的其他句子当作负例

use_row = torch.where((y_true + 1) % 3 != 0)[0]

3、丢弃第三句后的真实label

y_true = (use_row - use_row % 3 * 2) + 1

请添加图片描述

4、两两计算相似度, 此时sim的维度是[192, 192], 包含了第三句的负例

sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)

请添加图片描述
5、消除对角线上维度的影响

sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12

6、挑选出有用的行

sim = torch.index_select(sim, 0, use_row)

请添加图片描述

7、计算交叉熵损失, 与无监督的方法一致

loss = F.cross_entropy(sim, y_true)

效果

请添加图片描述

总结

大道至简

全部代码已上传至Github, 链接: https://github.com/seanzhang-zhichen/simcse-pytorch

数据集: 提取码: hlva

相关文章:

  • 简单旅游景点HTML网页设计作品 DIV布局故宫介绍网页模板代码 DW家乡网站制作成品 web网页制作与实现
  • 图解redis(四)——高可用篇
  • LQ0048 交换瓶子【无标题】
  • 《SpringBoot篇》11.JPA常用注解只需一个表
  • 不想手敲代码?Jupyter Notebook 又一利器 Visual Python
  • 【mysql体系结构】InnoDB索引页结构
  • Roson的Qt之旅 #123 QNetworkConfigurationManager网络配置管理
  • 【数据结构与算法】ArrayList的模拟实现
  • Spring5源码之IOC的Bean管理之xml
  • DHCP 服务
  • [架构之路-20]:目标系统 - 硬件平台 - 嵌入式系统硬件电路基础:架构、设计流程、总线、外设、基本电路、编码
  • 关系代数 运算
  • (附源码)计算机毕业设计ssm本地美食推荐平台
  • python使用cv2库实现图像的读取处理显示和保存
  • 二道题:分组顺序向下填充 和 标注数据整理
  • Angular 4.x 动态创建组件
  • Apache Spark Streaming 使用实例
  • flask接收请求并推入栈
  • MySQL-事务管理(基础)
  • opencv python Meanshift 和 Camshift
  • 反思总结然后整装待发
  • 简单易用的leetcode开发测试工具(npm)
  • 浅谈JavaScript的面向对象和它的封装、继承、多态
  • 它承受着该等级不该有的简单, leetcode 564 寻找最近的回文数
  • 移动端 h5开发相关内容总结(三)
  • 深度学习之轻量级神经网络在TWS蓝牙音频处理器上的部署
  • "无招胜有招"nbsp;史上最全的互…
  • #include到底该写在哪
  • #我与Java虚拟机的故事#连载13:有这本书就够了
  • $.ajax()参数及用法
  • (32位汇编 五)mov/add/sub/and/or/xor/not
  • (52)只出现一次的数字III
  • (javascript)再说document.body.scrollTop的使用问题
  • (pojstep1.3.1)1017(构造法模拟)
  • (二十四)Flask之flask-session组件
  • (一)插入排序
  • .bat批处理(十):从路径字符串中截取盘符、文件名、后缀名等信息
  • .NET C#版本和.NET版本以及VS版本的对应关系
  • .net core 6 redis操作类
  • .NET Core IdentityServer4实战-开篇介绍与规划
  • .NET MVC之AOP
  • .NET 使用 JustAssembly 比较两个不同版本程序集的 API 变化
  • .net打印*三角形
  • .NET导入Excel数据
  • @data注解_一枚 架构师 也不会用的Lombok注解,相见恨晚
  • @RequestParam,@RequestBody和@PathVariable 区别
  • @Transactional 详解
  • [20150321]索引空块的问题.txt
  • [C#]winform制作仪表盘好用的表盘控件和使用方法
  • [C/C++随笔] char与unsigned char区别
  • [CareerCup][Google Interview] 实现一个具有get_min的Queue
  • [CUDA手搓]从零开始用C++ CUDA搭建一个卷积神经网络(LeNet),了解神经网络各个层背后算法原理
  • [DM复习]Apriori算法-国会投票记录关联规则挖掘(上)
  • [hibernate]基本值类型映射之日期类型
  • [IT生活推荐]大家一起来玩游戏喽,来的都进!