• PyTorch的数据处理



    💥今天看一下 PyTorch数据通常的处理方法~

    一般我们会将dataset用来封装自己的数据集,dataloader用于读取数据 

    Dataset格式说明 

    💬dataset定义了这个数据集的总长度,以及会返回哪些参数,模板:

    1. from torch.utils.data import Dataset
    2. class MyDataset(Dataset):
    3. def __init__(self, ):
    4. # 定义数据集包含的数据和标签
    5. def __len__(self):
    6. return len(...)
    7. def __getitem__(self, index):
    8. # 当数据集被读取时,返回一个包含数据和标签的元组
    9. return self.x_data[index], self.y_data[index]

    DataLoader格式说明

    1. my_dataset = DataLoader(mydataset, batch_size=2, shuffle=True,num_workers=4)
    2. # num_workers:多进程读取数据

    导入两个列表到Dataset

    1. class MyDataset(Dataset):
    2. def __init__(self, ):
    3. # 定义数据集包含的数据和标签
    4. self.x_data = [i for i in range(10)]
    5. self.y_data = [2*i for i in range(10)]
    6. def __len__(self):
    7. return len(self.x_data)
    8. def __getitem__(self, index):
    9. # 当数据集被读取时,返回一个包含数据和标签的元组
    10. return self.x_data[index], self.y_data[index]
    11. mydataset = MyDataset()
    12. my_dataset = DataLoader(mydataset)
    13. for x_i ,y_i in my_dataset:
    14. print(x_i,y_i)

    💬输出:

    1. tensor([0]) tensor([0])
    2. tensor([1]) tensor([2])
    3. tensor([2]) tensor([4])
    4. tensor([3]) tensor([6])
    5. tensor([4]) tensor([8])
    6. tensor([5]) tensor([10])
    7. tensor([6]) tensor([12])
    8. tensor([7]) tensor([14])
    9. tensor([8]) tensor([16])
    10. tensor([9]) tensor([18])

    💬如果修改batch_size为2,则输出:

    1. tensor([0, 1]) tensor([0, 2])
    2. tensor([2, 3]) tensor([4, 6])
    3. tensor([4, 5]) tensor([ 8, 10])
    4. tensor([6, 7]) tensor([12, 14])
    5. tensor([8, 9]) tensor([16, 18])
    • 我们可以看出,这是管理每次输出的批次的
    • 还可以控制用多少个线程来加速读取数据(Num Workers),这参数和电脑cpu核心数有关系,尽量不超过电脑的核心数

    导入Excel数据到Dataset中

    💥dataset只是一个类,因此数据可以从外部导入,我们也可以在dataset中规定数据在返回时进行更多的操作,数据在返回时也不一定是有两个。

    1. pip install pandas
    2. pip install openpyxl
    1. class myDataset(Dataset):
    2. def __init__(self, data_loc):
    3. data = pd.read_ecl(data_loc)
    4. self.x1,self.x2,self.x3,self.x4,self.y = data['x1'],data['x2'],data['x3'] ,data['x4'],data['y']
    5. def __len__(self):
    6. return len(self.x1)
    7. def __getitem__(self, idx):
    8. return self.x1[idx],self.x2[idx],self.x3[idx],self.x4[idx],self.y[idx]
    9. mydataset = myDataset(data_loc='e:\pythonProject Pytorch1\data.xls')
    10. my_dataset = DataLoader(mydataset,batch_size=2)
    11. for x1_i ,x2_i,x3_i,x4_i,y_i in my_dataset:
    12. print(x1_i,x2_i,x3_i,x4_i,y_i)

    导入图像数据集到Dataset

    需要安装opencv

    pip install opencv-python

    💯加载官方数据集 

    有一些数据集是PyTorch自带的,它被保存在TorchVision中,以mnist数据集为例进行加载:

  • 相关阅读:
    Linux入门教程||Linux系统目录结构
    WPF 控件分辨率自适应问题
    Pytest参数化:简化测试用例编写的利器
    【无标题】
    按摩 推拿上门服务小程序源码 家政上门服务系统源码
    MySQL:主从复制-基础复制(6)
    Android日历提醒增删改查事件、添加天数不对问题
    中仑网络全站 Dubbo 2 迁移 Dubbo 3 总结
    linux redis自启动
    【网络安全】护网
  • 原文地址:https://blog.csdn.net/qq_64685283/article/details/139235108