框架是fedml
我是一名初学者,若是有也研究联邦学习的朋友看见这篇博文,欢迎私信或者加我好友,一起讨论一起学习。

我从client开始学习
- from fedml.constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL, FEDML_CROSS_SILO_SCENARIO_HORIZONTAL
- from .fedml_client_master_manager import ClientMasterManager
- from .fedml_trainer_dist_adapter import TrainerDistAdapter
-
-
- def init_client(
- args,
- device,
- comm,
- client_rank,
- client_num,
- model,
- train_data_num,
- train_data_local_num_dict,
- train_data_local_dict,
- test_data_local_dict,
- model_trainer=None,
- ):
- backend=args.backend
- trainer_dist_adapter=get_trainer_dist_adapter(
- args,
- device,
- client_rank,
- model,
- train_data_num,
- train_data_local_num_dict,
- train_data_local_dict,
- test_data_local_dict,
- model_trainer,
-
- )
- if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:#垂直划分
- if args.proc_rank_in_silo == 0:
- client_manager=get_client_manager_master(
- args,trainer_dist_adapter,comm,client_rank, client_num, backend
- )
- else:
- client_manager=get_client_manager_salve(args, trainer_dist_adapter)
-
- elif args.scenario == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL:#水平划分
- client_manager=get_client_manager_master(
- args,trainer_dist_adapter,comm,client_rank,client_num,backend
- )
- else:
- raise Exception(
- "we do not support {}. Please check whether this is typo.".format(
- args.scenario
- )
- )
- client_manager.run()#配置好了客户端的管理,开始运行
- def get_trainer_dist_adapter(
- args,
- device,
- client_rank,
- model,
- train_data_num,
- train_data_local_num_dict,
- train_data_local_dict,
- test_data_local_dict,
- model_trainer,
- ):
- return TrainDistAdapter(
- args,
- device,
- client_rank,
- model,
- train_data_num,
- train_data_local_num_dict,
- train_data_local_dict,
- test_data_local_dict,
- model_trainer,
- )
- def get_client_manager_master(
- args,trainer_dist_adapter,comm,client_rank,client_num,backend
- ):
- return ClientMasterManager(#这个函数后文还有出现
- args,trainer_dist_adapter,comm,client_rank,client_num,backend
- )
- def get_client_manager_salve(args,trainer_dist_adapter):
- from .fedml_client_slave_manager import ClientSlaveManager
- return ClientSlaveManager(args,trainer_dist_adapter)
- #这两个函数后续都有说明
这个用于横向联邦学习,另外一个用于纵向,因为我要做的是横向,所以我需要这个,要求后台可以支持mpi多机并行运算,有一个参数comm之前一直不知道他是什么意思,现在这个文件,感觉应该是comm_round全局迭代轮数
这个是我认为比较重要的文件之一,在这个上面做一些改动
- import json
- import logging
- import platform
- import time
- import torch.distributed as dist
-
- from fedml.constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL
- from .message_define import MyMessage
- from .utils import convert_model_params_from_ddp, convert_model_params_to_ddp
- from ...core.distributed.client.client_manager import ClientManager
- from ...core.distributed.communication.message import Message
- from ...core.mlops.mlops_metrics import MLOpsMetrics
- from ...core.mlops.mlops_profiler_event import MLOpsProfilerEvent
-
- class ClientMasterManager(ClientManager):
- def __init__(
- self,args,trainer_dist_adapter,comm=None,rank=0,size=0,backend="MPI"
- ):
- super().__init__(args,comm,rank,size,backend)
- self.trainer_dist_adapter=trainer_dist_adapter
- self.args=args
- self.num_rounds=args.comm_round
- self.round_idx=0
- self.rank=rank
- self.client_real_ids=json.loads(args.client_id_list)
- #读取客户端的id
- logging.info("self.client_real_ids = {}".format(self.client_real_ids))
- # for the client, len(self.client_real_ids)==1: we only specify its client id in the list, not including others.
- self.client_real_id = self.client_real_ids[0]
- if hasattr(self.args, "using_mlops") and self.args.using_mlops:
- self.mlops_metrics = MLOpsMetrics()
- self.mlops_metrics.set_messenger(self.com_manager_status, args)
- self.mlops_event = MLOpsProfilerEvent(self.args)
- #判断是否适用mlops
-
- #登记收信的操作者
- def register_message_receive_handlers(self):
- self.register_message_receive_handler(
- MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready
- )
-
- self.register_message_receive_handler(
- MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.handle_message_check_status
- )
-
- self.register_message_receive_handler(
- MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init
- )
- self.register_message_receive_handler(
- MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT,
- self.handle_message_receive_model_from_server,
- )
-
- self.register_message_receive_handler(
- MyMessage.MSG_TYPE_S2C_FINISH, self.handle_message_finish,
- )
- def handle_message_connection_ready(self, msg_params):
- logging.info("Connection is ready!")
- if not self.has_sent_online_msg:
- self.has_sent_online_msg = True
- self.send_client_status(0)
-
- if hasattr(self.args, "using_mlops") and self.args.using_mlops:
- # Notify MLOps with training status.
- self.report_training_status(
- MyMessage.MSG_MLOPS_CLIENT_STATUS_INITIALIZING
- )
-
- # Open new process for report system performances to MQTT server
- MLOpsMetrics.report_sys_perf(self.args)
-
- def handle_message_check_status(self, msg_params):
- self.send_client_status(0)
- def handle_messsage_init(self,msg_params):
- global_model_params=msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
- data_silo_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)
- logging.info("data_silo_index = %s" % str(data_silo_index))
- self.report_training_status(MyMessage.MSG_MLOPS_CLIENT_STATUS_TRAINING)
-
- if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
- global_model_params = convert_model_params_to_ddp(global_model_params)
- self.sync_process_group(0, global_model_params, data_silo_index)
-
- self.trainer_dist_adapter.update_model(global_model_params)
- self.trainer_dist_adapter.update_dataset(int(data_silo_index))
- self.round_idx = 0
-
- self.__train()
- def handle_message_receive_model_from_server(self, msg_params):
- logging.info("handle_message_receive_model_from_server.")
- model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
- client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)
- if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
- model_params = convert_model_params_to_ddp(model_params)
- self.sync_process_group(self.round_idx, model_params, client_index)
-
- self.trainer_dist_adapter.update_model(model_params)
- self.trainer_dist_adapter.update_dataset(int(client_index))
- if self.round_idx == self.num_rounds - 1:
- # 这里可能需要动
-
- # Notify MLOps with the finished message
- if hasattr(self.args, "using_mlops") and self.args.using_mlops:
- self.mlops_metrics.report_client_id_status(
- self.args.run_id,
- self.client_real_id,
- MyMessage.MSG_MLOPS_CLIENT_STATUS_FINISHED,
- )
- return
- self.round_idx += 1
- self.__train()
- def handle_message_finish(self, msg_params):
- logging.info(" ====================cleanup ====================")
- self.cleanup()
- def cleanup(self):
- if hasattr(self.args, "using_mlops") and self.args.using_mlops:
- # mlops_metrics = MLOpsMetrics()
- # mlops_metrics.set_sys_reporting_status(False)
- pass
- self.finish()
- def send_model_to_sever(self,receive_id,weights,local_sample_num):
- tick=time.time()
- if hasattr(self.args, "using_mlops") and self.args.using_mlops:
- self.mlops_event.log_event_started(
- "comm_c2s", event_value=str(self.round_idx)
- )
- message = Message(
- MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER,
- self.client_real_id,
- receive_id,
- )
- message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights)
- message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num)
- self.send_message(message)
- MLOpsProfilerEvent.log_to_wandb(
- {"Communication/Send_Total": time.time() - tick}
- )
- # Report client model to MLOps
- if hasattr(self.args, "using_mlops") and self.args.using_mlops:
- model_url = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL)
- model_info = {
- "run_id": self.args.run_id,
- "edge_id": self.client_real_id,
- "round_idx": self.round_idx + 1,
- "client_model_s3_address": model_url,
- }
- self.mlops_metrics.report_client_model_info(model_info)
-
- #
-
- def send_client_status(self, receive_id, status="ONLINE"):
- logging.info("send_client_status")
- message = Message(
- MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id
- )
- sys_name = platform.system()
- if sys_name == "Darwin":
- sys_name = "Mac"
- # Debug for simulation mobile system
- # sys_name = MyMessage.MSG_CLIENT_OS_ANDROID
-
- message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_STATUS, status)
- message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, sys_name)
- self.send_message(message)
- def report_training_status(self, status):
- if hasattr(self.args, "using_mlops") and self.args.using_mlops:
- self.mlops_metrics.set_messenger(self.com_manager_status, self.args)
- self.mlops_metrics.report_client_training_status(
- self.client_real_id, status
- )
-
- def sync_process_group(
- self, round_idx, model_params=None, client_index=None, src=0
- ):
- logging.info("sending round number to pg")
- round_number = [round_idx, model_params, client_index]
- dist.broadcast_object_list(
- round_number,
- src=src,
- group=self.trainer_dist_adapter.process_group_manager.get_process_group(),
- )
- logging.info("round number %d broadcast to process group" % round_number[0])
-
- def __train(self):
- logging.info("#######training########### round_id = %d" % self.round_idx)
- if hasattr(self.args, "using_mlops") and self.args.using_mlops:
- self.mlops_event.log_event_started("train", event_value=str(self.round_idx))
-
- weights, local_sample_num = self.trainer_dist_adapter.train(self.round_idx)
-
- if hasattr(self.args, "using_mlops") and self.args.using_mlops:
- self.mlops_event.log_event_ended("train", event_value=str(self.round_idx))
-
- # the current model is still DDP-wrapped under cross-silo-hi setting
- if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
- weights = convert_model_params_from_ddp(weights)
-
- self.send_model_to_server(0, weights, local_sample_num)
-
- def run(self):
- super().run()
这个感觉基本不用动吧
- import os
- import subprocess
- import torch
- from fedml.arguments import load_arguments
- from fedml.constants import (
- FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL,
- FEDML_TRAINING_PLATFORM_CROSS_SILO,
- FEDML_CROSS_SILO_SCENARIO_HORIZONTAL,
- )
- from fedml.device import get_device_type
-
- class CrossSiloLauncher:
- def launch_dist_trainer(torch_client_filename,inputs):
- args=load_arguments(FEDML_TRAINING_PLATFORM_CROSS_SILO)
- if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
- CrossSiloLauncher._run_cross_silo_hierarchical(
- args, torch_client_filename, inputs
- )
- elif args.scenarios == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL:
- CrossSiloLauncher._run_cross_silo_horizontal(
- args, torch_client_filename, inputs
- )
- else:
- raise Exception(
- "we do not support {}, check whether this is typo in args.scenario".format(
- args.scenario
- )
- )
- def _run_cross_silo_horizontal(args, torch_client_filename, inputs):
- python_path = subprocess.run(
- ["which", "python"], capture_output=True, text=True
- ).stdout.strip()
- process_arguments = [python_path, torch_client_filename] + inputs
- subprocess.run(process_arguments)
- def _run_cross_silo_hierarchical(args, torch_client_filename, inputs):
- def get_torchrun_arguments(node_rank):
- torchrun_path = subprocess.run(
- ["which", "torchrun"], capture_output=True, text=True
- ).stdout.strip()
-
- return [
- torchrun_path,
- f"--nnodes={args.n_node_in_silo}",
- f"--nproc_per_node={args.n_proc_per_node}",
- # "--rdzv_backend=c10d",
- f"--rdzv_endpoint={args.master_address}:{args.launcher_rdzv_port}",
- f"--node_rank={node_rank}",
- "--rdzv_id=hi_fl",
- torch_client_filename,
- ] + inputs
-
- network_interface = (
- None if not hasattr(args, "network_interface") else args.network_interface
- )
- print(
- f"Using network interface {network_interface} for process group and TRPC communication"
- )
- env_variables = {
- "OMP_NUM_THREADS": "4",
- }
- if network_interface:
- env_variables = {
- **env_variables,
- "NCCL_SOCKET_IFNAME": network_interface,
- "GLOO_SOCKET_IFNAME": network_interface,
- }
-
- if args.n_node_in_silo == 1:
- args.node_rank = 0
- args.manual_launch = True
- if not (hasattr(args, "n_proc_per_node") and args.n_proc_per_node):
- print("Number of processes per node not specified.")
- device_type = get_device_type(args)
- if torch.cuda.is_available() and device_type == "gpu":
- gpu_count = torch.cuda.device_count()
- print(f"Using number of GPUs ({gpu_count}) as number of processeses.")
- args.n_proc_per_node = gpu_count
- else:
- print(f"Using number 1 as number of processeses.")
- args.n_proc_per_node = 1
-
- if hasattr(args, "manual_launch") and args.manual_launch:
- print(f"Manual Client Launcher")
- node_rank = args.node_rank
- torchrun_cmd_arguments = get_torchrun_arguments(node_rank)
- process_args = torchrun_cmd_arguments
- print(f"Launching node {node_rank} of silo {args.rank}")
- subprocess.run(process_args, env=dict(os.environ, **env_variables))
-
- else:
- print(f"Automatic Client Launcher")
-
- which_pdsh = subprocess.run(
- ["which", "pdsh"], capture_output=True, text=True
- ).stdout.strip()
-
- if not which_pdsh:
- raise Exception(
- f"Silo {args.rank} has {args.n_node_in_silo} nodes. Automatic Client Launcher for more than 1 nodes requires PSDH."
- )
-
- print(f"Launching nodes using pdsh")
-
- os.environ["PDSH_RCMD_TYPE"] = "ssh"
- node_addresses = ",".join(args.node_addresses)
- pdsh_cmd_aruments = ["pdsh", "-w", node_addresses]
-
- exports = ""
- for key, val in env_variables.items():
- exports += "export {}={}; ".format(key, val)
- prerun_args = [
- exports,
- f"cd {os.path.abspath('.')};",
- ]
-
- node_rank = "%n"
- torchrun_cmd_arguments = get_torchrun_arguments(node_rank)
- process_args = pdsh_cmd_aruments + prerun_args + torchrun_cmd_arguments
- subprocess.run(process_args)
这里有一个train函数可能需要改
- import time
-
- from ...constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL
- from ...core.mlops.mlops_profiler_event import MLOpsProfilerEvent
- from fedml.data import split_data_for_dist_trainers
-
- class FedMLTrainer(object):
- def __init__(
- self,
- client_index,
- train_data_local_dict,
- train_data_local_num_dict,
- test_data_local_dict,
- train_data_num,
- device,
- args,
- model_trainer,
- ):
- self.trainer=model_trainer
- self.client_index=client_index
- if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
- self.train_data_local_dict = split_data_for_dist_trainers(
- train_data_local_dict, args.n_proc_in_silo
- )
- else:
- self.train_data_local_dict = train_data_local_dict
- self.train_data_local_num_dict=train_data_local_num_dict
- self.test_data_local_dict=test_data_local_dict
- self.all_train_data_num=train_data_num
- self.train_local=None
- self.local_sample_number=None
- self.test_local=None
- self.device=device
- self.args=args
-
- def update_model(self,weights):
- self.trainer.set_model_params(weights)
-
- def update_dataset(self,client_index):
- self.client_index=client_index
- if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
- self.train_local = self.train_data_local_dict[client_index][
- self.args.proc_rank_in_silo
- ]
- else:
- self.train_local = self.train_data_local_dict[client_index]
- self.local_sample_number = self.train_data_local_num_dict[client_index]
- self.test_local = self.test_data_local_dict[client_index]
-
-
- def train(self,round_idx=None):
- self.args.round_idx=round_idx
- tick=time.time()
- self.trainer.train(self.train_local, self.device, self.args)
- MLOpsProfilerEvent.log_to_wandb(
- {"Train/Time": time.time() - tick, "round": round_idx}
- )
- weights=self.trainer.get_model_params()
- return weights,self.local_sample_number
-
-
- def test(self):
- train_metrics=self.train.test(self.train_local, self.device, self.args)
- train_tot_correct,train_num_sample,train_loss=(
- train_metrics["test_correct"],
- train_metrics["test_total"],
- train_metrics["test_loss"],
- )
-
- test_metrics = self.trainer.test(self.test_local, self.device, self.args)
- test_tot_correct, test_num_sample, test_loss = (
- test_metrics["test_correct"],
- test_metrics["test_total"],
- test_metrics["test_loss"],
- )
-
- return(
- train_tot_correct,
- train_loss,
- train_num_sample,
- test_tot_correct,
- test_loss,
- test_num_sample,
- )
- import logging
-
- from fedml.constants import FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL
- from .fedml_trainer import FedMLTrainer
- from .trainer.trainer_creator import create_model_trainer
-
- class TrainaDistAdapter:
- def __init__(
- self,
- args,
- device,
- client_rank,
- model,
- train_data_num,
- train_data_local_num_dict,
- train_data_local_dict,
- test_data_local_dict,
- model_trainer,
- ):
- model.to(device)
- if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
- from torch.nn.parallel import DistributedDataParallel as DDP
- from .process_group_manager import ProcessGroupManager
-
- only_gpu = args.using_gpu
- self.process_group_manager = ProcessGroupManager(
- args.proc_rank_in_silo,
- args.n_proc_in_silo,
- args.pg_master_address,
- args.pg_master_port,
- only_gpu,
- )
- model = DDP(model, device_ids=[device] if only_gpu else None)
-
- if model_trainer is None:
- model_trainer = create_model_trainer(args, model)
- else:
- model_trainer.model = model
-
- client_index = client_rank - 1
-
- model_trainer.set_id(client_index)
-
- logging.info("Initiating Trainer")
- trainer = self.get_trainer(
- client_index,
- train_data_local_dict,
- train_data_local_num_dict,
- test_data_local_dict,
- train_data_num,
- device,
- args,
- model_trainer,
- )
- self.client_index=client_index
- self.client_rank=client_rank
- self.device=device
- self.trainer=trainer
- self.args=args
-
-
- def get_trainer(
- self,
- client_index,
- train_data_local_dict,
- train_data_local_num_dict,
- test_data_local_dict,
- train_data_num,
- device,
- args,
- model_trainer,
- ):
- return FedMLTrainer(
- client_index,
- train_data_local_dict,
- train_data_local_num_dict,
- test_data_local_dict,
- train_data_num,
- device,
- args,
- model_trainer,
- )
-
-
- def train(self,round_idx):
- weights,local_sample_num = self.trainer.train(round_idx)
- return weights,local_sample_num
- def update_model(self,model_params):
- self.trainer.update_model(model_params)
- def update_dataset(self,client_index=None):
- _client_index=client_index or self.client_index
- self.trainer.update_model(int(_client_index))
-
- def cleanup_pg(self):
- if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL:
- logging.info(
- "Cleaningup process group for client %s in silo %s"
- % (self.args.proc_rank_in_silo, self.args.rank_in_node)
- )
- self.process_group_manager.cleanup()
其他的文件我觉得是不需要动的