当前位置: 首页 > news >正文

PyTorch 基础学习(1) - 快速入门

系列文章:
PyTorch 基础学习(1) - 快速入门
PyTorch 基础学习(2)- 张量 Tensors
PyTorch 基础学习(3) - 张量的数学操作
PyTorch 基础学习(4)- 张量的类型
PyTorch 基础学习(5)- 神经网络

介绍

PyTorch学习,我们从AI的hello world程序(线性回归)开始,线性回归就通过一系列输入x和输出y,计算y = wx + b公式,w和b的合理值,训练后,就可以通过模型计算任意输入x,得到y值。

PyTorch 快速入门:线性回归

1. 环境准备
  • 安装 PyTorch

    • 访问 PyTorch 官网,选择适合你的操作系统和 Python 版本的安装命令。通常可以通过 pip 安装:
    pip install torch
    
2. 基本概念
  • 张量 (Tensor):

    • 类似于 NumPy 数组,但可以在 GPU 上运行以加速计算。
    • 是 PyTorch 中数据操作的基本单位。
    • 就是一维或多位数组,在线性回归中就是输入x数组和输出y数组
  • 自动微分 (Autograd):

    • 自动计算张量的梯度,这在训练机器学习模型时非常有用。
    • 就是通过对比结果调整w和b的值,逐渐趋于合理。
  • 神经网络 (Neural Networks):

    • 使用 torch.nn 模块来构建和训练模型。
    • 就是加载模型,torch集成了线性回顾模型
3. 线性回归模型实例

我们将通过一个简单的线性回归示例来介绍 PyTorch 的基本用法。目标是拟合一个线性关系 ( y = 2x )。

线性回归背景

线性回归是用于建模两个变量之间关系的基本统计技术。它假设两个变量之间的关系是线性的,可以用下面的公式表示:

[ y = wx + b ]

  • ( w ) 是权重(slope),决定了输入 ( x ) 如何影响输出 ( y )。
  • ( b ) 是偏置(intercept),是输出值在输入为 0 时的偏移量。
代码实现与详细注解
import torch
import torch.nn as nn
import torch.optim as optim# 数据
# 我们创建了一些简单的训练数据。x 是输入张量,y 是目标输出张量。
x = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)
y = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32)# 模型
# 定义一个简单的线性模型,输入和输出都是一维的。
# nn.Linear(1, 1) 意味着我们有一个输入特征和一个输出特征。
model = nn.Linear(1, 1)# 损失和优化器
# 损失函数用于衡量模型预测值与实际值之间的差异。
# 这里使用均方误差损失(MSE),适合用于回归问题。
criterion = nn.MSELoss()# 优化器用于更新模型的参数以减少损失。
# 这里使用随机梯度下降(SGD)优化器,学习率设为 0.01。
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练
epochs = 1000  # 训练的轮数
for epoch in range(epochs):# 前向传播:通过模型计算预测值outputs = model(x)# 计算损失:比较预测值和实际值loss = criterion(outputs, y)# 反向传播和优化optimizer.zero_grad()  # 清除之前的梯度loss.backward()  # 反向传播计算梯度optimizer.step()  # 更新模型参数# 每 100 个 epoch 打印一次损失if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')# 测试
# 关闭梯度计算,以提高计算效率,因为我们不再需要计算梯度
with torch.no_grad():# 使用训练好的模型进行预测predicted = model(torch.tensor([[5.0]]))# 打印预测结果print(f'Predicted value for input 5.0: {predicted.item():.4f}')

输出:

......
Epoch [800/1000], Loss: 0.0006
Epoch [900/1000], Loss: 0.0003
Epoch [1000/1000], Loss: 0.0002
Predicted value for input 5.0: 9.9769

