当前位置: 首页 > news >正文

pytorch实现单层线性回归模型

文章目录

    • 简述
      • 代码重构要点
    • 数学模型、运行结果
    • 数据构建与分批
    • 模型封装
    • 运行测试

简述

python使用 数值微分法 求梯度,实现单层线性回归-CSDN博客
python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客
数值微分求梯度、计算图求梯度,实现单层线性回归 模型速度差异及损失率比对-CSDN博客

上述文章都是使用python来实现求梯度的,是为了学习原理,实际使用上,pytorch实现了自动求导,原理也是(基于计算图的)链式求导,本文还就 “单层线性回归” 问题用pytorch实现。

代码重构要点

1.nn.Moudle

torch.nn.Module的继承、nn.Sequentialnn.Linear
torch.nn — PyTorch 2.4 documentation

对于nn.Sequential的理解可以看python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客一文代码的模型初始化与计算部分,如图:

在这里插入图片描述

nn.Sequential可以说是把图中标注的代码封装起来了,并且可以放多层。

2.torch.optim优化器

本例中使用随机梯度下降torch.optim.SGD()
torch.optim — PyTorch 2.4 documentation
SGD — PyTorch 2.4 documentation

3.数据构建与数据加载

data.TensorDatasetdata.DataLoader,之前为了实现数据分批,手动实现了data_iter,现在可以直接调用pytorch的data.DataLoader

对于data.DataLoader的参数num_workers,默认值为0,即在主线程中处理,但设置其它值时存在反而速度变慢的情况,以后再讨论。

数学模型、运行结果

y = X W + b y = XW + b y=XW+b

y为标量,X列数为2. 损失函数使用均方误差。

运行结果:

在这里插入图片描述

在这里插入图片描述

数据构建与分批

def build_data(weights, bias, num_examples):  x = torch.randn(num_examples, len(weights))  y = x.matmul(weights) + bias  # 给y加个噪声  y += torch.randn(1)  return x, y  def load_array(data_arrays, batch_size, num_workers=0, is_train=True):  """构造一个PyTorch数据迭代器"""  dataset = data.TensorDataset(*data_arrays)  return data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=is_train)

模型封装

class TorchLinearNet(torch.nn.Module):  def __init__(self):  super(TorchLinearNet, self).__init__()  model = nn.Sequential(Linear(in_features=2, out_features=1))  self.model = model  self.criterion = nn.MSELoss()  def predict(self, x):  return self.model(x)  def loss(self, y_predict, y):  return self.criterion(y_predict, y)

运行测试

if __name__ == '__main__':  start = time.perf_counter()  true_w1 = torch.rand(2, 1)  true_b1 = torch.rand(1)  x_train, y_train = build_data(true_w1, true_b1, 5000)  net = TorchLinearNet()  print(net)  init_loss = net.loss(net.predict(x_train), y_train)  loss_history = list()  loss_history.append(init_loss.item())  num_epochs = 3  batch_size = 50  learning_rate = 0.01  dataloader_workers = 6  data_loader = load_array((x_train, y_train), batch_size=batch_size, is_train=True)  optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)  for epoch in range(num_epochs):  # running_loss = 0.0  for x, y in data_loader:  y_pred = net.predict(x)  loss = net.loss(y_pred, y)  optimizer.zero_grad()  loss.backward()  optimizer.step()  # running_loss = running_loss + loss.item()  loss_history.append(loss.item())  end = time.perf_counter()  print(f"运行时间(不含绘图时间):{(end - start) * 1000}毫秒\n")  plt.title("pytorch实现单层线性回归模型", fontproperties="STSong")  plt.xlabel("epoch")  plt.ylabel("loss")  plt.plot(loss_history, linestyle='dotted')  plt.show()  print(f'初始损失值:{init_loss}')  print(f'最后一次损失值:{loss_history[-1]}\n')  print(f'正确参数: true_w1={true_w1}, true_b1={true_b1}')  print(f'预测参数:{net.model.state_dict()}')

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 探索工业互联网智能赋能智能制造算法综述
  • 工程数学线性代数(同济大学数学系)第六版(更新中)
  • 2024 年的 Node.js 生态系统
  • 输出倒闭输入
  • Android Studio报错 Failed to transform ‘...‘ using Jetifier. Reason null
  • 学生阅读行为与图书预定平台的设计与实现(全网独一无二,24年最新定做)
  • Qt 字符串类应用
  • 代码随想录算法训练营第二天 | 滑动窗口 + 螺旋矩阵
  • React概念理解
  • 使用yolov5实现目标检测简单案例(测试图片)
  • MySQL- 覆盖索引
  • STM32初识
  • 3.类和对象(中)
  • 《Techporters架构搭建》-Day06 国际化
  • 怎样在 SQL 中对一个包含销售数据的表按照销售额进行降序排序?
  • [deviceone开发]-do_Webview的基本示例
  • 【399天】跃迁之路——程序员高效学习方法论探索系列(实验阶段156-2018.03.11)...
  • co.js - 让异步代码同步化
  • iBatis和MyBatis在使用ResultMap对应关系时的区别
  • Java 网络编程(2):UDP 的使用
  • JS创建对象模式及其对象原型链探究(一):Object模式
  • Laravel深入学习6 - 应用体系结构:解耦事件处理器
  • node 版本过低
  • nodejs实现webservice问题总结
  • Shadow DOM 内部构造及如何构建独立组件
  • socket.io+express实现聊天室的思考(三)
  • Spring-boot 启动时碰到的错误
  • uva 10370 Above Average
  • ViewService——一种保证客户端与服务端同步的方法
  • Web设计流程优化:网页效果图设计新思路
  • 百度地图API标注+时间轴组件
  • -- 查询加强-- 使用如何where子句进行筛选,% _ like的使用
  • 关于Java中分层中遇到的一些问题
  • 基于Vue2全家桶的移动端AppDEMO实现
  • 首页查询功能的一次实现过程
  • 网页视频流m3u8/ts视频下载
  • 要让cordova项目适配iphoneX + ios11.4,总共要几步?三步
  • 用Node EJS写一个爬虫脚本每天定时给心爱的她发一封暖心邮件
  • 树莓派用上kodexplorer也能玩成私有网盘
  • ![CDATA[ ]] 是什么东东
  • #QT(一种朴素的计算器实现方法)
  • (2)(2.4) TerraRanger Tower/Tower EVO(360度)
  • (C语言)求出1,2,5三个数不同个数组合为100的组合个数
  • (day 12)JavaScript学习笔记(数组3)
  • (delphi11最新学习资料) Object Pascal 学习笔记---第13章第6节 (嵌套的Finally代码块)
  • (Demo分享)利用原生JavaScript-随机数-实现做一个烟花案例
  • (NSDate) 时间 (time )比较
  • (solr系列:一)使用tomcat部署solr服务
  • (编程语言界的丐帮 C#).NET MD5 HASH 哈希 加密 与JAVA 互通
  • (二)十分简易快速 自己训练样本 opencv级联lbp分类器 车牌识别
  • (仿QQ聊天消息列表加载)wp7 listbox 列表项逐一加载的一种实现方式,以及加入渐显动画...
  • (附源码)springboot金融新闻信息服务系统 毕业设计651450
  • (黑客游戏)HackTheGame1.21 过关攻略
  • (一)WLAN定义和基本架构转
  • (正则)提取页面里的img标签