下面是 用 TripletLoss 优化bert ranking 的demo
-
- import torch
- from torch.utils.data import DataLoader, Dataset
- from transformers import BertModel, BertTokenizer
- from sklearn.metrics.pairwise import pairwise_distances
-
- class TripletRankingDataset(Dataset):
- def __init__(self, queries, positive_docs, negative_docs, tokenizer, max_length):
- self.input_ids_q = []
- self.attention_masks_q = []
- self.input_ids_p = []
- self.attention_masks_p = []
- self.input_ids_n = []
- self.attention_masks_n = []
-
- for query, pos_doc, neg_doc in zip(queries, positive_docs, negative_docs):
- encoded_query = tokenizer.encode_plus(query, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
- encoded_pos_doc = tokenizer.encode_plus(pos_doc, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
- encoded_neg_doc = tokenizer.encode_plus(neg_doc, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
-
- self.input_ids_q.append(encoded_query['input_ids'])
- self.attention_masks_q.append(encoded_query['attention_mask'])
- self.input_ids_p.append(encoded_pos_doc['input_ids'])
- self.attention_masks_p.append(encoded_pos_doc['attention_mask'])
- self.input_ids_n.append(encoded_neg_doc['input_ids'])
- self.attention_masks_n.append(encoded_neg_doc['attention_mask'])
-
- self.input_ids_q = torch.cat(self.input_ids_q, dim=0)
- self.attention_masks_q = torch.cat(self.attention_masks_q, dim=0)
- self.input_ids_p = torch.cat(self.input_ids_p, dim=0)
- self.attention_masks_p = torch.cat(self.attention_masks_p, dim=0)
- self.input_ids_n = torch.cat(self.input_ids_n, dim=0)
- self.attention_masks_n = torch.cat(self.attention_masks_n, dim=0)
-
- def __len__(self):
- return len(self.input_ids_q)
-
- def __getitem__(self, idx):
- input_ids_q = self.input_ids_q[idx]
- attention_mask_q = self.attention_masks_q[idx]
- input_ids_p = self.input_ids_p[idx]
- attention_mask_p = self.attention_masks_p[idx]
- input_ids_n = self.input_ids_n[idx]
- attention_mask_n = self.attention_masks_n[idx]
- return input_ids_q, attention_mask_q, input_ids_p, attention_mask_p, input_ids_n, attention_mask_n
-
- class BERTTripletRankingModel(torch.nn.Module):
- def __init__(self, bert_model_name, hidden_size):
- super(BERTTripletRankingModel, self).__init__()
- self.bert = BertModel.from_pretrained(bert_model_name)
- self.dropout = torch.nn.Dropout(0.1)
- self.fc = torch.nn.Linear(hidden_size, 1)
-
- def forward(self, input_ids, attention_mask):
- outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
- pooled_output = self.dropout(outputs[1])
- logits = self.fc(pooled_output)
- return logits.squeeze()
-
- def triplet_loss(anchor, positive, negative, margin):
- distance_positive = torch.nn.functional.pairwise_distance(anchor, positive)
- distance_negative = torch.nn.functional.pairwise_distance(anchor, negative)
- losses = torch.relu(distance_positive - distance_negative + margin)
- return torch.mean(losses)
-
- # 初始化BERT模型和分词器
- bert_model_name = 'bert-base-uncased'
- tokenizer = BertTokenizer.from_pretrained(bert_model_name)
-
- # 示例输入数据
- queries = ['I like cats', 'The sun is shining']
- positive_docs = ['I like dogs', 'The weather is beautiful']
- negative_docs = ['Snakes are dangerous', 'It is raining']
-
- # 超参数
- batch_size = 8
- max_length = 128
- learning_rate = 1e-5
- num_epochs = 5
- margin = 1.0
-
- # 创建数据集和数据加载器
- dataset = TripletRankingDataset(queries, positive_docs, negative_docs, tokenizer, max_length)
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
-
- # 初始化模型并加载预训练权重
- model = BERTTripletRankingModel(bert_model_name, hidden_size=model.bert.config.hidden_size)
- optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
-
- # 训练模型
- model.train()
-
- for epoch in range(num_epochs):
- total_loss = 0
-
- for input_ids_q, attention_masks_q, input_ids_p, attention_masks_p, input_ids_n, attention_masks_n in dataloader:
- optimizer.zero_grad()
-
- embeddings_q = model(inputids_q, attention_masks_q)
- embeddings_p = model(input_ids_p, attention_masks_p)
- embeddings_n = model(input_ids_n, attention_masks_n)
-
- loss = triplet_loss(embeddings_q, embeddings_p, embeddings_n, margin)
-
- total_loss += loss.item()
-
- loss.backward()
- optimizer.step()
-
- print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss:.4f}")
-
- # 推断模型
- model.eval()
-
- with torch.no_grad():
- embeddings = model.bert.embeddings.word_embeddings(dataset.input_ids_q)
- pairwise_distances = pairwise_distances(embeddings.numpy())
-
- # 输出结果
- for i, query in enumerate(queries):
- print(f"Query: {query}")
- print("Documents:")
-
- for j, doc in enumerate(positive_docs):
- doc_idx = pairwise_distances[0][i * len(positive_docs) + j]
- doc_dist = pairwise_distances[1][i * len(positive_docs) + j]
-
- print(f"Document index: {doc_idx}, Distance: {doc_dist:.4f}")
- print(f"Document: {doc}")
- print("")
-
- print("---------")