• PyTorch下的5种不同神经网络-一.AlexNet


    1.导入模块

    导入所需的Python库,包括图像处理、深度学习模型和数据加载

    1. import os
    2. import torch
    3. import torch.nn as nn
    4. import torch.optim as optim
    5. from torch.utils.data import Dataset, DataLoader
    6. from PIL import Image
    7. from torchvision import models, transforms

    2.定义自定义图像数据集

    创建一个自定义的图像数据集类,用于加载和处理图像数据。

    1. class CustomImageDataset(Dataset):
    2.     def __init__(self, main_dir, transform=None):
    3.         self.main_dir = main_dir
    4.         self.transform = transform
    5.         self.files = []
    6.         self.labels = []
    7.         self.label_to_index = {}
    8.         for index, label in enumerate(os.listdir(main_dir)):
    9.             self.label_to_index[label] = index
    10.             label_dir = os.path.join(main_dir, label)
    11.             if os.path.isdir(label_dir):
    12.                 for file in os.listdir(label_dir):
    13.                     self.files.append(os.path.join(label_dir, file))
    14.                     self.labels.append(label)
    15.     def __len__(self):
    16.         return len(self.files)
    17.     def __getitem__(self, idx):
    18.         image = Image.open(self.files[idx])
    19.         label = self.labels[idx]
    20.         if self.transform:
    21.             image = self.transform(image)
    22.         return image, self.label_to_index[label]

    3.定义数据转换

    定义一个数据转换过程,包括图像大小调整、随机翻转、旋转、转换为张量以及标准化

    1. transform = transforms.Compose([
    2.     transforms.Resize((227, 227)),  # AlexNet的输入图像大小
    3.     transforms.RandomHorizontalFlip(),  # 随机水平翻转
    4.     transforms.RandomRotation(10),  # 随机旋转
    5.     transforms.ToTensor(),
    6.     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化
    7. ])

    4.创建数据集

    使用自定义数据集类和定义的数据转换来创建数据集

    dataset = CustomImageDataset(main_dir="F:\\A-GX\\A-SJJ\\flower_photos\\flower_photos", transform=transform)

    5.创建数据加载器

    使用数据集创建一个数据加载器,用于批量加载和处理数据

    data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

    6.加载预训练的AlexNet模型

    从PyTorch库中加载预训练的AlexNet模型

    alexnet_model = models.alexnet(pretrained=True)

    7.修改最后几层以适应新的分类任务

    修改AlexNet模型的最后几层,以便它能够处理新的分类任务

    1. num_ftrs = alexnet_model.classifier[6].in_features
    2. alexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))

    8.定义损失函数和优化器

    定义用于训练模型的损失函数和优化器。

    1. criterion = nn.CrossEntropyLoss()
    2. optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)

    9.模型并行化

    如果有多GPU,则使用nn.DataParallel来并行化模型

    1. if torch.cuda.device_count() > 1:
    2.     alexnet_model = nn.DataParallel(alexnet_model)

    10.将模型发送到GPU

    将模型发送到GPU进行训练。

    1. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    2. alexnet_model.to(device)

    11.训练模型

    数据加载器和定义的参数训练模型

    1. num_epochs = 10
    2. for epoch in range(num_epochs):
    3.     alexnet_model.train()
    4.     running_loss = 0.0
    5.     for images, labels in data_loader:
    6.         images, labels = images.to(device), labels.to(device)
    7.         # 前向传播
    8.         outputs = alexnet_model(images)
    9.         loss = criterion(outputs, labels)
    10.         # 反向传播和优化
    11.         optimizer.zero_grad()
    12.         loss.backward()
    13.         optimizer.step()
    14.         running_loss += loss.item()
    15.     # 在每个epoch结束后评估模型
    16.     train_accuracy = evaluate_model(alexnet_model, data_loader, device)
    17.     print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')

    12.评估模型

    定义一个评估函数,用于评估模型的性能

    1. def evaluate_model(model, data_loader, device):
    2.     model.eval()  # 将模型设置为评估模式
    3.     correct = 0
    4.     total = 0
    5.     with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度
    6.         for images, labels in data_loader:
    7.             images, labels = images.to(device), labels.to(device)
    8.             outputs = model(images)
    9.             _, predicted = torch.max(outputs.data, 1)
    10.             total += labels.size(0)
    11.             correct += (predicted == labels).sum().item()
    12.     accuracy = 100 * correct / total
    13. return accuracy

  • 相关阅读:
    数据结构实验四 线性表的基本操作及应用
    SpringData、SparkStreaming和Flink集成Elasticsearch
    基于自动化工具autox.js的抢票(猫眼)
    stable diffusion mode 的使用 invokeAI or stable diffusion web UI?
    获得店铺的所有商品 API 返回值说明
    [ Shell ] 两个 case 实现 GetOptions 效果
    项目整体管理
    DV SSL证书便宜吗?申请后多久签发?
    React源码分析(三):useState,useReducer
    Java 死,前端凉?!斗胆说点真话
  • 原文地址:https://blog.csdn.net/hhy020115/article/details/139847797