当前位置: 首页 > news >正文

pytorch pyro更高阶的优化器会使用更高阶的导数,比如二阶导数(Hessian矩阵)

在机器学习和深度学习中,优化器是用来更新模型参数以最小化损失函数的算法。通常,优化器会计算损失函数相对于参数的一阶导数(梯度),然后根据这些梯度来更新参数。但是,更高阶的优化器会使用更高阶的导数,比如二阶导数(Hessian矩阵),来指导参数的更新

 

关于使用更高阶导数的优化器基类的描述。在机器学习和深度学习中,优化器是用来更新模型参数以最小化损失函数的算法。通常,优化器会计算损失函数相对于参数的一阶导数(梯度),然后根据这些梯度来更新参数。但是,更高阶的优化器会使用更高阶的导数,比如二阶导数(Hessian矩阵),来指导参数的更新。

这段描述中的关键点包括:

  1. 使用torch.autograd.grad而不是torch.Tensor.backwardtorch.autograd.grad是PyTorch中的一个函数,它可以用来计算张量相对于其他张量的导数。这与torch.Tensor.backward不同,后者是自动求导机制的一部分,通常用于计算梯度。

  2. 不同的接口:由于高阶优化器需要计算更高阶的导数,它们需要一个不同的接口。在这个接口中,step方法接受一个损失张量作为输入,并在优化器内部触发一次或多次反向传播。

  3. 派生类必须实现step方法:这意味着任何从这个基类派生的优化器类都需要提供自己的step方法实现,以计算导数并就地更新参数。

  4. 示例代码:示例展示了如何使用这种优化器。首先,通过poutine.trace获取模型的跟踪,然后计算负对数概率之和作为损失。接着,从跟踪中提取参数,并调用优化器的step方法来更新这些参数。

