• classification_report加入topk计算


    参考:https://blog.csdn.net/dipizhong7224/article/details/104579159
    官方文档:https://github.com/scikit-learn/scikit-learn/blob/7f9bad99d6e0a3e8ddf92a7e5561245224dab102/sklearn/metrics/_classification.py#L1551

    def classification_report_topk(y_true, y_pred, topk=1, labelnames=None, digits=2, output_dict=False,):
        '''
        y_true: [1,1,2,3]
        y_pred: [[1,3],[3,2],[2,3],[1,2]]
        labelnames: [1,2,3]
        '''
        assert topk <= len(y_pred[0]), 'topk out of bounds!'
        if labelnames==None:
            from sklearn.utils.multiclass import unique_labels
            if type(y_pred)==list:
                labelnames = unique_labels(y_true, sum(y_pred,[]))
            elif type(y_pred)==numpy.ndarray:
                labelnames = unique_labels(y_true, y_pred.flatten())
            else:
                labelnames = unique_labels(y_true, y_true)
        rows = []
        tp_sums = 0
        y_pred=[each[0:topk] for each in y_pred]
        for label in labelnames:
            cur_res=[]
            tp_fn=y_true.count(label)#TP+FN
            #TP+FP
            tp_fp=0
            for i in y_pred:
                if label in i:
                    tp_fp+=1
            #TP
            # 计算acc时需要使用tp
            tp=0
            for i in range(len(y_true)):
                if y_true[i] == label and label in y_pred[i]:
                    tp+=1
            tp_sums+=tp
            support=tp_fn
            try:
                precision=tp/tp_fp
                recall=tp/tp_fn
                f1_score=2/((1/precision)+(1/recall))
            except ZeroDivisionError:
                precision=0.0
                recall=0.0
                f1_score=0.0
            rows.append([str(label),precision,recall,f1_score, support])
    
        accuracy_topk = tp_sums / len(y_true)
        rows.append(['accuracy', accuracy_topk, accuracy_topk, accuracy_topk, len(y_true)])
        
        average_options = ["macro", "weighted"]
        
        weights_weighted = [rows[i][4] for i in range(len(rows)-1)]
        weights_options = [None, weights_weighted]
        precision = [row[1] for row in rows[:-1]]
        recall = [row[2] for row in rows[:-1]]
        f1_score = [row[3] for row in rows[:-1]]
        for avg_name, weight in zip(average_options,weights_options):
            p = np.average(precision,weights=weight)
            r = np.average(recall,weights=weight)
            f1 = np.average(f1_score,weights=weight)
            rows.append([avg_name+' avg',p,r,f1,len(y_true)])
        
        # print format
        headers = ["precision", "recall", "f1-score", "support"]
        if output_dict:
            report_dict = {label[0]: label[1:] for label in rows}
            for label, scores in report_dict.items():
                report_dict[label] = dict(zip(headers, [float(i) for i in scores]))
            return report_dict
        else:
            target_names = [rows[i][0] for i in range(len(rows))]
            longest_last_line_heading = "weighted avg"
            name_width = max(len(cn) for cn in target_names)
            width = max(name_width, len(longest_last_line_heading), digits)
            head_fmt = "{:>{width}s} " + " {:>9}" * len(headers)
            report = head_fmt.format("", *headers, width=width)
            report += "\n\n"
            row_fmt = "{:>{width}s} " + " {:>9.{digits}f}" * 3 + " {:>9}\n"
            for row in rows:
                report += row_fmt.format(*row, width=width, digits=digits)
            report += "\n"
            return report
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
  • 相关阅读:
    cmake-format使用教程
    群晖linux ——设置短密码、免密码登录、多个群晖免密登录
    [七夕节]——html炫酷七夕情人节表白动画特效
    在Win11上部署ChatGLM2-6B详细步骤--(下)开始部署
    AP5193 DC-DC恒流转换器 消防应急 灯汽车灯 应急日光灯太阳能灯驱动IC
    Python数据分析11——Seaborn绘图
    从车窗升降一探 Android 车机的重要 API:车辆属性 CarProperty
    Map接口遍历方法
    Lightrun还可以这样用?
    wordpress网站搭建(centos stream 9)
  • 原文地址:https://blog.csdn.net/joyce_peng/article/details/132737026