• 数据压缩STC稀疏三元压缩算法复现


    1. 数据集介绍

    MINIST数据集

    MNIST是一个手写体数字的图片数据集,该数据集来由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,50%来自人口普查局的工作人员。该数据集的收集目的是希望通过算法,实现对手写数字的识别。

    2. logistic模型

    class logistic(nn.Module):
        """
        logistic模型,用于MINIST图片分类预测
        """
    
        def __init__(self, in_size=32 * 32 * 1, num_classes=10):
            super(logistic, self).__init__()
            self.linear = nn.Linear(in_size, num_classes)
    
        def forward(self, x):
            out = x.view(x.size(0), -1)
            out = self.linear(out)
            return out
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    3. 分布式培训设备模型

    class DistributedTrainingDevice(object):
        '''
        分布式培训设备类(客户端或服务器)
        dataloader: 由数据点(x,y)组成的pytorch数据集
        model: pytorch神经网络
        hyperparameters:包含所有超参数的python dict
        experiment: 实验类型
        '''
    
        def __init__(self, dataloader, model, hyperparameters, experiment):
            self.hp = hyperparameters
            self.xp = experiment
            self.loader = dataloader
            self.model = model
            self.loss_fn = nn.CrossEntropyLoss()
    
        def copy(self, target, source):
            """拷贝超参数,结果保存在target中"""
            for name in target:
                target[name].data = source[name].data.clone()
    
        def add(self, target, source):
            """超参数做加法,结果保存在target中"""
            for name in target:
                target[name].data += source[name].data.clone()
    
        def subtract(self, target, source):
            """超参数做减法,结果保存在target中"""
            for name in target:
                target[name].data -= source[name].data.clone()
    
        def subtract_(self, target, minuend, subtrahend):
            """超参数做减法(minuend-subtrahend),结果保存在target中"""
            for name in target:
                target[name].data = minuend[name].data.clone() - subtrahend[name].data.clone()
    
        def approx_v(self, T, p, frac):
            if frac < 1.0:
                n_elements = T.numel()
                n_sample = min(int(max(np.ceil(n_elements * frac), np.ceil(100 / p))), n_elements)
                n_top = int(np.ceil(n_sample * p))
    
                if n_elements == n_sample:
                    i = 0
                else:
                    i = np.random.randint(n_elements - n_sample)
    
                topk, _ = torch.topk(T.flatten()[i:i + n_sample], n_top)
                if topk[-1] == 0.0 or topk[-1] == T.max():
                    return self.approx_v(T, p, 1.0)
            else:
                n_elements = T.numel()
                n_top = int(np.ceil(n_elements * p))
                topk, _ = torch.topk(T.flatten(), n_top)  # 返回列表中最大的n_top个值
    
            return topk[-1], topk
    
        def stc(self, T, hp):
            """稀疏三元组压缩算法"""
            hp_ = {'p': 0.001, 'approx': 1.0}
            hp_.update(hp)
    
            T_abs = torch.abs(T)
    
            v, topk = self.approx_v(T_abs, hp_["p"], hp_["approx"])
            mean = torch.mean(topk)  # 前n_top的均值
    
            out_ = torch.where(T >= v, mean, torch.Tensor([0.0]).to(device))  # 大于均值的重新赋值为均值,小于自己的赋值为0
            out = torch.where(T <= -v, -mean, out_)  # 小于副的均值的赋值为-v,大于的赋值为out_对应索引值
    
            return out
    
        def compress(self, target, source):
            '''
            分别对每一个超参数进行稀疏三元压缩
            '''
            for name in target:
                target[name].data = self.stc(source[name].data.clone(), self.hp)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78

    4. 客户端模型

    class Client(DistributedTrainingDevice):
        """
        客户端类,继承分布式培训设备类
        """
    
        def __init__(self, dataloader, model, hyperparameters, experiment, id_num=0):
            super().__init__(dataloader, model, hyperparameters, experiment)
    
            self.id = id_num
    
            # 超参数
            self.W = {name: value for name, value in self.model.named_parameters()}
            self.W_old = {name: torch.zeros(value.shape).to(device) for name, value in self.W.items()}
            self.dW = {name: torch.zeros(value.shape).to(device) for name, value in self.W.items()}
            self.dW_compressed = {name: torch.zeros(value.shape).to(device) for name, value in self.W.items()}
            self.A = {name: torch.zeros(value.shape).to(device) for name, value in self.W.items()}
    
            self.n_params = sum([T.numel() for T in self.W.values()])
            self.bits_sent = []
    
            optimizer_object = getattr(optim, self.hp['optimizer'])
            optimizer_parameters = {k: v for k, v in self.hp.items() if k in optimizer_object.__init__.__code__.co_varnames}
    
            self.optimizer = optimizer_object(self.model.parameters(), **optimizer_parameters)
    
            # 学习率动态变化
            self.scheduler = getattr(optim.lr_scheduler, self.hp['lr_decay'][0])(self.optimizer, **self.hp['lr_decay'][1])
    
            # 状态记录
            self.epoch = 0
            self.train_loss = 0.0
    
        def synchronize_with_server(self, server):
            # W_client = W_server
            self.copy(target=self.W, source=server.W)
    
        def train_cnn(self, iterations):
    
            running_loss = 0.0
            for i in range(iterations):
    
                try:  # Load new batch of data
                    x, y = next(self.epoch_loader)
                except:  # Next epoch
                    self.epoch_loader = iter(self.loader)
                    self.epoch += 1
    
                    # 动态调整lr
                    if isinstance(self.scheduler, optim.lr_scheduler.LambdaLR):
                        self.scheduler.step()
                    if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau) and 'loss_test' in self.xp.results:
                        self.scheduler.step(self.xp.results['loss_test'][-1])
    
                    x, y = next(self.epoch_loader)
    
                x, y = x.to(device), y.to(device)
    
                self.optimizer.zero_grad()
    
                y_ = self.model(x)
    
                loss = self.loss_fn(y_, y)
                loss.backward()
                self.optimizer.step()
    
                running_loss += loss.item()
    
            return running_loss / iterations
    
        def compute_weight_update(self, iterations=1):
    
            # 设置为训练模式
            self.model.train()
    
            # W_old = W
            self.copy(target=self.W_old, source=self.W)
    
            # W = SGD(W, D)
            self.train_loss = self.train_cnn(iterations)
    
            # dW = W - W_old
            self.subtract_(target=self.dW, minuend=self.W, subtrahend=self.W_old)
    
        def compress_weight_update_up(self, compression=None, accumulate=False, count_bits=False):
    
            if accumulate and compression[0] != "none":
                # 超参数压缩,联邦通信优化
                self.add(target=self.A, source=self.dW)
                self.compress(target=self.dW_compressed, source=self.A)
                self.subtract(target=self.A, source=self.dW_compressed)
    
            else:
                # 没有任何压缩措施
                self.compress(target=self.dW_compressed, source=self.dW, )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94

    5. 服务端模型

    class Server(DistributedTrainingDevice):
        """
        服务端类,继承分布式培训设备类
        """
    
        def __init__(self, dataloader, model, hyperparameters, experiment, stats):
            super().__init__(dataloader, model, hyperparameters, experiment)
    
            # Parameters
            self.W = {name: value for name, value in self.model.named_parameters()}
            self.dW_compressed = {name: torch.zeros(value.shape).to(device) for name, value in self.W.items()}
            self.dW = {name: torch.zeros(value.shape).to(device) for name, value in self.W.items()}
    
            self.A = {name: torch.zeros(value.shape).to(device) for name, value in self.W.items()}
    
            self.n_params = sum([T.numel() for T in self.W.values()])
            self.bits_sent = []
    
            self.client_sizes = torch.Tensor(stats["split"])
    
        def average(self, target, sources):
            """求超参数平均函数,平均值赋值在target中"""
            for name in target:
                target[name].data = torch.mean(torch.stack([source[name].data for source in sources]), dim=0).clone()
    
        def aggregate_weight_updates(self, clients, aggregation="mean"):
            # dW = aggregate(dW_i, i=1,..,n)
            self.average(target=self.dW, sources=[client.dW_compressed for client in clients])
    
        def compress_weight_update_down(self, compression=None, accumulate=False, count_bits=False):
            if accumulate and compression[0] != "none":
                # 对超参数进行稀疏三元压缩
                self.add(target=self.A, source=self.dW)
                self.compress(target=self.dW_compressed, source=self.A)
                self.subtract(target=self.A, source=self.dW_compressed)
    
            else:
                self.compress(target=self.dW_compressed, source=self.dW)
    
            self.add(target=self.W, source=self.dW_compressed)
    
        def evaluate(self, loader=None, max_samples=50000, verbose=True):
            """评估服务端全局模型的训练效果"""
            self.model.eval()
    
            eval_loss, correct, samples, iters = 0.0, 0, 0, 0
            if not loader:
                loader = self.loader
            with torch.no_grad():
                for i, (x, y) in enumerate(loader):
    
                    x, y = x.to(device), y.to(device)
                    y_ = self.model(x)
                    _, predicted = torch.max(y_.data, 1)
                    eval_loss += self.loss_fn(y_, y).item()
                    correct += (predicted == y).sum().item()
                    samples += y_.shape[0]
                    iters += 1
    
                    if samples >= max_samples:
                        break
                if verbose:
                    print("Evaluated on {} samples ({} batches)".format(samples, iters))
    
                results_dict = {'loss': eval_loss / iters, 'accuracy': correct / samples}
    
            return results_dict
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67

    6. 图片数据集DataLoader类

    class CustomImageDataset(Dataset):
        '''
        图片数据集DataLoader类
        inputs : numpy array [n_data x shape]
        labels : numpy array [n_data (x 1)]
        '''
    
        def __init__(self, inputs, labels, transforms=None):
            assert inputs.shape[0] == labels.shape[0]
            self.inputs = torch.Tensor(inputs)
            self.labels = torch.Tensor(labels).long()
            self.transforms = transforms
    
        def __getitem__(self, index):
            img, label = self.inputs[index], self.labels[index]
    
            if self.transforms is not None:
                img = self.transforms(img)
    
            return (img, label)
    
        def __len__(self):
            return self.inputs.shape[0]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    7. MNIST数据下载与标准化

    def get_mnist():
        '''下载mnist数据集数据'''
        data_train = torchvision.datasets.MNIST(root=os.path.join(DATA_PATH, "MNIST"), train=True, download=True)
        data_test = torchvision.datasets.MNIST(root=os.path.join(DATA_PATH, "MNIST"), train=False, download=True)
    
        x_train, y_train = data_train.train_data.numpy().reshape(-1, 1, 28, 28) / 255, np.array(data_train.train_labels)
        x_test, y_test = data_test.test_data.numpy().reshape(-1, 1, 28, 28) / 255, np.array(data_test.test_labels)
    
        return x_train, y_train, x_test, y_test
    
    def get_default_data_transforms(name, train=True, verbose=True):
        """数据集标准化处理函数"""
        transforms_train = {
            'mnist': transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((32, 32)),
                # transforms.RandomCrop(32, padding=4),
                transforms.ToTensor(),
                transforms.Normalize((0.06078,), (0.1957,))
            ]),
        }
        transforms_eval = {
            'mnist': transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
                transforms.Normalize((0.06078,), (0.1957,))
            ]),
        }
    
        if verbose:
            print("\nData preprocessing: ")
            for transformation in transforms_train[name].transforms:
                print(' -', transformation)
            print()
    
        return (transforms_train[name], transforms_eval[name])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37

    8. 数据集分配

    def split_image_data(data, labels, n_clients=10, classes_per_client=10, shuffle=True, verbose=True, balancedness=None):
        '''
        分割数据集
        data : [n_data x shape]
        labels : [n_data (x 1)] from 0 to n_labels
        '''
        # constants
        n_data = data.shape[0]
        n_labels = np.max(labels) + 1
    
        if balancedness >= 1.0:
            data_per_client = [n_data // n_clients] * n_clients
            data_per_client_per_class = [data_per_client[0] // classes_per_client] * n_clients
        else:
            fracs = balancedness ** np.linspace(0, n_clients - 1, n_clients)
            fracs /= np.sum(fracs)
            fracs = 0.1 / n_clients + (1 - 0.1) * fracs
            data_per_client = [np.floor(frac * n_data).astype('int') for frac in fracs]
    
            data_per_client = data_per_client[::-1]
    
            data_per_client_per_class = [np.maximum(1, nd // classes_per_client) for nd in data_per_client]
    
        if sum(data_per_client) > n_data:
            print("Impossible Split")
            exit()
    
        # sort for labels
        data_idcs = [[] for i in range(n_labels)]
        for j, label in enumerate(labels):
            data_idcs[label] += [j]
        if shuffle:
            for idcs in data_idcs:
                np.random.shuffle(idcs)
    
        # split data among clients
        clients_split = []
        c = 0
        for i in range(n_clients):
            client_idcs = []
            budget = data_per_client[i]
            c = np.random.randint(n_labels)
            while budget > 0:
                take = min(data_per_client_per_class[i], len(data_idcs[c]), budget)
    
                client_idcs += data_idcs[c][:take]
                data_idcs[c] = data_idcs[c][take:]
    
                budget -= take
                c = (c + 1) % n_labels
    
            clients_split += [(data[client_idcs], labels[client_idcs])]
    
        return clients_split
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54

    9. 读取数据集

    def get_data_loaders(hp, verbose=True):
        """获取数据集的dataloader形式"""
        x_train, y_train, x_test, y_test = get_mnist()  # 获取数据集
    
        transforms_train, transforms_eval = get_default_data_transforms(hp['dataset'], verbose=False)  # 数据集标准化处理
    
        split = split_image_data(x_train, y_train, n_clients=hp['n_clients'],
                                 classes_per_client=hp['classes_per_client'], balancedness=hp['balancedness'],
                                 verbose=verbose)  # 根据客户端分割数据集
        # 建立数据集的Dataloader
        client_loaders = [torch.utils.data.DataLoader(CustomImageDataset(x, y, transforms_train),
                                                      batch_size=hp['batch_size'], shuffle=True) for x, y in split]
        train_loader = torch.utils.data.DataLoader(CustomImageDataset(x_train, y_train, transforms_eval), batch_size=100,
                                                   shuffle=False)
        test_loader = torch.utils.data.DataLoader(CustomImageDataset(x_test, y_test, transforms_eval), batch_size=100,
                                                  shuffle=False)
    
        stats = {"split": [x.shape[0] for x, y in split]}
    
        return client_loaders, train_loader, test_loader, stats
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    10. 模型训练

    def train():
        hp = {
            "communication_rounds": 20,
            "dataset": "mnist",
            "n_clients": 50,
            "classes_per_client": 10,
            "local_iterations": 1,
            "weight_decay": 0.0,
            "optimizer": "SGD",
            "log_frequency": -100,
            "count_bits": False,
            "participation_rate": 1.0,
            "balancedness": 1.0,
            "compression_up": ["stc", {"p": 0.001}],
            "compression_down": ["stc", {"p": 0.002}],
            "accumulation_up": True,
            "accumulation_down": True,
            "aggregation": "mean",
            'type': 'CNN', 'lr': 0.04,
            'batch_size': 100,
            'lr_decay': ['LambdaLR', {'lr_lambda': lambda epoch: 1.0}],
            'momentum': 0.0,
        }
        xp = {
            "iterations": 100,
            "participation_rate": 0.5,
            "momentum": 0.9,
            "compression": [
                "stc_updown",
                {
                    "p_up": 0.001,
                    "p_down": 0.002
                }
            ],
            "log_frequency": 30,
            "log_path": "results/trash/"
        }
        # 加载数据集并根据客户端来进行划分
        client_loaders, train_loader, test_loader, stats = get_data_loaders(hp)
        # 初始化服务器与客户端的神经网络模型
        net = logistic()
        clients = [Client(loader, net, hp, xp, id_num=i) for i, loader in enumerate(client_loaders)]
        server = Server(test_loader, net, hp, xp, stats)
        # 开始训练
        print("Start Distributed Training..\n")
        t1 = time.time()
        for c_round in range(1, hp['communication_rounds'] + 1):
            # 随机选择一定的客户端来训练
            participating_clients = random.sample(clients, int(len(clients) * hp['participation_rate']))
            # 客户端
            for client in participating_clients:
                client.synchronize_with_server(server)  # 加载当前全局模型参数
                client.compute_weight_update(hp['local_iterations'])  # 权重更性
                client.compress_weight_update_up(compression=hp['compression_up'], accumulate=hp['accumulation_up'],
                                                 count_bits=hp["count_bits"])  # 超参数压缩,联邦通信优化
    
            # 服务端
            server.aggregate_weight_updates(participating_clients, aggregation=hp['aggregation'])  # 聚集客户端的权重
            server.compress_weight_update_down(compression=hp['compression_down'], accumulate=hp['accumulation_down'],
                                               count_bits=hp["count_bits"])  # 超参数压缩,联邦通信优化
            # 全局模型评估
            print("Evaluate...")
            results_train = server.evaluate(max_samples=5000, loader=train_loader)
            results_test = server.evaluate(max_samples=10000)
            # 日志情况
            print({'communication_round': c_round, 'lr': clients[0].optimizer.__dict__['param_groups'][0]['lr'],
                    'epoch': clients[0].epoch, 'iteration': c_round * hp['local_iterations']})
            print({'client{}_loss'.format(client.id): client.train_loss for client in clients})
    
            print({key + '_train': value for key, value in results_train.items()})
            print({key + '_test': value for key, value in results_test.items()})
    
            print({'time': time.time() - t1})
            total_time = time.time() - t1
            avrg_time_per_c_round = (total_time) / c_round
            e = int(avrg_time_per_c_round * (hp['communication_rounds'] - c_round))
            print("Remaining Time (approx.):", '{:02d}:{:02d}:{:02d}'.format(e // 3600, (e % 3600 // 60), e % 60),
                  "[{:.2f}%]\n".format(c_round / hp['communication_rounds'] * 100))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78

    11. 运行结果

    image-20220614225700353
    源代码:https://github.com/1957787636/FederalLearning

  • 相关阅读:
    思维导图之规范与重构
    会议OA项目----我的审批
    in(...) 可能会让你排查Bug到崩溃,哈哈哈
    网络请求与数据解析
    MySQL和Oracle JDBC驱动包下载步骤
    【ARK UI】HarmonyOS ETS的启动页的实现
    坑爹,线上同步近 3w 个用户导致链路阻塞引入发的线上问题,你经历过吗?
    4_使用预训练模型 微调训练CIFAR10
    JS基础6--逻辑运算符
    3D 纹理渲染如何帮助设计师有效、清晰地表达设计理念
  • 原文地址:https://blog.csdn.net/qq_45724216/article/details/126030490