当前位置: 首页 > news >正文

【Python深度学习】深度学习框架搭建模版

深度学习框架搭建模版

文章目录

  • 深度学习框架搭建模版
  • 一、框架搭建四部曲
    • 1.导入包
    • 2.定义类和函数
    • 3.定义网络层
    • 4.实例化网络
  • 二、完整代码
  • 三、运行结果

一、框架搭建四部曲

1.导入包

首先是导入包因为使用的是pytorch框架所以倒入torch相关包,summary是可以获得自己搭建模型的参数、各层特征图大小、以及各层的参数所占内存的包作用效果如p2

安装方法:pip install torchsummary

'''
导入包
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

2.定义类和函数

定义类Class以及def super,这些是类的继承最基础的知识啦如果不懂原理就按模版记下即可;接着开始搭建层,这里采用nn.Sequential,相当于一个大容器可以放入任意量的网络层p1中放入一个卷积层;接着进入线性层依然使用nn.Sequential;

class Net(nn.Module):
    def __init__(self, num_classes=10):
        super(Net, self).__init__()
        self.fetures = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64,
                                                kernel_size=3, stride=1, padding=1))
        self.classify = nn.Sequential(nn.Linear(32 * 32 * 64, 20),
                                      nn.Linear(20, num_classes))

3.定义网络层

定义好网络层就可以定义层之间的计算过程啦首先进入卷积层接着需要将卷积层的形状从四维变成二维,在这里使用了view函数,接着传入线性层得到return

def forward(self, x):
        x = self.fetures(x)
        x = x.view(x.size(0), -1)
        x = self.classify(x)
        return x

4.实例化网络

实例化网络;假设输入大小为(10, 3, 32,32),将输入传入网络就得到输出结果的尺寸啦!其中10代表每一次输入的图像张数;3是通道数3, 32, 32为输入图片的宽高。调用summary检查网络结构,此时只需输入(3, 32, 32)即可因为summary中只需输入通道数以及宽高即可。

Modle = Net()
input = torch.ones([10, 3, 32, 32])
result = Modle(input)
print(result.shape)
summary(Modle.to("cuda"), (3, 32, 32))

二、完整代码

完整代码如下,如果运行后出现错误可以在评论区里写下你的看法和建议:

'''
Aouther:LiuZhenming
Time:2022-09-25
'''
# 导入包
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

# 定义类和函数
class Net(nn.Module):
    def __init__(self, num_classes=10):
        super(Net, self).__init__()
        self.fetures = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64,
                                                kernel_size=3, stride=1, padding=1))
        self.classify = nn.Sequential(nn.Linear(32 * 32 * 64, 20),
                                      nn.Linear(20, num_classes))
    # 定义网络层
    def forward(self, x):
        x = self.fetures(x)
        x = x.view(x.size(0), -1)
        x = self.classify(x)
        return x
# 实例化网络
Modle = Net()
input = torch.ones([10, 3, 32, 32])
result = Modle(input)
print(result.shape)
summary(Modle.to("cuda"), (3, 32, 32))

三、运行结果


~

torch .Size([10,10])

–––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
Layer (type) Output Shape Param #

==============================================================
Conv2d-1 [-1,64,32,32] 1,792
Linear-3 [-1,10] 210

==============================================================
Total params: 1, 312, 742
Trainable params: 1, 312, 742
Non-trainable params: 0

–––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
Input size (MB) : 0.01
Forward/ backward pass size (MB) : 0.50

在这里插入图片描述



THE END


这篇深度学习框架搭建模版到这里就结束了,大家如果喜欢就给个三连呗!




相关文章:

  • 双绞线连接网卡和集线器时的制作步骤
  • vue实战-mockjs模拟数据
  • 深度学习优化算法之动量法[公式推导](MXNet)
  • tomcat面试和Spring的面试题
  • 网课查题公众号接口
  • 基于Hive的搜狗搜索日志与结果Python可视化设计
  • vue+echarts项目四:地区销量趋势(堆叠折线图)
  • SpringSecurity实战-第6-8章
  • 前端 基础知识
  • 【极客时间2】左耳听风
  • 六级高频词汇——Group07
  • C++类和对象详解(中篇)
  • java五位随机验证码的实现。要求前四位是随机大小写的字母,最后一位是数字的组合。例如qWrY4
  • 《关于我摸鱼一天后搞定PyCharm这档事》Python环境配置
  • 公众号网课搜题系统
  • 【刷算法】求1+2+3+...+n
  • android百种动画侧滑库、步骤视图、TextView效果、社交、搜房、K线图等源码
  • leetcode388. Longest Absolute File Path
  • Nacos系列:Nacos的Java SDK使用
  • PHP CLI应用的调试原理
  • vue2.0项目引入element-ui
  • windows-nginx-https-本地配置
  • WinRAR存在严重的安全漏洞影响5亿用户
  • 代理模式
  • 来,膜拜下android roadmap,强大的执行力
  • 聊聊hikari连接池的leakDetectionThreshold
  • 面试题:给你个id,去拿到name,多叉树遍历
  • 入门到放弃node系列之Hello Word篇
  • 算法---两个栈实现一个队列
  • 我有几个粽子,和一个故事
  • 小程序滚动组件,左边导航栏与右边内容联动效果实现
  • 新书推荐|Windows黑客编程技术详解
  • 3月7日云栖精选夜读 | RSA 2019安全大会:企业资产管理成行业新风向标,云上安全占绝对优势 ...
  • 从如何停掉 Promise 链说起
  • 分布式关系型数据库服务 DRDS 支持显示的 Prepare 及逻辑库锁功能等多项能力 ...
  • 组复制官方翻译九、Group Replication Technical Details
  • ​ 无限可能性的探索:Amazon Lightsail轻量应用服务器引领数字化时代创新发展
  • ​html.parser --- 简单的 HTML 和 XHTML 解析器​
  • ​linux启动进程的方式
  • # 日期待t_最值得等的SUV奥迪Q9:空间比MPV还大,或搭4.0T,香
  • # 透过事物看本质的能力怎么培养?
  • #git 撤消对文件的更改
  • (MonoGame从入门到放弃-1) MonoGame环境搭建
  • (Redis使用系列) SpringBoot 中对应2.0.x版本的Redis配置 一
  • (附源码)springboot助农电商系统 毕业设计 081919
  • (附源码)ssm智慧社区管理系统 毕业设计 101635
  • (三)模仿学习-Action数据的模仿
  • (未解决)jmeter报错之“请在微信客户端打开链接”
  • (转)Mysql的优化设置
  • (转)微软牛津计划介绍——屌爆了的自然数据处理解决方案(人脸/语音识别,计算机视觉与语言理解)...
  • (轉貼) UML中文FAQ (OO) (UML)
  • ..回顾17,展望18
  • .htaccess配置常用技巧
  • .md即markdown文件的基本常用编写语法
  • .NET CORE 第一节 创建基本的 asp.net core