• 使用 PyTorch 实现 Word2Vec 中Skip-gram 模型


    首先创建了一个使用 Word2VecDataset 类自定义的数据集,用于生成训练数据。然后,定义了 Skip-gram 模型,并使用交叉熵损失函数和 Adam 优化器进行训练。

    在每个训练周期中,遍历数据加载器,对每个批次进行前向传播、计算损失、反向传播和权重更新。最后,得到训练得到的词向量,并可以使用 word_vector 来获取特定单词的词向量表示。

    确保在运行之前安装 PyTorch,可以使用 pip install torch 来安装它。请注意,如果可用的话,代码将在 GPU 上运行。如果没有 GPU,请将 .to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 部分删除,并在 CPU 上运行。

    以下是使用 PyTorch 实现 Skip-gram 模型的示例代码:

    1. import torch
    2. import torch.nn as nn
    3. import torch.optim as optim
    4. from torch.utils.data import Dataset, DataLoader
    5. # Hyperparameters
    6. embedding_dim = 100
    7. window_size = 2
    8. learning_rate = 0.001
    9. epochs = 100
    10. batch_size = 32
    11. # Example corpus
    12. corpus = [['I', 'enjoy', 'playing', 'football', 'with', 'my', 'friends'],
    13. ['We', 'like', 'to', 'play', 'tennis', 'on', 'weekends'],
    14. ['She', 'is', 'a', 'good', 'dancer']]
    15. # Create vocabulary
    16. vocab = list(set([word for sentence in corpus for word in sentence]))
    17. vocab_size = len(vocab)
    18. word2idx = {word: idx for idx, word in enumerate(vocab)}
    19. idx2word = {idx: word for idx, word in enumerate(vocab)}
    20. # Generate training data
    21. class Word2VecDataset(Dataset):
    22. def __init__(self, corpus, word2idx):
    23. self.data = []
    24. for sentence in corpus:
    25. word_indices = [word2idx[word] for word in sentence]
    26. for center_word_idx, center_word in enumerate(word_indices):
    27. for context_word_idx in range(max(0, center_word_idx - window_size), min(center_word_idx + window_size + 1, len(word_indices))):
    28. if context_word_idx != center_word_idx:
    29. context_word = word_indices[context_word_idx]
    30. self.data.append((center_word, context_word))
    31. def __len__(self):
    32. return len(self.data)
    33. def __getitem__(self, idx):
    34. return self.data[idx]
    35. dataset = Word2VecDataset(corpus, word2idx)
    36. dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    37. # Define Skip-gram model
    38. class SkipGramModel(nn.Module):
    39. def __init__(self, vocab_size, embedding_dim):
    40. super(SkipGramModel, self).__init__()
    41. self.embedding = nn.Embedding(vocab_size, embedding_dim)
    42. self.linear = nn.Linear(embedding_dim, vocab_size)
    43. def forward(self, center_word):
    44. embedded = self.embedding(center_word)
    45. output = self.linear(embedded)
    46. return output
    47. model = SkipGramModel(vocab_size, embedding_dim)
    48. criterion = nn.CrossEntropyLoss()
    49. optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    50. # Training
    51. for epoch in range(epochs):
    52. running_loss = 0.0
    53. for i, (center_word, context_word) in enumerate(dataloader):
    54. optimizer.zero_grad()
    55. center_word = center_word.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    56. context_word = context_word.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    57. output = model(center_word)
    58. loss = criterion(output, context_word)
    59. loss.backward()
    60. optimizer.step()
    61. running_loss += loss.item()
    62. average_loss = running_loss / len(dataloader)
    63. print(f'Epoch {epoch+1}/{epochs}, Loss: {average_loss:.4f}')
    64. # Get trained word embeddings
    65. trained_embeddings = model.embedding.weight.data.numpy()
    66. # Example usage - Getting word vector for a word
    67. word = 'football'
    68. word_vector = trained_embeddings[word2idx[word]]
    69. print(f"Word vector for '{word}': {word_vector}")

    运行结果如下:

    Epoch 1/100, Loss: 3.1324
    Epoch 2/100, Loss: 3.0791
    Epoch 3/100, Loss: 2.9902
    Epoch 4/100, Loss: 2.9392
    Epoch 5/100, Loss: 2.8870
    Epoch 6/100, Loss: 2.8166
    Epoch 7/100, Loss: 2.7615
    Epoch 8/100, Loss: 2.7017
    Epoch 9/100, Loss: 2.6500
    Epoch 10/100, Loss: 2.5993
    Epoch 11/100, Loss: 2.5496
    Epoch 12/100, Loss: 2.5013
    Epoch 13/100, Loss: 2.4621
    Epoch 14/100, Loss: 2.4079
    Epoch 15/100, Loss: 2.3660
    Epoch 16/100, Loss: 2.3229
    Epoch 17/100, Loss: 2.2795
    Epoch 18/100, Loss: 2.2398
    Epoch 19/100, Loss: 2.1998
    Epoch 20/100, Loss: 2.1582
    Epoch 21/100, Loss: 2.1278
    Epoch 22/100, Loss: 2.1023
    Epoch 23/100, Loss: 2.0569
    Epoch 24/100, Loss: 2.0245
    Epoch 25/100, Loss: 1.9936
    Epoch 26/100, Loss: 1.9639
    Epoch 27/100, Loss: 1.9344
    Epoch 28/100, Loss: 1.9137
    Epoch 29/100, Loss: 1.8888
    Epoch 30/100, Loss: 1.8586
    Epoch 31/100, Loss: 1.8352
    Epoch 32/100, Loss: 1.8200
    Epoch 33/100, Loss: 1.7815
    Epoch 34/100, Loss: 1.7685
    Epoch 35/100, Loss: 1.7531
    Epoch 36/100, Loss: 1.7209
    Epoch 37/100, Loss: 1.7049
    Epoch 38/100, Loss: 1.6881
    Epoch 39/100, Loss: 1.6775
    Epoch 40/100, Loss: 1.6517
    Epoch 41/100, Loss: 1.6390
    Epoch 42/100, Loss: 1.6238
    Epoch 43/100, Loss: 1.6077
    Epoch 44/100, Loss: 1.5939
    Epoch 45/100, Loss: 1.5745
    Epoch 46/100, Loss: 1.5703
    Epoch 47/100, Loss: 1.5574
    Epoch 48/100, Loss: 1.5458
    Epoch 49/100, Loss: 1.5308
    Epoch 50/100, Loss: 1.5215
    Epoch 51/100, Loss: 1.5122
    Epoch 52/100, Loss: 1.4988
    Epoch 53/100, Loss: 1.4958
    Epoch 54/100, Loss: 1.4773
    Epoch 55/100, Loss: 1.4746
    Epoch 56/100, Loss: 1.4618
    Epoch 57/100, Loss: 1.4560
    Epoch 58/100, Loss: 1.4506
    Epoch 59/100, Loss: 1.4380
    Epoch 60/100, Loss: 1.4266
    Epoch 61/100, Loss: 1.4257
    Epoch 62/100, Loss: 1.4148
    Epoch 63/100, Loss: 1.4090
    Epoch 64/100, Loss: 1.4070
    Epoch 65/100, Loss: 1.3940
    Epoch 66/100, Loss: 1.3890
    Epoch 67/100, Loss: 1.3846
    Epoch 68/100, Loss: 1.3813
    Epoch 69/100, Loss: 1.3738
    Epoch 70/100, Loss: 1.3717
    Epoch 71/100, Loss: 1.3681
    Epoch 72/100, Loss: 1.3594
    Epoch 73/100, Loss: 1.3593
    Epoch 74/100, Loss: 1.3504
    Epoch 75/100, Loss: 1.3447
    Epoch 76/100, Loss: 1.3439
    Epoch 77/100, Loss: 1.3397
    Epoch 78/100, Loss: 1.3315
    Epoch 79/100, Loss: 1.3260
    Epoch 80/100, Loss: 1.3253
    Epoch 81/100, Loss: 1.3229
    Epoch 82/100, Loss: 1.3215
    Epoch 83/100, Loss: 1.3148
    Epoch 84/100, Loss: 1.3160
    Epoch 85/100, Loss: 1.3072
    Epoch 86/100, Loss: 1.3105
    Epoch 87/100, Loss: 1.3104
    Epoch 88/100, Loss: 1.3018
    Epoch 89/100, Loss: 1.2912
    Epoch 90/100, Loss: 1.2950
    Epoch 91/100, Loss: 1.2938
    Epoch 92/100, Loss: 1.2951
    Epoch 93/100, Loss: 1.2859
    Epoch 94/100, Loss: 1.2902
    Epoch 95/100, Loss: 1.2840
    Epoch 96/100, Loss: 1.2748
    Epoch 97/100, Loss: 1.2840
    Epoch 98/100, Loss: 1.2763
    Epoch 99/100, Loss: 1.2772
    Epoch 100/100, Loss: 1.2746


    Word vector for 'football':

    [-1.2727762   0.8401019  -0.5115612   2.0667355   1.1854529  -0.7444803
     -1.9658612  -1.0488677   0.98938674 -1.1675086   1.582392    1.7414839
     -0.4892138  -1.2149098   0.15343344 -1.8318586   0.41794038  0.25481498
      0.6008032  -0.23904797  0.80143225 -1.0495795  -1.0174142  -0.01827855
      2.7477944  -0.9574399   1.025569    2.4843202  -0.2796719  -0.4390253
     -1.4423424  -1.8073392   0.1897556   0.90259725  2.7565296  -0.28331178
     -1.8443514   0.77545553 -1.0289538   0.71483964  1.1801128  -0.22635305
      0.5960759   0.6690206  -1.9100318   1.2388043  -0.68522704  0.92120373
      1.0252377  -1.4376261  -0.6595934   0.31699112  0.6751458   0.99656415
      0.40565705 -1.0904227  -0.3513346  -0.66078615  1.1834346  -1.0899751
     -1.4925232  -0.30818892  1.4249563   0.06006899 -3.2386255   0.96192694
     -1.1045157   0.5540482  -1.5388466  -0.8721646   1.1221852   1.6488599
      0.44869688  1.1519432  -1.4588032  -0.04230021 -0.33113605  1.1316347
     -0.7425484  -0.11400439  0.37237874 -0.34573358  0.4140474  -0.04413145
      0.6157635  -1.0094129  -1.2208599  -0.7154122   0.9412035   0.9452426
     -0.0973389  -0.23566085  0.34300375 -0.95858365  0.8764276  -0.5669889
     -1.933235    0.22371146  1.6641699   1.3258857 ]

  • 相关阅读:
    使用定时任务发布文章的流程
    我的Spark学习笔记
    学1个月爬虫就月赚6000?别被骗了,老师傅告诉你爬虫的真实情况!
    什么是自然语言处理
    Clickhouse 分布式表的写入原理
    南京邮电大学计算机网络实验二(网络路由器配置RIP协议)
    超详细的MySQL三万字总结
    Thymeleaf中使用二维数组[[]]报错:Could not parse as expression
    kafka的Java客户端-消费者
    【文章摘要-20231019】Any-to-Any Generation via Composable Diffusion
  • 原文地址:https://blog.csdn.net/Metal1/article/details/132886936