PyTorch
中的 Dataset
类是一个抽象类,用来表示数据集。通过继承 Dataset
类可以进行自定义数据集的格式、大小和其它属性,供后续使用;
可以看到官方封装好的数据集也是直接或间接的继承自 Dataset
类
import torch
from torch.utils.data import Dataset
# 自定义数据集继承 pytorch 内置的 Dataset 类
class GreenDataset(Dataset):
"""
重写构造函数
Args:
data_tensor 数据或数据集合
target_tensor 数据标签或数据标签集合
"""
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
# 重写 len 方法: return 数据集大小
def __len__(self):
return self.data_tensor.size(0)
# 重写 getitem 方法:基于索引,return 对应的数据及其标签,组合成 1 个元组返回
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
def test_data_set():
"""
自定义数据集测试
"""
# 生成数据集和标签集 (数据元素长度=标签元素长度)
# 10 行 3 列数据,可以理解为 10 个元素,每个元素是一维的 3个元素列表
data_tensor = torch.randn(10, 3)
# 对应方法 torch.randint(low, high, size)标签是 0 或 1 的 10 个元素
# low ( int , optional ) – 要从分布中提取的最小整数。默认值:0
# high ( int ) – 高于要从分布中提取的最高整数
# size ( tuple ) – 定义输出张量形状的元组
# 以下示例中 low 取默认值 0
target_tensor = torch.randint(2, (10,))
# 将数据封装成自定义数据集的 Dataset
my_dataset = GreenDataset(data_tensor, target_tensor)
# 调用方法:查看数据集大小
print('dataset size info:', len(my_dataset))
# 根据索引获取数据
print('tensor_data[0]: ', my_dataset[0])
# 打印数据集
for i, my_dataset in enumerate(my_dataset):
print('索引值:%s 数据:%s' % (i, my_dataset))
if __name__ == '__main__':
test_data_set()
torch.randn()
torch.randint()