当前位置: 首页 > news >正文

昇思25天学习打卡营第16天 | Vision Transformer图像分类

昇思25天学习打卡营第16天 | Vision Transformer图像分类

文章目录

  • 昇思25天学习打卡营第16天 | Vision Transformer图像分类
    • Vision Transform(ViT)模型
      • Transformer
        • Attention模块
        • Encoder模块
      • ViT模型输入
    • 模型构建
      • Multi-Head Attention模块
      • Encoder模块
      • Patch Embedding模块
      • ViT网络
    • 总结
    • 打卡

Vision Transform(ViT)模型

ViT是NLP和CV领域的融合,可以在不依赖于卷积操作的情况下在图像分类任务上达到很好的效果。

ViT模型的主体结构是基于Transformer的Encoder部分。

Transformer

Transformer由很多Encoder和Decoder模块构成,包括多头注意力(Multi-Head Attention)层,Feed Forward层,Normalization层和残差连接(Residual Connection)。
encoder-decoder
多头注意力结构基于自注意力机制(Self-Attention),是多个Self-Attention的并行组成。

Attention模块

Attention的核心在于为输入向量的每个单词学习一个权重。

  1. 最初的输入向量首先经过Embedding层映射为Q(Query),K(Key),V(Value)三个向量。
  2. 通过将Q和所有K进行点乘初一维度平方根,得到向量间的相似度,通过softmax获取每词向量之间的关系权重。
  3. 利用关系权重对词向量的V加权求和,得到自注意力值。
    self-attention
    多头注意力机制只是对self-attention的并行化:
    multi-head-attention
Encoder模块

ViT中的Encoder相对于标准Transformer,主要在于将Normolization放在self-attention和Feed Forward之前,其他结构与标准Transformer相同。
vit-encoder

ViT模型输入

传统Transformer主要应用于自然语言处理的一维词向量,而图像时二维矩阵的堆叠。
在ViT中:

  1. 通过卷积将输入图像在每个channel上划分为 16 × 16 16\times 16 16×16个patch。如果输入 224 × 224 224\times224 224×224的图像,则每一个patch的大小为 14 × 14 14\times 14 14×14
  2. 将每一个patch拉伸为一个一维向量,得到近似词向量堆叠的效果。如将 14 × 14 14\times14 14×14展开为 196 196 196的向量。
    这一部分Patch Embedding用来替换Transformer中Word Embedding,用作网络中的图像输入。

模型构建

Multi-Head Attention模块

from mindspore import nn, opsclass Attention(nn.Cell):def __init__(self,dim: int,num_heads: int = 8,keep_prob: float = 1.0,attention_keep_prob: float = 1.0):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)def construct(self, x):"""Attention construct."""b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)return out

Encoder模块

from typing import Optional, Dictclass FeedForward(nn.Cell):def __init__(self,in_features: int,hidden_features: Optional[int] = None,out_features: Optional[int] = None,activation: nn.Cell = nn.GELU,keep_prob: float = 1.0):super(FeedForward, self).__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dense1 = nn.Dense(in_features, hidden_features)self.activation = activation()self.dense2 = nn.Dense(hidden_features, out_features)self.dropout = nn.Dropout(p=1.0-keep_prob)def construct(self, x):"""Feed Forward construct."""x = self.dense1(x)x = self.activation(x)x = self.dropout(x)x = self.dense2(x)x = self.dropout(x)return xclass ResidualCell(nn.Cell):def __init__(self, cell):super(ResidualCell, self).__init__()self.cell = celldef construct(self, x):"""ResidualCell construct."""return self.cell(x) + xclass TransformerEncoder(nn.Cell):def __init__(self,dim: int,num_layers: int,num_heads: int,mlp_dim: int,keep_prob: float = 1.,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim,num_heads=num_heads,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob)feedforward = FeedForward(in_features=dim,hidden_features=mlp_dim,activation=activation,keep_prob=keep_prob)layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])),ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)def construct(self, x):"""Transformer construct."""return self.layers(x)

Patch Embedding模块

class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4def __init__(self,image_size: int = 224,patch_size: int = 16,embed_dim: int = 768,input_channels: int = 3):super(PatchEmbedding, self).__init__()self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)def construct(self, x):"""Path Embedding construct."""x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))return x

