当前位置: 首页 > news >正文

《动手学深度学习(PyTorch版)》笔记3.2

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过。

Chapter3 Linear Neural Networks

3.2 Implementations of Linear Regression from Scratch

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import random
import torch
from d2l import torch as d2ldef synthetic_data(w, b, num_examples):  #@save"""Generate y = Xw + b + noise."""#generates a random matrix X with dimensions (num_examples, len(w)) using a normal distribution with a mean of 0 and standard deviation of 1.X = torch.normal(0, 1, (num_examples, len(w)),dtype=torch.float32) #calculates the target values y by multiplying the input matrix X with the weight vector w and adding the bias term b. y = torch.matmul(X, w) + b  #And then adds some random noise to the target values y. The noise is generated from a normal distribution with mean 0 and standard deviation 0.01.                    y += torch.normal(0, 0.01, y.shape)             return X, y.reshape((-1, 1)) #The -1 in the first dimension means that PyTorch should automatically infer the size of that dimension based on the total number of elements. In other words, it is used to ensure that the reshaped tensor has the same total number of elements as the original tensor.true_w=torch.tensor([2,-3.4],dtype=torch.float32)
true_b=4.2
features,labels=synthetic_data(true_w,true_b,1000)
print('features:',features[0],'\nlabel:',labels[0])d2l.set_figsize()
d2l.plt.scatter(features[:,(1)].detach().numpy(),labels.detach().numpy(),1)
#plt.show()#显示散点图
#"features[:, 1]" selects the second column of the features tensor. 
#The detach() method is used to create a new tensor that shares no memory with the original tensor, and numpy() is then called to convert it to a NumPy array.
#"1" is the size of the markers in the scatter plot.def data_iter(batch_size,features,labels):num_examples=len(features)indices=list(range(num_examples))#随机读取文本random.shuffle(indices)#"Shuffle the indices"意为打乱索引 for i in range(0,num_examples,batch_size):batch_indices=torch.tensor(indices[i:min(i+batch_size,num_examples)])#"min(i + batch_size, num_examples)" is used to handle the last batch, which might have fewer examples than batch_size.yield features[batch_indices],labels[batch_indices]#初始化参数。从均值为0,标准差为0.01的正态分布中抽取随机数来初始化权重,并将偏置量置为0       
w=torch.normal(0,0.01,size=(2,1),requires_grad=True)
b=torch.zeros(1,requires_grad=True)#定义线性回归模型
def linreg(X,w,b): #@savereturn torch.matmul(X,w)+b #广播机制:用一个向量加一个标量时,标量会加到向量的每一个分量上#定义均方损失函数
def squared_loss(y_hat,y): #@savereturn (y_hat-y.reshape(y_hat.shape))**2/2#定义优化算法:小批量随机梯度下降
def sgd(params,lr,batch_size): #@savewith torch.no_grad():for param in params:param-=lr*param.grad/batch_sizeparam.grad.zero_()#轮数num_epochs和学习率lr都是超参数,先分别设为3和0.03,具体方法后续讲解
lr=0.03
num_epochs=3
batch_size=10for epoch in range(num_epochs):for X,y in data_iter(batch_size,features,labels):l=squared_loss(linreg(X,w,b),y)l.sum().backward()#因为l是一个向量而不是标量,因此需要把l的所有元素加到一起来计算关于(w,b)的梯度sgd([w,b],lr,batch_size)with torch.no_grad():train_l=squared_loss(linreg(features,w,b),labels)print(f'epoch {epoch+1}:squared_loss {float(train_l.mean()):f}')
print(f'w的估计误差:{true_w-w.reshape(true_w.shape)}')
#结果中的grad_fn=<SubBackward0>表示这个tensor是由一个正向减法操作生成的
print(f'b的估计误差:{true_b-b}')#<RsubBackward1>表示由一个反向减法操作生成

相关文章:

  • 数据结构和算法笔记5:堆和优先队列
  • MYSQL数据库基本操作-DQL-基本查询
  • day34WEB 攻防-通用漏洞文件上传黑白盒审计逻辑中间件外部引用
  • CentOS 7 下安装 Docker 及配置阿里云加速服务
  • 浅析大数据汇总
  • PyTorch初探:基本函数与案例实践
  • HCIP之MPLS实验
  • TensorFlow2实战-系列教程4:数据增强:keras工具包/Data Augmentation
  • HTML — 区块元素
  • 嵌入式Linux系统下的智能家居能源管理系统的设计与实现
  • NIO-Selector详解
  • Java 基于SpringBoot+Vue的母婴商城系统,附源码,文档
  • React Hooks大全—useRef
  • Kafka-服务端-GroupCoordinator
  • 武忠祥2025高等数学,基础阶段的百度网盘+视频及PDF
  • 【108天】Java——《Head First Java》笔记(第1-4章)
  • CentOS 7 防火墙操作
  • Django 博客开发教程 16 - 统计文章阅读量
  • docker容器内的网络抓包
  • Hexo+码云+git快速搭建免费的静态Blog
  • IDEA 插件开发入门教程
  • JavaScript 是如何工作的:WebRTC 和对等网络的机制!
  • JavaScript学习总结——原型
  • Node项目之评分系统(二)- 数据库设计
  • PaddlePaddle-GitHub的正确打开姿势
  • Unix命令
  • Vue 动态创建 component
  • 聚簇索引和非聚簇索引
  • 可能是历史上最全的CC0版权可以免费商用的图片网站
  • 理清楚Vue的结构
  • 前端_面试
  • 前端每日实战:70# 视频演示如何用纯 CSS 创作一只徘徊的果冻怪兽
  • 提醒我喝水chrome插件开发指南
  • 写代码的正确姿势
  • 学习使用ExpressJS 4.0中的新Router
  • 移动端唤起键盘时取消position:fixed定位
  • 【运维趟坑回忆录】vpc迁移 - 吃螃蟹之路
  • #、%和$符号在OGNL表达式中经常出现
  • #Z0458. 树的中心2
  • #我与Java虚拟机的故事#连载10: 如何在阿里、腾讯、百度、及字节跳动等公司面试中脱颖而出...
  • (33)STM32——485实验笔记
  • (AtCoder Beginner Contest 340) -- F - S = 1 -- 题解
  • (C语言)求出1,2,5三个数不同个数组合为100的组合个数
  • (五)Python 垃圾回收机制
  • (轉貼) 蒼井そら挑戰筋肉擂台 (Misc)
  • *上位机的定义
  • .net 按比例显示图片的缩略图
  • .net(C#)中String.Format如何使用
  • .NET/MSBuild 中的发布路径在哪里呢?如何在扩展编译的时候修改发布路径中的文件呢?
  • .NET下ASPX编程的几个小问题
  • .NET学习教程二——.net基础定义+VS常用设置
  • .net用HTML开发怎么调试,如何使用ASP.NET MVC在调试中查看控制器生成的html?
  • .NET中的Exception处理(C#)
  • ::什么意思
  • [ MSF使用实例 ] 利用永恒之蓝(MS17-010)漏洞导致windows靶机蓝屏并获取靶机权限