当前位置: 首页 > news >正文

深入解析torch.load中的【map_location】参数

深入解析torch.load中的map_location参数


🌵文章目录🌵

  • 🌳引言🌳
  • 🌳map_location参数详解🌳
    • map_location参数的数据类型
    • map_location参数的使用场景
  • 🌳代码实战(详细注释)🌳
  • 🌳参考文档🌳
  • 🌳结尾🌳


🌳引言🌳

在PyTorch中,torch.load()函数是用于加载保存模型或张量数据的重要工具。当我们训练好一个深度学习模型后,通常需要将模型的参数(或称为状态字典,state_dict)保存下来,以便后续进行模型评估、继续训练或部署到其他环境中。在加载这些保存的数据时,map_location参数为我们提供了极大的灵活性,以决定这些数据应该被加载到哪个设备上。本文将详细解析map_location参数的功能和使用方法,并通过实战案例来展示其在不同场景下的应用。

🌳map_location参数详解🌳

map_location参数在torch.load()函数中扮演着至关重要的角色。它决定了从保存的文件中加载数据时应将它们映射到哪个设备上。在PyTorch中,设备可以是CPU或GPU,而GPU可以有多个,每个都有其独立的索引。map_location的灵活使用能够让我们轻松地在不同设备之间迁移模型,从而充分利用不同设备的计算优势。

map_location参数的数据类型

map_location参数的数据类型可以是:

参数类型描述示例
字符串(str)预定义的设备字符串,指定目标设备。1. 'cpu':加载到CPU上;
2. 'cuda:X':加载到索引为X的GPU上。
torch.device对象一个表示目标设备的torch.device对象。1.torch.device('cpu'):加载到CPU上;
2. torch.device('cuda:1'):加载到索引为1的GPU上。
可调用对象(callable)一个接收存储路径并返回新位置的函数。lambda storage, loc: storage.cuda(1):将每个存储对象移动到索引为1的GPU上。
字典(dict)一个将存储路径映射到新位置的字典。{'cuda:1':'cuda:0'}:将原本在GPU 1上的张量加载到GPU 0上。

map_location参数的使用场景

  1. CPU加载:当你想在CPU上加载模型时,可以设置map_location='cpu'。这适用于那些不需要GPU加速的推理任务,或者在没有GPU的环境中部署模型。

  2. 指定GPU加载:如果你有多个GPU,并且想将模型加载到特定的GPU上,可以使用'cuda:X'格式的字符串,其中X是GPU的索引。这在多GPU环境中非常有用,可以确保模型加载到指定的设备上。

  3. 自动选择GPU:如果你只想在GPU上加载模型,但不关心具体是哪一个GPU,可以设置map_location=torch.device('cuda')。这会自动选择第一个可用的GPU来加载模型。

  4. 保持原始设备:如果你想保持模型在加载时的原始设备(即如果模型原先是在GPU上训练的,就仍然在GPU上加载;如果是在CPU上,就在CPU上加载),可以使用map_location=Nonemap_location=torch.device('cpu')(对于CPU模型)和map_location=torch.device('cuda')(对于GPU模型)。

  5. 自定义映射逻辑:通过传递一个可调用对象,你可以实现更复杂的映射逻辑。例如,你可以编写一个函数,根据存储路径或模型结构来决定将模型加载到哪个设备上。这在需要根据特定条件动态选择加载设备时非常有用。

🌳代码实战(详细注释)🌳

下面将通过几个实战案例来展示map_location参数在不同场景下的应用。

案例1:从文件加载张量到CPU

# 案例1:从文件加载张量到CPU
# 使用torch.load()函数加载tensors.pt文件中的所有张量到CPU上
tensors = torch.load('tensors.pt')

案例2:指定设备加载张量

# 案例2:指定设备加载张量
# 使用torch.load()函数并指定map_location参数为CPU设备,加载tensors.pt文件中的所有张量到CPU上
tensors_on_cpu = torch.load('tensors.pt', map_location=torch.device('cpu'))

案例3:使用匿名函数指定加载位置

# 案例3:使用函数指定加载位置
# 使用torch.load()函数和map_location参数为一个lambda函数,该函数不做任何改变,保持张量原始位置(通常是CPU)
tensors_original_location = torch.load('tensors.pt', map_location=lambda storage, loc: storage)

案例4:将张量加载到指定GPU

# 案例4:将张量加载到指定GPU
# 使用torch.load()函数和map_location参数为一个lambda函数,该函数将张量移动到索引为1的GPU上
tensors_on_gpu1 = torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))

案例5:张量从一个GPU映射到另一个GPU

# 案例5:张量从一个GPU映射到另一个GPU
# 使用torch.load()函数和map_location参数为一个字典,将原本在GPU 1上的张量映射到GPU 0上
tensors_mapped = torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

案例6:从io.BytesIO对象加载张量

