仅记录学习过程,有问题欢迎讨论
本Demo新增了embedding层、池化层、丢弃层
对样本数据能够更好的处理。
详情请看注释。
import json
import random
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.utils.data as Data
"""
基于pytorch的网络编写
实现一个网络完成一个简单nlp任务
判断文本中是否有某些特定字符出现
"""
"""
词表示成向量输入神经网络,一种简单的方法是使用二进制独热编码(one-hot向量)
embedding层:可以降维;从稀疏矩阵到密集矩阵的过程,叫做Embedding
文本数值化 : abc ---> 3*5的矩阵=== 文本长度*向量长度
池化层:主要用于减少数据量和计算复杂度,同时保留重要的特征信息。它可以提高模型的训练速度和泛化能力,避免过拟合等问题
"""
class TorchModel(nn.Module):
def __init__(self, vector_dim, sentence_length, vocab):
super(TorchModel, self).__init__()
# embedding层 第一个参数为字符集长度 第二个为输入x的维度
self.embedding = nn.Embedding(len(vocab)+1, vector_dim, padding_idx=0)
#
self.layer = nn.Linear(vector_dim, vector_dim)
# 池化层 保留重要特征 提升效率 1维池化(缩减模型大小) 入参为文本长度 压缩
# eg: 3*4*5 ===> 3*5*4 =====> 3*5*1 压缩为1维
# 注意池化的是文本长度 规则为向下池化
self.pool = nn.MaxPool1d(sentence_length)
# 线性层
self.linear = nn.Linear(vector_dim, 3)
# sigmoid做激活函数
self.activation = torch.sigmoid
# 0.2~0.5 丢弃层 防止过拟合
self.dropout = nn.Dropout(0.4)
self.loss = nn.functional.cross_entropy
def forward(self, x, y=None):
x = self.embedding(x) # input shape:(batch_size, sen_len) (20,6)
x = self.layer(x) # input shape:(batch_size, sen_len, input_dim) (20,6,20)
x = self.dropout(x) # input shape:(batch_size, sen_len, input_dim)
# 如果是 最大池 先归一化 效果好很多
x = self.activation(x) # input shape:(batch_size, sen_len, input_dim)
x = self.pool(x.transpose(1, 2)).squeeze() # input shape:(batch_size, sen_len, input_dim)
x = self.linear(x) # input shape:(batch_size, input_dim) (20,3)
y_pred = self.activation(x)
if y is None:
return y_pred
else:
return self.loss(y_pred, y)
# 构建一个字符集的对应关系
def build_vocab():
chars = "abcdefghijklmnopqrstuvwxyz" # 字符集
# 字典格式{}
vocab = {}
for index, char in enumerate(chars):
# a == 1 b == 2
vocab[char] = index + 1
vocab['unk'] = len(vocab) + 1
return vocab
# 建立数据
def build_simple(vocab, sentence_length):
# 生成全是字母的 x
x = [random.choice(list(vocab.keys())) for _ in range(sentence_length)]
if set("abc") & set(x) and not set("xyz") & set(x):
y = 1
elif not set("abc") & set(x) and set("xyz") & set(x):
y = 2
else:
y = 0
# x 转化为数字
x = [vocab.get(word, vocab['unk']) for word in x]
return x, y
# 合并数据
def build_dataset(train_simple, vocab, sentence_length):
dataset_x = []
dataset_y = []
for i in range(train_simple):
X, Y = build_simple(vocab, sentence_length)
dataset_x.append(X)
dataset_y.append(Y)
return torch.LongTensor(dataset_x), torch.LongTensor(dataset_y)
# 评估效果
def evaluate(model, vocab, sentence_length):
model.eval()
x, y = build_dataset(200, vocab, sentence_length)
correct, wrong = 0, 0
print("0类样本数量:%d, 1类样本数量:%d, 2类样本数量:%d" % (y.tolist().count(0), y.tolist().count(1), y.tolist().count(2)))
with torch.no_grad():
y_pred = model(x)
for y_p, y_t in zip(y_pred, y): # 与真实标签进行对比
if int(torch.argmax(y_p)) == int(y_t):
correct += 1 # 正样本判断正确
else:
wrong += 1
print("正确预测个数:%d / %d, 正确率:%f" % (correct, correct + wrong, correct / (correct + wrong)))
return correct / (correct + wrong)
def main():
batch_size = 20
batch_num = 20
lr = 0.005
sentence_length = 6
train_simple = 1000
# 每个字符的维度
input_dim = 20
vocab = build_vocab()
# 构建数据
x, y = build_dataset(train_simple, vocab, sentence_length)
dataset = Data.TensorDataset(x, y)
dataiter = Data.DataLoader(dataset, batch_size, shuffle=True)
model = TorchModel(input_dim, sentence_length, vocab)
optim = torch.optim.Adam(model.parameters(), lr=lr)
log = []
for epoch in range(batch_num):
model.train()
epoch_loss = []
for x, y_true in dataiter:
optim.zero_grad()
loss = model(x, y_true)
loss.backward()
optim.step()
epoch_loss.append(loss.item())
print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(epoch_loss)))
acc = evaluate(model, vocab, sentence_length) # 测试本轮模型结果
log.append([acc, np.mean(epoch_loss)])
plt.plot(range(len(log)), [l[0] for l in log]) # 画acc曲线
plt.plot(range(len(log)), [l[1] for l in log]) # 画loss曲线
plt.show()
# 保存模型
torch.save(model.state_dict(), "model.pth")
# 保存词表
writer = open("vocab.json", "w", encoding="utf8")
writer.write(json.dumps(vocab, ensure_ascii=False, indent=2))
writer.close()
return
if __name__ == '__main__':
main()