目录
用深度强化学习来玩Chrome小恐龙快跑

- import os
- import cv2
- from pygame import RLEACCEL
- from pygame.image import load
- from pygame.sprite import Sprite, Group, collide_mask
- from pygame import Rect, init, time, display, mixer, transform, Surface
- from pygame.surfarray import array3d
- import torch
- from random import randrange, choice
- import numpy as np
-
- mixer.pre_init(44100, -16, 2, 2048)
- init()
-
- scr_size = (width, height) = (600, 150)
- FPS = 60
- gravity = 0.6
-
- black = (0, 0, 0)
- white = (255, 255, 255)
- background_col = (235, 235, 235)
-
- high_score = 0
-
- screen = display.set_mode(scr_size)
- clock = time.Clock()
- display.set_caption("T-Rex Rush")
-
-
- def load_image(
- name,
- sizex=-1,
- sizey=-1,
- colorkey=None,
- ):
- fullname = os.path.join("assets/sprites", name)
- image = load(fullname)
- image = image.convert()
- if colorkey is not None:
- if colorkey is -1:
- colorkey = image.get_at((0, 0))
- image.set_colorkey(colorkey, RLEACCEL)
-
- if sizex != -1 or sizey != -1:
- image = transform.scale(image, (sizex, sizey))
-
- return (image, image.get_rect())
-
-
- def load_sprite_sheet(
- sheetname,
- nx,
- ny,
- scalex=-1,
- scaley=-1,
- colorkey=None,
- ):
- fullname = os.path.join("assets/sprites", sheetname)
- sheet = load(fullname)
- sheet = sheet.convert()
-
- sheet_rect = sheet.get_rect()
-
- sprites = []
-
- sizey = sheet_rect.height / ny
- if isinstance(nx, int):
- sizex = sheet_rect.width / nx
- for i in range(0, ny):
- for j in range(0, nx):
- rect = Rect((j * sizex, i * sizey, sizex, sizey))
- image = Surface(rect.size)
- image = image.convert()
- image.blit(sheet, (0, 0), rect)
-
- if colorkey is not None:
- if colorkey is -1:
- colorkey = image.get_at((0, 0))
- image.set_colorkey(colorkey, RLEACCEL)
-
- if scalex != -1 or scaley != -1:
- image = transform.scale(image, (scalex, scaley))
-
- sprites.append(image)
-
- else: #list
- sizex_ls = [sheet_rect.width / i_nx for i_nx in nx]
- for i in range(0, ny):
- for i_nx, sizex, i_scalex in zip(nx, sizex_ls, scalex):
- for j in range(0, i_nx):
- rect = Rect((j * sizex, i * sizey, sizex, sizey))
- image = Surface(rect.size)
- image = image.convert()
- image.blit(sheet, (0, 0), rect)
-
- if colorkey is not None:
- if colorkey is -1:
- colorkey = image.get_at((0, 0))
- image.set_colorkey(colorkey, RLEACCEL)
-
- if i_scalex != -1 or scaley != -1:
- image = transform.scale(image, (i_scalex, scaley))
-
- sprites.append(image)
-
- sprite_rect = sprites[0].get_rect()
-
- return sprites, sprite_rect
-
-
- def extractDigits(number):
- if number > -1:
- digits = []
- i = 0
- while (number / 10 != 0):
- digits.append(number % 10)
- number = int(number / 10)
-
- digits.append(number % 10)
- for i in range(len(digits), 5):
- digits.append(0)
- digits.reverse()
- return digits
-
-
- def pre_processing(image, w=84, h=84):
- image = image[:300, :, :]
- # cv2.imwrite("ori.jpg", image)
- image = cv2.cvtColor(cv2.resize(image, (w, h)), cv2.COLOR_BGR2GRAY)
- # cv2.imwrite("color.jpg", image)
- _, image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
- # cv2.imwrite("bw.jpg", image)
-
- return image[None, :, :].astype(np.float32)
-
-
- class Dino():
- def __init__(self, sizex=-1, sizey=-1):
- self.images, self.rect = load_sprite_sheet("dino.png", 5, 1, sizex, sizey, -1)
- self.images1, self.rect1 = load_sprite_sheet("dino_ducking.png", 2, 1, 59, sizey, -1)
- self.rect.bottom = int(0.98 * height)
- self.rect.left = width / 15
- self.image = self.images[0]
- self.index = 0
- self.counter = 0
- self.score = 0
- self.isJumping = False
- self.isDead = False
- self.isDucking = False
- self.isBlinking = False
- self.movement = [0, 0]
- self.jumpSpeed = 11.5
-
- self.stand_pos_width = self.rect.width
- self.duck_pos_width = self.rect1.width
-
- def draw(self):
- screen.blit(self.image, self.rect)
-
- def checkbounds(self):
- if self.rect.bottom > int(0.98 * height):
- self.rect.bottom = int(0.98 * height)
- self.isJumping = False
-
- def update(self):
- if self.isJumping:
- self.movement[1] = self.movement[1] + gravity
-
- if self.isJumping:
- self.index = 0
- elif self.isBlinking:
- if self.index == 0:
- if self.counter % 400 == 399:
- self.index = (self.index + 1) % 2
- else:
- if self.counter % 20 == 19:
- self.index = (self.index + 1) % 2
-
- elif self.isDucking:
- if self.counter % 5 == 0:
- self.index = (self.index + 1) % 2
- else:
- if self.counter % 5 == 0:
- self.index = (self.index + 1) % 2 + 2
-
- if self.isDead:
- self.index = 4
-
- if not self.isDucking:
- self.image = self.images[self.index]
- self.rect.width = self.stand_pos_width
- else:
- self.image = self.images1[(self.index) % 2]
- self.rect.width = self.duck_pos_width
-
- self.rect = self.rect.move(self.movement)
- self.checkbounds()
-
- if not self.isDead and self.counter % 7 == 6 and self.isBlinking == False:
- self.score += 1
-
- self.counter = (self.counter + 1)
-
-
- class Cactus(Sprite):
- def __init__(self, speed=5, sizex=-1, sizey=-1):
- Sprite.__init__(self, self.containers)
- self.images, self.rect = load_sprite_sheet("cacti-small.png", [2, 3, 6], 1, sizex, sizey, -1)
- self.rect.bottom = int(0.98 * height)
- self.rect.left = width + self.rect.width
- self.image = self.images[randrange(0, 11)]
- self.movement = [-1 * speed, 0]
-
- def draw(self):
- screen.blit(self.image, self.rect)
-
- def update(self):
- self.rect = self.rect.move(self.movement)
-
- if self.rect.right < 0:
- self.kill()
-
-
- class Ptera(Sprite):
- def __init__(self, speed=5, sizex=-1, sizey=-1):
- Sprite.__init__(self, self.containers)
- self.images, self.rect = load_sprite_sheet("ptera.png", 2, 1, sizex, sizey, -1)
- self.ptera_height = [height * 0.82, height * 0.75, height * 0.60, height * 0.48]
- self.rect.centery = self.ptera_height[randrange(0, 4)]
- self.rect.left = width + self.rect.width
- self.image = self.images[0]
- self.movement = [-1 * speed, 0]
- self.index = 0
- self.counter = 0
-
- def draw(self):
- screen.blit(self.image, self.rect)
-
- def update(self):
- if self.counter % 10 == 0:
- self.index = (self.index + 1) % 2
- self.image = self.images[self.index]
- self.rect = self.rect.move(self.movement)
- self.counter = (self.counter + 1)
- if self.rect.right < 0:
- self.kill()
-
-
- class Ground():
- def __init__(self, speed=-5):
- self.image, self.rect = load_image("ground.png", -1, -1, -1)
- self.image1, self.rect1 = load_image("ground.png", -1, -1, -1)
- self.rect.bottom = height
- self.rect1.bottom = height
- self.rect1.left = self.rect.right
- self.speed = speed
-
- def draw(self):
- screen.blit(self.image, self.rect)
- screen.blit(self.image1, self.rect1)
-
- def update(self):
- self.rect.left += self.speed
- self.rect1.left += self.speed
-
- if self.rect.right < 0:
- self.rect.left = self.rect1.right
-
- if self.rect1.right < 0:
- self.rect1.left = self.rect.right
-
-
- class Cloud(Sprite):
- def __init__(self, x, y):
- Sprite.__init__(self, self.containers)
- self.image, self.rect = load_image("cloud.png", int(90 * 30 / 42), 30, -1)
- self.speed = 1
- self.rect.left = x
- self.rect.top = y
- self.movement = [-1 * self.speed, 0]
-
- def draw(self):
- screen.blit(self.image, self.rect)
-
- def update(self):
- self.rect = self.rect.move(self.movement)
- if self.rect.right < 0:
- self.kill()
-
-
- class Scoreboard():
- def __init__(self, x=-1, y=-1):
- self.score = 0
- self.tempimages, self.temprect = load_sprite_sheet("numbers.png", 12, 1, 11, int(11 * 6 / 5), -1)
- self.image = Surface((55, int(11 * 6 / 5)))
- self.rect = self.image.get_rect()
- if x == -1:
- self.rect.left = width * 0.89
- else:
- self.rect.left = x
- if y == -1:
- self.rect.top = height * 0.1
- else:
- self.rect.top = y
-
- def draw(self):
- screen.blit(self.image, self.rect)
-
- def update(self, score):
- score_digits = extractDigits(score)
- self.image.fill(background_col)
- if len(score_digits) == 6:
- score_digits = score_digits[1:]
- for s in score_digits:
- self.image.blit(self.tempimages[s], self.temprect)
- self.temprect.left += self.temprect.width
- self.temprect.left = 0
-
-
- class ChromeDino(object):
- def __init__(self):
- self.gamespeed = 5
- self.gameOver = False
- self.gameQuit = False
- self.playerDino = Dino(44, 47)
- self.new_ground = Ground(-1 * self.gamespeed)
- self.scb = Scoreboard()
- self.highsc = Scoreboard(width * 0.78)
- self.counter = 0
-
- self.cacti = Group()
- self.pteras = Group()
- self.clouds = Group()
- self.last_obstacle = Group()
-
- Cactus.containers = self.cacti
- Ptera.containers = self.pteras
- Cloud.containers = self.clouds
-
- self.retbutton_image, self.retbutton_rect = load_image("replay_button.png", 35, 31, -1)
- self.gameover_image, self.gameover_rect = load_image("game_over.png", 190, 11, -1)
-
- self.temp_images, self.temp_rect = load_sprite_sheet("numbers.png", 12, 1, 11, int(11 * 6 / 5), -1)
- self.HI_image = Surface((22, int(11 * 6 / 5)))
- self.HI_rect = self.HI_image.get_rect()
- self.HI_image.fill(background_col)
- self.HI_image.blit(self.temp_images[10], self.temp_rect)
- self.temp_rect.left += self.temp_rect.width
- self.HI_image.blit(self.temp_images[11], self.temp_rect)
- self.HI_rect.top = height * 0.1
- self.HI_rect.left = width * 0.73
-
- def step(self, action, record=False): # 0: Do nothing. 1: Jump. 2: Duck
- reward = 0.1
- if action == 0:
- reward += 0.01
- self.playerDino.isDucking = False
- elif action == 1:
- self.playerDino.isDucking = False
- if self.playerDino.rect.bottom == int(0.98 * height):
- self.playerDino.isJumping = True
- self.playerDino.movement[1] = -1 * self.playerDino.jumpSpeed
-
- elif action == 2:
- if not (self.playerDino.isJumping and self.playerDino.isDead) and self.playerDino.rect.bottom == int(
- 0.98 * height):
- self.playerDino.isDucking = True
-
- for c in self.cacti:
- c.movement[0] = -1 * self.gamespeed
- if collide_mask(self.playerDino, c):
- self.playerDino.isDead = True
- reward = -1
- break
- else:
- if c.rect.right < self.playerDino.rect.left < c.rect.right + self.gamespeed + 1:
- reward = 1
- break
-
- for p in self.pteras:
- p.movement[0] = -1 * self.gamespeed
- if collide_mask(self.playerDino, p):
- self.playerDino.isDead = True
- reward = -1
- break
- else:
- if p.rect.right < self.playerDino.rect.left < p.rect.right + self.gamespeed + 1:
- reward = 1
- break
-
- if len(self.cacti) < 2:
- if len(self.cacti) == 0 and len(self.pteras) == 0:
- self.last_obstacle.empty()
- self.last_obstacle.add(Cactus(self.gamespeed, [60, 40, 20], choice([40, 45, 50])))
- else:
- for l in self.last_obstacle:
- if l.rect.right < width * 0.7 and randrange(0, 50) == 10:
- self.last_obstacle.empty()
- self.last_obstacle.add(Cactus(self.gamespeed, [60, 40, 20], choice([40, 45, 50])))
-
- # if len(self.pteras) == 0 and randrange(0, 200) == 10 and self.counter > 500:
- if len(self.pteras) == 0 and len(self.cacti) < 2 and randrange(0, 50) == 10 and self.counter > 500:
- for l in self.last_obstacle:
- if l.rect.right < width * 0.8:
- self.last_obstacle.empty()
- self.last_obstacle.add(Ptera(self.gamespeed, 46, 40))
-
- if len(self.clouds) < 5 and randrange(0, 300) == 10:
- Cloud(width, randrange(height / 5, height / 2))
-
- self.playerDino.update()
- self.cacti.update()
- self.pteras.update()
- self.clouds.update()
- self.new_ground.update()
- self.scb.update(self.playerDino.score)
-
- state = display.get_surface()
- screen.fill(background_col)
- self.new_ground.draw()
- self.clouds.draw(screen)
- self.scb.draw()
- self.cacti.draw(screen)
- self.pteras.draw(screen)
- self.playerDino.draw()
-
- display.update()
- clock.tick(FPS)
-
- if self.playerDino.isDead:
- self.gameOver = True
-
- self.counter = (self.counter + 1)
-
- if self.gameOver:
- self.__init__()
-
- state = array3d(state)
- if record:
- return torch.from_numpy(pre_processing(state)), np.transpose(
- cv2.cvtColor(state, cv2.COLOR_RGB2BGR), (1, 0, 2)), reward, not (reward > 0)
- else:
- return torch.from_numpy(pre_processing(state)), reward, not (reward > 0)
- import torch.nn as nn
-
- class DeepQNetwork(nn.Module):
- def __init__(self):
- super(DeepQNetwork, self).__init__()
-
- self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
- self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
- self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))
-
- self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))
- self.fc2 = nn.Linear(512, 3)
- self._initialize_weights()
-
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
- nn.init.uniform_(m.weight, -0.01, 0.01)
- nn.init.constant_(m.bias, 0)
-
- def forward(self, input):
- output = self.conv1(input)
- output = self.conv2(output)
- output = self.conv3(output)
- output = output.view(output.size(0), -1)
- output = self.fc1(output)
- output = self.fc2(output)
-
- return output
- import argparse
- import torch
-
- from src.model import DeepQNetwork
- from src.env import ChromeDino
- import cv2
-
-
- def get_args():
- parser = argparse.ArgumentParser(
- """Implementation of Deep Q Network to play Chrome Dino""")
- parser.add_argument("--saved_path", type=str, default="trained_models")
- parser.add_argument("--fps", type=int, default=60, help="frames per second")
- parser.add_argument("--output", type=str, default="output/chrome_dino.mp4", help="the path to output video")
-
- args = parser.parse_args()
- return args
-
-
- def q_test(opt):
- if torch.cuda.is_available():
- torch.cuda.manual_seed(123)
- else:
- torch.manual_seed(123)
- model = DeepQNetwork()
- checkpoint_path = "{}/chrome_dino.pth".format(opt.saved_path)
- checkpoint = torch.load(checkpoint_path)
- model.load_state_dict(checkpoint["model_state_dict"])
- model.eval()
- env = ChromeDino()
- state, raw_state, _, _ = env.step(0, True)
- state = torch.cat(tuple(state for _ in range(4)))[None, :, :, :]
- if torch.cuda.is_available():
- model.cuda()
- state = state.cuda()
- out = cv2.VideoWriter(opt.output, cv2.VideoWriter_fourcc(*"MJPG"), opt.fps, (600, 150))
- done = False
- while not done:
- prediction = model(state)[0]
- action = torch.argmax(prediction).item()
- next_state, raw_next_state, reward, done = env.step(action, True)
- out.write(raw_next_state)
- if torch.cuda.is_available():
- next_state = next_state.cuda()
- next_state = torch.cat((state[0, 1:, :, :], next_state))[None, :, :, :]
- state = next_state
-
-
-
- if __name__ == "__main__":
- opt = get_args()
- q_test(opt)
- import argparse
- import os
- from random import random, randint, sample
- import pickle
- import numpy as np
- import torch
- import torch.nn as nn
-
- from src.model import DeepQNetwork
- from src.env import ChromeDino
-
-
- def get_args():
- parser = argparse.ArgumentParser(
- """Implementation of Deep Q Network to play Chrome Dino""")
- parser.add_argument("--batch_size", type=int, default=64, help="The number of images per batch")
- parser.add_argument("--optimizer", type=str, choices=["sgd", "adam"], default="adam")
- parser.add_argument("--lr", type=float, default=1e-4)
- parser.add_argument("--gamma", type=float, default=0.99)
- parser.add_argument("--initial_epsilon", type=float, default=0.1)
- parser.add_argument("--final_epsilon", type=float, default=1e-4)
- parser.add_argument("--num_decay_iters", type=float, default=2000000)
- parser.add_argument("--num_iters", type=int, default=2000000)
- parser.add_argument("--replay_memory_size", type=int, default=50000,
- help="Number of epoches between testing phases")
- parser.add_argument("--saved_folder", type=str, default="trained_models")
-
- args = parser.parse_args()
- return args
-
-
- def train(opt):
- if torch.cuda.is_available():
- torch.cuda.manual_seed(123)
- else:
- torch.manual_seed(123)
- model = DeepQNetwork()
- if torch.cuda.is_available():
- model.cuda()
- optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
- if not os.path.isdir(opt.saved_folder):
- os.makedirs(opt.saved_folder)
- checkpoint_path = os.path.join(opt.saved_folder, "chrome_dino.pth")
- memory_path = os.path.join(opt.saved_folder, "replay_memory.pkl")
- if os.path.isfile(checkpoint_path):
- checkpoint = torch.load(checkpoint_path)
- iter = checkpoint["iter"] + 1
- model.load_state_dict(checkpoint["model_state_dict"])
- optimizer.load_state_dict(checkpoint["optimizer"])
- print("Load trained model from iteration {}".format(iter))
- else:
- iter = 0
- if os.path.isfile(memory_path):
- with open(memory_path, "rb") as f:
- replay_memory = pickle.load(f)
- print("Load replay memory")
- else:
- replay_memory = []
- criterion = nn.MSELoss()
- env = ChromeDino()
- state, _, _ = env.step(0)
- state = torch.cat(tuple(state for _ in range(4)))[None, :, :, :]
- while iter < opt.num_iters:
- if torch.cuda.is_available():
- prediction = model(state.cuda())[0]
- else:
- prediction = model(state)[0]
- # Exploration or exploitation
- epsilon = opt.final_epsilon + (
- max(opt.num_decay_iters - iter, 0) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_decay_iters)
- u = random()
- random_action = u <= epsilon
- if random_action:
- action = randint(0, 2)
- else:
- action = torch.argmax(prediction).item()
-
- next_state, reward, done = env.step(action)
- next_state = torch.cat((state[0, 1:, :, :], next_state))[None, :, :, :]
- replay_memory.append([state, action, reward, next_state, done])
- if len(replay_memory) > opt.replay_memory_size:
- del replay_memory[0]
- batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))
- state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*batch)
-
- state_batch = torch.cat(tuple(state for state in state_batch))
- action_batch = torch.from_numpy(
- np.array([[1, 0, 0] if action == 0 else [0, 1, 0] if action == 1 else [0, 0, 1] for action in
- action_batch], dtype=np.float32))
- reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])
- next_state_batch = torch.cat(tuple(state for state in next_state_batch))
-
- if torch.cuda.is_available():
- state_batch = state_batch.cuda()
- action_batch = action_batch.cuda()
- reward_batch = reward_batch.cuda()
- next_state_batch = next_state_batch.cuda()
- current_prediction_batch = model(state_batch)
- next_prediction_batch = model(next_state_batch)
-
- y_batch = torch.cat(
- tuple(reward if done else reward + opt.gamma * torch.max(prediction) for reward, done, prediction in
- zip(reward_batch, done_batch, next_prediction_batch)))
-
- q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
- optimizer.zero_grad()
- loss = criterion(q_value, y_batch)
- loss.backward()
- optimizer.step()
-
- state = next_state
- iter += 1
- print("Iteration: {}/{}, Loss: {:.5f}, Epsilon {:.5f}, Reward: {}".format(
- iter + 1,
- opt.num_iters,
- loss,
- epsilon, reward))
- if (iter + 1) % 50000 == 0:
- checkpoint = {"iter": iter,
- "model_state_dict": model.state_dict(),
- "optimizer": optimizer.state_dict()}
- torch.save(checkpoint, checkpoint_path)
- with open(memory_path, "wb") as f:
- pickle.dump(replay_memory, f, protocol=pickle.HIGHEST_PROTOCOL)
-
-
- if __name__ == "__main__":
- opt = get_args()
- train(opt)