当前位置: 首页 > news >正文

概率神经网络_深度学习算法(第33期)强化学习之神经网络策略学习平衡车

上期我们一起学习了强化学习入门的相关知识,深度学习算法(第32期)----强化学习入门必读  

今天我们学习下OpenAI工具包以及神经网络策略学习平衡车的相关知识。

OpenAI Gym 介绍

强化学习的一个挑战就是为了训练智能体,我们必须有一个工作环境。如果想整一个学习玩Atari游戏的智能体,那么必须得有一个Atati游戏的模拟器。如果我们想训练一个自动行走的机器人,那么这个环境就是现实世界,我们可以直接去训练它去适应这个环境,但是这有一定的局限性,如果机器人从悬崖上掉下来,你不能仅仅点击“撤消”。你也不能加快时间;增加更多的计算能力不会让机器人移动得更快。一般来说,同时训练1000个机器人是非常昂贵的。简而言之,训练在现实世界中是困难和缓慢的,所以你通常需要一个模拟环境,至少需要引导训练。

OpenAI gym是一个提供各种游戏模拟环境(包括Atari游戏,棋盘游戏,2D-3D物理模拟等)的工具包。所以我们可以训练智能体,比较并开发新的强化学习算法。

首先,我们用pip安装OpenAI gym:

$ pip install --upgrade gym

接下来就是创建环境了:

>>> import gym
>>> env = gym.make("CartPole-v0")
[2016-10-14 16:03:23,199] Making new env: MsPacman-v0
>>> obs = env.reset()
>>> obs
array([-0.03799846,-0.03288115,0.02337094,0.00720711])
>>> env.render()

上面的make()函数创建了一个新环境,这是一个CartPole环境。这是一个2D的模拟环境,其中可以通过左右加速小推车,来平衡小车上的平衡杆,如下图: c8d6ac3d84de5a579b32c6004f893310.png环境创建之后,我们需要用reset()函数去初始化,这将返回第一次观测结果,观测结果取决于环境类型。对于这个CartPole环境,每一次观测都是一个1x4的数组,4个浮点数据分别表示小车的水平位置(原点为中心),速度,平衡杆的角度(竖直为0度)及其角速度。最后的render()函数显示如上图的环境。

如果你想让render()将返回图像以一个Numpy数组格式返回,可以将mode参数设置为rgb_array(注意其他环境可能支持不同的模式):

>>> img = env.render(mode="rgb_array")
>>> img.shape # height, width, channels (3=RGB)
(400, 600, 3)

然而,即使将mode参数设置为rgb_array,CartPole(和其他一些环境)还是会将将图像呈现到屏幕上。避免这种情况的唯一方式是使用一个 fake X 服务器,如 Xvfb 或 XDimMy。例如,可以安装 Xvfb 并且用下命令启动 Python:

xvfb-run -s "screen 0 1400x900x24" python

或者使用xvfbwrapper包。

现在我们有了这样一个环境,那么在环境中我们能做什么呢?试试如下指令:

>>> env.action_space
Discrete(2)

Discrete(2)意味着可能的操作是0和1,分别代表左右加速。其他环境可能有更多的动作。或者其他类型的动作(比如:连续的)。由于平衡杆向右倾斜,所以我们让小车加速向右。

>>> action = 1  # accelerate right
>>> obs, reward, done, info = env.step(action)
>>> obs
array([-0.03865608, 0.16189797, 0.02351508, -0.27801135])
>>> reward
1.0
>>> done
False
>>> info
{}

这个step()函数执行给定的动作,并返回4个值:

  • obs新的观测值,小车现在正在向右移动(obs[1]>0,注:当前速度为正,向右为正)。平衡杆仍然向右倾斜(obs[2]>0),但是他的角速度现在为负(obs[3]<0),所以它在下一步后可能会向左倾斜。

  • reward在这个环境中,无论你做什么,每一步都会得到1.0奖励,所以游戏的目标就是尽可能长的运行。

  • done当游戏结束时这个值会为True。当平衡杆倾斜太多时会发生这种情况。之后,必须重新设置环境才能重新使用。

  • info在其他环境中该字典可以提供额外的调试信息。但是这些数据不应该用于训练(这是作弊)。

接下来我们编码一个简单的策略,当杆向左倾斜时加速左边,当杆向右倾斜时加速。我们使用这个策略来获得超过500步的平均奖励:

def basic_policy(obs):
angle = obs[2]
return 0 if angle < 0 else 1

