• 机器学习 Q-Learning


    对马尔可夫奖励的理解

    看的这个教程

    • 公式:V(s) = R(s) + γ * V(s’)
      V(s) 代表当前状态 s 的价值。
      R(s) 代表从状态 s 到下一个状态 s’ 执行某个动作后所获得的即时奖励。
      γ 是折扣因子,它表示未来奖励的重要性,通常取值在 0 到 1 之间。
      V(s’) 代表下一个状态 s’ 的价值。
    • 理解
      这个图片是个例子,便于讲解问题
    1. 如果折扣因子γ为1,那么从现在开始,一直到结束,所有的即时奖励加在一起就是当前状态的价值。所以,现在的价值是以后的所有即时奖励决定的。但是,实际中,γ是0到1的一个小数。就是说,相同的动作,离现在越远,带来的收益越小。还有,我发现,终点是没有价值的,或者他的价值对于算法没有帮助,只是终点前一步到终点这个动作,或者状态转移产生了一个大的奖励。不知道对不对。请大家提出意见。
    2. 假设我们把所有的状态价值放在一个shape为(16,4)的表格里,我们把它称为Q表。16代表16个格子,4代表每一个动作。(数字是16,4是因为图片有16个格子,每个格子都能执行四个动作,这里只是举个简单的例子,你有多少种状态和有几个动作都没有关系,可以随便改,只要合理)。初始值都为0。就是说当前所有位置的所有动作的价值都为0。
    3. 在这个格子里,我们的目的是走到终点。规则是,每次任意方向走一步,走到终点胜利,走到陷阱,就失败。胜利与失败就结束游戏。胜利,这次游戏的一分,失败则是得-100分。每走一步扣一分。
    4. 要知道,Q表的所有格子初始值为0,是不符合现实的,那么,怎么把值逐步更改为现实中对应的值呢?
    5. 假设,我们走对了一次,倒数第二个格子,在向终点方向的那个动作就有了价值(不是0了,而且大于0)。
    6. 假设,我们走错了一次,那么走错的倒数第二个格子,向陷阱走的那个动作就有了价值,(不是0,并且小于0)。这样打完一局游戏,不论走对还是走错,都会产生1个有价值的格子。如果这个格子不是起点,那就肯定还有倒数第三个格子,根据公式,倒数第三个格子的那个方向价值也能算出来。如果倒数第三个格子不是起点…就这样,一点一点的“辐射”。所有的,走过的格子都有了价值。
    7. 假如走到了一个格子,我们只要查Q表,就能知道,往哪里走比较安全,能通向终点,往哪里走比较危险,会掉进陷阱。所以Q表会指引我们,走向正确的道路,避开危险的道路。
    8. 算法成立的前提是,有过走成功的经历,这样才会把最终的那个奖励,“扩散”到起点。
    9. 实际上,我们不是直接从终点扩散的,而是直接采样足够多的样本,一点点更新Q表。比如,我们采样到一步数据,拿Q表查询当前状态的当前动作的价值(V(s) )计作A,还有查询下一个状态的价值(V(s’))计作B。再拿到这一步的奖励R(s)计作R,假设折扣是0.9,那么A = R+0.9*B 。看到没有,是未来的价值决定现在的价值。如果Q表是正确的,这个等式就成立,但是我们会发现有误差,所以,我们得计算出误差(等式右边减去左边),误差 = (R+0.9*B - A)0.1,0.1是学习率,再拿这个误差更新A,就是Q表中,当前的状态这个动作的价值。这样,Q表就会距离理想中的绝对正确的Q表更进一步了。至于为什么有学习率,我的理解是,R+0.9*B这个东西也是估算出来的,不是真正的值,(但是按道理他是和奖励R决定A的),所以只取用他的影响*,不取用他真正的值。**(大家可以谈谈自己的看法,本人能力尚浅)**什么是影响,我也不清楚,可能在这个领域有他的名字,只是我不知道,或者没有察觉出是哪个概念。

    关于陷阱的作用

    在这里插入图片描述

    1. 加入把打叉的都变成陷阱,那么,我们就会更快的到达终点,因为走进陷阱后,Q表就不会让他再次掉进陷阱。所以说,陷阱在某种程度上,帮助我们接近终点。有不同意见,可以提出来,让大家讨论。

    代码,上面的链接里有完整版。还有视频,我也是从B站找到的

    • 这个代码在2023-10-11 跑成功过
    • gym== 0.26.2
    • python == 3.9
    • ipython == 8.16.1
    • ipython-genutils == 0.2.0 (不确定有没有用到)
    • 用的conda(这个倒是无所谓)
    import random
    
    import gym
    import numpy as np
    from IPython import display
    
    
    class NasWrapper(gym.Wrapper):
        def __init__(self):
            env = gym.make('FrozenLake-v1',
                           render_mode='rgb_array',
                           is_slippery=False)
            super(NasWrapper, self).__init__(env)
            self.env = env
    
        def reset(self):
            state, _ = self.env.reset()
            return state
    
        def step(self, action):
            state, reward, terminated, truncated, info = self.env.step(action)
            over = terminated or truncated
            if not over:
                reward = -1
            # elif reward == 1:
            #     reward = 100
            if over and reward == 0:
                reward = -100
            return state, reward, over
    
        def show(self):
            from matplotlib import pyplot as plt
            plt.figure(figsize=(3, 3))
            plt.imshow(self.env.render())
            plt.show()
    
    
    nw = NasWrapper()
    Q = np.zeros((16, 4))
    
    
    def play(isShow=False):
        data = []
        reword_sum = 0
        state = nw.reset()
        over = False
        while not over:
            action = Q[state].argmax()
            if random.random() < 0.1:
                action = nw.action_space.sample()
    
            next_state, reward, over = nw.step(action)
            reword_sum += reward
            data.append((state, action, reward, next_state, over))
    
            state = next_state
    
            if isShow:
                display.clear_output(wait=True)
                nw.show()
    
        return data, reword_sum
    
    
    class Pool():
        def __init__(self):
            self.pool = []
    
        def __len__(self):
            return len(self.pool)
    
        def __getitem__(self, item):
            return self.pool[item]
    
        def update(self):
            old_len = len(pool)
            while len(pool) - old_len < 200:
                self.pool.extend(play()[0])
            self.pool = self.pool[-10000:]
    
        # 获取一批数据样本
        def sample(self):
            return random.choice(self.pool)
    
    
    pool = Pool()
    
    
    # pool.update()
    
    
    def train():
        for epoch in range(100):
            pool.update()
            for i in range(100):
                state, action, reward, next_state, over = pool.sample()
    
                value = Q[state, action]
                target = Q[next_state].max() * 0.9 + reward
                update = (target - value) * 0.1
                Q[state, action] += update
            if epoch % 100 == 0:
                print(epoch, len(pool), play()[-1])
    
    
    train()
    print("train ok")
    print(Q)
    play(isShow=True)
    # nw.reset()
    # while True:
    #     inputNumber = input()
    #     print("---")
    #     nw.step(int(inputNumber))
    #     nw.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
  • 相关阅读:
    Kubernetes入门到精通-基础知识
    nginx 配置少一个‘/‘引起 “detail“:“Not Found“
    IP地址与在线教育平台资源分配优化
    03-Nginx性能调优与零拷贝
    SpringBoot中HttpClient的使用
    WPF知识小结(2)
    LayaBox---TypeScript---泛型
    python绘制混淆矩阵
    单向链表浅析(小学生都能看懂)
    alpha模型:打开量化投资的黑箱;附创业板布林带策略代码:年化15%。
  • 原文地址:https://blog.csdn.net/sinat_40387150/article/details/133790507