• bert ranking listwise demo


    下面是用bert 训练listwise rank 的 demo 

    1. import torch
    2. from torch.utils.data import DataLoader, Dataset
    3. from transformers import BertModel, BertTokenizer
    4. from sklearn.metrics import pairwise_distances_argmin_min
    5. class ListwiseRankingDataset(Dataset):
    6. def __init__(self, queries, documents, labels, tokenizer, max_length):
    7. self.input_ids = []
    8. self.attention_masks = []
    9. self.labels = []
    10. for query, doc_list, label_list in zip(queries, documents, labels):
    11. for doc, label in zip(doc_list, label_list):
    12. encoded_pair = tokenizer(query, doc, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
    13. self.input_ids.append(encoded_pair['input_ids'])
    14. self.attention_masks.append(encoded_pair['attention_mask'])
    15. self.labels.append(label)
    16. self.input_ids = torch.cat(self.input_ids, dim=0)
    17. self.attention_masks = torch.cat(self.attention_masks, dim=0)
    18. self.labels = torch.tensor(self.labels)
    19. def __len__(self):
    20. return len(self.input_ids)
    21. def __getitem__(self, idx):
    22. input_id = self.input_ids[idx]
    23. attention_mask = self.attention_masks[idx]
    24. label = self.labels[idx]
    25. return input_id, attention_mask, label
    26. class BERTListwiseRankingModel(torch.nn.Module):
    27. def __init__(self, bert_model_name):
    28. super(BERTListwiseRankingModel, self).__init__()
    29. self.bert = BertModel.from_pretrained(bert_model_name)
    30. self.dropout = torch.nn.Dropout(0.1)
    31. self.fc = torch.nn.Linear(self.bert.config.hidden_size, 1)
    32. def forward(self, input_ids, attention_mask):
    33. outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
    34. pooled_output = self.dropout(outputs[1])
    35. logits = self.fc(pooled_output)
    36. return logits.squeeze()
    37. # 初始化BERT模型和分词器
    38. bert_model_name = 'bert-base-uncased'
    39. tokenizer = BertTokenizer.from_pretrained(bert_model_name)
    40. # 示例输入数据
    41. queries = ['I like cats', 'The sun is shining']
    42. documents = [['I like dogs', 'Dogs are cute'], ['It is raining', 'Rainy weather is gloomy']]
    43. labels = [[1, 0], [0, 1]]
    44. # 超参数
    45. batch_size = 8
    46. max_length = 128
    47. learning_rate = 1e-5
    48. num_epochs = 5
    49. # 创建数据集和数据加载器
    50. dataset = ListwiseRankingDataset(queries, documents, labels, tokenizer, max_length)
    51. dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    52. # 初始化模型并加载预训练权重
    53. model = BERTListwiseRankingModel(bert_model_name)
    54. optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    55. # 训练模型
    56. model.train()
    57. for epoch in range(num_epochs):
    58. total_loss = 0
    59. for input_ids, attention_masks, labels in dataloader:
    60. optimizer.zero_grad()
    61. logits = model(input_ids, attention_masks)
    62. # 计算损失函数(使用交叉熵损失函数)
    63. loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.float())
    64. total_loss += loss.item()
    65. loss.backward()
    66. optimizer.step()
    67. print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss:.4f}")
    68. # 推断模型
    69. model.eval()
    70. with torch.no_grad():
    71. embeddings = model.bert.embeddings.word_embeddings(dataset.input_ids)
    72. pairwise_distances = pairwise_distances_argmin_min(embeddings.numpy())
    73. # 输出结果
    74. for i, query in enumerate(queries):
    75. print(f"Query: {query}")
    76. print("Documents:")
    77. for j, doc in enumerate(documents[i]):
    78. doc_idx = pairwise_distances[0][i * len(documents[i]) + j]
    79. doc_dist = pairwise_distances[1][i * len(documents[i]) + j]
    80. print(f"Document index: {doc_idx}, Distance: {doc_dist:.4f}")
    81. print(f"Document: {doc}")
    82. print("")
    83. print("---------")

  • 相关阅读:
    【每日一题Day44】LC1779找到最近的相同X和相同Y的点 | 模拟
    设计模式-责任链模式
    7-爬虫-中间件和下载中间件(加代理,加请求头,加cookie)、scrapy集成selenium、源码去重规则(布隆过滤器)、分布式爬虫
    Linux 和 macOS 下 rename 批量重命名文件
    计算机等级考试—信息安全三级真题九
    AppInfo应用信息查看V1.0.2测试版
    消费行业分析:我国可降解餐具现状分析及未来发展趋势预测
    这才是真正的Spring全家桶:Spring+SpringData+MVC+Boot+Cloud
    YOLOv5项目实战(4)— 简单三步,教你按比例划分数据集
    JVM相关概念
  • 原文地址:https://blog.csdn.net/jp_666/article/details/132759643