• 基于CNN-RNN的医疗文本生成


    🐱 基于CNN-RNN的医疗文本生成

    本项目使用经过IMAGENET预训练的resnet101网络对图像特征进行提取后,
    将图像特征输入LSTM来生成影像的文本描述。

    初步实现了图像到文本的简单生成。


    📖 0 项目背景

    随着近年来深度学习的飞速发展,深度学习在医疗行业展现出巨大的发展潜力。因此,如果能通过深度学习的方法,使用计算机代替医生进行机械的影像报告撰写工作,这样既避免了经验不足的医生在阅片诊断中产生的误诊情况,又使得更多的资深医生可以从繁重的重复性工作中解脱出来,将更多的时间投入病人的诊治中去。

    医学影像报告自动生成是近年来计算机与医疗图像新兴的交叉方向之一。目前,影像报告自动生成模型主要借鉴了机器翻译领域的 Encoder-Decoder 框架,利用卷积
    神经网络(Convolutional Neural Network, CNN)对图像特征进行提取进而利用循环神经网络(Recurrent Neural Network, RNN)来生成影像的文本描述

    📌 1 数据集

    印第安纳大学胸部 X 射线集合 (IU X 射线) 是一组胸部 X 射线图像及其相应的诊断报告。该数据集包含 7,470 对图像和报告(6470:500:500)。 每个报告由以下部分组成:印象、发现、标签、比较和指示。平均每张图像关联2.2个标签,5.7个句子,每个句子包含6.5个单词。

    本项目仅使用FINDINGS部分作为图像生成标签

    参考代码:

    • https://github.com/chenyuntc/pytorch-book/tree/master/chapter10-image_caption

    🐰 2 数据集生成

    🐅 2.1 医疗文本CSV生成


    解压原始数据,对xml格式的数据进行解析,提取图像文件名和对应的FINDINGS,并生成CSV文件。

    # 解压数据集
    !unzip  -o  data/data123482/IU数据集.zip -d /home/aistudio/work/
    
    • 1
    • 2
      inflating: /home/aistudio/work/IU数据集/NLMCXR_reports/ecgen-radiology/1504.xml  
    
    • 1
    # 取消警告的输出
    import warnings
    warnings.filterwarnings("ignore")
    
    • 1
    • 2
    • 3
    ## 制作CSV数据集
    # 平均字符数为 31.64992700729927
    import os
    import glob
    import pandas as pd
    from xml.dom import minidom
    import re
    import numpy as np
    LENGTH = []
    
    def EmptyDrop(data):
        for i in range(len(data)):
            if data.loc[i,'dir'] ==[] or data.loc[i,'caption'] ==[]:
                #如果为空,则删除该行
                data.drop([i],axis = 0,inplace = True) 
            else:
                data.loc[i,'dir'] = data.loc[i,'dir'][0]
                data.loc[i,'caption'] = data.loc[i,'caption'][0]
        data.reset_index(drop = True,inplace = True)
        return data
    
    def clean_text(origin_text):
        # 去掉标点和非法字符
        text = re.sub("^a-zA-Z"," ",origin_text)
        #大写改小写
        cleaned_text = text.lower()
        return cleaned_text
    def xml2csv(path):
        num = 0
        column_name = ['dir','caption']
        xml_csv = pd.DataFrame(columns = column_name)
        #图片保存地址
        pic_path =  'work/IU数据集/NLMCXR_png'
        for xml_file in glob.glob(path+'/*.xml'):
    
            #记录每个xml需要保存的所有信息 fx 地址 IMPRESSION FINDINGS 
            xml_list = []
            #打开xml文档
            dom = minidom.parse(xml_file)
            #得到文档元素对象
            root = dom.documentElement
            # f1 地址
            itemlists=root.getElementsByTagName('parentImage')
            #记录地址
            dirAll = []
            for itemlist in itemlists:
                figureId=itemlist.getElementsByTagName('figureId')
                #找出该图片的figureID
                figure = figureId[0].childNodes[0].nodeValue
                #找出该图片的名称
                ID= itemlist.getAttribute('id')
                IdPath = ID
                #正面图&侧面图
                figurePath = [figure+' '+IdPath]
                dirAll.extend(figurePath)
            xml_list.append(dirAll)
    
            #记录FINDINGS and IMPRESSION
            #记录内容
            CaptionAll = []
            itemlists=root.getElementsByTagName('AbstractText')
            for i in range(len(itemlists)):
                Label= itemlists[i].getAttribute('Label')
                if Label == 'FINDINGS': # or Label == 'IMPRESSION':
                    # 内容不为空
                    if len(itemlists[i].childNodes)!=0:
                        text = itemlists[i].childNodes[0].nodeValue
                        #转小写,过滤无效字符
                        text = clean_text(text)
                        text = text.replace('.','')
                        text = text.replace(',','')
                        text = [text+'']                    
                        CaptionAll.extend(text)
            if len(CaptionAll)>=1:
                LENGTH.append(len(CaptionAll[0].split(' ')))
            xml_list.append(CaptionAll)
            xml_csv.loc[num] = [item for item in xml_list]
            num = num + 1
            print('epoch[{}/{}]'.format(num,len(glob.glob(path+'/*.xml'))))
       # print(np.mean(LENGTH))
        return xml_csv
    
    def main():
        xml_path = os.path.join('work','IU数据集','NLMCXR_reports','ecgen-radiology')
        csv = xml2csv(xml_path)
        csv1 = EmptyDrop(csv)
        csv1.to_csv('work/IUxRay.csv',index=None)
        
    if __name__ == '__main__':
        main()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91

    🐅 2.2 图像特征提取


    • 使用ImageNet预训练的resnet101模型提取图像特征(删除最后的全连接层,改为恒等映射)。

    • 将数据保存为h5文件

    ## 使用resnet101预训练模型提取图像特征
    import paddle
    from paddle.vision.models import resnet101
    import h5py
    import cv2
    
    csv_file = pd.read_csv('work/IUxRay.csv')
    h5_png_file =list(csv_file['dir'])
    
    # 创建保存目录
    save_path = 'work/util_IUxRay'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    # 导入模型resnet101 使用谷歌预训练
    model = resnet101(pretrained=True)
    # 删除最后的全连接层
    del model.fc
    model.fc = lambda x:x
    
    h5f = h5py.File(os.path.join(save_path,'resnet101_festures.h5'), 'w')
    
    for idx,item in enumerate(h5_png_file):
        # 读取所有PNG(F1,F2...)
        print(idx,len(h5_png_file))
        item_all = item.split(',')
        for item_t in item_all:
            item_t = item_t.replace('\'','').replace('[','').replace(']','')
            # 对不同朝向的图进行区分
            for orie in ['F1','F2','F3','F4']:
                if orie in item_t:
                    orie_fin = orie
                    item_fin = item_t.replace(orie,'').replace(' ','')
                    item_fin_png = item_fin + '.png'
                    print(orie_fin + '_' + item_fin)
                    # 读取文件送入模型提取特征并保存为h5
                    img = cv2.imread(os.path.join('work/IU数据集/NLMCXR_png',item_fin_png))
                    # BGR转RGB,且HWC转CHW
                    img = img[:,:,::-1].transpose((2,0,1))
                    # 扩展维度
                    img = np.expand_dims(img, 0)
                    img_tensor = paddle.to_tensor(img,dtype='float32' ,place=paddle.CPUPlace())
                    # 进行特征提取
                    out = model(img_tensor)
                    data = out.numpy().astype('float32')
                    # 保存的数据为h5
                    save_path_h5 = data[0]
                    h5f.create_dataset(orie_fin + '_' + item_fin, data=save_path_h5)
    h5f.close()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    # 读取h5文件
    import h5py
    h5f = h5py.File('work/util_IUxRay/resnet101_festures.h5','r')
    data = h5f['F1_CXR3027_IM-1402-1001'] # 第一个下划线 之前为图片朝向,之后为图像原命名
    print(np.array(data).shape) # 每个图片保存为一个2048维度的向量
    h5f.close() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    🐅 2.3 字典生成


    • 统计训练数据,按照单词进行分割创建字典。

    • 字典修正:删除仅在数据集中出现过一次的单词

    # 统计训练数据,以单词为粒度创建字典
    import pandas as pd
    import numpy as np
    import re
    
    csv_file = pd.read_csv('work/IUxRay.csv')
    csv_file.head()
    
    CaptionWordAll = []
    CaptionWordLength = []
    for idx,data_ in enumerate(csv_file.iterrows()):
        caption = data_[1][1]
        CaptionWordLength.append(len(caption.split(' ')))
        CaptionWordAll.extend(caption.split(' '))
    
    print('平均句子长度为:',np.mean(CaptionWordLength))
    print('最大句子长度为:',np.max(CaptionWordLength))
    print('最小句子长度为:',np.min(CaptionWordLength))
    
    print('单词总量为:',len(CaptionWordAll))
    print('字典长度为:',len(set(CaptionWordAll)))
    
    # 100
    
    from collections import Counter
    # 统计频率,按照从高到底排序,这样构建的字典使用频率最高的符号在最前面,查找起来快
    counts = Counter(CaptionWordAll)
    count_sorted = counts.most_common()
    count_sorted_ = {k: v for k, v in count_sorted if v > 1}
    
    # 构造字典
    # 增加 <pad> 0 <unk> 1 <start> 2 <end> 3  四个作为常用符号
    
    word2id_dict={'<pad>':0,'<unk>':1,'<start>':2,'<end>':3}
    id2word_dict={0:'<pad>',1:'<unk>',2:'<start>',3:'<end>'}
    
    for idx,item in enumerate(count_sorted_):
        idx_ = idx+4 # 预留四个做为记录
        item_ = item
        word2id_dict[item_] = idx_
        id2word_dict[idx_] = item_
    
    # 删除只出现一次的单词
    print('修正后字典长度为:',len(word2id_dict))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45

    🥝 3 定义数据读取类


    • 将数据按照8:2划分为训练集和验证集。

    • 将文本数据经过字典进行映射,不同于翻译任务,本任务用图像特征替代了(85行)。

    ## 完成dataload
    
    import paddle
    from paddle.io import Dataset
    import numpy as np
    from sklearn.model_selection import train_test_split
    
    # 重写数据读取类
    class CaptionDataset(Dataset):
        # 构造函数,定义函数参数
        def __init__(self,csvData,word2id_dict,h5f,maxlength = 40,mode = 'train'):
            self.mode = mode
            self.w2i_dict = word2id_dict
            self.maxlength = maxlength # 输入的最长字符数
            self.padid = 0 # 0为填充符号
            self.h5f = h5f
            # 根据train/test 将数据按比例处理
            train,test =csvData.iloc[:int(0.8*len(csvData)),:],csvData.iloc[int(0.8*len(csvData)):,:] #train_test_split(csvData,train_size=0.8,random_state=10)
    
            if self.mode == 'train':
                train.reset_index(drop=True)
                self.data = train
            else:
                test.reset_index(drop=True)
                self.data = test
    
        # 实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        def __getitem__(self, index):
    
            path_name, trg_  = self.data.iloc[index,:]
            
            # 读取图像的特征
            temp  = path_name.split(' ')
            names = '_'.join(temp)
            img_feature = np.array(self.h5f[names]) # 第一个下划线 之前为图片朝向,之后为图像原命名
    
            # 输入转成idx
            trg,trg_length = self.generIdxList(trg_) # data
            img_name = temp[-1]
            return img_feature,trg,trg_length,img_name
    
        # 实现__len__方法,返回数据集总数目
        def __len__(self):
            return len(self.data)
        
        def generIdxList(self,tdata):
            # 从输入的String中,生成idx的List
            data = tdata.split(' ')
            data_out = []
            # 限制长度,输入'<start>' 和 '<end>'
            data = ['<start>'] + data 
            
            if len(data)>self.maxlength-1:
                data = data[:self.maxlength-1] # 留一个位置给'<end>
                data = data + ['<end>']
    
            else:
                # 占位符
                occupy_ = ['<pad>'] * (self.maxlength - 1 - len(data))
                data = data + ['<end>']
                data = data + occupy_
    
            # word 2 index
            for word in data:
    
                if self.w2i_dict.get(word)!= None: # 能找到word
                    id_ = self.w2i_dict[word]
                    data_out.append(id_)
                else:
                    id_ = self.w2i_dict['<unk>']
                    data_out.append(id_)
    
            length = len(data_out)-1
    
            return data_out,length
    
    def stackInput(inputs):
    
        img_features = np.stack([inputsub[0] for inputsub in inputs], axis=0)
        trg = np.stack([inputsub[1] for inputsub in inputs], axis=0)
        trg_length = np.stack([inputsub[2] for inputsub in inputs], axis=0)
    
    
        trg_mask =(trg[:,:-1]!=0).astype(paddle.get_default_dtype())
        trg_ = trg[:,1:] # 将start标记更改为 imgfeatures
        return img_features,trg_length,trg_[:,:-1],trg[:,1:,np.newaxis],trg_mask
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    # 测试数据读取
    import pandas as pd
    import numpy as np
    import h5py
    from sklearn.model_selection import train_test_split
    csvData = pd.read_csv('work/IUxRay.csv')
    h5f = h5py.File('work/util_IUxRay/resnet101_festures.h5','r')
    
    maxlength = 40
    
    dataset=CaptionDataset(csvData,word2id_dict,h5f,maxlength,'train')
    data_loader = paddle.io.DataLoader(dataset, batch_size=1,collate_fn = stackInput, shuffle=False)
    
    for item in data_loader:
        print(item[0].shape,item[1].shape,item[2].shape,item[3].shape,item[4].shape)
        break
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    💡 4 定义模型


    • 定义LSTM模型用于文本生成

    • 定义beam search算法对生成结果进行优化

    # 定义模型
    import paddle.nn as nn
    import paddle
    class CaptionModel(paddle.nn.Layer):
    
        def __init__(self, vocab_size,embedding_dim,hidden_size,num_layers,word2id_dict,id2word_dict):
            
            super(CaptionModel,self).__init__()
    
            self.hidden_size=hidden_size
            self.num_layers=num_layers        
            self.fc = paddle.nn.Linear(2048,embedding_dim)
            self.embedding=paddle.nn.Embedding(vocab_size,embedding_dim)
            self.rnn=paddle.nn.LSTM(input_size=embedding_dim,
                                    hidden_size=hidden_size,
                                    num_layers=num_layers)
            self.word2ix = word2id_dict
            self.ix2word = id2word_dict
            self.classifier = paddle.nn.Linear(hidden_size,vocab_size)
    
        def forward(self,img_features,trg,trg_length):
            img_features = paddle.unsqueeze(self.fc(img_features),axis = 1)
            embeddings = self.embedding(trg)
            inputs = paddle.concat([img_features,embeddings],axis = 1)
            outputs,state = self.rnn(inputs,sequence_length  = trg_length)
            predict = self.classifier(outputs)
    
            return predict
    
        def generate(self, img_feat, eos_token='<end>',
                     beam_size=2,
                     max_caption_length=40,
                     length_normalization_factor=0.0):
            """
            根据图片生成描述,主要是使用beam search算法以得到更好的描述
            """
            cap_gen = CaptionGenerator(embedder=self.embedding,
                                       rnn=self.rnn,
                                       classifier=self.classifier,
                                       eos_id=self.word2ix[eos_token],
                                       beam_size=beam_size,
                                       max_caption_length=max_caption_length,
                                       length_normalization_factor=length_normalization_factor)
    
            img_feat = paddle.unsqueeze(img_feat,axis = 0)
            img = paddle.unsqueeze(self.fc(img_feat),axis = 0)
    
            sentences, score = cap_gen.beam_search(img)
    
            sentences = [' '.join([self.ix2word[int(idx)] for idx in sent])
                         for sent in sentences]
            return sentences
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    # Beam Search
    import paddle.nn as nn
    import heapq
    
    class TopN(object):
        """Maintains the top n elements of an incrementally provided set."""
    
        def __init__(self, n):
            self._n = n
            self._data = []
    
        def size(self):
            assert self._data is not None
            return len(self._data)
    
        def push(self, x):
            """Pushes a new element."""
            assert self._data is not None
            if len(self._data) < self._n:
                heapq.heappush(self._data, x)
            else:
                heapq.heappushpop(self._data, x)
    
        def extract(self, sort=False):
            """
            Extracts all elements from the TopN. This is a destructive operation.
            The only method that can be called immediately after extract() is reset().
            Args:
              sort: Whether to return the elements in descending sorted order.
            Returns:
              A list of data; the top n elements provided to the set.
            """
            assert self._data is not None
            data = self._data
            self._data = None
            if sort:
                data.sort(reverse=True)
            return data
    
        def reset(self):
            """Returns the TopN to an empty state."""
            self._data = []
    
    class Caption(object):
        """Represents a complete or partial caption."""
    
        def __init__(self, sentence, state, logprob, score, metadata=None):
            """Initializes the Caption.
            Args:
              sentence: List of word ids in the caption.
              state: Model state after generating the previous word.
              logprob: Log-probability of the caption.
              score: Score of the caption.
              metadata: Optional metadata associated with the partial sentence. If not
                None, a list of strings with the same length as 'sentence'.
            """
            self.sentence = sentence
            self.state = state
            self.logprob = logprob
            self.score = score
            self.metadata = metadata
    
        def __cmp__(self, other):
            """Compares Captions by score."""
            assert isinstance(other, Caption)
            if self.score == other.score:
                return 0
            elif self.score < other.score:
                return -1
            else:
                return 1
    
        # For Python 3 compatibility (__cmp__ is deprecated).
        def __lt__(self, other):
            assert isinstance(other, Caption)
            return self.score < other.score
    
        # Also for Python 3 compatibility.
        def __eq__(self, other):
            assert isinstance(other, Caption)
            return self.score == other.score
    
    class CaptionGenerator(object):
        """Class to generate captions from an image-to-text model."""
    
        def __init__(self,
                     embedder,
                     rnn,
                     classifier,
                     eos_id,
                     beam_size=3,
                     max_caption_length=100,
                     length_normalization_factor=0.0):
            """Initializes the generator.
            Args:
              model: recurrent model, with inputs: (input, state) and outputs len(vocab) values
              beam_size: Beam size to use when generating captions.
              max_caption_length: The maximum caption length before stopping the search.
              length_normalization_factor: If != 0, a number x such that captions are
                scored by logprob/length^x, rather than logprob. This changes the
                relative scores of captions depending on their lengths. For example, if
                x > 0 then longer captions will be favored.
            """
            self.embedder = embedder
            self.rnn = rnn
            self.classifier = classifier
            self.eos_id = eos_id
            self.beam_size = beam_size
            self.max_caption_length = max_caption_length
            self.length_normalization_factor = length_normalization_factor
    
        def beam_search(self, rnn_input, initial_state=None):
            """Runs beam search caption generation on a single image.
            Args:
              initial_state: An initial state for the recurrent model
            Returns:
              A list of Caption sorted by descending score.
            """
    
            def get_topk_words(embeddings, state):
                output, new_states = self.rnn(embeddings, state)
                output = self.classifier(paddle.squeeze(output,axis=0))
                logprobs = nn.functional.log_softmax(output, axis=-1)
    
                if len(logprobs.shape) == 3:
                    logprobs = paddle.squeeze(logprobs)
                logprobs, words = logprobs.topk(self.beam_size, 1)
                return words, logprobs, new_states
    
            partial_captions  = TopN(self.beam_size)
            complete_captions = TopN(self.beam_size)
    
            words, logprobs, new_state = get_topk_words(rnn_input, initial_state)
    
            for k in range(self.beam_size):
                cap = Caption(
                    sentence=[words[0, k]],
                    state=new_state,
                    logprob=logprobs[0, k],
                    score=logprobs[0, k])
    
                partial_captions.push(cap)
    
            # Run beam search.
    
            for _ in range(self.max_caption_length - 1):
                partial_captions_list = partial_captions.extract()
                partial_captions.reset()
                input_feed =[c.sentence[-1] for c in partial_captions_list]
    
                input_feed = paddle.to_tensor(input_feed)
                state_feed = [c.state for c in partial_captions_list]
                if isinstance(state_feed[0], tuple):
                    state_feed_h, state_feed_c = zip(*state_feed)
                    state_feed = (paddle.concat(state_feed_h, 1),
                                  paddle.concat(state_feed_c, 1))
                else:
                    state_feed = paddle.concat(state_feed, 1)
                
    
                embeddings = self.embedder(input_feed)
    
                words, logprobs, new_states = get_topk_words( 
                    embeddings, state_feed)
                    
                for i, partial_caption in enumerate(partial_captions_list):
                    if isinstance(new_states, tuple):
                        state = (paddle.slice(new_states[0],axes=[1],starts=[i],ends = [i+1]),
                                 paddle.slice(new_states[1],axes=[1],starts=[i],ends = [i+1]))
                    else:
                        state = new_states[i]
                    for k in range(self.beam_size):
                        w = words[i, k]
                        sentence = partial_caption.sentence + [w]
                        logprob = partial_caption.logprob + logprobs[i, k]
                        score = logprob
                        if w == self.eos_id:
                            if self.length_normalization_factor > 0:
                                score /= len(sentence)**self.length_normalization_factor
                            beam = Caption(sentence, state, logprob, score)
                            complete_captions.push(beam)
                        else:
                            beam = Caption(sentence, state, logprob, score)
                            partial_captions.push(beam)
                if partial_captions.size() == 0:
                    # We have run out of partial candidates; happens when beam_size
                    # = 1.
                    break
    
            # If we have no complete captions then fall back to the partial captions.
            # But never output a mixture of complete and partial captions because a
            # partial caption could have a higher score than all the complete
            # captions.
            if not complete_captions.size():
                complete_captions = partial_captions
    
            caps = complete_captions.extract(sort=True)
    
            return [c.sentence for c in caps], [c.score for c in caps]
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200

    🥝 5 定义损失函数


    • 使用基本的交叉熵损失函数

    • 使用定义的trg_mask避免对padding部分求loss

    # 定义损失函数
    class CrossEntropy(paddle.nn.Layer):
        def __init__(self):
            super(CrossEntropy,self).__init__()
    
        def forward(self,pre,real,trg_mask):
    
            cost=paddle.nn.functional.softmax_with_cross_entropy(logits=pre,label=real)
            
            # 删除axis=2 shape上为1的维度
            cost=paddle.squeeze(cost,axis=[2])
            
            # trg_mask 的形状[batch_size,suqence_len]
            masked_cost=cost*trg_mask
            
            return paddle.mean(paddle.mean(masked_cost,axis=[0]))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    🦃 6 定义参数并训练


    • 增加困惑度作为评价指标

    • 设置训练参数

    # 参数
    import h5py
    
    epochs=60
    word_size = 1151
    eos_id=word2id_dict['<end>']
    num_layers=32
    hidden_size=512
    embedding_dim=512
    lr=1e-3
    maxlength=40
    model_path='./output'
    
    csvData = pd.read_csv('work/IUxRay.csv')
    h5f = h5py.File('work/util_IUxRay/resnet101_festures.h5','r')
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    import paddlenlp
    
    model=CaptionModel(word_size,embedding_dim,hidden_size,num_layers,word2id_dict,id2word_dict)
    
    optimizer=paddle.optimizer.Adam(learning_rate=lr,parameters=model.parameters())
    # 困惑度
    ppl_metric=paddlenlp.metrics.Perplexity()
    
    
    train_dataset=CaptionDataset(csvData,word2id_dict,h5f,maxlength,'train')
    train_loader = paddle.io.DataLoader(train_dataset, batch_size=128,collate_fn = stackInput, shuffle=True)
    
    val_dataset=CaptionDataset(csvData,word2id_dict,h5f,maxlength,'test')
    val_loader = paddle.io.DataLoader(val_dataset, batch_size=64,collate_fn = stackInput, shuffle=True)
    
    
    # 设置优化器
    optimizer=paddle.optimizer.Adam(learning_rate=lr,parameters=model.parameters())
    # 设置损失函数
    loss_fn = CrossEntropy()
    
    perplexity = paddlenlp.metrics.Perplexity()
    
    model.train()
    
    for epoch in range(epochs):
        for batch_id, data in enumerate(train_loader()):
    
            img_features,trg_length,inputs,label,label_mask = data[0],data[1],data[2],data[3], data[4]  # 数据
    
            predicts = model(img_features,inputs,trg_length)    # 预测结果
    
            # 计算损失 等价于 prepare 中loss的设置
            loss = loss_fn(predicts, label , label_mask)
    
            # 计算困惑度 等价于 prepare 中metrics的设置
            correct = perplexity.compute(predicts, label)
            perplexity.update(correct.numpy())
            ppl = perplexity.accumulate()
            
            # 下面的反向传播、打印训练信息、更新参数、梯度清零都被封装到 Model.fit() 中
    
            # 反向传播
            loss.backward()
    
            if (batch_id+1) % 20 == 0:
                print("epoch: {}, batch_id: {}, loss is: {}, ppl is: {}".format(epoch+1, batch_id+1, loss.item(), ppl))
    
                # 保存模型参数,文件名为Unet_model.pdparams
                paddle.save(model.state_dict(), 'work/LSTM_model.pdparams')
    
            # 更新参数
            optimizer.step()
    
            # 梯度清零
            optimizer.clear_grad()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    
    model.eval()
    
    for batch_id, data in enumerate(val_loader()):
    
        img_features,trg_length,inputs,label,label_mask = data[0],data[1],data[2],data[3], data[4]  # 数据
    
        predicts = model(img_features,inputs,trg_length)    # 预测结果
    
        # 计算损失 等价于 prepare 中loss的设置
        loss = loss_fn(predicts , label , label_mask)
    
        # 计算困惑度 等价于 prepare 中metrics的设置
        correct = perplexity.compute(predicts, label)
        perplexity.update(correct.numpy())
        ppl = perplexity.accumulate()
        
        # 下面的反向传播、打印训练信息、更新参数、梯度清零都被封装到 Model.fit() 中
    
        if (batch_id+1) % 1 == 0:
            print(" batch_id: {}, loss is: {}, ppl is: {}".format( batch_id+1, loss.item(), ppl))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    🍓 7 模型推理

    
    # 验证数据集
    from IPython.display import display
    from PIL import Image
    import numpy as np
    from tqdm import tqdm
    
    path = 'work/IU数据集/NLMCXR_png/'
    csvData = pd.read_csv('work/IUxRay.csv')
    h5f = h5py.File('work/util_IUxRay/resnet101_festures.h5','r')
    data = csvData.iloc[int(0.8*len(csvData)):,:]
    scores = []
    
    Beam_Size = 3
    for idx,data_ in tqdm(enumerate(data.iterrows())):
        F_name = data_[1][0]
        F_text = data_[1][1]
    
        img_name = F_name.split(' ')[-1]
        h5f_name = '_'.join(F_name.split(' '))
        img_feature = np.array(h5f[h5f_name])
        img_path = path + img_name + '.png'
        
        img_feature = paddle.to_tensor(img_feature)
        results = model.generate(img_feature,beam_size=Beam_Size)
    
        #print('预测结果:',results[Beam_Size-1])
        #print('正确结果:',F_text)
    
        #img = Image.open(img_path).convert('RGB')
        #display(img, Image.BILINEAR)
        
        # 计算BLUE
        from nltk.translate.bleu_score import sentence_bleu
        reference = [F_text.split(' ')]
        candidate = results[Beam_Size-1].split(' ')
        score = sentence_bleu(reference,candidate)
        scores.append(score)
    
    print('预测结果:',results[Beam_Size-1])
    print('正确结果:',F_text)
    print('BLEU:',np.mean(scores))
    img = Image.open(img_path).convert('RGB')
    )
        img_path = path + img_name + '.png'
        
        img_feature = paddle.to_tensor(img_feature)
        results = model.generate(img_feature,beam_size=Beam_Size)
    
        #print('预测结果:',results[Beam_Size-1])
        #print('正确结果:',F_text)
    
        #img = Image.open(img_path).convert('RGB')
        #display(img, Image.BILINEAR)
        
        # 计算BLUE
        from nltk.translate.bleu_score import sentence_bleu
        reference = [F_text.split(' ')]
        candidate = results[Beam_Size-1].split(' ')
        score = sentence_bleu(reference,candidate)
        scores.append(score)
    
    print('预测结果:',results[Beam_Size-1])
    print('正确结果:',F_text)
    print('BLEU:',np.mean(scores))
    img = Image.open(img_path).convert('RGB')
    display(img, Image.BILINEAR)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67

    🎖️ 8 项目总结

    • 项目主要使用CNN+RNN的形式对CT影像报告的生成进行演示。

    • 由于BeamSearch的部分代码有小bug,目前使用的实际上是最大概率 已修正,可以正常传入Beam Size参数

    • 该项目是ImageCaption任务在医疗文本领域的简单实现,

    • 本项目所有代码及数据均以notebook呈现,简单易懂。

    • 本项目使用BLUE进行效果评价


    特别注意:该项目灵感来自《深度学习框架Pytorch入门与实践》第十章内容。


    有任何问题,欢迎评论区留言交流。
    
    • 1
  • 相关阅读:
    电脑软件:UPDF - 高效率高颜值的PDF编辑器
    Python:对程序做性能分析及计时统计
    北约网络安全防御演习:Locked Shields
    Cocos Creator3.8 实战问题(一)cocos creator prefab 无法显示内容
    Au 入门系列之一:认识 Audition
    基于java+ssm+vue+mysql的网上书店
    哪些场景需要额外注意线程安全问题
    分布式学习 - MPICH编译与实践
    ORACLE 特殊日期时间转换,计算
    Wifi通信协议:WEP,WPA,WPA2,WPA3,WPS
  • 原文地址:https://blog.csdn.net/Magic_Zsir/article/details/125429965