当前位置: 首页 > news >正文

从0开始深度学习(5)——线性回归的逐步实现

将从零开始实现整个方法, 包括数据流水线、模型、损失函数和小批量随机梯度下降优化器,但现代的深度学习框架几乎可以自动化地进行所有这些工作,但从零开始实现可以确保我们真正知道自己在做什么。

下一章会使用框架简洁的实现线性回归

# 提前导入的库
import random
import torch
import matplotlib.pyplot as plt

1 生成数据集

我们将根据带有噪声的线性模型构造一个人造数据集。
在这里插入图片描述

def synthetic_data(w, b, num_examples):  #@save"""生成y=Xw+b+噪声"""X = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(X, w) + by += torch.normal(0, 0.01, y.shape)return X, y.reshape((-1, 1))true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)# 生成1000个点# 绘制散点图,我们可以简单看下每个点趋近于哪条线
plt.scatter(features[:, 0].numpy(), labels.numpy(), 1.0)
plt.xlabel('Feature')
plt.ylabel('Label')
plt.title('Scatter Plot of Generated Data')
plt.show()

在这里插入图片描述

2 读取数据集

每次抽取一小批量样本,并使用它们来更新我们的模型。

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))  # 创建一个包含所有样本索引的列表random.shuffle(indices)  # 随机打乱索引,以确保每次迭代时数据的顺序都是随机的# 遍历索引列表,步长为batch_sizefor i in range(0, num_examples, batch_size):# 根据当前的索引i和batch_size,计算出当前小批量的索引范围batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])# 使用当前小批量的索引从特征和标签中抽取对应的数据yield features[batch_indices], labels[batch_indices]# 设置小批量的大小
batch_size = 10for X, y in data_iter(batch_size, features, labels):# 打印第一个小批量的特征和标签print(X, '\n', y)break

第一个下批量的特征和标签
在这里插入图片描述
PS:在深度学习框架中实现的内置迭代器效率要高得多

3 定义模型

.matmul 函数是 PyTorch 中的一个方法,用于执行矩阵乘法。
定义一个简单的线性模型,即一个特征矩阵X和向量w进行矩阵-向量相乘后,再加上一个偏置参数b

def linreg(X, w, b):  #@save"""线性回归模型"""return torch.matmul(X, w) + b

4 定义损失函数

因为需要计算损失函数的梯度,所以我们应该先定义损失函数。这里使用均方误差(MSE)

def squared_loss(y_hat, y):  #@save"""均方损失"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

5 定义优化算法

线性模型有解析解,但是为了模拟其他没有解析解的模型,这里使用梯度下降,即在每一步中,使用从数据集中随机抽取的一个小批量,然后根据参数计算损失的梯度,接下来,朝着减少损失的方向更新我们的参数。每一步更新的大小由学习速率lr决定。

def sgd(params, lr, batch_size):  #@save"""小批量随机梯度下降"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()

6 初始化模型参数

初始化参数

import torch# 初始化权重w,使用正态分布,均值为0,标准差为0.01,形状为(2, 1)
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
# 初始化偏置b,值为0,形状为(1,),这里使用reshape或者直接创建一个标量
b = torch.zeros(1, requires_grad=True)

7 训练

# 设置模型参数
lr = 0.03 # 学习率
num_epochs = 3 # 迭代周期(迭代几次)
net = linreg # 线性回归模型
loss = squared_loss # 平方损失函数for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels): # 计算当前小批量的损失l = loss(net(X, w, b), y)  # 使用net函数计算预测值,然后计算与真实值y的损失l.sum().backward() # 将损失相加(因为损失是按小批量计算的),然后对所有参数求梯度sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}') # 打印每个epoch的损失

运行结果
在这里插入图片描述
比较真实参数和通过训练学到的参数来评估训练的成功程度。

print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')

在这里插入图片描述

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 二进制方式安装K8S
  • 【Python】BeautifulSoup:HTML解析
  • H264结构及RTP封装
  • SQLite3 数据库
  • Linux中全局变量配置,/etc/profile.d还是/etc/profile
  • 数据结构(15)——哈希表(2)
  • C#从入门到精通(22)—Path类的使用
  • 2024 高教社杯 数学建模国赛 (C题)深度剖析|农作物的种植策略|数学建模完整代码+建模过程全解全析
  • 【项目一】基于pytest的自动化测试框架day1
  • CRE6959AM70V055S 超低待机功耗反激式开关电源芯片
  • CSS解析:盒模型
  • linux~~目录结构远程登录教程(xshell+xftp)
  • 鼠标控制dom元素的大小。采用ResizeObserver——监听元素大小的变化
  • HarmonyOS开发实战( Beta5版)合理使用动画丢帧规范实践
  • SpringBoot+Vue实现大文件上传(断点续传-后端控制(一))
  • Android单元测试 - 几个重要问题
  • Effective Java 笔记(一)
  • gitlab-ci配置详解(一)
  • Java 多线程编程之:notify 和 wait 用法
  • Javascript弹出层-初探
  • php ci框架整合银盛支付
  • PHP CLI应用的调试原理
  • QQ浏览器x5内核的兼容性问题
  • Sublime text 3 3103 注册码
  • thinkphp5.1 easywechat4 微信第三方开放平台
  • 从零开始学习部署
  • 大型网站性能监测、分析与优化常见问题QA
  • 高度不固定时垂直居中
  • 如何编写一个可升级的智能合约
  • 算法-图和图算法
  • 我的面试准备过程--容器(更新中)
  • 如何用纯 CSS 创作一个菱形 loader 动画
  • 我们雇佣了一只大猴子...
  • 直播平台建设千万不要忘记流媒体服务器的存在 ...
  • ‌‌雅诗兰黛、‌‌兰蔻等美妆大品牌的营销策略是什么?
  • # 计算机视觉入门
  • (1)(1.11) SiK Radio v2(一)
  • (16)Reactor的测试——响应式Spring的道法术器
  • (39)STM32——FLASH闪存
  • (C语言)深入理解指针2之野指针与传值与传址与assert断言
  • (第27天)Oracle 数据泵转换分区表
  • (二)springcloud实战之config配置中心
  • (附源码)ssm基于微信小程序的疫苗管理系统 毕业设计 092354
  • (黑马出品_高级篇_01)SpringCloud+RabbitMQ+Docker+Redis+搜索+分布式
  • (转)GCC在C语言中内嵌汇编 asm __volatile__
  • (转)利用ant在Mac 下自动化打包签名Android程序
  • (最简单,详细,直接上手)uniapp/vue中英文多语言切换
  • *算法训练(leetcode)第三十九天 | 115. 不同的子序列、583. 两个字符串的删除操作、72. 编辑距离
  • .bat批处理(三):变量声明、设置、拼接、截取
  • .gitignore文件—git忽略文件
  • .net core 依赖注入的基本用发
  • .NET MVC之AOP
  • .NET Remoting学习笔记(三)信道
  • .NET 中使用 Mutex 进行跨越进程边界的同步
  • .net和jar包windows服务部署