当前位置: 首页 > news >正文

24/8/17算法笔记 DDPG算法

深度确定性策略梯度(DDPG)算法是一种用于解决连续动作空间强化学习问题的算法。它结合了确定性策略梯度(DPG)和深度学习技术的优点,通过Actor-Critic框架进行策略和价值函数的近似表示。DDPG算法的关键组成部分包括经验回放缓冲区、Actor-Critic神经网络、探索噪声、目标网络以及软目标更新。

DDPG算法使用两个神经网络,分别作为Actor和Critic。Actor网络负责生成策略,即在给定状态下选择最佳动作,而Critic网络评估当前策略的表现,通过Q值来衡量。经验回放缓冲区存储了与环境交互过程中产生的转换数据,这些数据用于训练网络,打破样本之间的时间相关性,提高学习效率。

DDPG算法的一个关键特性是目标网络的使用,它通过缓慢更新目标网络的参数来增加学习过程的稳定性。软更新是通过将目标网络参数设置为目标网络参数加上一小部分主网络参数的变化来实现的。

探索噪声是DDPG算法中用于平衡探索与利用的另一个重要组成部分。通过在Actor网络输出的动作上添加噪声,鼓励智能体探索环境,这有助于发现更好的策略。

在实现DDPG算法时,需要定义Actor和Critic网络,初始化目标网络,并设置优化器。训练过程中,通过从经验回放缓冲区中采样数据来更新网络参数。更新过程包括计算目标Q值、当前Q值,并分别更新Critic和Actor网络。

import gym
from matplotlib import pyplot as plt
%matplotlib inline
#创建环境
env = gym.make('Pendulum-v1')
env.reset()#打印游戏
def show():plt.imshow(env.render(mode='rgb_array'))plt.show()

action网络模型

