当前位置: 首页 > news >正文

pytorch学习之pytorch构建模型的流程

PyTorch 是一个基于 Python 的科学计算库,提供了两个主要特征:第一,它是一个 GPU 加速的张量计算库,提供类似于 NumPy 的操作接口,可以在 GPU 上进行加速计算;第二,它是一个自动微分系统,可以用于深度学习模型的开发和训练。

PyTorch 的主要模块包括:

  1. torch:包含了张量数据类型、数学运算以及用于构建神经网络的函数等等。
  2. torch.nn:包含了定义神经网络层、损失函数、优化器等等的类和函数。
  3. torch.autograd:实现了自动微分功能,用于计算梯度。 torch.optim:包含了定义优化器的类和函数。
  4. torch.utils.data:用于处理数据集和数据加载的工具类和函数。
  5. torchvision:提供了常见的计算机视觉数据集、模型架构、预训练模型等等。

下面简单介绍一下这些模块的主要功能和使用方法:

1.torch:这个模块包含了很多操作张量的函数,例如张量的创建、数学运算、转换、切片等等。可以将其看作是 NumPy 的一个扩展,但是支持 GPU 加速,也支持自动微分。
2.torch.nn:这个模块提供了很多用于定义神经网络的类和函数,包括了各种不同类型的层、激活函数、损失函数等等。用户可以使用这些类和函数来构建自己的神经网络。
3.torch.autograd:这个模块实现了自动微分功能,用于计算梯度。用户只需要将神经网络中的变量设置为可求导的,PyTorch 就可以自动地计算出其梯度。在计算图中,这些变量被称为叶子节点。
4.torch.optim:这个模块包含了各种不同类型的优化器,例如随机梯度下降(SGD)、Adam、Adagrad 等等。用户可以使用这些优化器来更新神经网络的参数。
5.torch.utils.data:这个模块提供了各种用于处理数据集和数据加载的工具类和函数。例如 DataLoader 类可以用于批量加载数据,Dataset 类可以用于处理自定义数据集,Transforms 类可以用于数据增强等等。
6.torchvision:这个模块提供了常见的计算机视觉数据集、模型架构、预训练模型等等。例如可以使用其中的 ImageFolder 类来加载图像数据集,也可以使用其中的 ResNet 类来构建一个 ResNet 神经网络。

  1. 创建一个张量:
import torch

x = torch.tensor([1, 2, 3])

  1. 定义一个简单的神经网络:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

net = Net()

  1. 训练模型:
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

for epoch in range(100):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(trainloader)))

  1. 保存和加载模型:
# 保存模型
PATH = './my_model.pth'
torch.save(net.state_dict(), PATH)

# 加载模型
net = Net()
net.load_state_dict(torch.load(PATH))

这只是PyTorch的一小部分功能,它还包括了很多其他特性,例如数据加载器、自动微分、分布式训练等。如果你需要更多关于PyTorch的帮助,可以查阅官方文档:https://pytorch.org/docs/stable/index.html。

相关文章:

  • react-swipeable-views轮播图实现下方的切换点控制组件
  • Java线程知识点总结
  • Android Compose——一个简单的Bilibili APP
  • 世界顶级五大女程序媛,不仅技术强还都是美女
  • 2023年再不会Redis,就要被淘汰了
  • 【学习笔记】深入理解JVM之垃圾回收机制
  • 【数据结构】链式二叉树
  • 自学大数据第三天~终于轮到hadoop了
  • 应用层协议 HTTP HTTPS
  • Linux内核学习笔记——页表的那些事。
  • 一文带你入门,领略angular风采(上)!!!
  • 嵌入式学习笔记——STM32硬件基础知识
  • 2023年“网络安全”赛项浙江省金华市选拔赛 任务书
  • Qt安装与使用经验分享;无.pro文件;无QTextCodec file;Qt小试;界面居中;无缝;更换Qt图标;更换Qt标题。
  • MyBatis常用的俩种分页方式
  • .pyc 想到的一些问题
  • 【140天】尚学堂高淇Java300集视频精华笔记(86-87)
  • js写一个简单的选项卡
  • JS字符串转数字方法总结
  • NLPIR语义挖掘平台推动行业大数据应用服务
  • Redash本地开发环境搭建
  • vue.js框架原理浅析
  • vue-router的history模式发布配置
  • 关键词挖掘技术哪家强(一)基于node.js技术开发一个关键字查询工具
  • 模仿 Go Sort 排序接口实现的自定义排序
  • 译自由幺半群
  • #define 用法
  • (01)ORB-SLAM2源码无死角解析-(66) BA优化(g2o)→闭环线程:Optimizer::GlobalBundleAdjustemnt→全局优化
  • (arch)linux 转换文件编码格式
  • (C语言)求出1,2,5三个数不同个数组合为100的组合个数
  • (层次遍历)104. 二叉树的最大深度
  • (附源码)php投票系统 毕业设计 121500
  • (附源码)springboot高校宿舍交电费系统 毕业设计031552
  • (十一)图像的罗伯特梯度锐化
  • (转)EOS中账户、钱包和密钥的关系
  • (转)nsfocus-绿盟科技笔试题目
  • .helper勒索病毒的最新威胁:如何恢复您的数据?
  • .Net Web项目创建比较不错的参考文章
  • .net 中viewstate的原理和使用
  • .Net(C#)自定义WinForm控件之小结篇
  • .NET/MSBuild 中的发布路径在哪里呢?如何在扩展编译的时候修改发布路径中的文件呢?
  • .NET的微型Web框架 Nancy
  • .net流程开发平台的一些难点(1)
  • .NET微信公众号开发-2.0创建自定义菜单
  • .net项目IIS、VS 附加进程调试
  • .NET中统一的存储过程调用方法(收藏)
  • .project文件
  • [ SNOI 2013 ] Quare
  • [17]JAVAEE-HTTP协议
  • [8481302]博弈论 斯坦福game theory stanford week 1
  • [DAX] MAX函数 | MAXX函数
  • [ES-5.6.12] x-pack ssl
  • [IOI2018] werewolf 狼人
  • [jQuery]10 Things I Learned from the jQuery Source
  • [LeetCode] Binary Tree Preorder Traversal 二叉树的先序遍历