• AI作诗,模仿周杰伦创作歌词<->实战项目


    点击上方码农的后花园”,选择星标” 公众号

    精选文章,第一时间送达

    很久以来,我们都想让机器自己创作诗歌,当无数作家、编辑还没有抬起笔时,AI已经完成了数千篇文章。现在,这里是第一步....

    8bc4caa4a23777a620708baace439a64.png这诗做的很有感觉啊,这都是勤奋的结果啊,基本上学习了全唐诗的所有精华才有了这么牛逼的能力,这一般人能做到?

    甚至还可以模仿周杰伦创作歌词 !!955f467cf7261b4e0819695a2cd35389.png怎么说,目前由于缺乏训练文本,导致我们的AI做的歌词有点....额,还好啦,有那么一点忧郁之风。

    1.下载代码和数据集

    Github地址: https://github.com/jinfagang/tensorflow_poems

    数据集: 存放于项目的data文件夹内

    fa4fb44039883035f308f81c9320ac2d.png

    2.环境导入

    1. import os
    2. import tensorflow as tf
    3. from poems.model import rnn_model
    4. from poems.poems import process_poems, generate_batch
    5. import argparse
    6. from pathlib import Path

    3.参数设置

    1. parser = argparse.ArgumentParser()
    2. #type是要传入的参数的数据类型  help是该参数的提示信息
    3. parser.add_argument('--batch_size'type=int, help='batch_size',default=64)
    4. parser.add_argument('--learning_rate'type=float, help='learning_rate',default=0.0001)
    5. parser.add_argument('--model_dir'type=Path, help='model save path.',default='./model')
    6. parser.add_argument('--file_path'type=Path, help='file name of poems.',default='./data/poems.txt')
    7. parser.add_argument('--model_prefix'type=str, help='model save prefix.',default='poems')
    8. parser.add_argument('--epochs'type=int, help='train how many epochs.',default=126)
    9. args = parser.parse_args(args=[])

    4.训练

    下载的代码中的./model/中包含最新的训练模型,再次训练会接着训练。如果训练路径报错,需要删除./model的模型,重新开始训练。

    1. def run_training():
    2.     if not os.path.exists(args.model_dir):
    3.         os.makedirs(args.model_dir)
    4.     poems_vector, word_to_int, vocabularies = process_poems(args.file_path)
    5.     batches_inputs, batches_outputs = generate_batch(args.batch_size, poems_vector, word_to_int)
    6.     input_data = tf.placeholder(tf.int32, [args.batch_size, None])
    7.     output_targets = tf.placeholder(tf.int32, [args.batch_size, None])
    8.     end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
    9.         vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=args.learning_rate)
    10.     saver = tf.train.Saver(tf.global_variables())
    11.     init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    12.     with tf.Session() as sess:
    13.         # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
    14.         # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
    15.         sess.run(init_op)
    16.         start_epoch = 0
    17.         checkpoint = tf.train.latest_checkpoint(args.model_dir)
    18.         if checkpoint:
    19.             saver.restore(sess, checkpoint)
    20.             print("## restore from the checkpoint {0}".format(checkpoint))
    21.             start_epoch += int(checkpoint.split('-')[-1])
    22.         print('## start training...')
    23.         try:
    24.             n_chunk = len(poems_vector) // args.batch_size
    25.             for epoch in range(start_epoch, args.epochs):
    26.                 n = 0
    27.                 for batch in range(n_chunk):
    28.                     loss, _, _ = sess.run([
    29.                         end_points['total_loss'],
    30.                         end_points['last_state'],
    31.                         end_points['train_op']
    32.                     ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
    33.                     n += 1
    34.                 print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
    35.                 #if epoch % 5 == 0:
    36.                 saver.save(sess, os.path.join(args.model_dir, args.model_prefix), global_step=epoch)
    37.         except KeyboardInterrupt:
    38.             print('## Interrupt manually, try saving checkpoint for now...')
    39.             saver.save(sess, os.path.join(args.model_dir, args.model_prefix), global_step=epoch)
    40.             print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
    41.             
    42. run_training()

    5.诗词生成

    1. import tensorflow as tf
    2. from poems.model import rnn_model
    3. from poems.poems import process_poems
    4. import numpy as np
    5. start_token = 'B'
    6. end_token = 'E'
    7. model_dir = './model/'
    8. corpus_file = './data/poems.txt'
    9. lr = 0.0002
    10. def to_word(predict, vocabs):
    11.     predict = predict[0]       
    12.     predict /= np.sum(predict)
    13.     sample = np.random.choice(np.arange(len(predict)), p=predict)
    14.     if sample > len(vocabs):
    15.         return vocabs[-1]
    16.     else:
    17.         return vocabs[sample]
    18. def gen_poem(begin_word):
    19.     batch_size = 1
    20.     print('## loading corpus from %s' % model_dir)
    21.     tf.reset_default_graph()
    22.     
    23.     poems_vector, word_int_map, vocabularies = process_poems(corpus_file)
    24.     input_data = tf.placeholder(tf.int32, [batch_size, None])
    25.     end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
    26.         vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr)
    27.     saver = tf.train.Saver(tf.global_variables())
    28.     init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    29.     with tf.Session() as sess:
    30.         sess.run(init_op)
    31.         checkpoint = tf.train.latest_checkpoint(model_dir)
    32.         saver.restore(sess, checkpoint)
    33.         x = np.array([list(map(word_int_map.get, start_token))])
    34.         [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
    35.                                          feed_dict={input_data: x})
    36.         word = begin_word or to_word(predict, vocabularies)
    37.         poem_ = ''
    38.         i = 0
    39.         while word != end_token:
    40.             poem_ += word
    41.             i += 1
    42.             if i > 24:
    43.                 break
    44.             x = np.array([[word_int_map[word]]])
    45.             [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
    46.                                              feed_dict={input_data: x, end_points['initial_state']: last_state})
    47.             word = to_word(predict, vocabularies)
    48.         return poem_
    49. def pretty_print_poem(poem_):
    50.     poem_sentences = poem_.split('。')
    51.     for s in poem_sentences:
    52.         if s != '' and len(s) > 10:
    53.             print(s + '。')

    6.测试运行

    5c5f84595f5a8c12d4461dae750357da.png

    7. 运行环境

    本次使用框架TensorFlow1.13.1,本项目可以在华为提供的JupyterLab环境中运行。参考华为的实践案例:《AI作诗》:http://su.modelarts.club/dqTT https://developer.huaweicloud.com/signup/e4240e984d1c4d20bfcc83e7f7648b6c?

    后台回复关键字:项目实战,可下载完整代码。

    ed0131cc7caf4482de1a18ef98944741.png

    13125fdc342eb0446b6ed9607364de08.gif

    ·················END·················

  • 相关阅读:
    六安RapidSSL泛域名https能保护几个域名
    Linux 本地 Docker Registry本地镜像仓库远程连接【内网穿透】
    9.19~9.20elf论文(浮点数的二进制表示&确定擦除尾随0的数量)
    Java---Java Web---JSP
    Vue2.0新手入门-模板语法-计算属性与监听属性的介绍和差异
    【Python学习笔记】列表、元组
    【LeetCode】完全二叉树的节点个数 [M](递归)
    122 买卖股票的最佳时机II
    基于springboot的房产销售系统
    接口自动化测试实战:JMeter+Ant+Jenkins+钉钉机器人群通知完美结合
  • 原文地址:https://blog.csdn.net/weixin_45192980/article/details/126277302