当前位置: 首页 > news >正文

Pytorch将标签转为One-Hot编码

一、标签映射与One-Hot编码过程

先进行标签映射,要为每个分类建立一个整数索引,对于每个样本的标签,使用整数索引创建一个长度为类别总数的二进制向量。这个向量的所有元素都是0,除了与整数索引相对应的位置,该位置的值为1。

二、pytorch的官方实现

在pytorch中实现了one hot编码,就在torch.nn.functional里面,下面是它的注释当中的示例,我们开看看:

Examples:>>> F.one_hot(torch.arange(0, 5) % 3)tensor([[1, 0, 0],[0, 1, 0],[0, 0, 1],[1, 0, 0],[0, 1, 0]])>>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)tensor([[1, 0, 0, 0, 0],[0, 1, 0, 0, 0],[0, 0, 1, 0, 0],[1, 0, 0, 0, 0],[0, 1, 0, 0, 0]])>>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)tensor([[[1, 0, 0],[0, 1, 0]],[[0, 0, 1],[1, 0, 0]],[[0, 1, 0],[0, 0, 1]]])

我们可以根据那自己实现的与它给出的这个示例进行比对,一样就当然没问题了。

三、手写实现

首先,在原先的函数(one_hot)当中numclass=-1,类别当然不能为1,说明这里是自动进行了计算,大家普遍使用的方式都是创建一个全零矩阵,使用 scatter_ 函数进行独热编码,作用是按照给定的索引,在指定的维度上进行赋值。

def one_hot(labels, num_classes=-1):"""将标签转为独热编码, 经过测试与torch.nn.functional里面的函数测试相同:param labels: 标签:param num_classes: 默认为-1, 表示进行自动计算类别最大的那个Examples:>>> label_1 = torch.arange(0, 5) % 3# tensor([0, 1, 2, 0, 1])>>> label_2 = torch.arange(0, 6).view(3, 2) % 3# tensor([[0, 1], [2, 0], [1, 2]])>>> print(one_hot(label_1))tensor([[1, 0, 0],[0, 1, 0],[0, 0, 1],[1, 0, 0],[0, 1, 0]])>>> print(one_hot(label_1, 5))tensor([[1, 0, 0, 0, 0],[0, 1, 0, 0, 0],[0, 0, 1, 0, 0],[1, 0, 0, 0, 0],[0, 1, 0, 0, 0]])>>> print(one_hot(label_2))tensor([[[1, 0, 0],[0, 1, 0]],[[0, 0, 1],[1, 0, 0]],[[0, 1, 0],[0, 0, 1]]])"""if num_classes == -1:num_classes = int(labels.max()) + 1one_hot_tensor = torch.zeros(labels.size() + (num_classes,), dtype=torch.int64)one_hot_tensor.scatter_(-1, labels.unsqueeze(-1).to(torch.int64), 1)return one_hot_tensorlabel_1 = torch.arange(0, 5) % 3
# tensor([0, 1, 2, 0, 1])
label_2 = torch.arange(0, 6).view(3, 2) % 3
# tensor([[0, 1], [2, 0], [1, 2]])
print(one_hot(label_1))
print(one_hot(label_1, 5))
print(one_hot(label_2))

首先是判断分类数是不是为-1,如果是就根据其中的最大值+1进行自动计算。然后创建一个契合分类数量的全零矩阵。

在这里,labels.unsqueeze(-1)用于在标签的最后一个维度上添加一个维度,以便与独热编码张量进行广播操作。

假设原始的 labels 张量的形状为 (batch_size,),那么经过 unsqueeze(-1) 操作后,形状变为 (batch_size, 1)。这样,每个样本的标签都被表示为一个列向量,而不再是一个标量。scatter_函数在最后一个维度进行操作,也就是对类别总数的维度进行操作,而 1 是要赋给相应位置的值。

labels.unsqueeze(-1) 已经确保了与 one_hot_tensor 的形状匹配,所以在这里能够正确地进行广播和赋值操作。

