当前位置: 首页 > news >正文

用Pytorch实现一个线性回归

文章目录

  • 问题描述
  • 代码与注释
  • 总结
  • 参考资料

问题描述

假设学生在期末考试中,如果他们花x个小时在一门课程上,他们将得到y分。

x (hours)y (points)
12
24
36
4

问题是在这门课上花费4个小时时,得到的分数是多少?


很显然这是一个回归的问题。
下面结局这个问题的方法,是机器学习训练任务的基本方法论,可以基于此构建出更为复杂的模型去处理更为复杂的任务。


代码与注释

import torch

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

class LinearModel(torch.nn.Module):	#将模型构建成一个类,并且继承自Moudle
    def __init__(self):	#构造函数
        super(LinearModel, self).__init__() #调用父类的构造,固定写法,这一步必须要有
        self.linear = torch.nn.Linear(1, 1) #Linear是一个类,类后面加括号意思是构建了一个对象,括号里面是的参数是权重和偏置
        
    def forward(self, x):	#前向传播函数
        y_pred = self.linear(x) #在一个对象的后面加括号,实现了一个可调用的对象,x送入Linear对象,执行w * x + b
        return y_pred
model = LinearModel()

# 损失函数,将向量里的损失进行求和,得到一个标量的损失值,MSELoss也继承自nn.Moudle
criterion = torch.nn.MSELoss(size_average=False)
# 优化器,不继承自Moudle,不会构建计算图,构建出的优化器知道要对哪些参数做优化,并且知道学习率
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练
for epoch in range(100):
    y_pred = model(x_data)	#先计算y_hat
    loss = criterion(y_pred, y_data) #计算损失
    print(epoch, loss) #loss是一个标量,一个对象,会自动调用__str__()函数,不会产生计算图,是安全的
    optimizer.zero_grad()	#梯度归0
    loss.backward()	#反向传播
    optimizer.step()	#step()用来做更新,根据预先设置的参数以及包含的梯度和学习率自动进行更新
    
# 输出 weight 和 bias
print('w = ', model.linear.weight.item()) #weight是一个矩阵,加上.item()让其只显示数值
print('b = ', model.linear.bias.item())

# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)

print('y_pred = ', y_test.data)

训练结果:
在这里插入图片描述
可以看到,通过不断的迭代,参数中的weight趋近于2,b趋近于0,最终的预测值趋近于8。
上面用的nn.Linear是一个线性模型,如下图所示:
在这里插入图片描述


总结

对于一个模型,都要将其构建成一个类,并且继承自Module,至少定义两个方法,一个是构造函数__init__(self),另一个是forward()。而backward会由Module里自动根据计算图计算。

对于整个训练搭建和训练的的步骤,总结为以下四步:

  1. Prepare dataset
  2. Design model using Class
  3. Construct loss and optimizer
  4. Training cycle

然后就是不断的前馈—反馈—更新----前馈—反馈—更新最后使Loss收敛。


参考资料

[1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=5

相关文章:

  • 【C++】二叉搜索树set/map
  • 最短路径查找Dijkstra算法
  • [数字媒体] Photoshop基础之图像校正、抠图(证件照)和融合
  • 【毕业设计】基于的单片机的移动硬盘设计与实现 - stm32 嵌入式 物联网
  • 使用Python的requests库发送SOAP请求,错误码415
  • Python爬虫技术系列-02HTML解析-lxml+BS4
  • 今日头条——机器学习算法岗1234面
  • 【笔记】快速理解傅里叶级数
  • 宣布发布 .NET 7 Release Candidate 1
  • 8万Star,这个开源项目有点强
  • 数据批处理速度慢?不妨试试这个
  • 透过安全事件剖析黑客组织攻击技术(2FA/MA的攻击手法)
  • java毕业设计——基于Java+AI的五子棋游戏设计与实现(毕业论文+程序源码)——五子棋游戏
  • 29、Java 中的接口详解
  • mysql中怎么防止数据丢失
  • css选择器
  • MySQL数据库运维之数据恢复
  • OpenStack安装流程(juno版)- 添加网络服务(neutron)- controller节点
  • Rancher-k8s加速安装文档
  • Spring Cloud(3) - 服务治理: Spring Cloud Eureka
  • 从setTimeout-setInterval看JS线程
  • 浅析微信支付:申请退款、退款回调接口、查询退款
  • 让你的分享飞起来——极光推出社会化分享组件
  • 责任链模式的两种实现
  • #NOIP 2014#Day.2 T3 解方程
  • #QT(串口助手-界面)
  • (32位汇编 五)mov/add/sub/and/or/xor/not
  • (HAL)STM32F103C6T8——软件模拟I2C驱动0.96寸OLED屏幕
  • (二)linux使用docker容器运行mysql
  • (附源码)spring boot车辆管理系统 毕业设计 031034
  • (五)网络优化与超参数选择--九五小庞
  • (一)Mocha源码阅读: 项目结构及命令行启动
  • (原創) 如何使用ISO C++讀寫BMP圖檔? (C/C++) (Image Processing)
  • (转)Google的Objective-C编码规范
  • (转)iOS字体
  • (转)为C# Windows服务添加安装程序
  • ..thread“main“ com.fasterxml.jackson.databind.JsonMappingException: Jackson version is too old 2.3.1
  • .mysql secret在哪_MySQL如何使用索引
  • .net on S60 ---- Net60 1.1发布 支持VS2008以及新的特性
  • .NET 中使用 Mutex 进行跨越进程边界的同步
  • .NET实现之(自动更新)
  • .NET运行机制
  • @Controller和@RestController的区别?
  • @SuppressWarnings注解
  • @synthesize和@dynamic分别有什么作用?
  • [ solr入门 ] - 利用solrJ进行检索
  • [20190113]四校联考
  • [2021ICPC济南 L] Strange Series (Bell 数 多项式exp)
  • [ACM] hdu 1201 18岁生日
  • [Angular 基础] - 指令(directives)
  • [Angular] 笔记 16:模板驱动表单 - 选择框与选项
  • [APUE]进程关系(下)
  • [AUTOSAR][诊断管理][ECU][$37] 请求退出传输。终止数据传输的(上传/下载)
  • [Bzoj4722]由乃(线段树好题)(倍增处理模数小快速幂)
  • [C++]打开新世界的大门之C++入门