• AI4 决策树的生成与训练-信息增益


    1. # -*- coding: UTF-8 -*-
    2. from math import log
    3. import pandas as pd
    4. dataSet = pd.read_csv('dataSet.csv', header=None).values.tolist()
    5. # 给定一个数据集,calcInfoEnt可以用于计算一个数据集的信息熵,可直接调用
    6. # 也可不使用,通过自己的方式计算信息增益
    7. def calcInfoEnt(data):
    8. numEntres = len(data)
    9. labelcnt = {} # 用于统计正负样本的个数
    10. for item in data:
    11. if item[-1] not in labelcnt:
    12. labelcnt[item[-1]] = 0
    13. labelcnt[item[-1]] += 1
    14. infoEnt = 0.0
    15. for item in labelcnt: # 根据信息熵的公式计算信息熵
    16. curr_info_entr = float(labelcnt[item]) / numEntres
    17. infoEnt = infoEnt - curr_info_entr * log(curr_info_entr, 2)
    18. return infoEnt
    19. # 返回值 infoEnt 为数据集的信息熵
    20. # 给定一个数据集,用于切分一个子集,可直接用于计算某一特征的信息增益
    21. # 也可不使用,通过自己的方式计算信息增益
    22. # dataSet是要划分的数据集,i 代表第i个特征的索引index
    23. # value对应该特征的某一取值
    24. def create_sub_dataset(dataSet, i, value):
    25. res = []
    26. for item in dataSet:
    27. if item[i] == value:
    28. curr_data = item[:i] + item[i + 1:]
    29. res.append(curr_data)
    30. return res
    31. def calc_max_info_gain(dataSet): # 计算所有特征的最大信息增益,dataSet为给定的数据集
    32. n = len(dataSet[0]) - 1 # n 是特征的数量,-1 的原因是最后一列是分类标签
    33. total_entropy = calcInfoEnt(dataSet) # 整体数据集的信息熵
    34. max_info_gain = [0, 0] # 返回值初始化
    35. # code start here
    36. best_feature=-1
    37. for i in range(n):
    38. featList=[example[i] for example in dataSet]
    39. uniqueVals=set(featList)
    40. newEntropy=0.0
    41. for value in uniqueVals:
    42. subDataSet=create_sub_dataset(dataSet,i,value)
    43. prob=len(subDataSet)/float(len(dataSet))
    44. newEntropy+=prob*calcInfoEnt(subDataSet)
    45. infoGain=total_entropy-newEntropy
    46. if(infoGain>max_info_gain[1]):
    47. max_info_gain[1]=infoGain
    48. max_info_gain[0]=i
    49. best_feature=1
    50. # code end here
    51. return max_info_gain
    52. if __name__ == '__main__':
    53. info_res = calc_max_info_gain(dataSet)
    54. print("信息增益最大的特征索引为:{0},对应的信息增益为{1}".format(info_res[0], info_res[1]))

  • 相关阅读:
    【并发编程五】c++进程通信——共享内存(shared memmory)
    【测试沉思录】11. 如何进行基准测试?
    1.4.16 实验16:ABR汇总
    Makefile从入门到入门
    期货量化交易客户端开源教学第九节——新用户注册
    springmvc异常处理解析#ExceptionHandlerExceptionResolver
    WSL2 ubuntu18.04安装ROS
    Mybatis的XML配置文件
    MaskRcnn训练自己的数据集
    python 数据挖掘与机器学习核心技术
  • 原文地址:https://blog.csdn.net/kling_bling/article/details/126365047