当前位置: 首页 > news >正文

Pytorch获取特征图

简单加载官方预训练模型

  • torchvision.models预定义了很多公开的模型结构
  • 如果pretrained参数设置为False,那么仅仅设定模型结构;如果设置为True,那么会启动一个下载流程,下载预训练参数
  • 如果只想调用模型,不想训练,那么设置model.eval()model.requires_grad_(False)
  • 想查看模型参数可以使用modulesnamed_modules,其中named_modules是一个长度为2的tuple,第一个变量是name,第二个变量是module本身。
# -*- coding: utf-8 -*-
from torch import nn
from torchvision import models

# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model.eval()
model.requires_grad_(False)

# get model component
features = model.features
modules = features.modules()
named_modules = features.named_modules()

# print modules
for module in modules:
    if isinstance(module, nn.Conv2d):
        weight = module.weight
        bias = module.bias
        print(module, weight.shape, bias.shape,
              weight.requires_grad, bias.requires_grad)
    elif isinstance(module, nn.ReLU):
        print(module)

print()
for named_module in named_modules:
    name = named_module[0]
    module = named_module[1]
    if isinstance(module, nn.Conv2d):
        weight = module.weight
        bias = module.bias
        print(name, module, weight.shape, bias.shape,
              weight.requires_grad, bias.requires_grad)
    elif isinstance(module, nn.ReLU):
        print(name, module)

图片预处理

  • 使用opencv和pil读图都可以使用transforms.ToTensor()把原本[H, W, 3]的数据转成[3, H, W]的tensor。但opencv要注意把数据改成RGB顺序。
  • vgg系列模型需要做normalization,建议配合torchvision.transforms来实现。
  • mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].

参考:https://pytorch.org/hub/pytorch_vision_vgg/

# -*- coding: utf-8 -*-
from PIL import Image
import cv2
import torch
from torchvision import transforms

# transforms for preprocess
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# load image using cv2
image_cv2 = cv2.imread('lena_std.bmp')
image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image_cv2 = preprocess(image_cv2)

# load image using pil
image_pil = Image.open('lena_std.bmp')
image_pil = preprocess(image_pil)

# check whether image_cv2 and image_pil are same
print(torch.all(image_cv2 == image_pil))
print(image_cv2.shape, image_pil.shape)

提取单个特征图

如果只提取单层特征图,可以把模型截断,以节省算力和显存消耗。
下面索引之所以有+1是因为pytorch预训练模型里面第一个索引的module总是完整模块结构,第二个才开始子模块。

# -*- coding: utf-8 -*-
from PIL import Image
from torchvision import models
from torchvision import transforms

# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model = model.features[:16 + 1]  # 16 = conv3_4
model.eval()
model.requires_grad_(False)
model.to('cuda')
print(model)

# load and preprocess image
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0)  # add batch dimension
inputs = inputs.cuda()

# forward
output = model(inputs)
print(output.shape)

提取多个特征图

  • 第一种方式:逐层运行model,如果碰到了需要保存的feature map就存下来。
  • 第二种方式:使用register_forward_hook,使用这种方式需要用一个类把feature map以成员变量的形式缓存下来。
  • 两种方式的运行效率差不多
  • 第一种方式简单直观,但是只能处理类似VGG这种没有跨层连接的网络;第二种方式更加通用。
# -*- coding: utf-8 -*-
from PIL import Image
import torch
from torchvision import models
from torchvision import transforms

# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model = model.features[:16 + 1]  # 16 = conv3_4
model.eval()
model.requires_grad_(False)
model.to('cuda')

# check module name
for named_module in model.named_modules():
    name = named_module[0]
    module = named_module[1]
    print('-------- %s --------' % name)
    print(module)
    print()

# load and preprocess image
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0)  # add batch dimension
inputs = inputs.cuda()

# forward - 1
layers = [2, 7, 8, 9, 16]
layers = sorted(set(layers))
feature_maps = {}
feature = inputs
for i in range(max(layers) + 1):
    feature = model[i](feature)
    if i in layers:
        feature_maps[i] = feature
for key in feature_maps:
    print(key, feature_maps.get(key).shape)


# forward - 2
class FeatureHook:
    def __init__(self, module):
        self.inputs = None
        self.output = None
        self.hook = module.register_forward_hook(self.get_features)

    def get_features(self, module, inputs, output):
        self.inputs = inputs
        self.output = output


layer_names = ['2', '7', '8', '9', '16']
hook_modules = []
for named_module in model.named_modules():
    name = named_module[0]
    module = named_module[1]
    if name in layer_names:
        hook_modules.append(module)

hooks = [FeatureHook(module) for module in hook_modules]
output = model(inputs)
features = [hook.output for hook in hooks]
for feature in features:
    print(feature.shape)

# check correctness
for i, layer in enumerate(layers):
    feature1 = feature_maps.get(layer)
    feature2 = features[i]
    print(torch.all(feature1 == feature2))

