• 利用决策树找出最优特征组合


     

    目录

    一、基本流程

    二、实现代码

            2.1、导包

            2.2、读取数据

            2.3、特征聚合及衍生

            2.4、训练决策树

            2.5、解析树模型

    三、总结


            大致表达的意思是,对流水数据聚合特征。然后特征衍生,接着对衍生后的特征决策树训练,输出对应的特征组合。

            上图中,x1,x4,x8是叶子结点到根结点的路径,也就是说可以看成一组特征组合。同样,x1,x4,x9为一组。x1,x2,x6为一组。x1,x7,x6为一组,x1,x7,x4为一组。共四组,只需要解析出这几组数,然后去我们训练决策树的样本特征中把对应的特征提取出来,并返回就可以了。

     

    一、基本流程

            

     

    二、实现代码

            因为我这边是用的notebook写的,所以代码一段一段截取。

            2.1、导包

    1. import pymysql
    2. import sys
    3. from sqlalchemy import create_engine
    4. from sklearn.ensemble import GradientBoostingClassifier
    5. import pandas as pd
    6. import sys
    7. from sklearn.metrics import classification_report
    8. from sklearn import tree
    9. from sklearn.tree import _tree
    10. from graphviz import Source
    11. from ipywidgets import interactive
    12. from IPython.display import SVG, display
    13. from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, export_graphviz
    14. from sklearn import tree
    15. import graphviz
    16. from collections import deque
    17. import re

            2.2、读取数据

    1. def query_sql(cursor,sql):
    2. cursor.execute(sql)
    3. res = cursor.fetchall()
    4. col = [item[0] for item in cursor.description]
    5. crash_result = pd.DataFrame(res, columns=col)
    6. return crash_result

            2.3、特征聚合及衍生

    1. connection = pymysql.connect(host='ip', user='root', password='password', database='xxx',
    2. port=端口)
    3. cursor = connection.cursor()
    4. # 读取标签数据
    5. target_info = pd.read_excel(r'样本标签.xlsx')
    6. target_info = target_info[['卡号','是否可疑(人工)']]
    7. # 计算特征 —— 交易总金额
    8. sql = '''
    9. select `客户名称`,`查询账号`,`查询卡号`,sum(`交易金额`) `交易总金额`
    10. from `trade1`
    11. group by `客户名称`,`查询账号`,`查询卡号`
    12. '''
    13. df_sum_money = query_sql(cursor,sql)
    14. # 交易次数
    15. sql = '''
    16. select `客户名称`,`查询账号`,`查询卡号`,count(1) `交易笔数`
    17. from `trade1`
    18. group by `客户名称`,`查询账号`,`查询卡号`
    19. '''
    20. df_trade_count = query_sql(cursor,sql)
    21. # 入向交易笔数
    22. sql = '''
    23. select `客户名称`,`查询账号`,`查询卡号`,count(1) `入向交易笔数`
    24. from `trade1`
    25. where 借贷标志 = '进'
    26. group by `客户名称`,`查询账号`,`查询卡号`
    27. '''
    28. df_in_num = query_sql(cursor,sql)
    29. # 出向交易笔数
    30. sql = '''
    31. select `客户名称`,`查询账号`,`查询卡号`,count(1) `出向交易笔数`
    32. from `trade1`
    33. where 借贷标志 = '出'
    34. group by `客户名称`,`查询账号`,`查询卡号`
    35. '''
    36. df_out_num = query_sql(cursor,sql)
    37. # 交易总天数
    38. sql = '''
    39. select `客户名称`,`查询账号`,`查询卡号`,count(1) '交易总天数'
    40. from
    41. (
    42. select `客户名称`,`查询账号`,`查询卡号`,substring(`交易时间`,1,10)
    43. from trade1
    44. group by `客户名称`,`查询账号`,`查询卡号`,substring(`交易时间`,1,10)
    45. ) a
    46. group by `客户名称`,`查询账号`,`查询卡号`
    47. '''
    48. df_sum_day = query_sql(cursor,sql)
    49. merge1 = pd.merge(df_sum_money,df_trade_count,on=['客户名称', '查询账号', '查询卡号'], how='inner')
    50. merge2 = pd.merge(merge1,df_in_num,on=['客户名称', '查询账号', '查询卡号'], how='inner')
    51. merge3 = pd.merge(merge2,df_out_num,on=['客户名称', '查询账号', '查询卡号'], how='inner')
    52. merge = pd.merge(merge3,df_sum_day,on=['客户名称', '查询账号', '查询卡号'], how='inner')
    53. final_result = pd.merge(target_info,merge,left_on=['卡号'],right_on=['查询卡号'],how='inner')
    54. final_result = final_result.iloc[:,1:]
    55. final_result['是否可疑(人工)'] = final_result['是否可疑(人工)'].apply(lambda x:1 if x == '可疑' else 0)
    56. print(final_result)
    57. print(final_result.info())
    58. x = final_result[['交易总金额','交易笔数','入向交易笔数','出向交易笔数','交易总天数']]
    59. y = final_result['是否可疑(人工)']
    60. print(x[['交易总金额','交易笔数']])
    61. # 两两组合新特征
    62. new_x = pd.DataFrame()
    63. new_x['交易总金额_交易笔数'] = x['交易总金额']+x['交易笔数']
    64. new_x['交易总金额_入向交易笔数'] = x['交易总金额']+x['入向交易笔数']
    65. new_x['交易总金额_出向交易笔数'] = x['交易总金额']+x['出向交易笔数']
    66. new_x['交易总金额_交易总天数'] = x['交易总金额']+x['交易总天数']
    67. new_x['交易笔数_入向交易笔数'] = x['交易笔数']+x['入向交易笔数']
    68. new_x['交易笔数_出向交易笔数'] = x['交易笔数']+x['出向交易笔数']
    69. new_x['交易笔数_交易总天数'] = x['交易笔数']+x['交易总天数']
    70. new_x['入向交易笔数_出向交易笔数'] = x['入向交易笔数']+x['出向交易笔数']
    71. new_x['入向交易笔数_交易总天数'] = x['入向交易笔数']+x['交易总天数']
    72. new_x['出向交易笔数_交易总天数'] = x['出向交易笔数']+x['交易总天数']
    73. print(new_x)

            2.4、训练决策树

    1. Dtree = tree.DecisionTreeRegressor(max_depth=3,random_state=111)
    2. dtree = Dtree.fit(new_x, y)
    3. n_nodes = dtree.tree_.node_count
    4. children_left = dtree.tree_.children_left
    5. children_right = dtree.tree_.children_right
    6. feature = dtree.tree_.feature
    7. threshold = dtree.tree_.threshold
    8. dot_data = tree.export_graphvizdot_data = tree.export_graphviz(dtree, out_file=None, )
    9. graph = graphviz.Source(dot_data)
    10. graph.render("dt")
    11. '''
    12. 在文件的同级目录上生成一个dt和dt.pdf文件。
    13. dt.pdf文件存储树模型的可视化
    14. dt文件则是树模型对应的原始数据(非结构化)
    15. '''

            2.5、解析树模型

    1. relation = []
    2. node = []
    3. for c,i in enumerate(dot_data.replace('\n','').split(';')):
    4. if c == 0 or c == 1:
    5. continue
    6. if '->' in i:
    7. relation.append(i.split('[')[0].strip())
    8. else:
    9. if 'X[' not in i:
    10. i = i.split(' ')[0] + ' None'
    11. node.append(i)
    12. else:
    13. node.append(i)
    14. # 将结点信息转换为字典,方便后续拿取
    15. node_dict = {}
    16. for i in node:
    17. key = i[0:2]
    18. value = i[2:]
    19. node_dict[key.strip()] = value.strip()
    20. # 将结点指针信息转为字典,方便组合特征
    21. relation_dict = {}
    22. agg = []
    23. other = ''
    24. for c,i in enumerate(relation):
    25. i_info = i.replace(' ','').split('->')
    26. if i_info[0] == '0':
    27. agg.append(i_info[0])
    28. agg.append(i_info[-1])
    29. elif i_info[0] == agg[-1] and int(i_info[-1]) - int(i_info[0]) == 1:
    30. agg.append(i_info[-1])
    31. # elif int(i_info[-1]) - int(i_info[0]) != 1:
    32. else:
    33. # 同路径不同结点 与 其子结点
    34. other += f'{i_info[0]}' + f' {i_info[-1]}|'
    35. # 记录根结点位置
    36. node = []
    37. for c,i in enumerate(agg):
    38. if i == '0':
    39. node.append(c)
    40. node.append(len(agg))
    41. # 解析出组合
    42. agg_feature_index = []
    43. for c in range(0,len(node)-1,1):
    44. agg_feature_index.append(agg[node[c]:node[c+1]])
    45. # 合并不同路径下的所有相邻结点
    46. # 记录上一个结点的尾巴
    47. last = ['0']
    48. head = ['0']
    49. new_other = ''
    50. for i in other.split('|')[0:-1]:
    51. i = i.split(' ')
    52. if i[0] == last[-1]:
    53. new_other += f'{head[-1],i[0],i[-1]}|'
    54. else:
    55. new_other += f'{i[0],i[-1]}|'
    56. last.append(i[-1])
    57. head.append(i[0])
    58. # 筛选出重复的数据
    59. new_other_len = len(new_other.split('|'))
    60. # 因为有一个 | 。所以要多减去一个自然数
    61. all_other = []
    62. repetition = []
    63. for i in range(new_other_len-2):
    64. f1 = eval(tuple(new_other.split('|'))[i])
    65. f2 = eval(tuple(new_other.split('|'))[i+1])
    66. all_other.append(f1)
    67. if set(f1).issubset(f2) == True:
    68. repetition.append(f1)
    69. all_other.append(eval(tuple(new_other.split('|'))[new_other_len-2]))
    70. # 去掉重复的
    71. final_other = []
    72. for i in range(len(all_other)):
    73. if all_other[i] not in repetition:
    74. final_other.append(all_other[i])
    75. # 组合特征所在下标
    76. for i in final_other:
    77. try:
    78. head_index = agg_feature_index[0].index(i[0])
    79. result = agg_feature_index[0][:head_index+1]
    80. for j in i[1:]:
    81. result.append(j)
    82. except:
    83. head_index = agg_feature_index[1].index(i[0])
    84. result = agg_feature_index[1][:head_index+1]
    85. for j in i[1:]:
    86. result.append(j)
    87. finally:
    88. agg_feature_index.append(result)
    89. # 找出下标对应的特征
    90. feature_index = []
    91. for i in agg_feature_index:
    92. if i[:-1] in feature_index:
    93. continue
    94. else:
    95. feature_index.append(i[:-1])
    96. # 组合特征,结果输出
    97. feature = {}
    98. for c,i in enumerate(feature_index):
    99. f = []
    100. for j in i:
    101. for k in re.findall('\[.*\]',node_dict.get(j)):
    102. if len(k.split('\\n')[0].split('[')) == 3:
    103. # 取出对应下标
    104. fn = int(k.split('\\n')[0].replace('[label="','').split('<=')[0].strip().replace('X[','').replace(']',''))
    105. # 取出对应特征名称
    106. f.append(new_x.columns.tolist()[fn])
    107. feature[f'feature_{c+1}'] = f
    108. print(feature)
    109. # 最终输出从根结点到叶子结点的四种特征组合
    110. '''
    111. {'feature_1': ['交易总金额_入向交易笔数', '交易笔数_入向交易笔数', '入向交易笔数_交易总天数'], 'feature_2': ['交易总金额_入向交易笔数', '交易总金额_出向交易笔数', '交易笔数_交易总天数'], 'feature_3': ['交易总金额_入向交易笔数', '交易笔数_入向交易笔数', '出向交易笔数_交易总天数'], 'feature_4': ['交易总金额_入向交易笔数', '交易总金额_出向交易笔数', '交易笔数_入向交易笔数']}
    112. '''

    三、总结

            以上只不过是在做需求时想到的一些办法,具体效果不敢保证,只能说是提供一种特征组合的思想。但是具体得到那n个特征后,是加?是减?就不能确定了。

  • 相关阅读:
    【学习笔记01】node的认识和安装
    分享76个PHP源码总有一个适合你
    量产技术与成本比拼“升级”,谁能打赢4D成像雷达的规模化之战?
    P4_toturial练习1问题:ModuleNotFoundError: No module named ‘p4.tmp‘
    Linux学习记录——이십구 网络基础(2)
    TensorRT的结构
    墨西哥FBA海运头程货代,墨西哥海运几天到?
    这个开源项目超哇塞,手写照片在线生成
    ASP.net core 8.0网站发布
    C++:函数指针进阶二:指向对象成员函数的指针
  • 原文地址:https://blog.csdn.net/zkkkkkkkkkkkkk/article/details/126126407