• 神经网络中间层特征图可视化(输入为音频)(二)


    相比方法一个人感觉这种方法更好

    import librosa
    import numpy as np
    import utils
    import torch
    import matplotlib.pyplot as plt
    
    class Hook:
        def __init__(self):
            self.features = None
    
        def hook_fn(self, module, input, output):
            self.features = output
    
    # 创建钩子的实例
    hook = Hook()
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    def extract_mbe(_y, _sr, _nfft, _nb_mel):
        #梅尔频谱
        spec = librosa.core.spectrum._spectrogram(y=_y, n_fft=_nfft, hop_length=_nfft // 2, power=1)[0]
        mel_basis = librosa.filters.mel(sr=_sr, n_fft=_nfft, n_mels=_nb_mel)
        mel_spec = np.log(np.dot(mel_basis, spec).T)
        return mel_spec       #最后必须是[frames, dimensions]
    
    def preprocess_data(X, seq_len, nb_ch):
        # split into sequences
        X = utils.split_in_seqs(X, seq_len)
        X = utils.split_multi_channels(X, nb_ch)
        # Convert to PyTorch tensors
        X = torch.Tensor(X)
        X = X.permute(0,1,3,2)   #x形状为[709,2,40,256],【总样本数,通道数,特征维度,像素宽度】
        return X
    
    # 提取梅尔频谱特征
    audio_path = "a011.wav"
    y, sr = librosa.load(audio_path, sr=44100)
    mel = extract_mbe(y, sr, 2048, 64)
    
    value = preprocess_data(mel, 256, 1).to(device)     #value 为输入模型的样本特征
    
    
    model = torch.load(f'best_model_2.pth')
    
    # 将钩子注册到需要的层
    model.cnn1.register_forward_hook(hook.hook_fn)
    
    # 假设`input_data`是你的输入张量
    output = model(value)
    
    # 访问存储的特征
    retnet_features = hook.features
    #print(retnet_features.shape)
    # 可视化特征(假设retnet_features是一个张量)
    retnet_features = retnet_features.permute(0, 2, 1, 3)
    #retnet_features = retnet_features.transpose(1, 2)
    #print(retnet_features.shape)
    retnet_features = torch.cat([retnet_features[i] for i in range(10)], dim=2)
    #print(retnet_features.shape)
    
    # 可视化批次中第一个样本的特定通道
    plt.imshow(retnet_features.sum(1).detach().cpu().numpy(), cmap='viridis', origin='lower')   #[高,通道, 宽]
    # plt.imshow(retnet_features.detach().cpu().numpy(), cmap='viridis', origin='lower')   #[高,宽]
    plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
  • 相关阅读:
    微信小程序入门2
    uniapp 全部权限
    SpringBoot整合SpringSecurity [超详细] (二)获取用户信息
    大型 APP 的性能优化思路
    华清 c++ day5 9月12
    2023-11-14 mysql-LOGICAL_CLOCK 并行复制原理及实现分析
    Check for degenerate boxes检查退化框
    【APP VTable】和市面上的 Table 组件一样,都是接收表格[] 以及数据源[]
    学信息系统项目管理师第4版系列20_风险管理
    Nginx 优化
  • 原文地址:https://blog.csdn.net/qq358660877/article/details/134556277