当前位置: 首页 > news >正文

pytorch线性/非线性回归拟合

一、线性回归

1. 导入依赖库

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.autograd import Variable
  • numpy:用来构建数据
  • matplotlib.pyplot: 将构建好的数据可视化
  • torch.nn:包含了torch已经准备好的层,激活函数、全连接层等
  • torch.optim:提供了神经网络的一系列优化算法,如 SGD、Adam 等
  • torch.autograd:用来自动求导,计算梯度。其中Variable用来包装张量,使得张量能够支持自动求导,但在 PyTorch 0.4 及以后,已经被 Tensor 对象取代。

2. 构建数据

        首先确定一个线性函数,例如y_data = 0.1 * x_data + 0.2。然后在这条直线上加一些噪点,最后看神经网络是否能抵抗这些干扰点,拟合出正确的线性函数。

        只要做神经网络相关的数据处理,就一定要把数据转为张量(tensor)类型。然后想要实现梯度下降算法,就要把张量类型再转为Variable类型。

x_data = np.random.rand(100)
noise = np.random.normal(0, 0.01, x_data.shape)  # 构建正态分布噪点
y_data = x_data * 0.1 + 0.2 + noisex_data = x_data.reshape(-1, 1)  # 把原始数据更改形状,自动匹配任意行,1列
y_data = y_data.reshape(-1, 1)x_data = torch.FloatTensor(x_data)  # 把numpy类型转为tensor类型
y_data = torch.FloatTensor(y_data)
inputs = Variable(x_data)  # 变成variable类型才可以自动求导操作
target = Variable(y_data)

 3. 构建神经网络模型

        构建神经网络模型通常遵循一个相对固定的模板。这种模板不仅让代码结构清晰,还能利用 PyTorch 提供的模块化设计,使得网络的定义、训练、推理更加简洁。

        这里我们定义一个一对一的全连接层即可。使用MSE代价函数,SGD优化算法。

class LinearRegression(nn.Module):# 定义网络结构def __init__(self):super(LinearRegression, self).__init__()  # 固定写法,初始化父类self.fc = nn.Linear(1, 1)  # 定义一个全连接层,且一对一# 定义网络计算(前向传播)def forward(self, x):out = self.fc(x)  # 将输入传递给全连接层return outmodel = LinearRegression()  # 定义模型
mse_loss = nn.MSELoss()  # 使用均方差代价函数
optimizer = optim.SGD(model.parameters(), lr=0.1)  # 使用随机梯度下降法优化模型

4. 模型训练

         在模型训练上,几乎也是一个固定套路。之前写的,inputs和target即x_data和y_data的Variable类型。那么当模型(model)获得输入值(inputs),通过前向传播(forward)就会获得一个输出值(out)。然后通过MSE代价函数就能计算出损失(loss),最后经过计算梯度,优化权值,就完成了一轮训练。共训练1000次,期间可以每隔200次看一下损失值。通过输出结果可以看到loss值在一直变小,训练还不错!

for i in range(1001):out = model(inputs)loss = mse_loss(out, target)  # 计算损失optimizer.zero_grad()  # 梯度清0loss.backward()  # 计算梯度optimizer.step()  # 优化权值if i % 200 == 0:print('第{}次,loss值为:{}'.format(i, loss.item()))

        如果我们查看看最后拟合后的权重值(weight)和偏置值(bias),可以发现和我们之前设计好的的 y_data = 0.1 * x_data + 0.2 几乎非常吻合。

for name, param in model.named_parameters():print('name:{}\nparam:{}\n'.format(name, param))

5. 绘图查看结果

         首先利用scatter画出散点图,然后用plot绘出神经网络的拟合结果。

y_pred = model(inputs)
plt.scatter(x_data, y_data)
plt.plot(x_data, y_pred.data.numpy(), color='red')
plt.show()

二、非线性回归

         构建非线性回归时,思路和线性回归几乎一致,只需要把数据改为非线性数据,然后神经网络模型增加一个隐藏层即可。    

1. 构建非线性数据 

        首先事先设计一个非线性函数:y_data = x_data²,然后再加入一些噪点干扰神经网络。

x_data = np.linspace(-2, 2, 200)[:, np.newaxis]  # linspace(起始点,终止点,分割点总数),然后增加维度到(200, 1)
noise = np.random.normal(0, 0.2, x_data.shape)
y_data = np.square(x_data) + noise

2. 修改神经网络模型 

         一般情况下,只有隐藏层使用激活函数才可用来拟合非线性数据,如sigmoid、relu、tanh等。这里可以先确定10个隐藏神经元看效果如何。

