• 用深度强化学习来玩Flappy Bird


    目录

    演示视频

    核心代码


    演示视频

    深度强化学习来玩Flappy Bird

    核心代码

    1. import torch.nn as nn
    2. class DeepQNetwork(nn.Module):
    3. def __init__(self):
    4. super(DeepQNetwork, self).__init__()
    5. self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
    6. self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
    7. self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))
    8. self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))
    9. self.fc2 = nn.Linear(512, 2)
    10. self._create_weights()
    11. def _create_weights(self):
    12. for m in self.modules():
    13. if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
    14. nn.init.uniform_(m.weight, -0.01, 0.01)
    15. nn.init.constant_(m.bias, 0)
    16. def forward(self, input):
    17. output = self.conv1(input)
    18. output = self.conv2(output)
    19. output = self.conv3(output)
    20. output = output.view(output.size(0), -1)
    21. output = self.fc1(output)
    22. output = self.fc2(output)
    23. return output
    1. from itertools import cycle
    2. from numpy.random import randint
    3. from pygame import Rect, init, time, display
    4. from pygame.event import pump
    5. from pygame.image import load
    6. from pygame.surfarray import array3d, pixels_alpha
    7. from pygame.transform import rotate
    8. import numpy as np
    9. class FlappyBird(object):
    10. init()
    11. fps_clock = time.Clock()
    12. screen_width = 288
    13. screen_height = 512
    14. screen = display.set_mode((screen_width, screen_height))
    15. display.set_caption('Deep Q-Network Flappy Bird')
    16. base_image = load('assets/sprites/base.png').convert_alpha()
    17. background_image = load('assets/sprites/background-black.png').convert()
    18. pipe_images = [rotate(load('assets/sprites/pipe-green.png').convert_alpha(), 180),
    19. load('assets/sprites/pipe-green.png').convert_alpha()]
    20. bird_images = [load('assets/sprites/redbird-upflap.png').convert_alpha(),
    21. load('assets/sprites/redbird-midflap.png').convert_alpha(),
    22. load('assets/sprites/redbird-downflap.png').convert_alpha()]
    23. # number_images = [load('assets/sprites/{}.png'.format(i)).convert_alpha() for i in range(10)]
    24. bird_hitmask = [pixels_alpha(image).astype(bool) for image in bird_images]
    25. pipe_hitmask = [pixels_alpha(image).astype(bool) for image in pipe_images]
    26. fps = 30
    27. pipe_gap_size = 100
    28. pipe_velocity_x = -4
    29. # parameters for bird
    30. min_velocity_y = -8
    31. max_velocity_y = 10
    32. downward_speed = 1
    33. upward_speed = -9
    34. bird_index_generator = cycle([0, 1, 2, 1])
    35. def __init__(self):
    36. self.iter = self.bird_index = self.score = 0
    37. self.bird_width = self.bird_images[0].get_width()
    38. self.bird_height = self.bird_images[0].get_height()
    39. self.pipe_width = self.pipe_images[0].get_width()
    40. self.pipe_height = self.pipe_images[0].get_height()
    41. self.bird_x = int(self.screen_width / 5)
    42. self.bird_y = int((self.screen_height - self.bird_height) / 2)
    43. self.base_x = 0
    44. self.base_y = self.screen_height * 0.79
    45. self.base_shift = self.base_image.get_width() - self.background_image.get_width()
    46. pipes = [self.generate_pipe(), self.generate_pipe()]
    47. pipes[0]["x_upper"] = pipes[0]["x_lower"] = self.screen_width
    48. pipes[1]["x_upper"] = pipes[1]["x_lower"] = self.screen_width * 1.5
    49. self.pipes = pipes
    50. self.current_velocity_y = 0
    51. self.is_flapped = False
    52. def generate_pipe(self):
    53. x = self.screen_width + 10
    54. gap_y = randint(2, 10) * 10 + int(self.base_y / 5)
    55. return {"x_upper": x, "y_upper": gap_y - self.pipe_height, "x_lower": x, "y_lower": gap_y + self.pipe_gap_size}
    56. def is_collided(self):
    57. # Check if the bird touch ground
    58. if self.bird_height + self.bird_y + 1 >= self.base_y:
    59. return True
    60. bird_bbox = Rect(self.bird_x, self.bird_y, self.bird_width, self.bird_height)
    61. pipe_boxes = []
    62. for pipe in self.pipes:
    63. pipe_boxes.append(Rect(pipe["x_upper"], pipe["y_upper"], self.pipe_width, self.pipe_height))
    64. pipe_boxes.append(Rect(pipe["x_lower"], pipe["y_lower"], self.pipe_width, self.pipe_height))
    65. # Check if the bird's bounding box overlaps to the bounding box of any pipe
    66. if bird_bbox.collidelist(pipe_boxes) == -1:
    67. return False
    68. for i in range(2):
    69. cropped_bbox = bird_bbox.clip(pipe_boxes[i])
    70. min_x1 = cropped_bbox.x - bird_bbox.x
    71. min_y1 = cropped_bbox.y - bird_bbox.y
    72. min_x2 = cropped_bbox.x - pipe_boxes[i].x
    73. min_y2 = cropped_bbox.y - pipe_boxes[i].y
    74. if np.any(self.bird_hitmask[self.bird_index][min_x1:min_x1 + cropped_bbox.width,
    75. min_y1:min_y1 + cropped_bbox.height] * self.pipe_hitmask[i][min_x2:min_x2 + cropped_bbox.width,
    76. min_y2:min_y2 + cropped_bbox.height]):
    77. return True
    78. return False
    79. def next_frame(self, action):
    80. pump()
    81. reward = 0.1
    82. terminal = False
    83. # Check input action
    84. if action == 1:
    85. self.current_velocity_y = self.upward_speed
    86. self.is_flapped = True
    87. # Update score
    88. bird_center_x = self.bird_x + self.bird_width / 2
    89. for pipe in self.pipes:
    90. pipe_center_x = pipe["x_upper"] + self.pipe_width / 2
    91. if pipe_center_x < bird_center_x < pipe_center_x + 5:
    92. self.score += 1
    93. reward = 1
    94. break
    95. # Update index and iteration
    96. if (self.iter + 1) % 3 == 0:
    97. self.bird_index = next(self.bird_index_generator)
    98. self.iter = 0
    99. self.base_x = -((-self.base_x + 100) % self.base_shift)
    100. # Update bird's position
    101. if self.current_velocity_y < self.max_velocity_y and not self.is_flapped:
    102. self.current_velocity_y += self.downward_speed
    103. if self.is_flapped:
    104. self.is_flapped = False
    105. self.bird_y += min(self.current_velocity_y, self.bird_y - self.current_velocity_y - self.bird_height)
    106. if self.bird_y < 0:
    107. self.bird_y = 0
    108. # Update pipes' position
    109. for pipe in self.pipes:
    110. pipe["x_upper"] += self.pipe_velocity_x
    111. pipe["x_lower"] += self.pipe_velocity_x
    112. # Update pipes
    113. if 0 < self.pipes[0]["x_lower"] < 5:
    114. self.pipes.append(self.generate_pipe())
    115. if self.pipes[0]["x_lower"] < -self.pipe_width:
    116. del self.pipes[0]
    117. if self.is_collided():
    118. terminal = True
    119. reward = -1
    120. self.__init__()
    121. # Draw everything
    122. self.screen.blit(self.background_image, (0, 0))
    123. self.screen.blit(self.base_image, (self.base_x, self.base_y))
    124. self.screen.blit(self.bird_images[self.bird_index], (self.bird_x, self.bird_y))
    125. for pipe in self.pipes:
    126. self.screen.blit(self.pipe_images[0], (pipe["x_upper"], pipe["y_upper"]))
    127. self.screen.blit(self.pipe_images[1], (pipe["x_lower"], pipe["y_lower"]))
    128. image = array3d(display.get_surface())
    129. display.update()
    130. self.fps_clock.tick(self.fps)
    131. return image, reward, terminal
    1. import argparse
    2. import torch
    3. from src.deep_q_network import DeepQNetwork
    4. from src.flappy_bird import FlappyBird
    5. from src.utils import pre_processing
    6. def get_args():
    7. parser = argparse.ArgumentParser(
    8. """Implementation of Deep Q Network to play Flappy Bird""")
    9. parser.add_argument("--image_size", type=int, default=84, help="The common width and height for all images")
    10. parser.add_argument("--saved_path", type=str, default="trained_models")
    11. args = parser.parse_args()
    12. return args
    13. def q_test(opt):
    14. if torch.cuda.is_available():
    15. torch.cuda.manual_seed(123)
    16. else:
    17. torch.manual_seed(123)
    18. if torch.cuda.is_available():
    19. model = torch.load("{}/flappy_bird".format(opt.saved_path))
    20. else:
    21. model = torch.load("{}/flappy_bird".format(opt.saved_path), map_location=lambda storage, loc: storage)
    22. model.eval()
    23. game_state = FlappyBird()
    24. image, reward, terminal = game_state.next_frame(0)
    25. image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)
    26. image = torch.from_numpy(image)
    27. if torch.cuda.is_available():
    28. model.cuda()
    29. image = image.cuda()
    30. state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]
    31. while True:
    32. prediction = model(state)[0]
    33. action = torch.argmax(prediction)
    34. next_image, reward, terminal = game_state.next_frame(action)
    35. next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,
    36. opt.image_size)
    37. next_image = torch.from_numpy(next_image)
    38. if torch.cuda.is_available():
    39. next_image = next_image.cuda()
    40. next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]
    41. state = next_state
    42. if __name__ == "__main__":
    43. opt = get_args()
    44. q_test(opt)
    1. def get_args():
    2. parser = argparse.ArgumentParser(
    3. """Implementation of Deep Q Network to play Flappy Bird""")
    4. parser.add_argument("--image_size", type=int, default=84, help="The common width and height for all images")
    5. parser.add_argument("--batch_size", type=int, default=32, help="The number of images per batch")
    6. parser.add_argument("--optimizer", type=str, choices=["sgd", "adam"], default="adam")
    7. parser.add_argument("--lr", type=float, default=1e-6)
    8. parser.add_argument("--gamma", type=float, default=0.99)
    9. parser.add_argument("--initial_epsilon", type=float, default=0.1)
    10. parser.add_argument("--final_epsilon", type=float, default=1e-4)
    11. parser.add_argument("--num_iters", type=int, default=2000000)
    12. parser.add_argument("--replay_memory_size", type=int, default=50000,
    13. help="Number of epoches between testing phases")
    14. parser.add_argument("--log_path", type=str, default="tensorboard")
    15. parser.add_argument("--saved_path", type=str, default="trained_models")
    16. args = parser.parse_args()
    17. return args
    18. def train(opt):
    19. if torch.cuda.is_available():
    20. torch.cuda.manual_seed(123)
    21. else:
    22. torch.manual_seed(123)
    23. model = DeepQNetwork()
    24. if os.path.isdir(opt.log_path):
    25. shutil.rmtree(opt.log_path)
    26. os.makedirs(opt.log_path)
    27. writer = SummaryWriter(opt.log_path)
    28. optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    29. criterion = nn.MSELoss()
    30. game_state = FlappyBird()
    31. image, reward, terminal = game_state.next_frame(0)
    32. image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)
    33. image = torch.from_numpy(image)
    34. if torch.cuda.is_available():
    35. model.cuda()
    36. image = image.cuda()
    37. state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]
    38. replay_memory = []
    39. iter = 0
    40. while iter < opt.num_iters:
    41. prediction = model(state)[0]
    42. # Exploration or exploitation
    43. epsilon = opt.final_epsilon + (
    44. (opt.num_iters - iter) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_iters)
    45. u = random()
    46. random_action = u <= epsilon
    47. if random_action:
    48. print("Perform a random action")
    49. action = randint(0, 1)
    50. else:
    51. action = torch.argmax(prediction)
    52. next_image, reward, terminal = game_state.next_frame(action)
    53. next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,
    54. opt.image_size)
    55. next_image = torch.from_numpy(next_image)
    56. if torch.cuda.is_available():
    57. next_image = next_image.cuda()
    58. next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]
    59. replay_memory.append([state, action, reward, next_state, terminal])
    60. if len(replay_memory) > opt.replay_memory_size:
    61. del replay_memory[0]
    62. batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))
    63. state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(*batch)
    64. state_batch = torch.cat(tuple(state for state in state_batch))
    65. action_batch = torch.from_numpy(
    66. np.array([[1, 0] if action == 0 else [0, 1] for action in action_batch], dtype=np.float32))
    67. reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])
    68. next_state_batch = torch.cat(tuple(state for state in next_state_batch))
    69. if torch.cuda.is_available():
    70. state_batch = state_batch.cuda()
    71. action_batch = action_batch.cuda()
    72. reward_batch = reward_batch.cuda()
    73. next_state_batch = next_state_batch.cuda()
    74. current_prediction_batch = model(state_batch)
    75. next_prediction_batch = model(next_state_batch)
    76. y_batch = torch.cat(
    77. tuple(reward if terminal else reward + opt.gamma * torch.max(prediction) for reward, terminal, prediction in
    78. zip(reward_batch, terminal_batch, next_prediction_batch)))
    79. q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
    80. optimizer.zero_grad()
    81. # y_batch = y_batch.detach()
    82. loss = criterion(q_value, y_batch)
    83. loss.backward()
    84. optimizer.step()
    85. state = next_state
    86. iter += 1
    87. print("Iteration: {}/{}, Action: {}, Loss: {}, Epsilon {}, Reward: {}, Q-value: {}".format(
    88. iter + 1,
    89. opt.num_iters,
    90. action,
    91. loss,
    92. epsilon, reward, torch.max(prediction)))
    93. writer.add_scalar('Train/Loss', loss, iter)
    94. writer.add_scalar('Train/Epsilon', epsilon, iter)
    95. writer.add_scalar('Train/Reward', reward, iter)
    96. writer.add_scalar('Train/Q-value', torch.max(prediction), iter)
    97. if (iter+1) % 1000000 == 0:
    98. torch.save(model, "{}/flappy_bird_{}".format(opt.saved_path, iter+1))
    99. torch.save(model, "{}/flappy_bird".format(opt.saved_path))
    100. if __name__ == "__main__":
    101. opt = get_args()
    102. train(opt)

  • 相关阅读:
    DP27 跳跃游戏(二)
    PPT基础:表格
    区块链的生成与基本操作
    ​力扣解法汇总1374-生成每种字符都是奇数个的字符串
    如何在Mac上启用蓝牙,这里提供几个方法
    没有英语要求的中国人大女王金融硕士有多香你可能还不知道
    Redis 持久化之RDB操作
    大型Web网站高并发架构方案
    EFK部署centos7.9(三)Kibana部署
    DC综合基本概念:set_compile_directives
  • 原文地址:https://blog.csdn.net/timberman666/article/details/132590406