- log_prob = torch.log(self.policy_net(state).gather(1, action))
- G = self.gamma * G + reward
- loss = -log_prob * G # 每一步的损失函数
可以看到:
- q_values = self.q_net(states).gather(1, actions) # Q值
- # 下个状态的最大Q值
- max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)
- q_targets = rewards + self.gamma * max_next_q_values * (1 - dones) # TD误差目标
- dqn_loss = torch.mean(F.mse_loss(q_values, q_targets)) # 均方误差损失函数
- class PolicyNet(torch.nn.Module):
- def __init__(self, state_dim, hidden_dim, action_dim):
- super(PolicyNet, self).__init__()
- self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
- self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
-
- def forward(self, x):
- x = F.relu(self.fc1(x))
- return F.softmax(self.fc2(x), dim=1)
- class Qnet(torch.nn.Module):
- ''' 只有一层隐藏层的Q网络 '''
- def __init__(self, state_dim, hidden_dim, action_dim):
- super(Qnet, self).__init__()
- self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
- self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
-
- def forward(self, x):
- x = F.relu(self.fc1(x)) # 隐藏层使用ReLU激活函数
- return self.fc2(x)