totals = []
for episode in range(500):
episode_rewards = 0
obs = env.reset()
for step in range(1000): # 最多1000步,我们不想让它永远运行下去
action = basic_policy(obs)
obs, reward, done, info = env.step(action)
episode_rewards += reward
if done:
break
totals.append(episode_rewards)

上面代码不多说,有注释,我们看看结果:

>>> import numpy as np
>>> np.mean(totals), np.std(totals), np.min(totals), np.max(totals)
(42.125999999999998, 9.1237121830974033, 24.0, 68.0)

即使有 500 次尝试,这一策略从未使平衡杆在超过68个连续的步骤里保持直立。这不太好。如果你看一下 Juyter Notebook 中的模拟,你会发现,推车越来越强烈地左右摆动,直到平衡杆倾斜太多。让我们看看神经网络是否能提出更好的策略。

神经网络策略

接下来我们创建一个神经网络策略,就像刚才写的一样。这个神经网络用观测值作为输入,将执行的动作作为输出。更精确的说,它将估计每个动作的概率,然后我们将根据估计的概率随机地选择一个动作如下图:e1ada34ab8f9eedff5ce17924156d0e4.png

在这个CartPole的环境中,只有两种可能的动作(左或右),所以我们只需要一个输出神经元。它输出动作0(向左)的概率p,动作1(向右)的概率显然将是1 - p。

例如,如果它输出0.7,那么我们将以70%的概率选择动作0,以30%的概率选择动作1。

你可能问为什么我们根据神经网络给出的概率来选择随机的动作,而不是选择最高分数的动作呢?这种概率选择的方法能够使智能体在探索新的行为和利用那些已知可行的行动之间找到正确的平衡。举个例子:假设你第一次去餐馆,所有的菜看起来同样吸引人,所以你随机挑选一个。如果菜好吃,你可以增加下一次点它的概率,但是你不应该把这个概率提高到 100%,否则你将永远不会尝试其他菜肴,其中一些甚至比你尝试的更好。

还需注意的是,在这个特定的环境中,过去的动作和观察可以被安全地忽略,因为每个观察都包含环境的完整状态。如果环境中有一些隐藏状态,那么我们就需要考虑过去的行为和观察。例如,如果环境仅仅揭示了推车的位置,而不是它的速度,那么你不仅要考虑当前的观测,还要考虑先前的观测,以便估计当前的速度。或者当观测是有噪声的情况下,通常是用过去几次的观察来估计最可能的当前状态。CartPole的问题是简单的;观测是无噪声的,而且它们包含环境的全部状态。

下面是用tensorflow来创建神经网络策略的代码:

import tensorflow as tf
from tensorflow.contrib.layers import fully_connected
# 1. 声明神经网络结构
n_inputs = 4 # == env.observation_space.shape[0]
n_hidden = 4 # 这只是个简单的测试,不需要过多的隐藏层
n_outputs = 1 # 只输出向左加速的概率
initializer = tf.contrib.layers.variance_scaling_initializer()
# 2. 建立神经网络
X = tf.placeholder(tf.float32, shape=[None, n_inputs]) hidden = fully_connected(X, n_hidden, activation_fn=tf.nn.elu,weights_initializer=initializer) # 隐层激活函数使用指数线性函数
logits = fully_connected(hidden, n_outputs, activation_fn=None,weights_initializer=initializer)
outputs = tf.nn.sigmoid(logits)
# 3. 在概率基础上随机选择动作
p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])
action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)
init = tf.global_variables_initializer()

我们仔细看下上面的代码:

  1. 导入库之后,我们定义了神经网络的结构,输入的数量是观测值的size,在CartPole环境中是4,我们设置了4个隐藏层,输出为1个向左加速的概率值。

  2. 接下来,我们构建了神经网络,在这个例子中,是带一个输出的多层感知机,注意输出层的激活函数用的是sigmoid激活函数,主要考虑到该函数可以输出0.0到1.0之间的概率值。如果有两个以上的可能动作,每个动作都会有一个输出神经元,相应的激活函数将使用Softmax函数。

  3. 最后,我们调用multinomial()函数来选择一个随机动作。给定每个整数的对数概率,该函数独立地采样一个(或多个)整数。例如,如果通过设置num_samples=5,令数组为[np.log(0.5), np.log(0.2), np.log(0.3)]来调用它,那么它将输出五个整数,每个整数都有 50% 的概率是0,20% 为1,30%为2。在我们的情况下,我们只需要一个整数来表示要采取的行动。由于输出仅包含向左的概率,为了概率选择,所以我们首先将 1 - output加进去,以得到包含左和右动作的概率的张量。

