• 隐私计算 FATE - 多分类神经网络算法测试


    一、说明

    本文分享基于 Fate 使用 横向联邦 神经网络算法 对 多分类 的数据进行 模型训练,并使用该模型对数据进行 多分类预测

    • 二分类算法:是指待预测的 label 标签的取值只有两种;直白来讲就是每个实例的可能类别只有两种 (0 或者 1),例如性别只有  或者 ;此时的分类算法其实是在构建一个分类线将数据划分为两个类别。
    • 多分类算法:是指待预测的 label 标签的取值可能有多种情况,例如个人爱好可能有 篮球足球电影 等等多种类型。常见算法:Softmax、SVM、KNN、决策树。

    关于 Fate 的核心概念、单机部署、训练以及预测请参考以下相关文章:

    二、准备训练数据

    上传到 Fate 里的数据有两个字段名必需是规定的,分别是主键为 id 字段和分类字段为 y 字段,y 字段就是所谓的待预测的 label 标签;其他的特征字段 (属性) 可任意填写,例如下面例子中的 x0 - x9

    例如有一条用户数据为: 收入 : 10000,负债 : 5000,是否有还款能力 : 1 ;数据中的 收入 和 负债 就是特征字段,而 是否有还款能力 就是分类字段。

    本文只描述关键部分,关于详细的模型训练步骤,请查看文章《隐私计算 FATE - 模型训练

    2.1. guest 端

    10 条数据,包含 1 个分类字段 y 和 10 个标签字段 x0 - x9

    y 值有 0、1、2、3 四个分类

    上传到 Fate 中,表名为 muti_breast_homo_guest 命名空间为 experiment

    2.2. host 端

    10 条数据,字段与 guest 端一样,但是内容不一样

    上传到 Fate 中,表名为 muti_breast_homo_host 命名空间为 experiment

    三、执行训练任务

    3.1. 准备 dsl 文件

    创建文件 homo_nn_dsl.json 内容如下 :

    1. {
    2. "components": {
    3. "reader_0": {
    4. "module": "Reader",
    5. "output": {
    6. "data": [
    7. "data"
    8. ]
    9. }
    10. },
    11. "data_transform_0": {
    12. "module": "DataTransform",
    13. "input": {
    14. "data": {
    15. "data": [
    16. "reader_0.data"
    17. ]
    18. }
    19. },
    20. "output": {
    21. "data": [
    22. "data"
    23. ],
    24. "model": [
    25. "model"
    26. ]
    27. }
    28. },
    29. "homo_nn_0": {
    30. "module": "HomoNN",
    31. "input": {
    32. "data": {
    33. "train_data": [
    34. "data_transform_0.data"
    35. ]
    36. }
    37. },
    38. "output": {
    39. "data": [
    40. "data"
    41. ],
    42. "model": [
    43. "model"
    44. ]
    45. }
    46. }
    47. }
    48. }

    3.2. 准备 conf 文件

    创建文件 homo_nn_multi_label_conf.json 内容如下 :

    1. {
    2. "dsl_version": 2,
    3. "initiator": {
    4. "role": "guest",
    5. "party_id": 9999
    6. },
    7. "role": {
    8. "arbiter": [
    9. 10000
    10. ],
    11. "host": [
    12. 10000
    13. ],
    14. "guest": [
    15. 9999
    16. ]
    17. },
    18. "component_parameters": {
    19. "common": {
    20. "data_transform_0": {
    21. "with_label": true
    22. },
    23. "homo_nn_0": {
    24. "encode_label": true,
    25. "max_iter": 15,
    26. "batch_size": -1,
    27. "early_stop": {
    28. "early_stop": "diff",
    29. "eps": 0.0001
    30. },
    31. "optimizer": {
    32. "learning_rate": 0.05,
    33. "decay": 0.0,
    34. "beta_1": 0.9,
    35. "beta_2": 0.999,
    36. "epsilon": 1e-07,
    37. "amsgrad": false,
    38. "optimizer": "Adam"
    39. },
    40. "loss": "categorical_crossentropy",
    41. "metrics": [
    42. "accuracy"
    43. ],
    44. "nn_define": {
    45. "class_name": "Sequential",
    46. "config": {
    47. "name": "sequential",
    48. "layers": [
    49. {
    50. "class_name": "Dense",
    51. "config": {
    52. "name": "dense",
    53. "trainable": true,
    54. "batch_input_shape": [
    55. null,
    56. 18
    57. ],
    58. "dtype": "float32",
    59. "units": 5,
    60. "activation": "relu",
    61. "use_bias": true,
    62. "kernel_initializer": {
    63. "class_name": "GlorotUniform",
    64. "config": {
    65. "seed": null,
    66. "dtype": "float32"
    67. }
    68. },
    69. "bias_initializer": {
    70. "class_name": "Zeros",
    71. "config": {
    72. "dtype": "float32"
    73. }
    74. },
    75. "kernel_regularizer": null,
    76. "bias_regularizer": null,
    77. "activity_regularizer": null,
    78. "kernel_constraint": null,
    79. "bias_constraint": null
    80. }
    81. },
    82. {
    83. "class_name": "Dense",
    84. "config": {
    85. "name": "dense_1",
    86. "trainable": true,
    87. "dtype": "float32",
    88. "units": 4,
    89. "activation": "sigmoid",
    90. "use_bias": true,
    91. "kernel_initializer": {
    92. "class_name": "GlorotUniform",
    93. "config": {
    94. "seed": null,
    95. "dtype": "float32"
    96. }
    97. },
    98. "bias_initializer": {
    99. "class_name": "Zeros",
    100. "config": {
    101. "dtype": "float32"
    102. }
    103. },
    104. "kernel_regularizer": null,
    105. "bias_regularizer": null,
    106. "activity_regularizer": null,
    107. "kernel_constraint": null,
    108. "bias_constraint": null
    109. }
    110. }
    111. ]
    112. },
    113. "keras_version": "2.2.4-tf",
    114. "backend": "tensorflow"
    115. },
    116. "config_type": "keras"
    117. }
    118. },
    119. "role": {
    120. "host": {
    121. "0": {
    122. "reader_0": {
    123. "table": {
    124. "name": "muti_breast_homo_host",
    125. "namespace": "experiment"
    126. }
    127. }
    128. }
    129. },
    130. "guest": {
    131. "0": {
    132. "reader_0": {
    133. "table": {
    134. "name": "muti_breast_homo_guest",
    135. "namespace": "experiment"
    136. }
    137. }
    138. }
    139. }
    140. }
    141. }
    142. }

    注意 reader_0 组件的表名和命名空间需与上传数据时配置的一致。

    3.3. 提交任务

    执行以下命令:

    flow job submit -d homo_nn_dsl.json -c homo_nn_multi_label_conf.json
    

    执行成功后,查看 dashboard 显示:

    四、准备预测数据

    与前面训练的数据字段一样,但是内容不一样,y 值全为 0

    4.1. guest 端

    上传到 Fate 中,表名为 predict_muti_breast_homo_guest 命名空间为 experiment

    4.2. host 端

    上传到 Fate 中,表名为 predict_muti_breast_homo_host 命名空间为 experiment

    五、准备预测配置

    本文只描述关键部分,关于详细的预测步骤,请查看文章《隐私计算 FATE - 离线预测

    创建文件 homo_nn_multi_label_predict.json 内容如下 :

    1. {
    2. "dsl_version": 2,
    3. "initiator": {
    4. "role": "guest",
    5. "party_id": 9999
    6. },
    7. "role": {
    8. "arbiter": [
    9. 10000
    10. ],
    11. "host": [
    12. 10000
    13. ],
    14. "guest": [
    15. 9999
    16. ]
    17. },
    18. "job_parameters": {
    19. "common": {
    20. "model_id": "arbiter-10000#guest-9999#host-10000#model",
    21. "model_version": "202207061504081543620",
    22. "job_type": "predict"
    23. }
    24. },
    25. "component_parameters": {
    26. "role": {
    27. "guest": {
    28. "0": {
    29. "reader_0": {
    30. "table": {
    31. "name": "predict_muti_breast_homo_guest",
    32. "namespace": "experiment"
    33. }
    34. }
    35. }
    36. },
    37. "host": {
    38. "0": {
    39. "reader_0": {
    40. "table": {
    41. "name": "predict_muti_breast_homo_host",
    42. "namespace": "experiment"
    43. }
    44. }
    45. }
    46. }
    47. }
    48. }
    49. }

    注意以下两点:

    1. model_id 和 model_version 需修改为模型部署后的版本号。

    2. reader_0 组件的表名和命名空间需与上传数据时配置的一致。

    六、执行预测任务

    执行以下命令:

    flow job submit -c homo_nn_multi_label_predict.json
    

    执行成功后,查看 homo_nn_0 组件的数据输出:

    可以看到算法输出的预测结果。

  • 相关阅读:
    ChatGPT 桌面客户端正式发布
    cmd使用ssh连接Linux脚本
    小程序页面路由传参的方法?
    【T】03
    Julia绘图初步:Plots
    快速排序、归并排序、堆排序的C++实现_独家原创
    github双重身份验证2FAs扫码
    注释之背后:代码的解释者与保护者
    《SpringBoot 手册》:国际化组件 MessageSource
    Java Kafka实现消息的生产和消费
  • 原文地址:https://blog.csdn.net/Xiaohong0716/article/details/126247041