• 基于DTW算法的命令字识别


    DTW算法介绍

    DTW(Dynamic Time Warping):按距离最近原则,构建两个序列之间的对应的关系,评估两个序列的相似性。

    要求:

    • 单向对应,不能回头;
    • 一一对应,不能有空;
    • 对应之后,距离最近。

     

    DTW代码实现

    1. import numpy as np
    2. def dis_abs(x, y):
    3. return abs(x - y)[0]
    4. def estimate_twf(A, B, dis_func=dis_abs):
    5. N_A = len(A)
    6. N_B = len(B)
    7. D = np.zeros([N_A, N_B])
    8. D[0, 0] = dis_func(A[0], B[0])
    9. # 左边一列
    10. for i in range(1, N_A):
    11. D[i, 0] = D[i - 1, 0] + dis_func(A[i], B[0])
    12. # 下边一行
    13. for j in range(1, N_B):
    14. D[0, j] = D[0, j-1] + dis_func(A[0], B[j])
    15. # 中间部分
    16. for i in range(1, N_A):
    17. for j in range(1, N_B):
    18. D[i, j] = dis_func(A[i], B[j]) + min(D[i-1, j], D[i, j-1], D[i-1][j-1])
    19. # 路径回溯
    20. i = N_A - 1
    21. j = N_B - 1
    22. cnt = 0
    23. d = np.zeros(max(N_A, N_B) * 3)
    24. path = []
    25. while True:
    26. if i > 0 and j > 0:
    27. path.append((i, j))
    28. m = min(D[i-1, j], D[i, j-1], D[i-1, j-1])
    29. if m == D[i-1, j-1]:
    30. d[cnt] = D[i,j] - D[i-1, j-1]
    31. i -= 1
    32. j -= 1
    33. cnt += 1
    34. elif m == D[i, j-1]:
    35. d[cnt] = D[i,j] - D[i, j-1]
    36. j -= 1
    37. cnt += 1
    38. elif m == D[i-1, j]:
    39. d[cnt] = D[i,j] - D[i-1, j]
    40. i -= 1
    41. cnt += 1
    42. elif i == 0 and j == 0:
    43. path.append((i, j))
    44. d[cnt] = D[i, j]
    45. cnt += 1
    46. break
    47. elif i == 0:
    48. path.append((i, j))
    49. d[cnt] = D[i, j] - D[i, j-1]
    50. j -= 1
    51. cnt += 1
    52. elif j == 0:
    53. path.append((i, j))
    54. d[cnt] = D[i, j] - D[i-1, j]
    55. i -= 1
    56. cnt += 1
    57. mean = np.sum(d) / cnt
    58. return mean, path[::-1], D
    1. a = np.array([1,3,4,9,8,2,1,5,7,3])
    2. b = np.array([1,6,2,3,0,9,4,1,6,3])
    3. a = a[:, np.newaxis]
    4. b = b[:, np.newaxis]
    5. dis, path, D = estimate_twf(a, b, dis_func=dis_abs)
    6. print(dis, path, D)
    7. >>:
    8. 1.0833333333333333
    9. [(0, 0), (1, 1), (1, 2), (1, 3), (2, 4), (3, 5), (4, 5), (5, 6), (6, 7), (7, 8), (8, 8), (9, 9)]
    10. [[ 0. 5. 6. 8. 9. 17. 20. 20. 25. 27.]
    11. [ 2. 3. 4. 4. 7. 13. 14. 16. 19. 19.]
    12. [ 5. 4. 5. 5. 8. 12. 12. 15. 17. 18.]
    13. [13. 7. 11. 11. 14. 8. 13. 20. 18. 23.]
    14. [20. 9. 13. 16. 19. 9. 12. 19. 20. 23.]
    15. [21. 13. 9. 10. 12. 16. 11. 12. 16. 17.]
    16. [21. 18. 10. 11. 11. 19. 14. 11. 16. 18.]
    17. [25. 19. 13. 12. 16. 15. 15. 15. 12. 14.]
    18. [31. 20. 18. 16. 19. 17. 18. 21. 13. 16.]
    19. [33. 23. 19. 16. 19. 23. 18. 20. 16. 13.]]

    基于DTW算法的命令字识别

    utils.py:

    1. # -*- coding:UTF-8 -*-
    2. import streamlit as st
    3. import pyaudio
    4. import wave
    5. import librosa
    6. import soundfile as sf
    7. import numpy as np
    8. import os
    9. import time
    10. # 采用MFCC特征使用mcd距离
    11. def euclideanDistance(a, b):
    12. diff = a - b
    13. mcd = 10.0 / np.log(10) * np.sqrt(2.0 * np.sum(diff ** 2))
    14. return mcd
    15. # DTW算法匹配距离
    16. class DTW:
    17. def __init__(self, disFunc=euclideanDistance):
    18. self.disFunc = disFunc
    19. def compute_distance(self, reference, test):
    20. DTW_matrix = np.empty([reference.shape[0], test.shape[0]])
    21. DTW_matrix[:] = np.inf
    22. DTW_matrix[0, 0] = 0
    23. for i in range(reference.shape[0]):
    24. for j in range(test.shape[0]):
    25. cost = self.disFunc(reference[i, :], test[j, :])
    26. r_index = i - 1
    27. c_index = j - 1
    28. if r_index < 0:
    29. r_index = 0
    30. if c_index < 0:
    31. c_index = 0
    32. DTW_matrix[i, j] = cost + min(DTW_matrix[r_index, j], DTW_matrix[i, c_index],
    33. DTW_matrix[r_index, c_index])
    34. return DTW_matrix[-1, -1] / (test.shape[0] + reference.shape[0])
    35. # 语音录制
    36. class wordRecorder:
    37. def __init__(self, samplingFrequency=8000, threshold=20):
    38. self.samplingFrequency = samplingFrequency
    39. self.threshold = threshold
    40. def record(self):
    41. p = pyaudio.PyAudio()
    42. stream = p.open(format=pyaudio.paInt16, channels=1, rate=self.samplingFrequency, input=True, output=False,
    43. frames_per_buffer=1024)
    44. frames = []
    45. for i in range(int(self.samplingFrequency * 4 / 1024)):
    46. data = stream.read(1024)
    47. frames.append(data)
    48. stream.stop_stream()
    49. stream.close()
    50. p.terminate()
    51. return frames
    52. def record2File(self, path):
    53. frames = self.record()
    54. p = pyaudio.PyAudio()
    55. with wave.open(path, 'wb') as wf:
    56. wf.setnchannels(1)
    57. wf.setsampwidth(p.get_sample_size(pyaudio.paInt16))
    58. wf.setframerate(self.samplingFrequency)
    59. wf.writeframes(b''.join(frames))
    60. print('record finished!')
    61. # 提取mfcc特征
    62. def getmfcc(audio, isfile=True):
    63. if isfile:
    64. # 读取音频文件
    65. y, fs = librosa.load(audio, sr=8000)
    66. else:
    67. # 音频数据,需要去除静音
    68. y = np.array(audio)
    69. intervals = librosa.effects.split(y, top_db=20)
    70. y = librosa.effects.remix(y, intervals)
    71. # 预加重
    72. y = librosa.effects.preemphasis(y)
    73. fs = 8000
    74. N_fft = 256
    75. win_length = 256
    76. hop_length = 128
    77. n_mels = 23
    78. n_mfcc = 14
    79. # mfcc提取
    80. mfcc = librosa.feature.mfcc(y=y, sr=fs, n_mfcc=n_mfcc, n_mels=n_mels, n_fft=N_fft, win_length=win_length,
    81. hop_length=hop_length)
    82. mfcc = mfcc[1:, :]
    83. # 添加差分量
    84. mfcc_deta = librosa.feature.delta(mfcc)
    85. mfcc_deta2 = librosa.feature.delta(mfcc, order=2)
    86. # 特征拼接
    87. mfcc_d1_d2 = np.concatenate([mfcc, mfcc_deta, mfcc_deta2], axis=0)
    88. return mfcc_d1_d2.T
    89. # 指定文件夹下文件个数
    90. def check_file(name):
    91. os.makedirs('data', exist_ok=True)
    92. save_dir = os.path.join('data', name)
    93. os.makedirs(save_dir, exist_ok=True)
    94. n_files = 0
    95. for roots, dirs, files in os.walk(save_dir):
    96. for file in files:
    97. if file.endswith('.wav'):
    98. n_files += 1
    99. return n_files
    100. @st.cache_resource # 防止重载
    101. def model_load():
    102. model1 = ModelHotWord(os.path.join('data', '向上'))
    103. model2 = ModelHotWord(os.path.join('data', '向下'))
    104. model3 = ModelHotWord(os.path.join('data', '向左'))
    105. model4 = ModelHotWord(os.path.join('data', '向右'))
    106. models = [model1, model2, model3, model4]
    107. return models
    108. class ModelHotWord(object):
    109. def __init__(self, path):
    110. self.mfccs = get_train_mfcc_list(path)
    111. def get_score(self, ref_mfcc):
    112. return get_score(ref_mfcc, self.mfccs)
    113. def get_train_mfcc_list(data_path):
    114. mfccs = []
    115. for roots, dirs, files in os.walk(data_path):
    116. for file in files:
    117. if file.endswith('wav'):
    118. file_audio = os.path.join(data_path, file)
    119. mfcc = getmfcc(file_audio)
    120. mfccs.append(mfcc)
    121. return mfccs
    122. def get_score(ref_mfcc, list_mfccs):
    123. m_dtw = DTW()
    124. N = len(list_mfccs)
    125. scores = 0
    126. for i in range(N):
    127. dis = m_dtw.compute_distance(ref_mfcc, list_mfccs[i])
    128. scores = scores + dis
    129. return scores / N

    DTW.py:

    1. # -*- coding:UTF-8 -*-
    2. from utils import *
    3. st.title('基于DTW算法的命令字识别')
    4. tab1, tab2 = st.tabs(['音频录制', '识别演示'])
    5. with tab1:
    6. list_labs = ['向上', '向下', '向左', '向右']
    7. col1, col2, col3, col4 = st.columns(4)
    8. with col1:
    9. name = st.selectbox('模型选择', list_labs)
    10. with col2:
    11. st.write('命令字录制')
    12. flag_record = st.button(label='录音')
    13. with col3:
    14. st.write('命令字重录')
    15. flag_cancel = st.button(label='撤销')
    16. with col4:
    17. st.write('试听')
    18. flag_show_audios = st.button(label='试听')
    19. info_file_number = st.empty()
    20. info_file_number.write('命令字---%s--已有%d个样本'%(name, check_file(name)))
    21. info_audios = st.empty()
    22. info_success = st.empty()
    23. if flag_record:
    24. info_audios.info('')
    25. info_success.success('')
    26. n_files = check_file(name)
    27. info_audios.info('开始录制---第%d个命令字---%s--请在2s内完成录制.....'%(n_files + 1, name))
    28. save_dir = os.path.join('data', name)
    29. audio_name = os.path.join(save_dir, '%d.wav'%(n_files + 1))
    30. wRec = wordRecorder()
    31. wRec.record2File(audio_name)
    32. info_success.success('录制完成,保存为' + audio_name)
    33. if flag_cancel:
    34. n_files = check_file(name)
    35. save_dir = os.path.join('data', name)
    36. file_del = os.path.join(save_dir, str(n_files)+'.wav')
    37. os.remove(file_del)
    38. info_file_number.write('命令字--%s--已有%d个样本'%(name, check_file(name)))
    39. if flag_show_audios:
    40. n_files = check_file(name)
    41. save_dir = os.path.join('data', name)
    42. if n_files > 0:
    43. for i in range(n_files):
    44. audio_file = open(os.path.join(save_dir, '%d.wav'%(i+1)), 'rb')
    45. audio_bytes = audio_file.read()
    46. st.audio(audio_bytes, format='audio/')
    47. with tab2:
    48. th = 125
    49. st.write('识别演示')
    50. if 'run' not in st.session_state:
    51. st.session_state['run'] = False
    52. def start_listening():
    53. st.session_state['run'] = True
    54. def stop_listening():
    55. st.session_state['run'] = False
    56. col1, col2 = st.columns(2)
    57. with col1:
    58. st.button('开始检测', on_click=start_listening)
    59. with col2:
    60. st.button('停止检测', on_click=stop_listening)
    61. det_word = st.empty()
    62. def init_up():
    63. det_word.write('向上')
    64. def init_down():
    65. det_word.write('向下')
    66. def init_left():
    67. det_word.write('向左')
    68. def init_right():
    69. det_word.write('向右')
    70. callbacks = [init_up, init_down, init_left, init_right]
    71. # 加载预测模型,提取好的一些mfcc特征
    72. models = model_load()
    73. dic_labs = {'0': '向上', '1': '向下', '2': '向左', '3': '向右', '-1': ''}
    74. while st.session_state['run']: # 循环进行检测
    75. wRec = wordRecorder()
    76. wRec.record2File('data/test.wav')
    77. ref_mfcc = getmfcc('data/test.wav', True)
    78. # 在每个模型上进行打分,扎到最小分数作为检测结果
    79. scores = [model.get_score(ref_mfcc) for model in models]
    80. i_word = np.argmin(scores)
    81. score = np.min(scores)
    82. print(i_word, score)
    83. if score < th:
    84. i_det_word = i_word
    85. callback = callbacks[i_det_word]
    86. if callback is not None:
    87. callback()
    88. print('---------det word---------', dic_labs[str(i_det_word)])
    89. else:
    90. continue

    python命令行运行streamlit run DTW.py即会出现web网页ui,结果如下图所示:

    参考DTW关键字检测-代码实现_哔哩哔哩_bilibili

  • 相关阅读:
    最小体力消耗路径 -- dijkstra算法应用
    AIGC(生成式AI)试用 6 -- 从简单到复杂
    刷题笔记26——图论二分图判定
    Maven常用命令、坐标、依赖管理、依赖范围
    【Spring】Spring的JdbcTemplate
    使用kettle进行正则表达式组件日志分析
    Linux友人帐之网络编程基础FTP服务器
    LeetCode 450.删除二叉搜索树中的节点和669.修建二叉搜索树思路对比 及heap-use-after-free问题解决
    第十四章 配置国家语言支持 (NLS)
    记录扩充linux服务器centos-root目录过程
  • 原文地址:https://blog.csdn.net/qq_24946843/article/details/133357792