当前位置: 首页 > news >正文

如何查看resnet网络的中间输出特征和卷积核的参数

查看中间层的特征,需要在定义Model时,在forward时,将中间要显示的层输出。

    def forward(self, x):outputs = []x = self.conv1(x)outputs.append(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)outputs.append(x)# x = self.layer2(x)# x = self.layer3(x)# x = self.layer4(x)## if self.include_top:#     x = self.avgpool(x)#     x = torch.flatten(x, 1)#     x = self.fc(x)return outputs

这里在convert1后和layer1后添加到一个列表中,然后输出。后面的就不进行卷积操作了。

 

import torch.nn as nn
import torchclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(out_channel)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(out_channel)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, blocks_num, num_classes=1000, include_top=True):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, blocks_num[0])self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)if self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')def _make_layer(self, block, channel, block_num, stride=1):downsample = Noneif stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel, channel))return nn.Sequential(*layers)def forward(self, x):outputs = []x = self.conv1(x)outputs.append(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)outputs.append(x)# x = self.layer2(x)# x = self.layer3(x)# x = self.layer4(x)## if self.include_top:#     x = self.avgpool(x)#     x = torch.flatten(x, 1)#     x = self.fc(x)return outputsdef resnet34(num_classes=1000, include_top=True):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

 然后就可以在预测的时候输出中间层。

import torch
from alexnet_model import AlexNet
from resnet_model import resnet34
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from torchvision import transformsdata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# data_transform = transforms.Compose(
#     [transforms.Resize(256),
#      transforms.CenterCrop(224),
#      transforms.ToTensor(),
#      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# create model
model = AlexNet(num_classes=5)
# model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"  # "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
print(model)# load image
img = Image.open("../tulip.jpg")
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# forward
out_put = model(img)
for feature_map in out_put:# [N, C, H, W] -> [C, H, W]im = np.squeeze(feature_map.detach().numpy())# [C, H, W] -> [H, W, C]im = np.transpose(im, [1, 2, 0])# show top 12 feature mapsplt.figure()for i in range(12):ax = plt.subplot(3, 4, i+1)# [H, W, C]plt.imshow(im[:, :, i], cmap='gray')plt.show()

输出卷积核的参数

import torch
from alexnet_model import AlexNet
from resnet_model import resnet34
import matplotlib.pyplot as plt
import numpy as np# create model
model = AlexNet(num_classes=5)
# model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"  # "resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
print(model)weights_keys = model.state_dict().keys()
for key in weights_keys:# remove num_batches_tracked para(in bn)if "num_batches_tracked" in key:continue# [kernel_number, kernel_channel, kernel_height, kernel_width]weight_t = model.state_dict()[key].numpy()# read a kernel information# k = weight_t[0, :, :, :]# calculate mean, std, min, maxweight_mean = weight_t.mean()weight_std = weight_t.std(ddof=1)weight_min = weight_t.min()weight_max = weight_t.max()print("mean is {}, std is {}, min is {}, max is {}".format(weight_mean,weight_std,weight_max,weight_min))# plot hist imageplt.close()weight_vec = np.reshape(weight_t, [-1])plt.hist(weight_vec, bins=50)plt.title(key)plt.show()

相关文章:

  • 工厂模式~
  • 怎么培养孩子的学习习惯?
  • 【MybatisPlus】BaseMapper详解,举例说明
  • 探索React中的类组件和函数组件
  • 学AI,3种人,3种学法
  • QT 自定义信号
  • YoloV7改进策略:卷积改进|MogaNet——高效的多阶门控聚合网络
  • Claude 3 Sonnet 模型现已在亚马逊云科技的 Amazon Bedrock 正式可用!
  • GPT实战系列-LangChain如何构建基通义千问的多工具链
  • 数据库--SQL语言-1
  • 深入了解二叉搜索树:原理、实现与应用
  • C语言-写一个简单的Web服务器(一)
  • uniapp+node.js前后端做帖子模块:发布帖子评论(社区管理平台的小程序)
  • 链表中的经典问题——反转链表
  • C#拾遗补漏之goto跳转语句
  • 2017-08-04 前端日报
  • javascript数组去重/查找/插入/删除
  • js中的正则表达式入门
  • Linux编程学习笔记 | Linux IO学习[1] - 文件IO
  • node.js
  • vue总结
  • 那些年我们用过的显示性能指标
  • 配置 PM2 实现代码自动发布
  • 前端设计模式
  • 一个6年java程序员的工作感悟,写给还在迷茫的你
  • 一些基于React、Vue、Node.js、MongoDB技术栈的实践项目
  • 用 vue 组件自定义 v-model, 实现一个 Tab 组件。
  • 用quicker-worker.js轻松跑一个大数据遍历
  • Mac 上flink的安装与启动
  • PostgreSQL 快速给指定表每个字段创建索引 - 1
  • zabbix3.2监控linux磁盘IO
  • ​ ​Redis(五)主从复制:主从模式介绍、配置、拓扑(一主一从结构、一主多从结构、树形主从结构)、原理(复制过程、​​​​​​​数据同步psync)、总结
  • # 计算机视觉入门
  • ###51单片机学习(2)-----如何通过C语言运用延时函数设计LED流水灯
  • #中的引用型是什么意识_Java中四种引用有什么区别以及应用场景
  • $.proxy和$.extend
  • (1)(1.11) SiK Radio v2(一)
  • (4)(4.6) Triducer
  • (4)事件处理——(2)在页面加载的时候执行任务(Performing tasks on page load)...
  • (C语言版)链表(三)——实现双向链表创建、删除、插入、释放内存等简单操作...
  • (done) ROC曲线 和 AUC值 分别是什么?
  • (附源码)springboot 基于HTML5的个人网页的网站设计与实现 毕业设计 031623
  • (转)visual stdio 书签功能介绍
  • ***php进行支付宝开发中return_url和notify_url的区别分析
  • **登录+JWT+异常处理+拦截器+ThreadLocal-开发思想与代码实现**
  • .net core 6 集成 elasticsearch 并 使用分词器
  • .net core使用ef 6
  • .net6+aspose.words导出word并转pdf
  • .NET实现之(自动更新)
  • /bin/bash^M: bad interpreter: No such file or directory
  • @require_PUTNameError: name ‘require_PUT‘ is not defined 解决方法
  • @transactional 方法执行完再commit_当@Transactional遇到@CacheEvict,你的代码是不是有bug!...
  • [17]JAVAEE-HTTP协议
  • [2016.7.Test1] T1 三进制异或
  • [C#基础知识系列]专题十七:深入理解动态类型