def update(self, a, b):
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
with torch.no_grad():
k = (a >= 0) & (a < n) # 0,1
inds = n * a[k].to(torch.int64) + b[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
这是语义分割生成混淆矩阵的代码,详细流程:
1、传入参数a为 target.flatten(),b为 output.argmax(1).flatten(),b为模型预测值,找到概率最大的索引值作为最终预测值,所以会用到.argmax(1)
2、初始化混淆矩阵,大小为n×n
3、k返回的是一个布尔类型的一维向量,超过或小于类别范围返回False
4、根据k中False的索引值,将真实标签对应位置元素丢掉,即a[k]的长度可能会小于a,并且根据有意义的真实标签索引值找到对应预测值进行计算,乘以n相当于flatten的逆过程,方便reshape之后变成一个n×n矩阵