目录
在这个操作中,当前数据集的上一级目录就是当前所有同一数据的label
- import os
- from torch.utils.data import Dataset
- from PIL import Image
-
- class MyDataset(Dataset):
- def __init__(self, root_dir, label_dir):
- """
- :param root_dir: 根目录文件
- :param label_dir: 分类标签目录
- """
- self.root_dir = root_dir
- self.label_dir = label_dir
- self.path = os.path.join(root_dir, label_dir)
- self.image_path_list = os.listdir(self.path)
- def __getitem__(self, idx):
- """
- :param idx: idx是自己文件夹下的每一个图片索引
- :return: 返回每一个图片对象和其对应的标签,对于返回类型可以直接调用image.show显示或者用于后续图像处理
- """
- img_name = self.image_path_list[idx]
- ever_image_path = os.path.join(self.root_dir, self.label_dir, img_name)
- image = Image.open(ever_image_path)
- label = self.label_dir
- return image, label
- def __len__(self):
- return len(self.image_path_list)
-
- root_dir = 'G:\python_files\深度学习代码库\cats_and_dogs_small\\train'
- label_dir = 'cats'
- my_data = MyDataset(root_dir, label_dir)
- first_pic, label = my_data[0] # 自动调用__getitem__(self, idx)
- first_pic.show()
- print("当前图片中动物所属label", label)
F:\Anaconda\envs\py38\python.exe G:/python_files/深度学习代码库/dataset/MyDataSet.py
当前图片中动物所属label cats
-
- import os
- from torch.utils.data import Dataset
- from PIL import Image
-
- class MyLabelData:
- def __init__(self, root_dir, target_dir, label_dir, label_name):
- """
- :param root_dir: 根目录
- :param target_dir: 生成标签的目录
- :param label_dir: 要生成为标签目录名称
- :param label_name: 生成的标签名称
- """
- self.root_dir = root_dir
- self.target_dir = target_dir
- self.label_dir = label_dir
- self.label_name = label_name
- self.image_name_list = os.listdir(os.path.join(root_dir, target_dir))
- def label(self):
- for name in self.image_name_list:
- file_name = name.split(".jpg", 1)[0]
- label_path = os.path.join(self.root_dir, self.label_dir)
- if not os.path.exists(label_path):
- os.makedirs(label_path)
- with open(os.path.join(label_path, '{}'.format(file_name)), 'w') as f:
- f.write(self.label_name)
- f.close()
- root_dir = 'G:\python_files\深度学习代码库\cats_and_dogs_small\\train'
- target_dir = 'cats'
- label_dir = 'cats_label'
- label_name = 'cat'
- label = MyLabelData(root_dir, target_dir, label_dir, label_name)
- label.label()
这样上面的代码中的训练集目录下的每一个样本都会在train的cats_label目录下创建其对应的分类标签
每一个标签中文件中都有一个cat字符串或者其他动物的分类名称,以确定它到底是哪一个动物
- # tensorboard --logdir=深度学习代码库/logs --port=2001
- from torch.utils.tensorboard import SummaryWriter
- writer = SummaryWriter('logs')
- for i in range(100):
- writer.add_scalar('当前的函数表达式y=3*x',i*3,i)
- writer.close()
- #-----------------------------------------------------------
- import numpy as np
- from PIL import Image
- image_PIL = Image.open('G:\python_files\深度学习代码库\cats_and_dogs_small\\train\cats\cat.1.jpg')
- image_numpy = np.array(image_PIL)
- print(type(image_numpy))
- print(image_numpy.shape)
- writer.add_image('cat图片', image_numpy,2, dataformats='HWC')
这里使用tensorboard的作用是为了更好的展示数据,但是对于函数的使用,比如上面的add_image中的参数,最好的方式是点击源码查看其对应的参数类型,然后根据实际需要将它所需的数据类型丢给add_image就好,而在源码中该函数的参数中所要求的图片类型必须是tensor类型或者是numpy,所以想要使用tensorboard展示数据就首先必须使用numpy或者使用transforms.Totensor将其转化为tensor,然后丢给add_image函数
还有一个需要注意的是,使用add_image函数,图片的tensor类型或者numpy类型必须和dataformats的默认数据类型一样,否则根据图片的数据类型修改后面的额dataformatas就好
- import numpy as np
- from torchvision import transforms
- from PIL import Image
- tran = transforms.ToTensor()
- PIL_image = Image.open('G:\python_files\深度学习代码库\\cats\cat\cat.11.jpg')
- tensor_pic = tran(PIL_image)
- print(tensor_pic)
- print(tensor_pic.shape)
- from torch.utils.tensorboard import SummaryWriter
- write = SummaryWriter('logs')
- write.add_image('Tensor_picture',tensor_pic)
tensor([[[0.9216, 0.9059, 0.8353, ..., 0.2392, 0.2275, 0.2078],
[0.9765, 0.9216, 0.8118, ..., 0.2431, 0.2392, 0.2235],
[0.9490, 0.8745, 0.7608, ..., 0.2471, 0.2471, 0.2314],
...,
[0.3490, 0.4902, 0.6667, ..., 0.7804, 0.7804, 0.7804],
[0.3412, 0.4431, 0.5216, ..., 0.7765, 0.7922, 0.7882],
[0.3490, 0.4510, 0.5294, ..., 0.7765, 0.7922, 0.7882]],
[[0.9451, 0.9294, 0.8706, ..., 0.2980, 0.2863, 0.2667],
[1.0000, 0.9451, 0.8471, ..., 0.3020, 0.2980, 0.2824],
[0.9725, 0.8980, 0.7961, ..., 0.2980, 0.2980, 0.2824],
...,
[0.3725, 0.5137, 0.6902, ..., 0.8431, 0.8431, 0.8431],
[0.3647, 0.4667, 0.5451, ..., 0.8392, 0.8549, 0.8510],
[0.3608, 0.4627, 0.5412, ..., 0.8392, 0.8549, 0.8510]],
[[0.9294, 0.9137, 0.8588, ..., 0.2235, 0.2118, 0.1922],
[0.9922, 0.9373, 0.8353, ..., 0.2275, 0.2235, 0.2078],
[0.9725, 0.8980, 0.7922, ..., 0.2275, 0.2275, 0.2118],
...,
[0.4196, 0.5608, 0.7373, ..., 0.9412, 0.9412, 0.9333],
[0.4196, 0.5216, 0.6000, ..., 0.9373, 0.9529, 0.9412],
[0.4196, 0.5216, 0.6000, ..., 0.9373, 0.9529, 0.9412]]])
torch.Size([3, 410, 431])
- # 对应三个通道,每一个通道一个平均值和方差
- # output[channel] = (input[channel] - mean[channel]) / std[channel]
- nor = transforms.Normalize([0.5, 0.5, 0.5],[10, 0.5, 0.5])
- print(tensor_pic[0][0][0])
- x_nor = nor(tensor_pic)
- write.add_image('nor_picture:', x_nor)
- print(tensor_pic[0][0][0])
- write.close()
打开源码查看
def forward(self, tensor: Tensor) -> Tensor: """ Args: tensor (Tensor): Tensor image to be normalized. Returns: Tensor: Normalized Tensor image. """ return F.normalize(tensor, self.mean, self.std, self.inplace)必须传入的是tensor数据类型
- size_tensor = transforms.Resize((512,512))
- # 裁剪tensor
- tensor_pic_size = size_tensor(tensor_pic)
- # 裁剪Image
- size_pic = transforms.Resize((512,512))
- image_size = size_pic(PIL_image)
- print(image_size)
- write.add_image('tensor_pic_size',tensor_pic_size)
- print(tensor_pic_size.shape)
- np_image = np.array(image_size)
- print('np_image.shape:', np_image.shape)
- write.add_image('image_size', np_image, dataformats='HWC')
调用Resize的时候,需要传入的数据类型的要求,查看源码如下
def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be scaled. Returns: PIL Image or Tensor: Rescaled image. """ return F.resize(img, self.size, self.interpolation)
torch.Size([3, 512, 512])
np_image.shape: (512, 512, 3)
- nor = transforms.Normalize([0.5, 0.5, 0.5],[10, 0.5, 0.5])
- trans_resize_2 = transforms.Resize((64,64))
- trans_to_tensor = transforms.ToTensor()
- trans_compose = transforms.Compose([trans_resize_2, trans_to_tensor])
- tensor_pic_compose = trans_compose(PIL_image)
- write.add_image('tensor_pic_compose',tensor_pic_compose,dataformats='CHW')
class Compose: """Composes several transforms together. This transform does not support torchscript. Please, see the note below. Args: transforms (list of ``Transform`` objects): list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.ToTensor(), >>> ]) .. note:: In order to script the transformations, please use ``torch.nn.Sequential`` as below. >>> transforms = torch.nn.Sequential( >>> transforms.CenterCrop(10), >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), >>> ) >>> scripted_transforms = torch.jit.script(transforms) Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require `lambda` functions or ``PIL.Image``. """ def __init__(self, transforms): self.transforms = transforms def __call__(self, img): for t in self.transforms: img = t(img) return img def __repr__(self): format_string = self.__class__.__name__ + '(' for t in self.transforms: format_string += '\n' format_string += ' {0}'.format(t) format_string += '\n)' return format_string
- from torch.utils.data import DataLoader
- from torchvision import transforms
- import torchvision
- data_transform = transforms.Compose([transforms.ToTensor()])
- train_data = torchvision.datasets.CIFAR10('./data', train=True, download=True)
- test_data = torchvision.datasets.CIFAR10('./data', train=False, download=True)
- print("train_data", train_data)
- # 原始的数据集中每一条数据中包含以一张图片和该图片所属的类别
- print("train_data[0]", train_data[0])
- print("train_data.classes", train_data.classes)
- image, label = train_data[0]
- print("label ",label)
- image.show()
- print("train_data.classes[label]", train_data.classes[label])
train_data Dataset CIFAR10
Number of datapoints: 50000
Root location: ./data
Split: Train
train_data[0] (
train_data.classes ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
label 6
train_data.classes[label] frog
- #%%
- from torchvision import transforms
- import torchvision
- # 将整个数据集转化为tensor类型
- data_transform1 = transforms.Compose([transforms.ToTensor()])
- train_data = torchvision.datasets.CIFAR10('./data', train=True, transform=data_transform1, download=True)
- test_data1 = torchvision.datasets.CIFAR10('./data', train=False, transform=data_transform1, download=True)
- from torch.utils.tensorboard import SummaryWriter
- write = SummaryWriter('batch_picture')
- for i in range(10):
- tensor_pic, label = train_data[i] # 经过前面的transforms成了tensor
- print(tensor_pic.shape)
- write.add_image('batch_picture', tensor_pic, i)
- write.close()
Files already downloaded and verified
Files already downloaded and verified
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'): """Add image data to summary. Note that this requires the ``pillow`` package. Args: tag (string): Data identifier img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data global_step (int): Global step value to record walltime (float): Optional override default walltime (time.time()) seconds after epoch of event Shape: img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job. Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``. Examples:: from torch.utils.tensorboard import SummaryWriter import numpy as np img = np.zeros((3, 100, 100)) img[0] = np.arange(0, 10000).reshape(100, 100) / 10000 img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000 img_HWC = np.zeros((100, 100, 3)) img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000 writer = SummaryWriter() writer.add_image('my_image', img, 0) # If you have non-default dimension setting, set the dataformats argument. writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC') writer.close() Expected result: .. image:: _static/img/tensorboard/add_image.png :scale: 50 % """ torch._C._log_api_usage_once("tensorboard.logging.add_image") if self._check_caffe2_blob(img_tensor): from caffe2.python import workspace img_tensor = workspace.FetchBlob(img_tensor) self._get_file_writer().add_summary( image(tag, img_tensor, dataformats=dataformats), global_step, walltime)
- from torchvision import transforms
- import torchvision
- # 将整个数据集转化为tensor类型
- data_transform = transforms.Compose([transforms.ToTensor()])
- train_data = torchvision.datasets.CIFAR10('./data', train=True, transform=data_transform, download=True)
- test_data = torchvision.datasets.CIFAR10('./data', train=False, transform=data_transform, download=True)
- # dataLoad会将原始数据中一个batch中的图片和图片的Label分别放在一起,形成对应
- train_data_load = DataLoader(dataset=train_data, shuffle=True, batch_size=64,)
- from torch.utils.tensorboard import SummaryWriter
- write = SummaryWriter('dataLoad')
- # 遍历整个load,一次遍历的图片是64个
- for batch_id, data in enumerate(train_data_load):
- # 经过DataLoda之后,每一个批次返回一批图片和该图片对应的标签类别
- print('data',data)
- batch_image, batch_label = data
- print("batch_id",batch_id)
- print("image.shape", batch_image.shape)
- print("label.shape", batch_label.shape)
- write.add_images('batch_load_picture', batch_image, batch_id, dataformats='NCHW')
- write.close()
其中一个批次的输出结果展示 batch_id 646 image.shape torch.Size([64, 3, 32, 32]) label.shape torch.Size([64]) data [tensor([[[[0.2510, 0.3804, 0.5176, ..., 0.5529, 0.5451, 0.2980], [0.2706, 0.6000, 0.6667, ..., 0.5686, 0.3961, 0.1176], [0.2745, 0.6627, 0.6980, ..., 0.3961, 0.1608, 0.0824], ..., [0.6863, 0.6824, 0.5333, ..., 0.2941, 0.4863, 0.5059], [0.5804, 0.6784, 0.4902, ..., 0.1451, 0.2824, 0.3451], [0.4353, 0.4353, 0.5098, ..., 0.1373, 0.1529, 0.2902]], [[0.3020, 0.4549, 0.6078, ..., 0.6627, 0.6353, 0.3608], [0.3451, 0.6980, 0.7765, ..., 0.6745, 0.4706, 0.1647], [0.3490, 0.7529, 0.8039, ..., 0.4667, 0.2000, 0.1137], ..., [0.8196, 0.8157, 0.6157, ..., 0.3608, 0.5529, 0.5804], [0.7137, 0.8039, 0.5686, ..., 0.1922, 0.3373, 0.4078], [0.5412, 0.5333, 0.5765, ..., 0.1765, 0.2000, 0.3490]], [[0.3098, 0.5490, 0.7412, ..., 0.8314, 0.7373, 0.3765], [0.3765, 0.8392, 0.9569, ..., 0.7686, 0.4941, 0.1216], [0.3843, 0.9176, 1.0000, ..., 0.4627, 0.1490, 0.0588], ..., [0.9843, 0.9922, 0.7373, ..., 0.3882, 0.6353, 0.7255], [0.8039, 0.9373, 0.6745, ..., 0.1804, 0.3647, 0.4941], [0.6471, 0.6549, 0.6980, ..., 0.1569, 0.2000, 0.3961]]], [[[0.9608, 0.9490, 0.9529, ..., 0.8314, 0.8196, 0.8235], [0.9255, 0.9216, 0.9333, ..., 0.8275, 0.8196, 0.8235], [0.9137, 0.9137, 0.9294, ..., 0.8392, 0.8314, 0.8353], ..., [0.4118, 0.4353, 0.4431, ..., 0.4157, 0.4431, 0.4275], [0.4667, 0.4667, 0.4627, ..., 0.3961, 0.3804, 0.3882], [0.4392, 0.4235, 0.4235, ..., 0.5490, 0.4471, 0.4706]], [[0.9647, 0.9529, 0.9529, ..., 0.8745, 0.8667, 0.8667], [0.9294, 0.9255, 0.9333, ..., 0.8627, 0.8549, 0.8549], [0.9137, 0.9176, 0.9294, ..., 0.8627, 0.8588, 0.8549], ..., [0.4196, 0.4392, 0.4471, ..., 0.4314, 0.4627, 0.4510], [0.4745, 0.4745, 0.4706, ..., 0.4078, 0.4039, 0.4118], [0.4471, 0.4314, 0.4314, ..., 0.5608, 0.4667, 0.4863]], [[0.9765, 0.9686, 0.9647, ..., 0.9412, 0.9373, 0.9569], [0.9451, 0.9412, 0.9529, ..., 0.9216, 0.9216, 0.9373], [0.9451, 0.9451, 0.9569, ..., 0.9176, 0.9176, 0.9333], ..., [0.4078, 0.4314, 0.4353, ..., 0.4353, 0.4706, 0.4588], [0.4627, 0.4627, 0.4588, ..., 0.4118, 0.4118, 0.4157], [0.4353, 0.4196, 0.4196, ..., 0.5569, 0.4627, 0.4863]]], [[[0.9569, 0.9569, 0.9647, ..., 0.8510, 0.8353, 0.8235], [0.9569, 0.9569, 0.9608, ..., 0.8627, 0.8431, 0.8392], [0.9804, 0.9725, 0.9725, ..., 0.8745, 0.8627, 0.8549], ..., [0.3725, 0.3882, 0.3922, ..., 0.3647, 0.3725, 0.3686], [0.3882, 0.4000, 0.4157, ..., 0.3882, 0.3804, 0.3608], [0.3882, 0.4000, 0.4118, ..., 0.3725, 0.3608, 0.3490]], [[0.9608, 0.9608, 0.9686, ..., 0.8706, 0.8549, 0.8392], [0.9608, 0.9608, 0.9686, ..., 0.8784, 0.8549, 0.8510], [0.9843, 0.9765, 0.9804, ..., 0.8863, 0.8745, 0.8627], ..., [0.3804, 0.3922, 0.3961, ..., 0.3255, 0.3529, 0.3686], [0.3961, 0.4078, 0.4235, ..., 0.3647, 0.3686, 0.3647], [0.3961, 0.4078, 0.4196, ..., 0.3843, 0.3686, 0.3569]], [[0.9843, 0.9765, 0.9804, ..., 0.9294, 0.9176, 0.9137], [0.9804, 0.9686, 0.9725, ..., 0.9216, 0.9059, 0.9098], [0.9961, 0.9804, 0.9765, ..., 0.9137, 0.9098, 0.9098], ..., [0.3725, 0.3882, 0.3922, ..., 0.2902, 0.3255, 0.3686], [0.3922, 0.4039, 0.4196, ..., 0.3412, 0.3490, 0.3608], [0.3922, 0.4039, 0.4157, ..., 0.3843, 0.3686, 0.3529]]], ..., [[[0.8902, 0.8863, 0.8824, ..., 0.8314, 0.8392, 0.8353], [0.8902, 0.8863, 0.8863, ..., 0.8353, 0.8431, 0.8392], [0.8902, 0.8863, 0.8902, ..., 0.8392, 0.8431, 0.8431], ..., [0.9569, 0.9529, 0.9569, ..., 0.5765, 0.5843, 0.5961], [0.9686, 0.9647, 0.9608, ..., 0.9412, 0.9255, 0.9255], [0.9804, 0.9765, 0.9725, ..., 0.9255, 0.9176, 0.9176]], [[0.9176, 0.9137, 0.9098, ..., 0.8667, 0.8745, 0.8706], [0.9176, 0.9137, 0.9137, ..., 0.8706, 0.8784, 0.8745], [0.9176, 0.9137, 0.9176, ..., 0.8784, 0.8824, 0.8784], ..., [0.9608, 0.9569, 0.9608, ..., 0.6392, 0.6667, 0.6706], [0.9765, 0.9725, 0.9647, ..., 0.9608, 0.9765, 0.9725], [0.9882, 0.9843, 0.9804, ..., 0.9255, 0.9451, 0.9490]], [[0.9412, 0.9373, 0.9333, ..., 0.9255, 0.9333, 0.9294], [0.9412, 0.9373, 0.9373, ..., 0.9294, 0.9373, 0.9333], [0.9412, 0.9373, 0.9412, ..., 0.9294, 0.9333, 0.9333], ..., [0.9686, 0.9647, 0.9686, ..., 0.6667, 0.6824, 0.6863], [0.9725, 0.9686, 0.9647, ..., 0.9804, 0.9804, 0.9804], [0.9843, 0.9804, 0.9765, ..., 0.9373, 0.9451, 0.9490]]], [[[0.1725, 0.1725, 0.1804, ..., 0.1255, 0.1255, 0.1255], [0.1922, 0.1882, 0.1843, ..., 0.1333, 0.1373, 0.1333], [0.1961, 0.1922, 0.1882, ..., 0.1412, 0.1412, 0.1333], ..., [0.4471, 0.4902, 0.5137, ..., 0.5647, 0.5725, 0.5961], [0.4431, 0.4706, 0.4824, ..., 0.5608, 0.5529, 0.5569], [0.4275, 0.4431, 0.4392, ..., 0.6078, 0.5608, 0.5176]], [[0.0980, 0.0980, 0.1059, ..., 0.0353, 0.0353, 0.0392], [0.1137, 0.1137, 0.1098, ..., 0.0431, 0.0471, 0.0471], [0.1216, 0.1176, 0.1137, ..., 0.0549, 0.0549, 0.0549], ..., [0.2471, 0.2824, 0.3529, ..., 0.5490, 0.5451, 0.5608], [0.2510, 0.2980, 0.3765, ..., 0.5569, 0.5294, 0.5255], [0.2471, 0.3059, 0.3765, ..., 0.6078, 0.5451, 0.4902]], [[0.0431, 0.0431, 0.0510, ..., 0.0118, 0.0118, 0.0118], [0.0588, 0.0588, 0.0549, ..., 0.0118, 0.0118, 0.0118], [0.0667, 0.0627, 0.0588, ..., 0.0118, 0.0118, 0.0118], ..., [0.2431, 0.2745, 0.3176, ..., 0.5373, 0.5608, 0.5804], [0.2510, 0.2824, 0.3294, ..., 0.5490, 0.5412, 0.5412], [0.2510, 0.2863, 0.3216, ..., 0.6000, 0.5529, 0.4980]]], [[[0.6353, 0.6314, 0.6314, ..., 0.6157, 0.6157, 0.6157], [0.6353, 0.6314, 0.6314, ..., 0.6157, 0.6157, 0.6157], [0.6353, 0.6314, 0.6314, ..., 0.6157, 0.6157, 0.6157], ..., [0.6471, 0.6431, 0.6431, ..., 0.6392, 0.6392, 0.6392], [0.6471, 0.6431, 0.6431, ..., 0.6392, 0.6392, 0.6392], [0.6471, 0.6431, 0.6431, ..., 0.6392, 0.6392, 0.6392]], [[0.7804, 0.7765, 0.7765, ..., 0.7725, 0.7725, 0.7686], [0.7804, 0.7765, 0.7765, ..., 0.7725, 0.7725, 0.7686], [0.7804, 0.7765, 0.7765, ..., 0.7725, 0.7725, 0.7686], ..., [0.7922, 0.7882, 0.7882, ..., 0.7843, 0.7843, 0.7843], [0.7922, 0.7882, 0.7882, ..., 0.7843, 0.7843, 0.7843], [0.7922, 0.7882, 0.7882, ..., 0.7843, 0.7843, 0.7843]], [[0.9882, 0.9804, 0.9843, ..., 0.9765, 0.9765, 0.9765], [0.9882, 0.9804, 0.9843, ..., 0.9765, 0.9765, 0.9765], [0.9882, 0.9804, 0.9843, ..., 0.9765, 0.9765, 0.9765], ..., [0.9961, 0.9882, 0.9922, ..., 0.9882, 0.9882, 0.9882], [0.9961, 0.9882, 0.9922, ..., 0.9882, 0.9882, 0.9882], [0.9961, 0.9882, 0.9922, ..., 0.9882, 0.9882, 0.9882]]]]), tensor([2, 8, 9, 6, 9, 3, 8, 3, 7, 7, 7, 3, 9, 2, 3, 1, 0, 1, 9, 6, 7, 6, 7, 9, 1, 1, 8, 9, 2, 7, 5, 0, 1, 5, 9, 4, 2, 5, 7, 6, 3, 2, 2, 9, 4, 2, 1, 1, 9, 5, 2, 5, 0, 8, 1, 7, 3, 5, 8, 0, 5, 0, 5, 0])]
使用add_images对所有批次的数据进行展示
def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'): """Add batched image data to summary. Note that this requires the ``pillow`` package. Args: tag (string): Data identifier img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data global_step (int): Global step value to record walltime (float): Optional override default walltime (time.time()) seconds after epoch of event dataformats (string): Image data format specification of the form NCHW, NHWC, CHW, HWC, HW, WH, etc. Shape: img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be accepted. e.g. NCHW or NHWC. Examples:: from torch.utils.tensorboard import SummaryWriter import numpy as np img_batch = np.zeros((16, 3, 100, 100)) for i in range(16): img_batch[i, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 / 16 * i img_batch[i, 1] = (1 - np.arange(0, 10000).reshape(100, 100) / 10000) / 16 * i writer = SummaryWriter() writer.add_images('my_image_batch', img_batch, 0) writer.close() Expected result: .. image:: _static/img/tensorboard/add_images.png :scale: 30 % """ torch._C._log_api_usage_once("tensorboard.logging.add_images") if self._check_caffe2_blob(img_tensor): from caffe2.python import workspace img_tensor = workspace.FetchBlob(img_tensor) self._get_file_writer().add_summary( image(tag, img_tensor, dataformats=dataformats), global_step, walltime)在使用add_images时要注意默认的通道数是3,如果经过卷积层以后的图片通道数大于3,那么是无法使用该函数进行显示的,会显示断言错误的信息,所以此时要使用torch.reshape将通道数变为3,然后可以正常调用
对于还未涉及的方法也是这样,查看其对应的参数类型(使用crtl+p,或者直接crtl+鼠标点击相应的函数查看源码),将所需要的参数类型丢给它使用就好