• 用 TripletLoss 优化bert ranking


    下面是 用 TripletLoss 优化bert ranking 的demo

    1. import torch
    2. from torch.utils.data import DataLoader, Dataset
    3. from transformers import BertModel, BertTokenizer
    4. from sklearn.metrics.pairwise import pairwise_distances
    5. class TripletRankingDataset(Dataset):
    6. def __init__(self, queries, positive_docs, negative_docs, tokenizer, max_length):
    7. self.input_ids_q = []
    8. self.attention_masks_q = []
    9. self.input_ids_p = []
    10. self.attention_masks_p = []
    11. self.input_ids_n = []
    12. self.attention_masks_n = []
    13. for query, pos_doc, neg_doc in zip(queries, positive_docs, negative_docs):
    14. encoded_query = tokenizer.encode_plus(query, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
    15. encoded_pos_doc = tokenizer.encode_plus(pos_doc, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
    16. encoded_neg_doc = tokenizer.encode_plus(neg_doc, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
    17. self.input_ids_q.append(encoded_query['input_ids'])
    18. self.attention_masks_q.append(encoded_query['attention_mask'])
    19. self.input_ids_p.append(encoded_pos_doc['input_ids'])
    20. self.attention_masks_p.append(encoded_pos_doc['attention_mask'])
    21. self.input_ids_n.append(encoded_neg_doc['input_ids'])
    22. self.attention_masks_n.append(encoded_neg_doc['attention_mask'])
    23. self.input_ids_q = torch.cat(self.input_ids_q, dim=0)
    24. self.attention_masks_q = torch.cat(self.attention_masks_q, dim=0)
    25. self.input_ids_p = torch.cat(self.input_ids_p, dim=0)
    26. self.attention_masks_p = torch.cat(self.attention_masks_p, dim=0)
    27. self.input_ids_n = torch.cat(self.input_ids_n, dim=0)
    28. self.attention_masks_n = torch.cat(self.attention_masks_n, dim=0)
    29. def __len__(self):
    30. return len(self.input_ids_q)
    31. def __getitem__(self, idx):
    32. input_ids_q = self.input_ids_q[idx]
    33. attention_mask_q = self.attention_masks_q[idx]
    34. input_ids_p = self.input_ids_p[idx]
    35. attention_mask_p = self.attention_masks_p[idx]
    36. input_ids_n = self.input_ids_n[idx]
    37. attention_mask_n = self.attention_masks_n[idx]
    38. return input_ids_q, attention_mask_q, input_ids_p, attention_mask_p, input_ids_n, attention_mask_n
    39. class BERTTripletRankingModel(torch.nn.Module):
    40. def __init__(self, bert_model_name, hidden_size):
    41. super(BERTTripletRankingModel, self).__init__()
    42. self.bert = BertModel.from_pretrained(bert_model_name)
    43. self.dropout = torch.nn.Dropout(0.1)
    44. self.fc = torch.nn.Linear(hidden_size, 1)
    45. def forward(self, input_ids, attention_mask):
    46. outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
    47. pooled_output = self.dropout(outputs[1])
    48. logits = self.fc(pooled_output)
    49. return logits.squeeze()
    50. def triplet_loss(anchor, positive, negative, margin):
    51. distance_positive = torch.nn.functional.pairwise_distance(anchor, positive)
    52. distance_negative = torch.nn.functional.pairwise_distance(anchor, negative)
    53. losses = torch.relu(distance_positive - distance_negative + margin)
    54. return torch.mean(losses)
    55. # 初始化BERT模型和分词器
    56. bert_model_name = 'bert-base-uncased'
    57. tokenizer = BertTokenizer.from_pretrained(bert_model_name)
    58. # 示例输入数据
    59. queries = ['I like cats', 'The sun is shining']
    60. positive_docs = ['I like dogs', 'The weather is beautiful']
    61. negative_docs = ['Snakes are dangerous', 'It is raining']
    62. # 超参数
    63. batch_size = 8
    64. max_length = 128
    65. learning_rate = 1e-5
    66. num_epochs = 5
    67. margin = 1.0
    68. # 创建数据集和数据加载器
    69. dataset = TripletRankingDataset(queries, positive_docs, negative_docs, tokenizer, max_length)
    70. dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    71. # 初始化模型并加载预训练权重
    72. model = BERTTripletRankingModel(bert_model_name, hidden_size=model.bert.config.hidden_size)
    73. optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    74. # 训练模型
    75. model.train()
    76. for epoch in range(num_epochs):
    77. total_loss = 0
    78. for input_ids_q, attention_masks_q, input_ids_p, attention_masks_p, input_ids_n, attention_masks_n in dataloader:
    79. optimizer.zero_grad()
    80. embeddings_q = model(inputids_q, attention_masks_q)
    81. embeddings_p = model(input_ids_p, attention_masks_p)
    82. embeddings_n = model(input_ids_n, attention_masks_n)
    83. loss = triplet_loss(embeddings_q, embeddings_p, embeddings_n, margin)
    84. total_loss += loss.item()
    85. loss.backward()
    86. optimizer.step()
    87. print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss:.4f}")
    88. # 推断模型
    89. model.eval()
    90. with torch.no_grad():
    91. embeddings = model.bert.embeddings.word_embeddings(dataset.input_ids_q)
    92. pairwise_distances = pairwise_distances(embeddings.numpy())
    93. # 输出结果
    94. for i, query in enumerate(queries):
    95. print(f"Query: {query}")
    96. print("Documents:")
    97. for j, doc in enumerate(positive_docs):
    98. doc_idx = pairwise_distances[0][i * len(positive_docs) + j]
    99. doc_dist = pairwise_distances[1][i * len(positive_docs) + j]
    100. print(f"Document index: {doc_idx}, Distance: {doc_dist:.4f}")
    101. print(f"Document: {doc}")
    102. print("")
    103. print("---------")

  • 相关阅读:
    CodeTalker 踩坑实录
    计算机视觉与深度学习实战,Python工具,深度学习的视觉场景识别
    【电商运营】如何吸引客户?经典WhatsApp营销案例分享!
    Vue 3.0前的 TypeScript 最佳入门实践
    RPA机器人的10大基础功能与2大类型
    【深度学习实验】卷积神经网络(七):实现深度残差神经网络ResNet
    ⑮、企业快速开发平台Spring Cloud之HTML 速查列表
    wazuh自定义规则-检测内网扫描行为
    微信小程序-起步
    SpringBoot MongoDB操作封装
  • 原文地址:https://blog.csdn.net/jp_666/article/details/132759960