当前位置: 首页 > news >正文

pytorch实战(2)-----回归例子

一、回归任务介绍:

拟合一个二元函数 y = x ^ 2.

二、步骤:

  1. 导入包
  2. 创建数据
  3. 构建网络
  4. 设置优化器和损失函数
  5. 前向和后向传播训练网络
  6. 画图

三、代码:

导入包:

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt

创建数据

#torch中的数据要是二维的,unsqueeze是将一维数据转化成二维数据
tmp = torch.linspace(-1,1,100)
x = torch.unsqueeze(tmp,dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())

print(tmp)  #torch.Size([100])
print(x)  #torch.Size([100, 1])
#转成向量
x,y = Variable(x),Variable(y)

   查看数据图像:

plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

构建网络

#Net类继承了Module这个模块
class Net(torch.nn.Module):
    def __init__(self,n_feature,n_hidden,n_output):
        #在搭建模型之前需要继承的一些信息,super表示继承nn.Module的信息,此步骤必须有
        super(Net,self).__init__()
        self.hidden = torch.nn.Linear(n_feature,n_hidden)
        self.predict = torch.nn.Linear(n_hidden,n_output)
    #神经网络前向传递的一个过程,流程图
    def forward(self,x):
        x = F.relu(self.hidden(x))
        x = self.predict(x)
        return x
net = Net(1,10,1)
plt.ion()
plt.show()
#可以看到搭建的图流程
print(net)
 打印的结果:
Net(
  (hidden): Linear(in_features=1, out_features=10, bias=True)
  (predict): Linear(in_features=10, out_features=1, bias=True)
)

 设置优化器和损失函数

optimizer = torch.optim.SGD(net.parameters(),lr = 0.5)  #传入网络的参数来优化它们
loss_func = torch.nn.MSELoss()

前向和后向传播训练网络

for t in range(100):
    
    #forward
    prediction = net(x)
    loss = loss_func(prediction,y)  #预测值pre在前,实际值y在后,不然结果会不一样
    
    #backward()
    optimizer.zero_grad()   #梯度全部设为0
    loss.backward()  #loss计算参数的梯度
    optimizer.step()  #采用优化器以lr=0.5来优化梯度
    
###########################以下为可视化过程##################################
    if t % 5 == 0:
        plt.cla()
        plt.scatter(x.data.numpy(),y.data.numpy())
        plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
        plt.text(0.5,0,'Loss=%.4f' % loss.data[0],fontdict={'size':20,'color':'red'})
        plt.pause(0.1)
plt.ioff()
plt.show()

训练结果:

第一次:

最后一次:

 

转载于:https://www.cnblogs.com/Lee-yl/p/9885011.html

相关文章:

  • 柔宇科技发售可折叠柔性屏手机 平板与手机从此二合一
  • 浅析宽带接入技术
  • 一个网工的Linux学习过程
  • 数据结构(算法)-图(最短距离Dijkstra)
  • Jessica Kerr:高绩效团队简史
  • Windows操作系统查看电脑开关机记录
  • 实现图元及属性的算法---椭圆生成算法
  • 大快搜索数据爬虫技术实例安装教学篇
  • 解决项目不编译4大clean
  • 迭代器 /生成器 yield
  • mysql表与表之间的关系
  • 对标汽车之家,新势力杉车网的另类崛起
  • RabbitMq集群搭建
  • vue-cli2使用cdn方式引入cytoscape
  • VS2015 提示 无法启动 IIS Express Web 服务器
  • [译]如何构建服务器端web组件,为何要构建?
  • 【翻译】babel对TC39装饰器草案的实现
  • 2018天猫双11|这就是阿里云!不止有新技术,更有温暖的社会力量
  • ECMAScript 6 学习之路 ( 四 ) String 字符串扩展
  • fetch 从初识到应用
  • Flannel解读
  • jquery ajax学习笔记
  • Laravel深入学习6 - 应用体系结构:解耦事件处理器
  • Linux Process Manage
  • maven工程打包jar以及java jar命令的classpath使用
  • Phpstorm怎样批量删除空行?
  • Python 基础起步 (十) 什么叫函数?
  • Quartz实现数据同步 | 从0开始构建SpringCloud微服务(3)
  • Webpack4 学习笔记 - 01:webpack的安装和简单配置
  • 从零到一:用Phaser.js写意地开发小游戏(Chapter 3 - 加载游戏资源)
  • 第十八天-企业应用架构模式-基本模式
  • 类orAPI - 收藏集 - 掘金
  • 利用DataURL技术在网页上显示图片
  • 面试题:给你个id,去拿到name,多叉树遍历
  • 如何抓住下一波零售风口?看RPA玩转零售自动化
  • 少走弯路,给Java 1~5 年程序员的建议
  • 新手搭建网站的主要流程
  • 优化 Vue 项目编译文件大小
  • No resource identifier found for attribute,RxJava之zip操作符
  • 宾利慕尚创始人典藏版国内首秀,2025年前实现全系车型电动化 | 2019上海车展 ...
  • 策略 : 一文教你成为人工智能(AI)领域专家
  • ​什么是bug?bug的源头在哪里?
  • #stm32整理(一)flash读写
  • (1)常见O(n^2)排序算法解析
  • (Forward) Music Player: From UI Proposal to Code
  • (安卓)跳转应用市场APP详情页的方式
  • (二) Windows 下 Sublime Text 3 安装离线插件 Anaconda
  • (附源码)springboot 房产中介系统 毕业设计 312341
  • (附源码)基于SSM多源异构数据关联技术构建智能校园-计算机毕设 64366
  • (三)docker:Dockerfile构建容器运行jar包
  • (四)TensorRT | 基于 GPU 端的 Python 推理
  • (转)http协议
  • ****** 二 ******、软设笔记【数据结构】-KMP算法、树、二叉树
  • *上位机的定义
  • .htaccess 强制https 单独排除某个目录