• 基于cross_silo做联邦学习编程的学习


    框架是fedml

    我是一名初学者,若是有也研究联邦学习的朋友看见这篇博文,欢迎私信或者加我好友,一起讨论一起学习。

    我从client开始学习

    先是客户端初始化,clientinitialize

    1. from fedml.constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL, FEDML_CROSS_SILO_SCENARIO_HORIZONTAL
    2. from .fedml_client_master_manager import ClientMasterManager
    3. from .fedml_trainer_dist_adapter import TrainerDistAdapter
    4. def init_client(
    5. args,
    6. device,
    7. comm,
    8. client_rank,
    9. client_num,
    10. model,
    11. train_data_num,
    12. train_data_local_num_dict,
    13. train_data_local_dict,
    14. test_data_local_dict,
    15. model_trainer=None,
    16. ):
    17. backend=args.backend
    18. trainer_dist_adapter=get_trainer_dist_adapter(
    19. args,
    20. device,
    21. client_rank,
    22. model,
    23. train_data_num,
    24. train_data_local_num_dict,
    25. train_data_local_dict,
    26. test_data_local_dict,
    27. model_trainer,
    28. )
    29. if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:#垂直划分
    30. if args.proc_rank_in_silo == 0:
    31. client_manager=get_client_manager_master(
    32. args,trainer_dist_adapter,comm,client_rank, client_num, backend
    33. )
    34. else:
    35. client_manager=get_client_manager_salve(args, trainer_dist_adapter)
    36. elif args.scenario == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL:#水平划分
    37. client_manager=get_client_manager_master(
    38. args,trainer_dist_adapter,comm,client_rank,client_num,backend
    39. )
    40. else:
    41. raise Exception(
    42. "we do not support {}. Please check whether this is typo.".format(
    43. args.scenario
    44. )
    45. )
    46. client_manager.run()#配置好了客户端的管理,开始运行
    47. def get_trainer_dist_adapter(
    48. args,
    49. device,
    50. client_rank,
    51. model,
    52. train_data_num,
    53. train_data_local_num_dict,
    54. train_data_local_dict,
    55. test_data_local_dict,
    56. model_trainer,
    57. ):
    58. return TrainDistAdapter(
    59. args,
    60. device,
    61. client_rank,
    62. model,
    63. train_data_num,
    64. train_data_local_num_dict,
    65. train_data_local_dict,
    66. test_data_local_dict,
    67. model_trainer,
    68. )
    69. def get_client_manager_master(
    70. args,trainer_dist_adapter,comm,client_rank,client_num,backend
    71. ):
    72. return ClientMasterManager(#这个函数后文还有出现
    73. args,trainer_dist_adapter,comm,client_rank,client_num,backend
    74. )
    75. def get_client_manager_salve(args,trainer_dist_adapter):
    76. from .fedml_client_slave_manager import ClientSlaveManager
    77. return ClientSlaveManager(args,trainer_dist_adapter)
    78. #这两个函数后续都有说明

    2.fedml_client_master_manager.py

    这个用于横向联邦学习,另外一个用于纵向,因为我要做的是横向,所以我需要这个,要求后台可以支持mpi多机并行运算,有一个参数comm之前一直不知道他是什么意思,现在这个文件,感觉应该是comm_round全局迭代轮数

    这个是我认为比较重要的文件之一,在这个上面做一些改动

    1. import json
    2. import logging
    3. import platform
    4. import time
    5. import torch.distributed as dist
    6. from fedml.constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL
    7. from .message_define import MyMessage
    8. from .utils import convert_model_params_from_ddp, convert_model_params_to_ddp
    9. from ...core.distributed.client.client_manager import ClientManager
    10. from ...core.distributed.communication.message import Message
    11. from ...core.mlops.mlops_metrics import MLOpsMetrics
    12. from ...core.mlops.mlops_profiler_event import MLOpsProfilerEvent
    13. class ClientMasterManager(ClientManager):
    14. def __init__(
    15. self,args,trainer_dist_adapter,comm=None,rank=0,size=0,backend="MPI"
    16. ):
    17. super().__init__(args,comm,rank,size,backend)
    18. self.trainer_dist_adapter=trainer_dist_adapter
    19. self.args=args
    20. self.num_rounds=args.comm_round
    21. self.round_idx=0
    22. self.rank=rank
    23. self.client_real_ids=json.loads(args.client_id_list)
    24. #读取客户端的id
    25. logging.info("self.client_real_ids = {}".format(self.client_real_ids))
    26. # for the client, len(self.client_real_ids)==1: we only specify its client id in the list, not including others.
    27. self.client_real_id = self.client_real_ids[0]
    28. if hasattr(self.args, "using_mlops") and self.args.using_mlops:
    29. self.mlops_metrics = MLOpsMetrics()
    30. self.mlops_metrics.set_messenger(self.com_manager_status, args)
    31. self.mlops_event = MLOpsProfilerEvent(self.args)
    32. #判断是否适用mlops
    33. #登记收信的操作者
    34. def register_message_receive_handlers(self):
    35. self.register_message_receive_handler(
    36. MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready
    37. )
    38. self.register_message_receive_handler(
    39. MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.handle_message_check_status
    40. )
    41. self.register_message_receive_handler(
    42. MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init
    43. )
    44. self.register_message_receive_handler(
    45. MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT,
    46. self.handle_message_receive_model_from_server,
    47. )
    48. self.register_message_receive_handler(
    49. MyMessage.MSG_TYPE_S2C_FINISH, self.handle_message_finish,
    50. )
    51. def handle_message_connection_ready(self, msg_params):
    52. logging.info("Connection is ready!")
    53. if not self.has_sent_online_msg:
    54. self.has_sent_online_msg = True
    55. self.send_client_status(0)
    56. if hasattr(self.args, "using_mlops") and self.args.using_mlops:
    57. # Notify MLOps with training status.
    58. self.report_training_status(
    59. MyMessage.MSG_MLOPS_CLIENT_STATUS_INITIALIZING
    60. )
    61. # Open new process for report system performances to MQTT server
    62. MLOpsMetrics.report_sys_perf(self.args)
    63. def handle_message_check_status(self, msg_params):
    64. self.send_client_status(0)
    65. def handle_messsage_init(self,msg_params):
    66. global_model_params=msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
    67. data_silo_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)
    68. logging.info("data_silo_index = %s" % str(data_silo_index))
    69. self.report_training_status(MyMessage.MSG_MLOPS_CLIENT_STATUS_TRAINING)
    70. if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
    71. global_model_params = convert_model_params_to_ddp(global_model_params)
    72. self.sync_process_group(0, global_model_params, data_silo_index)
    73. self.trainer_dist_adapter.update_model(global_model_params)
    74. self.trainer_dist_adapter.update_dataset(int(data_silo_index))
    75. self.round_idx = 0
    76. self.__train()
    77. def handle_message_receive_model_from_server(self, msg_params):
    78. logging.info("handle_message_receive_model_from_server.")
    79. model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
    80. client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)
    81. if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
    82. model_params = convert_model_params_to_ddp(model_params)
    83. self.sync_process_group(self.round_idx, model_params, client_index)
    84. self.trainer_dist_adapter.update_model(model_params)
    85. self.trainer_dist_adapter.update_dataset(int(client_index))
    86. if self.round_idx == self.num_rounds - 1:
    87. # 这里可能需要动
    88. # Notify MLOps with the finished message
    89. if hasattr(self.args, "using_mlops") and self.args.using_mlops:
    90. self.mlops_metrics.report_client_id_status(
    91. self.args.run_id,
    92. self.client_real_id,
    93. MyMessage.MSG_MLOPS_CLIENT_STATUS_FINISHED,
    94. )
    95. return
    96. self.round_idx += 1
    97. self.__train()
    98. def handle_message_finish(self, msg_params):
    99. logging.info(" ====================cleanup ====================")
    100. self.cleanup()
    101. def cleanup(self):
    102. if hasattr(self.args, "using_mlops") and self.args.using_mlops:
    103. # mlops_metrics = MLOpsMetrics()
    104. # mlops_metrics.set_sys_reporting_status(False)
    105. pass
    106. self.finish()
    107. def send_model_to_sever(self,receive_id,weights,local_sample_num):
    108. tick=time.time()
    109. if hasattr(self.args, "using_mlops") and self.args.using_mlops:
    110. self.mlops_event.log_event_started(
    111. "comm_c2s", event_value=str(self.round_idx)
    112. )
    113. message = Message(
    114. MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER,
    115. self.client_real_id,
    116. receive_id,
    117. )
    118. message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights)
    119. message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num)
    120. self.send_message(message)
    121. MLOpsProfilerEvent.log_to_wandb(
    122. {"Communication/Send_Total": time.time() - tick}
    123. )
    124. # Report client model to MLOps
    125. if hasattr(self.args, "using_mlops") and self.args.using_mlops:
    126. model_url = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL)
    127. model_info = {
    128. "run_id": self.args.run_id,
    129. "edge_id": self.client_real_id,
    130. "round_idx": self.round_idx + 1,
    131. "client_model_s3_address": model_url,
    132. }
    133. self.mlops_metrics.report_client_model_info(model_info)
    134. #
    135. def send_client_status(self, receive_id, status="ONLINE"):
    136. logging.info("send_client_status")
    137. message = Message(
    138. MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id
    139. )
    140. sys_name = platform.system()
    141. if sys_name == "Darwin":
    142. sys_name = "Mac"
    143. # Debug for simulation mobile system
    144. # sys_name = MyMessage.MSG_CLIENT_OS_ANDROID
    145. message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_STATUS, status)
    146. message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, sys_name)
    147. self.send_message(message)
    148. def report_training_status(self, status):
    149. if hasattr(self.args, "using_mlops") and self.args.using_mlops:
    150. self.mlops_metrics.set_messenger(self.com_manager_status, self.args)
    151. self.mlops_metrics.report_client_training_status(
    152. self.client_real_id, status
    153. )
    154. def sync_process_group(
    155. self, round_idx, model_params=None, client_index=None, src=0
    156. ):
    157. logging.info("sending round number to pg")
    158. round_number = [round_idx, model_params, client_index]
    159. dist.broadcast_object_list(
    160. round_number,
    161. src=src,
    162. group=self.trainer_dist_adapter.process_group_manager.get_process_group(),
    163. )
    164. logging.info("round number %d broadcast to process group" % round_number[0])
    165. def __train(self):
    166. logging.info("#######training########### round_id = %d" % self.round_idx)
    167. if hasattr(self.args, "using_mlops") and self.args.using_mlops:
    168. self.mlops_event.log_event_started("train", event_value=str(self.round_idx))
    169. weights, local_sample_num = self.trainer_dist_adapter.train(self.round_idx)
    170. if hasattr(self.args, "using_mlops") and self.args.using_mlops:
    171. self.mlops_event.log_event_ended("train", event_value=str(self.round_idx))
    172. # the current model is still DDP-wrapped under cross-silo-hi setting
    173. if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
    174. weights = convert_model_params_from_ddp(weights)
    175. self.send_model_to_server(0, weights, local_sample_num)
    176. def run(self):
    177. super().run()

    client_launch.py

    这个感觉基本不用动吧

    1. import os
    2. import subprocess
    3. import torch
    4. from fedml.arguments import load_arguments
    5. from fedml.constants import (
    6. FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL,
    7. FEDML_TRAINING_PLATFORM_CROSS_SILO,
    8. FEDML_CROSS_SILO_SCENARIO_HORIZONTAL,
    9. )
    10. from fedml.device import get_device_type
    11. class CrossSiloLauncher:
    12. def launch_dist_trainer(torch_client_filename,inputs):
    13. args=load_arguments(FEDML_TRAINING_PLATFORM_CROSS_SILO)
    14. if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
    15. CrossSiloLauncher._run_cross_silo_hierarchical(
    16. args, torch_client_filename, inputs
    17. )
    18. elif args.scenarios == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL:
    19. CrossSiloLauncher._run_cross_silo_horizontal(
    20. args, torch_client_filename, inputs
    21. )
    22. else:
    23. raise Exception(
    24. "we do not support {}, check whether this is typo in args.scenario".format(
    25. args.scenario
    26. )
    27. )
    28. def _run_cross_silo_horizontal(args, torch_client_filename, inputs):
    29. python_path = subprocess.run(
    30. ["which", "python"], capture_output=True, text=True
    31. ).stdout.strip()
    32. process_arguments = [python_path, torch_client_filename] + inputs
    33. subprocess.run(process_arguments)
    34. def _run_cross_silo_hierarchical(args, torch_client_filename, inputs):
    35. def get_torchrun_arguments(node_rank):
    36. torchrun_path = subprocess.run(
    37. ["which", "torchrun"], capture_output=True, text=True
    38. ).stdout.strip()
    39. return [
    40. torchrun_path,
    41. f"--nnodes={args.n_node_in_silo}",
    42. f"--nproc_per_node={args.n_proc_per_node}",
    43. # "--rdzv_backend=c10d",
    44. f"--rdzv_endpoint={args.master_address}:{args.launcher_rdzv_port}",
    45. f"--node_rank={node_rank}",
    46. "--rdzv_id=hi_fl",
    47. torch_client_filename,
    48. ] + inputs
    49. network_interface = (
    50. None if not hasattr(args, "network_interface") else args.network_interface
    51. )
    52. print(
    53. f"Using network interface {network_interface} for process group and TRPC communication"
    54. )
    55. env_variables = {
    56. "OMP_NUM_THREADS": "4",
    57. }
    58. if network_interface:
    59. env_variables = {
    60. **env_variables,
    61. "NCCL_SOCKET_IFNAME": network_interface,
    62. "GLOO_SOCKET_IFNAME": network_interface,
    63. }
    64. if args.n_node_in_silo == 1:
    65. args.node_rank = 0
    66. args.manual_launch = True
    67. if not (hasattr(args, "n_proc_per_node") and args.n_proc_per_node):
    68. print("Number of processes per node not specified.")
    69. device_type = get_device_type(args)
    70. if torch.cuda.is_available() and device_type == "gpu":
    71. gpu_count = torch.cuda.device_count()
    72. print(f"Using number of GPUs ({gpu_count}) as number of processeses.")
    73. args.n_proc_per_node = gpu_count
    74. else:
    75. print(f"Using number 1 as number of processeses.")
    76. args.n_proc_per_node = 1
    77. if hasattr(args, "manual_launch") and args.manual_launch:
    78. print(f"Manual Client Launcher")
    79. node_rank = args.node_rank
    80. torchrun_cmd_arguments = get_torchrun_arguments(node_rank)
    81. process_args = torchrun_cmd_arguments
    82. print(f"Launching node {node_rank} of silo {args.rank}")
    83. subprocess.run(process_args, env=dict(os.environ, **env_variables))
    84. else:
    85. print(f"Automatic Client Launcher")
    86. which_pdsh = subprocess.run(
    87. ["which", "pdsh"], capture_output=True, text=True
    88. ).stdout.strip()
    89. if not which_pdsh:
    90. raise Exception(
    91. f"Silo {args.rank} has {args.n_node_in_silo} nodes. Automatic Client Launcher for more than 1 nodes requires PSDH."
    92. )
    93. print(f"Launching nodes using pdsh")
    94. os.environ["PDSH_RCMD_TYPE"] = "ssh"
    95. node_addresses = ",".join(args.node_addresses)
    96. pdsh_cmd_aruments = ["pdsh", "-w", node_addresses]
    97. exports = ""
    98. for key, val in env_variables.items():
    99. exports += "export {}={}; ".format(key, val)
    100. prerun_args = [
    101. exports,
    102. f"cd {os.path.abspath('.')};",
    103. ]
    104. node_rank = "%n"
    105. torchrun_cmd_arguments = get_torchrun_arguments(node_rank)
    106. process_args = pdsh_cmd_aruments + prerun_args + torchrun_cmd_arguments
    107. subprocess.run(process_args)

    fedml_trainer.py

    这里有一个train函数可能需要改

    1. import time
    2. from ...constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL
    3. from ...core.mlops.mlops_profiler_event import MLOpsProfilerEvent
    4. from fedml.data import split_data_for_dist_trainers
    5. class FedMLTrainer(object):
    6. def __init__(
    7. self,
    8. client_index,
    9. train_data_local_dict,
    10. train_data_local_num_dict,
    11. test_data_local_dict,
    12. train_data_num,
    13. device,
    14. args,
    15. model_trainer,
    16. ):
    17. self.trainer=model_trainer
    18. self.client_index=client_index
    19. if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
    20. self.train_data_local_dict = split_data_for_dist_trainers(
    21. train_data_local_dict, args.n_proc_in_silo
    22. )
    23. else:
    24. self.train_data_local_dict = train_data_local_dict
    25. self.train_data_local_num_dict=train_data_local_num_dict
    26. self.test_data_local_dict=test_data_local_dict
    27. self.all_train_data_num=train_data_num
    28. self.train_local=None
    29. self.local_sample_number=None
    30. self.test_local=None
    31. self.device=device
    32. self.args=args
    33. def update_model(self,weights):
    34. self.trainer.set_model_params(weights)
    35. def update_dataset(self,client_index):
    36. self.client_index=client_index
    37. if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
    38. self.train_local = self.train_data_local_dict[client_index][
    39. self.args.proc_rank_in_silo
    40. ]
    41. else:
    42. self.train_local = self.train_data_local_dict[client_index]
    43. self.local_sample_number = self.train_data_local_num_dict[client_index]
    44. self.test_local = self.test_data_local_dict[client_index]
    45. def train(self,round_idx=None):
    46. self.args.round_idx=round_idx
    47. tick=time.time()
    48. self.trainer.train(self.train_local, self.device, self.args)
    49. MLOpsProfilerEvent.log_to_wandb(
    50. {"Train/Time": time.time() - tick, "round": round_idx}
    51. )
    52. weights=self.trainer.get_model_params()
    53. return weights,self.local_sample_number
    54. def test(self):
    55. train_metrics=self.train.test(self.train_local, self.device, self.args)
    56. train_tot_correct,train_num_sample,train_loss=(
    57. train_metrics["test_correct"],
    58. train_metrics["test_total"],
    59. train_metrics["test_loss"],
    60. )
    61. test_metrics = self.trainer.test(self.test_local, self.device, self.args)
    62. test_tot_correct, test_num_sample, test_loss = (
    63. test_metrics["test_correct"],
    64. test_metrics["test_total"],
    65. test_metrics["test_loss"],
    66. )
    67. return(
    68. train_tot_correct,
    69. train_loss,
    70. train_num_sample,
    71. test_tot_correct,
    72. test_loss,
    73. test_num_sample,
    74. )

    fedml_trainer_dist_adapter.py

    1. import logging
    2. from fedml.constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL
    3. from .fedml_trainer import FedMLTrainer
    4. from .trainer.trainer_creator import create_model_trainer
    5. class TrainaDistAdapter:
    6. def __init__(
    7. self,
    8. args,
    9. device,
    10. client_rank,
    11. model,
    12. train_data_num,
    13. train_data_local_num_dict,
    14. train_data_local_dict,
    15. test_data_local_dict,
    16. model_trainer,
    17. ):
    18. model.to(device)
    19. if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
    20. from torch.nn.parallel import DistributedDataParallel as DDP
    21. from .process_group_manager import ProcessGroupManager
    22. only_gpu = args.using_gpu
    23. self.process_group_manager = ProcessGroupManager(
    24. args.proc_rank_in_silo,
    25. args.n_proc_in_silo,
    26. args.pg_master_address,
    27. args.pg_master_port,
    28. only_gpu,
    29. )
    30. model = DDP(model, device_ids=[device] if only_gpu else None)
    31. if model_trainer is None:
    32. model_trainer = create_model_trainer(args, model)
    33. else:
    34. model_trainer.model = model
    35. client_index = client_rank - 1
    36. model_trainer.set_id(client_index)
    37. logging.info("Initiating Trainer")
    38. trainer = self.get_trainer(
    39. client_index,
    40. train_data_local_dict,
    41. train_data_local_num_dict,
    42. test_data_local_dict,
    43. train_data_num,
    44. device,
    45. args,
    46. model_trainer,
    47. )
    48. self.client_index=client_index
    49. self.client_rank=client_rank
    50. self.device=device
    51. self.trainer=trainer
    52. self.args=args
    53. def get_trainer(
    54. self,
    55. client_index,
    56. train_data_local_dict,
    57. train_data_local_num_dict,
    58. test_data_local_dict,
    59. train_data_num,
    60. device,
    61. args,
    62. model_trainer,
    63. ):
    64. return FedMLTrainer(
    65. client_index,
    66. train_data_local_dict,
    67. train_data_local_num_dict,
    68. test_data_local_dict,
    69. train_data_num,
    70. device,
    71. args,
    72. model_trainer,
    73. )
    74. def train(self,round_idx):
    75. weights,local_sample_num = self.trainer.train(round_idx)
    76. return weights,local_sample_num
    77. def update_model(self,model_params):
    78. self.trainer.update_model(model_params)
    79. def update_dataset(self,client_index=None):
    80. _client_index=client_index or self.client_index
    81. self.trainer.update_model(int(_client_index))
    82. def cleanup_pg(self):
    83. if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
    84. logging.info(
    85. "Cleaningup process group for client %s in silo %s"
    86. % (self.args.proc_rank_in_silo, self.args.rank_in_node)
    87. )
    88. self.process_group_manager.cleanup()

    其他的文件我觉得是不需要动的

  • 相关阅读:
    Redis 缓存预热、预热数据选取策略、缓存保温、性能边界
    19.前端笔记-CSS-显示隐藏元素
    RK3568 安卓12 EC20模块NOCONN没有ip的问题(已解决)
    Redis7.2.3集群安装,新增节点,删除节点,分配哈希槽,常见问题
    项目问题-常见错误
    从文件里一次读取一行
    vscode git 拉取报错 在签出前,请清理存储库工作树
    基于Java毕业设计中小企业人力资源管理系统源码+系统+mysql+lw文档+部署软件
    【CNN-SVM回归预测】基于CNN-SVM实现数据回归预测附matlab代码
    香山杯-2023-Crypto
  • 原文地址:https://blog.csdn.net/kling_bling/article/details/126729854