当前位置: 首页 > news >正文

【pytorch】封装 optimizer实现 “梯度截断” 与 “学习率下调”

文章目录

  • 参考代码
  • 初始化
  • 梯度截断
  • 下调学习率

参考代码

https://github.com/laiguokun/LSTNet

初始化

import math
import torch.optim as optim

class Optim(object):

    def _makeOptimizer(self):
        if self.method == 'sgd':
            self.optimizer = optim.SGD(self.params, lr=self.lr)
        elif self.method == 'adagrad':
            self.optimizer = optim.Adagrad(self.params, lr=self.lr)
        elif self.method == 'adadelta':
            self.optimizer = optim.Adadelta(self.params, lr=self.lr)
        elif self.method == 'adam':
            self.optimizer = optim.Adam(self.params, lr=self.lr)
        else:
            raise RuntimeError("Invalid optim method: " + self.method)

    def __init__(self, params, method, lr, max_grad_norm, lr_decay=1, start_decay_at=None):
        self.params = list(params)  # careful: params may be a generator
        self.last_ppl = None
        self.lr = lr
        self.max_grad_norm = max_grad_norm
        self.method = method
        self.lr_decay = lr_decay
        self.start_decay_at = start_decay_at
        self.start_decay = False

        self._makeOptimizer()

梯度截断

使得梯度向量的 L 2 L2 L2 范数不会超过 self.max_grad_norm

    def step(self):
        # Compute gradients norm.
        grad_norm = 0
        for param in self.params:
            grad_norm += math.pow(param.grad.data.norm(), 2)

        grad_norm = math.sqrt(grad_norm)
        if grad_norm > 0:
            shrinkage = self.max_grad_norm / grad_norm
        else:
            shrinkage = 1.

        for param in self.params:
            if shrinkage < 1:
                param.grad.data.mul_(shrinkage)

        self.optimizer.step()
        return grad_norm

下调学习率

只有当 达到下调时间点self.start_decay_at或 验证集上表现没有提高ppl > self.last_ppl 时下调学习率

    # decay learning rate if validation performance does not improve or we hit the start_decay_at limit
    def updateLearningRate(self, ppl, epoch):
        if self.start_decay_at is not None and epoch >= self.start_decay_at:
            self.start_decay = True
        if self.last_ppl is not None and ppl > self.last_ppl:
            self.start_decay = True

        if self.start_decay:
            self.lr = self.lr * self.lr_decay
            print("Decaying learning rate to %g" % self.lr)
        #only decay for one epoch
        self.start_decay = False

        self.last_ppl = ppl

        self._makeOptimizer()

使用时

 for epoch in tqdm(range(1, args.epochs+1)):

	train_loss = train(Data, Data.train[0], Data.train[1], model, criterion, optim, args.batch_size)
	val_loss, val_rae, val_corr = evaluate(Data, Data.valid[0], Data.valid[1], model, evaluateL2, evaluateL1, args.batch_size);
	print('| end of epoch {:3d} | time: {:5.2f}s | train_loss {:5.4f} | valid rse {:5.4f} | valid rae {:5.4f} | valid corr  {:5.4f}'.format(epoch, (time.time() - epoch_start_time), train_loss, val_loss, val_rae, val_corr))
      
	# Save the model if the validation loss is the best we've seen so far.
	if val_loss < best_val:
		with open(args.save, 'wb') as f:
			torch.save(model, f)
		best_val = val_loss
      
	if epoch % 5 == 0:
		test_acc, test_rae, test_corr  = evaluate(Data, Data.test[0], Data.test[1], model, evaluateL2, evaluateL1, args.batch_size);
		print ("test rse {:5.4f} | test rae {:5.4f} | test corr {:5.4f}".format(test_acc, test_rae, test_corr))

	optim.updateLearningRate(val_loss, epoch)  # <<< here!

相关文章:

  • 赫连勃勃
  • LSTNet
  • Windows Embedded征文比赛
  • 【pytorch】用 GRU 做时间序列预测
  • 时间序列问题与自然语言处理的区别
  • 周日-购书记录---五道口光合作用
  • 时间序列特征提取 —— 获取日期相关的协变量
  • c#中高效的excel导入sqlserver的方法
  • DeepGLO
  • 《梦断代码》上市
  • Multi-Horizon Time Series Forecasting with Temporal Attention Learning
  • 网络互联设备之区别详解
  • Quantile RNN
  • 正式开始homeR的计划
  • 非线性状态空间模型与非线性自回归模型的联系
  • 【node学习】协程
  • 0基础学习移动端适配
  • Android Studio:GIT提交项目到远程仓库
  • Docker: 容器互访的三种方式
  • docker-consul
  • extjs4学习之配置
  • iOS 颜色设置看我就够了
  • Kibana配置logstash,报表一体化
  • Laravel核心解读--Facades
  • MyEclipse 8.0 GA 搭建 Struts2 + Spring2 + Hibernate3 (测试)
  • Redis提升并发能力 | 从0开始构建SpringCloud微服务(2)
  • SSH 免密登录
  • thinkphp5.1 easywechat4 微信第三方开放平台
  • Webpack入门之遇到的那些坑,系列示例Demo
  • weex踩坑之旅第一弹 ~ 搭建具有入口文件的weex脚手架
  • 半理解系列--Promise的进化史
  • 函数式编程与面向对象编程[4]:Scala的类型关联Type Alias
  • 缓存与缓冲
  • 力扣(LeetCode)22
  • 如何用Ubuntu和Xen来设置Kubernetes?
  • 设计模式 开闭原则
  • 实战:基于Spring Boot快速开发RESTful风格API接口
  • 数据库写操作弃用“SELECT ... FOR UPDATE”解决方案
  • 因为阿里,他们成了“杭漂”
  • ### Error querying database. Cause: com.mysql.jdbc.exceptions.jdbc4.CommunicationsException
  • (12)目标检测_SSD基于pytorch搭建代码
  • (2)(2.10) LTM telemetry
  • (6)【Python/机器学习/深度学习】Machine-Learning模型与算法应用—使用Adaboost建模及工作环境下的数据分析整理
  • (a /b)*c的值
  • (day 2)JavaScript学习笔记(基础之变量、常量和注释)
  • (附源码)springboot课程在线考试系统 毕业设计 655127
  • (力扣)1314.矩阵区域和
  • (力扣题库)跳跃游戏II(c++)
  • (十一)c52学习之旅-动态数码管
  • .360、.halo勒索病毒的最新威胁:如何恢复您的数据?
  • .NET 8.0 发布到 IIS
  • .NET(C#、VB)APP开发——Smobiler平台控件介绍:Bluetooth组件
  • /run/containerd/containerd.sock connect: connection refused
  • @Transactional类内部访问失效原因详解
  • [ 隧道技术 ] 反弹shell的集中常见方式(二)bash反弹shell