模型中间部分的卷积可视化
整体代码如下:
def forward(self, x):x = self.conv1(x)x1 = xout_img3 = x1.squeeze()print(out_img3.shape)print("经过第一个卷积之后的输出:",x.shape)
import yaml
from omegaconf import OmegaConf
from pathlib import Path
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# 加载图像
img_path = "E:/fasterrcnn/input_images/cat_dog.png"
image = Image.open(img_path)
transform = transforms.ToTensor()
tensor_image = transform(image)
transform_size = transforms.Resize((512,512))
out = transform_size(tensor_image)
tensor_image_batch = out.unsqueeze(0)yaml_file_path = 'F:/code/DEKR-main/experiments/coco/w32/w32_4x_reg03_bs10_512_adam_lr1e-3_coco_x140.yaml'
yaml_file = Path(yaml_file_path)
if yaml_file.exists():with yaml_file.open('r') as file:cfg_dict = yaml.safe_load(file)cfg = OmegaConf.create(cfg_dict)out_img3 = PoseHigherResolutionNet(cfg)input = torch.randn(1, 3, 512, 512)out_img3 = out_img3.forward(tensor_image_batch)# 选择要可视化的通道索引selected_channels = [0, 15, 30, 45, 63] # 例如,选择第0、15、30、45和63个通道# 创建一个包含多个子图的图形fig, axes = plt.subplots(1, len(selected_channels), figsize=(15, 5))# 遍历选定的通道并绘制它们for i, ax in enumerate(axes):channel_image = out_img3[selected_channels[i]] # 提取选定通道的图像print(channel_image.shape)out_img1 = transforms.ToPILImage()(channel_image)ax.imshow(out_img1, cmap='viridis') # 使用 viridis 颜色图显示图像ax.set_title(f'Channel {selected_channels[i]}') # 设置子图的标题ax.axis('off') # 关闭坐标轴# 显示图形plt.tight_layout() # 调整子图参数, 使之填充整个图像区域plt.show()# print(pose.forward(input))else:print(f"Error: The YAML file {yaml_file_path} does not exist.")
首先加载图像
img_path = "E:/fasterrcnn/input_images/cat_dog.png"
image = Image.open(img_path)
然后将图像转换为tensor,也就是数组的形式
transform = transforms.ToTensor()
tensor_image = transform(image)
之后resize大小为(3,512,512)的形式
transform_size = transforms.Resize((512,512))
out = transform_size(tensor_image)
到了这里还不能作为模型的输入,模型的输入是四维的(batch_size,input_channel,width,height)
所以要加一个维度
tensor_image_batch = out.unsqueeze(0)
调用模型
out_img3 = PoseHigherResolutionNet(cfg)input = torch.randn(1, 3, 512, 512)out_img3 = out_img3.forward(tensor_image_batch)
我们取前向传播的第一个卷积输出,输出的大小为(1,64,112,112),将输出x定义为x1,并用squeeze将其batch_size去掉变成(64,112,112)
def forward(self, x):x = self.conv1(x)x1 = xout_img3 = x1.squeeze()print(out_img3.shape)
选择可视化的对应通道
# 选择要可视化的通道索引selected_channels = [0, 15, 30, 45, 63] # 例如,选择第0、15、30、45和63个通道
创建一个包含多个通道子图的图形
# 创建一个包含多个子图的图形fig, axes = plt.subplots(1, len(selected_channels), figsize=(15, 5))
之后遍历通道并显示出来
# 遍历选定的通道并绘制它们for i, ax in enumerate(axes):channel_image = out_img3[selected_channels[i]] # 提取选定通道的图像print(channel_image.shape)out_img1 = transforms.ToPILImage()(channel_image)ax.imshow(out_img1, cmap='viridis') # 使用 viridis 颜色图显示图像ax.set_title(f'Channel {selected_channels[i]}') # 设置子图的标题ax.axis('off') # 关闭坐标轴
显示
# 显示图形plt.tight_layout() # 调整子图参数, 使之填充整个图像区域plt.show()
让其封装到一些函数,方便调用
def figure_trans():# 加载图像并转换为数组的形式,调整大小img_path = "E:/fasterrcnn/input_images/cat_dog.png"image = Image.open(img_path)transform = transforms.ToTensor()tensor_image = transform(image)transform_size = transforms.Resize((512, 512))out = transform_size(tensor_image)tensor_image_batch = out.unsqueeze(0)return tensor_image_batchdef yml_plot():yaml_file_path = 'F:/code/DEKR-main/experiments/coco/w32/w32_4x_reg03_bs10_512_adam_lr1e-3_coco_x140.yaml'yaml_file = Path(yaml_file_path)tensor_image_batch = figure_trans()if yaml_file.exists():with yaml_file.open('r') as file:cfg_dict = yaml.safe_load(file)cfg = OmegaConf.create(cfg_dict)out_img3 = PoseHigherResolutionNet(cfg)# input = torch.randn(1, 3, 512, 512)out_img3 = out_img3.forward(tensor_image_batch)# 选择要可视化的通道索引selected_channels = [0, 15, 30, 45, 63] # 例如,选择第0、15、30、45和63个通道# 创建一个包含多个子图的图形fig, axes = plt.subplots(1, len(selected_channels), figsize=(15, 5))# 遍历选定的通道并绘制它们for i, ax in enumerate(axes):channel_image = out_img3[selected_channels[i]] # 提取选定通道的图像print(channel_image.shape)out_img1 = transforms.ToPILImage()(channel_image)ax.imshow(out_img1, cmap='viridis') # 使用 viridis 颜色图显示图像ax.set_title(f'Channel {selected_channels[i]}') # 设置子图的标题ax.axis('off') # 关闭坐标轴# 显示图形plt.tight_layout() # 调整子图参数, 使之填充整个图像区域plt.show()# print(pose.forward(input))else:print(f"Error: The YAML file {yaml_file_path} does not exist.")plt_figure = yml_plot()