# 案例6:从io.BytesIO对象加载张量
# 打开tensor.pt文件并读取内容到BytesIO缓冲区
with open('tensor.pt', 'rb') as f:buffer = io.BytesIO(f.read())# 使用torch.load()函数从BytesIO缓冲区加载张量
tensors_from_buffer = torch.load(buffer)

案例7:使用ASCII编码加载模块

# 案例7:使用ASCII编码加载模块
# 使用torch.load()函数和encoding参数为'ascii',加载module.pt文件中的模块(如神经网络模型)
model = torch.load('module.pt', encoding='ascii')

这些案例代码和注释展示了如何使用torch.load()函数的不同map_location参数和编码设置来加载张量和模型。这些设置对于控制数据加载的位置和格式非常重要,特别是在跨设备或跨平台加载数据时。


🌳参考文档🌳

[1] PyTorch官方文档


🌳结尾🌳

亲爱的读者,首先感谢抽出宝贵的时间来阅读我们的博客。我们真诚地欢迎您留下评论和意见💬
俗话说,当局者迷,旁观者清。的客观视角对于我们发现博文的不足、提升内容质量起着不可替代的作用。
如果博文给您带来了些许帮助,那么,希望能为我们点个免费的赞👍👍/收藏👇👇您的支持和鼓励👏👏是我们持续创作✍️✍️的动力
我们会持续努力创作✍️✍️,并不断优化博文质量👨‍💻👨‍💻,只为给带来更佳的阅读体验。
如果有任何疑问或建议,请随时在评论区留言,我们将竭诚为你解答~
愿我们共同成长🌱🌳,共享智慧的果实🍎🍏!


万分感谢🙏🙏点赞👍👍、收藏⭐🌟、评论💬🗯️、关注❤️💚~

相关文章:

  • 安全基础~通用漏洞4
  • Flink流式数据倾斜
  • 案例:爬取豆瓣电影 Top250 的数据
  • VBA技术资料MF117:测试显示器大小
  • 深度学习自然语言处理(NLP)模型BERT:从理论到Pytorch实战
  • 设计模式1-访问者模式
  • Linux 命令行速查表
  • Android 11 访问 Android/data/或者getExternalCacheDir() 非root方式
  • vim常用命令以及配置文件
  • centos安装inpanel
  • 按键扫描16Hz-单片机通用模板
  • PostgreSQL 与 MySQL 相比,优势何在?
  • containerd中文翻译系列(十九)cri插件
  • Java开发IntelliJ IDEA2023
  • Vue 进阶系列丨实现简易VueRouter
  • [Vue CLI 3] 配置解析之 css.extract
  • 【Redis学习笔记】2018-06-28 redis命令源码学习1
  • 【跃迁之路】【735天】程序员高效学习方法论探索系列(实验阶段492-2019.2.25)...
  • 78. Subsets
  • canvas实际项目操作,包含:线条,圆形,扇形,图片绘制,图片圆角遮罩,矩形,弧形文字...
  • CNN 在图像分割中的简史:从 R-CNN 到 Mask R-CNN
  • ECMAScript 6 学习之路 ( 四 ) String 字符串扩展
  • Java方法详解
  • Js基础知识(四) - js运行原理与机制
  • laravel with 查询列表限制条数
  • Python十分钟制作属于你自己的个性logo
  • React 快速上手 - 07 前端路由 react-router
  • Spring技术内幕笔记(2):Spring MVC 与 Web
  • TypeScript迭代器
  • 半理解系列--Promise的进化史
  • 动手做个聊天室,前端工程师百无聊赖的人生
  • 对象引论
  • 技术攻略】php设计模式(一):简介及创建型模式
  • 力扣(LeetCode)56
  • 面试遇到的一些题
  • 前端设计模式
  • 通过几道题目学习二叉搜索树
  • 小程序开发之路(一)
  • 数据可视化之下发图实践
  • ​​​​​​​GitLab 之 GitLab-Runner 安装,配置与问题汇总
  • ​什么是bug?bug的源头在哪里?
  • #、%和$符号在OGNL表达式中经常出现
  • #define MODIFY_REG(REG, CLEARMASK, SETMASK)
  • #etcd#安装时出错
  • $con= MySQL有关填空题_2015年计算机二级考试《MySQL》提高练习题(10)
  • (八)光盘的挂载与解挂、挂载CentOS镜像、rpm安装软件详细学习笔记
  • (论文阅读笔记)Network planning with deep reinforcement learning
  • (每日持续更新)jdk api之StringBufferInputStream基础、应用、实战
  • .bat批处理(十一):替换字符串中包含百分号%的子串
  • .gitignore文件—git忽略文件
  • .net 4.0发布后不能正常显示图片问题
  • .net Application的目录
  • .NET Core实战项目之CMS 第一章 入门篇-开篇及总体规划
  • .Net 高效开发之不可错过的实用工具
  • .NET版Word处理控件Aspose.words功能演示:在ASP.NET MVC中创建MS Word编辑器