当前位置: 首页 > news >正文

Pytorch深度强化学习案例:基于DQN实现Flappy Bird游戏与分析

目录

  • 1 案例介绍
  • 2 构造深度Q网络
  • 3 经验回放与目标网络
  • 4 训练流程
  • 5 实验分析

1 案例介绍

Flappy Bird是一款由来自越南的独立游戏开发者Dong Nguyen所开发的作品,于2013年5月24日上线。

Flappy Bird中,玩家只需要用一根手指来操控:点击一次屏幕,小鸟就会往上飞一次,不断地点击就会使小鸟不断往高处飞。放松手指,小鸟则会快速下降。所以玩家要控制小鸟一直向前飞行,然后注意躲避途中高低不平的管子。小鸟每安全穿过一个水管得1分,若撞上水管则游戏失败。

如图所示是用强化学习模型DQN训练AI完成Flappy Bird游戏的案例,接下来具体分析如何实现这个案例

在这里插入图片描述

2 构造深度Q网络

深度Q网络(Deep Q-Network, DQN)的核心原理是通过

  • 经验回放池
  • 目标网络

拟合高维状态空间,是Q-Learning算法的深度学习版本。具体理论参考Pytorch深度强化学习(八):基于价值的强化学习——DQN算法

具体到Flappy Bird游戏,结构如图所示:设置网络输入为游戏的连续四帧图片,使用卷积神经网络提取状态特征,最后输出为一个布尔值,即小鸟选择的动作——向上飞或下降。

在这里插入图片描述
实现如下

class DeepQNetwork(nn.Module):
    def __init__(self):
        super(DeepQNetwork, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))

        self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))
        self.fc2 = nn.Linear(512, 2)

    def forward(self, input):
        output = self.conv1(input)
        output = self.conv2(output)
        output = self.conv3(output)
        output = output.view(output.size(0), -1)
        output = self.fc1(output)
        output = self.fc2(output)

        return output

3 经验回放与目标网络

考虑到强化学习采样的是连续非静态样本,样本间的相关性导致网络参数并非独立同分布,使训练过程难以收敛,因此设置经验池存储样本,再通过随机采样去除相关性。经验回放池的设置、存储与采样如下

replay_memory = []

# 将<s, a, r, s'>添加到经验回放池
replay_memory.append([state, action, reward, next_state, terminal])
if len(replay_memory) > opt["replay_memory_size"]:
    del replay_memory[0]

# 采样一个batch的数据
batch = sample(replay_memory, min(len(replay_memory), opt["batch_size"]))
state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(*batch)

考虑到若目标价值与当前价值 是同一个网络时会导致优化目标不断变化,产生模型振荡与发散,因此构建结构相同但慢于更新的独立目标网络来评估目标价值,使模型更稳定

# 采用的网络
self.model = DQN(env.observation_space.shape, env.action_space.n).to(self.device)
self.target_model = DQN(env.observation_space.shape, env.action_space.n).to(self.device)
for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
    target_param.data.copy_(param)

# 更新target网络
for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
    target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)

4 训练流程

除了与环境的交互采样强化学习思想,其余步骤与深度学习训练相同

# 实例化DQN模型
model = DeepQNetwork()

# 设置优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=opt["lr"])
criterion = nn.MSELoss()

# 初始化环境
game_state = FlappyBird()
image, reward, terminal = game_state.step(0)
image = preProcessing(image[:game_state.screen_width, :int(game_state.base_y)], opt["image_size"], opt["image_size"])
image = torch.from_numpy(image)

# 获得状态, 将图片化为batch x in_channel x h x w
state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]
replay_memory = []

# 开始迭代
with tqdm(range(opt["num_iters"])) as bar:
    for i in bar:
        prediction = model(state)[0]
        
        # 动态调整贪心概率并执行贪心算法
        epsilon = opt["final_epsilon"] + (
                (opt["num_iters"] - i) * (opt["initial_epsilon"] - opt["final_epsilon"]) / opt["num_iters"])
        action = randint(0, 1) if random() <= epsilon else torch.argmax(prediction)

        # 获取下一个状态(时序差分)
        next_image, reward, terminal = game_state.step(action)
        next_image = preProcessing(next_image[:game_state.screen_width, :int(game_state.base_y)],
                        opt["image_size"], opt["image_size"])
        next_image = torch.from_numpy(next_image)
        next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]

        # 将<s, a, r, s'>添加到经验回放池
        ...
        
        # 采样一个batch的数据
        ...

        # 目标网络为训练样本添加标注信息,并与当前值网络做损失
        current_prediction_batch = model(state_batch)
        next_prediction_batch = model(next_state_batch)
        y_batch = torch.cat(
            tuple(reward if terminal else reward + opt["gamma"] * torch.max(prediction) for reward, terminal, prediction in
                zip(reward_batch, terminal_batch, next_prediction_batch)))
        q_value = torch.sum(current_prediction_batch * action_batch, dim=1)

        # 梯度优化
        optimizer.zero_grad()
        # y_batch = y_batch.detach()
        loss = criterion(q_value, y_batch)
        loss.backward()
        optimizer.step()

        state = next_state

