• Keras中reset_states对stateful的影响探究


    一、reset_states与stateful关系

            在上一篇文章中,我们比较详细的解释了如何正确理解Keras接口中的stateful参数,文章如下:

    Keras中stateful的正确理解Keras中stateful的正确理解https://forecast.blog.csdn.net/article/details/126882142?spm=1001.2014.3001.5502

            同时与stateful紧密联系的还有一个参数是reset_states,这个函数本身作用很简单,就是清除网络中的隐藏状态,其中model.reset_states()是清除整个模型的隐藏状态,layer.reset_states()是清除模型中某一层的隐藏状态。reset_states与stateful在不同情境结合使用时,会出现很多出人意料的结果,而对结果的分析,也更有利于我们去理解stateful的原理和效果。本文也算对stateful的扩展知识。

            这里强调一下,reset_states重置的是每次fit()执行完后LSTM神经单元里的状态(输入门,输出门,遗忘门),而不是重置我们训练的权重矩阵W的参数。也不是重置每个batch之间的状态,因为一个fit()执行完后,所有的batch都已经完成了训练。

            我们通过对一组数据的训练和预测,来观察reset_states与stateful之间的关系,我们使用的训练数据如下所示,显然目的是希望模型能学会字母的前后顺序关系:

    A -> B
    B -> C
    C -> D
    D -> E
    E -> F
    F -> G
    G -> H
    H -> I
    I -> J
    J -> K
    K -> L
    L -> M
    M -> N
    N -> O
    O -> P
    P -> Q
    Q -> R
    R -> S
    S -> T
    T -> U
    U -> V
    V -> W
    W -> X
    X -> Y
    Y -> Z

    二、使用reset_states的效果

    2.1.训练时不使用reset_states

            首先我们写一段正常的LSTM模型代码,这里使用stateful=True,同时在每轮循环中不调用model.reset_states()函数,观察一下结果。

    1. # Stateful LSTM to learn one-char to one-char mapping
    2. import numpy
    3. from keras.models import Sequential
    4. from keras.layers import Dense
    5. from keras.layers import LSTM
    6. from keras.utils import np_utils
    7. # fix random seed for reproducibility
    8. numpy.random.seed(7)
    9. # define the raw dataset
    10. alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    11. # create mapping of characters to integers (0-25) and the reverse
    12. char_to_int = dict((c, i) for i, c in enumerate(alphabet))
    13. int_to_char = dict((i, c) for i, c in enumerate(alphabet))
    14. # prepare the dataset of input to output pairs encoded as integers
    15. seq_length = 1
    16. dataX = []
    17. dataY = []
    18. for i in range(0, len(alphabet) - seq_length, 1):
    19. seq_in = alphabet[i:i + seq_length]
    20. seq_out = alphabet[i + seq_length]
    21. dataX.append([char_to_int[char] for char in seq_in])
    22. dataY.append(char_to_int[seq_out])
    23. print (seq_in, '->', seq_out)
    24. # reshape X to be [samples, time steps, features]
    25. X = numpy.reshape(dataX, (len(dataX), seq_length, 1))
    26. # normalize
    27. X = X / float(len(alphabet))
    28. # one hot encode the output variable
    29. y = np_utils.to_categorical(dataY)
    30. # create and fit the model
    31. batch_size = 1
    32. model = Sequential()
    33. model.add(LSTM(16, batch_input_shape=(batch_size, X.shape[1], X.shape[2]), stateful=True))
    34. model.add(Dense(y.shape[1], activation='softmax'))
    35. model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    36. for i in range(3):
    37. model.fit(X, y, epochs=1, batch_size=batch_size, verbose=2, shuffle=False)
    38. # 每次循环不对模型参数进行重置
    39. #model.reset_states()
    40. # summarize performance of the model
    41. scores = model.evaluate(X, y, batch_size=batch_size, verbose=0)
    42. print("Model Accuracy: %.2f%%" % (scores[1]*100))

            训练过程输出的参数如下,观察可发现,如果每个周期后不执行reset_states,收敛的过程很混乱,推测:应该是因为上次fit的最终状态没有清理,被强行带入了下次fit中,但其实上次fit和本次fit之间没有顺序关系,状态信息传递的混乱也导致权重矩阵W的更新混乱,因此导致无法收敛或收敛缓慢。

    loss: 1.1953 - accuracy: 0.9600

    loss: 1.1738 - accuracy: 1.0000

    loss: 1.1764 - accuracy: 0.9600

    loss: 1.1952 - accuracy: 0.9200

    loss: 1.1817 - accuracy: 0.9600

    loss: 1.1534 - accuracy: 1.0000

    loss: 1.1518 - accuracy: 0.9600

    loss: 1.1786 - accuracy: 0.9200

    loss: 1.1708 - accuracy: 0.9600

    loss: 1.1324 - accuracy: 1.0000

    loss: 1.1267 - accuracy: 0.9600

    loss: 1.1659 - accuracy: 0.8400

    loss: 1.1661 - accuracy: 0.8400

    loss: 1.1109 - accuracy: 1.0000

    loss: 1.1017 - accuracy: 1.0000

    loss: 1.1621 - accuracy: 0.7600

    loss: 1.1776 - accuracy: 0.7600

    loss: 1.0928 - accuracy: 1.0000

    loss: 1.0851 - accuracy: 0.9600

    loss: 1.1787 - accuracy: 0.7200

    loss: 1.2398 - accuracy: 0.4400

    loss: 1.0954 - accuracy: 0.9600

    loss: 1.1317 - accuracy: 0.8000

    loss: 1.1977 - accuracy: 0.6400

    loss: 1.4574 - accuracy: 0.0800

    loss: 1.1274 - accuracy: 0.8400

    2.2.预测前不使用reset_states,顺序预测

            最终在完成300个周期的训练后,我们尝试写一段代码进行预测,本次预测前先不调用reset_states函数重置网络状态:

    1. seed = [char_to_int[alphabet[0]]]
    2. for i in range(0, len(alphabet)-1):
    3. x = numpy.reshape(seed, (1, len(seed), 1))
    4. x = x / float(len(alphabet))
    5. prediction = model.predict(x, verbose=0)
    6. index = numpy.argmax(prediction)
    7. print (int_to_char[seed[0]], "->", int_to_char[index])
    8. seed = [index]

            从预测结果来看,正常从A开始预测,可以观察到预测结果总体好像没有大问题,好像仅仅第一条和最后一条预测不准确,预测结果如下:

    A -> B

    B -> C

    C -> D

    D -> E

    E -> F

    F -> G

    G -> H

    H -> I

    I -> J

    J -> K

    K -> L

    L -> M

    M -> N

    N -> O

    O -> P

    P -> Q

    Q -> R

    R -> S

    S -> T

    T -> U

    U -> V

    V -> W

    W -> X

    X -> Y

    Y -> Z

    Z -> Z

    2.3.预测前不使用reset_states,随机预测

            这次我们尝试对随机字母进行预测,测试代码如下:

    1. letter = "G"
    2. seed = [char_to_int[letter]]
    3. print ("New start: ", letter)
    4. for i in range(0, 5):
    5. x = numpy.reshape(seed, (1, len(seed), 1))
    6. x = x / float(len(alphabet))
    7. prediction = model.predict(x, verbose=0)
    8. index = numpy.argmax(prediction)
    9. print (int_to_char[seed[0]], "->", int_to_char[index])
    10. seed = [index+2]

            然而观测下面的预测结果发现不管输入X是什么,其实结果仅仅是按照字母表顺序在输出,并非正确的X=>y结果。所以刚刚上面的预测结果是一个假象,只是我们的输入X恰巧是从字母A开始。

            因此推测:模型并没有学到X=>y的映射关系,只学到了y的顺序(因为他总是从B开始按照字母表顺序输出值)这其实也是一种过拟合的体现。而导致这种过拟合的原因,暂时猜测是因为我们使用了stateful参数,使得历史信息被持续的向后传播,所以模型没学到当前的输入X和输出y之间的映射关系,反而学到了历史输出序列y的顺序关系。

            这里仍有一个疑问,既然输入Xy都是按照字母表顺序排序,那模型学到的到底是输入X的顺序,还是输出y的顺序?作者这里猜测学到的是输出y的顺序,是因为y是从B开始,而预测值的输出每次也都是从B开始,我们继续通过后续实验来验证。

    New start:  G

    G -> B

    D -> C

    E -> D

    F -> E

    G -> F

    2.4.预测前先进性reset_states,随机预测

            为了验证上面的推测,我们尝试在预测前先调用reset_states()函数,重置模型的状态后,再看看预测结果如何,代码如下:

    1. model.reset_states()
    2. seed = [char_to_int[alphabet[0]]]
    3. for i in range(0, len(alphabet)-1):
    4. x = numpy.reshape(seed, (1, len(seed), 1))
    5. x = x / float(len(alphabet))
    6. prediction = model.predict(x, verbose=0)
    7. index = numpy.argmax(prediction)
    8. print (int_to_char[seed[0]], "->", int_to_char[index])
    9. seed = [index]

            代码在执行预测方法前先调用reset_states()函数重置网络状态,发现此时的预测值就开始混乱了,不再从B开始预测。显然模型中保留的历史状态消失后,预测没有了历史信息做依据,结果就失去了规律,虽然这不能完全说明stateful的机制原理,但能让我们看到stateful确实在不同batch之间的历史状态传播中起了很大作用。

    A -> L

    L -> N

    N -> N

    N -> T

    T -> V

    V -> Y

    Y -> Y

    Y -> Z

    Z -> C

    C -> C

    C -> E

    E -> F

    F -> G

    G -> H

    H -> I

    I -> J

    J -> K

    K -> L

    L -> M

    M -> N

    N -> O

    O -> P

    P -> Q

    Q -> R

    R -> S

    2.5.结果分析

            输入和输出并没有连续关系,模型每次执行新的fit()前没重新reset_states网络状态,导致本次fit()携带上次fit()的最后状态,因此模型的权重矩阵W参数更新也会面临上次fit()带来的干扰,这也解释了为什么在训练过程中模型无法很好收敛。

            每轮模型训练中都残留很多历史状态,当下输入的X信息被弱化变得不再重要,或许这也导致了权重矩阵W只能勉强记住y的顺序,而不是X=>y的映射关系,导致最终模型没有泛化能力。 

    三、在不同位置使用reset_states的效果

    3.1.训练时每轮fit()后使用reset_states

            下面进行更进一步的研究,在训练时每次执行完fit()就调用一次reset_states重置网络状态,使得上次fit()的结果不会对本次fit()产生影响,但仍然保留stateful=True,即每个fit()内部batch之间的状态还是会正常传播。代码如下:

    1. import numpy
    2. from keras.models import Sequential
    3. from keras.layers import Dense
    4. from keras.layers import LSTM
    5. from keras.utils import np_utils
    6. # fix random seed for reproducibility
    7. numpy.random.seed(7)
    8. # define the raw dataset
    9. alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    10. # create mapping of characters to integers (0-25) and the reverse
    11. char_to_int = dict((c, i) for i, c in enumerate(alphabet))
    12. int_to_char = dict((i, c) for i, c in enumerate(alphabet))
    13. # prepare the dataset of input to output pairs encoded as integers
    14. seq_length = 1
    15. dataX = []
    16. dataY = []
    17. for i in range(0, len(alphabet) - seq_length, 1):
    18. seq_in = alphabet[i:i + seq_length]
    19. seq_out = alphabet[i + seq_length]
    20. dataX.append([char_to_int[char] for char in seq_in])
    21. dataY.append(char_to_int[seq_out])
    22. print (seq_in, '->', seq_out)
    23. # reshape X to be [samples, time steps, features]
    24. X = numpy.reshape(dataX, (len(dataX), seq_length, 1))
    25. # normalize
    26. X = X / float(len(alphabet))
    27. # one hot encode the output variable
    28. y = np_utils.to_categorical(dataY)
    29. # create and fit the model
    30. batch_size = 1
    31. model = Sequential()
    32. model.add(LSTM(16, batch_input_shape=(batch_size, X.shape[1], X.shape[2]), stateful=True))
    33. model.add(Dense(y.shape[1], activation='softmax'))
    34. model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    35. for i in range(3):
    36. model.fit(X, y, epochs=1, batch_size=batch_size, verbose=2, shuffle=False)
    37. # 每次循环不对模型参数进行重置
    38. model.reset_states()
    39. # summarize performance of the model
    40. scores = model.evaluate(X, y, batch_size=batch_size, verbose=0)
    41. print("Model Accuracy: %.2f%%" % (scores[1]*100))

            此时训练结果看起来是平稳收敛的,说明reset_states能避免上轮fit()的信息对本轮fit()产生干扰,从而有效的更新权重矩阵W,较快的达到稳定收敛。

    loss: 1.0001 - accuracy: 0.8400

    loss: 0.9941 - accuracy: 0.8400

    loss: 0.9881 - accuracy: 0.8400

    loss: 0.9821 - accuracy: 0.8400

    loss: 0.9763 - accuracy: 0.8400

    loss: 0.9704 - accuracy: 0.8400

    loss: 0.9647 - accuracy: 0.8800

    loss: 0.9589 - accuracy: 0.8800

    loss: 0.9533 - accuracy: 0.8800

    loss: 0.9476 - accuracy: 0.8800

    loss: 0.9420 - accuracy: 0.8800

    loss: 0.9365 - accuracy: 0.8800

    loss: 0.9310 - accuracy: 0.8800

    loss: 0.9255 - accuracy: 0.8800

    loss: 0.9201 - accuracy: 0.9200

    loss: 0.9147 - accuracy: 0.9200

    loss: 0.9093 - accuracy: 0.9200

    loss: 0.9040 - accuracy: 0.9200

    loss: 0.8987 - accuracy: 0.9200

    loss: 0.8934 - accuracy: 0.9200

    loss: 0.8882 - accuracy: 0.9200

    loss: 0.8830 - accuracy: 0.9200

    loss: 0.8778 - accuracy: 0.9200

    loss: 0.8726 - accuracy: 0.9200

    loss: 0.8674 - accuracy: 0.9200

    loss: 0.8623 - accuracy: 0.9600

    loss: 0.8572 - accuracy: 0.9600

    loss: 0.8520 - accuracy: 0.9600

    3.2.预测前不使用reset_states,顺序预测

            首先测试不重置网络状态的情况,代码如下:

    1. seed = [char_to_int[alphabet[0]]]
    2. for i in range(0, len(alphabet)-1):
    3. x = numpy.reshape(seed, (1, len(seed), 1))
    4. x = x / float(len(alphabet))
    5. prediction = model.predict(x, verbose=0)
    6. index = numpy.argmax(prediction)
    7. print (int_to_char[seed[0]], "->", int_to_char[index])
    8. seed = [index]

            从下面结果可以看出,从A开始预测,观察结果发现前半部分预测的全是Z,后半部分预测值又开始慢慢按照字母表顺序输出。这里有一个细节,在预测之前,我们调用了model.evaluate()函数,但在此之后没有重新reset_states。

            因此推测:下面的预测结果输出很多Z,是因为当前预测时受之前的model.evaluate()的影响,因为model.evaluate()的最后一次预测值是Z,所以在没有reset_states的情况下,新一轮的预测参数延续了上一轮Z的输出状态,导致结果多次输出为Z。下面继续进行验证这一推测:

    A -> Z

    Z -> Z

    Z -> Z

    Z -> Z

    Z -> Z

    Z -> Z

    Z -> Z

    Z -> F

    F -> F

    F -> F

    F -> H

    H -> H

    H -> I

    I -> J

    J -> L

    L -> L

    L -> M

    M -> N

    N -> O

    O -> P

    P -> Q

    Q -> R

    R -> S

    S -> T

    T -> U

    3.3.预测前使用reset_states,顺序预测+随机预测对比实验

            为了避免model.evaluate()函数的干扰,我们在预测前调用reset_states(),同时为了避免出现类似上面2.2节的问题,我们先按顺序预测18个字母,然后在随机预测10个字母,代码如下:

    1. # 先按顺序预测18个字母
    2. model.reset_states()
    3. seed = [char_to_int[alphabet[0]]]
    4. for i in range(0, len(alphabet)-1):
    5. x = numpy.reshape(seed, (1, len(seed), 1))
    6. x = x / float(len(alphabet))
    7. prediction = model.predict(x, verbose=0)
    8. index = numpy.argmax(prediction)
    9. print (int_to_char[seed[0]], "->", int_to_char[index])
    10. seed = [index]
    11. if i == 18 :
    12. break
    13. print ("---分割线---")
    14. # 再随机预测10个字母
    15. letter = "C"
    16. seed = [char_to_int[letter]]
    17. print ("New start: ", letter)
    18. for i in range(0, 10):
    19. x = numpy.reshape(seed, (1, len(seed), 1))
    20. x = x / float(len(alphabet))
    21. prediction = model.predict(x, verbose=0)
    22. index = numpy.argmax(prediction)
    23. print (int_to_char[seed[0]], "->", int_to_char[index])
    24. seed = [i+2]

            观察预测结果发现,前18个字母预测看起来是正常的,但后面即使我们改用随机字母,预测结果仍是在按照字母表顺序,延续上一次的输出结果进行输出。这说明和上面2.2节类似,模型学到的其实还是字母表的顺序,没有学到X=>y的映射关系。

            从这个例子中我们能看出reset_states()函数再重置模型参数时产生的效果。并且我们可以得出一个结论,使用stateful参数时,输出结果受历史输入序列的影响十分严重。因为此时模型并不太在意当前输入是什么,而更在意历史信息的输入。

    #前18次预测

    A -> B

    B -> C

    C -> D

    D -> E

    E -> F

    F -> G

    G -> H

    H -> I

    I -> J

    J -> K

    K -> L

    L -> M

    M -> N

    N -> O

    O -> P

    P -> Q

    Q -> R

    R -> S

    S -> T

    ---分割线---

    #后10次随机预测

    New start:  C

    C -> U

    C -> V

    D -> W

    E -> X

    F -> Y

    G -> Z

    H -> Z

    I -> Z

    J -> Z

    K -> Z

    3.4.每预测一个字母,进行一次reset_states

            为了进一步观察reset_states()函数的效果,我们尝试最后一个例子,在每次预测完成后,都调用一次model.reset_states(),重置模型参数,代码如下:

    1. letter = "C"
    2. seed = [char_to_int[letter]]
    3. print ("New start: ", letter)
    4. for i in range(0, 10):
    5. x = numpy.reshape(seed, (1, len(seed), 1))
    6. x = x / float(len(alphabet))
    7. model1.reset_states()
    8. prediction = model.predict(x, verbose=0)
    9. index = numpy.argmax(prediction)
    10. print (int_to_char[seed[0]], "->", int_to_char[index])
    11. seed = [i+2]

            观察如下结果符合预期,即每次执行reset_states()重置模型状态后,预测的结果都会从y的起始点B开始,和当前的输入无关。也进一步说明了前文推测的正确性。

            我们分析出现这种情况的原因大概率是因为训练集过少导致出现过拟合,模型没有泛化能力,只记住了字母表的顺序,没有推演出映射规律。

    New start:  C

    C -> B

    C -> B

    D -> B

    E -> B

    F -> B

    G -> B

    H -> B

    I -> B

    J -> B

    K -> B

    四、结论

    本次实验得出两个结论:

    • stateful能保留历史状态向后传递,在训练时使用stateful参数,LSTM单元会受到历史状态较大的影响,所以训练过程中必要的时候对模型进行reset_states很重要。
    • 在使用stateful时发现,如果每个batch输出结果过于有规律,训练数据很少的情况下,很可能导致模型学到每个batch间输出的顺序,而非输入和输出间的映射关系,出现过拟合。所以这也体现出另一个参数shuffle的重要。
    • 在网上别人建议说如果不是必须,不要用stateful,其实我觉得大部分人应该是用不好stateful,所以不要想当然的乱用。如果一定要用,建议多做几组实验进行参照,才能看出stateful是不是真的有效。

    参考文献:

    LSTM之Keras中Stateful参数 - 光彩照人 - 博客园

    深入理解Keras中LSTM的stateful和stateless应用区别 - 光彩照人 - 博客园

    https://www.freesion.com/article/9460826868/

     https://forecast.blog.csdn.net/article/details/126882142?spm=1001.2014.3001.5502

  • 相关阅读:
    linux(centos7)配置SSH免密登录
    [论文笔记]GLM
    使用多线程实现批处理过程
    Linux 下搭建 Hive 环境
    计算机网络笔记5 传输层
    hid-ft260驱动学习笔记 5 - ft260_i2c_probe
    关于windows下ffmpeg视频(libx264,h264_qsv,h264_cuvid,h264_amf)编码参数纪要
    Docker的基本使用
    基于SSH(Struts+Spring+Hibernate)实现汽修管理系统《建议收藏:附完整源码+数据库》
    每日一个C库函数-#2-memmove()
  • 原文地址:https://blog.csdn.net/yangwohenmai1/article/details/127792427