• python-pytorch 利用pytorch对堆叠自编码器进行训练和验证


    一、数据生成

    随机生成一些数据来模拟训练和验证数据集:

    import torch
    
    # 随机生成数据
    n_samples = 1000
    n_features = 784  # 例如,28x28图像的像素数
    train_data = torch.rand(n_samples, n_features)
    val_data = torch.rand(int(n_samples * 0.1), n_features)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    二、定义自编码器模型

    import torch.nn as nn
    
    class Autoencoder(nn.Module):
        def __init__(self, input_size, hidden_size):
            super(Autoencoder, self).__init__()
            self.encoder = nn.Sequential(
                nn.Linear(input_size, hidden_size),
                nn.Tanh())
            self.decoder = nn.Sequential(
                nn.Linear(hidden_size, input_size),
                nn.Tanh())
    
        def forward(self, x):
            x = self.encoder(x)
            x = self.decoder(x)
            return x
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    三、训练函数

    定义一个函数来训练自编码器:

    def train_ae(model, train_loader, val_loader, num_epochs, criterion, optimizer):
        for epoch in range(num_epochs):
            # Training
            model.train()
            train_loss = 0
            for batch_data in train_loader:
                optimizer.zero_grad()
                outputs = model(batch_data)
                loss = criterion(outputs, batch_data)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
            
            train_loss /= len(train_loader)
            print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}")
    
            # Validation
            model.eval()
            val_loss = 0
            with torch.no_grad():
                for batch_data in val_loader:
                    outputs = model(batch_data)
                    loss = criterion(outputs, batch_data)
                    val_loss += loss.item()
    
            val_loss /= len(val_loader)
            print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}")
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    四、训练堆叠自编码器

    使用上面定义的函数来训练自编码器:

    from torch.utils.data import DataLoader
    
    # DataLoader
    batch_size = 32
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    
    # 训练第一个自编码器
    ae1 = Autoencoder(input_size=784, hidden_size=400)
    optimizer = torch.optim.Adam(ae1.parameters(), lr=0.001)
    criterion = nn.MSELoss()
    train_ae(ae1, train_loader, val_loader, 10, criterion, optimizer)
    
    # 使用第一个自编码器的编码器对数据进行编码
    encoded_train_data = []
    for data in train_loader:
        encoded_train_data.append(ae1.encoder(data))
    encoded_train_loader = DataLoader(torch.cat(encoded_train_data), batch_size=batch_size, shuffle=True)
    
    encoded_val_data = []
    for data in val_loader:
        encoded_val_data.append(ae1.encoder(data))
    encoded_val_loader = DataLoader(torch.cat(encoded_val_data), batch_size=batch_size, shuffle=False)
    
    # 训练第二个自编码器
    ae2 = Autoencoder(input_size=400, hidden_size=200)
    optimizer = torch.optim.Adam(ae2.parameters(), lr=0.001)
    train_ae(ae2, encoded_train_loader, encoded_val_loader, 10, criterion, optimizer)
    
    # 使用第二个自编码器的编码器对数据进行编码
    encoded_train_data = []
    for data in train_loader:
        encoded_train_data.append(ae2.encoder(data))
    encoded_train_loader = DataLoader(torch.cat(encoded_train_data), batch_size=batch_size, shuffle=True)
    
    encoded_val_data = []
    for data in val_loader:
        encoded_val_data.append(ae2.encoder(data))
    encoded_val_loader = DataLoader(torch.cat(encoded_val_data), batch_size=batch_size, shuffle=False)
    
    # 训练第三个自编码器
    ae3 = Autoencoder(input_size=400, hidden_size=200)
    optimizer = torch.optim.Adam(ae3.parameters(), lr=0.001)
    train_ae(ae3, encoded_train_loader, encoded_val_loader, 10, criterion, optimizer)
    
    # 使用第三个自编码器的编码器对数据进行编码
    encoded_train_data = []
    for data in train_loader:
        encoded_train_data.append(ae3.encoder(data))
    encoded_train_loader = DataLoader(torch.cat(encoded_train_data), batch_size=batch_size, shuffle=True)
    
    encoded_val_data = []
    for data in val_loader:
        encoded_val_data.append(ae3.encoder(data))
    encoded_val_loader = DataLoader(torch.cat(encoded_val_data), batch_size=batch_size, shuffle=False)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56

    五、将已训练的自编码器级联

    class StackedAutoencoder(nn.Module):
        def __init__(self, ae1, ae2, ae3):
            super(StackedAutoencoder, self).__init__()
            self.encoder = nn.Sequential(ae1.encoder, ae2.encoder, ae3.encoder)
            self.decoder = nn.Sequential(ae3.decoder, ae2.decoder, ae1.decoder)
    
        def forward(self, x):
            x = self.encoder(x)
            x = self.decoder(x)
            return x
    
    sae = StackedAutoencoder(ae1, ae2, ae3)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    六、微调整个堆叠自编码器

    在整个数据集上重新训练堆叠自编码器来完成。

    train_autoencoder(sae, train_dataset)
    
    
    • 1
    • 2
  • 相关阅读:
    部署LVS-DR集群
    学1个月爬虫就月赚6000?别被骗了,老师傅告诉你爬虫的真实情况!
    es和kibana单机搭建
    linux shell实现将文件中所有的小写字母转换为大写字母
    在线论坛系统
    3. Spring Boot starter入门
    【数据结构与算法】线性表——顺序表
    工业互联网企业身份与访问控制课题研究与探索
    网购窗帘一定要注意的细节 - 好佳居窗帘十大品牌
    Qt源码阅读(三) 对象树管理
  • 原文地址:https://blog.csdn.net/xfysq_/article/details/133431232