• [强化学习总结6] actor-critic算法


    • actor:策略
    • critic:评估价值

    Actor-Critic 是囊括一系列算法的整体架构,目前很多高效的前沿算法都属于 Actor-Critic 算法,本章接下来将会介绍一种最简单的 Actor-Critic 算法。需要明确的是,Actor-Critic 算法本质上是基于策略的算法,因为这一系列算法的目标都是优化一个带参数的策略,只是会额外学习价值函数,从而帮助策略函数更好地学习。

    1 核心

    • 在 REINFORCE 算法中,目标函数的梯度中有一项轨迹回报(trajectory return),用于指导策略(policy, π(s | a) )的更新。REINFOCE 算法用蒙特卡洛方法来估计q(s, a)。
      • 其实就是用回报作为策略的加权值,所以这里可以推广出一个一般形式,只要一个值能作为aciton的好坏的判断,就可以做为权重。
      • 所以critic学的就是一个权重,输出是一个值。

      • 权重可以有以下这些:

    actor-critic优势:

    • 事实上,用q值或者v值本质上也是用奖励来进行指导,但是用神经网络进行估计的方法可以减小方差、提高鲁棒性。除此之外,REINFORCE 算法基于蒙特卡洛采样,只能在序列结束后进行更新,这同时也要求任务具有有限的步数,而 Actor-Critic 算法则可以在每一步之后都进行更新,并且不对任务的步数做限制。

    2 Actor-Critic

    我们将 Actor-Critic 分为两个部分:Actor(策略网络)和 Critic(价值网络)。

    • Actor 要做的是与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略。
    • Critic 要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断在当前状态什么动作是好的,什么动作不是好的,进而帮助 Actor 进行策略更新。

    2.1 code

    说的再多,不如看看代码。Actor-Critic 算法

    网络

    1. class PolicyNet(torch.nn.Module):
    2. def __init__(self, state_dim, hidden_dim, action_dim):
    3. super(PolicyNet, self).__init__()
    4. self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
    5. self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
    6. def forward(self, x):
    7. x = F.relu(self.fc1(x))
    8. return F.softmax(self.fc2(x), dim=1)
    9. class ValueNet(torch.nn.Module):
    10. def __init__(self, state_dim, hidden_dim):
    11. super(ValueNet, self).__init__()
    12. self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
    13. self.fc2 = torch.nn.Linear(hidden_dim, 1) ## 输出是1个值
    14. def forward(self, x):
    15. x = F.relu(self.fc1(x))
    16. return self.fc2(x)
    1. class ActorCritic:
    2. def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
    3. gamma, device):
    4. # 策略网络
    5. self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
    6. self.critic = ValueNet(state_dim, hidden_dim).to(device) # 价值网络
    7. # 策略网络优化器
    8. self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
    9. lr=actor_lr)
    10. self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
    11. lr=critic_lr) # 价值网络优化器
    12. self.gamma = gamma
    13. self.device = device
    14. def take_action(self, state):
    15. state = torch.tensor([state], dtype=torch.float).to(self.device)
    16. probs = self.actor(state)
    17. action_dist = torch.distributions.Categorical(probs)
    18. action = action_dist.sample()
    19. return action.item()
    20. def update(self, transition_dict):
    21. states = torch.tensor(transition_dict['states'],
    22. dtype=torch.float).to(self.device)
    23. actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
    24. self.device)
    25. rewards = torch.tensor(transition_dict['rewards'],
    26. dtype=torch.float).view(-1, 1).to(self.device)
    27. next_states = torch.tensor(transition_dict['next_states'],
    28. dtype=torch.float).to(self.device)
    29. dones = torch.tensor(transition_dict['dones'],
    30. dtype=torch.float).view(-1, 1).to(self.device)
    31. # 时序差分目标
    32. td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
    33. td_delta = td_target - self.critic(states) # 时序差分误差
    34. log_probs = torch.log(self.actor(states).gather(1, actions))
    35. actor_loss = torch.mean(-log_probs * td_delta.detach())
    36. # 均方误差损失函数
    37. critic_loss = torch.mean(
    38. F.mse_loss(self.critic(states), td_target.detach()))
    39. self.actor_optimizer.zero_grad()
    40. self.critic_optimizer.zero_grad()
    41. actor_loss.backward() # 计算策略网络的梯度
    42. critic_loss.backward() # 计算价值网络的梯度
    43. self.actor_optimizer.step() # 更新策略网络的参数
    44. self.critic_optimizer.step() # 更新价值网络的参数

  • 相关阅读:
    【时间序列分析】A Transformer-based Framework for Multivariate Time Series Representation Learning论文笔记
    机器学习(四十三):MLflow机器学习模型生命周期管理
    磁盘调度算法例题解析以及C语言实现
    react: scss使用样式
    Cadence OrCAD Capture 绘制总线的方法
    vue中使用高德地图的热力图方法1
    Unity AI Muse 基础教程
    Python实操:内存管理与优化策略
    条件构造器
    105. 从前序与中序遍历序列构造二叉树
  • 原文地址:https://blog.csdn.net/u012925804/article/details/127625683