当前位置: 首页 > news >正文

Linear Regression with PyTorch

Linear Regression with PyTorch

Problem Description

初始化一组数据 \((x,y)\),使其满足这样的线性关系 \(y = w x + b\) 。然后基于反向传播法,用均方误差(mean squared error)
\[ MSE = \frac{1}{n} \sum_{n} (y- \hat y)^{2} \]

去拟合这组数据。

衡量两个分布之间的距离,最直接的方法是用交叉熵。

我们用最简单的一元变量去拟合这组数据,其实一元线性回归的表达式 \(y = wx + b\) 用神经网络的形式可表示成如下图所示

528745-20180611202649179-916805629.png

该神经网络有一个输入、一个输出、不使用任何激活函数。这就是一元线性回归的神经网络表示结果。相比较于下图这种神经网络的形式化表示,上图是一种简单的特例。

528745-20180612131036870-1998475050.png

Key Points

torch.unsqueeze

重塑一个张量的 size,见下面代码

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1,  2,  3,  4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])

torch.linspace

得到一个在 start 和 end 之间等距的一维张量,见下面代码

>>> torch.linspace(1, 6, steps=3)
tensor([ 1.0000,  3.5000,  6.0000])

torch.rand

返回一个满足 size 维度要求的随机数组,随机数服从0-1均匀分布。

torch.nn.Linear(1,1)

self.prediction = torch.nn.Linear(1, 1)

这一行代码,实际是维护了两个变量,其描述了这样的一种关系:

\[prediction_{1\times1} = weight_{1\times1} \times input_{1\times1} + bias_{1\times1}\]

其中,每个参数都是 \(1\times1\) 维的。

Code

import torch


epoch = 10000
lr = 0.01
w = 10
b = 5

x = torch.unsqueeze(torch.linspace(1, 10, 20), 1)
y = w*x + b + torch.rand(x.size())


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.prediction = torch.nn.Linear(1, 1)

    def forward(self, x):
        out = self.prediction(x)
        return out


net = Net()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criticism = torch.nn.MSELoss()


for i in range(epoch):
    y_pred = net(x)
    loss = criticism(y_pred, y)  # 先是 y_pred 然后是 y_true 参数顺序不能乱

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("%.5f" % loss.data)
print(net.state_dict()['prediction.weight'])
print(net.state_dict()['prediction.bias'])

输出:

0.08882
tensor([[ 9.9713]])
tensor([ 5.6524])

Results Analysis

输出显示:

  1. 均方误差(MSE)为 0.0882
  2. \(weight\) 的拟合结果为 9.9713
  3. \(bias\) 的拟合结果为 5.6524

分析:

  1. 因为我主动引入了误差(服从0-1均匀分布),而且是线性拟合,所以 MSE 几乎不能减小到零;
  2. 9.9713 的拟合值已经非常接近真实值 10 了;5.6524 的拟合值较真实值 5 的距离较大(距离约为自身的 10%)

转载于:https://www.cnblogs.com/fengyubo/p/9164970.html

相关文章:

  • JavaScript -- 45 经典技巧以及注意点
  • mybatits
  • win7x64下的redis安装与使用
  • 【Revit API】FamilyInstance、FamilySymbol、Family的寻找关系
  • 「日常训练」 Soldier and Number Game (CFR304D2D)
  • 变量的经典
  • 从cookies 获取token
  • python - Linux C调用Python 函数
  • IIS 7 应用程序池自动回收关闭的解决方案
  • FullScreenPopNavigationController
  • tp5多条件查询
  • 本地电脑与远程服务器之间不能复制粘贴解决方法
  • 八 原型prototype和__proto__
  • SQL存储过程解密
  • 数据库可视化工具简介以及pymysql的使用
  • 【JavaScript】通过闭包创建具有私有属性的实例对象
  • 【node学习】协程
  • iOS 颜色设置看我就够了
  • JavaScript 奇技淫巧
  • Javascript设计模式学习之Observer(观察者)模式
  • jQuery(一)
  • JS基础篇--通过JS生成由字母与数字组合的随机字符串
  • mongodb--安装和初步使用教程
  • Python学习之路16-使用API
  • spring + angular 实现导出excel
  • Web设计流程优化:网页效果图设计新思路
  • 从输入URL到页面加载发生了什么
  • 记录一下第一次使用npm
  • 微信小程序开发问题汇总
  • 想使用 MongoDB ,你应该了解这8个方面!
  • 小程序button引导用户授权
  • 异步
  • 正则表达式
  • ​第20课 在Android Native开发中加入新的C++类
  • (k8s中)docker netty OOM问题记录
  • (ZT)薛涌:谈贫说富
  • (ZT)一个美国文科博士的YardLife
  • (二)PySpark3:SparkSQL编程
  • (分布式缓存)Redis分片集群
  • (附源码)springboot码头作业管理系统 毕业设计 341654
  • (力扣记录)235. 二叉搜索树的最近公共祖先
  • (亲测)设​置​m​y​e​c​l​i​p​s​e​打​开​默​认​工​作​空​间...
  • (转)EOS中账户、钱包和密钥的关系
  • (转)Linux整合apache和tomcat构建Web服务器
  • (转)Scala的“=”符号简介
  • (轉貼) 寄發紅帖基本原則(教育部禮儀司頒布) (雜項)
  • .bat批处理(八):各种形式的变量%0、%i、%%i、var、%var%、!var!的含义和区别
  • .helper勒索病毒的最新威胁:如何恢复您的数据?
  • .NET Compact Framework 3.5 支持 WCF 的子集
  • .net core 3.0 linux,.NET Core 3.0 的新增功能
  • .NET Core 和 .NET Framework 中的 MEF2
  • .NET Core 通过 Ef Core 操作 Mysql
  • .NET处理HTTP请求
  • @Autowired 与@Resource的区别
  • @RequestMapping-占位符映射