当前位置: 首页 > news >正文

深度学习 —— 个人学习笔记20(转置卷积、全卷积网络)

声明

  本文章为个人学习使用,版面观感若有不适请谅解,文中知识仅代表个人观点,若出现错误,欢迎各位批评指正。

三十九、转置卷积

import torch
from torch import nndef trans_conv(X, K):h, w = K.shapeY = torch.zeros((X.shape[0] + h - 1, X.shape[1] + w - 1))for i in range(X.shape[0]):for j in range(X.shape[1]):Y[i: i + h, j: j + w] += X[i, j] * Kreturn YX = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
print(f'基本的二维转置卷积运算 : {trans_conv(X, K)}')X, K = X.reshape(1, 1, 2, 2), K.reshape(1, 1, 2, 2)
tconv = nn.ConvTranspose2d(1, 1, kernel_size=2, bias=False)
tconv.weight.data = K
print(f'输入输出都是四维张量时 : {tconv(X)}')# 在转置卷积中,填充被应用于输出(常规卷积将填充应用于输入)。
# 当将高和宽两侧的填充数指定为 1 时,转置卷积的输出中将删除第一和最后的行与列。
tconv = nn.ConvTranspose2d(1, 1, kernel_size=2, padding=1, bias=False)
tconv.weight.data = K
print(f'padding=1 时 : {tconv(X)}')# 在转置卷积中,步幅被指定为中间结果(输出),而不是输入。
tconv = nn.ConvTranspose2d(1, 1, kernel_size=2, stride=2, bias=False)
tconv.weight.data = K
print(f'stride=2 时 : {tconv(X)}')X = torch.rand(size=(1, 10, 16, 16))
conv = nn.Conv2d(10, 20, kernel_size=5, padding=2, stride=3)
tconv = nn.ConvTranspose2d(20, 10, kernel_size=5, padding=2, stride=3)
print(f'先代入卷积,再代入转置卷积形状不变 : {tconv(conv(X)).shape == X.shape}')def corr2d(X, K):reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs)h, w = K.shapeY = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))for i in range(Y.shape[0]):for j in range(Y.shape[1]):Y[i, j] = reduce_sum((X[i: i + h, j: j + w] * K))return YX = torch.arange(9.0).reshape(3, 3)
K = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
Y = corr2d(X, K)
print(f'二维卷积运算 : {Y}')def kernel2matrix(K):k, W = torch.zeros(5), torch.zeros((4, 9))k[:2], k[3:5] = K[0, :], K[1, :]W[0, :5], W[1, 1:6], W[2, 3:8], W[3, 4:] = k, k, k, kreturn WW = kernel2matrix(K)
print(f'稀疏权重矩阵 : {W}')print(f'使用矩阵乘法实现卷积 : {Y == torch.matmul(W, X.reshape(-1)).reshape(2, 2)}')Z = trans_conv(Y, K)
print(f'使用矩阵乘法实现转置卷积 : {Z == torch.matmul(W.T, Y.reshape(-1)).reshape(3, 3)}')

四十、全卷积网络( FCN )