ViT网络

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameterdef init(init_type, shape, dtype, name, requires_grad):"""Init."""initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)class ViT(nn.Cell):def __init__(self,image_size: int = 224,input_channels: int = 3,patch_size: int = 16,embed_dim: int = 768,num_layers: int = 12,num_heads: int = 12,mlp_dim: int = 3072,keep_prob: float = 1.0,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: Optional[nn.Cell] = nn.LayerNorm,pool: str = 'cls') -> None:super(ViT, self).__init__()self.patch_embedding = PatchEmbedding(image_size=image_size,patch_size=patch_size,embed_dim=embed_dim,input_channels=input_channels)num_patches = self.patch_embedding.num_patchesself.cls_token = init(init_type=Normal(sigma=1.0),shape=(1, 1, embed_dim),dtype=ms.float32,name='cls',requires_grad=True)self.pos_embedding = init(init_type=Normal(sigma=1.0),shape=(1, num_patches + 1, embed_dim),dtype=ms.float32,name='pos_embedding',requires_grad=True)self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim,num_layers=num_layers,num_heads=num_heads,mlp_dim=mlp_dim,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob,drop_path_keep_prob=drop_path_keep_prob,activation=activation,norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)def construct(self, x):"""ViT construct."""x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embeddingx = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)return x

总结

这一节对Transformer进行介绍,包括Attention机制、并行化的Attention以及Encoder模块。由于传统Transformer主要作用于一维的词向量,因此二维图像需要被转换为类似的一维词向量堆叠,在ViT中通过将Patch Embedding解决这一问题,并用来代替传统Transformer中的Word Embedding作为网络的输入。

打卡

在这里插入图片描述

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • JavaWeb入门程序解析(Spring官方骨架、配置起步依赖、SpringBoot父工程、内嵌Tomcat)
  • 2、电脑各部件品牌介绍 - 计算机硬件品牌系列文章
  • 数据结构(Java):力扣 二叉树面试OJ题(二)【进阶】
  • NLP篇5:自然语言处理预训练
  • 【python】多种回归算法对比气温预测
  • 云计算监控减少网络安全事件的五种方法
  • LinuxShell编程1———shell基础命令
  • 打印室预约小程序的设计
  • 【C++】类和对象·this指针
  • 科技出海|百分点科技智慧政务解决方案亮相非洲展会
  • HLS加密技术:保障流媒体内容安全的利器
  • WebAssembly与JavaScript的交互(1)
  • Mongodb文本索引
  • react页面指定dom转pdf导出
  • 网络通信介绍
  • 【跃迁之路】【585天】程序员高效学习方法论探索系列(实验阶段342-2018.09.13)...
  • Angular6错误 Service: No provider for Renderer2
  • Centos6.8 使用rpm安装mysql5.7
  • ES6语法详解(一)
  • Java编程基础24——递归练习
  • opencv python Meanshift 和 Camshift
  • 猫头鹰的深夜翻译:JDK9 NotNullOrElse方法
  • 目录与文件属性:编写ls
  • 前端相关框架总和
  • 使用Maven插件构建SpringBoot项目,生成Docker镜像push到DockerHub上
  • 使用parted解决大于2T的磁盘分区
  • 使用权重正则化较少模型过拟合
  • 微服务核心架构梳理
  • 微信小程序设置上一页数据
  • 一道面试题引发的“血案”
  • 一加3T解锁OEM、刷入TWRP、第三方ROM以及ROOT
  • 优秀架构师必须掌握的架构思维
  • 职业生涯 一个六年开发经验的女程序员的心声。
  • #13 yum、编译安装与sed命令的使用
  • #mysql 8.0 踩坑日记
  • #ubuntu# #git# repository git config --global --add safe.directory
  • #Z2294. 打印树的直径
  • (145)光线追踪距离场柔和阴影
  • (2022版)一套教程搞定k8s安装到实战 | RBAC
  • (HAL)STM32F103C6T8——软件模拟I2C驱动0.96寸OLED屏幕
  • (pojstep1.1.2)2654(直叙式模拟)
  • (Redis使用系列) Springboot 整合Redisson 实现分布式锁 七
  • (代码示例)使用setTimeout来延迟加载JS脚本文件
  • (二)Kafka离线安装 - Zookeeper下载及安装
  • (附源码)php投票系统 毕业设计 121500
  • (六)激光线扫描-三维重建
  • (十)【Jmeter】线程(Threads(Users))之jp@gc - Stepping Thread Group (deprecated)
  • (十一)JAVA springboot ssm b2b2c多用户商城系统源码:服务网关Zuul高级篇
  • (原創) 人會胖會瘦,都是自我要求的結果 (日記)
  • (转载)微软数据挖掘算法:Microsoft 时序算法(5)
  • (轉貼) 2008 Altera 亞洲創新大賽 台灣學生成果傲視全球 [照片花絮] (SOC) (News)
  • ******之网络***——物理***
  • .NET CORE 3.1 集成JWT鉴权和授权2
  • .net core 6 集成和使用 mongodb
  • .net core IResultFilter 的 OnResultExecuted和OnResultExecuting的区别