• 强化学习笔记之【SAC算法】


    强化学习笔记之【SAC算法】


    前言:

    本文为强化学习笔记第四篇,第一篇讲的是Q-learning和DQN,第二篇DDPG,第三篇TD3

    TD3比DDPG少了一个target_actor网络,其它地方有点小改动

    CSDN主页:https://blog.csdn.net/rvdgdsva

    博客园主页:https://www.cnblogs.com/hassle


    STAND ALONE COMPLEX = S . A . C

    首先,我们需要明确,Q-learning算法发展成DQN算法,DQN算法发展成为DDPG算法,而DDPG算法发展成TD3算法,TD3算法发展成SAC算法

    Soft Actor-Critic (SAC) 是一种基于策略梯度的深度强化学习算法,它具有最大化奖励与最大化熵(探索性)的双重目标。SAC 通过引入熵正则项,使策略在决策时具有更大的随机性,从而提高探索能力。

    一、SAC算法

    OK,先用伪代码让你们感受一下SAC算法

    # 定义 SAC 超参数
    alpha = 0.2               # 熵正则项系数
    gamma = 0.99              # 折扣因子
    tau = 0.005               # 目标网络软更新参数
    lr = 3e-4                 # 学习率
    
    # 初始化 Actor、Critic、Target Critic 网络和优化器
    actor = ActorNetwork()                      # 策略网络 π(s)
    critic1 = CriticNetwork()                   # 第一个 Q 网络 Q1(s, a)
    critic2 = CriticNetwork()                   # 第二个 Q 网络 Q2(s, a)
    target_critic1 = CriticNetwork()            # 目标 Q 网络 1
    target_critic2 = CriticNetwork()            # 目标 Q 网络 2
    
    # 将目标 Q 网络的参数设置为与 Critic 网络相同
    target_critic1.load_state_dict(critic1.state_dict())
    target_critic2.load_state_dict(critic2.state_dict())
    
    # 初始化优化器
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=lr)
    critic1_optimizer = torch.optim.Adam(critic1.parameters(), lr=lr)
    critic2_optimizer = torch.optim.Adam(critic2.parameters(), lr=lr)
    
    # 经验回放池(Replay Buffer)
    replay_buffer = ReplayBuffer()
    
    # SAC 训练循环
    for each iteration:
        # Step 1: 从 Replay Buffer 中采样一个批次 (state, action, reward, next_state)
        batch = replay_buffer.sample()
        state, action, reward, next_state, done = batch
    
        # Step 2: 计算目标 Q 值 (y)
        with torch.no_grad():
            # 从 Actor 网络中获取 next_state 的下一个动作
            next_action, next_log_prob = actor.sample(next_state)
            
            # 目标 Q 值的计算:使用目标 Q 网络的最小值 + 熵项
            target_q1_value = target_critic1(next_state, next_action)
            target_q2_value = target_critic2(next_state, next_action)
            min_target_q_value = torch.min(target_q1_value, target_q2_value)
    
            # 目标 Q 值 y = r + γ * (最小目标 Q 值 - α * next_log_prob)
            target_q_value = reward + gamma * (1 - done) * (min_target_q_value - alpha * next_log_prob)
    
        # Step 3: 更新 Critic 网络
        # Critic 1 损失
        current_q1_value = critic1(state, action)
        critic1_loss = F.mse_loss(current_q1_value, target_q_value)
    
        # Critic 2 损失
        current_q2_value = critic2(state, action)
        critic2_loss = F.mse_loss(current_q2_value, target_q_value)
    
        # 反向传播并更新 Critic 网络参数
        critic1_optimizer.zero_grad()
        critic1_loss.backward()
        critic1_optimizer.step()
    
        critic2_optimizer.zero_grad()
        critic2_loss.backward()
        critic2_optimizer.step()
    
        # Step 4: 更新 Actor 网络
        # 通过 Actor 网络生成新的动作及其 log 概率
        new_action, log_prob = actor.sample(state)
    
        # 计算 Actor 的目标损失:L = α * log_prob - Q1(s, π(s))
        q1_value = critic1(state, new_action)
        actor_loss = (alpha * log_prob - q1_value).mean()
    
        # 反向传播并更新 Actor 网络参数
        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()
    
        # Step 5: 软更新目标 Q 网络参数
        with torch.no_grad():
            for param, target_param in zip(critic1.parameters(), target_critic1.parameters()):
                target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
    
            for param, target_param in zip(critic2.parameters(), target_critic2.parameters()):
                target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
    

    二、SAC算法Latex解释

    1、初始化 Actor、Critic1、Critic2、TargetCritic1 、TargetCritic2 网络
    2、Buffer中采样 (state, action, reward, next_state)

    3、Actor 输入 next_state 对应输出 next_action 和 next_log_prob
    4、Actor 输入 state 对应输出 new_action 和 log_prob
    5、Critic1 和 Critic2 分别输入next_state 和 next_action 取其中较小输出经熵正则计算得 target_q_value

    6、使用 MSE_loss(Critic1(state, action), target_q_value) 更新 Critic1
    7、使用 MSE_loss(Critic2(state, action), target_q_value) 更新 Critic2
    8、使用 (alpha * log_prob - critic1(state, new_action)).mean() 更新 Actor


    三、SAC五大网络和模块

    SAC 算法 中,Actor、Critic1、Critic2、Target Critic1 和 Target Critic2 网络是核心模块,它们分别用于输出动作、评估状态-动作对的价值,并通过目标网络进行稳定的更新。

    3.1 Actor 网络

    Actor 网络用于在给定状态下输出一个高斯分布的均值和标准差(即策略)。它是通过神经网络近似的随机策略。用于选择动作。

    import torch
    import torch.nn as nn
    
    class ActorNetwork(nn.Module):
        def __init__(self, state_dim, action_dim):
            super(ActorNetwork, self).__init__()
            self.fc1 = nn.Linear(state_dim, 256)
            self.fc2 = nn.Linear(256, 256)
            self.mean_layer = nn.Linear(256, action_dim)  # 输出动作的均值
            self.log_std_layer = nn.Linear(256, action_dim)  # 输出动作的log标准差
    
        def forward(self, state):
            x = torch.relu(self.fc1(state))
            x = torch.relu(self.fc2(x))
            mean = self.mean_layer(x)  # 输出动作均值
            log_std = self.log_std_layer(x)  # 输出 log 标准差
            log_std = torch.clamp(log_std, min=-20, max=2)  # 限制标准差范围
            return mean, log_std
    
        def sample(self, state):
            mean, log_std = self.forward(state)
            std = torch.exp(log_std)  # 将 log 标准差转为标准差
            normal = torch.distributions.Normal(mean, std)
            action = normal.rsample()  # 通过重参数化技巧进行采样
            log_prob = normal.log_prob(action).sum(-1)  # 计算 log 概率
            return action, log_prob
    
    

    3.2 Critic1 和 Critic2 网络

    Critic 网络用于计算状态-动作对的 Q 值,SAC 使用两个 Critic 网络(Critic1 和 Critic2)来缓解 Q 值的过估计问题。

    class CriticNetwork(nn.Module):
        def __init__(self, state_dim, action_dim):
            super(CriticNetwork, self).__init__()
            self.fc1 = nn.Linear(state_dim + action_dim, 256)
            self.fc2 = nn.Linear(256, 256)
            self.q_value_layer = nn.Linear(256, 1)  # 输出 Q 值
    
        def forward(self, state, action):
            x = torch.cat([state, action], dim=-1)  # 将 state 和 action 作为输入
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            q_value = self.q_value_layer(x)  # 输出 Q 值
            return q_value
    
    

    3.3 Target Critic1 和 Target Critic2 网络

    Target Critic 网络的结构与 Critic 网络相同,用于稳定 Q 值更新。它们通过软更新(即在每次训练后慢慢接近 Critic 网络的参数)来保持训练的稳定性。

    class TargetCriticNetwork(nn.Module):
        def __init__(self, state_dim, action_dim):
            super(TargetCriticNetwork, self).__init__()
            self.fc1 = nn.Linear(state_dim + action_dim, 256)
            self.fc2 = nn.Linear(256, 256)
            self.q_value_layer = nn.Linear(256, 1)  # 输出 Q 值
    
        def forward(self, state, action):
            x = torch.cat([state, action], dim=-1)  # 将 state 和 action 作为输入
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            q_value = self.q_value_layer(x)  # 输出 Q 值
            return q_value
    

    3.4 软更新模块

    在 SAC 中,目标网络会通过软更新逐渐逼近 Critic 网络的参数。每次更新后,目标网络参数会按照 ττ 的比例向 Critic 网络的参数靠拢。

    def soft_update(critic, target_critic, tau=0.005):
        for param, target_param in zip(critic.parameters(), target_critic.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
    

    3.5 总结

    1. 初始化网络和参数:
      • Actor 网络:用于选择动作。
      • Critic 1 和 Critic 2 网络:用于估计 Q 值。
      • Target Critic 1 和 Target Critic 2:与 Critic 网络架构相同,用于生成更稳定的目标 Q 值。
    2. 目标 Q 值计算:
      • 使用目标网络计算下一状态下的 Q 值。
      • 取两个 Q 网络输出的最小值,防止 Q 值的过估计。
      • 引入熵正则项,计算公式:y=r+γmin(Q1,Q2)αlogπ(a|s)
    3. 更新 Critic 网络:
      • 最小化目标 Q 值与当前 Q 值的均方误差 (MSE)。
    4. 更新 Actor 网络:
      • 最大化目标损失:L=αlogπ(a|s)Q1(s,π(s)),即在保证探索的情况下选择高价值动作。
    5. 软更新目标网络:
      • 软更新目标 Q 网络参数,使得目标网络参数缓慢向当前网络靠近,避免振荡。


    __EOF__

  • 本文作者: El Psy Kongroo!
  • 本文链接: https://www.cnblogs.com/hassle/p/18459320
  • 关于博主: 研二计算机遥感方向转强化学习方向,喜欢英国源神、杀戮尖塔、香蕉锁头、galgame,和下午的一杯红茶。
  • 版权声明: 本博客所有文章除特别声明外,均采用BY-NC-SA 许可协议。转载需要注明出处
  • 声援博主: 点个赞再走吧,初音未来会护佑每一位虔诚的信徒!
  • 相关阅读:
    Spark SQL之IDEA中的应用
    云原生之DevOps
    《Java基础知识》Java transient关键字详解
    VSCode编译运行C代码
    大数据与人工智能人脸识别
    并行前缀和计算——MPI SCAN算法的C语言实现
    yamot:一款功能强大的基于Web的服务器安全监控工具
    API接口采集商品详情页面数据(H5端和PC端)item_get-获得淘宝商品详情
    测试用例的8大设计原则
    SEO 对企业的重要性
  • 原文地址:https://www.cnblogs.com/hassle/p/18459320