• 使用StanfordCoreNLP的句法树以及NLTK的Tree建立DGL的图数据结构


    文章目录

    前言

    因为其实Tree-LSTM的数据类型构造起来也是比较麻烦的,还是要记录一下,以防之后忘记

    构建思路

    因为前面一篇写的Tree-LSTM是用DGL实现的,那么图数据结构自然也要用DGL来操作一下,不然喂不进去啊。

    构建数据集的样子是按照DGL自带的SSTDataset(mode=‘tiny’)来的,因为这样我就可以直接放进去训练了。

    我是使用句法树来构建的,因为句法树并不止二叉,因此模型使用Child-sum Tree-LSTM会比较合适。
    例如:
    在这里插入图片描述
    这么一颗句法树,因为DGL使用序号来标注节点的,那么其实我也需要转成序号才方便赋值,因此使用NLTK的Tree操作一下,可以将句法树变成这样:
    在这里插入图片描述
    然后我再从叶子节点往上加入边,知道根节点停止,那么一个图就建立完成了。

    但是,这只是一句话,我需要对一段话进行情感分类,而句法树只能分析一句话,那么我可以把一段话切句,然后将每句话的根节点连接到一个假想的节点上,然后用这个节点的信息作为最终的分类输出,代表整段话的含义,然后用来分类即可。
    在这里插入图片描述
    就像这样,构造一片句法森林,然后用0节点的信息来分类即可。

    代码如下:
    首先获取图数据结构,以及记录每段话叶子节点的词

    from stanfordcorenlp import StanfordCoreNLP
    from nltk import Tree
    from nltk.tokenize import sent_tokenize
    import dgl
    import torch
    import pandas as pd
    from tqdm import tqdm
    
    nlp = StanfordCoreNLP(r'D:\stanfordnlp\stanford-corenlp-4.4.0', lang="en")
    row_data_path = "demo.tsv" # 每段话的tsv文件
    data = pd.read_csv(row_data_path, sep="\t")
    graph_data = []  # 存放句法树组成的图
    tree_word = []  # 存放句法树的叶子节点
    
    for index, d in tqdm(data.iterrows()):
        sentence = sent_tokenize(d["text"])
        j = 1  # 节点编号,取下面那个为根节点,标号为1
        tree = []  # 存u->v的元组
        leaves_pos = []  # 存叶子节点的位置信息
        leaves_word = []  # 存叶子节点的词
        node_sum = -1  # 节点数
        for sen in sentence:
            tree.append((j, 0))  # 每句话句法树连到0节点
    
            parse = nlp.parse(sen)
            t = Tree.fromstring(parse)
            # t.draw()
    
            pos = t.treepositions()  # 所有节点的位置
            leaves = t.treepositions('leaves')  # 所有叶子节点的位置
            num = len(pos) - len(leaves)  # 非叶子节点的位置
    
            for i in range(len(pos)):  # 给节点打标号,跳过ROOT
                if i == 0:
                    continue
                if type(t[pos[i]]) == str:
                    continue
                else:
                    t[pos[i]].set_label(j)
                    j += 1
            # t.draw()
            # for pos in t.treepositions('leaves'):  # 获取所有叶子节点的位置
            #     print(t[pos[:-1]].label())
    
            
            for pos in t.treepositions('leaves'):  # 获取所有叶子节点的位置
                leaves_word.append(t[pos])  # 存叶子节点的词
                node_num = len(pos)
                labels = []  # 存此叶子节点到根节点的所有节点的序号
                for i in range(1, node_num, 1):
                    labels.append(t[pos[:-i]].label())
                leaves_pos.append(labels[0])  # 第一个为叶子节点
                for k in range(len(labels)):
                    if labels[k] > node_sum:  # 记录最大节点编号
                        node_sum = labels[k]
                    try:
                        u_v = (labels[k], labels[k+1])
                        if u_v not in tree:
                            tree.append(u_v)
                        else:
                            break
                    except:
                        break
    
        # print(tree)
        # print(leaves_pos)
        # print(node_sum)
        # print(leaves_word)
    
        tree_graph = dgl.graph(tree)
        mask = torch.zeros(node_sum + 1)  # 不是叶子节点的被遮掩
        mask[leaves_pos] = 1
        tree_graph.ndata["mask"] = mask.long()  # 节点给上掩码
        node_pos = torch.arange(0, node_sum  + 1)
        tree_graph.ndata["node_pos"] = node_pos  # 节点标记位置 
        if d["label"] == 0:
            label = torch.zeros(node_sum + 1)  # 存句子的标签0
        else:
            label = torch.ones(node_sum + 1)  # 存句子的标签1
    
        tree_graph.ndata["y"] = label.long() 
        tree_graph.ndata["x"] = label.long()
    
        # print(tree_graph)
        graph_data.append(tree_graph)
        tree_word.append("#$#".join(leaves_word))
        if index % 1000 == 0: # 每隔1000次保存一次
            print(index)
            torch.save(graph_data, "graph_data_" + str(index) + ".pt") # 存图数据
            pd.DataFrame(tree_word).to_csv("tree_word_" + str(index) + ".tsv", sep="\t", index=0) # 存叶子节点的词
    
    nlp.close()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93

    当所有数据分析完成后(没有完成的话字典是建立不出来的,下面代码没有建立字典的过程,需要自己补充),通过叶子节点的记录建立对应的字典,然后将图数据结构中的叶子节点对应的词等等特征补充完整

    graph_data = torch.load("graph_data_0.pt")
    
    import json
    with open("train_data_dic.json") as f:
        dic = json.load(f)
    
    import pandas as pd
    import torch
    f = pd.read_csv(r"tree_word_0.tsv", sep="\t")
    for i in range(len(f)): # 赋值叶子节点(node_pos)
        words = f.loc[i][0].split("#$#")
        mask = graph_data[i].ndata["mask"]
        wordid = torch.zeros(len(mask))
    
        k = 0 # 词的下标
        for j in range(len(mask)):
            if mask[j] == 1:  # 没被mask就是对应的词
                try:
                    wordid[j] = dic[words[k]]
                except:
                    wordid[j] = dic["<unk>"]
                k += 1
        graph_data[i].ndata["x"] = wordid.long()
    print(graph_data)
    torch.save(graph_data, "graph_data.pt")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    例如:
    demo.tsv:

    text	label
    although i am not a golf fan, i attended a sneak preview of this movie and absolutely loved it.	1
    
    • 1
    • 2

    运行完成后得到的图数据graph_data如下:

    [Graph(num_nodes=38, num_edges=37,
          ndata_schemes={'mask': Scheme(shape=(), dtype=torch.int64), 'node_pos': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'x': Scheme(shape=(), dtype=torch.int64)}
          edata_schemes={})]
    
    • 1
    • 2
    • 3

    元素0为这个段话的图,共38个节点,37条边,节点的特征中有mask(用于屏蔽不是叶子节点的位置),node_pos节点标记位置 ,x存词的id, y存标签。

  • 相关阅读:
    MAC地址注册的网络安全影响和措施分析
    Spring AOP(JavaEE进阶系列5)
    第三章 内存管理 十三、页面置换算法(最佳置换算法、先进先出置换算法、最近最久未使用置换算法、时钟置换算法、改进型的时钟置换算法)
    STC51单片机30——单个数码管显示
    Hadoop之HDFS重点架构原理简介
    如何使用 Yolov4 训练人脸口罩检测模型
    不用Swagger,那我用啥?
    iPhone开发--Xcode中的ld64和-ld_classic是什么意思
    哈希表专项练习 LeetCode
    HTML5的 WebSocket, wedworker,RPC模拟的介绍和使用
  • 原文地址:https://blog.csdn.net/qq_52785473/article/details/125633716