详细解读

  • 数据准备:

    • xy 是一组简单的线性数据对,表示 ( y = 2x ) 的关系。我们使用 PyTorch 的 tensor 创建张量数据。
  • 模型定义:

    • nn.Linear(1, 1) 定义了一个线性层,该层有一个输入节点和一个输出节点,表示简单的线性回归模型。
  • 损失函数和优化器:

    • 损失函数: nn.MSELoss() 用于计算模型预测值与真实目标值之间的均方误差。
    • 优化器: optim.SGD 是一种优化算法,用于更新模型参数。学习率 lr=0.01 决定了每次参数更新的步长。
  • 训练过程:

    • 通过一个循环进行多次迭代,每次迭代都包括前向传播、计算损失、反向传播和参数更新。
    • optimizer.zero_grad() 清除之前的梯度。
    • loss.backward() 计算当前损失相对于参数的梯度。
    • optimizer.step() 更新模型参数以最小化损失。
  • 测试模型:

    • torch.no_grad() 禁用梯度计算以提高计算效率,因为我们不需要更新参数。
    • 使用训练好的模型进行预测,并打印结果。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 从零开始搭建 LVS 高性能集群 (DR模式)
  • JAVA中的对象流ObjectInputStream
  • uniapp实现自定义弹窗组件,支持富文本传入内容
  • Linux:Linux环境基础开发工具使用
  • DIAdem 与 LabVIEW
  • 【数据结构篇】~顺序表
  • Golang | Leetcode Golang题解之第336题回文对
  • 分布式锁实现方案--redis、zookeeper、mysql
  • Java从zip文件中读取指定的csv文件使用EasyExcel解析出现流关闭异常Stream closed
  • 【常见算法题】斐波那契数列(矩阵快速幂)
  • 歌曲爬虫下载
  • java多线程(初阶)
  • 【Godot4自学手册】第四十五节用着色器(shader)制作水中效果
  • linux C语言remove函数及相关函数
  • 如何选择较为安全的第三方依赖版本?
  • [分享]iOS开发 - 实现UITableView Plain SectionView和table不停留一起滑动
  • 【每日笔记】【Go学习笔记】2019-01-10 codis proxy处理流程
  • go语言学习初探(一)
  • Idea+maven+scala构建包并在spark on yarn 运行
  • in typeof instanceof ===这些运算符有什么作用
  • Java 网络编程(2):UDP 的使用
  • Java,console输出实时的转向GUI textbox
  • Javascript 原型链
  • Java到底能干嘛?
  • Joomla 2.x, 3.x useful code cheatsheet
  • Lucene解析 - 基本概念
  • PHP 使用 Swoole - TaskWorker 实现异步操作 Mysql
  • php面试题 汇集2
  • V4L2视频输入框架概述
  • 百度贴吧爬虫node+vue baidu_tieba_crawler
  • 那些被忽略的 JavaScript 数组方法细节
  • 使用agvtool更改app version/build
  • 项目管理碎碎念系列之一:干系人管理
  • #pragma 指令
  • #数据结构 笔记三
  • (1)Hilt的基本概念和使用
  • (附源码)springboot社区居家养老互助服务管理平台 毕业设计 062027
  • (五)关系数据库标准语言SQL
  • (一)Neo4j下载安装以及初次使用
  • (转)拼包函数及网络封包的异常处理(含代码)
  • .NET Core/Framework 创建委托以大幅度提高反射调用的性能
  • .Net Winform开发笔记(一)
  • .NET运行机制
  • .php结尾的域名,【php】php正则截取url中域名后的内容
  • [AIGC 大数据基础]hive浅谈
  • [Angular 基础] - 数据绑定(databinding)
  • [Angular] 笔记 9:list/detail 页面以及@Output
  • [C#]winform利用seetaface6实现C#人脸检测活体检测口罩检测年龄预测性别判断眼睛状态检测
  • [Cloud Networking] Layer3 (Continue)
  • [COI2007] Sabor
  • [hadoop读书笔记] 第十五章 sqoop1.4.6小实验 - 将mysq数据导入HBASE
  • [HJ56 完全数计算]
  • [J2ME]url请求返回参数非法(java.lang.illegalArgument)
  • [Linux]如何理解kernel、shell、bash
  • [M二叉树] lc236. 二叉树的最近公共祖先(dfs+二叉搜索树)