当前位置: 首页 > news >正文

心法利器[118] | 向量检索组件(含代码)

心法利器

本栏目主要和大家一起讨论近期自己学习的心得和体会。具体介绍:仓颉专项:飞机大炮我都会,利器心法我还有。

2023年新的文章合集已经发布,获取方式看这里:又添十万字-CS的陋室2023年文章合集来袭,更有历史文章合集,欢迎下载。

往期回顾

向量召回这个东西我在好多文章里都有说过,包括代码在内,在很多带有代码的文章里也有提到,类似这些:

  • 心法利器[114] | 通用大模型文本分类实践(含代码)

  • 心法利器[104] | 基础RAG-向量检索模块(含代码)

但是目前还没有单独做成组件,这次我就把这个东西做成一个单独的组件开源出来,供大家快速修改使用吧。

代码地址:https://github.com/ZBayes/poc_project/tree/main/vec_searcher

有关什么是向量检索,以及对应向量表征模型的关系,见这篇文章:心法利器[16] | 向量表征和向量召回,此处不再赘述。

叠个甲,这只是最简单,平时可以用来做实验的案例,受限于性能、存储大小等因素,主要是因为索引通常对硬盘、内存都有一定要求,尤其是在数据比较大的情况下,生产场景下,更推荐用elasticsearch、milvus等分布式工具,FAISS大部分情况主要只用在单机,当然,数据量不大,单机也可以(十万级都还可以)。

核心代码

代码结构

代码的结构如下:

