#pic_center =400x
系列文章:
在 torch.utils.data.Dataset 提供了数据集的基类,我们只需要继承这个基类重写里面的方法即可完成数据集的加载与读取。
重写两个方法:
根据自己数据集的情况修改一下三个方法
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self):
self.lines = open("datasets/smsspamcollection/SMSSpamCollection", encoding="utf-8").readlines()
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
# strip取消换行
# cur_line = self.lines[index].strip()
# label = cur_line[:4].strip()
# content = cur_line[4:].strip()
cur_line = self.lines[index]
label = cur_line[:4]
content = cur_line[4:]
return label, content
my_dataset = MyDataset()
print(my_dataset[0])
print(len(my_dataset))
本文使用的数据集为开源的本文分类数据集SMS Spam Collection Data Set,下载地址为https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
数据集是从 Grumbletext 网站手动提取了 425 条 SMS 垃圾邮件的集合,由一个文本文件构成,其中每一行都是有一个类别和后面的原始消息构成。
DataLoader的主要作用是将Dataset处理后的数据集进行加载整合成batch用于后续训练,使用方法如下
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self):
self.lines = open("datasets/smsspamcollection/SMSSpamCollection", encoding="utf-8").readlines()
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
# strip取消换行
# cur_line = self.lines[index].strip()
# label = cur_line[:4].strip()
# content = cur_line[4:].strip()
cur_line = self.lines[index]
label = cur_line[:4]
content = cur_line[4:]
return label, content
my_dataset = MyDataset()
if __name__ == '__main__':
data_loader = DataLoader(dataset=my_dataset, batch_size=5, shuffle=True, num_workers=1)
for i in data_loader:
print(i)
break
实际项目中通常使用enumerate方法 读取每一个batch内容的同时也返回 batch的索引
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self):
self.lines = open("datasets/smsspamcollection/SMSSpamCollection", encoding="utf-8").readlines()
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
# strip取消换行
# cur_line = self.lines[index].strip()
# label = cur_line[:4].strip()
# content = cur_line[4:].strip()
cur_line = self.lines[index]
label = cur_line[:4]
content = cur_line[4:]
return label, content
my_dataset = MyDataset()
if __name__ == '__main__':
data_loader = DataLoader(dataset=my_dataset, batch_size=3, shuffle=True, num_workers=1)
for index, (label, content) in enumerate(data_loader):
print(index)
print(label)
print(content)
break