5 实验分析

训练2000代的奖赏曲线如图所示,左侧是验证集曲线,右侧是训练集曲线,可见随着训练过程进行,模型得到的奖励在不断上升

在这里插入图片描述
刚开始训练时的效果可视化

在这里插入图片描述

模型收敛后的效果可视化(200万次迭代),AI已经可以很好地掌握这款游戏了

在这里插入图片描述

本文完整工程代码请联系下方博主名片获取


🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《路径规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

相关文章:

  • Spring Cloud Sentinel(限流、熔断、降级)、SpringBoot整合Sentinel、Sentinel的使用-60
  • 极限存在准则 两个重要极限——“高等数学”
  • 蓝桥杯重点(C/C++)(随时更新,更新时间:2023.1.27)
  • 【计算机网络】应用层体系
  • 极限运算法则——“高等数学”
  • PHP代码审计之MVC与ThinkPHP简介
  • 【教程】Python实时检测CPU和GPU的功耗
  • Windows压缩工具 “ Bandizip 与 7-zip ”
  • Windows卸载与清除工具 “ Geek 与 CCleaner ”
  • Git速成指南
  • 逆序遍历List集合
  • VSCode配置C/C++环境
  • 芒果改进YOLOv7系列:超越ConvNeXt结构,原创结合Conv2Former改进结构,Transformer 风格的卷积网络视觉基线模型,高效涨点
  • Opencv调参神器——trackBar控件
  • 自动驾驶环境感知——视觉传感器技术
  • python3.6+scrapy+mysql 爬虫实战
  • [译]Python中的类属性与实例属性的区别
  • AzureCon上微软宣布了哪些容器相关的重磅消息
  • github从入门到放弃(1)
  • Java 网络编程(2):UDP 的使用
  • js如何打印object对象
  • k8s如何管理Pod
  • Linux gpio口使用方法
  • Linux中的硬链接与软链接
  • Netty 框架总结「ChannelHandler 及 EventLoop」
  • node-glob通配符
  • SpiderData 2019年2月23日 DApp数据排行榜
  • Terraform入门 - 3. 变更基础设施
  • 读懂package.json -- 依赖管理
  • 二维平面内的碰撞检测【一】
  • 关于Java中分层中遇到的一些问题
  • 湖南卫视:中国白领因网络偷菜成当代最寂寞的人?
  • 检测对象或数组
  • 如何在 Tornado 中实现 Middleware
  • 使用API自动生成工具优化前端工作流
  • 数据可视化之 Sankey 桑基图的实现
  • 推荐一款sublime text 3 支持JSX和es201x 代码格式化的插件
  • 一个6年java程序员的工作感悟,写给还在迷茫的你
  • d²y/dx²; 偏导数问题 请问f1 f2是什么意思
  • HanLP分词命名实体提取详解
  • 如何正确理解,内页权重高于首页?
  • ​中南建设2022年半年报“韧”字当头,经营性现金流持续为正​
  • #多叉树深度遍历_结合深度学习的视频编码方法--帧内预测
  • #我与虚拟机的故事#连载20:周志明虚拟机第 3 版:到底值不值得买?
  • (八)Docker网络跨主机通讯vxlan和vlan
  • (超详细)2-YOLOV5改进-添加SimAM注意力机制
  • (附源码)springboot美食分享系统 毕业设计 612231
  • (附源码)ssm航空客运订票系统 毕业设计 141612
  • (附源码)ssm户外用品商城 毕业设计 112346
  • (附源码)小程序 交通违法举报系统 毕业设计 242045
  • (深入.Net平台的软件系统分层开发).第一章.上机练习.20170424
  • (转)fock函数详解
  • ./configure、make、make install 命令
  • .bat批处理(八):各种形式的变量%0、%i、%%i、var、%var%、!var!的含义和区别
  • .net core使用RPC方式进行高效的HTTP服务访问