当前位置: 首页 > news >正文

手动构建线性回归(PyTorch)

import torch
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import random
#1.构建数据
#构建数据集
def create_dataset():x,y,coef=make_regression(n_samples=100,n_features=1,random_state=0,noise=10,coef=True,bias=14.5)#将构建数据转换为张量类型x=torch.tensor(x)y=torch.tensor(y)return x,y#构建数据加载器
def data_loader(x,y, batch_size):#计算下样本的数量data_len = len(y)#构建数据索引data_index=list(range(data_len))random.shuffle(data_index)#计算总的batch数量batch_number=data_len//batch_sizefor idx in range(batch_number):start=idx+batch_sizeend=start+batch_sizebatch_train_x=x[start:end]batch_train_y=y[start:end]yield batch_train_x,batch_train_ydef test01():x,y=create_dataset()plt.scatter(x,y)plt.show()for x,y in data_loader(x,y,batch_size=10):print(y)
#2.假设函数、损失函数、优化方法
#损失函数:平均损失
#优化方法:梯度下降
#假设函数
w=torch.tensor(0.1,requires_grad=True,dtype=torch.float64)
b=torch.tensor(0.1,requires_grad=True,dtype=torch.float64)def linear_regression(x):return w*x+b#损失函数
def square_loss(y_pred,y_true):return torch.square(y_pred - y_true)#优化方法
def sqd(lr=1e-2):#除以16是使用的是批次样本的平均梯度w.data=w.data-lr*w.grad.data/16b.data=b.data-lr*b.grad.data/16if __name__ == '__main__':test01()

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 如何搭建一个RADIUS服务器?
  • vue3封装el-table及实现表头自定义筛选
  • CSP-J模拟赛day1
  • 【计算机网络】0 课程主要内容(自顶向下方法,中科大郑烇、杨坚)(待)
  • Flutter 开源库学习
  • 【Linux】shell简单模拟实现
  • Adobe Premiere Pro(Pr)安装包软件下载
  • 2024年热门硬盘数据恢复软件大盘点:高效恢复您的宝贵数据
  • 【数据结构】二叉树OJ题_对称二叉树_另一棵的子树
  • NAS新品“翻车”后,绿联科技要上市了
  • Chapter 5: 二叉树详解
  • docker默认存储地址 var/lib/docker 满了,换个存储地址操作流程
  • HardeningMeter:一款针对二进制文件和系统安全强度的开源工具
  • 项目收获总结--MyBatis的知识收获
  • linux在ssh的时候询问,yes or no 如何关闭
  • 《Java编程思想》读书笔记-对象导论
  • Angular6错误 Service: No provider for Renderer2
  • CSS选择器——伪元素选择器之处理父元素高度及外边距溢出
  • gf框架之分页模块(五) - 自定义分页
  • IDEA 插件开发入门教程
  • React-Native - 收藏集 - 掘金
  • 从 Android Sample ApiDemos 中学习 android.animation API 的用法
  • 给Prometheus造假数据的方法
  • 关于Java中分层中遇到的一些问题
  • ​flutter 代码混淆
  • ​TypeScript都不会用,也敢说会前端?
  • ​什么是bug?bug的源头在哪里?
  • ​油烟净化器电源安全,保障健康餐饮生活
  • #我与Java虚拟机的故事#连载16:打开Java世界大门的钥匙
  • $(selector).each()和$.each()的区别
  • (day18) leetcode 204.计数质数
  • (js)循环条件满足时终止循环
  • (MonoGame从入门到放弃-1) MonoGame环境搭建
  • (solr系列:一)使用tomcat部署solr服务
  • (vue)页面文件上传获取:action地址
  • (Windows环境)FFMPEG编译,包含编译x264以及x265
  • (板子)A* astar算法,AcWing第k短路+八数码 带注释
  • (博弈 sg入门)kiki's game -- hdu -- 2147
  • (六)软件测试分工
  • (免费分享)基于springboot,vue疗养中心管理系统
  • (南京观海微电子)——示波器使用介绍
  • .Net Core webapi RestFul 统一接口数据返回格式
  • .net 简单实现MD5
  • .NET多线程执行函数
  • @EventListener注解使用说明
  • [ vulhub漏洞复现篇 ] Apache APISIX 默认密钥漏洞 CVE-2020-13945
  • [18] Opencv_CUDA应用之 基于颜色的对象检测与跟踪
  • [AIGC] Redis基础命令集详细介绍
  • [Android] Android ActivityManager
  • [C#] 基于 yield 语句的迭代器逻辑懒执行
  • [C++]: std::move
  • [C++]——带你学习类和对象
  • [C++]运行时,如何确保一个对象是只读的
  • [CareerCup] 12.3 Test Move Method in a Chess Game 测试象棋游戏中的移动方法
  • [CareerCup][Google Interview] 实现一个具有get_min的Queue