• Trace a function defined in a model


    reason

    The event file is too large to trace a total model, which couldn’t be opened by brower, trace every functions in the model is a alternative method.

    here to donwload pytorch_utils.py

    import torch
    import torch.nn as nn
    from pytorch_utils.modules import MLP
    
    EPS = 1e-8
    
    
    class WorkingMemory(nn.Module):
        def __init__(self, device='cpu', mem_type='vanilla', num_cells=10,
                     mem_size=300, mlp_size=300, dropout_rate=0.5,
                     key_size=20, usage_decay_rate=0.98, **kwargs):
            super(WorkingMemory, self).__init__()
            self.device = device
            self.mem_type = mem_type
            self.num_cells = num_cells
            self.mem_size = mem_size
            self.mlp_size = mlp_size
    
            self.usage_decay_rate = usage_decay_rate
            # Dropout module
            self.drop_module = nn.Dropout(p=dropout_rate, inplace=False)
    
            if self.mem_type == 'learned':
                self.init_mem = nn.Parameter(torch.zeros(self.num_cells, self.mem_size))
            elif self.mem_type == 'key_val':
                self.key_size = key_size
                # Initialize key and value vectors
                self.init_key = nn.Parameter(torch.zeros(self.num_cells, key_size))
    
            # MLP to determine entity or not
            self.entity_mlp = MLP(self.mem_size, mlp_size, 1, num_layers=2, bias=True)
            # MLP to merge past memory and current candidate to write new memory
            self.U_key = nn.Linear(2 * mem_size, mem_size, bias=True)
            # MLP to determine coref similarity between current token and memory
            self.sim_mlp = MLP(3 * self.mem_size + 1, mlp_size, 1, num_layers=2, bias=True)
    
            self.gumbel_temperature = nn.Parameter(torch.tensor([1.0]), requires_grad=False)
    
        def initialize_memory(self, batch_size):
            """Initialize the memory with the learned key and the null value part."""
            init_mem = torch.zeros(batch_size, self.num_cells, self.mem_size).to(self.device)
            if self.mem_type == 'learned':
                init_mem = self.init_mem.unsqueeze(dim=0)
                init_mem = init_mem.repeat(batch_size, 1, 1)
            elif self.mem_type == 'key_val':
                init_val = torch.zeros(batch_size, self.num_cells,
                                       self.mem_size - self.key_size).to(self.device)
                init_key = self.init_key.unsqueeze(dim=0)
                init_key = init_key.repeat(batch_size, 1, 1)
    
                init_mem = torch.cat([init_key, init_val], dim=2)
    
            init_usage = torch.zeros(batch_size, self.num_cells).to(self.device)
            return (init_mem, init_usage)
    
        def sample_gumbel(self, shape, eps=EPS):
            U = torch.rand(shape).to(self.device)
            return -torch.log(-torch.log(U + eps) + eps)
    
        def pick_overwrite_cell(self, usage, sim_score):
            """Pick cell to overwrite.
            - Prefer unused cells.
            - Break ties using similarity.
            """
            norm_sim_score = nn.functional.softmax(sim_score, dim=-1)
            # Assign overwrite scores to each cell.
            # (1) Prefer cells which have not been used.
            # (2) Among the unused cells, prefer ones with the higher similarily score
            #     (Useful for memory with learned initialization).
            # (3) Otherwise prefer cells with least usage.
            overwrite_score = ((usage == 0.0).float() * norm_sim_score * 1e5) + (1 - usage)
    
            if self.training:
                logits = torch.log(overwrite_score * (1 - EPS) + EPS)
                gumbel_noise = self.sample_gumbel(usage.size())
                y = nn.functional.softmax(
                    (logits + gumbel_noise) / self.gumbel_temperature, dim=-1)
            else:
                max_val = torch.max(overwrite_score, dim=-1, keepdim=True)[0]
                # Randomize the max
                index = torch.argmax(
                    (torch.empty(overwrite_score.shape).uniform_(0.01, 1).to(self.device)
                        * (overwrite_score == max_val).float()),
                    dim=-1, keepdim=True)
                y = torch.zeros_like(overwrite_score).scatter_(-1, index, 1.0)
            return y
    
        def predict_entity_prob(self, cur_hidden_state):
            """Predicts whether the current word is (part of) an entity or not."""
            ent_score = self.entity_mlp(cur_hidden_state)
    
            # Perform a softmax over scores of 0 and ent_score
            comb_score = torch.cat([torch.zeros_like(ent_score).to(self.device),
                                    ent_score], dim=1)
            # Numerically stable softmax
            max_score, _ = torch.max(comb_score, dim=1, keepdim=True)
            ent_prob = nn.functional.softmax(comb_score - max_score, dim=1)
            # We only care about the 2nd column i.e. corresponding to ent_score
            ent_prob = torch.unsqueeze(ent_prob[:, 1], dim=1)
    
            return ent_score, ent_prob
    
        def get_coref_mask(self, usage):
            """No coreference with empty cells."""
            cell_mask = (usage > 0).float().to(self.device)
            return cell_mask
    
        def predict_coref_overwrite(self, mem_vectors, query_vector, usage,
                                    ent_prob):
            """Calculate similarity between query_vector and mem_vectors.
            query_vector: B x M x H
            mem_vectors: B x M x H
            """
            pairwise_vec = torch.cat([mem_vectors, query_vector,
                                      query_vector * mem_vectors,
                                      torch.unsqueeze(usage, dim=2)], dim=-1)
            pairwise_score = self.sim_mlp(pairwise_vec)
    
            sim_score = pairwise_score  # B x M x1
            sim_score = torch.squeeze(sim_score, dim=-1)
    
            batch_size = query_vector.shape[0]
            base_score = torch.zeros((batch_size, 1)).to(self.device)
            comb_score = torch.cat([sim_score, base_score], dim=1)
            # Bx(M+1)
            coref_mask = self.get_coref_mask(usage)  # B x M
            # Coref only possible when the cell is active
            mult_mask = torch.cat([coref_mask,
                                   torch.ones((batch_size, 1)).to(self.device)], dim=-1)
            # Zero out the inactive cell scores and then add a big negative value
            comb_score = comb_score * mult_mask + (1 - mult_mask) * (-1e4)
    
            # Numerically stable softmax
            max_cell_score, _ = torch.max(comb_score, dim=1, keepdim=True)
            init_probs = nn.functional.softmax(comb_score - max_cell_score, dim=1)
    
            # Make sure the inactive cells are really zero even after logit of -1e4
            masked_probs = init_probs * mult_mask
            norm_probs = (
                masked_probs/(torch.sum(masked_probs, dim=-1, keepdim=True) + EPS))
    
            coref_over_probs = ent_prob * norm_probs
            indv_coref_prob = coref_over_probs[:, :self.num_cells]
    
            overwrite_prob = coref_over_probs[:, self.num_cells]
            overwrite_prob = torch.unsqueeze(overwrite_prob, dim=1)
    
            return indv_coref_prob, overwrite_prob
    
        def forward(self, data_w):
            """Read excerpts.
            hidden_state_list: list of B x H tensors
            input_mask_list: list of B sized tensors
            """
            hidden_state_list = data_w[0]
            input_mask_list = data_w[1]
            batch_size = hidden_state_list[0].shape[0]
    
            if self.mem_type == 'key_val':
                # Get initialized key vectors
                init_key = self.init_key.unsqueeze(dim=0)
                init_key = init_key.repeat(batch_size, 1, 1)
    
            # Initialize memory
            mem_vectors, usage = self.initialize_memory(batch_size)
    
            # Store all updates
            ent_list, usage_list, coref_list, overwrite_list = [], [], [], []
    
            for t, (cur_hidden_state, cur_input_mask) in \
                    enumerate(zip(hidden_state_list, input_mask_list)):
                query_vector = self.drop_module(cur_hidden_state)
    
                ent_score, ent_prob = self.predict_entity_prob(query_vector)
                ent_prob = ent_prob * torch.unsqueeze(cur_input_mask, dim=1)
                ent_list.append(ent_prob * (1 - EPS) + EPS)
    
                rep_query_vector = query_vector.unsqueeze(dim=1)
                # B x M x H
                rep_query_vector = rep_query_vector.repeat(1, self.num_cells, 1)
    
                indv_coref_prob, new_ent_prob = self.predict_coref_overwrite(
                    mem_vectors=mem_vectors, query_vector=rep_query_vector,
                    usage=usage, ent_prob=ent_prob)
    
                coref_list.append(indv_coref_prob * (1 - EPS) + EPS)
    
                # Overwriting Prob - B x M
                pairwise_vec = torch.cat([mem_vectors, rep_query_vector,
                                          rep_query_vector * mem_vectors,
                                          torch.unsqueeze(usage, dim=2)], dim=-1)
                init_sim_score = torch.squeeze(self.sim_mlp(pairwise_vec), dim=-1)
                overwrite_prob = (
                    new_ent_prob * self.pick_overwrite_cell(usage, init_sim_score)
                )
                try:
                    assert (torch.max(overwrite_prob) <= 1)
                    assert (torch.max(indv_coref_prob) <= 1)
                    assert (torch.max(ent_prob) <= 1)
                except AssertionError:
                    print("Assertion Error happened! Trying best to recover")
                    return None
                    # raise
    
                overwrite_list.append(overwrite_prob * (1 - EPS) + EPS)
    
                comb_inp = torch.cat([rep_query_vector, mem_vectors], dim=-1)
                mem_candidate = torch.tanh(self.U_key(comb_inp))
                # B x M x H
                updated_mem_vectors = (
                    torch.unsqueeze(overwrite_prob, dim=2) * rep_query_vector
                    + torch.unsqueeze(1 - overwrite_prob - indv_coref_prob, dim=2)
                    * mem_vectors
                    + torch.unsqueeze(indv_coref_prob, dim=2) * mem_candidate
                )
    
                if self.mem_type == 'key_val':
                    # Don't update the key dimensions. Only update the later dimensions.
                    updated_mem_vectors = torch.cat(
                        [init_key, updated_mem_vectors[:, :, self.key_size:]], dim=2)
    
                # Update usage
                updated_usage = torch.min(
                    torch.FloatTensor([1.0]).to(self.device),
                    overwrite_prob + indv_coref_prob + self.usage_decay_rate * usage)
                usage_list.append(updated_usage)
                # Update memory
                mem_vectors, usage = updated_mem_vectors, updated_usage
    
            # return {'ent': ent_list, 'usage': usage_list,
            #         'coref': coref_list, 'overwrite': overwrite_list}
            return (ent_list, usage_list, coref_list, overwrite_list)
        
    class PICK(nn.Module):
        def __init__(self, memory):
            super(PICK, self).__init__()
            self.memory = memory
    
        def forward(self, pick_data):
    
            usage, sim_score = pick_data
            return self.memory.pick_overwrite_cell(usage, sim_score)
    
    if __name__ == "__main__":
    
        from torch.utils.tensorboard import SummaryWriter
        hidden_state_list = tuple(list(torch.randn(99, 32, 300)))
        input_mask_list = tuple(list(torch.zeros(99, 32, )))
        memory = WorkingMemory()
        data_w = (hidden_state_list, input_mask_list)
        model = PICK(memory)
        # result = model(data_w)
        # print(result[0])
        pick_data = (torch.zeros(32, 10), torch.randn(32, 10))
        result = model(pick_data)
        # print(result)
        writer = SummaryWriter('structure/pick_overwrite_cell')  # 建立一个保存数据用的东西
        writer.add_graph(model, (pick_data, ))
        writer.add_text("text", "hello, this is a text info", global_step=2)
        writer.close()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261

    在这里插入图片描述

  • 相关阅读:
    分布式session的4种解决方案
    【每日一题Day360】LC1465切割后面积最大的蛋糕 | 贪心
    Unity-网格编程
    【实战案例】Python 信用卡欺诈检测其实特简单
    LeetCode LCP 06. 拿硬币【贪心,数学】简单
    python 中面向对象编程:深入理解封装、继承和多态
    java正则表达式 及应用场景爬虫,捕获分组非捕获分组
    【JDBC】Apache-DBUtils使用指南
    使用 HTML CSS 和 JavaScript 创建星级评分系统
    手机照片怎么恢复?10个照片恢复应用程序
  • 原文地址:https://blog.csdn.net/qq_45911550/article/details/134025387