当前位置: 首页 > news >正文

一维卷积神经网络的特征可视化

随着以深度学习为代表的人工智能技术的不断发展,许多具有重要意义的深度学习模型和算法被开发出来,应用于计算机视觉、自然语言处理、语音处理、生物医疗、金融应用等众多行业领域。深度学习先进的数据挖掘、训练和分析能力来源于深度神经网络的海量模型参数以及高度非线性。也正因为深度学习算法的高度复杂性,许多模型往往难以解释其内部工作原理,这导致这些模型被称为缺乏可解释性的“黑箱模型”。

随着AI应用渗透到各行各业,AI的科技伦理受到广泛的关注。而科技伦理的一个核心议题就是可解释人工智能XAI。从社会科学角度,可解释性是指人对决策原因的理解程度,可解释性越高,人就越能理解为什么做出这样的决策。对应于AI领域,可解释性是指能够在一定程度上揭示AI模型内部工作机制和对模型结果的进行解释,帮助用户理解模型是如何做出预测或决策的。

因此,本文简单地对一维卷积神经网络的特征进行可视化,运行环境为Python,研究对象为心电信号。

首先导入相关库

import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as pltimport signal_screen
import signal_screen_toolsfrom tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv1D, MaxPool1D, Flatten, BatchNormalization, Input
from tensorflow.keras.callbacks import ModelCheckpoint

数据导入及处理

# load data
data_train = pd.read_csv("mitbih_train.csv", sep=",", header=None).to_numpy()
data_test = pd.read_csv("mitbih_test.csv", sep=",", header=None).to_numpy()# get X and y
X_train, y_train = data_train[:, :data_train.shape[1]-2], data_train[:, -1]
X_test, y_test = data_test[:, :data_test.shape[1]-2], data_test[:, -1]# number of categories
num_of_categories = np.unique(y_train).shape[0]del data_train, data_test#indexing examples to show visualisations
examples_to_visualise = [np.where(y_test == i)[0][0] for i in range(5)]
titles = [ "nonectopic", "supraventricular ectopic beat", "ventricular ectopic beat", "fusion beat", "unknown"]# creation of tensors
X_train = np.expand_dims(tf.convert_to_tensor(X_train), axis=2)
X_test = np.expand_dims(tf.convert_to_tensor(X_test), axis=2)# one-hot encoding for 5 categories
y_train = tf.one_hot(y_train, num_of_categories)
y_test = tf.one_hot(y_test, num_of_categories)

建立模型并进行训练

# basic model
model = Sequential([Input(shape=[X_train.shape[1], 1]),Conv1D(filters=16, kernel_size=3, activation="relu"),BatchNormalization(),MaxPool1D(),Conv1D(filters=32, kernel_size=3, activation="relu"),BatchNormalization(),Conv1D(filters=64, kernel_size=3, activation="relu"),BatchNormalization(),Flatten(),Dense(20, activation="relu"),Dense(num_of_categories, activation="softmax")
]
)# train processmodel.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])checkPoint = ModelCheckpoint(filepath="model.h5", save_weights_only=False, monitor='val_accuracy',mode='max', save_best_only=True)model.fit(x=np.expand_dims(X_train, axis=2), y=y_train,batch_size=128, epochs=10, validation_data=(np.expand_dims(X_test, axis=2), y_test),callbacks=[checkPoint])model = tf.keras.models.load_model("model.h5")
loss, acc = model.evaluate(np.expand_dims(X_test, axis=2), y_test)

采用Occlusion Sensitivity方法进行可视化,相关的参考文献较多。

fig, axs = plt.subplots(nrows=5, ncols=1)
fig.suptitle("Occlusion sensitivity")
fig.tight_layout()
fig.set_size_inches(10, 10)
axs = axs.ravel()for c, row, ax, title in zip(range(5), examples_to_visualise, axs, titles):sensitivity, _ = signal_screen.calculate_occlusion_sensitivity(model=model,data=np.expand_dims(X_test[row, :], axis=(0, 2)),c=c,number_of_zeros=[15])# create gradient plotsignal_screen_tools.plot_with_gradient(ax=ax, y=X_test[row, :].ravel(), gradient=sensitivity[0], title=title)ax.set_xlabel("Samples[-]")ax.set_ylabel("ECG [-]")plt.show()

采用Saliency map方法进行可视化。

采用Grad-CAM方法进行可视化。

工学博士,担任《Mechanical System and Signal Processing》审稿专家,担任《中国电机工程学报》优秀审稿专家,《控制与决策》,《系统工程与电子技术》,《电力系统保护与控制》,《宇航学报》等EI期刊审稿专家。

擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

相关文章:

  • MySQL日志探索——redo log和bin log的刷盘时机详解
  • 实景三维:城市数据要素的新维度
  • YOLOv2
  • C++核心高级编程 --- 3、函数提高
  • 2024年阿里云服务器2核8G、4核16G、8核32G配置收费标准
  • Spring使用(一)注解
  • 梨花带雨网页音乐播放器二开优化修复美化版全开源版本源码
  • qT 地图显示飞机轨迹
  • C语言_第一轮笔记_指针
  • 数据仓库——事实表
  • 03-MySQl数据库的-用户管理
  • Stable Diffusion扩散模型推导公式的基础知识
  • R语言颜色细分
  • Leaflet使用多面(MultiPolygon)进行遥感影像掩膜报错解决之道
  • 【讲解下go和java的区别】
  • 【挥舞JS】JS实现继承,封装一个extends方法
  • Flannel解读
  • IndexedDB
  • Java 23种设计模式 之单例模式 7种实现方式
  • Lsb图片隐写
  • PHP的类修饰符与访问修饰符
  • React-redux的原理以及使用
  • Spring Boot快速入门(一):Hello Spring Boot
  • TCP拥塞控制
  • Traffic-Sign Detection and Classification in the Wild 论文笔记
  • 第13期 DApp 榜单 :来,吃我这波安利
  • 好的网址,关于.net 4.0 ,vs 2010
  • 基于组件的设计工作流与界面抽象
  • 开源地图数据可视化库——mapnik
  • 前端_面试
  • 前端技术周刊 2019-02-11 Serverless
  • 我建了一个叫Hello World的项目
  • 小李飞刀:SQL题目刷起来!
  • 用简单代码看卷积组块发展
  • #1014 : Trie树
  • #define MODIFY_REG(REG, CLEARMASK, SETMASK)
  • (Demo分享)利用原生JavaScript-随机数-实现做一个烟花案例
  • (笔记)Kotlin——Android封装ViewBinding之二 优化
  • (仿QQ聊天消息列表加载)wp7 listbox 列表项逐一加载的一种实现方式,以及加入渐显动画...
  • (附源码)python房屋租赁管理系统 毕业设计 745613
  • (附源码)springboot 智能停车场系统 毕业设计065415
  • (附源码)springboot电竞专题网站 毕业设计 641314
  • (牛客腾讯思维编程题)编码编码分组打印下标题目分析
  • (十二)devops持续集成开发——jenkins的全局工具配置之sonar qube环境安装及配置
  • (转)原始图像数据和PDF中的图像数据
  • ***测试-HTTP方法
  • .NET CORE 3.1 集成JWT鉴权和授权2
  • .net core 6 redis操作类
  • .net core webapi 大文件上传到wwwroot文件夹
  • .NET Framework 服务实现监控可观测性最佳实践
  • .Net Memory Profiler的使用举例
  • .Net MVC + EF搭建学生管理系统
  • .net 验证控件和javaScript的冲突问题
  • .NET6 开发一个检查某些状态持续多长时间的类
  • .net开发时的诡异问题,button的onclick事件无效