• Dataset类分批加载数据集


    在做NLP任务的时候,需要分批加载数据集进行训练,这个时候可以继承pytorch.utils.data中的Dataset类,就可以进行分批加载数据,并且可以将数据转换成tensor对象数据.
    处理流程:
    image.png

    1.自定义Dataset类

    这个类要配合的torch.utils.data 中的DataLoader类才可以发挥作用

    # 因为我在数据预处理的时候将转换成id的数据集全部持久化处理了,所以需要这个方法加载数据
    # 获取文件
    def load_pkl(path,obj_name):
        print(f'get{obj_name} in {path}')
        with codecs.open(path,'rb')as f:
            data=pkl.load(f)
        return data
    
    # 第三方库
    import torch
    from torch.utils.data import Dataset
    
    # 自定义库
    from BruceNRE.utils import load_pkl
    # 数据集的加载
    class CustomDataset(Dataset):
        def __init__(self,file_path,obj_name):
            self.file=load_pkl(file_path,obj_name)
    
        def __getitem__(self, item):
            sample=self.file[item]
            return sample
    
        def __len__(self):
            return len(self.file)
    
    # 这个方法负责将数据进行填充,并且转换成tensor对象
    def collate_fn(batch):
    # 把这个批次中的数据按照list长度由高到低排序
        batch.sort(key=lambda data: len(data[0]),reverse=True)
    # 将这个批次中数据长度放到len集合中
        lens=[len(data[0])for data in batch]
    # 获得最大的长度
        max_len=max(lens)
    
        sent_list=[]
        head_pos_list=[]
        tail_pos_list=[]
        mask_pos_list=[]
        relation_list=[]
    
        # 填充数据,都用0来填充
        def _padding(x,max_len):
            return x+[0]*(max_len-len(x))
    # 把数据集转换成tensor对象,然后封装到对应的list中
        for data in batch:
            sent,head_pos,tail_pos,mask_pos,relation=data
            sent_list.append(_padding(sent,max_len))
            head_pos_list.append(_padding(tail_pos,max_len))
            tail_pos_list.append(_padding(tail_pos,max_len))
            mask_pos_list.append(_padding(mask_pos,max_len))
            relation_list.append(relation)
    
        # 将numpy转换为tensor
        return torch.tensor(sent_list),torch.tensor(head_pos_list),torch.tensor(tail_pos_list),torch.tensor(mask_pos_list),torch.tensor(relation_list)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55

    这个类解释一下作用:

    • init方法:把所有数据集加载进来
    • getitem:如果设置suffle为True就会打乱数据,传递数据的索引给getitem,就是item,然后根据索引加载数据.
    • len:获取数据集的索引长度
    • collate_fn:因为使用DataLoader这个类要求每一个批次中的数据的长度必须要一样,所以这个方法有两个作用,第一个作用就是把数据集全部用0填充到相同的长度,然后将数据集(是转换成字典标志位的数据集)转换成tensor对象

    2.使用Dataset类

    # 调用Dataset实现类
    train_dataset=CustomDataset(train_data_path,'train-data')
    # 将train_dataset放到DataLoader中,才可以使用
    train_dataloader=DataLoader(
            dataset=train_dataset,
            batch_size=128,
            shuffle=True,
            drop_last=True,
            collate_fn=collate_fn
        )
    
        for batch_idx,batch in enumerate(train_dataloader):
            *x,y=[data.to(device) for data in batch]
        print('dataloader测试完成')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    参数解析:
    dataset:Dataset类封装的数据集
    batch_size:每个批次处理的数据量,一般128或者64
    shuffle:是否打乱顺序
    drop_last:丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。
    collate_fn:处理数据集成一样的长度,并且转换成tensor对象的方法

    ==============================================================

    3.再看个例子:

    我的原始数据格式:

    体验2D巅峰 倚天屠龙记十大创新概览	8
    60年铁树开花形状似玉米芯(组图)	5
    同步A股首秀:港股缩量回调	2
    中青宝sg现场抓拍 兔子舞热辣表演	8
    锌价难续去年辉煌	0
    2岁男童爬窗台不慎7楼坠下获救(图)	5
    布拉特:放球员一条生路吧 FIFA能消化俱乐部的攻击	7
    金科西府 名墅天成	1
    状元心经:考前一周重点是回顾和整理	3
    发改委治理涉企收费每年为企业减负超百亿	6
    一年网事扫荡10年纷扰开心网李鬼之争和平落幕	4
    2010英国新政府“三把火”或影响留学业	3
    俄达吉斯坦共和国一名区长被枪杀	6
    朝鲜要求日本对过去罪行道歉和赔偿	6
    《口袋妖怪 黑白》日本首周贩售255万	8
    图文:借贷成本上涨致俄罗斯铝业净利下滑21%	2
    组图:新《三国》再曝海量剧照 火战场面极震撼	9
    麻辣点评:如何走出“被留学”的尴尬	3
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 创建一个Dataset的子类来处理数据
    import torch
    from torch.utils.data import Dataset
    from transformers import BertTokenizer,BertConfig,BertModel
    bert_model='./bert-base-chinese'
    myconfig = BertConfig.from_pretrained("./bert-base-chinese")
    tokenizer=BertTokenizer.from_pretrained(bert_model)
    MAX_LEN = 256 - 2
    
    class ElementDataset(Dataset):
        def __init__(self, f_path):
            sents, label_li = [], []  # list of lists
            with open(f_path, 'r', encoding='utf-8') as fr:
                for line in fr:
                    if len(line) < 10:
                        continue
                    entries = line.strip().split('\t')
                    words = entries[0]
                    label = entries[1:]
                    label = list(map(int, label))
                    sents.append(words)
                    label_li.append(label)
            self.sents, self.label_li = sents, label_li
    
        def __getitem__(self, item):
            words,tags=self.sents[item],self.label_li[item]
            inputs=tokenizer.encode_plus(words)
            label=tags
            seqlen = len(inputs['input_ids'])
            sample=(inputs,label,seqlen)
            return sample
    
        def __len__(self):
            print('sents')
            return len(self.sents)
    
        # 填充
    def collate_fn(batch):
        all_input_ids=[]
        all_attention_mask=[]
        all_token_type_ids=[]
        all_labels=[]
        lens=[data[2] for data in batch]
        max_len=max(lens)
        def padding(input,max_len,pad_token):
            return input+[pad_token]*(max_len-len(input))
    
        for data in batch:
            input,tags,_=data
            all_input_ids.append(padding(input['input_ids'],max_len,1))
            all_token_type_ids.append(padding(input['token_type_ids'],max_len,0))
            all_attention_mask.append(padding(input['attention_mask'],max_len,0))
            all_labels.append(tags)
        return torch.tensor(all_input_ids),torch.tensor(all_token_type_ids),torch.tensor(all_attention_mask),all_labels
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 然后再调用的时候使用DataLoader加载数据
    train_data=ElementDataset(args.Train)
        test_data=ElementDataset(args.Test)
    
        train_iter=DataLoader(dataset=train_data,
                                   batch_size=10,
                                   shuffle=True,
                                   drop_last=True,
                                   collate_fn=collate_fn)
    
        test_iter =DataLoader(dataset=test_data,
                                     batch_size=10,
                                     shuffle=True,
                                     drop_last=True,
                                     collate_fn=collate_fn)
    # 可以使用一个for循环查看数据
        for i, batch in enumerate(iterator):
            input_ids,token_type_ids,attention_mask,labels= batch
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    batch就是每一个批次的数据,我设置的这个批次的数据是10个,则这个10个的数据的长度就是一样的长度,并且都是tensor格式.

  • 相关阅读:
    Linux设备使用阿里云盘终极方案
    Java基础(二十四):MySQL
    抗丙型肝炎病毒化合物库
    Java高手的30k之路|面试宝典|精通JVM(二)
    px4+vio实现无人机室内定位
    GCC 指令详解及动态库、静态库的使用
    彻底搞懂dfs与回溯
    解码自然语言处理之 Transformers
    考研分享第3期 | 211本378分上岸大连理工电子信息经验贴
    GRPC编译安装、各种语言插件及测试
  • 原文地址:https://blog.csdn.net/qq_35653657/article/details/126003653