|-- data
|   |-- index
|   |   `-- vec_index_toutiao_20240702_DEBUG
|   |       |-- forward_index.txt
|   |       `-- invert_index.faiss
|   `-- toutiao_cat_data
|       `-- toutiao_cat_data.txt
|-- script
|   `-- build_vec_index.py
|-- searcher.py
|-- utils
|   `-- data_processing.py
|-- vec_model
|   |-- simcse_model.py
|   `-- vec_model.py
`-- vec_searcher|-- vec_index.py`-- vec_searcher.py

这里其实最关键的就是vec_searcher文件夹里的内容,就是最简单的向量搜索引擎,如果要把向量表征模型囊括进来,那就是searcher,这个包了一层,把向量检索引擎和模型都放一块,build_vec_index.py是把数据读入构造向量召回的脚本。

此处为了举例,使用的是同一个项目下的数据:https://github.com/aceimnorstuvwxz/toutiao-text-classfication-dataset。

向量索引核心

首先是比较关键的代码,vec_searcher/vec_index.py是向量索引的核心,里面包装了FAISS中最常用的几个功能函数。

import faiss
from loguru import loggerclass VecIndex:def __init__(self) -> None:self.index = ""def build(self, index_dim):description = "HNSW64"measure = faiss.METRIC_L2self.index = faiss.index_factory(index_dim, description, measure)def insert(self, vec):self.index.add(vec)def batch_insert(self, vecs):self.index.add(vecs)def load(self, read_path):# read_path: XXX.indexself.index = faiss.read_index(read_path)def save(self, save_path):# save_path: XXX.indexfaiss.write_index(self.index, save_path)def search(self, vec, num):# id, distancereturn self.index.search(vec, num)

看具体函数名应该能理解具体含义,内部可能改造空间比较大的就是build里面对FAISS索引的基本控制,具体的其他配置可以看FAISS的官方文档做扩展使用:https://github.com/facebookresearch/faiss/wiki

然后是在其基础上包裹的新的一层,vec_searcher/vec_searcher.py,这个模块除了对FAISS的基础工作做了包装,最重要的是把索引对应的详情信息放了进来,即搜索系统中的常说的正排,一般情况,除了比较显著的文本外,还有很多别的信息,例如标题、更新时间这些,虽说不需要一起做向量表征进入索引,但是还是希望存起来的,正排就是为了存储这个信息,一般是通过Key-value的形式来存储的,这点大家可以重点看代码里forward有关的处理过程,直接看代码吧。

import os, json
from loguru import logger
from vec_searcher.vec_index import VecIndexclass VecSearcher:def __init__(self):self.invert_index = VecIndex() # 检索倒排,使用的是索引是VecIndexself.forward_index = [] # 检索正排,实质上只是个list,通过ID获取对应的内容self.INDEX_FOLDER_PATH_TEMPLATE = "data/index/{}"def build(self, index_dim, index_name):self.index_name = index_nameself.index_folder_path = self.INDEX_FOLDER_PATH_TEMPLATE.format(index_name)if not os.path.exists(self.index_folder_path) or not os.path.isdir(self.index_folder_path):os.mkdir(self.index_folder_path)self.invert_index = VecIndex()self.invert_index.build(index_dim)self.forward_index = []def insert(self, vec, doc):self.invert_index.insert(vec)# self.invert_index.batch_insert(vecs)self.forward_index.append(doc)def save(self):with open(self.index_folder_path + "/forward_index.txt", "w", encoding="utf8") as f:for data in self.forward_index:f.write("{}\n".format(json.dumps(data, ensure_ascii=False)))self.invert_index.save(self.index_folder_path + "/invert_index.faiss")def load(self, index_name):self.index_name = index_nameself.index_folder_path = self.INDEX_FOLDER_PATH_TEMPLATE.format(index_name)self.invert_index = VecIndex()self.invert_index.load(self.index_folder_path + "/invert_index.faiss")self.forward_index = []with open(self.index_folder_path + "/forward_index.txt", encoding="utf8") as f:for line in f:self.forward_index.append(json.loads(line.strip()))def search(self, vecs, nums = 5):search_res = self.invert_index.search(vecs, nums)recall_list = []for idx in range(nums):# recall_list_idx, recall_list_detail, distancerecall_list.append([search_res[1][0][idx], self.forward_index[search_res[1][0][idx]], search_res[0][0][idx]])# recall_list = list(filter(lambda x: x[2] < 100, result))return recall_list

这两个应该是向量检索里最有关的两块内容了,但一般向量搜索是离不开向量表征的,最常见的把文本转为向量,就需要用向量表征模型,一遍比较常用的是BGE系列的有关模型,但在个人的实验上,simcse仍旧非常能打,此处我就用这个模型为例讲讲如何把向量模型包进来。

把模型纳入向量搜索系统

首先是模型的定义,此处我弄了两层,一层是最基础的模型vec_model/simcse_model.py,一个是通用的模型接口vec_model/vec_model.py,在更换模型的时候,直接更换基础模型这个文件应该就够了。

首先得vec_model/simcse_model.py里面,只需要俩,一个加载一个推理即可,当然了,更粗暴的可以把tokenizer也放进来,此处我没放。

import torch
import torch.nn as nn
from loguru import logger
from tqdm import tqdm
from transformers import BertConfig, BertModel, BertTokenizerclass SimcseModel(nn.Module):# https://blog.csdn.net/qq_44193969/article/details/126981581def __init__(self, pretrained_bert_path, pooling="cls") -> None:super(SimcseModel, self).__init__()self.pretrained_bert_path = pretrained_bert_pathself.config = BertConfig.from_pretrained(self.pretrained_bert_path)self.model = BertModel.from_pretrained(self.pretrained_bert_path, config=self.config)self.model.eval()# self.model = Noneself.pooling = poolingdef forward(self, input_ids, attention_mask, token_type_ids):out = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)return out.last_hidden_state[:, 0]

然后是vec_model/vec_model.py,接口相对通用,然后内部可以自己切换想要的模型。

import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import loggerfrom transformers import BertTokenizerfrom vec_model.simcse_model import SimcseModelimport onnxruntime as ortclass VectorizeModel:def __init__(self, ptm_model_path, device = "cpu") -> None:self.tokenizer = BertTokenizer.from_pretrained(ptm_model_path)self.model = SimcseModel(pretrained_bert_path=ptm_model_path, pooling="cls")# print(self.model)self.model.eval()self.DEVICE = torch.device('cuda' if torch.cuda.is_available() else "cpu")# self.DEVICE = devicelogger.info(self.DEVICE)self.model.to(self.DEVICE)self.pdist = nn.PairwiseDistance(2)def predict_vec(self,query):q_id = self.tokenizer(query, max_length = 200, truncation=True, padding="max_length", return_tensors='pt')with torch.no_grad():q_id_input_ids = q_id["input_ids"].squeeze(1).to(self.DEVICE)q_id_attention_mask = q_id["attention_mask"].squeeze(1).to(self.DEVICE)q_id_token_type_ids = q_id["token_type_ids"].squeeze(1).to(self.DEVICE)q_id_pred = self.model(q_id_input_ids, q_id_attention_mask, q_id_token_type_ids)return q_id_preddef predict_vec_request(self, query):q_id_pred = self.predict_vec(query)return q_id_pred.cpu().numpy().tolist()def predict_sim(self, q1, q2):q1_v = self.predict_vec(q1)q2_v = self.predict_vec(q2)sim = F.cosine_similarity(q1_v[0], q2_v[0], dim=-1)return sim.cpu().numpy().tolist()class VectorizeModel_v2(VectorizeModel):def __init__(self, ptm_model_path, onnx_path, providers=['CUDAExecutionProvider']) -> None:# ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']self.tokenizer = BertTokenizer.from_pretrained(ptm_model_path)self.model = ort.InferenceSession(onnx_path, providers=providers)self.pdist = nn.PairwiseDistance(2)def _to_numpy(self, tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()def predict_vec(self,query):q_id = self.tokenizer(query, max_length = 200, truncation=True, padding="max_length", return_tensors='pt')input_feed = {self.model.get_inputs()[0].name: self._to_numpy(q_id["input_ids"]),self.model.get_inputs()[1].name: self._to_numpy(q_id["attention_mask"]),self.model.get_inputs()[2].name: self._to_numpy(q_id["token_type_ids"]),}return torch.tensor(self.model.run(None, input_feed=input_feed)[0])def predict_sim(self, q1, q2):q1_v = self.predict_vec(q1)q2_v = self.predict_vec(q2)sim = F.cosine_similarity(q1_v[0], q2_v[0], dim=-1)return sim.numpy().tolist()if __name__ == "__main__":import time,randomfrom tqdm import tqdmdevice = torch.device('cuda' if torch.cuda.is_available() else "cpu")# device = ""# vec_model = VectorizeModel('C:/work/tool/huggingface/models/simcse-chinese-roberta-wwm-ext', device=device)vec_model = VectorizeModel_v2('C:/work/tool/huggingface/models/simcse-chinese-roberta-wwm-ext',"./data/model_simcse_roberta_output_20240211.onnx",providers=['CUDAExecutionProvider'])# 单测# q = ["你好啊"]# print(vec_model.predict_vec(q))# print(vec_model.predict_sim("你好呀","你好啊"))tmp_queries = ["你好啊", "今天天气怎么样", "我要暴富"]# 开始批跑batch_sizes = [1,2,4,8,16]for b in batch_sizes:for i in tqdm(range(100),desc="warmup"):tmp_q = []for i in range(b):tmp_q.append(random.choice(tmp_queries))vec_model.predict_vec(tmp_q)for i in tqdm(range(1000),desc="batch_size={}".format(b)):tmp_q = []for i in range(b):tmp_q.append(random.choice(tmp_queries))vec_model.predict_vec(tmp_q)

另外,代码里还提供了ONNX模型的加载和使用接口,也可以参考使用。

有了模型,就可以和前面的向量检索合并成为一个比较完整但简单的检索系统了,我们看searcher.py,这里除了用FAISS做了向量召回外,还做了个简单的RANK,当然了,这个RANK还可以用别的模型代替,这里从简就用了原本的模型做余弦了。

import copy
from loguru import logger
from vec_searcher.vec_searcher import VecSearcher
from vec_model.vec_model import VectorizeModelclass Searcher:def __init__(self, model_path, vec_search_path):self.vec_model = VectorizeModel(model_path)logger.info("load vec_model done")self.vec_searcher = VecSearcher()self.vec_searcher.load(vec_search_path)logger.info("load vec_searcher done")def rank(self, query, recall_result):rank_result = []for idx in range(len(recall_result)):new_sim = self.vec_model.predict_sim(query, recall_result[idx][1][0])rank_item = copy.deepcopy(recall_result[idx])rank_item.append(new_sim)rank_result.append(copy.deepcopy(rank_item))rank_result.sort(key=lambda x: x[3], reverse=True)return rank_resultdef search(self, query, nums=3):logger.info("request: {}".format(query))q_vec = self.vec_model.predict_vec(query).cpu().numpy()recall_result = self.vec_searcher.search(q_vec, nums)rank_result = self.rank(query, recall_result)# rank_result = list(filter(lambda x:x[4] > 0.8, rank_result))logger.info("response: {}".format(rank_result))return rank_resultif __name__ == "__main__":VEC_MODEL_PATH = "C:/work/tool/huggingface/models/simcse-chinese-roberta-wwm-ext"VEC_INDEX_DATA = "vec_index_toutiao_20240702_DEBUG"searcher = Searcher(VEC_MODEL_PATH, VEC_INDEX_DATA)q = "小产权房"print(searcher.search(q))

使用案例

为了展示用法,我是用了项目内同一个数据集作为测试,即这个数据集:https://github.com/aceimnorstuvwxz/toutiao-text-classfication-dataset。反正只要有数据能存进去然后能搜出来就行了对吧(doge)。

上面有一段没提到但是挺重要的脚本,就是入库构造索引的脚本,这块因为会涉及具体需求、数据结构会有变化,所以没写在上面的通用脚本里,所以在大家使用时需要微调的,我给出在这个项目数据下的具体操作:

# coding=utf-8
# Filename:    build_vec_index.py
# Author:      ZENGGUANRONG
# Date:        2024-09-07
# description: 构造向量索引脚本import json,torch,copy,random
from tqdm import tqdm
from loguru import loggerfrom utils.data_processing import load_toutiao_data
from vec_model.vec_model import VectorizeModel
from vec_searcher.vec_searcher import VecSearcher if __name__ == "__main__":# 0. 必要配置MODE = "DEBUG"VERSION = "20240907"VEC_MODEL_PATH = "C:/work/tool/huggingface/models/simcse-chinese-roberta-wwm-ext"SOURCE_INDEX_DATA_PATH = "./data/toutiao_cat_data/toutiao_cat_data.txt" # 数据来源:https://github.com/aceimnorstuvwxz/toutiao-text-classfication-datasetVEC_INDEX_DATA = "vec_index_toutiao_{}_{}".format(VERSION,MODE)# TESE_DATA_PATH = "./data/toutiao_cat_data/test_set_{}_{}.txt".format(VERSION,MODE)RANDOM_SEED = 100DEVICE = torch.device('cuda' if torch.cuda.is_available() else "cpu")# TEST_SIZE = 0.1# 类目体系CLASS_INFO = [["100", '民生-故事', 'news_story'],["101", '文化-文化', 'news_culture'],["102", '娱乐-娱乐', 'news_entertainment'],["103", '体育-体育', 'news_sports'],["104", '财经-财经', 'news_finance'],# ["105", '时政 新时代', 'nineteenth'],["106", '房产-房产', 'news_house'],["107", '汽车-汽车', 'news_car'],["108", '教育-教育', 'news_edu' ],["109", '科技-科技', 'news_tech'],["110", '军事-军事', 'news_military'],# ["111" 宗教 无,凤凰佛教等来源],["112", '旅游-旅游', 'news_travel'],["113", '国际-国际', 'news_world'],["114", '证券-股票', 'stock'],["115", '农业-三农', 'news_agriculture'],["116", '电竞-游戏', 'news_game']]ID2CN_MAPPING = {}for idx in range(len(CLASS_INFO)):ID2CN_MAPPING[CLASS_INFO[idx][0]] = CLASS_INFO[idx][1]# 1. 加载数据、模型# 1.1 加载模型vec_model = VectorizeModel(VEC_MODEL_PATH, DEVICE)index_dim = len(vec_model.predict_vec("你好啊")[0])# 1.2 加载数据toutiao_index_data = load_toutiao_data(SOURCE_INDEX_DATA_PATH)source_index_data = copy.deepcopy(toutiao_index_data)logger.info("load data done: {}".format(len(source_index_data)))if MODE == "DEBUG":random.shuffle(source_index_data)source_index_data = source_index_data[:10000]# 2. 创建索引并灌入数据# 2.1 构造索引vec_searcher = VecSearcher()vec_searcher.build(index_dim, VEC_INDEX_DATA)# 2.2 推理向量vectorize_result = []for q in tqdm(source_index_data, desc="VEC MODEL RUNNING"):vec = vec_model.predict_vec(q[0]).cpu().numpy()tmp_result = copy.deepcopy(q)tmp_result.append(vec)vectorize_result.append(copy.deepcopy(tmp_result))# 2.3 开始存入for idx in tqdm(range(len(vectorize_result)), desc="INSERT INTO INDEX"):vec_searcher.insert(vectorize_result[idx][2], vectorize_result[idx][:2])# 3. 保存# 3.1 索引保存vec_searcher.save()logger.info("build done: {}".format(VEC_INDEX_DATA))

这里的整体流程我都有中文注释,我就不展开说了,大家按照实际情况调整就行,实际上可能只有0和1有变化而已。

在入库以后,就可以开始进行搜索了,此时用searcher就能把内容给搜出来了,回头看下面这串代码,就是加载并且尝试使用了。

if __name__ == "__main__":VEC_MODEL_PATH = "C:/work/tool/huggingface/models/simcse-chinese-roberta-wwm-ext"VEC_INDEX_DATA = "vec_index_toutiao_20240702_DEBUG"searcher = Searcher(VEC_MODEL_PATH, VEC_INDEX_DATA)q = "小产权房"print(searcher.search(q))

af90b2014281c074215cbe4a7013d8c2.png

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • [论文笔记] t-SNE数据可视化
  • 数字逻辑设计基础
  • 数据结构——单链表相关操作
  • 和服务端系统的通信
  • 正则表达式之grep
  • [C#学习笔记]注释
  • 信息学奥赛初赛天天练-86-NOIP2014普及组-基础题5-球盒问题、枚举算法、单源最短路、Dijkstra算法、Bellman-Ford算法
  • 营养方案调整执行流程 第十篇
  • Spring Batch
  • FPGA开发:Verilog数字设计基础
  • [论文笔记]QLoRA: Efficient Finetuning of Quantized LLMs
  • ios免签H5
  • tiptap parseHTML renderHTML 使用
  • 系统架构师考试学习笔记第三篇——架构设计高级知识(19)嵌入式系统架构设计理论与实践
  • 安卓下载工具箱_3.8.1/去浏览器跳转登录就是会员
  • php的引用
  • 【干货分享】SpringCloud微服务架构分布式组件如何共享session对象
  • angular2 简述
  • C++入门教程(10):for 语句
  • iOS 颜色设置看我就够了
  • JAVA多线程机制解析-volatilesynchronized
  • Laravel5.4 Queues队列学习
  • MD5加密原理解析及OC版原理实现
  • mysql常用命令汇总
  • ReactNativeweexDeviceOne对比
  • Redis中的lru算法实现
  • Vue--数据传输
  • 从0实现一个tiny react(三)生命周期
  • 规范化安全开发 KOA 手脚架
  • 入口文件开始,分析Vue源码实现
  • 深入浏览器事件循环的本质
  • 使用阿里云发布分布式网站,开发时候应该注意什么?
  • 适配iPhoneX、iPhoneXs、iPhoneXs Max、iPhoneXr 屏幕尺寸及安全区域
  • 想晋级高级工程师只知道表面是不够的!Git内部原理介绍
  • Salesforce和SAP Netweaver里数据库表的元数据设计
  • 回归生活:清理微信公众号
  • 继 XDL 之后,阿里妈妈开源大规模分布式图表征学习框架 Euler ...
  • ​力扣解法汇总1802. 有界数组中指定下标处的最大值
  • ######## golang各章节终篇索引 ########
  • #设计模式#4.6 Flyweight(享元) 对象结构型模式
  • #数据结构 笔记三
  • #我与Java虚拟机的故事#连载02:“小蓝”陪伴的日日夜夜
  • (1)SpringCloud 整合Python
  • (C语言)输入一个序列,判断是否为奇偶交叉数
  • (el-Date-Picker)操作(不使用 ts):Element-plus 中 DatePicker 组件的使用及输出想要日期格式需求的解决过程
  • (Oracle)SQL优化基础(三):看懂执行计划顺序
  • (创新)基于VMD-CNN-BiLSTM的电力负荷预测—代码+数据
  • (非本人原创)史记·柴静列传(r4笔记第65天)
  • (附表设计)不是我吹!超级全面的权限系统设计方案面世了
  • (含react-draggable库以及相关BUG如何解决)固定在左上方某盒子内(如按钮)添加可拖动功能,使用react hook语法实现
  • (学习总结16)C++模版2
  • (已解决)报错:Could not load the Qt platform plugin “xcb“
  • (游戏设计草稿) 《外卖员模拟器》 (3D 科幻 角色扮演 开放世界 AI VR)
  • (转)shell中括号的特殊用法 linux if多条件判断
  • .equals()到底是什么意思?