• pytorch入门,deep-learning-for-image-processing-master的test3-vggnet


    首先是train部分

    1. import os
    2. import sys
    3. import json
    4. import torch
    5. import torch.nn as nn
    6. from torchvision import transforms, datasets
    7. import torch.optim as optim
    8. from tqdm import tqdm
    9. from model import vgg
    10. def main():
    11. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    12. print("using {} device.".format(device))
    13. data_transform = {
    14. "train": transforms.Compose([transforms.RandomResizedCrop(224),
    15. transforms.RandomHorizontalFlip(),
    16. transforms.ToTensor(),
    17. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    18. "val": transforms.Compose([transforms.Resize((224, 224)),
    19. transforms.ToTensor(),
    20. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
    21. data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
    22. image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path
    23. assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    24. train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
    25. transform=data_transform["train"])
    26. train_num = len(train_dataset)
    27. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    28. flower_list = train_dataset.class_to_idx
    29. cla_dict = dict((val, key) for key, val in flower_list.items())
    30. # write dict into json file
    31. json_str = json.dumps(cla_dict, indent=4)
    32. with open('class_indices.json', 'w') as json_file:
    33. json_file.write(json_str)
    34. batch_size = 32
    35. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
    36. print('Using {} dataloader workers every process'.format(nw))
    37. train_loader = torch.utils.data.DataLoader(train_dataset,
    38. batch_size=batch_size, shuffle=True,
    39. num_workers=nw)
    40. validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
    41. transform=data_transform["val"])
    42. val_num = len(validate_dataset)
    43. validate_loader = torch.utils.data.DataLoader(validate_dataset,
    44. batch_size=batch_size, shuffle=False,
    45. num_workers=nw)
    46. print("using {} images for training, {} images for validation.".format(train_num,
    47. val_num))
    48. # test_data_iter = iter(validate_loader)
    49. # test_image, test_label = test_data_iter.next()
    50. model_name = "vgg16"
    51. net = vgg(model_name=model_name, num_classes=5, init_weights=True)
    52. net.to(device)
    53. loss_function = nn.CrossEntropyLoss()
    54. optimizer = optim.Adam(net.parameters(), lr=0.0001)
    55. epochs = 30
    56. best_acc = 0.0
    57. save_path = './{}Net.pth'.format(model_name)
    58. train_steps = len(train_loader)
    59. for epoch in range(epochs):
    60. # train
    61. net.train()
    62. running_loss = 0.0
    63. train_bar = tqdm(train_loader, file=sys.stdout)
    64. for step, data in enumerate(train_bar):
    65. images, labels = data
    66. optimizer.zero_grad()
    67. outputs = net(images.to(device))
    68. loss = loss_function(outputs, labels.to(device))
    69. loss.backward()
    70. optimizer.step()
    71. # print statistics
    72. running_loss += loss.item()
    73. train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
    74. epochs,
    75. loss)
    76. # validate
    77. net.eval()
    78. acc = 0.0 # accumulate accurate number / epoch
    79. with torch.no_grad():
    80. val_bar = tqdm(validate_loader, file=sys.stdout)
    81. for val_data in val_bar:
    82. val_images, val_labels = val_data
    83. outputs = net(val_images.to(device))
    84. predict_y = torch.max(outputs, dim=1)[1]
    85. acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
    86. val_accurate = acc / val_num
    87. print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
    88. (epoch + 1, running_loss / train_steps, val_accurate))
    89. if val_accurate > best_acc:
    90. best_acc = val_accurate
    91. torch.save(net.state_dict(), save_path)
    92. print('Finished Training')
    93. if __name__ == '__main__':
    94. main()
    def main():
        device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print("using {} device".format(device))
    
        data_transform={
            "train":transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ]),
            "val":transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ])
        }
        data_root=os.path.abspath(os.path.join(os.getcwd(),"../.."))
        image_path=os.path.join(data_root,"data_set","flower_data")
        assert os.path.exists(image_path), "{} path is not exist. ".format(image_path)
        train_dataset=datasets.ImageFolder(root=os.path.join(image_path,"train"),
                                           transform=data_transform["train"])
        val_dataset=datasets.ImageFolder(root=os.path.join(image_path,"val"),
                                         transform=data_transform["val"])
        train_num=len(train_dataset)

    这里和之前一样,需要注意绝对路径,使用dataset.Imagefolder很方便

    flower_list=train_dataset.class_to_idx使用这个,很方便的给图片打数字标
    cla_dict=dict((val,key) for key,val in flower_list.items())打包成字典

    json_str=json.dumps(cla_dict,indent=4)
    with open('class_indices.json','w') as json_file:
        json_file.write(json_str)
        写进json里面

    接下来是load

    model_name='vgg16'
    net=vgg(model_name=model_name,num_classes=5,init_weights=True)
    net.to(device)
    loss_function=nn.CrossEntropyLoss()
    optimizer=optim.Adam(net.parameters(),lr=0.00001)
    epochs=30
    best_acc=0.0
    save_path='./{}Net.path'.format(model_name)
    train_steps=len(train_loader)k
    
    开始训练
    

    for epoch in range(epochs):
        net.train()
        running_loss=0.0
        train_bar=tqdm(train_loader,file=sys.stdout)
        for step,data in enumerate(train_bar):
            images,labels=data
            optimizer.zero_grad()
            outputs=net(images.to(device))
            loss=loss_function(outputs,labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss+=loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
        然后做测试

    net.eval()
    acc=0.0
    with torch.no_grad():
        val_bar=tqdm(val_loader,file=sys.stdout)
        for val_data in val_bar:
            val_images,val_labels=val_data
            outputs=net(val_images.to(device))
            predict_y=torch.max(outputs,dim=1)[1]
            acc=acc+torch.eq(predict_y,val_labels.to(device).sum().item())
    
    val_acc=acc/val_num
    print('[epoch%d] train_loss: %.3f val_acc:%.3f'%(epoch+1,running_loss/train_steps,val_acc))b
     
    保存模型:
    if val_acc>best_acc:
        best_acc=val_acc
        torch.save(net.state_dict(),save_path)

    然后是测试集:

    1. import os
    2. import json
    3. import torch
    4. from PIL import Image
    5. from torchvision import transforms
    6. import matplotlib.pyplot as plt
    7. from model import vgg
    8. def main():
    9. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    10. data_transform = transforms.Compose(
    11. [transforms.Resize((224, 224)),
    12. transforms.ToTensor(),
    13. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    14. # load image
    15. img_path = "../tulip.jpg"
    16. assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    17. img = Image.open(img_path)
    18. plt.imshow(img)
    19. # [N, C, H, W]
    20. img = data_transform(img)
    21. # expand batch dimension
    22. img = torch.unsqueeze(img, dim=0)
    23. # read class_indict
    24. json_path = './class_indices.json'
    25. assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
    26. with open(json_path, "r") as f:
    27. class_indict = json.load(f)
    28. # create model
    29. model = vgg(model_name="vgg16", num_classes=5).to(device)
    30. # load model weights
    31. weights_path = "./vgg16Net.pth"
    32. assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    33. model.load_state_dict(torch.load(weights_path, map_location=device))
    34. model.eval()
    35. with torch.no_grad():
    36. # predict class
    37. output = torch.squeeze(model(img.to(device))).cpu()
    38. predict = torch.softmax(output, dim=0)
    39. predict_cla = torch.argmax(predict).numpy()
    40. print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
    41. predict[predict_cla].numpy())
    42. plt.title(print_res)
    43. for i in range(len(predict)):
    44. print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
    45. predict[i].numpy()))
    46. plt.show()
    47. if __name__ == '__main__':
    48. main()

    正常操作
    def main():
        device=torch.device("cuda: 0" if torch.cuda.is_available() else "cpu")
    
        data_transform=transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ])
    
        img_path="../tulip.jpg"
        assert os.path.exists(img_path) ,"{} is not exist".format(img_path)
        img=Image.open(img_path)
        plt.show(img)
        img=data_transform
        img=torch.unsqueeze(img,dim=0)

    开始读json文件

    json_path='./class_indices.json'
    assert os.path.exists(json_path),"{} is not exist".format(json_path)
    with open(json_path,"r") as f:
        class_indict=json.load(f)
        

    使用模型开始测试

    model=vgg(model_name="vgg_16",num_class=5).to(device)
    weights_path="./vgg16Net.pth"
    assert  os.path.exists(weights_path),"{} is not exist".format(weights_path)
    model.load_state_dict(torch.load(weights_path,map_location=device))
    model.eval()
    with torch.no_grad():
        output=model(img.to(device))
        output=torch.squeeze(output).cpu()
        predict=torch.softmax(output,dim=0)
        predict_cla=torch.argmax(predict).numpy()

    接下来是model

    1. import torch.nn as nn
    2. import torch
    3. # official pretrain weights
    4. model_urls = {
    5. 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    6. 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    7. 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    8. 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
    9. }
    10. class VGG(nn.Module):
    11. def __init__(self, features, num_classes=1000, init_weights=False):
    12. super(VGG, self).__init__()
    13. self.features = features
    14. self.classifier = nn.Sequential(
    15. nn.Linear(512*7*7, 4096),
    16. nn.ReLU(True),
    17. nn.Dropout(p=0.5),
    18. nn.Linear(4096, 4096),
    19. nn.ReLU(True),
    20. nn.Dropout(p=0.5),
    21. nn.Linear(4096, num_classes)
    22. )
    23. if init_weights:
    24. self._initialize_weights()
    25. def forward(self, x):
    26. # N x 3 x 224 x 224
    27. x = self.features(x)
    28. # N x 512 x 7 x 7
    29. x = torch.flatten(x, start_dim=1)
    30. # N x 512*7*7
    31. x = self.classifier(x)
    32. return x
    33. def _initialize_weights(self):
    34. for m in self.modules():
    35. if isinstance(m, nn.Conv2d):
    36. # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    37. nn.init.xavier_uniform_(m.weight)
    38. if m.bias is not None:
    39. nn.init.constant_(m.bias, 0)
    40. elif isinstance(m, nn.Linear):
    41. nn.init.xavier_uniform_(m.weight)
    42. # nn.init.normal_(m.weight, 0, 0.01)
    43. nn.init.constant_(m.bias, 0)
    44. def make_features(cfg: list):
    45. layers = []
    46. in_channels = 3
    47. for v in cfg:
    48. if v == "M":
    49. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
    50. else:
    51. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
    52. layers += [conv2d, nn.ReLU(True)]
    53. in_channels = v
    54. return nn.Sequential(*layers)
    55. cfgs = {
    56. 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    57. 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    58. 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    59. 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
    60. }
    61. def vgg(model_name="vgg16", **kwargs):
    62. assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    63. cfg = cfgs[model_name]
    64. model = VGG(make_features(cfg), **kwargs)
    65. return model
    class VGG(nn.Module):
        def __init__(self,features,num_classes=1000,init_weights=False):
            super(VGG,self).__init__()
            self.features=features
            self.classifier=nn.Sequential(
                nn.Linear(512 * 7 * 7, 4096),
                nn.ReLU(True),
                nn.Dropout(p=0.5),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(p=0.5),
                nn.Linear(4096, num_classes)
            )
            if init_weights:
                self._initialize_weights()

    调整维度,扁平化成线性的

    def forward(self,x):
        x=self.features(x)
        x=torch.flatten(x,start_dim=1)
        x=self.classifier(x)
        return x
    def make_features(cfg:list):
        layers=[]
        in_channels=3
        for v in cfg:
            if v=="M":
                layers+=[nn.MaxPool2d(kernel_size=2,stride=2)]
            else:
                conv2d=nn.Conv2d(in_channels,v,kernel_size=3,padding=1)
                layers+=[conv2d,nn.ReLU(True)]
                in_channels=v
        return nn.Sequential(*layers)
  • 相关阅读:
    Shiro学习与笔记
    【力扣2154】将找到的值乘以 2
    Mobile App自动化测试技术及实现
    肖sir__mysql之索引__010
    (实用)页面在线QQ咨询html代码
    虚拟机安装Ubuntu太慢解决办法
    SAP ABAP Function Module 的动态调用方式使用方式介绍试读版
    每日学习2
    div中包含checkbox 点击事件重复问题
    制造业企业为什么需要数字化转型
  • 原文地址:https://blog.csdn.net/kling_bling/article/details/126445713