def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[List[int], List[int]]:
print("predictions:", predictions)
print("label_ids:", label_ids)
# preds = np.argmax(predictions, axis=2) # ori,注释
preds = predictions # 改,不用重复计算了
print("preds:", preds)
batch_size, seq_len = preds.shape
print("batch_size:", batch_size)
print("seq_len:", seq_len)
out_label_list = [[] for _ in range(batch_size)]
preds_list = [[] for _ in range(batch_size)]
for i in range(batch_size):
for j in range(seq_len):
if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
out_label_list[i].append(label_map[label_ids[i][j]])
preds_list[i].append(label_map[preds[i][j]])
return preds_list, out_label_list
# 防止numpy矩阵过大,增加的函数,把数据提前运算
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1)
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
当然,如果你能改虚拟内存,把硬盘的空间,切一部分出来,当虚拟内存(CPU内存的2倍),那代码运行速度会更快!
参考的设置方法:https://zhuanlan.zhihu.com/p/37332255