• 6.Vgg16--CNN经典网络模型详解(pytorch实现)


    1.首先建立一个model.py文件,用来写神经网络,代码如下:

    1. import torch
    2. import torch.nn as nn
    3. class My_VGG16(nn.Module):
    4. def __init__(self,num_classes=5,init_weight=True):
    5. super(My_VGG16, self).__init__()
    6. # 特征提取层
    7. self.features = nn.Sequential(
    8. nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1),
    9. nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1),
    10. nn.MaxPool2d(kernel_size=2,stride=2),
    11. nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
    12. nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
    13. nn.MaxPool2d(kernel_size=2, stride=2),
    14. nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
    15. nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
    16. nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
    17. nn.MaxPool2d(kernel_size=2,stride=2),
    18. nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
    19. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    20. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    21. nn.MaxPool2d(kernel_size=2, stride=2),
    22. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    23. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    24. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    25. nn.MaxPool2d(kernel_size=2, stride=2),
    26. )
    27. # 分类层
    28. self.classifier = nn.Sequential(
    29. nn.Linear(in_features=7*7*512,out_features=4096),
    30. nn.ReLU(),
    31. nn.Dropout(0.5),
    32. nn.Linear(in_features=4096,out_features=4096),
    33. nn.ReLU(),
    34. nn.Dropout(0.5),
    35. nn.Linear(in_features=4096,out_features=num_classes)
    36. )
    37. # 参数初始化
    38. if init_weight: # 如果进行参数初始化
    39. for m in self.modules(): # 对于模型的每一层
    40. if isinstance(m, nn.Conv2d): # 如果是卷积层
    41. # 使用kaiming初始化
    42. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
    43. # 如果bias不为空,固定为0
    44. if m.bias is not None:
    45. nn.init.constant_(m.bias, 0)
    46. elif isinstance(m, nn.Linear):# 如果是线性层
    47. # 正态初始化
    48. nn.init.normal_(m.weight, 0, 0.01)
    49. # bias则固定为0
    50. nn.init.constant_(m.bias, 0)
    51. def forward(self,x):
    52. x = self.features(x)
    53. x = torch.flatten(x,1)
    54. result = self.classifier(x)
    55. return result

    2.下载数据集

    DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
    

    3.下载完后写一个spile_data.py文件,将数据集进行分类

    1. #spile_data.py
    2. import os
    3. from shutil import copy
    4. import random
    5. def mkfile(file):
    6. if not os.path.exists(file):
    7. os.makedirs(file)
    8. file = 'flower_data/flower_photos'
    9. flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
    10. mkfile('flower_data/train')
    11. for cla in flower_class:
    12. mkfile('flower_data/train/'+cla)
    13. mkfile('flower_data/val')
    14. for cla in flower_class:
    15. mkfile('flower_data/val/'+cla)
    16. split_rate = 0.1
    17. for cla in flower_class:
    18. cla_path = file + '/' + cla + '/'
    19. images = os.listdir(cla_path)
    20. num = len(images)
    21. eval_index = random.sample(images, k=int(num*split_rate))
    22. for index, image in enumerate(images):
    23. if image in eval_index:
    24. image_path = cla_path + image
    25. new_path = 'flower_data/val/' + cla
    26. copy(image_path, new_path)
    27. else:
    28. image_path = cla_path + image
    29. new_path = 'flower_data/train/' + cla
    30. copy(image_path, new_path)
    31. print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
    32. print()
    33. print("processing done!")

    之后应该是这样:
    在这里插入图片描述

    4.再写一个train.py文件,用来训练模型

    1. import torch
    2. import torch.nn as nn
    3. from torchvision import transforms, datasets
    4. import json
    5. import os
    6. import torch.optim as optim
    7. from model import My_VGG16
    8. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    9. data_transform = {
    10. "train": transforms.Compose([transforms.RandomResizedCrop(224),
    11. transforms.RandomHorizontalFlip(),
    12. transforms.ToTensor(),
    13. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    14. "val": transforms.Compose([transforms.Resize(256),
    15. transforms.CenterCrop(224),
    16. transforms.ToTensor(),
    17. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    18. data_root = os.getcwd() # get data root path
    19. image_path = data_root + "/flower_data/" # flower data set path
    20. train_dataset = datasets.ImageFolder(root=image_path+"train",
    21. transform=data_transform["train"])
    22. train_num = len(train_dataset)
    23. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    24. flower_list = train_dataset.class_to_idx
    25. cla_dict = dict((val, key) for key, val in flower_list.items())
    26. cla_dict = dict((val, key) for key, val in flower_list.items())
    27. # write dict into json file
    28. json_str = json.dumps(cla_dict, indent=4)
    29. with open('class_indices.json', 'w') as json_file:
    30. json_file.write(json_str)
    31. batch_size = 16
    32. train_loader = torch.utils.data.DataLoader(train_dataset,
    33. batch_size=batch_size, shuffle=True,
    34. num_workers=0)
    35. validate_dataset = datasets.ImageFolder(root=image_path + "val",
    36. transform=data_transform["val"])
    37. val_num = len(validate_dataset)
    38. validate_loader = torch.utils.data.DataLoader(validate_dataset,
    39. batch_size=batch_size, shuffle=False,
    40. num_workers=0)
    41. net = My_VGG16(num_classes=5)
    42. # load pretrain weights
    43. model_weight_path = "./vgg16.pth"
    44. pre_weights = torch.load(model_weight_path)
    45. net.to(device)
    46. loss_function = nn.CrossEntropyLoss()
    47. optimizer = optim.Adam(net.parameters(), lr=0.0001)
    48. best_acc = 0.0
    49. save_path = './vgg16_train.pth'
    50. for epoch in range(5):
    51. # train
    52. net.train()
    53. running_loss = 0.0
    54. for step, data in enumerate(train_loader, start=0):
    55. images, labels = data
    56. optimizer.zero_grad()
    57. logits = net(images.to(device))#.to(device)
    58. print("===>",logits.shape,labels.shape)
    59. loss = loss_function(logits, labels.to(device))
    60. loss.backward()
    61. optimizer.step()
    62. # print statistics
    63. running_loss += loss.item()
    64. # print train process
    65. rate = (step+1)/len(train_loader)
    66. a = "*" * int(rate * 50)
    67. b = "." * int((1 - rate) * 50)
    68. print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
    69. print()
    70. # validate
    71. net.eval()
    72. acc = 0.0 # accumulate accurate number / epoch
    73. with torch.no_grad():
    74. for val_data in validate_loader:
    75. val_images, val_labels = val_data
    76. outputs = net(val_images.to(device)) # eval model only have last output layer
    77. # loss = loss_function(outputs, test_labels)
    78. predict_y = torch.max(outputs, dim=1)[1]
    79. acc += (predict_y == val_labels.to(device)).sum().item()
    80. val_accurate = acc / val_num
    81. if val_accurate > best_acc:
    82. best_acc = val_accurate
    83. torch.save(net.state_dict(), save_path)
    84. print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %
    85. (epoch + 1, running_loss / step, val_accurate))
    86. print('Finished Training')

  • 相关阅读:
    leetcode1537. 最大得分(动态规划-java)
    防火墙nat实验
    组网神器WireGuard安装与配置教程(超详细)
    js_for循环实操
    华特迪士尼公司与DivX签署知识产权许可协议
    Linux:shell编程1(内含:1.shell简介+2.shell实操+3.shell的变量介绍+4.shell变量的定义)
    航天航空VR科普展VR太空科技馆沉浸式遨游体验
    32:第三章:开发通行证服务:15:浏览器存储介质,简介;(cookie,Session Storage,Local Storage)
    贪心算法一:最优装载问题
    qml里使用组件的案例
  • 原文地址:https://blog.csdn.net/weixin_71719718/article/details/138165753