class NonLinearRegression(nn.Module):# 定义网络结构def __init__(self):super(NonLinearRegression, self).__init__()  # 固定写法,初始化父类self.fc1 = nn.Linear(1, 10)  #   定义隐藏层,10个隐藏神经元self.tanh = nn.Tanh()  # 激活函数self.fc2 = nn.Linear(10, 1)# 定义网络计算(前向传播)def forward(self, x):x = self.fc1(x)x = self.tanh(x)x = self.fc2(x)return x

        如果想要较短时间的训练来获取一个相对较好的结果,可以尝试 Adam 自适应矩阵优化算法。虽然 Adam 算法可以自动调整学习率,但是一般默认初始值是0.001,最后训练情况不理想,所以这里设置为0.05的初始值。而且这个算法容易过拟合,需要正则化 weight_decay 来提高模型的泛化性。

        注意:这里的代价函数不可以修改为交叉熵(CrossEntropyLoss),因为交叉熵大多用于分类任务。

model = NonLinearRegression()
mse_loss = nn.MSELoss()  # 均方差代价函数
optimizer = optim.Adam(model.parameters(), lr=0.05, weight_decay=0.001)  # 设置L2正则化,防止过拟合

3. 查看拟合结果 

相关文章:

  • Leetcode 3302. Find the Lexicographically Smallest Valid Sequence
  • 数据库中的表添加uuid字段
  • spring 实用小技巧
  • 编程题 7-12 两个数的简单计算器【PAT】
  • Linux:磁盘管理
  • ps aux | grep smart_webrtc这条指令代表什么意思
  • SQLite3模块使用详解
  • 【Android 14源码分析】Activity启动流程-1
  • 大数据复习知识点5
  • linux服务器部署filebeat
  • [Everything] 文件搜索工具的下载及详细安装使用过程(附有下载文件)
  • Hadoop三大组件之HDFS(一)
  • 在树莓派上部署开源监控系统 ZoneMinder
  • 基于php的幸运舞蹈课程工作室管理系统
  • 黑名单与ip禁令是同一个东西吗
  • 「译」Node.js Streams 基础
  • Bytom交易说明(账户管理模式)
  • docker python 配置
  • iOS编译提示和导航提示
  • JavaScript标准库系列——Math对象和Date对象(二)
  • Java深入 - 深入理解Java集合
  • JS专题之继承
  • LeetCode18.四数之和 JavaScript
  • SQLServer之创建显式事务
  • SSH 免密登录
  • uva 10370 Above Average
  • 阿里云ubuntu14.04 Nginx反向代理Nodejs
  • 测试开发系类之接口自动化测试
  • - 概述 - 《设计模式(极简c++版)》
  • 基于Javascript, Springboot的管理系统报表查询页面代码设计
  • 用jquery写贪吃蛇
  • 这几个编码小技巧将令你 PHP 代码更加简洁
  • shell使用lftp连接ftp和sftp,并可以指定私钥
  • ‌内网穿透技术‌总结
  • #控制台大学课堂点名问题_课堂随机点名
  • (01)ORB-SLAM2源码无死角解析-(66) BA优化(g2o)→闭环线程:Optimizer::GlobalBundleAdjustemnt→全局优化
  • (1综述)从零开始的嵌入式图像图像处理(PI+QT+OpenCV)实战演练
  • (20)docke容器
  • (AtCoder Beginner Contest 340) -- F - S = 1 -- 题解
  • (C#)获取字符编码的类
  • (delphi11最新学习资料) Object Pascal 学习笔记---第5章第5节(delphi中的指针)
  • (done) 两个矩阵 “相似” 是什么意思?
  • (原创)可支持最大高度的NestedScrollView
  • (转载)Linux网络编程入门
  • .NET delegate 委托 、 Event 事件
  • .net6 core Worker Service项目,使用Exchange Web Services (EWS) 分页获取电子邮件收件箱列表,邮件信息字段
  • .ui文件相关
  • @RestControllerAdvice异常统一处理类失效原因
  • [ 隧道技术 ] 反弹shell的集中常见方式(四)python反弹shell
  • [10] CUDA程序性能的提升 与 流
  • [AIGC] 广度优先搜索(Breadth-First Search,BFS)详解
  • [Android]Tool-Systrace
  • [Android]通过PhoneLookup读取所有电话号码
  • [Android学习笔记]ScrollView的使用
  • [C# WPF] 如何给控件添加边框(Border)?