import torchclass Model(torch.nn.Module):def __init__(self):super().__init__()self.sequential = torch.nn.Sequential(torch.nn.Linear(3,64),#第一层全连接层,将输入特征从3维映射到64维。torch.nn.ReLU(),     #ReLU激活函数,用于引入非线性。torch.nn.Linear(64,64), #第二层全连接层,将64维特征再次映射到64维。torch.nn.ReLU(),    #ReLU激活函数。torch.nn.Linear(64,1), #第三层全连接层,将64维特征映射到1维,即输出一个值。torch.nn.Tanh(), #Tanh激活函数,将输出值映射到-1到1之间。)def forward(self,state):  #是前向传播函数,它定义了如何计算网络的输出。在这里,它将输入 state 通过 self.sequential 进行处理,然后乘以2.0。return self.sequential(state)*2.0model_action = Model()
model_action_next = Model()model_action_next.load_state_dict(model_action.state_dict())
model_action(torch.randn(1,3))

value网络模型

model_value = torch.nn.Sequential(torch.nn.Linear(4,64),torch.nn.ReLU(),torch.nn.Linear(64,64),torch.nn.ReLU(),torch.nn.Linear(64,1),
)
model_value_next = torch.nn.Sequential(torch.nn.Linear(4,64),torch.nn.ReLU(),torch.nn.Linear(64,64),torch.nn.ReLU(),torch.nn.Linear(64,1),
)
model_value_next.load_state_dict(model_value.state_dict())
model_value(torch.randn(1,4))

动作函数

import random
import numpy as np
def get_action(state):state = torch.FloatTensor(state).reshape(1,3)action = model_action(state).item() #.item(): 这个方法通常用于将一个张量(tensor)转换成一个标准的Python数值。在PyTorch中,模型的输出通常是一个张量。如果你想要获取这个张量中的单个值,可以使用.item()方法。#给动作添加噪声,增加探索action +=random.normalvariate(mu=0,sigma=0.01)#高斯随机噪声return action

更新样本池函数,准备离线学习

#样本池
datas = []#向样本池中添加N条数据,删除M条最古老的数据
def update_data():#初始化游戏state = env.reset()#玩到游戏结束为止over = Falsewhile not over:#根据当前状态得到一个动作action = get_action(state)#执行动作,得到反馈next_state,reward,over,_ = env.step([action])#记录数据样本datas.append((states,action,reward,next_state,over))#更新游戏状态,开始下一个当作state = next_state#数据上限,超出时从最古老的开始删除while len(datas)>10000:datas.pop(0)

env.step(action) 是一个常用方法,用于执行给定的动作 action 并与之环境交互。

数据采样函数

#获取一批数据样本
def get_sample():samples = random.sample(datas,64)#[b,4]state = torch.FloatTensor([i[0]for i in samples]).reshape(-1,3)#[b,1]action = torch.LongTensor([i[1]for i in samples]).reshape(-1,1)#[b,1]reward = torch.FloatTensor([i[2]for i in samples]).reshape(-1,1)#[b,4]next_state = torch.FloatTensor([i[3]for i in samples]).reshape(-1,3)#[b,1]over = torch.LongTensor([i[4]for i in samples]).reshape(-1,1)return state,action,reward,next_state,overstate,action,reward,next_state,over=get_sample()state[:5],action[:5],reward[:5],next_state[:5],over[:5]

这些数据通常用于训练强化学习模型,其中状态 statenext_state 被用来输入到价值函数或策略网络中,action 是模型选择的动作,reward 是环境对动作的反馈,over 表示游戏是否结束,通常用于确定奖励的折扣因子。

测试函数

from IPython import displaydef test(play):#初始化游戏state = env.reset()#重置环境状态#记录反馈值的和,这个值越大越好reward_sum= 0#玩到游戏结束为止over =Falsewhile not over:#根据当前状态得到一个动作action = get_action(state)#执行动作,得到反馈staet,reward,over,_=env.step(action)reward_sum+=reward#打印动画if play and random.random()<0.2: #用于清除先前在输出区域的显示内容display.clear_output(wait=True)  #wait:设置为 True 时,clear_output 将等待所有异步输出完成之后再清除输出区域。这可以确保在清除之前所有输出都已经显示。    show()return  reward_sum
def get_value(state,action):#直接评估综合了state和action的valueinput = torch.cat([state,action],dim=1) #torch.cat 函数用于连接多个张量,dim 参数指定了沿着哪个维度进行连接。return model_value(input)def get_target(next_state,reward,over):#对next_state评估需要先把它对应的当作计算出来action = model_action_next(next_state)#和value的计算一样,action拼合进next_state里综合计算input = torch.act([next_state,action],dim=1)target = model_value_next(input)*0.98target *=(1-over)target +=rewardreturn target

action模型的loss

def get_loss_action(state):#首先把动作计算出来action  = model_action(state)#像value计算那里一样,拼合state和action综合计算input = torch.cat([state,action],dim = 1)#使用value网络评估动作的价值,价值越高越好#因为这里是在计算loss,loss是越小越好,所以符号取反loss =-model_value(input).mean()return loss

软更新函数,DQN使用硬更新

软更新(Soft Update)是深度强化学习中用于更新目标网络参数的一种技术。在某些强化学习算法,如深度确定性策略梯度(DDPG)算法中,会使用两个相似的网络:一个用于生成当前策略或价值函数的“主网络”(online network),以及一个“目标网络”(target network)。

目标网络的参数是主网络参数的慢速更新版本,这样做的目的是增加训练过程的稳定性。软更新的具体步骤如下:

1. **初始化**:开始时,目标网络的参数被复制或初始化为主网络参数的副本。

2. **慢速更新**:在每次训练迭代中,以一个小的比例(通常是一个小于1的因子,如0.001或0.005)更新目标网络的参数。这个更新过程可以表示为:
 
   其中, 是目标网络的参数, 是主网络的参数,而是更新比例(tau 系数)。

3. **逐步逼近**:通过这种方式,目标网络的参数会逐步逼近主网络的参数,但不会立即完全同步。这有助于减少训练过程中的震荡。

软更新的优点包括:

- **稳定性**:由于目标网络参数更新得更慢,它为训练过程提供了一定程度的稳定性。
- **减少震荡**:软更新减少了目标值的突然变化,这有助于避免训练中的大振荡。
- **平滑学习**:它允许模型在更新过程中保持平滑的学习曲线。

软更新通常在算法的每次迭代或每隔几个步骤执行一次,具体取决于算法的设计和所需的更新频率。
 

def train():model_action.train()  #设置模型为训练模式model_value.train()optimizer_action = torch.optim.Adan(model.parameters(),lr =5e-4) #创建优化器optimizer_value = torch.optim.Adam(model_td.parameters(),lr=5e-3)loss_fn = torch.nn.MSELoss()#玩N局游戏,每局训练一次for epoch in range(200):#更新N条数据update_data()for i in range(200):#玩一局游戏,得到数据states,rewards,actions,next_states,overs = get_sample()#计算values 和targetsvalues= get_value(state,action)targets = get_target(next_states,reward,over)#两者求差,计算loss,更新参数loss_value= loss_fn(values,targets)#更新参数optimizer.zero_grad()   #作用是清除(重置)模型参数的梯度loss.backward()       #反向传播计算梯度的标准方法optimizer.step()     #更新模型的参数#使用value网络评估action网络的loss,更新参数loss_action = get_loss_action(state)optimizer_td.zero_grad()  loss_td.backward()       optimizer_td.step()  #以一个小的比例更新soft_update(model_action,model_action_next)soft_update(model_value,model_value_next)if i %20 ==0:test_result = sum([test(play=False)for _ in range(10)])/10print(epoch,len(datas),test_result)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • spark-sgg-java
  • 已解决Exception in thread “main“ java.lang.NullPointerException
  • 【数据结构题集(c语言版)】魔王语言解释 题解(字符串+栈)
  • 【JavaEE】文件操作
  • Shell——流程控制语句(if、case、for、while等)
  • SQLALchemy ORM 的关联关系之 ORM 中的一对一
  • 2024.8.17
  • 基于DPU云盘挂载的Spark优化解决方案
  • 【Linux网络】高级 I/O
  • 电脑监控怎样看回放视频?一键解锁电脑监控回放,守护安全不留死角!高效员工电脑监控,回放视频随时查!
  • mysql主从复制同步、mysql5.7版本安装配置、python操作mysql数据库、mycat读写分离实现
  • P2016 战略游戏
  • 【Python机器学习】利用PCA来简化数据——示例:利用PCA对半导体制造数据降维
  • 【书生大模型实战营(暑假场)闯关材料】基础岛:第1关 书生大模型全链路开源体系
  • Kubectl 常用命令汇总大全
  • 【编码】-360实习笔试编程题(二)-2016.03.29
  • C++入门教程(10):for 语句
  • DataBase in Android
  • eclipse的离线汉化
  • el-input获取焦点 input输入框为空时高亮 el-input值非法时
  • Java应用性能调优
  • PAT A1120
  • Python十分钟制作属于你自己的个性logo
  • React+TypeScript入门
  • Shadow DOM 内部构造及如何构建独立组件
  • spring + angular 实现导出excel
  • v-if和v-for连用出现的问题
  • Webpack4 学习笔记 - 01:webpack的安装和简单配置
  • 二维平面内的碰撞检测【一】
  • 机器学习 vs. 深度学习
  • 驱动程序原理
  • 使用API自动生成工具优化前端工作流
  • 想使用 MongoDB ,你应该了解这8个方面!
  • 走向全栈之MongoDB的使用
  • ionic异常记录
  • ​ubuntu下安装kvm虚拟机
  • ​什么是bug?bug的源头在哪里?
  • !!【OpenCV学习】计算两幅图像的重叠区域
  • # 深度解析 Socket 与 WebSocket:原理、区别与应用
  • # 数论-逆元
  • (02)vite环境变量配置
  • (3) cmake编译多个cpp文件
  • (Matalb时序预测)PSO-BP粒子群算法优化BP神经网络的多维时序回归预测
  • (Redis使用系列) Springboot 实现Redis消息的订阅与分布 四
  • (分布式缓存)Redis哨兵
  • (简单有案例)前端实现主题切换、动态换肤的两种简单方式
  • (三)Honghu Cloud云架构一定时调度平台
  • (四)stm32之通信协议
  • (算法)硬币问题
  • (学习日记)2024.04.04:UCOSIII第三十二节:计数信号量实验
  • (一)使用IDEA创建Maven项目和Maven使用入门(配图详解)
  • **《Linux/Unix系统编程手册》读书笔记24章**
  • .dat文件写入byte类型数组_用Python从Abaqus导出txt、dat数据
  • .NET MAUI学习笔记——2.构建第一个程序_初级篇
  • .Net 应用中使用dot trace进行性能诊断