简而言之,这段代码描述了一个用于高级优化的基类,它允许开发者实现使用更高阶导数的自定义优化器。这种类型的优化器可能在某些情况下比传统的一阶优化器更有效,尤其是在参数更新需要更精细控制的场景中。

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0from typing import Dict, Listimport torchfrom pyro.ops.newton import newton_step
from pyro.optim.optim import PyroOptimclass MultiOptimizer:"""Base class of optimizers that make use of higher-order derivatives.Higher-order optimizers generally use :func:`torch.autograd.grad` ratherthan :meth:`torch.Tensor.backward`, and therefore require a differentinterface from usual Pyro and PyTorch optimizers. In this interface,the :meth:`step` method inputs a ``loss`` tensor to be differentiated,and backpropagation is triggered one or more times inside the optimizer.Derived classes must implement :meth:`step` to compute derivatives andupdate parameters in-place.Example::tr = poutine.trace(model).get_trace(*args, **kwargs)loss = -tr.log_prob_sum()params = {name: site['value'].unconstrained()for name, site in tr.nodes.items()if site['type'] == 'param'}optim.step(loss, params)"""def step(self, loss: torch.Tensor, params: Dict) -> None:"""Performs an in-place optimization step on parameters given adifferentiable ``loss`` tensor.Note that this detaches the updated tensors.:param torch.Tensor loss: A differentiable tensor to be minimized.Some optimizers require this to be differentiable multiple times.:param dict params: A dictionary mapping param name to unconstrainedvalue as stored in the param store."""updated_values = self.get_step(loss, params)for name, value in params.items():with torch.no_grad():# we need to detach because updated_value may depend on valuevalue.copy_(updated_values[name].detach())def get_step(self, loss: torch.Tensor, params: Dict) -> Dict:"""Computes an optimization step of parameters given a differentiable``loss`` tensor, returning the updated values.Note that this preserves derivatives on the updated tensors.:param torch.Tensor loss: A differentiable tensor to be minimized.Some optimizers require this to be differentiable multiple times.:param dict params: A dictionary mapping param name to unconstrainedvalue as stored in the param store.:return: A dictionary mapping param name to updated unconstrainedvalue.:rtype: dict"""raise NotImplementedErrorclass PyroMultiOptimizer(MultiOptimizer):"""Facade to wrap :class:`~pyro.optim.optim.PyroOptim` objectsin a :class:`MultiOptimizer` interface."""def __init__(self, optim: PyroOptim) -> None:if not isinstance(optim, PyroOptim):raise TypeError("Expected a PyroOptim object but got a {}".format(type(optim)))self.optim = optimdef step(self, loss: torch.Tensor, params: Dict) -> None:values = params.values()grads = torch.autograd.grad(loss, values, create_graph=True)  # type: ignorefor x, g in zip(values, grads):x.grad = gself.optim(values)class TorchMultiOptimizer(PyroMultiOptimizer):"""Facade to wrap :class:`~torch.optim.Optimizer` objectsin a :class:`MultiOptimizer` interface."""def __init__(self, optim_constructor: torch.optim.Optimizer, optim_args: Dict):optim = PyroOptim(optim_constructor, optim_args)super().__init__(optim)class MixedMultiOptimizer(MultiOptimizer):"""Container class to combine different :class:`MultiOptimizer` instances fordifferent parameters.:param list parts: A list of ``(names, optim)`` pairs, where each``names`` is a list of parameter names, and each ``optim`` is a:class:`MultiOptimizer` or :class:`~pyro.optim.optim.PyroOptim` objectto be used for the named parameters. Together the ``names`` shouldpartition up all desired parameters to optimize.:raises ValueError: if any name is optimized by multiple optimizers."""def __init__(self, parts: List) -> None:optim_dict: Dict = {}self.parts = []for names_part, optim in parts:if isinstance(optim, PyroOptim):optim = PyroMultiOptimizer(optim)for name in names_part:if name in optim_dict:raise ValueError("Attempted to optimize parameter '{}' by two different optimizers: ""{} vs {}".format(name, optim_dict[name], optim))optim_dict[name] = optimself.parts.append((names_part, optim))def step(self, loss: torch.Tensor, params: Dict):for names_part, optim in self.parts:optim.step(loss, {name: params[name] for name in names_part})def get_step(self, loss: torch.Tensor, params: Dict) -> Dict:updated_values = {}for names_part, optim in self.parts:updated_values.update(optim.get_step(loss, {name: params[name] for name in names_part}))return updated_valuesclass Newton(MultiOptimizer):"""Implementation of :class:`MultiOptimizer` that performs a Newton updateon batched low-dimensional variables, optionally regularizing via aper-parameter ``trust_radius``. See :func:`~pyro.ops.newton.newton_step`for details.The result of :meth:`get_step` will be differentiable, however theupdated values from :meth:`step` will be detached.:param dict trust_radii: a dict mapping parameter name to radius of trustregion. Missing names will use unregularized Newton update, equivalentto infinite trust radius."""def __init__(self, trust_radii: Dict = {}):self.trust_radii = trust_radiidef get_step(self, loss: torch.Tensor, params: Dict):updated_values = {}for name, value in params.items():trust_radius = self.trust_radii.get(name)  # type: ignoreupdated_value, cov = newton_step(loss, value, trust_radius)updated_values[name] = updated_valuereturn updated_values

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • HTTP、Session、Token及Cookie详解
  • 【Unity优化】优化Android平台拖动地图表现
  • 使用ElementUI + Vue框架实现学生管理系统前端页面设计
  • java中数据访问层userdao接口怎么写
  • IDEA项目启动在不同端口的方法,服务多端口启动
  • CSS3 文本效果(text-shadow,box-shadow,white-space等)文本溢出隐藏并且显示省略号
  • 系统分析师6:计算机网络
  • 46. 把数字翻译成字符串【难】
  • 【软件测试专栏】测试分类篇
  • 【Android】 工具篇:ProxyPin抓包详解---夜神模拟器
  • Elasticsearch检索原理
  • 详解Asp.Net Core管道模型中的五种过滤器的适用场景与用法
  • 人活着的意义是什么
  • [NOI2014] 魔法森林(LCT维护MST)
  • Spring Boot 多数据源配置(JPA)
  • HashMap ConcurrentHashMap
  • JavaSE小实践1:Java爬取斗图网站的所有表情包
  • Js实现点击查看全文(类似今日头条、知乎日报效果)
  • scrapy学习之路4(itemloder的使用)
  • Spark VS Hadoop:两大大数据分析系统深度解读
  • Vue2 SSR 的优化之旅
  • vue2.0一起在懵逼的海洋里越陷越深(四)
  • 产品三维模型在线预览
  • 互联网大裁员:Java程序员失工作,焉知不能进ali?
  • 目录与文件属性:编写ls
  • 算法---两个栈实现一个队列
  • 限制Java线程池运行线程以及等待线程数量的策略
  • 新版博客前端前瞻
  • 延迟脚本的方式
  • 一、python与pycharm的安装
  • 做一名精致的JavaScripter 01:JavaScript简介
  • No resource identifier found for attribute,RxJava之zip操作符
  • 没有任何编程基础可以直接学习python语言吗?学会后能够做什么? ...
  • ​iOS实时查看App运行日志
  • ​水经微图Web1.5.0版即将上线
  • # 飞书APP集成平台-数字化落地
  • #、%和$符号在OGNL表达式中经常出现
  • #ifdef 的技巧用法
  • (CPU/GPU)粒子继承贴图颜色发射
  • (python)数据结构---字典
  • (react踩过的坑)Antd Select(设置了labelInValue)在FormItem中initialValue的问题
  • (八)Spring源码解析:Spring MVC
  • (二)WCF的Binding模型
  • (非本人原创)史记·柴静列传(r4笔记第65天)
  • (附源码)计算机毕业设计ssm电影分享网站
  • (汇总)os模块以及shutil模块对文件的操作
  • (九)One-Wire总线-DS18B20
  • (算法)Game
  • (贪心) LeetCode 45. 跳跃游戏 II
  • (小白学Java)Java简介和基本配置
  • (译) 函数式 JS #1:简介
  • (转)es进行聚合操作时提示Fielddata is disabled on text fields by default
  • (转)VC++中ondraw在什么时候调用的
  • ******之网络***——物理***
  • *Algs4-1.5.25随机网格的倍率测试-(未读懂题)