• Seq2Seq+Attention代码


    import tensorflow as tf
    import matplotlib.pyplot as plt
    import numpy as np
    
    tf.reset_default_graph()
    # S: Symbol that shows starting of decoding input
    # E: Symbol that shows starting of decoding output
    # P: Symbol that will fill in blank sequence if current batch data size is short than time steps
    sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E']
    
    word_list = " ".join(sentences).split()
    word_list = list(set(word_list))
    word_dict = {w: i for i, w in enumerate(word_list)}
    number_dict = {i: w for i, w in enumerate(word_list)}
    n_class = len(word_dict)  # vocab list
    
    # Parameter
    n_step = 5  # maxium number of words in one sentence(=number of time steps)
    n_hidden = 128
    
    def make_batch(sentences):
        input_batch = [np.eye(n_class)[[word_dict[n] for n in sentences[0].split()]]]
        output_batch = [np.eye(n_class)[[word_dict[n] for n in sentences[1].split()]]]
        target_batch = [[word_dict[n] for n in sentences[2].split()]]
        return input_batch, output_batch, target_batch
    
    # Model
    enc_inputs = tf.placeholder(tf.float32, [None, None, n_class])  # [batch_size, n_step, n_class]
    dec_inputs = tf.placeholder(tf.float32, [None, None, n_class])  # [batch_size, n_step, n_class]
    targets = tf.placeholder(tf.int64, [1, n_step])  # [batch_size, n_step], not one-hot
    
    # Linear for attention
    attn = tf.Variable(tf.random_normal([n_hidden, n_hidden]))
    out = tf.Variable(tf.random_normal([n_hidden * 2, n_class]))
    
    def get_att_score(dec_output, enc_output):  # enc_output [1, n_hidden],代表第一个时间步的output
        score = tf.squeeze(tf.matmul(enc_output, attn), 0)  # h3 * w
        dec_output = tf.squeeze(dec_output, [0, 1])  # H1 [1, n_hidden]
        return tf.tensordot(dec_output, score, 1)  # H1*h3*w
    
    def get_att_weight(dec_output, enc_outputs):
        # dec_output是decoder的第一个时间步的输出,=H[i]
        # enc_outputs:encoder阶段的output包含了每一个时间步的输出[b_z,seq,hidden]=h
        attn_scores = []  # list of attention scalar : [n_step]
        enc_outputs = tf.transpose(enc_outputs, [1, 0, 2])  # enc_outputs : [n_step, batch_size, n_hidden]
        for i in range(n_step): # h3
            attn_scores.append(get_att_score(dec_output, enc_outputs[i]))
        # Normalize scores to weights in range 0 to 1
        return tf.reshape(tf.nn.softmax(attn_scores), [1, 1, -1])  # [1, 1, n_step]
    
    model = []
    Attention = []
    with tf.variable_scope('encode'):
        enc_cell = tf.nn.rnn_cell.BasicRNNCell(n_hidden)
        enc_cell = tf.nn.rnn_cell.DropoutWrapper(enc_cell, output_keep_prob=0.5)
        # enc_outputs : [batch_size(=1), n_step(=decoder_step), n_hidden(=128)]
        # enc_hidden : [batch_size(=1), n_hidden(=128)]
        enc_outputs, enc_hidden = tf.nn.dynamic_rnn(enc_cell, enc_inputs, dtype=tf.float32)
    
    with tf.variable_scope('decode'):
        dec_cell = tf.nn.rnn_cell.BasicRNNCell(n_hidden)
        dec_cell = tf.nn.rnn_cell.DropoutWrapper(dec_cell, output_keep_prob=0.5)
    
        inputs = tf.transpose(dec_inputs, [1, 0, 2]) # [seq,b_z,embedding]
        hidden = enc_hidden
        for i in range(n_step): # enc_hidden=hidden=C,dec_output=H[i],enc_outputs=h
            # time_major True mean inputs shape: [max_time, batch_size, ...]
            dec_output, hidden = tf.nn.dynamic_rnn(dec_cell, tf.expand_dims(inputs[i], 1),
                                                   initial_state=hidden, dtype=tf.float32, time_major=True)
            attn_weights = get_att_weight(dec_output, enc_outputs)  # attn_weights : [1, 1, n_step]
            Attention.append(tf.squeeze(attn_weights))
    
            # matrix-matrix product of matrices [1, 1, n_step] x [1, n_step, n_hidden] = [1, 1, n_hidden]
            context = tf.matmul(attn_weights, enc_outputs) # 0.1* h1+0.5*h2+0.4*h3
            dec_output = tf.squeeze(dec_output, 0)  # [1, n_step]
            context = tf.squeeze(context, 1)  # [1, n_hidden]
    
            model.append(tf.matmul(tf.concat((dec_output, context), 1), out))  # [n_step, batch_size(=1), n_class]
    
    trained_attn = tf.stack([Attention[0], Attention[1], Attention[2], Attention[3], Attention[4]], 0)  # to show attention matrix
    model = tf.transpose(model, [1, 0, 2])  # model : [n_step, n_class]
    prediction = tf.argmax(model, 2)
    
    cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=model, labels=targets))
    optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)
    
    # Training and Test
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        for epoch in range(2000):
            input_batch, output_batch, target_batch = make_batch(sentences)
            _, loss, attention = sess.run([optimizer, cost, trained_attn],
                                          feed_dict={enc_inputs: input_batch, dec_inputs: output_batch, targets: target_batch})
    
            if (epoch + 1) % 400 == 0:
                print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    
        predict_batch = [np.eye(n_class)[[word_dict[n] for n in 'P P P P P'.split()]]]
        result = sess.run(prediction, feed_dict={enc_inputs: input_batch, dec_inputs: predict_batch})
        print(sentences[0].split(), '->', [number_dict[n] for n in result[0]])
    
        # Show Attention
        fig = plt.figure(figsize=(5, 5))
        ax = fig.add_subplot(1, 1, 1)
        ax.matshow(attention, cmap='viridis')
        ax.set_xticklabels([''] + sentences[0].split(), fontdict={'fontsize': 14})
        ax.set_yticklabels([''] + sentences[2].split(), fontdict={'fontsize': 14})
        plt.show()
    
  • 相关阅读:
    基于PCA主成分分析的BP神经网络回归预测研究(Matlab代码实现)
    vue解除数据双向绑定
    聊聊计算机之Intel CPU的MESI协议
    JDK对String操作优化
    C++常量
    基于C#的屏幕鼠标跟随圈圈应用 - 开源研究系列文章 - 个人小作品
    更加高效的为新项目添加 eslint 和 prettier
    Svelte3.x网页聊天实例|svelte.js仿微信PC版聊天svelte-webchat
    浏览器渲染原理-通俗易懂版本
    [PAT练级笔记] 28 Basic Level 1028 人口普查
  • 原文地址:https://blog.csdn.net/m0_67084346/article/details/128157429