当前位置: 首页 > news >正文

【PYG】Cora数据集分析argmax(dim=1)输出

简单版本

  • argmax(dim=1):这是一个张量操作,用于在指定维度上找到最大值的索引。dim=1 表示我们在类别维度上进行操作,因此对于每个节点或样本,argmax(dim=1) 将返回具有最大对数概率的类别索引。
  • 通过 argmax(dim=1),我们从模型的输出中提取每个节点或样本的预测类别。对于分类任务,这一步是必要的,因为模型的输出通常是每个类别的对数概率分布,我们需要从中选出概率最大的类别作为预测结果。
  • 假设 out 的输出如下:
out = torch.tensor([[0.1, 2.3, 0.4],[1.2, 0.8, 3.5],[0.6, 0.7, 0.2]])

使用 out.argmax(dim=1) 提取预测类别:

pred = out.argmax(dim=1)
print(pred)

输出将是:

tensor([1, 2, 1])

这表示模型预测第一个样本的类别为1,第二个样本的类别为2,第三个样本的类别为1。通过这种方式,我们可以从模型的输出中提取每个样本的预测类别,用于后续的评估或其他处理。

Cora数据集结合验证函数分析

out[0]与out[0:1,:],同样是取出一行,out[0]得到的是一维向量torch.Size([7]),out[0:1,:]得到的是二维向量torch.Size([1, 7]),二维向量就可以在这一行使用argmax(dim=1)返回最大数的索引
结合输出看 tensor([[-1.9416, -1.9553, -1.9600, -1.9397, -1.9333, -1.9581, -1.9338]]中最大的数是-1.9333,第五个数索引是4,所以test out[0:1,:].argmax(dim=1)的输出是tensor([4], device=‘cuda:0’)

test out torch.Size([2708, 7])
test out[0] torch.Size([7]) tensor([-1.9416, -1.9553, -1.9600, -1.9397, -1.9333, -1.9581, -1.9338],device='cuda:0', grad_fn=<SelectBackward0>)
test out[0:1,:] torch.Size([1, 7]) tensor([[-1.9416, -1.9553, -1.9600, -1.9397, -1.9333, -1.9581, -1.9338]],device='cuda:0', grad_fn=<SliceBackward0>)
test out[0:1,:].argmax(dim=1) tensor([4], device='cuda:0')
test pred torch.Size([1000]) 
def test():model.eval()out = model(data)print(f"test out {out.shape}")print(f"test out[0] {out[0].shape} {out[0]}")print(f"test out[0:1,:] {out[0:1,:].shape} {out[0:1,:]}")print(f"test out[0:1,:].argmax(dim=1) {out[0:1,:].argmax(dim=1)}")pred = out.argmax(dim=1)print(f"test pred {pred[data.test_mask].shape} {pred[data.test_mask]}")print(f"data {data.y[data.test_mask].shape} {data.y[data.test_mask]}")correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()acc = int(correct) / int(data.test_mask.sum())return acc

前置知识:布尔索引
out:这是模型的输出,通常是形状为 [num_nodes, num_classes] 的张量。例如,[2708,7]表明有2708个节点,7个类别
data.test_mask:这是一个布尔张量,形状为 [num_nodes],用于指示哪些节点属于验证集。True 表示该节点属于验证集,False 表示该节点不属于验证集。统计test_mask中True的个数可以确认训练集的个数,该例中是1000,num_test = data.test_mask.sum().item()

out[data.train_mask]:使用布尔索引从 out 中选择训练集中节点的输出。这将返回一个形状为 [num_train_nodes, num_classes] 的张量,其中 num_train_nodes 是训练集中节点的数量。out[data.train_mask] 的形状为 [140, 7],因为 data.train_mask 选择了 140 个节点进行训练,每个节点对应 7 个类别的未归一化分数

完整代码(胡乱改改)

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures# 加载Cora数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]# 定义GCN模型
class GCN(torch.nn.Module):def __init__(self):super(GCN, self).__init__()self.conv1 = GCNConv(dataset.num_node_features, 16)self.conv2 = GCNConv(16, dataset.num_classes)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)# 初始化模型和优化器
model = GCN()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
data = data.to('cuda')
model = model.to('cuda')print(f"data.train_mask{data.train_mask}")# 训练模型
def train():model.train()optimizer.zero_grad()out = model(data)# print(f"out[data.train_mask] {data.train_mask.shape} {out[data.train_mask].shape} {out[data.train_mask]}")loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss.item()# 评估模型
def test():model.eval()out = model(data)print(f"test out {out.shape}")print(f"test out[0] {out[0].shape} {out[0]}")print(f"test out[0:1,:] {out[0:1,:].shape} {out[0:1,:]}")print(f"test out[0:1,:].argmax(dim=1) {out[0:1,:].argmax(dim=1)}")pred = out.argmax(dim=1)print(f"test pred {pred[data.test_mask].shape} {pred[data.test_mask]}")print(f"data {data.y[data.test_mask].shape} {data.y[data.test_mask]}")correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()acc = int(correct) / int(data.test_mask.sum())return accfor epoch in range(1):loss = train()acc = test()print(f'Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

相关文章:

  • 初学51单片机之简易电子密码锁及PWM应用扩展
  • 二维码登录的原理
  • vue根据文字长短展示跑马灯效果
  • Kafka-服务端-副本同步-源码流程
  • 编程入门:从零开始学习编程的方法与步骤
  • Java List操作详解及常用方法
  • 【Llama 2的使用方法】
  • 大学生放学后一定要做的4件事情
  • PO模式简介
  • 什么是有效的电子签名?PDF电子签名怎样具备法律效力?
  • 发电机保护屏的作用及其重要性
  • 亚马逊等跨境电商测评怎么做?
  • Chapter8 透明效果——Shader入门精要学习笔记
  • 【愤怒的小方块案例 Objective-C语言】
  • Java实现数据结构——不带头单链表
  • -------------------- 第二讲-------- 第一节------在此给出链表的基本操作
  • 2017届校招提前批面试回顾
  • 230. Kth Smallest Element in a BST
  • JDK9: 集成 Jshell 和 Maven 项目.
  • React中的“虫洞”——Context
  • Redash本地开发环境搭建
  • Spring Boot MyBatis配置多种数据库
  • Sublime text 3 3103 注册码
  • TiDB 源码阅读系列文章(十)Chunk 和执行框架简介
  • VirtualBox 安装过程中出现 Running VMs found 错误的解决过程
  • WinRAR存在严重的安全漏洞影响5亿用户
  • 彻底搞懂浏览器Event-loop
  • 等保2.0 | 几维安全发布等保检测、等保加固专版 加速企业等保合规
  • 第十八天-企业应用架构模式-基本模式
  • 计算机常识 - 收藏集 - 掘金
  • 京东美团研发面经
  • 如何优雅的使用vue+Dcloud(Hbuild)开发混合app
  • 想晋级高级工程师只知道表面是不够的!Git内部原理介绍
  • 延迟脚本的方式
  • 云大使推广中的常见热门问题
  • 400多位云计算专家和开发者,加入了同一个组织 ...
  • 阿里云重庆大学大数据训练营落地分享
  • 东超科技获得千万级Pre-A轮融资,投资方为中科创星 ...
  • ​【已解决】npm install​卡主不动的情况
  • ​configparser --- 配置文件解析器​
  • ​软考-高级-系统架构设计师教程(清华第2版)【第20章 系统架构设计师论文写作要点(P717~728)-思维导图】​
  • #我与Java虚拟机的故事#连载09:面试大厂逃不过的JVM
  • #我与虚拟机的故事#连载20:周志明虚拟机第 3 版:到底值不值得买?
  • $$$$GB2312-80区位编码表$$$$
  • (¥1011)-(一千零一拾一元整)输出
  • (AtCoder Beginner Contest 340) -- F - S = 1 -- 题解
  • (补)B+树一些思想
  • (附程序)AD采集中的10种经典软件滤波程序优缺点分析
  • (附源码)ssm高校运动会管理系统 毕业设计 020419
  • (六)什么是Vite——热更新时vite、webpack做了什么
  • (三维重建学习)已有位姿放入colmap和3D Gaussian Splatting训练
  • (四)七种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB
  • (转)关于如何学好游戏3D引擎编程的一些经验
  • .NET Core6.0 MVC+layui+SqlSugar 简单增删改查
  • .NET NPOI导出Excel详解