当前位置: 首页 > news >正文

【深度学习】PyTorch深度学习笔记02-线性模型

1. 监督学习

2. 数据集的划分

3. 平均平方误差MSE

4. 线性模型Linear Model - y = x * w

用穷举法确定线性模型的参数

import numpy as np
import matplotlib.pyplot as pltx_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]def forward(x):return x * wdef loss(x, y):y_pred = forward(x)return (y_pred - y) * (y_pred - y)w_list = []
mse_list = []for w in np.arange(0.0, 4.0, 0.1):print('w=', w)l_sum = 0for x_val, y_val in zip(x_data, y_data):  y_pred_val = forward(x_val)loss_val = loss(x_val, y_val)  l_sum += loss_valprint('\t', x_val, y_val, y_pred_val, loss_val)print('MSE=', l_sum / len(x_data))  w_list.append(w)mse_list.append(l_sum / len(x_data))plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()

详细过程

    本课程的主要任务是构建一个完整的线性模型:
        导入numpy和matplotlib库;
        导入数据 x_data 和 y_data;
        定义前向传播函数:
            forward:输出是预测值y_hat
        定义损失函数:
            loss:平方误差
        创建两个空列表,后面绘图的时候要用:
            分别是横轴的w_list和纵轴的mse_list
        开始计算(这里没有训练的概念,只是单纯的计算每一个数据对应的预测值,然后让预测值跟真实y值求MSE):
            外层循环:
                在0.0~4.0之间均匀取点,步长0.1,作为n个横坐标自变量,用w表示;
            内层循环:核心计算内容
                从数据集中,按数据对取出自变量x_val和真实值y_val;
                先调用forward函数,计算y的预测值 w*x
                调用loss函数,计算单个数据的平方误差;
                累加损失;
                打印想要看到的数值;
                在外层循环中,把计算的结果放进之前的空列表,用于绘图;
    在获得了打印所需的数据列表之后,模式化地打印图像:

运行结果

ps:

visdom库可用于可视化

np.meshgrid()可用于绘制三维图

5. 线性模型Linear Model - y = x * w + b

有w,b两个参数,穷举最小值

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3Dx_data = [1.0, 2.0, 3.0]
y_data = [3.0, 4.0, 6.0]def forward(x, w, b):return x * w + bdef loss(x, y, w, b):y_pred = forward(x, w, b)loss = (y_pred - y) * (y_pred - y)return lossw_list = np.arange(0.0, 4.1, 0.1)
b_list = np.arange(-2.0, 2.1, 0.1)# np.zeros(): 返回给定维度的全零数组; mse_matrix用于存储不同 w,b 组合下的均方误差损失
mse_matrix = np.zeros((len(w_list), len(b_list)))for i, w in enumerate(w_list):for j, b in enumerate(b_list):l_sum = 0for x_val, y_val in zip(x_data, y_data):l_sum += loss(x_val, y_val, w, b)mse_matrix[i, j] = l_sum / len(x_data)W, B = np.meshgrid(w_list, b_list)
fig = plt.figure('Linear Model Cost Value')
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(W, B, mse_matrix.T, cmap='viridis')
ax.set_xlabel('w')
ax.set_ylabel('b')
ax.set_zlabel('loss')
plt.show()

可以得出,穷举法算法的时间复杂度 随着参数的个数增大 而变得很大,因此使用穷举法找到最优解,很不合理。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 百度安全大模型智能体实践入选信通院“安全守卫者计划”优秀案例
  • 专业条码二维码扫描设备和手机二维码扫描软件的区别?
  • 【Java--数据结构】栈:不仅仅是数据存储,它是编程的艺术
  • Docker 容器出现 IP 冲突
  • 深度加速器 为游戏而生
  • 【ARM】CCI缓存一致性整理
  • [论文笔记]RAPTOR: RECURSIVE ABSTRACTIVE PROCESSING FOR TREE-ORGANIZED RETRIEVAL
  • 【LeetCode】2187. 完成旅途的最少时间
  • 基于Python/MATLAB长时间序列遥感数据处理及在全球变化、植被物候提取、植被变绿与生态系统固碳分析、生物量估算与趋势分析应用
  • Three.js相机简明教程
  • 期货量化交易客户端开源教学第三节——键盘通信协议
  • CSS相对定位和绝对定位的区别
  • 了解Maven
  • stm32中断详解
  • LabVIEW滤波器性能研究
  • 9月CHINA-PUB-OPENDAY技术沙龙——IPHONE
  • 【跃迁之路】【444天】程序员高效学习方法论探索系列(实验阶段201-2018.04.25)...
  • 【跃迁之路】【735天】程序员高效学习方法论探索系列(实验阶段492-2019.2.25)...
  • 2017 前端面试准备 - 收藏集 - 掘金
  • Angular 2 DI - IoC DI - 1
  • Angular 响应式表单 基础例子
  • Computed property XXX was assigned to but it has no setter
  • If…else
  • Java精华积累:初学者都应该搞懂的问题
  • JS创建对象模式及其对象原型链探究(一):Object模式
  • Linux下的乱码问题
  • mysql innodb 索引使用指南
  • Redux系列x:源码分析
  • spring cloud gateway 源码解析(4)跨域问题处理
  • zookeeper系列(七)实战分布式命名服务
  • 从重复到重用
  • 订阅Forge Viewer所有的事件
  • 基于OpenResty的Lua Web框架lor0.0.2预览版发布
  • 理解在java “”i=i++;”所发生的事情
  • 猫头鹰的深夜翻译:JDK9 NotNullOrElse方法
  • 浅谈JavaScript的面向对象和它的封装、继承、多态
  • 如何抓住下一波零售风口?看RPA玩转零售自动化
  • 入手阿里云新服务器的部署NODE
  • 【干货分享】dos命令大全
  • ​什么是bug?bug的源头在哪里?
  • #!/usr/bin/python与#!/usr/bin/env python的区别
  • ###51单片机学习(1)-----单片机烧录软件的使用,以及如何建立一个工程项目
  • #mysql 8.0 踩坑日记
  • #NOIP 2014#day.2 T1 无限网络发射器选址
  • #pragam once 和 #ifndef 预编译头
  • #我与Java虚拟机的故事#连载13:有这本书就够了
  • $LayoutParams cannot be cast to android.widget.RelativeLayout$LayoutParams
  • ()、[]、{}、(())、[[]]命令替换
  • (2)nginx 安装、启停
  • (C语言)fread与fwrite详解
  • (C语言)求出1,2,5三个数不同个数组合为100的组合个数
  • (MATLAB)第五章-矩阵运算
  • (三) prometheus + grafana + alertmanager 配置Redis监控
  • (原创) cocos2dx使用Curl连接网络(客户端)
  • (中等) HDU 4370 0 or 1,建模+Dijkstra。