目录
用深度强化学习来玩Flappy Bird

- import torch.nn as nn
-
- class DeepQNetwork(nn.Module):
- def __init__(self):
- super(DeepQNetwork, self).__init__()
-
- self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
- self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
- self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))
-
- self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))
- self.fc2 = nn.Linear(512, 2)
- self._create_weights()
-
- def _create_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
- nn.init.uniform_(m.weight, -0.01, 0.01)
- nn.init.constant_(m.bias, 0)
-
- def forward(self, input):
- output = self.conv1(input)
- output = self.conv2(output)
- output = self.conv3(output)
- output = output.view(output.size(0), -1)
- output = self.fc1(output)
- output = self.fc2(output)
-
- return output
- from itertools import cycle
- from numpy.random import randint
- from pygame import Rect, init, time, display
- from pygame.event import pump
- from pygame.image import load
- from pygame.surfarray import array3d, pixels_alpha
- from pygame.transform import rotate
- import numpy as np
-
-
- class FlappyBird(object):
- init()
- fps_clock = time.Clock()
- screen_width = 288
- screen_height = 512
- screen = display.set_mode((screen_width, screen_height))
- display.set_caption('Deep Q-Network Flappy Bird')
- base_image = load('assets/sprites/base.png').convert_alpha()
- background_image = load('assets/sprites/background-black.png').convert()
-
- pipe_images = [rotate(load('assets/sprites/pipe-green.png').convert_alpha(), 180),
- load('assets/sprites/pipe-green.png').convert_alpha()]
- bird_images = [load('assets/sprites/redbird-upflap.png').convert_alpha(),
- load('assets/sprites/redbird-midflap.png').convert_alpha(),
- load('assets/sprites/redbird-downflap.png').convert_alpha()]
- # number_images = [load('assets/sprites/{}.png'.format(i)).convert_alpha() for i in range(10)]
-
- bird_hitmask = [pixels_alpha(image).astype(bool) for image in bird_images]
- pipe_hitmask = [pixels_alpha(image).astype(bool) for image in pipe_images]
-
- fps = 30
- pipe_gap_size = 100
- pipe_velocity_x = -4
-
- # parameters for bird
- min_velocity_y = -8
- max_velocity_y = 10
- downward_speed = 1
- upward_speed = -9
-
- bird_index_generator = cycle([0, 1, 2, 1])
-
- def __init__(self):
-
- self.iter = self.bird_index = self.score = 0
-
- self.bird_width = self.bird_images[0].get_width()
- self.bird_height = self.bird_images[0].get_height()
- self.pipe_width = self.pipe_images[0].get_width()
- self.pipe_height = self.pipe_images[0].get_height()
-
- self.bird_x = int(self.screen_width / 5)
- self.bird_y = int((self.screen_height - self.bird_height) / 2)
-
- self.base_x = 0
- self.base_y = self.screen_height * 0.79
- self.base_shift = self.base_image.get_width() - self.background_image.get_width()
-
- pipes = [self.generate_pipe(), self.generate_pipe()]
- pipes[0]["x_upper"] = pipes[0]["x_lower"] = self.screen_width
- pipes[1]["x_upper"] = pipes[1]["x_lower"] = self.screen_width * 1.5
- self.pipes = pipes
-
- self.current_velocity_y = 0
- self.is_flapped = False
-
- def generate_pipe(self):
- x = self.screen_width + 10
- gap_y = randint(2, 10) * 10 + int(self.base_y / 5)
- return {"x_upper": x, "y_upper": gap_y - self.pipe_height, "x_lower": x, "y_lower": gap_y + self.pipe_gap_size}
-
- def is_collided(self):
- # Check if the bird touch ground
- if self.bird_height + self.bird_y + 1 >= self.base_y:
- return True
- bird_bbox = Rect(self.bird_x, self.bird_y, self.bird_width, self.bird_height)
- pipe_boxes = []
- for pipe in self.pipes:
- pipe_boxes.append(Rect(pipe["x_upper"], pipe["y_upper"], self.pipe_width, self.pipe_height))
- pipe_boxes.append(Rect(pipe["x_lower"], pipe["y_lower"], self.pipe_width, self.pipe_height))
- # Check if the bird's bounding box overlaps to the bounding box of any pipe
- if bird_bbox.collidelist(pipe_boxes) == -1:
- return False
- for i in range(2):
- cropped_bbox = bird_bbox.clip(pipe_boxes[i])
- min_x1 = cropped_bbox.x - bird_bbox.x
- min_y1 = cropped_bbox.y - bird_bbox.y
- min_x2 = cropped_bbox.x - pipe_boxes[i].x
- min_y2 = cropped_bbox.y - pipe_boxes[i].y
- if np.any(self.bird_hitmask[self.bird_index][min_x1:min_x1 + cropped_bbox.width,
- min_y1:min_y1 + cropped_bbox.height] * self.pipe_hitmask[i][min_x2:min_x2 + cropped_bbox.width,
- min_y2:min_y2 + cropped_bbox.height]):
- return True
- return False
-
- def next_frame(self, action):
- pump()
- reward = 0.1
- terminal = False
- # Check input action
- if action == 1:
- self.current_velocity_y = self.upward_speed
- self.is_flapped = True
-
- # Update score
- bird_center_x = self.bird_x + self.bird_width / 2
- for pipe in self.pipes:
- pipe_center_x = pipe["x_upper"] + self.pipe_width / 2
- if pipe_center_x < bird_center_x < pipe_center_x + 5:
- self.score += 1
- reward = 1
- break
-
- # Update index and iteration
- if (self.iter + 1) % 3 == 0:
- self.bird_index = next(self.bird_index_generator)
- self.iter = 0
- self.base_x = -((-self.base_x + 100) % self.base_shift)
-
- # Update bird's position
- if self.current_velocity_y < self.max_velocity_y and not self.is_flapped:
- self.current_velocity_y += self.downward_speed
- if self.is_flapped:
- self.is_flapped = False
- self.bird_y += min(self.current_velocity_y, self.bird_y - self.current_velocity_y - self.bird_height)
- if self.bird_y < 0:
- self.bird_y = 0
-
- # Update pipes' position
- for pipe in self.pipes:
- pipe["x_upper"] += self.pipe_velocity_x
- pipe["x_lower"] += self.pipe_velocity_x
- # Update pipes
- if 0 < self.pipes[0]["x_lower"] < 5:
- self.pipes.append(self.generate_pipe())
- if self.pipes[0]["x_lower"] < -self.pipe_width:
- del self.pipes[0]
- if self.is_collided():
- terminal = True
- reward = -1
- self.__init__()
-
- # Draw everything
- self.screen.blit(self.background_image, (0, 0))
- self.screen.blit(self.base_image, (self.base_x, self.base_y))
- self.screen.blit(self.bird_images[self.bird_index], (self.bird_x, self.bird_y))
- for pipe in self.pipes:
- self.screen.blit(self.pipe_images[0], (pipe["x_upper"], pipe["y_upper"]))
- self.screen.blit(self.pipe_images[1], (pipe["x_lower"], pipe["y_lower"]))
-
-
- image = array3d(display.get_surface())
- display.update()
- self.fps_clock.tick(self.fps)
- return image, reward, terminal
- import argparse
- import torch
-
- from src.deep_q_network import DeepQNetwork
- from src.flappy_bird import FlappyBird
- from src.utils import pre_processing
-
-
- def get_args():
- parser = argparse.ArgumentParser(
- """Implementation of Deep Q Network to play Flappy Bird""")
- parser.add_argument("--image_size", type=int, default=84, help="The common width and height for all images")
- parser.add_argument("--saved_path", type=str, default="trained_models")
-
- args = parser.parse_args()
- return args
-
-
- def q_test(opt):
- if torch.cuda.is_available():
- torch.cuda.manual_seed(123)
- else:
- torch.manual_seed(123)
- if torch.cuda.is_available():
- model = torch.load("{}/flappy_bird".format(opt.saved_path))
- else:
- model = torch.load("{}/flappy_bird".format(opt.saved_path), map_location=lambda storage, loc: storage)
- model.eval()
- game_state = FlappyBird()
- image, reward, terminal = game_state.next_frame(0)
- image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)
- image = torch.from_numpy(image)
- if torch.cuda.is_available():
- model.cuda()
- image = image.cuda()
- state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]
-
- while True:
- prediction = model(state)[0]
- action = torch.argmax(prediction)
-
- next_image, reward, terminal = game_state.next_frame(action)
- next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,
- opt.image_size)
- next_image = torch.from_numpy(next_image)
- if torch.cuda.is_available():
- next_image = next_image.cuda()
- next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]
-
- state = next_state
-
-
- if __name__ == "__main__":
- opt = get_args()
- q_test(opt)
- def get_args():
- parser = argparse.ArgumentParser(
- """Implementation of Deep Q Network to play Flappy Bird""")
- parser.add_argument("--image_size", type=int, default=84, help="The common width and height for all images")
- parser.add_argument("--batch_size", type=int, default=32, help="The number of images per batch")
- parser.add_argument("--optimizer", type=str, choices=["sgd", "adam"], default="adam")
- parser.add_argument("--lr", type=float, default=1e-6)
- parser.add_argument("--gamma", type=float, default=0.99)
- parser.add_argument("--initial_epsilon", type=float, default=0.1)
- parser.add_argument("--final_epsilon", type=float, default=1e-4)
- parser.add_argument("--num_iters", type=int, default=2000000)
- parser.add_argument("--replay_memory_size", type=int, default=50000,
- help="Number of epoches between testing phases")
- parser.add_argument("--log_path", type=str, default="tensorboard")
- parser.add_argument("--saved_path", type=str, default="trained_models")
-
- args = parser.parse_args()
- return args
-
-
- def train(opt):
- if torch.cuda.is_available():
- torch.cuda.manual_seed(123)
- else:
- torch.manual_seed(123)
- model = DeepQNetwork()
- if os.path.isdir(opt.log_path):
- shutil.rmtree(opt.log_path)
- os.makedirs(opt.log_path)
- writer = SummaryWriter(opt.log_path)
- optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
- criterion = nn.MSELoss()
- game_state = FlappyBird()
- image, reward, terminal = game_state.next_frame(0)
- image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)
- image = torch.from_numpy(image)
- if torch.cuda.is_available():
- model.cuda()
- image = image.cuda()
- state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]
-
- replay_memory = []
- iter = 0
- while iter < opt.num_iters:
- prediction = model(state)[0]
- # Exploration or exploitation
- epsilon = opt.final_epsilon + (
- (opt.num_iters - iter) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_iters)
- u = random()
- random_action = u <= epsilon
- if random_action:
- print("Perform a random action")
- action = randint(0, 1)
- else:
-
- action = torch.argmax(prediction)
-
- next_image, reward, terminal = game_state.next_frame(action)
- next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,
- opt.image_size)
- next_image = torch.from_numpy(next_image)
- if torch.cuda.is_available():
- next_image = next_image.cuda()
- next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]
- replay_memory.append([state, action, reward, next_state, terminal])
- if len(replay_memory) > opt.replay_memory_size:
- del replay_memory[0]
- batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))
- state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(*batch)
-
- state_batch = torch.cat(tuple(state for state in state_batch))
- action_batch = torch.from_numpy(
- np.array([[1, 0] if action == 0 else [0, 1] for action in action_batch], dtype=np.float32))
- reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])
- next_state_batch = torch.cat(tuple(state for state in next_state_batch))
-
- if torch.cuda.is_available():
- state_batch = state_batch.cuda()
- action_batch = action_batch.cuda()
- reward_batch = reward_batch.cuda()
- next_state_batch = next_state_batch.cuda()
- current_prediction_batch = model(state_batch)
- next_prediction_batch = model(next_state_batch)
-
- y_batch = torch.cat(
- tuple(reward if terminal else reward + opt.gamma * torch.max(prediction) for reward, terminal, prediction in
- zip(reward_batch, terminal_batch, next_prediction_batch)))
-
- q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
- optimizer.zero_grad()
- # y_batch = y_batch.detach()
- loss = criterion(q_value, y_batch)
- loss.backward()
- optimizer.step()
-
- state = next_state
- iter += 1
- print("Iteration: {}/{}, Action: {}, Loss: {}, Epsilon {}, Reward: {}, Q-value: {}".format(
- iter + 1,
- opt.num_iters,
- action,
- loss,
- epsilon, reward, torch.max(prediction)))
- writer.add_scalar('Train/Loss', loss, iter)
- writer.add_scalar('Train/Epsilon', epsilon, iter)
- writer.add_scalar('Train/Reward', reward, iter)
- writer.add_scalar('Train/Q-value', torch.max(prediction), iter)
- if (iter+1) % 1000000 == 0:
- torch.save(model, "{}/flappy_bird_{}".format(opt.saved_path, iter+1))
- torch.save(model, "{}/flappy_bird".format(opt.saved_path))
-
-
- if __name__ == "__main__":
- opt = get_args()
- train(opt)