• 6.Vgg16--CNN经典网络模型详解(pytorch实现)


    1.首先建立一个model.py文件,用来写神经网络,代码如下:

    1. import torch
    2. import torch.nn as nn
    3. class My_VGG16(nn.Module):
    4. def __init__(self,num_classes=5,init_weight=True):
    5. super(My_VGG16, self).__init__()
    6. # 特征提取层
    7. self.features = nn.Sequential(
    8. nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1),
    9. nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1),
    10. nn.MaxPool2d(kernel_size=2,stride=2),
    11. nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
    12. nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
    13. nn.MaxPool2d(kernel_size=2, stride=2),
    14. nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
    15. nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
    16. nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
    17. nn.MaxPool2d(kernel_size=2,stride=2),
    18. nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
    19. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    20. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    21. nn.MaxPool2d(kernel_size=2, stride=2),
    22. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    23. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    24. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
    25. nn.MaxPool2d(kernel_size=2, stride=2),
    26. )
    27. # 分类层
    28. self.classifier = nn.Sequential(
    29. nn.Linear(in_features=7*7*512,out_features=4096),
    30. nn.ReLU(),
    31. nn.Dropout(0.5),
    32. nn.Linear(in_features=4096,out_features=4096),
    33. nn.ReLU(),
    34. nn.Dropout(0.5),
    35. nn.Linear(in_features=4096,out_features=num_classes)
    36. )
    37. # 参数初始化
    38. if init_weight: # 如果进行参数初始化
    39. for m in self.modules(): # 对于模型的每一层
    40. if isinstance(m, nn.Conv2d): # 如果是卷积层
    41. # 使用kaiming初始化
    42. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
    43. # 如果bias不为空,固定为0
    44. if m.bias is not None:
    45. nn.init.constant_(m.bias, 0)
    46. elif isinstance(m, nn.Linear):# 如果是线性层
    47. # 正态初始化
    48. nn.init.normal_(m.weight, 0, 0.01)
    49. # bias则固定为0
    50. nn.init.constant_(m.bias, 0)
    51. def forward(self,x):
    52. x = self.features(x)
    53. x = torch.flatten(x,1)
    54. result = self.classifier(x)
    55. return result

    2.下载数据集

    DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
    

    3.下载完后写一个spile_data.py文件,将数据集进行分类

    1. #spile_data.py
    2. import os
    3. from shutil import copy
    4. import random
    5. def mkfile(file):
    6. if not os.path.exists(file):
    7. os.makedirs(file)
    8. file = 'flower_data/flower_photos'
    9. flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
    10. mkfile('flower_data/train')
    11. for cla in flower_class:
    12. mkfile('flower_data/train/'+cla)
    13. mkfile('flower_data/val')
    14. for cla in flower_class:
    15. mkfile('flower_data/val/'+cla)
    16. split_rate = 0.1
    17. for cla in flower_class:
    18. cla_path = file + '/' + cla + '/'
    19. images = os.listdir(cla_path)
    20. num = len(images)
    21. eval_index = random.sample(images, k=int(num*split_rate))
    22. for index, image in enumerate(images):
    23. if image in eval_index:
    24. image_path = cla_path + image
    25. new_path = 'flower_data/val/' + cla
    26. copy(image_path, new_path)
    27. else:
    28. image_path = cla_path + image
    29. new_path = 'flower_data/train/' + cla
    30. copy(image_path, new_path)
    31. print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
    32. print()
    33. print("processing done!")

    之后应该是这样:
    在这里插入图片描述

    4.再写一个train.py文件,用来训练模型

    1. import torch
    2. import torch.nn as nn
    3. from torchvision import transforms, datasets
    4. import json
    5. import os
    6. import torch.optim as optim
    7. from model import My_VGG16
    8. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    9. data_transform = {
    10. "train": transforms.Compose([transforms.RandomResizedCrop(224),
    11. transforms.RandomHorizontalFlip(),
    12. transforms.ToTensor(),
    13. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    14. "val": transforms.Compose([transforms.Resize(256),
    15. transforms.CenterCrop(224),
    16. transforms.ToTensor(),
    17. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    18. data_root = os.getcwd() # get data root path
    19. image_path = data_root + "/flower_data/" # flower data set path
    20. train_dataset = datasets.ImageFolder(root=image_path+"train",
    21. transform=data_transform["train"])
    22. train_num = len(train_dataset)
    23. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    24. flower_list = train_dataset.class_to_idx
    25. cla_dict = dict((val, key) for key, val in flower_list.items())
    26. cla_dict = dict((val, key) for key, val in flower_list.items())
    27. # write dict into json file
    28. json_str = json.dumps(cla_dict, indent=4)
    29. with open('class_indices.json', 'w') as json_file:
    30. json_file.write(json_str)
    31. batch_size = 16
    32. train_loader = torch.utils.data.DataLoader(train_dataset,
    33. batch_size=batch_size, shuffle=True,
    34. num_workers=0)
    35. validate_dataset = datasets.ImageFolder(root=image_path + "val",
    36. transform=data_transform["val"])
    37. val_num = len(validate_dataset)
    38. validate_loader = torch.utils.data.DataLoader(validate_dataset,
    39. batch_size=batch_size, shuffle=False,
    40. num_workers=0)
    41. net = My_VGG16(num_classes=5)
    42. # load pretrain weights
    43. model_weight_path = "./vgg16.pth"
    44. pre_weights = torch.load(model_weight_path)
    45. net.to(device)
    46. loss_function = nn.CrossEntropyLoss()
    47. optimizer = optim.Adam(net.parameters(), lr=0.0001)
    48. best_acc = 0.0
    49. save_path = './vgg16_train.pth'
    50. for epoch in range(5):
    51. # train
    52. net.train()
    53. running_loss = 0.0
    54. for step, data in enumerate(train_loader, start=0):
    55. images, labels = data
    56. optimizer.zero_grad()
    57. logits = net(images.to(device))#.to(device)
    58. print("===>",logits.shape,labels.shape)
    59. loss = loss_function(logits, labels.to(device))
    60. loss.backward()
    61. optimizer.step()
    62. # print statistics
    63. running_loss += loss.item()
    64. # print train process
    65. rate = (step+1)/len(train_loader)
    66. a = "*" * int(rate * 50)
    67. b = "." * int((1 - rate) * 50)
    68. print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
    69. print()
    70. # validate
    71. net.eval()
    72. acc = 0.0 # accumulate accurate number / epoch
    73. with torch.no_grad():
    74. for val_data in validate_loader:
    75. val_images, val_labels = val_data
    76. outputs = net(val_images.to(device)) # eval model only have last output layer
    77. # loss = loss_function(outputs, test_labels)
    78. predict_y = torch.max(outputs, dim=1)[1]
    79. acc += (predict_y == val_labels.to(device)).sum().item()
    80. val_accurate = acc / val_num
    81. if val_accurate > best_acc:
    82. best_acc = val_accurate
    83. torch.save(net.state_dict(), save_path)
    84. print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %
    85. (epoch + 1, running_loss / step, val_accurate))
    86. print('Finished Training')

  • 相关阅读:
    蚂蚁链发布全新Web3品牌ZAN,涉及RWA、合规等服务
    全图化在线系统设计
    带头双向循环链表增删查改实现(C语言)
    Spring实例化Bean的三种方式及其优缺点
    A Mathematical Framework for Transformer Circuits—(二)
    面试突击89:事务隔离级别和传播机制有什么区别?
    [英雄星球六月集训LeetCode解题日报] 第27日 图
    day4_redis中分布式锁实现一人一单
    排序方法——《快速排序》
    Redis企业版数据库如何支持实时金融服务?
  • 原文地址:https://blog.csdn.net/weixin_71719718/article/details/138165753