当前位置: 首页 > news >正文

在PyG上构建自己的数据集

PyG构建自己数据集

PyG简介

PyG(PyTorch Geometric)是一个建立在 PyTorch 基础上的库,用于轻松编写和训练图神经网络(GNN),用于与结构化数据相关的广泛应用。

它包括在图和其他不规则结构上进行深度学习的各种方法,也被称为几何深度学习,来自各种已发表的论文。此外,它还包括易于使用的迷你批量加载器(mini-batch loaders),用于在许多小型和单一的巨型图形上操作;多 GPU 支持、大量常见的基准数据集(基于简单的接口来创建你自己的数据集);以及有用的变换,既可以在任意图形上学习,也可以在 3D 网格或点云上学习。

数据集介绍

本部分用到的也是Cora数据集,但是不是官方版本的数据集,而是非常平易近人的风格,拿来就可以使用,格式如下:
cora.cites
在这里插入图片描述
cora.cites文件格式非常简单,就是两列,代表两个具备边关系的节点。
cora.content
在这里插入图片描述
在这里插入图片描述
cora.content文件内容也很简单,第一列是节点id,最后一列是每个节点的标签,中间的数值是每个节点的特征值。

代码实现

PyG构建数据集,氛围两类,一种是针对小数据集的in_memory_dataset,这种形式可以直接将所用的数据集都加载到内存当中;另一种是针对大数据集的Dataset,这种形式主要是可以对大数据集进行索引,进行batch合并,减少每次内存的数据量。实际业务中,我们大多是用大数据集,因此,就以这个作为例子。

from torch_geometric.data import Dataset, Data
# 定义自己的数据集类
class mydataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(mydataset, self).__init__(root, transform, pre_transform)

    # 原始文件位置
    @property
    def raw_file_names(self):
        return ['cora.content', 'cora.cites']

    # 文件保存位置
    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        pass

    # 数据处理逻辑
    def process(self):
        idx_features_labels = np.genfromtxt(self.raw_paths[0])
        x = idx_features_labels[:, 1:-1]
        x = torch.tensor(x, dtype=torch.float32)
        y, label_dict = self.encode_labels(np.genfromtxt(self.raw_paths[0], dtype='str', usecols=(-1,)))
        y = torch.tensor(y)
        idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
        id_node = {j: i for i, j in enumerate(idx)}

        edges_unordered = np.genfromtxt(self.raw_paths[1], dtype=np.int32)
        edge_str = [id_node[each[0]] for each in edges_unordered]
        edge_end = [id_node[each[1]] for each in edges_unordered]
        edge_index = torch.tensor([edge_str, edge_end], dtype=torch.long)

        data = Data(x=x, edge_index=edge_index, y=y)

        torch.save(data, os.path.join(self.processed_dir, f'data.pt'))

    def encode_labels(self, labels):
        classes = sorted(list(set(labels)))
        labels_id = [classes.index(i) for i in labels]
        label_dict = {i: c for i, c in enumerate(classes)}
        return labels_id, label_dict

    # 定义总数据长度
    def len(self):
        idx_features_labels = np.genfromtxt(self.raw_paths[0], dtype=np.int32)
        uid = idx_features_labels[:, 0:1]
        return len(uid)

    # 定义获取数据方法
    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, f'data.pt'))
        return data
dataset = mydataset('../data/')
data = dataset[0].to(device)

首先,我们定义了自己的一个类,mydataset类,其继承了一个父类-Dataset,这个Dataset类是PyG框架自己定义好的,其中包括数据集下载、数据预处理、数据文件保存、数据检索等等功能,大家可以详细了解一下,我们只对用到的进行解释。

# 原始文件位置
@property
def raw_file_names(self):
    return ['cora.content', 'cora.cites']

raw_file_names:指向自己的文件目录下的文件名,这个可以将你用到的文件按照列表的形式进行展现,如果用cora.content,那就是0,用cora.cites,那就是1;

@property
def processed_file_names(self):
    return 'data.pt'

processed_file_names:指向处理后的数据文件保存文件名称,可以在下次加载数据的时候,直接读取该文件;

def download(self):
    pass

download:该函数是需要去下载数据集的,因为我们是自建数据集,因此,不用;

def process(self):
	#读取cora.content文件
    idx_features_labels = np.genfromtxt(self.raw_paths[0])
    #获取节点特征
    x = idx_features_labels[:, 1:-1]
    #转为tensor,并指定数据类型
    x = torch.tensor(x, dtype=torch.float32)
    #获取每个节点的标签
    y, label_dict = self.encode_labels(np.genfromtxt(self.raw_paths[0], dtype='str', usecols=(-1,)))
    #tensor化
    y = torch.tensor(y)
    #获取每个节点
    idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
    #将每个节点映射为id(从0开始)
    id_node = {j: i for i, j in enumerate(idx)}
	#读取cora.cites
    edges_unordered = np.genfromtxt(self.raw_paths[1], dtype=np.int32)
    #获取每个节点对应的id
    #第一列节点-->id
    edge_str = [id_node[each[0]] for each in edges_unordered]
    #第二列节点-->id
    edge_end = [id_node[each[1]] for each in edges_unordered]
    #将边转为tensor
    edge_index = torch.tensor([edge_str, edge_end], dtype=torch.long)
	#将所有数据加载至Data对象中
    data = Data(x=x, edge_index=edge_index, y=y)
	#保存处理好的图数据,下次可以直接加载
    torch.save(data, os.path.join(self.processed_dir, f'data.pt'))

