当前位置: 首页 > news >正文

pytorch深度学习基础 8(简单的神经网络替换线性模型)

接上一节的思路,这一节我们将使用神经网络来代替我们的之前的线性模型作为逼近函数。我们将保持其他的一切不变,只重新定义模型,小编这里构建的是最简单的神经网络,一个线性模块,一个激活函数,然后一个线性模块。

seq_model = nn.Sequential(OrderedDict([('hidden_linear', nn.Linear(1, 8)),('hidden_activation', nn.Tanh()),('output_linear', nn.Linear(8, 1))
]))

这个代码片段定义了一个简单的神经网络模型,使用nn.SequentialOrderedDict来组织模型的层。这个模型包含一个隐藏层和一个输出层,隐藏层使用Tanh激活函数。

模型结构

  1. 隐藏层 (hidden_linear):
    • 输入维度: 1
    • 输出维度: 8
    • 线性变换: nn.Linear(1, 8)
  2. 激活函数 (hidden_activation):
    • 激活函数: Tanh (nn.Tanh())
  3. 输出层 (output_linear):
    • 输入维度: 8
    • 输出维度: 1
    • 线性变换: nn.Linear(8, 1)
for name, param in seq_model.named_parameters():print(name, param.shape)

seq_model.named_parameters() 方法用于遍历模型中的所有参数,并返回每个参数的名称(name)和参数本身(params)

其他的没有变化

from collections import OrderedDictseq_model = nn.Sequential(OrderedDict([('hidden_linear', nn.Linear(1, 8)),('hidden_activation', nn.Tanh()),('output_linear', nn.Linear(8, 1))
]))optimizer = optim.SGD(seq_model.parameters(), lr=1e-4)  # <1>training_loop(n_epochs=100000,optimizer=optimizer,model=seq_model,loss_fn=nn.MSELoss(),t_u_train=t_un_train,t_u_val=t_un_val,t_c_train=t_c_train,t_c_val=t_c_val)print('output', seq_model(t_un_val))
print('answer', t_c_val)
print('hidden', seq_model.hidden_linear.weight.grad)from matplotlib import pyplot as pltt_range = torch.arange(20., 90.).unsqueeze(1)fig = plt.figure(dpi=100)
plt.xlabel("Fahrenheit")
plt.ylabel("Celsius")
plt.plot(t_u.numpy(), t_c.numpy(), 'o')
plt.plot(t_range.numpy(), seq_model(0.1 * t_range).detach().numpy(), 'c-')
plt.plot(t_u.numpy(), seq_model(0.1 * t_u).detach().numpy(), 'kx')
plt.show()

 可以看到即使使用神经网络进行训练还是有点过拟合的现象发生,总的来说做的还不错

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 52 mysql 启动过程中常见的相关报错信息
  • 讲述Navicat for MySQL定时备份数据库和数据恢复等功能
  • 【Sceneform-EQR】scenefrom-eqr中的几种背景实现(不仅用于AR、三维场景,在图片、视频播放器中也适用)
  • 【Docker项目实战】使用Docker部署miniPaint图片编辑器
  • 如何在AutoGen中使用自定义的大模型
  • 打卡53天------图论(应用题)
  • CRUD的最佳实践,联动前后端,包含微信小程序,API,HTML等(二)
  • 大模型企业应用落地系列》基于大模型的对话式推荐系统》技术架构设计全攻略
  • HarmonyOS应用开发者基础认证
  • IPv4和IPv6的区别是什么?什么是局域网和广域网,公网IP和私有IP?
  • Redis Cluster(无中心化设计)
  • 信号量笔记
  • pytorch FSDP分布式训练minist案例
  • java springboot 集成activeMQ(保姆级别教程)
  • C++学习笔记——交换值
  • python3.6+scrapy+mysql 爬虫实战
  • CentOS学习笔记 - 12. Nginx搭建Centos7.5远程repo
  • java取消线程实例
  • mockjs让前端开发独立于后端
  • REST架构的思考
  • spring security oauth2 password授权模式
  • Vue 2.3、2.4 知识点小结
  • vue2.0一起在懵逼的海洋里越陷越深(四)
  • 代理模式
  • 对象引论
  • 汉诺塔算法
  • 力扣(LeetCode)357
  • 买一台 iPhone X,还是创建一家未来的独角兽?
  • 手机app有了短信验证码还有没必要有图片验证码?
  • 微信小程序上拉加载:onReachBottom详解+设置触发距离
  • 一道闭包题引发的思考
  • 云大使推广中的常见热门问题
  • 深度学习之轻量级神经网络在TWS蓝牙音频处理器上的部署
  • 策略 : 一文教你成为人工智能(AI)领域专家
  • ​Redis 实现计数器和限速器的
  • ​草莓熊python turtle绘图代码(玫瑰花版)附源代码
  • #include到底该写在哪
  • ( )的作用是将计算机中的信息传送给用户,计算机应用基础 吉大15春学期《计算机应用基础》在线作业二及答案...
  • (2)nginx 安装、启停
  • (5)STL算法之复制
  • (独孤九剑)--文件系统
  • (四)软件性能测试
  • (新)网络工程师考点串讲与真题详解
  • .jks文件(JAVA KeyStore)
  • .NET 8 编写 LiteDB vs SQLite 数据库 CRUD 接口性能测试(准备篇)
  • .NET Core工程编译事件$(TargetDir)变量为空引发的思考
  • .net core使用EPPlus设置Excel的页眉和页脚
  • .NET Framework 3.5安装教程
  • .NET 解决重复提交问题
  • .NET大文件上传知识整理
  • /var/log/cvslog 太大
  • @antv/g6 业务场景:流程图
  • @FeignClient 调用另一个服务的test环境,实际上却调用了另一个环境testone的接口,这其中牵扯到k8s容器外容器内的问题,注册到eureka上的是容器外的旧版本...
  • @Mapper作用
  • @Not - Empty-Null-Blank