• 【信号处理】基于CNN自编码器的心电信号异常检测识别(tensorflow)


    关于

    本项目主要实现卷积自编码器对于异常心电ECG信号的检测和识别,属于无监督学习中的生理信号检测的典型方法之一。

    工具

     

    方法实现

    读取心电信号
    1. normal_df = pd.read_csv("/heartbeat/ptbdb_normal.csv").iloc[:, :-1]
    2. anomaly_df = pd.read_csv("/heartbeat/ptbdb_abnormal.csv").iloc[:, :-1]
    3. normal_df.head()

    信号可视化

    1. def plot_sample(normal, anomaly):
    2. index = np.random.randint(0, len(normal_df), 2)
    3. fig, ax = plt.subplots(1, 2, sharey=True, figsize=(10, 4))
    4. ax[0].plot(normal.iloc[index[0], :].values, label=f"Case {index[0]}")
    5. ax[0].plot(normal.iloc[index[1], :].values, label=f"Case {index[1]}")
    6. ax[0].legend(shadow=True, frameon=True, facecolor="inherit", loc=1, fontsize=9)
    7. ax[0].set_title("Normal")
    8. ax[1].plot(anomaly.iloc[index[0], :].values, label=f"Case {index[0]}")
    9. ax[1].plot(anomaly.iloc[index[1], :].values, label=f"Case {index[1]}")
    10. ax[1].legend(shadow=True, frameon=True, facecolor="inherit", loc=1, fontsize=9)
    11. ax[1].set_title("Anomaly")
    12. plt.tight_layout()
    13. plt.show()
    14. plot_sample(normal_df, anomaly_df)

     

     信号均值计算及可视化
    1. def plot_smoothed_mean(data, class_name = "normal", step_size=5, ax=None):
    2. df = pd.DataFrame(data)
    3. roll_df = df.rolling(step_size)
    4. smoothed_mean = roll_df.mean().dropna().reset_index(drop=True)
    5. smoothed_std = roll_df.std().dropna().reset_index(drop=True)
    6. margin = 3*smoothed_std
    7. lower_bound = (smoothed_mean - margin).values.flatten()
    8. upper_bound = (smoothed_mean + margin).values.flatten()
    9. ax.plot(smoothed_mean.index, smoothed_mean)
    10. ax.fill_between(smoothed_mean.index, lower_bound, y2=upper_bound, alpha=0.3, color="red")
    11. ax.set_title(class_name, fontsize=9)
    12. fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharey=True)
    13. axes = axes.flatten()
    14. for i, label in enumerate(CLASS_NAMES, start=1):
    15. data_group = df.groupby("target")
    16. data = data_group.get_group(label).mean(axis=0, numeric_only=True).to_numpy()
    17. plot_smoothed_mean(data, class_name=label, step_size=20, ax=axes[i-1])
    18. fig.suptitle("Plot of smoothed mean for each class", y=0.95, weight="bold")
    19. plt.tight_layout()

     训练/测试数据划分
    1. normal_df.drop("target", axis=1, errors="ignore", inplace=True)
    2. normal = normal_df.to_numpy()
    3. anomaly_df.drop("target", axis=1, errors="ignore", inplace=True)
    4. anomaly = anomaly_df.to_numpy()
    5. X_train, X_test = train_test_split(normal, test_size=0.15, random_state=45, shuffle=True)
    6. print(f"Train shape: {X_train.shape}, Test shape: {X_test.shape}, anomaly shape: {anomaly.shape}")
    搭建自编码器
    1. class AutoEncoder(Model):
    2. def __init__(self, input_dim, latent_dim):
    3. super(AutoEncoder, self).__init__()
    4. self.input_dim = input_dim
    5. self.latent_dim = latent_dim
    6. self.encoder = tf.keras.Sequential([
    7. layers.Input(shape=(input_dim,)),
    8. layers.Reshape((input_dim, 1)), # Reshape to 3D for Conv1D
    9. layers.Conv1D(128, 3, strides=1, activation='relu', padding="same"),
    10. layers.BatchNormalization(),
    11. layers.MaxPooling1D(2, padding="same"),
    12. layers.Conv1D(128, 3, strides=1, activation='relu', padding="same"),
    13. layers.BatchNormalization(),
    14. layers.MaxPooling1D(2, padding="same"),
    15. layers.Conv1D(latent_dim, 3, strides=1, activation='relu', padding="same"),
    16. layers.BatchNormalization(),
    17. layers.MaxPooling1D(2, padding="same"),
    18. ])
    19. # Previously, I was using UpSampling. I am trying Transposed Convolution this time around.
    20. self.decoder = tf.keras.Sequential([
    21. layers.Conv1DTranspose(latent_dim, 3, strides=1, activation='relu', padding="same"),
    22. # layers.UpSampling1D(2),
    23. layers.BatchNormalization(),
    24. layers.Conv1DTranspose(128, 3, strides=1, activation='relu', padding="same"),
    25. # layers.UpSampling1D(2),
    26. layers.BatchNormalization(),
    27. layers.Conv1DTranspose(128, 3, strides=1, activation='relu', padding="same"),
    28. # layers.UpSampling1D(2),
    29. layers.BatchNormalization(),
    30. layers.Flatten(),
    31. layers.Dense(input_dim)
    32. ])
    33. def call(self, X):
    34. encoded = self.encoder(X)
    35. decoded = self.decoder(encoded)
    36. return decoded
    37. input_dim = X_train.shape[-1]
    38. latent_dim = 32
    39. model = AutoEncoder(input_dim, latent_dim)
    40. model.build((None, input_dim))
    41. model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss="mae")
    42. model.summary()
    模型训练
    1. epochs = 100
    2. batch_size = 128
    3. early_stopping = EarlyStopping(patience=10, min_delta=1e-3, monitor="val_loss", restore_best_weights=True)
    4. history = model.fit(X_train, X_train, epochs=epochs, batch_size=batch_size,
    5. validation_split=0.1, callbacks=[early_stopping])
    训练可视化
    1. plt.plot(history.history['loss'], label="Training loss")
    2. plt.plot(history.history['val_loss'], label="Validation loss", ls="--")
    3. plt.legend(shadow=True, frameon=True, facecolor="inherit", loc="best", fontsize=9)
    4. plt.title("Training loss")
    5. plt.ylabel("Loss")
    6. plt.xlabel("Epoch")
    7. plt.show()

     

    信号重建可视化
    1. fig, axes = plt.subplots(2, 5, sharey=True, sharex=True, figsize=(12, 6))
    2. random_indexes = np.random.randint(0, len(X_train), size=5)
    3. for i, idx in enumerate(random_indexes):
    4. data = X_train[[idx]]
    5. plot_examples(model, data, ax=axes[0, i], title="Normal")
    6. for i, idx in enumerate(random_indexes):
    7. data = anomaly[[idx]]
    8. plot_examples(model, data, ax=axes[1, i], title="anomaly")
    9. plt.tight_layout()
    10. fig.suptitle("Sample plots (Actual vs Reconstructed by the CNN autoencoder)", y=1.04, weight="bold")
    11. fig.savefig("autoencoder.png")
    12. plt.show()

    计算重建MAE误差
    1. train_mae = model.evaluate(X_train, X_train, verbose=0)
    2. test_mae = model.evaluate(X_test, X_test, verbose=0)
    3. anomaly_mae = model.evaluate(anomaly_df, anomaly_df, verbose=0)
    4. print("Training dataset error: ", train_mae)
    5. print("Testing dataset error: ", test_mae)
    6. print("Anormaly dataset error: ", anomaly_mae)

     异常检测阈值选取

    MAE误差阈值=正常数据重建MAE均值+正常数据重建MAE标准差,此阈值可以用来直接检测某信号为正常信号还是异常心电信号。

    1. def predict(model, X):
    2. pred = model.predict(X, verbose=False)
    3. loss = mae(pred, X)
    4. return pred, loss
    5. _, train_loss = predict(model, X_train)
    6. _, test_loss = predict(model, X_test)
    7. _, anomaly_loss = predict(model, anomaly)
    8. threshold = np.mean(train_loss) + np.std(train_loss) # Setting threshold for distinguish normal data from anomalous data
    9. bins = 40
    10. plt.figure(figsize=(9, 5), dpi=100)
    11. sns.histplot(np.clip(train_loss, 0, 0.5), bins=bins, kde=True, label="Train Normal")
    12. sns.histplot(np.clip(test_loss, 0, 0.5), bins=bins, kde=True, label="Test Normal")
    13. sns.histplot(np.clip(anomaly_loss, 0, 0.5), bins=bins, kde=True, label="anomaly")
    14. ax = plt.gca() # Get the current Axes
    15. ylim = ax.get_ylim()
    16. plt.vlines(threshold, 0, ylim[-1], color="k", ls="--")
    17. plt.annotate(f"Threshold: {threshold:.3f}", xy=(threshold, ylim[-1]), xytext=(threshold+0.009, ylim[-1]),
    18. arrowprops=dict(facecolor='black', shrink=0.05), fontsize=9)
    19. plt.legend(shadow=True, frameon=True, facecolor="inherit", loc="best", fontsize=9)
    20. plt.show()

    模型评估
    1. plot_confusion_matrix(model, X_train, X_test, anomaly, threshold=threshold)
    2. ytrue, ypred = prepare_labels(model, X_train, X_test, anomaly, threshold=threshold)
    3. print(classification_report(ytrue, ypred, target_names=CLASS_NAMES))

     

    代码获取

    相关项目开发和问题,欢迎后台沟通交流。

  • 相关阅读:
    Nginx 高性能架构解析
    自动驾驶——软件和云服务介绍
    getBoundingClientRect使用场景(table固定表头)
    6、堆(新生区,永久区,堆内存调优(jvm调优))
    Redis使用lua脚本实现库存扣减
    Linux学习-52-Linux工作管理-后台命令运行管理
    WPF Binding对象、数据校验、数据转换
    rebase 和 merge合并代码
    DevOps|1024程序员节怎么做?介绍下我的思路
    java毕业生设计预约挂号系统演示录像2021计算机源码+系统+mysql+调试部署+lw
  • 原文地址:https://blog.csdn.net/YINTENAXIONGNAIER/article/details/138034878