差在一个 sigmoid 函数上
见下面的代码
import torch from torch.nn import functional as F logits = torch.rand(16,7) ys = torch.randint(0,2,(16,7)) F.binary_cross_entropy_with_logits(logits,ys.float()) == F.binary_cross_entropy(torch.sigmoid(logits),ys.float())
京公网安备 11010502049817号