import os
import torch
import time
import torchvision
from PIL import Image
from IPython import display
from torch import nn
import matplotlib.pyplot as plt
from torch.nn import functional as F
from matplotlib_inline import backend_inlinedef accuracy(y_hat, y):                                                           # 定义一个函数来为预测正确的数量计数"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == y                                                # bool 类型,若预测结果与实际结果一致,则为 Truereturn float(cmp.type(y.dtype).sum())def evaluate_accuracy_gpu(net, data_iter, device=None):"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):axes.set_xlabel(xlabel), axes.set_ylabel(ylabel)axes.set_xscale(xscale), axes.set_yscale(yscale)axes.set_xlim(xlim),     axes.set_ylim(ylim)if legend:axes.legend(legend)axes.grid()class Accumulator:                                                                # 定义一个实用程序类 Accumulator,用于对多个变量进行累加"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]class Animator:                                                                   # 定义一个在动画中绘制数据的实用程序类 Animator"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []backend_inline.set_matplotlib_formats('svg')self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# Add multiple data points into the figureif not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)# 通过以下两行代码实现了在PyCharm中显示动图# plt.draw()# plt.pause(interval=0.001)display.clear_output(wait=True)plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']class Timer:def __init__(self):self.times = []self.start()def start(self):self.tik = time.time()def stop(self):self.times.append(time.time() - self.tik)return self.times[-1]def sum(self):"""Return the sum of time."""return sum(self.times)def read_voc_images(voc_dir, is_train=True):txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation','train.txt' if is_train else 'val.txt')mode = torchvision.io.image.ImageReadMode.RGBwith open(txt_fname, 'r') as f:images = f.read().split()features, labels = [], []for i, fname in enumerate(images):features.append(torchvision.io.read_image(os.path.join(voc_dir, 'JPEGImages', f'{fname}.jpg')))labels.append(torchvision.io.read_image(os.path.join(voc_dir, 'SegmentationClass' ,f'{fname}.png'), mode))return features, labelsVOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],[0, 64, 128]]def voc_colormap2label():colormap2label = torch.zeros(256 ** 3, dtype=torch.long)for i, colormap in enumerate(VOC_COLORMAP):colormap2label[(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = ireturn colormap2labeldef voc_rand_crop(feature, label, height, width):rect = torchvision.transforms.RandomCrop.get_params(feature, (height, width))feature = torchvision.transforms.functional.crop(feature, *rect)label = torchvision.transforms.functional.crop(label, *rect)return feature, labeldef voc_label_indices(colormap, colormap2label):colormap = colormap.permute(1, 2, 0).numpy().astype('int32')idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256+ colormap[:, :, 2])return colormap2label[idx]class VOCSegDataset(torch.utils.data.Dataset):def __init__(self, is_train, crop_size, voc_dir):self.transform = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])self.crop_size = crop_sizefeatures, labels = read_voc_images(voc_dir, is_train=is_train)self.features = [self.normalize_image(feature)for feature in self.filter(features)]self.labels = self.filter(labels)self.colormap2label = voc_colormap2label()if is_train:print('train : read ' + str(len(self.features)) + ' examples')else:print('validation : read ' + str(len(self.features)) + ' examples')def normalize_image(self, img):return self.transform(img.float() / 255)def filter(self, imgs):return [img for img in imgs if (img.shape[1] >= self.crop_size[0] andimg.shape[2] >= self.crop_size[1])]def __getitem__(self, idx):feature, label = voc_rand_crop(self.features[idx], self.labels[idx],*self.crop_size)return (feature, voc_label_indices(label, self.colormap2label))def __len__(self):return len(self.features)def set_figsize(figsize=(8.5, 6.5)):backend_inline.set_matplotlib_formats('svg')plt.rcParams['figure.figsize'] = figsizeplt.rcParams['font.sans-serif'] = ['Microsoft YaHei']def try_all_gpus():return [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())]def load_data_voc(batch_size, crop_size):voc_dir = 'E:\\dayLily'train_iter = torch.utils.data.DataLoader(VOCSegDataset(True, crop_size, voc_dir), batch_size,shuffle=True, drop_last=True)test_iter = torch.utils.data.DataLoader(VOCSegDataset(False, crop_size, voc_dir), batch_size,drop_last=True)return train_iter, test_iterdef show_images(imgs, num_rows, num_cols, suptitle=None, titles=None, scale=1.5):numpy = lambda x, *args, **kwargs: x.detach().numpy(*args, **kwargs)figsize = (num_cols * scale, num_rows * scale)_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):try:img = numpy(img)except:passax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])elif suptitle:plt.suptitle(suptitle)plt.show()return axespretrained_net = torchvision.models.resnet18(pretrained=True)
print(list(pretrained_net.children())[-3:])net = nn.Sequential(*list(pretrained_net.children())[:-2])X = torch.rand(size=(1, 3, 320, 480))
print(f'net 的前向传播将输入的高和宽减小至原来的 1/32 : {net(X).shape}')num_classes = 21
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv', nn.ConvTranspose2d(num_classes, num_classes,kernel_size=64, padding=16, stride=32))def bilinear_kernel(in_channels, out_channels, kernel_size):factor = (kernel_size + 1) // 2if kernel_size % 2 == 1:center = factor - 1else:center = factor - 0.5og = (torch.arange(kernel_size).reshape(-1, 1),torch.arange(kernel_size).reshape(1, -1))filt = (1 - torch.abs(og[0] - center) / factor) * \(1 - torch.abs(og[1] - center) / factor)weight = torch.zeros((in_channels, out_channels,kernel_size, kernel_size))weight[range(in_channels), range(out_channels), :, :] = filtreturn weightconv_trans = nn.ConvTranspose2d(3, 3, kernel_size=4, padding=1, stride=2,bias=False)
conv_trans.weight.data.copy_(bilinear_kernel(3, 3, 4))img = torchvision.transforms.ToTensor()(Image.open('E:\\dayLily\\JPEGImages\\2024_959.jpg'))
X = img.unsqueeze(0)
Y = conv_trans(X)
out_img = Y[0].permute(1, 2, 0).detach()set_figsize()
print('input image shape:', img.permute(1, 2, 0).shape)
plt.imshow(img.permute(1, 2, 0))
print('output image shape:', out_img.shape)
plt.imshow(out_img)
plt.title('转置卷积层将图像的高和宽分别放大了 2 倍')
plt.show()W = bilinear_kernel(num_classes, num_classes, 64)
net.transpose_conv.weight.data.copy_(W)batch_size, crop_size = 32, (320, 480)
train_iter, test_iter = load_data_voc(batch_size, crop_size)def loss(inputs, targets):return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)def train_batch(net, X, y, loss, trainer, devices):if isinstance(X, list):X = [x.to(devices[0]) for x in X]else:X = X.to(devices[0])y = y.to(devices[0])net.train()trainer.zero_grad()pred = net(X)l = loss(pred, y)l.sum().backward()trainer.step()train_loss_sum = l.sum()train_acc_sum = accuracy(pred, y)return train_loss_sum, train_acc_sumdef train(net, train_iter, test_iter, loss, trainer, num_epochs,devices=try_all_gpus()):timer, num_batches = Timer(), len(train_iter)animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],legend=['train loss', 'train acc', 'test acc'])net = nn.DataParallel(net, device_ids=devices).to(devices[0])for epoch in range(num_epochs):# Sum of training loss, sum of training accuracy, no. of examples,# no. of predictionsmetric = Accumulator(4)for i, (features, labels) in enumerate(train_iter):timer.start()l, acc = train_batch(net, features, labels, loss, trainer, devices)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[3],None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))plt.title(f'loss {metric[0] / metric[2]:.3f}, train acc 'f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}\n'f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on 'f'{str(torch.cuda.get_device_name())}')plt.show()num_epochs, lr, wd, devices = 5, 0.001, 1e-3, try_all_gpus()
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
train(net, train_iter, test_iter, loss, trainer, num_epochs, devices)##### 预测 #####
def predict(img):X = test_iter.dataset.normalize_image(img).unsqueeze(0)pred = net(X.to(devices[0])).argmax(dim=1)return pred.reshape(pred.shape[1], pred.shape[2])def label2image(pred):colormap = torch.tensor(VOC_COLORMAP, device=devices[0])X = pred.long()return colormap[X, :]voc_dir = 'E:\\dayLily'
test_images, test_labels = read_voc_images(voc_dir, False)
n, imgs = 4, []
for i in range(n):crop_rect = (0, 0, 500, 500)X = torchvision.transforms.functional.crop(test_images[i], *crop_rect)pred = label2image(predict(X))imgs += [X.permute(1,2,0), pred.cpu(),torchvision.transforms.functional.crop(test_labels[i], *crop_rect).permute(1,2,0)]
show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2, suptitle='第一行为原图,第二行为预测结果,第三行为真实结果')




  文中部分知识参考:B 站 —— 跟李沐学AI;百度百科

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【大数据】6:MapReduce YARN 初体验
  • DAMA学习笔记(十五)-数据管理组织与角色期望
  • 模拟三层--控制层、业务层和数据访问层
  • 抓包分析排查利器TCPdump
  • Qt读写sysfs
  • 8月13日学习笔记 LVS
  • 代码随想录算法训练营day42|动态规划part09
  • 【中等】 猿人学web第一届 第5题 js混淆-乱码增强
  • HAProxy原理及实例
  • 51单片机学习记录-数码管操作
  • Unity 流光shader的思路
  • 开源模型应用落地-LangChain高阶-记忆组件-RedisChatMessageHistory正确使用(八)
  • java注解(实现原理及自定义注解)
  • Flask获取请求信息
  • Stable Diffusion绘画 | ControlNet应用-NormalMap(法线贴图)
  • 2018一半小结一波
  • css布局,左右固定中间自适应实现
  • Dubbo 整合 Pinpoint 做分布式服务请求跟踪
  • isset在php5.6-和php7.0+的一些差异
  • JavaScript 一些 DOM 的知识点
  • JS基础篇--通过JS生成由字母与数字组合的随机字符串
  • Mac转Windows的拯救指南
  • NLPIR语义挖掘平台推动行业大数据应用服务
  • React-redux的原理以及使用
  • ubuntu 下nginx安装 并支持https协议
  • vue从创建到完整的饿了么(11)组件的使用(svg图标及watch的简单使用)
  • Web设计流程优化:网页效果图设计新思路
  • 如何借助 NoSQL 提高 JPA 应用性能
  • 时间复杂度与空间复杂度分析
  • 小程序01:wepy框架整合iview webapp UI
  • 用jquery写贪吃蛇
  • Semaphore
  • #{}和${}的区别是什么 -- java面试
  • #NOIP 2014#Day.2 T3 解方程
  • (2024最新)CentOS 7上在线安装MySQL 5.7|喂饭级教程
  • (done) ROC曲线 和 AUC值 分别是什么?
  • (力扣题库)跳跃游戏II(c++)
  • (五)activiti-modeler 编辑器初步优化
  • (五)c52学习之旅-静态数码管
  • (源码版)2024美国大学生数学建模E题财产保险的可持续模型详解思路+具体代码季节性时序预测SARIMA天气预测建模
  • (转)Linux下编译安装log4cxx
  • .[hudsonL@cock.li].mkp勒索病毒数据怎么处理|数据解密恢复
  • .equals()到底是什么意思?
  • .htaccess 强制https 单独排除某个目录
  • .java 指数平滑_转载:二次指数平滑法求预测值的Java代码
  • .locked1、locked勒索病毒解密方法|勒索病毒解决|勒索病毒恢复|数据库修复
  • .NET Core WebAPI中使用Log4net 日志级别分类并记录到数据库
  • .NET Core 成都线下面基会拉开序幕
  • .net FrameWork简介,数组,枚举
  • .net 程序 换成 java,NET程序员如何转行为J2EE之java基础上(9)
  • .NET/C# 利用 Walterlv.WeakEvents 高性能地定义和使用弱事件
  • .NET/MSBuild 中的发布路径在哪里呢?如何在扩展编译的时候修改发布路径中的文件呢?
  • .netcore如何运行环境安装到Linux服务器
  • .NET精简框架的“无法找到资源程序集”异常释疑
  • @EnableWebMvc介绍和使用详细demo