• 6.Vgg16--CNN经典网络模型详解(pytorch实现)


    1.首先建立一个model.py文件,用来写神经网络,代码如下:

    1. import torch
    2. import torch.nn as nn
    3. class My_VGG16(nn.Module):
    4. def __init__(self,num_classes=5,init_weight=True):
    5. super(My_VGG16, self).__init__()
    6. # 特征提取层
    7. self.features = nn.Sequential(
    8. nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1),
    9. nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1),
    10. nn.MaxPool2d(kernel_size=2,stride=2),
    11. nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
    12. nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
    13. nn.MaxPool2d(kernel_size=2, stride=2),
    14. nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
    15. nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
    16. nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
    17. nn.MaxPool2d(kernel_size=2,stride=2),
    18. nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
    19. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    20. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    21. nn.MaxPool2d(kernel_size=2, stride=2),
    22. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    23. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    24. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    25. nn.MaxPool2d(kernel_size=2, stride=2),
    26. )
    27. # 分类层
    28. self.classifier = nn.Sequential(
    29. nn.Linear(in_features=7*7*512,out_features=4096),
    30. nn.ReLU(),
    31. nn.Dropout(0.5),
    32. nn.Linear(in_features=4096,out_features=4096),
    33. nn.ReLU(),
    34. nn.Dropout(0.5),
    35. nn.Linear(in_features=4096,out_features=num_classes)
    36. )
    37. # 参数初始化
    38. if init_weight: # 如果进行参数初始化
    39. for m in self.modules(): # 对于模型的每一层
    40. if isinstance(m, nn.Conv2d): # 如果是卷积层
    41. # 使用kaiming初始化
    42. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
    43. # 如果bias不为空,固定为0
    44. if m.bias is not None:
    45. nn.init.constant_(m.bias, 0)
    46. elif isinstance(m, nn.Linear):# 如果是线性层
    47. # 正态初始化
    48. nn.init.normal_(m.weight, 0, 0.01)
    49. # bias则固定为0
    50. nn.init.constant_(m.bias, 0)
    51. def forward(self,x):
    52. x = self.features(x)
    53. x = torch.flatten(x,1)
    54. result = self.classifier(x)
    55. return result

    2.下载数据集

    DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
    

    3.下载完后写一个spile_data.py文件,将数据集进行分类

    1. #spile_data.py
    2. import os
    3. from shutil import copy
    4. import random
    5. def mkfile(file):
    6. if not os.path.exists(file):
    7. os.makedirs(file)
    8. file = 'flower_data/flower_photos'
    9. flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
    10. mkfile('flower_data/train')
    11. for cla in flower_class:
    12. mkfile('flower_data/train/'+cla)
    13. mkfile('flower_data/val')
    14. for cla in flower_class:
    15. mkfile('flower_data/val/'+cla)
    16. split_rate = 0.1
    17. for cla in flower_class:
    18. cla_path = file + '/' + cla + '/'
    19. images = os.listdir(cla_path)
    20. num = len(images)
    21. eval_index = random.sample(images, k=int(num*split_rate))
    22. for index, image in enumerate(images):
    23. if image in eval_index:
    24. image_path = cla_path + image
    25. new_path = 'flower_data/val/' + cla
    26. copy(image_path, new_path)
    27. else:
    28. image_path = cla_path + image
    29. new_path = 'flower_data/train/' + cla
    30. copy(image_path, new_path)
    31. print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
    32. print()
    33. print("processing done!")

    之后应该是这样:
    在这里插入图片描述

    4.再写一个train.py文件,用来训练模型

    1. import torch
    2. import torch.nn as nn
    3. from torchvision import transforms, datasets
    4. import json
    5. import os
    6. import torch.optim as optim
    7. from model import My_VGG16
    8. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    9. data_transform = {
    10. "train": transforms.Compose([transforms.RandomResizedCrop(224),
    11. transforms.RandomHorizontalFlip(),
    12. transforms.ToTensor(),
    13. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    14. "val": transforms.Compose([transforms.Resize(256),
    15. transforms.CenterCrop(224),
    16. transforms.ToTensor(),
    17. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    18. data_root = os.getcwd() # get data root path
    19. image_path = data_root + "/flower_data/" # flower data set path
    20. train_dataset = datasets.ImageFolder(root=image_path+"train",
    21. transform=data_transform["train"])
    22. train_num = len(train_dataset)
    23. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    24. flower_list = train_dataset.class_to_idx
    25. cla_dict = dict((val, key) for key, val in flower_list.items())
    26. cla_dict = dict((val, key) for key, val in flower_list.items())
    27. # write dict into json file
    28. json_str = json.dumps(cla_dict, indent=4)
    29. with open('class_indices.json', 'w') as json_file:
    30. json_file.write(json_str)
    31. batch_size = 16
    32. train_loader = torch.utils.data.DataLoader(train_dataset,
    33. batch_size=batch_size, shuffle=True,
    34. num_workers=0)
    35. validate_dataset = datasets.ImageFolder(root=image_path + "val",
    36. transform=data_transform["val"])
    37. val_num = len(validate_dataset)
    38. validate_loader = torch.utils.data.DataLoader(validate_dataset,
    39. batch_size=batch_size, shuffle=False,
    40. num_workers=0)
    41. net = My_VGG16(num_classes=5)
    42. # load pretrain weights
    43. model_weight_path = "./vgg16.pth"
    44. pre_weights = torch.load(model_weight_path)
    45. net.to(device)
    46. loss_function = nn.CrossEntropyLoss()
    47. optimizer = optim.Adam(net.parameters(), lr=0.0001)
    48. best_acc = 0.0
    49. save_path = './vgg16_train.pth'
    50. for epoch in range(5):
    51. # train
    52. net.train()
    53. running_loss = 0.0
    54. for step, data in enumerate(train_loader, start=0):
    55. images, labels = data
    56. optimizer.zero_grad()
    57. logits = net(images.to(device))#.to(device)
    58. print("===>",logits.shape,labels.shape)
    59. loss = loss_function(logits, labels.to(device))
    60. loss.backward()
    61. optimizer.step()
    62. # print statistics
    63. running_loss += loss.item()
    64. # print train process
    65. rate = (step+1)/len(train_loader)
    66. a = "*" * int(rate * 50)
    67. b = "." * int((1 - rate) * 50)
    68. print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
    69. print()
    70. # validate
    71. net.eval()
    72. acc = 0.0 # accumulate accurate number / epoch
    73. with torch.no_grad():
    74. for val_data in validate_loader:
    75. val_images, val_labels = val_data
    76. outputs = net(val_images.to(device)) # eval model only have last output layer
    77. # loss = loss_function(outputs, test_labels)
    78. predict_y = torch.max(outputs, dim=1)[1]
    79. acc += (predict_y == val_labels.to(device)).sum().item()
    80. val_accurate = acc / val_num
    81. if val_accurate > best_acc:
    82. best_acc = val_accurate
    83. torch.save(net.state_dict(), save_path)
    84. print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %
    85. (epoch + 1, running_loss / step, val_accurate))
    86. print('Finished Training')

  • 相关阅读:
    套接口通用发送缓存区限定
    Ubuntu 安装最新版python
    kind 安装 k8s 集群
    RL gym 环境(1)—— 安装和基础使用
    鸿鹄工程项目管理系统em Spring Cloud+Spring Boot+前后端分离构建工程项目管理系统
    【高性能计算是如何变现的?】
    正火热的人机协作,优势揭晓!
    浅谈一下前端字符编码
    shopify 如何进行二次开发~起航篇
    新零售系统主要功能有哪些?新零售系统开发公司推荐
  • 原文地址:https://blog.csdn.net/weixin_71719718/article/details/138165753