当前位置: 首页 > news >正文

[pytorch] --- pytorch基础之tensorboard使用

0 tensorboard介绍

TensorBoard是一个用于可视化机器学习实验结果的工具,可以帮助我们更好地理解和调试训练过程中的模型。

在PyTorch中,我们可以使用TensorBoardX库来与TensorBoard进行交互。TensorBoardX 是一个PyTorch的扩展,它允许我们将PyTorch的训练中的关键指标和摘要写入TensorBoard的事件文件中。

1 tensorboard使用步骤

1.1 安装TensorBoard

确保你已安装TensorBoard。对于PyTorch用户,TensorBoard也可以独立安装:

pip install tensorboard

1.2 在你的代码中配置TensorBoard

使用PyTorch时,你可以通过torch.utils.tensorboard模块来使用TensorBoard。首先,导入SummaryWriter来记录事件:

    from torch.utils.tensorboard import SummaryWriter# 初始化SummaryWriterwriter = SummaryWriter('runs/experiment_name')

然后,在你的训练循环中,使用writer.add_scalar等方法来记录你感兴趣的信息,例如损失和准确率:

    for epoch in range(num_epochs):# 训练模型...loss = ...accuracy = ...# 记录损失和准确率writer.add_scalar('Loss/train', loss, epoch)writer.add_scalar('Accuracy/train', accuracy, epoch)# 关闭writerwriter.close()

1.3 在PyCharm中启动TensorBoard

接下来,有两种方法在PyCharm中查看TensorBoard:

方法一:使用Terminal
1> 打开PyCharm的Terminal。
2> 导航到你的项目目录。
3> 使用以下命令启动TensorBoard:

tensorboard --logdir=runs/

Note: 一定要进入对应的虚拟环境才可使用命令打开Tenssorboard

方法二:
配置PyCharm运行配置
1> 在PyCharm中,点击右上角的“Add Configuration”。
2> 点击"+“,选择"Python”。
3> 在"Script path"中,找到并输入tensorboard的执行文件路径。
4> 在"Parameters"字段中,输入--logdir=runs/,确保路径与你的TensorBoard日志目录匹配。
5> 保存配置,然后你可以通过点击运行按钮来启动TensorBoard。

1.4 浏览TensorBoard

在TensorBoard启动后,通过浏览器访问TensorBoard界面,你可以看到损失、准确率、图像示例等多种类型的日志信息,这些都可以帮助你分析和改进你的模型。
Note:

  • 当使用PyTorch时,SummaryWriter的路径(例如runs/experiment_name)定义了TensorBoard日志的存储位置。确保每次实验使用不同的名称,以便在TensorBoard中清晰地区分它们。
  • 利用TensorBoard的高级特性,如图像、图表和直方图记录,可以提供更多关于模型训练过程和结果的洞察。

2 实例演示tensborboard使用

步骤一:创建PyTorch模型

首先,我们定义一个简单的线性回归模型。

import torch
import torch.nn as nn
import numpy as np
from torch.utils.tensorboard import SummaryWriter# 定义模型
class LinearRegressionModel(nn.Module):def __init__(self):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(1, 1)  # 输入和输出都是1维def forward(self, x):return self.linear(x)

步骤2: 训练模型并记录日志

接着,我们将准备数据、定义损失函数和优化器,并在训练循环中使用SummaryWriter来记录损失:

# 准备数据
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168], [9.779], [6.182], [7.59], [2.167], [7.042], [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573], [3.366], [2.596], [2.53], [1.221], [2.827], [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)# 初始化模型
model = LinearRegressionModel()# 损失和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 初始化SummaryWriter
writer = SummaryWriter('runs/linear_regression_experiment')# 训练模型
num_epochs = 100
for epoch in range(num_epochs):# 转换为tensorinputs = x_traintargets = y_train# 前向传播outputs = model(inputs)loss = criterion(outputs, targets)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 记录损失writer.add_scalar('Loss/train', loss.item(), epoch)if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 关闭SummaryWriter
writer.close()

步骤三:命令行中驱动tensorboard

使用以下命令启动TensorBoard:

tensorboard --logdir=runs/

在这里插入图片描述

步骤4: 观察TensorBoard

点击生成的链接http://localhost:6006/即可查看结果:
在这里插入图片描述

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Vue 登录状态判断与跳转指南
  • 一.海量数据实时分析-Doris入门和安装
  • JMeter之上传文件同时带有参数
  • Python计算机视觉四章-照相机模型与增强现实
  • Spring Cloud全解析:网关之GateWay过滤器
  • RASA使用长文记录以及一些bug整理
  • 鸿蒙启动框架配置文件(StartUpTask)
  • 学习记录:js算法(二十一):字符串的排列、替换后的最长重复字符
  • YOLOv9改进策略【模型轻量化】| MoblieNetV3:基于搜索技术和新颖架构设计的轻量型网络模型
  • 前端内存泄露案例与解决方案
  • Ubuntu 安装个人热点
  • 字符集介绍
  • 八、2 DMA数据转运 DMA函数介绍
  • 使用 streamlink 把 m3u8 转为 mp4
  • 如何使用IDEA搭建Mybatis框架环境(详细教程)
  • 《深入 React 技术栈》
  • 【编码】-360实习笔试编程题(二)-2016.03.29
  • 0基础学习移动端适配
  • codis proxy处理流程
  • css选择器
  • FineReport中如何实现自动滚屏效果
  • gops —— Go 程序诊断分析工具
  • HTTP中GET与POST的区别 99%的错误认识
  • JavaScript类型识别
  • LeetCode算法系列_0891_子序列宽度之和
  • Logstash 参考指南(目录)
  • nodejs实现webservice问题总结
  • Python打包系统简单入门
  • React组件设计模式(一)
  • ⭐ Unity 开发bug —— 打包后shader失效或者bug (我这里用Shader做两张图片的合并发现了问题)
  • Web设计流程优化:网页效果图设计新思路
  • 创建一个Struts2项目maven 方式
  • 前端之Sass/Scss实战笔记
  • 实战:基于Spring Boot快速开发RESTful风格API接口
  • 我与Jetbrains的这些年
  • ​用户画像从0到100的构建思路
  • # 消息中间件 RocketMQ 高级功能和源码分析(七)
  • #laravel 通过手动安装依赖PHPExcel#
  • (done) ROC曲线 和 AUC值 分别是什么?
  • (六)库存超卖案例实战——使用mysql分布式锁解决“超卖”问题
  • (牛客腾讯思维编程题)编码编码分组打印下标题目分析
  • (求助)用傲游上csdn博客时标签栏和网址栏一直显示袁萌 的头像
  • (十六)Flask之蓝图
  • (十三)MipMap
  • (一)模式识别——基于SVM的道路分割实验(附资源)
  • (原)本想说脏话,奈何已放下
  • (原創) 是否该学PetShop将Model和BLL分开? (.NET) (N-Tier) (PetShop) (OO)
  • (转)eclipse内存溢出设置 -Xms212m -Xmx804m -XX:PermSize=250M -XX:MaxPermSize=356m
  • (转)Google的Objective-C编码规范
  • (转)ObjectiveC 深浅拷贝学习
  • (自适应手机端)响应式服装服饰外贸企业网站模板
  • .NET Core 成都线下面基会拉开序幕
  • .net 发送邮件
  • .NET开发者必备的11款免费工具
  • .net网站发布-允许更新此预编译站点