当前位置: 首页 > news >正文

(pytorch进阶之路)CLIP模型 实现图像多模态检索任务

CLIP模型解决了一个多模态问题

代码地址:
https://github.com/yyz159756/CLIP-VIT-


文章目录

  • 概述
  • CLIP
  • 代码实现
    • 划分训练集和测试集
    • 统计所有图片的每个通道的均值和标准差
    • 搜索图片引擎
  • 边角料

概述

问题描述:输入一张照片,从数据库中找到最相近的一张照片

解决思路:将图片emb,用余弦 相似度计算图片emb,排序返回topk的图片

图片embbing预训练模型选择,图片表征模型那就有太多的选择了,一种是纯粹的图像分类的预训练模型,比如用resnet,VIT,VGG,纯粹基于图像识别得到的预训练模型,再或者是自监督学习的MAE得到的encoder emb,

这里我们使用CLIP模型,CLIP是基于图像和文本两个领域的数据训练出来的表征模型

为什么用CLIP模型,而不用视觉通用模型呢?
CLIP优点是同类型的文字和图像有着很高的相似度,所以可以完成一个多模态的搜索任务

CLIP项目地址:https://github.com/openai/CLIP
在这里插入图片描述
安装CLIP模型,注意要安装git
pip install git+https://github.com/openai/CLIP.git

CLIP是一个多模态的模型,它不仅仅能表征图像,还能够识别图像中的文字信息

果蔬分类数据集:
https://aistudio.baidu.com/aistudio/datasetdetail/119023/0
下载后解压到data/raw

只能下载数据集,只有train,因此我们还得自己构建测试集和验证集

CLIP

利用text信息监督视觉任务自训练,本质就是将分类任务化成了图文匹配任务,效果可与全监督方法相当

《Learning transferable visual models from natural language supervision》
在这里插入图片描述
核心思想:对比学习
I1是图片表征,T1是文字表征,我们希望In和Tn相似度尽可能的大,表示文字和图片是匹配的,要么相似要么不相似进行对比,学会如何将文字和图像连接起来的一个编码

一个是图像编码器(resnet或者vit模型),一个文本编码器(transformer模型),是一个经典的双塔模型

伪代码部分写的很清晰,图像经过图像编码器得到图像emb,文字经过编码器得到文字emb,经过l2正则化,直接对两个emb矩阵做dot矩阵乘法,做一个预测任务,预测这个图片是一个什么东西,loss计算交叉熵,总体模型十分的简单,所用的数据集比较大
在这里插入图片描述

应用的话,hugging face里面已经有CLIP的模型库了
https://huggingface.co/docs/transformers/model_doc/clip

或者用sentence_transformers库

代码实现

划分训练集和测试集

这部分代码挺常用了,一般图像分类任务数据集都是每个子文件名称表示类别名称,子文件夹中有具体的样本图片

预处理起码做两件事情,转化图片为RGB通道,图片大小规整到同样大小的形状如128×128(我电脑比较破,就64×64了)

split_data.py

统计所有图片的每个通道的均值和标准差

统计一个mean和var
statistic_mean_std.py
我提前统计好了
mean是[0.47043142, 0.43239759, 0.32576062]
var是[0.37251864, 0.35710091, 0.3417497]
因为resnet模型要求图片先归一化

搜索图片引擎

搜索图片引擎,顾名思义就是搜索图片用的,有两种搜索方式,一是以图搜图,二是文字搜图,我在github中实现的是以图搜图功能,文字搜图看最后的边角料就好了,官方给的例子就是以文字搜图的eg了,几行代码搞定,非常简单

下面开始以图搜图的实现

第一步,我们对图像计算表征向量

我们需要加载模型,加载模型可以使用timm和clip库

model = timm.create_model(args.model_name, pretrained=True)
model, preprocess = clip.load(model_name, device=device)

clip的多返回了一个preprocess预处理函数,我们就可以不用自己做预处理(Transformation)

加载好模型后我们需要对train数据集里面的所有的图片做一个表征

我们使用模型的某一层或者最后一层的特征作为图像的抽象表征

all_vectors = extract_features(args, model, image_path='data/val_images', preprocess=preprocess)