def encode_labels(self, labels):
    classes = sorted(list(set(labels)))
    labels_id = [classes.index(i) for i in labels]
    label_dict = {i: c for i, c in enumerate(classes)}
    return labels_id, label_dict

process:该函数是处理数据的逻辑函数,大家可以将处理数据的逻辑放在该函数中,主要是节点特征、节点标签、以及边的构成;
self.raw_paths:这个是raw_file_names返回的列表和文件路径拼接之后的结果,就是将文件名扩展为路径+文件名;

# 定义总数据长度
def len(self):
    idx_features_labels = np.genfromtxt(self.raw_paths[0], dtype=np.int32)
    uid = idx_features_labels[:, 0:1]
    return len(uid)

len:获取总数据的长度,为了进行数据分割做准备,可以自己定义;

def get(self, idx):
    data = torch.load(os.path.join(self.processed_dir, f'data.pt'))
    return data

get:制定获取图数据的方式,可以自己定义。

数据输出

在这里插入图片描述
我们可以看到,Data是一个包含所有属性的对象。
x:是27081433的矩阵,即2708个节点,每个节点有1433维;
edge_index:是一个2
5429的矩阵,表示共有5429条边;
y:表示节点的标签,共2708个节点。

数据集划分

我们构建好了自己的数据集格式,但是,进行训练的时候,必须有训练集、验证集和测试集,这块我曾经自己进行实现过,但是,实现起来比较复杂,这个时候发现,原来PyG框架,也把这块给实现了,还是很方便的。

data = T.RandomNodeSplit()(data)

在这里插入图片描述
我们可以看一下RandomNodeSplit,顾名思义,就是随机划分节点,是不是很简单,该函数可以自己划分数据集,自己也可以指定每个数据集的比例,替换其中的参数即可。
在这里插入图片描述
当我们加载完之后,可以看出Data对象中多出来三个,分别是train_mask、val_mask、test_mask,输出看的话,每个都是2708个,但是不同位置上有不同的bool值,就是为了表示该节点是否是训练集、验证集或者测试集。

结语

整体看下来,是不是对于PyG处理数据集有所了解呢,以上已经经过小编的实际运行啦,大家可以拿来改改,用在自己的开发数据集上。
当然,如果有问题或者需要补充的地方,大家可以随时联系我,QQ:1143948594。

相关文章:

  • Docker部署Logstash 7.2.0
  • Nginx -- -- 配置SSL证书
  • DID革命:详解PoP、SBT和VC三种去中心化身份方案
  • Redis与Python交互
  • 算法基础: 位运算
  • 记录一次坑 | 包版本不一致产生的问题的排查过程
  • SmartX Everoute 如何通过微分段技术实现 “零信任” | 社区成长营分享回顾
  • “相信美好,即将发生”——天泽智云
  • 面试阿里技术专家岗,对答如流,这些面试题你能答出多少
  • Spring AOP与事务
  • 时序与空间结构
  • 一幅长文细学TypeScript(一)——上手
  • DM JDBC
  • hadoop2.2.0开机启动的后台服务脚本(请结合上一篇学习)
  • java基于springboot+vue的学生成绩管理系统 elementui
  • [分享]iOS开发 - 实现UITableView Plain SectionView和table不停留一起滑动
  • [译]CSS 居中(Center)方法大合集
  • 2018天猫双11|这就是阿里云!不止有新技术,更有温暖的社会力量
  • Fastjson的基本使用方法大全
  • leetcode-27. Remove Element
  • Map集合、散列表、红黑树介绍
  • node和express搭建代理服务器(源码)
  • Odoo domain写法及运用
  • Rancher如何对接Ceph-RBD块存储
  • Spring Boot快速入门(一):Hello Spring Boot
  • vue脚手架vue-cli
  • 搭建gitbook 和 访问权限认证
  • 力扣(LeetCode)22
  • 力扣(LeetCode)357
  • 人脸识别最新开发经验demo
  • 日剧·日综资源集合(建议收藏)
  • 使用 QuickBI 搭建酷炫可视化分析
  • 微信小程序设置上一页数据
  • 写代码的正确姿势
  • 译自由幺半群
  • 3月7日云栖精选夜读 | RSA 2019安全大会:企业资产管理成行业新风向标,云上安全占绝对优势 ...
  • #stm32驱动外设模块总结w5500模块
  • $.proxy和$.extend
  • (11)MATLAB PCA+SVM 人脸识别
  • (附源码)计算机毕业设计ssm高校《大学语文》课程作业在线管理系统
  • (附源码)小程序 交通违法举报系统 毕业设计 242045
  • (解决办法)ASP.NET导出Excel,打开时提示“您尝试打开文件'XXX.xls'的格式与文件扩展名指定文件不一致
  • (利用IDEA+Maven)定制属于自己的jar包
  • (数位dp) 算法竞赛入门到进阶 书本题集
  • (原创) cocos2dx使用Curl连接网络(客户端)
  • (原創) 人會胖會瘦,都是自我要求的結果 (日記)
  • (转)Mysql的优化设置
  • ***原理与防范
  • **登录+JWT+异常处理+拦截器+ThreadLocal-开发思想与代码实现**
  • .[hudsonL@cock.li].mkp勒索加密数据库完美恢复---惜分飞
  • .jks文件(JAVA KeyStore)
  • .NET 表达式计算:Expression Evaluator
  • @Autowired自动装配
  • @DependsOn:解析 Spring 中的依赖关系之艺术
  • @Transactional类内部访问失效原因详解