当前位置: 首页 > news >正文

24/8/17算法笔记 模仿学习算法

模仿学习(Imitation Learning,IL)算法是强化学习领域的一个分支,它关注于让智能体通过模仿专家的行为来学习任务。模仿学习通常用于学习复杂任务,尤其是当通过传统的强化学习算法直接学习效率较低或成本较高时。以下是一些常见的模仿学习算法:

  1. 行为克隆(Behavioral Cloning, BC):

    • 这是最直接的模仿学习方法,通过简单地训练一个策略网络来模仿专家的决策。通常使用监督学习的方法,将专家的(状态,动作)对作为训练数据。
  2. 逆强化学习(Inverse Reinforcement Learning, IRL):

    • 在这种方法中,算法试图推断出专家行为背后的奖励函数。一旦学习到奖励函数,就可以使用传统的强化学习算法来训练策略。
  3. 模仿学习(Apprenticeship Learning):

    • 这种方法结合了行为克隆和逆强化学习的思想,通过在行为克隆的基础上对策略进行改进,以更好地适应任务。
  4. 对抗性模仿学习(Adversarial Imitation Learning):

    • 这种方法使用生成对抗网络(GANs)的思想,通过训练一个生成模型来模仿专家的行为,同时训练一个判别模型来区分真实和生成的行为。
  5. 专家演示再参数化(Reparameterization of Expert Demonstrations):

    • 在这种方法中,专家的演示被重新参数化,以便于策略网络可以学习如何生成类似专家的行为。
  6. 最大熵模仿学习(Maximum Entropy Inverse Reinforcement Learning):

    • 这种方法在逆强化学习的框架下,通过最大化策略的熵来鼓励探索性行为,同时学习奖励函数。
  7. 序列级模仿学习(Sequence-Level Imitation Learning):

    • 这种方法关注于模仿一系列动作,而不仅仅是单个动作,通常使用序列模型如RNN或LSTM来捕捉时间依赖性。
  8. 元模仿学习(Meta-Learning for Imitation Learning):

    • 这种方法通过元学习技术,使智能体能够快速适应新的模仿任务,即使只有少量的专家演示。

模仿学习算法的关键优势在于它们可以利用专家的知识来加速学习过程,减少试错的次数。然而,这些算法也面临一些挑战,例如如何确保学习到的策略在未见过的状态上也能表现良好,以及如何处理专家演示中的噪声或次优行为。

模仿学习在自动驾驶、机器人控制、游戏AI等领域有广泛的应用。通过模仿专家的行为,智能体可以在这些复杂任务中快速获得有效的策略。

import gym
from matplotlib import pyplot as plt
%matplotlib inline
#创建环境
env = gym.make('CartPole-v1')
env.reset()#打印游戏
def show():plt.imshow(env.render(mode='rgb_array'))plt.show()

定义ppo模型


class PPO:def __init__(self, env, model, model_td):self.env = envself.model = modelself.model_td = model_tdself.optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)self.optimizer_td = torch.optim.Adam(model_td.parameters(), lr=1e-2)self.loss_fn = torch.nn.MSELoss()def get_action(self, state):state = torch.FloatTensor(state).reshape(1, 4)prob = self.model(state)action = random.choice(range(2), weights=prob[0].tolist(), k=1)[0]return actiondef get_data(self):states, rewards, actions, next_states, overs = [], [], [], [], []state = self.env.reset()over = Falsewhile not over:action = self.get_action(state)next_state, reward, over, _ = self.env.step(action)states.append(state)rewards.append(reward)actions.append(action)next_states.append(next_state)overs.append(over)state = next_statereturn (torch.FloatTensor(states).reshape(-1, 4),torch.FloatTensor(rewards).reshape(-1, 1),torch.LongTensor(actions).reshape(-1, 1),torch.FloatTensor(next_states).reshape(-1, 4),torch.LongTensor(overs).reshape(-1, 1))def test(self, play=False):state = self.env.reset()reward_sum = 0over = Falsewhile not over:action = self.get_action(state)state, reward, over, _ = self.env.step(action)reward_sum += rewardif play and random.random() < 0.2:display.clear_output(wait=True)# 假设有一个函数 show() 用于显示游戏状态# show()return reward_sumdef train(self, epochs=1000):for epoch in range(epochs):states, rewards, actions, next_states, overs = self.get_data()values = self.model_td(states)with torch.no_grad():targets = self.model_td(next_states)targets = targets * 0.98targets = (1 - overs) * targets + rewardsdeltas = targets - valuesadvantages = self.get_advantages(deltas)advantages = torch.FloatTensor(advantages).reshape(-1, 1)old_probs = self.model(states).gather(dim=1, index=actions)for _ in range(10):new_probs = self.model(states)ratios = new_probs / old_probssurr1 = ratios * advantagessurr2 = torch.clamp(ratios, 1 - 0.1, 1 + 0.1) * advantagesloss = -torch.min(surr1, surr2).mean()self.optimizer.zero_grad()loss.backward()self.optimizer.step()# 更新价值网络loss_td = self.loss_fn(self.model_td(states), targets)self.optimizer_td.zero_grad()loss_td.backward()self.optimizer_td.step()if epoch % 100 == 0:test_result = sum([self.test(play=False) for _ in range(10)]) / 10print(epoch, test_result)