下面这一种是应用于分割网络当中,在保留输入标签张量形状的同时,将独热编码张量的最后一个维度设置为分类数num_classes,确保独热编码张量与输入标签张量具有相同的形状。

def get_one_hot(labels, num_classes=-1):"""用于分割网络的one hot"""labels = torch.as_tensor(labels)ones = one_hot(labels, num_classes)return ones.view(*labels.size(), num_classes)if __name__=="__main__":seg_labels = torch.randint(0, 3, size=[512, 512])print(get_one_hot(seg_labels))print(get_one_hot(seg_labels).shape)   # torch.Size([512, 512, 3])

你可以将这里应用于自定义dataset部分。

相关文章:

  • 模型的权值平均的原理和Pytorch的实现
  • Spark与云存储的集成:S3、Azure Blob Storage
  • 基于JavaWeb+BS架构+SpringBoot+Vue协同过滤算法的体育商品推荐系统的设计和实现
  • 2023年全国职业院校技能大赛(高职组)“云计算应用”赛项赛卷⑦
  • 【Qt之Quick模块】8. Quick基础、布局管理、布局管理器
  • U-Boot学习(2):U-Boot编译和.config配置文件生成分析
  • 一、Mybatis 简介
  • C //练习 5-4 编写函数strend(s, t)。如果字符串t出现在字符串s的尾部,该函数返回1;否则返回0。
  • 微信小程序:发送小程序订阅消息
  • PostgreSQL 低级错误集锦 (不定时更新)
  • 10个提高 Python Web 开发效率的VS Code插件
  • 大气精美网站APP官网HTML源码
  • HarmonyOS 容器组件(Column Row Flex)
  • 前端基础 keep-alive的使用(Vue)
  • 基于JAVA+SpringBoot的高校学术报告系统
  • 【391天】每日项目总结系列128(2018.03.03)
  • 【css3】浏览器内核及其兼容性
  • 【跃迁之路】【735天】程序员高效学习方法论探索系列(实验阶段492-2019.2.25)...
  • Bytom交易说明(账户管理模式)
  • CODING 缺陷管理功能正式开始公测
  • ES6之路之模块详解
  • Hexo+码云+git快速搭建免费的静态Blog
  • idea + plantuml 画流程图
  • JavaScript设计模式与开发实践系列之策略模式
  • Joomla 2.x, 3.x useful code cheatsheet
  • JSONP原理
  • nginx 负载服务器优化
  • nodejs调试方法
  • Protobuf3语言指南
  • Python学习笔记 字符串拼接
  • SpiderData 2019年2月16日 DApp数据排行榜
  • 猫头鹰的深夜翻译:JDK9 NotNullOrElse方法
  • 如何打造100亿SDK累计覆盖量的大数据系统
  • 深度学习在携程攻略社区的应用
  • 思考 CSS 架构
  • 线性表及其算法(java实现)
  • 转载:[译] 内容加速黑科技趣谈
  • 我们雇佣了一只大猴子...
  • ​水经微图Web1.5.0版即将上线
  • # C++之functional库用法整理
  • #pragma预处理命令
  • (10)Linux冯诺依曼结构操作系统的再次理解
  • (每日持续更新)jdk api之FileFilter基础、应用、实战
  • (三)mysql_MYSQL(三)
  • (小白学Java)Java简介和基本配置
  • (新)网络工程师考点串讲与真题详解
  • ***监测系统的构建(chkrootkit )
  • ... 是什么 ?... 有什么用处?
  • .bat批处理(十):从路径字符串中截取盘符、文件名、后缀名等信息
  • .h头文件 .lib动态链接库文件 .dll 动态链接库
  • .NET Compact Framework 多线程环境下的UI异步刷新
  • .net core 实现redis分片_基于 Redis 的分布式任务调度框架 earth-frost
  • .NET/C# 项目如何优雅地设置条件编译符号?
  • .net利用SQLBulkCopy进行数据库之间的大批量数据传递
  • .Net下的签名与混淆