第二阶段,预测阶段,找到所有的测试图片,也做一个表征,对图片计算余弦相似度

similarity, keys = get_similarity_matrix(all_vectors)

我们都知道余弦相似度的计算公式
在这里插入图片描述
经过线性代数所学,分子可以矩阵×矩阵的转置
[num, dim] × [dim, num] = [num, num]
那么第一行就是第一张照片和所有照片表征的内积,以此类推

那么分母就是向量的模

我们用np的linalg.norm对每一列求一个范数([num, ]),同时维度保持不变,再和它自身转置相乘
那么第一行就是第一张照片的表征模和所有照片表征的模的乘积

计算完余弦相似度后,我们就是要对测试图片和其他图片的余弦相似度进行排序,排序后进行输出

以上是以图搜图的功能

边角料

如果要完成文字搜图的功能那是完全类似的,将文字编码,文字编码和所有图片做一个余弦相似度即可,我们返回topk个图片

下面直接给官方例子,非常的哇塞,也非常的简单有没有

import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

相关文章:

  • RabbitMQ 集群部署及配置
  • python @classmethod详解
  • JSP超市管理系统myeclipse定制开发SQLServer数据库网页模式java编程jdbc
  • FreeRTOS 软件定时器的使用
  • 软件测试培训到底值不值得参加?
  • IIC通信协议
  • FUP AMD300-27便携式拉曼食品安全分析仪 检测微痕量农兽药残留 非法添加
  • 高被引论文有什么特征?
  • MMDet 3.0:目标检测新基准与前沿
  • (入门自用)--C++--抽象类--多态原理--虚表--1020
  • 风险:一些Web3安全工具
  • 【RBF预测】基于时空 RBF-NN 实现混沌时间序列预测附matlab代码
  • spring 入门
  • 【Git】Git基本配置和常用命令
  • 软考中级(软件设计师)——数据库系统(上下午各占6-8分)
  • Android Studio:GIT提交项目到远程仓库
  • Asm.js的简单介绍
  • E-HPC支持多队列管理和自动伸缩
  • Fastjson的基本使用方法大全
  • java8 Stream Pipelines 浅析
  • 阿里云前端周刊 - 第 26 期
  • 百度地图API标注+时间轴组件
  • 道格拉斯-普克 抽稀算法 附javascript实现
  • 短视频宝贝=慢?阿里巴巴工程师这样秒开短视频
  • 分享一个自己写的基于canvas的原生js图片爆炸插件
  • 聊一聊前端的监控
  • 每天10道Java面试题,跟我走,offer有!
  • 排序算法之--选择排序
  • 事件委托的小应用
  • 一文看透浏览器架构
  • 用jquery写贪吃蛇
  • 与 ConTeXt MkIV 官方文档的接驳
  • kubernetes资源对象--ingress
  • PostgreSQL 快速给指定表每个字段创建索引 - 1
  • ​​​​​​​sokit v1.3抓手机应用socket数据包: Socket是传输控制层协议,WebSocket是应用层协议。
  • ​你们这样子,耽误我的工作进度怎么办?
  • # Apache SeaTunnel 究竟是什么?
  • #pragma 指令
  • #考研#计算机文化知识1(局域网及网络互联)
  • #控制台大学课堂点名问题_课堂随机点名
  • $(function(){})与(function($){....})(jQuery)的区别
  • (1)SpringCloud 整合Python
  • (iPhone/iPad开发)在UIWebView中自定义菜单栏
  • (ZT)北大教授朱青生给学生的一封信:大学,更是一个科学的保证
  • (二)pulsar安装在独立的docker中,python测试
  • (附源码)计算机毕业设计SSM教师教学质量评价系统
  • *** 2003
  • .locked1、locked勒索病毒解密方法|勒索病毒解决|勒索病毒恢复|数据库修复
  • .net 简单实现MD5
  • .Net 转战 Android 4.4 日常笔记(4)--按钮事件和国际化
  • .NET上SQLite的连接
  • .Net转Java自学之路—基础巩固篇十三(集合)
  • .pyc文件是什么?
  • @SentinelResource详解
  • [100天算法】-实现 strStr()(day 52)