定义teacher模型

teacher = PPO()
teacher.train(*teacher.get_data())
teacher.get_action([1,2,3,4]),teacher.test(play=Fasle)

训练teacher模型

#训练teacher模型
for i in range(500):teacher.train(*teacher.get_data())if i%50==0:test_result = sum([teacher.test(play=False)for _ in range(10)])/10print(i,test_reslut)teacher.test(play=False)

获取教师数据,并且删除教师

#获取教师数据,并且删除教师
#使用训练好的模型获取一批教师数据
teacher_states, _,teacher_actions, _, _=teacher.get_date()#删除教师,只留下教师的数据就可以了
del teacherteacher_states.shape,teacher_actions.shape

创建学生模型

student = PPO()

创建鉴定器模型

#定义鉴别器网络,它的任务是鉴定一批数据是出自teaccher还是student
class Discriminator(torch.nn.Module):def __init__(self):super().__init__()self.sequential = torch.nn.Sequential(torch.nn.Linear(6,128),torch.nn.ReLU(),torch.nn.Linear(128,1),torch.NN.Sigmoid(),)def forward(self,states,actions):one_hot = torch.nn.functional.one_hot(actions.squeeze(dim = 1),num_classes = 2)cat = torch.cat([states,one_hot],dim=1)return self.sequential(cat)discriminator = Discriminator()
discriminator(torch,randn(2,4),torch.ones(2,1).log())

torch.nn.functional.one_hot 函数用于将离散的索引转换成 one-hot 编码形式。这种编码在处理分类问题或需要表示每个类别的独立特征时非常有用。

您提供的代码 one_hot = torch.nn.functional.one_hot(actions.squeeze(dim=1), num_classes=2) 执行以下操作:

  1. actions.squeeze(dim=1):这个调用首先将 actions 张量在第二个维度(dim=1)上进行挤压,移除维度为 1 的单维度。这通常用于准备 one-hot 编码,确保 actions 张量在进行 one-hot 编码之前是一维的。

  2. torch.nn.functional.one_hot(...):接着,对挤压后的 actions 张量应用 one-hot 编码。编码的类别数由 num_classes=2 指定,这意味着每个索引将被转换成一个长度为 2 的 one-hot 向量。

  3. one_hot:这是编码后得到的张量的变量名。对于每个索引值,它将在对应索引的位置上有一个 1,其余位置为 0。

例如,如果 actions 是一个包含 [0, 1] 的张量,one_hot 将是一个长度为 2 的张量,其中第一个元素是 [1, 0],第二个元素是 [0, 1]

One-hot 编码在强化学习中的策略梯度方法中很有用,特别是当你需要从连续的动作概率中采样动作时,它可以帮助你将概率转换为动作的 one-hot 编码表示。在某些实现中,这种方法可以用于计算策略网络输出的概率与实际采取的动作之间的差异。

模仿学习函数