好了,现在我们有一个可以观察和输出动作的神经网络了,那我们怎么训练它呢?

至此,我们今天熟悉了OpenAI中平衡车的环境,以及学习了如何搭建神经网络策略,下期我们将使用Tensorflow来实现梯度策略算法,并且开始训练我们的神经网络策略。希望有些收获,欢迎留言或进社区共同交流,喜欢的话,就点个赞吧,您也可以置顶公众号,第一时间接收最新内容。


智能算法,与您携手,沉淀自己,引领AI!

8cc994142080c6f80e2cc6efca8bad0a.png

相关文章:

  • vasp 安装_安装Atomic Simulation Environment (ASE)
  • fastreport调用frf文件直接打印_不是我吹,20M的压缩文件我只用了1秒!
  • python属于哪类型的编程语言_python属于什么类型的语言
  • python excel接口测试_利用python和excel 搭建接口测试框架
  • sql 日期月转换到日_速来!9月新增雅思考点、考试日期(更新至8月16日)
  • python处理csv文件 pandas_Pandas操作CSV文件的读写实现方法
  • 组播应用场景_慧联应用研究 | 浅谈LoRaWAN组播技术和应用(市场篇)
  • python语言过渡到c语言_2020年1月编程语言排行榜:Python让位,C语言获得“2019 年度编程语言”...
  • redis 端口_Centos7 linux 下yum安装redis以及使用
  • python groupby用法_Python教程 | 数据分析系统步骤介绍!
  • input file文件上传_了解PHP文件上传相关知识
  • hystrix熔断和降级的区别_学习笔记32-Hystrix
  • python数据分析职位_python代写拉勾数据职位分析
  • sql 去除空格_10个Excel常用操作,SQL也能实现啦!附面试原题
  • java循环输入_自学JAVA每日记录(6)-欢迎指点欢迎共勉
  • Android开源项目规范总结
  • avalon2.2的VM生成过程
  • ES6简单总结(搭配简单的讲解和小案例)
  • iOS帅气加载动画、通知视图、红包助手、引导页、导航栏、朋友圈、小游戏等效果源码...
  • javascript从右向左截取指定位数字符的3种方法
  • Java多态
  • JS字符串转数字方法总结
  • miaov-React 最佳入门
  • node学习系列之简单文件上传
  • spring cloud gateway 源码解析(4)跨域问题处理
  • Spring Cloud(3) - 服务治理: Spring Cloud Eureka
  • Traffic-Sign Detection and Classification in the Wild 论文笔记
  • Vue ES6 Jade Scss Webpack Gulp
  • 关于使用markdown的方法(引自CSDN教程)
  • 世界上最简单的无等待算法(getAndIncrement)
  • 小程序上传图片到七牛云(支持多张上传,预览,删除)
  • 译自由幺半群
  • - 转 Ext2.0 form使用实例
  • [Shell 脚本] 备份网站文件至OSS服务(纯shell脚本无sdk) ...
  • 支付宝花15年解决的这个问题,顶得上做出十个支付宝 ...
  • ​VRRP 虚拟路由冗余协议(华为)
  • # Swust 12th acm 邀请赛# [ K ] 三角形判定 [题解]
  • # 数据结构
  • #我与Java虚拟机的故事#连载08:书读百遍其义自见
  • #我与Java虚拟机的故事#连载09:面试大厂逃不过的JVM
  • $con= MySQL有关填空题_2015年计算机二级考试《MySQL》提高练习题(10)
  • (04)odoo视图操作
  • (Arcgis)Python编程批量将HDF5文件转换为TIFF格式并应用地理转换和投影信息
  • (大众金融)SQL server面试题(1)-总销售量最少的3个型号的车及其总销售量
  • (论文阅读31/100)Stacked hourglass networks for human pose estimation
  • (译) 函数式 JS #1:简介
  • (转)fock函数详解
  • (转)nsfocus-绿盟科技笔试题目
  • (转)VC++中ondraw在什么时候调用的
  • (转)用.Net的File控件上传文件的解决方案
  • *_zh_CN.properties 国际化资源文件 struts 防乱码等
  • .jks文件(JAVA KeyStore)
  • .NET 使用配置文件
  • .net项目IIS、VS 附加进程调试
  • @Bean, @Component, @Configuration简析