当前位置: 首页 > news >正文

第二章 pytorch回归问题

文章目录

  • 一、梯度下降算法
  • 二、线性回归问题与逻辑回归问题
  • 三、线性回归实战
  • 四、分类问题
  • 五、手写数字识别体验

一、梯度下降算法

  • 用于反向传播时求解权重矩阵的最优解

二、线性回归问题与逻辑回归问题

  • 主要是使用使用激活函数将函数值限定在有限范围内(即:[0,1] or [-1,1] .etc)

三、线性回归实战

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: pytorch_learning - regression_demo.py
@author: yonghao
@Description: 实现线性回归
@since 2021/01/24 21:29
'''
import numpy as npdef compute_error_for_line_given_points(w, b, points) -> float:
'''
计算样本的平均损失
:paramw: 权重值
:paramb: 误差值
:parampoints: 样本点
:return: 返回样本平均损失值
'''
total_error = 0
for i in range(0, len(points)):
x = points[i, 0]
y = points[i, 1]
total_error += (y - (w * x + b)) ** 2
return total_error / float(len(points))def step_gradient(b_current, w_current, points, learning_rate):
'''
计算当前样本的梯度值
:paramb_current:当前的拟合线性方程的常数值
:paramw_current:当前的拟合线性方程的斜率
:parampoints: 样本
:paramlearning_rate:学习率
:return: 返回梯度下降后的参数
'''
b_gradient = 0
w_gradient = 0
N = float(len(points))
for i in range(0, len(points)):
x = points[i, 0]
y = points[i, 1]
b_gradient += -(2 / N) * (w_current * x + b_current - y)
w_gradient += -(2 / N) * (w_current * x + b_current - y) * xnew_w = w_current - learning_rate * w_gradientnew_b = b_current - learning_rate * b_gradientreturn new_w, new_bdef gradient_descent_runner(points, starting_b, starting_w, learning_rate, num_iterations):
'''
梯度下降算法
:parampoints:样本点集
:paramstarting_b: 起始的b
:paramstarting_w: 起始的w
:paramlearning_rate: 学习率
:paramnum_iterations: 迭代次数
:return: 计算得到的权重系数与误差系数
'''
w = starting_wb = starting_bfor i in range(num_iterations):
w, b = step_gradient(b, w, points, learning_rate)
return w, bdef run():
points = np.random.uniform(0, 5, (100, 2))
learning_rate = 0.0001
initial_b = 0
initial_w = 0
num_iterations = 1000
print('Starting gradient descent at b = {},w = {},error={}'
.format(initial_b, initial_w,
compute_error_for_line_given_points(initial_w,
initial_b,
points)))
w, b = gradient_descent_runner(points, initial_b, initial_w,
learning_rate, num_iterations)
print('After gradient descent at b = {},w = {},error={}'
.format(b, w,
compute_error_for_line_given_points(w,
b,
points)))if __name__ == '__main__':
run()

四、分类问题

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述


五、手写数字识别体验

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: pytorch_learning - classification.py
@author: yonghao
@Description: 实现手写数字的分类时间
@since 2021/01/24 22:38
'''
import torch
from torch import nn
from torch.nn import functional as F
from torch import optimimport torchvision
from matplotlib import pyplot as plt
from realwork.work2_handwrite_classification.utils import plot_curve, plot_image, one_hotbatch_size = 512# step1 .load dataset
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=False)
# 测试显示数据集的函数
# x, y = next(iter(train_loader))
# print(x.shape, y.shape)
# print(test_loader)
# plot_image(x, y, 'image sample')
print("train num = {}".format(len(train_loader)))
print("test num = {}".format(len(test_loader)))class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# xw+b
self.fc1 = nn.Linear(28 * 28, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, 10) # 由于是10分类所以最后层输出一定是10def forward(self, x):
# x:[b,1,28,28]
# h1 = relu(xw1+b1)
x = F.relu(self.fc1(x))
# h2 = relu(h1w2+b2)
x = F.relu(self.fc2(x))
# h3 = h2w3+b3
x = self.fc3(x)
return x# 创建网络
net = Net()
# 定制的梯度下架算法计算器:[w1,b1,w2,b2,w3,b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
# 保存梯度计算过程中的损失值
loss_values = []
for epoch in range(3):
for batch_idx, (x, y) in enumerate(train_loader):
# x:[b,1,28,28],y:[512]
# x:[b,1,28,28] -> [b,784]
x = x.view(x.size(0), x.size(2) * x.size(3))
# => [b,10]
out = net(x)
# [b,10]
y_onehot = one_hot(y)
# loss = mse(out,y_onehot)
loss = F.mse_loss(out, y_onehot)optimizer.zero_grad()
loss.backward()
# w' = w - lr*grad
optimizer.step()
loss_values.append(loss.item())
# 每10个batch显示一次loss值
# if batch_idx % 10 == 0:
#     print(epoch, batch_idx, loss.item())# we get optimal [w1,b1,w2,b2,w3,b3]
# 显示loss的变化情况
# plot_curve(loss_values)# 由测试集显示其准确度
total_correct = 0
for x, y in test_loader:
x = x.view(x.size(0), 28 * 28)
out = net(x)
# out: [b,10] => pred: [b]
# 将soft_one_hot值转换为hard_one_hot值,使其与真实标签一致
# 标注最大值的位置(将高维空间表示为一维中的位置从0开始)
pred = out.argmax(dim=1)
# 判断正确的总数
correct = pred.eq(y).sum().float().item()
total_correct += correcttotal_dataset = len(test_loader.dataset)
acc = total_correct / total_dataset
print("test acc:{}%".format(acc * 100))next(iter(test_loader))
x, _ = next(iter(test_loader))
out = net(x.view(x.size(0), 28 * 28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Java、python、php版的企业单位考勤打卡管理系统的设计与实现(源码、调试、LW、开题、PPT)
  • 深度学习·Pytorch
  • Java TCP练习1
  • 部署 K8s 图形化管理工具 Dashboard
  • 【与C++的邂逅】--- 类和对象(上)
  • 【数据结构-1】二叉树
  • haproxy负载均衡(twenty-eight day)
  • C# 重载运算符
  • web自动化测试Day5
  • 举例说明自然语言处理(NLP)技术。
  • Web前端:CSS篇(二)背景,文本,链接
  • 【ML】Image Augmentation)的作用、使用方法及其分类
  • UIScrollView 的 pagingEnabled属性(UIScrollView默认一次滑动多少距离?)
  • 掌握SQL的威力:批量更新与删除的艺术
  • 如何在 Windows/Mac/在线/iPhone/Android 上将 PDF 转换为 Word
  • C++类中的特殊成员函数
  • CSS3 变换
  • ES6之路之模块详解
  • JDK 6和JDK 7中的substring()方法
  • leetcode386. Lexicographical Numbers
  • Python十分钟制作属于你自己的个性logo
  • SpiderData 2019年2月23日 DApp数据排行榜
  • Spring-boot 启动时碰到的错误
  • vue和cordova项目整合打包,并实现vue调用android的相机的demo
  • webpack+react项目初体验——记录我的webpack环境配置
  • 创建一个Struts2项目maven 方式
  • 从 Android Sample ApiDemos 中学习 android.animation API 的用法
  • 基于组件的设计工作流与界面抽象
  • 看完九篇字体系列的文章,你还觉得我是在说字体?
  • 前端学习笔记之原型——一张图说明`prototype`和`__proto__`的区别
  • 微信开源mars源码分析1—上层samples分析
  • 微信小程序实战练习(仿五洲到家微信版)
  • 我感觉这是史上最牛的防sql注入方法类
  • 智能合约Solidity教程-事件和日志(一)
  • AI又要和人类“对打”,Deepmind宣布《星战Ⅱ》即将开始 ...
  • LevelDB 入门 —— 全面了解 LevelDB 的功能特性
  • 交换综合实验一
  • 你学不懂C语言,是因为不懂编写C程序的7个步骤 ...
  • ​2020 年大前端技术趋势解读
  • ​二进制运算符:(与运算)、|(或运算)、~(取反运算)、^(异或运算)、位移运算符​
  • ​人工智能书单(数学基础篇)
  • !! 2.对十份论文和报告中的关于OpenCV和Android NDK开发的总结
  • # Swust 12th acm 邀请赛# [ E ] 01 String [题解]
  • #Datawhale AI夏令营第4期#多模态大模型复盘
  • #Js篇:单线程模式同步任务异步任务任务队列事件循环setTimeout() setInterval()
  • #微信小程序:微信小程序常见的配置传旨
  • $(selector).each()和$.each()的区别
  • (4)事件处理——(2)在页面加载的时候执行任务(Performing tasks on page load)...
  • (附源码)计算机毕业设计ssm基于B_S的汽车售后服务管理系统
  • (全注解开发)学习Spring-MVC的第三天
  • (一)WLAN定义和基本架构转
  • (原创)可支持最大高度的NestedScrollView
  • (转)chrome浏览器收藏夹(书签)的导出与导入
  • (转)eclipse内存溢出设置 -Xms212m -Xmx804m -XX:PermSize=250M -XX:MaxPermSize=356m
  • .Net CoreRabbitMQ消息存储可靠机制