def copy_learn():optimizer = torch.optim.Adam(discriminator.parameters(),lr = 1e-3)bce_loss = torch.nn.BCELoss()for i in range(500):#使用学生模型获取一局游戏的数据,不需要rewardstates,_,actions,next_states,overs = student.get_data()#使用鉴别器坚定两批数据是来自教师还是学生prob_teacher = discriminator(teacher_states,teacher_actions)prob_student = discriminator(states,actions)#老师的用0表示,学生的用1表示,计算二分类lossloss_teacher = bec_loss(prob_teacher,torch.zeros_like(prob_teacher))loss_student = bec_loss(prob_student,torch.ones_like(prob_student))#调整鉴别器的lossoptimizer.zero_grad()loss.backward()optimizer.step()#使用一批数据来自学生的概率作为reward,取loss再符号取反#因为鉴别器会把学生数据的概率贴近1,所以目标是让鉴别器无法分辨,这是一种对抗网络的思路reward =-prob_student.log().detach()#更新学生模型参数,使用PPO模型本身的更新方式student.train(states,rewards,actions,next_states,overs)if i%50==0:test_result =sum([student.test(play=False)for _ in range(10)])/10print(i,test_result)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Spring中AbstractAutowireCapableBeanFactory
  • Unity3D开发之OnCollisionXXX触发条件
  • Spring Boot集成Devtools实现热更新?
  • 8.15 day bug
  • 最佳薪酬管理系统盘点:9款优选推荐
  • 微信答题小程序产品研发-后端开发
  • 重复的子字符串 | LeetCode-459 | 字符串匹配 | KMP | 双指针
  • 融合创新:EasyCVR视频汇聚平台云计算技术与AI技术共筑雪亮工程智能防线
  • WEB漏洞-SQL注入之简要SQL注入
  • 零售业务产品系统应用架构设计(三)
  • 牛客网SQL 练习 一
  • 网络专线和IPsecVPN在使用上有什么区别?
  • TypeScript 构建工具之 webpack
  • Spring框架中ReflectionUtils类
  • 2024.8.18
  • .pyc 想到的一些问题
  • [js高手之路]搞清楚面向对象,必须要理解对象在创建过程中的内存表示
  • 【MySQL经典案例分析】 Waiting for table metadata lock
  • CSS 专业技巧
  • iOS 系统授权开发
  • java 多线程基础, 我觉得还是有必要看看的
  • java2019面试题北京
  • JS笔记四:作用域、变量(函数)提升
  • JS进阶 - JS 、JS-Web-API与DOM、BOM
  • 好的网址,关于.net 4.0 ,vs 2010
  • 基于OpenResty的Lua Web框架lor0.0.2预览版发布
  • 记一次用 NodeJs 实现模拟登录的思路
  • 区块链将重新定义世界
  • 日剧·日综资源集合(建议收藏)
  • 异常机制详解
  • 主流的CSS水平和垂直居中技术大全
  • ​ArcGIS Pro 如何批量删除字段
  • ​Spring Boot 分片上传文件
  • ​草莓熊python turtle绘图代码(玫瑰花版)附源代码
  • ​人工智能书单(数学基础篇)
  • ​十个常见的 Python 脚本 (详细介绍 + 代码举例)
  • #《AI中文版》V3 第 1 章 概述
  • #include到底该写在哪
  • #调用传感器数据_Flink使用函数之监控传感器温度上升提醒
  • (2/2) 为了理解 UWP 的启动流程,我从零开始创建了一个 UWP 程序
  • (八)Docker网络跨主机通讯vxlan和vlan
  • (论文阅读40-45)图像描述1
  • (算法)区间调度问题
  • (一一四)第九章编程练习
  • (转)Oracle存储过程编写经验和优化措施
  • (自适应手机端)响应式服装服饰外贸企业网站模板
  • *算法训练(leetcode)第四十五天 | 101. 孤岛的总面积、102. 沉没孤岛、103. 水流问题、104. 建造最大岛屿
  • .360、.halo勒索病毒的最新威胁:如何恢复您的数据?
  • .bat批处理(九):替换带有等号=的字符串的子串
  • .naturalWidth 和naturalHeight属性,
  • .net 8 发布了,试下微软最近强推的MAUI
  • .NET中 MVC 工厂模式浅析
  • @Conditional注解详解
  • @SuppressWarnings注解
  • @value 静态变量_Python彻底搞懂:变量、对象、赋值、引用、拷贝