• pytorch学习(二):transforms使用


    作用:将特定的图片通过transforms工具处理,得到我们想要的结果。

    ToTensor写成如下会报错:

    1. img=cv2.imread('./data/train/ants/0013035.jpg')
    2. tensor_img=transforms.ToTensor(img)
    3. print(tensor_img)

    正确形式:

    1. img=cv2.imread('./data/train/ants/0013035.jpg')
    2. tensor_trans=transforms.ToTensor()
    3. tensor_img=tensor_trans(img)
    4. print(tensor_img)

     原因:

    PyTorchtorchvision.transforms模块中,ToTensor是一个用于图像预处理的类,而不是一个函数。因此,你不能直接调用transforms.ToTensor(img)来转换图像,因为这不是这个类设计的使用方式。

    ToTensor类是用来创建一个预处理对象的,该对象可以被调用以将PIL图像或NumPy ndarray转换为PyTorch张量(tensor)。

    正确的使用方式是:

    1. 首先,实例化ToTensor类,得到一个转换对象。
    2. 然后,使用该转换对象来转换图像。

    遇见小错误:tensorboard无法显示图片,原因:忘记 writer.close()

    __call__函数和普通函数的区别:

    1. class Person:
    2. def __call__(self,name):
    3. print('__call__'+' hello'+name)
    4. def hello(self,name):
    5. print(' hello' + name)
    6. person=Person()
    7. person('cool')
    8. person.hello('cool')

    在Python的类中,__call__函数和forward函数(后者并不是Python的一个内置特殊方法,但经常在神经网络或计算图中见到)有不同的调用规则。

    __call__ 函数

    __call__ 是一个特殊方法,允许类的实例像函数那样被调用。当你尝试调用一个类的实例时,Python会自动查找并调用该实例的__call__方法。

    1. class CallableClass:
    2. def __call__(self, *args, **kwargs):
    3. print("CallableClass is being called!")
    4. print(f"Arguments: {args}, Keyword arguments: {kwargs}")
    5. # 创建一个实例
    6. instance = CallableClass()
    7. # 调用实例,就像调用一个函数
    8. instance(1, 2, 3, a=4, b=5)
    9. # 输出:
    10. # CallableClass is being called!
    11. # Arguments: (1, 2, 3), Keyword arguments: {'a': 4, 'b': 5}

    forward 函数

    forward 函数通常不是Python内置的一部分,但在某些库(如PyTorch)中,它被用作神经网络模块(如nn.Module)的前向传播方法。当你定义一个继承自nn.Module的类时,你通常需要实现一个forward方法,该方法描述了数据通过网络的前向传播。

    1. import torch
    2. import torch.nn as nn
    3. import torch.nn.functional as F
    4. class SimpleNeuralNet(nn.Module):
    5. def __init__(self):
    6. super(SimpleNeuralNet, self).__init__()
    7. self.fc = nn.Linear(10, 1) # 一个简单的全连接层
    8. def forward(self, x):
    9. # 定义前向传播
    10. x = F.relu(self.fc(x))
    11. return x
    12. # 创建一个实例
    13. net = SimpleNeuralNet()
    14. # 创建一个随机的输入张量
    15. input_tensor = torch.randn(1, 10)
    16. # 调用网络实例,就像调用一个函数
    17. output_tensor = net(input_tensor) # 这里实际上调用的是 net.forward(input_tensor)
    18. # 但我们通常不需要显式地调用 forward 方法

    Randomcrop函数:增加样本数据,用于学习。

    1. #RandomCrop
    2. trans_random=transforms.RandomCrop(512)
    3. trans_compose_2=transforms.Compose([trans_random,tensor_trans])
    4. for i in range(10):
    5. img_crop=trans_compose_2(img_PIL)
    6. writer.add_image('RandomCrop',img_crop,i)

    所有代码:

    1. from torchvision import transforms
    2. import cv2
    3. from torch.utils.tensorboard import SummaryWriter
    4. from PIL import Image
    5. writer=SummaryWriter("logs")
    6. img=cv2.imread('./data/train/ants/0013035.jpg')
    7. #ToTensor
    8. tensor_trans=transforms.ToTensor()
    9. img_tensor=tensor_trans(img)
    10. writer.add_image('tensor_img',img_tensor)
    11. #Normalize
    12. print(img_tensor[0][0][0])
    13. tensor_norm=transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])#main均值,std标准差
    14. img_norm=tensor_norm(img_tensor)#output[channel] = (input[channel] - mean[channel]) / std[channel]
    15. print(img_norm[0][0][0])
    16. writer.add_image('Normalize',img_norm,0)
    17. #Resize 输入必须是PIL格式或者是tensor格式
    18. img_PIL=Image.open('./data/train/ants/0013035.jpg')
    19. print(img_PIL.size)
    20. trans_resize=transforms.Resize((512,512))#高和宽
    21. img_resize=trans_resize(img_PIL)#返回值仍然是PIL格式或者tensor格式
    22. img_resize=tensor_trans(img_resize)
    23. writer.add_image('Resize',img_resize,0)
    24. print(img_resize)
    25. trans_resize_2=transforms.Resize(512)#只输入一个参数代表最短的那个边输出的像素点的数量
    26. trans_compose=transforms.Compose([trans_resize_2,tensor_trans])
    27. img_resize_2=trans_compose(img_PIL)
    28. writer.add_image('Resize_2',img_resize_2,0)
    29. #RandomCrop
    30. trans_random=transforms.RandomCrop(512)
    31. trans_compose_2=transforms.Compose([trans_random,tensor_trans])
    32. for i in range(10):
    33. img_crop=trans_compose_2(img_PIL)
    34. writer.add_image('RandomCrop',img_crop,i)
    35. writer.close()

  • 相关阅读:
    centos或aws linux部署java应用,环境搭建shell
    手把手教你做K均值聚类分析
    矩阵上下翻转
    Linux基础教程:10、进程通讯(管道通讯)
    (mac M1)Flutter环境搭建
    windows11系统封装教程
    面试题 ⑥
    C#操作PPT动画窗格并插入音频文件的一些思路
    《Happy Birthday》游戏开发记录(送给朋友的小礼物)
    Git系列之移动文件
  • 原文地址:https://blog.csdn.net/weixin_52307528/article/details/138768126