首先创建了一个使用 Word2VecDataset
类自定义的数据集,用于生成训练数据。然后,定义了 Skip-gram 模型,并使用交叉熵损失函数和 Adam 优化器进行训练。
在每个训练周期中,遍历数据加载器,对每个批次进行前向传播、计算损失、反向传播和权重更新。最后,得到训练得到的词向量,并可以使用 word_vector
来获取特定单词的词向量表示。
确保在运行之前安装 PyTorch,可以使用 pip install torch
来安装它。请注意,如果可用的话,代码将在 GPU 上运行。如果没有 GPU,请将 .to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
部分删除,并在 CPU 上运行。
以下是使用 PyTorch 实现 Skip-gram 模型的示例代码:
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.utils.data import Dataset, DataLoader
-
- # Hyperparameters
- embedding_dim = 100
- window_size = 2
- learning_rate = 0.001
- epochs = 100
- batch_size = 32
-
- # Example corpus
- corpus = [['I', 'enjoy', 'playing', 'football', 'with', 'my', 'friends'],
- ['We', 'like', 'to', 'play', 'tennis', 'on', 'weekends'],
- ['She', 'is', 'a', 'good', 'dancer']]
-
- # Create vocabulary
- vocab = list(set([word for sentence in corpus for word in sentence]))
- vocab_size = len(vocab)
- word2idx = {word: idx for idx, word in enumerate(vocab)}
- idx2word = {idx: word for idx, word in enumerate(vocab)}
-
- # Generate training data
- class Word2VecDataset(Dataset):
- def __init__(self, corpus, word2idx):
- self.data = []
- for sentence in corpus:
- word_indices = [word2idx[word] for word in sentence]
- for center_word_idx, center_word in enumerate(word_indices):
- for context_word_idx in range(max(0, center_word_idx - window_size), min(center_word_idx + window_size + 1, len(word_indices))):
- if context_word_idx != center_word_idx:
- context_word = word_indices[context_word_idx]
- self.data.append((center_word, context_word))
-
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, idx):
- return self.data[idx]
-
- dataset = Word2VecDataset(corpus, word2idx)
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
-
- # Define Skip-gram model
- class SkipGramModel(nn.Module):
- def __init__(self, vocab_size, embedding_dim):
- super(SkipGramModel, self).__init__()
- self.embedding = nn.Embedding(vocab_size, embedding_dim)
- self.linear = nn.Linear(embedding_dim, vocab_size)
-
- def forward(self, center_word):
- embedded = self.embedding(center_word)
- output = self.linear(embedded)
- return output
-
- model = SkipGramModel(vocab_size, embedding_dim)
- criterion = nn.CrossEntropyLoss()
- optimizer = optim.Adam(model.parameters(), lr=learning_rate)
-
- # Training
- for epoch in range(epochs):
- running_loss = 0.0
- for i, (center_word, context_word) in enumerate(dataloader):
- optimizer.zero_grad()
-
- center_word = center_word.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
- context_word = context_word.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
-
- output = model(center_word)
- loss = criterion(output, context_word)
- loss.backward()
- optimizer.step()
-
- running_loss += loss.item()
-
- average_loss = running_loss / len(dataloader)
- print(f'Epoch {epoch+1}/{epochs}, Loss: {average_loss:.4f}')
-
- # Get trained word embeddings
- trained_embeddings = model.embedding.weight.data.numpy()
-
- # Example usage - Getting word vector for a word
- word = 'football'
- word_vector = trained_embeddings[word2idx[word]]
- print(f"Word vector for '{word}': {word_vector}")
运行结果如下:
Epoch 1/100, Loss: 3.1324
Epoch 2/100, Loss: 3.0791
Epoch 3/100, Loss: 2.9902
Epoch 4/100, Loss: 2.9392
Epoch 5/100, Loss: 2.8870
Epoch 6/100, Loss: 2.8166
Epoch 7/100, Loss: 2.7615
Epoch 8/100, Loss: 2.7017
Epoch 9/100, Loss: 2.6500
Epoch 10/100, Loss: 2.5993
Epoch 11/100, Loss: 2.5496
Epoch 12/100, Loss: 2.5013
Epoch 13/100, Loss: 2.4621
Epoch 14/100, Loss: 2.4079
Epoch 15/100, Loss: 2.3660
Epoch 16/100, Loss: 2.3229
Epoch 17/100, Loss: 2.2795
Epoch 18/100, Loss: 2.2398
Epoch 19/100, Loss: 2.1998
Epoch 20/100, Loss: 2.1582
Epoch 21/100, Loss: 2.1278
Epoch 22/100, Loss: 2.1023
Epoch 23/100, Loss: 2.0569
Epoch 24/100, Loss: 2.0245
Epoch 25/100, Loss: 1.9936
Epoch 26/100, Loss: 1.9639
Epoch 27/100, Loss: 1.9344
Epoch 28/100, Loss: 1.9137
Epoch 29/100, Loss: 1.8888
Epoch 30/100, Loss: 1.8586
Epoch 31/100, Loss: 1.8352
Epoch 32/100, Loss: 1.8200
Epoch 33/100, Loss: 1.7815
Epoch 34/100, Loss: 1.7685
Epoch 35/100, Loss: 1.7531
Epoch 36/100, Loss: 1.7209
Epoch 37/100, Loss: 1.7049
Epoch 38/100, Loss: 1.6881
Epoch 39/100, Loss: 1.6775
Epoch 40/100, Loss: 1.6517
Epoch 41/100, Loss: 1.6390
Epoch 42/100, Loss: 1.6238
Epoch 43/100, Loss: 1.6077
Epoch 44/100, Loss: 1.5939
Epoch 45/100, Loss: 1.5745
Epoch 46/100, Loss: 1.5703
Epoch 47/100, Loss: 1.5574
Epoch 48/100, Loss: 1.5458
Epoch 49/100, Loss: 1.5308
Epoch 50/100, Loss: 1.5215
Epoch 51/100, Loss: 1.5122
Epoch 52/100, Loss: 1.4988
Epoch 53/100, Loss: 1.4958
Epoch 54/100, Loss: 1.4773
Epoch 55/100, Loss: 1.4746
Epoch 56/100, Loss: 1.4618
Epoch 57/100, Loss: 1.4560
Epoch 58/100, Loss: 1.4506
Epoch 59/100, Loss: 1.4380
Epoch 60/100, Loss: 1.4266
Epoch 61/100, Loss: 1.4257
Epoch 62/100, Loss: 1.4148
Epoch 63/100, Loss: 1.4090
Epoch 64/100, Loss: 1.4070
Epoch 65/100, Loss: 1.3940
Epoch 66/100, Loss: 1.3890
Epoch 67/100, Loss: 1.3846
Epoch 68/100, Loss: 1.3813
Epoch 69/100, Loss: 1.3738
Epoch 70/100, Loss: 1.3717
Epoch 71/100, Loss: 1.3681
Epoch 72/100, Loss: 1.3594
Epoch 73/100, Loss: 1.3593
Epoch 74/100, Loss: 1.3504
Epoch 75/100, Loss: 1.3447
Epoch 76/100, Loss: 1.3439
Epoch 77/100, Loss: 1.3397
Epoch 78/100, Loss: 1.3315
Epoch 79/100, Loss: 1.3260
Epoch 80/100, Loss: 1.3253
Epoch 81/100, Loss: 1.3229
Epoch 82/100, Loss: 1.3215
Epoch 83/100, Loss: 1.3148
Epoch 84/100, Loss: 1.3160
Epoch 85/100, Loss: 1.3072
Epoch 86/100, Loss: 1.3105
Epoch 87/100, Loss: 1.3104
Epoch 88/100, Loss: 1.3018
Epoch 89/100, Loss: 1.2912
Epoch 90/100, Loss: 1.2950
Epoch 91/100, Loss: 1.2938
Epoch 92/100, Loss: 1.2951
Epoch 93/100, Loss: 1.2859
Epoch 94/100, Loss: 1.2902
Epoch 95/100, Loss: 1.2840
Epoch 96/100, Loss: 1.2748
Epoch 97/100, Loss: 1.2840
Epoch 98/100, Loss: 1.2763
Epoch 99/100, Loss: 1.2772
Epoch 100/100, Loss: 1.2746
Word vector for 'football':
[-1.2727762 0.8401019 -0.5115612 2.0667355 1.1854529 -0.7444803
-1.9658612 -1.0488677 0.98938674 -1.1675086 1.582392 1.7414839
-0.4892138 -1.2149098 0.15343344 -1.8318586 0.41794038 0.25481498
0.6008032 -0.23904797 0.80143225 -1.0495795 -1.0174142 -0.01827855
2.7477944 -0.9574399 1.025569 2.4843202 -0.2796719 -0.4390253
-1.4423424 -1.8073392 0.1897556 0.90259725 2.7565296 -0.28331178
-1.8443514 0.77545553 -1.0289538 0.71483964 1.1801128 -0.22635305
0.5960759 0.6690206 -1.9100318 1.2388043 -0.68522704 0.92120373
1.0252377 -1.4376261 -0.6595934 0.31699112 0.6751458 0.99656415
0.40565705 -1.0904227 -0.3513346 -0.66078615 1.1834346 -1.0899751
-1.4925232 -0.30818892 1.4249563 0.06006899 -3.2386255 0.96192694
-1.1045157 0.5540482 -1.5388466 -0.8721646 1.1221852 1.6488599
0.44869688 1.1519432 -1.4588032 -0.04230021 -0.33113605 1.1316347
-0.7425484 -0.11400439 0.37237874 -0.34573358 0.4140474 -0.04413145
0.6157635 -1.0094129 -1.2208599 -0.7154122 0.9412035 0.9452426
-0.0973389 -0.23566085 0.34300375 -0.95858365 0.8764276 -0.5669889
-1.933235 0.22371146 1.6641699 1.3258857 ]