使用第二种方式(register_forward_hook),resnet特征图也可以顺利拿到。
而由于resnet的model已经不可以用model[i]的形式索引,所以无法使用第一种方式。

# -*- coding: utf-8 -*-
from PIL import Image
from torchvision import models
from torchvision import transforms

# load model. If pretrained is True, there will be a downloading process
model = models.resnet18(pretrained=True)
model.eval()
model.requires_grad_(False)
model.to('cuda')

# check module name
for named_module in model.named_modules():
    name = named_module[0]
    module = named_module[1]
    print('-------- %s --------' % name)
    print(module)
    print()

# load and preprocess image
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0)  # add batch dimension
inputs = inputs.cuda()


class FeatureHook:
    def __init__(self, module):
        self.inputs = None
        self.output = None
        self.hook = module.register_forward_hook(self.get_features)

    def get_features(self, module, inputs, output):
        self.inputs = inputs
        self.output = output


layer_names = [
    'conv1',
    'layer1.0.relu',
    'layer2.0.conv1'
]

hook_modules = []
for named_module in model.named_modules():
    name = named_module[0]
    module = named_module[1]
    if name in layer_names:
        hook_modules.append(module)

hooks = [FeatureHook(module) for module in hook_modules]
output = model(inputs)
features = [hook.output for hook in hooks]
for feature in features:
    print(feature.shape)

问题来了,resnet这种类型的网络结构怎么截断?
使用如下命令就可以,print查看需要截断到哪里,然后用nn.Sequential重组即可。
需注意重组后网络的module_name会发生变化

print(list(model.children())
model = torch.nn.Sequential(*list(model.children())[:6])

相关文章:

  • yaml文件格式说明及编写教程
  • java计算机毕业设计能源类网站平台源码+系统+数据库+lw文档+mybatis+运行部署
  • 个人做量化交易一定不靠谱?
  • 迅为RK3588开发板编译环境Ubuntu20.04编译配置-增加交换内存
  • 申报绿色工厂的条件是什么
  • Android面试官:入职大厂的Android程序员具备怎样的专业素养?
  • 六大设计原则
  • VMware vCenter Server 7 升级
  • Word控件Spire.Doc 【段落处理】教程(十二):如何在 C# 中管理 word 文档的分页
  • 在线批注审片工具有哪些?分秒帧团队版与个人版的主要区别
  • 中国内窥镜行业市场投资战略规划分析报告
  • flink 自定义序列化对象Sink/Source
  • 目前期货开户手续费比较透明
  • 深度跳转-scheme
  • 2022 全球 AI 模型周报
  • 时间复杂度分析经典问题——最大子序列和
  • [Vue CLI 3] 配置解析之 css.extract
  • 《Javascript数据结构和算法》笔记-「字典和散列表」
  • 【159天】尚学堂高琪Java300集视频精华笔记(128)
  • Dubbo 整合 Pinpoint 做分布式服务请求跟踪
  • exports和module.exports
  • happypack两次报错的问题
  • input的行数自动增减
  • java正则表式的使用
  • laravel 用artisan创建自己的模板
  • orm2 中文文档 3.1 模型属性
  • Quartz实现数据同步 | 从0开始构建SpringCloud微服务(3)
  • Redis 懒删除(lazy free)简史
  • VirtualBox 安装过程中出现 Running VMs found 错误的解决过程
  • Vue2.0 实现互斥
  • 大主子表关联的性能优化方法
  • 基于HAProxy的高性能缓存服务器nuster
  • 简单实现一个textarea自适应高度
  • 如何优雅的使用vue+Dcloud(Hbuild)开发混合app
  • 思否第一天
  • 微信开源mars源码分析1—上层samples分析
  • 优化 Vue 项目编译文件大小
  • NLPIR智能语义技术让大数据挖掘更简单
  • UI设计初学者应该如何入门?
  • (1)Nginx简介和安装教程
  • (Redis使用系列) Springboot 实现Redis 同数据源动态切换db 八
  • (附源码)spring boot儿童教育管理系统 毕业设计 281442
  • (附源码)计算机毕业设计SSM教师教学质量评价系统
  • (一)基于IDEA的JAVA基础12
  • (原创)攻击方式学习之(4) - 拒绝服务(DOS/DDOS/DRDOS)
  • (原創) 是否该学PetShop将Model和BLL分开? (.NET) (N-Tier) (PetShop) (OO)
  • (转)linux自定义开机启动服务和chkconfig使用方法
  • (转)Scala的“=”符号简介
  • (转)程序员技术练级攻略
  • **PHP分步表单提交思路(分页表单提交)
  • .form文件_一篇文章学会文件上传
  • .Net Core/.Net6/.Net8 ,启动配置/Program.cs 配置
  • .NET Windows:删除文件夹后立即判断,有可能依然存在
  • .NET 实现 NTFS 文件系统的硬链接 mklink /J(Junction)
  • .net 怎么循环得到数组里的值_关于js数组