当前位置: 首页 > news >正文

Pytorch搭建自定义神经网络

简介

在日常的实验中,为了改变网络模型,我们一般会自己搭建网络,在基础上进行修改创新,本文主要介绍如何利用Pytorch搭建自定义的神经网络

nn.Module类

在Pytorch中搭建自己的网络模型,首先要继承nn.Module,通过继承的方式定义自己的模型。

init函数

在Python中__init__函数就是构造函数,在__init__函数中我们一般会定义模型中的一些结构,比如卷积层,池化层,全连接层等。这里会用到一个函数nn.Sequential,这个函数的目的是为了拼接各种卷积,全连接等(当然你也可以把里面的卷积等结构拆开写)。

forward函数

forward函数是重写了nn.Module中的forward函数,目的在于训练/预测数据,forward函数中有个参数x表示输入的数据,如果__init__函数中定义的参数没有问题,那么在forward中直接调用就可以,最后返回输出的结果即可。
上面的文字理解或许不是很好理解,下面我们看一下利用Pytorch搭建的AlextNet网络模型。

import torch.nn as nn


class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        # 卷积层
        self.layers = nn.Sequential(
            # 第一层
            nn.Conv2d(3, 96, kernel_size=(3, 3)),
            nn.BatchNorm2d(96),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            # 第二层
            nn.Conv2d(96, 256, kernel_size=(3, 3)),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            # 第三层
            nn.Conv2d(256, 384, kernel_size=(3, 3), padding=1),
            nn.ReLU(True),

            # 第四层
            nn.Conv2d(384, 384, kernel_size=(3, 3), padding=1),
            nn.ReLU(True),

            # 第五层
            nn.Conv2d(384, 256, kernel_size=(3, 3), padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

        # 全连接层
        self.fc = nn.Sequential(
            nn.Linear(1024, 2048),
            nn.Dropout(0.5),
            nn.Linear(2048, 2048),
            nn.Dropout(0.5),
            nn.Linear(2048, 10)
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

分析

根据上面的介绍,首先继承nn.Module,然后编写__init__函数和forward函数。这里我详细的分析一下每行代码的意义。

  • nn.Conv2d:表示一个2D卷积层,nn.Conv2d(3, 96, kernel_size=(3, 3))表示这是一个2D卷积层,其中输入的是3通道,输出的是96通道,卷积核大小是(3,3)
  • nn.MaxPool2d:表示一个2D最大池化层,nn.MaxPool2d(kernel_size=3, stride=2)表示池化层大小为(3,3),步长为2(注意:池化层不改变输入的维度)
  • nn.ReLU:表示使用relu作为激活函数
  • nn.BatchNorm2d:表示批归一化处理
  • nn.Linear():全连接层,nn.Linear(2048, 10)表示输入为2048输出为10
  • nn.Dropout()nn.Dropout(0.5)表示舍弃50%的参数

利用nn.Sequential把这些结构进行整合,最后在forward函数中进行调用,返回处理后的值即可。因此利用Pytorch搭建网络可以分为以下几步:

  1. 定义网络名称继承nn.Module
  2. 在构造方法中定义需要的结构
  3. 在forward函数中调用,并返回处理后的值

预测

利用刚才建立的网络模型在cifar10数据集上进行测试,加载数据集的方式不再赘述,下面主要介绍一下如何训练+预测。

  1. 生成已经建立的网络模型并利用GPU加速:alexNet = AlexNet().to(device)
  2. 定义优化器optimize = torch.optim.Adam(alexNet.parameters(), lr=0.01)
  3. 定义损失函数loss_function = nn.CrossEntropyLoss()
  4. 定义迭代次数开始预测
  5. 反向传播+优化器优化

预测代码如下:

import os

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
import torch.utils.data
from AlexNet import AlexNet

device = torch.device("cuda")

if __name__ == '__main__':
    download = False

    if 'res' not in os.listdir():
        download = True

    train_set = datasets.CIFAR10(root='./res/data', transform=transforms.ToTensor(), train=True, download=download)
    test_set = datasets.CIFAR10(root='./res/data', transform=transforms.ToTensor(), train=False, download=download)
    train_set = torch.utils.data.DataLoader(dataset=train_set, batch_size=64, shuffle=True, num_workers=4)

    test_set = torch.utils.data.DataLoader(dataset=test_set, batch_size=64, shuffle=True, num_workers=4)
    alexNet = AlexNet().to(device)

    optimize = torch.optim.Adam(alexNet.parameters(), lr=0.01)
    loss_function = nn.CrossEntropyLoss()
    epochs = 5
    for epoch in range(epochs):
        loss_sum = 0.0
        for i, data in enumerate(train_set):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimize.zero_grad()
            outputs = alexNet(inputs)
            loss_per = loss_function(outputs, labels)
            loss_per.backward()
            optimize.step()
            loss_sum += loss_per.item()

            if i % 100 == 99:
                print('[Epoch:%d, batch:%d] train loss: %.03f' % (epoch + 1, i + 1, loss_sum / 100))
                loss_sum = 0.0

        total, right = 0, 0
        for i, data in enumerate(test_set):
            test_inputs, test_labels = data
            test_inputs, test_labels = test_inputs.to(device), test_labels.to(device)
            test_outputs = alexNet(test_inputs)
            test_outputs = torch.max(test_outputs.data, 1)[1]
            total += test_outputs.size(0)
            right += (test_labels == test_outputs).sum()
        print("第{}轮的准确率为:{:.2f}%".format(epoch + 1, 100.0 * right.item() / total))


相关文章:

  • Python数据分析——基础数据结构
  • TestNG-常用注解介绍
  • STM32时钟系统和TIMER配置(溢出中断/PWM)实例
  • 随想录一刷Day04——链表
  • 【javaweb简单教程】2.JSP实现数据传递和保存(含四大作用域及简单示例)
  • 7.ROS2笔记-节点
  • 【C++】类和对象(下篇)(万字)
  • 【牛客 - 剑指offer】JZ67 把字符串转换成整数 Java实现
  • python采集火热弹幕数据并做词云图可视化分析
  • 【小程序从0到1】模版与配置|数据绑定|事件绑定
  • NetSuite SuiteQL Query Tool
  • 功能异常强大,推荐这款 Python 时序异常检测神器
  • 串的存储结构 --王道
  • React路由规则的定义、声明式导航、编程式导航
  • Java_四种内部类
  • ➹使用webpack配置多页面应用(MPA)
  • 2018一半小结一波
  • Android Studio:GIT提交项目到远程仓库
  • Django 博客开发教程 8 - 博客文章详情页
  • docker-consul
  • Docker容器管理
  • mysql innodb 索引使用指南
  • python学习笔记 - ThreadLocal
  • socket.io+express实现聊天室的思考(三)
  • ubuntu 下nginx安装 并支持https协议
  • VirtualBox 安装过程中出现 Running VMs found 错误的解决过程
  • vue 个人积累(使用工具,组件)
  • 基于Dubbo+ZooKeeper的分布式服务的实现
  • 微服务入门【系列视频课程】
  • 写代码的正确姿势
  • 一些css基础学习笔记
  • 走向全栈之MongoDB的使用
  • 白色的风信子
  • #HarmonyOS:基础语法
  • #ifdef 的技巧用法
  • #我与Java虚拟机的故事#连载07:我放弃了对JVM的进一步学习
  • (javascript)再说document.body.scrollTop的使用问题
  • (Matlab)遗传算法优化的BP神经网络实现回归预测
  • (补)B+树一些思想
  • (附源码)计算机毕业设计ssm基于Internet快递柜管理系统
  • (免费领源码)Python#MySQL图书馆管理系统071718-计算机毕业设计项目选题推荐
  • (四)Controller接口控制器详解(三)
  • (一)kafka实战——kafka源码编译启动
  • .NET CF命令行调试器MDbg入门(二) 设备模拟器
  • .NET delegate 委托 、 Event 事件,接口回调
  • .NET 简介:跨平台、开源、高性能的开发平台
  • .NET/C# 将一个命令行参数字符串转换为命令行参数数组 args
  • @RequestBody与@ModelAttribute
  • [].slice.call()将类数组转化为真正的数组
  • [Asp.net MVC]Asp.net MVC5系列——Razor语法
  • [BUUCTF]-Reverse:reverse3解析
  • [BZOJ]4817: [Sdoi2017]树点涂色
  • [C++]STL之map
  • [C语言][C++][时间复杂度详解分析]二分查找——杨氏矩阵查找数字详解!!!
  • [DNS网络] 网页无法打开、显示不全、加载卡顿缓慢 | 解决方案