当前位置: 首页 > news >正文

Reinforced Causal Explainer for GNN论文笔记

论文:TPAMI 2023 图神经网络的强化因果解释器

论文代码地址:代码

目录

Abstract

Introduction

PRELIMINARIES

Causal Attribution of a Holistic Subgraph​

individual causal effect (ICE)​

*Causal Screening of an Edge Sequence

Reinforced Causal Explainer (RC-Explainer)​

Policy Network

Policy Gradient Training

Discussion

EXPERIMENTS

Evaluation Metrics

Evaluation of Explanations​


Abstract

Motivation:解释图神经网络(GNNs)预测结果来理解模型决策背后的原因。现有Feature attribution忽略了边之间的依赖关系,尤其是协同效应。

Method引入Reinforced Causal Explainer(RC-Explainer)实现因果筛选策略, 策略网络学习边序列生成策略(每个边缘被选中的概率),在每step选择一个潜在边缘作为action,获得由每个边的组合子图因果属性组成的reward,可突出解释边的依赖性、边的联盟的影响。

策略梯度来优化策略网络,并通过对GNN全局理解,RC-Explainer能为每个图实例提供模型级解释,并泛化到未见过的图。

Conclusion:在解释三个图分类数据集上不同的GNN时,RC-Explainerpredictive accuracycontrastivity等两个定量指标上实现了与最先进方法相当或更好的性能,并通过了合理性检查(sanity checks)视觉检查(visual inspections)

 一、Introduction

PRELIMINARIES

相关代码实现:Mutag_gnn.py

节点表示:

#获取节点表示def get_node_reps(self, x, edge_index, edge_attr, batch):node_x = self.node_emb(x)#节点嵌入层edge_attr = self.edge_emb(edge_attr)#边嵌入层# 对于每个 GINConv 单元for conv, batch_norm, ReLU in \zip(self.convs, self.batch_norms, self.relus):node_x = conv(node_x, edge_index, edge_attr)              #节点表示传递给GINConv层进行信息聚合node_x = ReLU(batch_norm(node_x))#标准化,激活函数return node_x

最终用于预测的表示: 

def get_graph_rep(self, x, edge_index, edge_attr, batch):node_x = self.get_node_reps(x, edge_index, edge_attr, batch)graph_x = global_mean_pool(node_x, batch)return graph_x
def get_pred(self, graph_x):pred = self.relu(self.lin1(graph_x))#线性层,relu处理图表示pred = self.lin2(pred)#预测self.readout = self.softmax(pred)return pred

Causal Attribution of a Holistic Subgraph

individual causal effect (ICE)

论文代码中对于互信息的实现,在reward的计算中

def get_reward(full_subgraph_pred, new_subgraph_pred, target_y, pre_reward, mode='mutual_info'):if mode in ['mutual_info']:#计算互信息,衡量完整子图预测值和新子图预测值之间的相似度# full_subgraph_pred:[batch_size, num_classes] reward:[batch_size]reward = torch.sum(full_subgraph_pred * torch.log(new_subgraph_pred + EPS), dim=1)#对每个样本,新子图预测的最大类别与目标类别相同+1;否则-1reward += 2 * (target_y == new_subgraph_pred.argmax(dim=1)).float() - 1.# print('reward2',reward)elif mode in ['binary']:# 新子图预测的最大类别与目标类别相同,奖励+1;否则-1reward = (target_y == new_subgraph_pred.argmax(dim=1)).float()reward = 2. * reward - 1.elif mode in ['cross_entropy']:# 交叉熵作为奖励,衡量完整子图预测值与目标类别之间的差异reward = torch.log(new_subgraph_pred + EPS)[:, target_y]# reward += pre_rewardreward += 0.97 * pre_rewardreturn reward

*Causal Screening of an Edge Sequence

Reinforced Causal Explainer (RC-Explainer)

 主要流程框架:train_test_pool_batch3.py

def test_policy_all_with_gnd(rc_explainer, model, test_loader, topN=None):rc_explainer.eval()model.eval()topK_ratio_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]acc_count_list = np.zeros(len(topK_ratio_list))precision_topN_count = 0.recall_topN_count = 0.with torch.no_grad():for graph in iter(test_loader):graph = graph.to(device)max_budget = graph.num_edges#最大预算state = torch.zeros(max_budget, dtype=torch.bool)#当前状态# 根据 top K 比率列表计算出需要检查准确率的预算列表check_budget_list = [max(int(_topK * max_budget), 1) for _topK in topK_ratio_list]valid_budget = max(int(0.9 * max_budget), 1)#有效预算for budget in range(valid_budget):#每一个预算available_actions = state[~state].clone()#可用的动作# 获取下一步的动作_, _, make_action_id, _ = rc_explainer(graph=graph, state=state, train_flag=False)# 将推断的动作应用到可用动作列表中available_actions[make_action_id] = Truestate[~state] = available_actions.clone()#更新当前状态# 如果当前预算需要检查准确率if (budget + 1) in check_budget_list:check_idx = check_budget_list.index(budget + 1)#查找当前预算在 check_budget_list 中的索引subgraph = relabel_graph(graph, state)# 用模型对子图进行预测subgraph_pred = model(subgraph.x, subgraph.edge_index, subgraph.edge_attr, subgraph.batch)# 计算准确率并累加到对应的位置acc_count_list[check_idx] += sum(graph.y == subgraph_pred.argmax(dim=1))print('graph.ground_truth_mask[0]',graph.ground_truth_mask[0])# 指定了 topN & 当前预算=topN-1if topN is not None and budget == topN - 1:print('graph.ground_truth_mask[0]',graph.ground_truth_mask[0])# 累加前N个动作的精度precision_topN_count += torch.sum(state*graph.ground_truth_mask[0])/topNrecall_topN_count += torch.sum(state*graph.ground_truth_mask[0])/sum(graph.ground_truth_mask[0])acc_count_list[-1] = len(test_loader)acc_count_list = np.array(acc_count_list)/len(test_loader)precision_topN_count = precision_topN_count / len(test_loader)recall_topN_count = recall_topN_count / len(test_loader)if topN is not None:print('\nACC-AUC: %.4f, Precision@5: %.4f, Recall@5: %.4f' %(acc_count_list.mean(), precision_topN_count, recall_topN_count))else:print('\nACC-AUC: %.4f' % acc_count_list.mean())print(acc_count_list)return acc_count_list.mean(), acc_count_list, precision_topN_count, recall_topN_count

 

其中这四步的实现: rc_explainer_pool.py

class RC_Explainer_Batch_star(RC_Explainer_Batch):def __init__(self, _model, _num_labels, _hidden_size, _use_edge_attr=False):super(RC_Explainer_Batch_star, self).__init__(_model, _num_labels, _hidden_size, _use_edge_attr=False)# 单层MLPdef build_edge_action_prob_generator(self):edge_action_prob_generator = nn.ModuleList()for i in range(self.num_labels):i_explainer = Sequential(Linear(self.hidden_size * (2 + self.use_edge_attr), self.hidden_size * 2),ELU(),Linear(self.hidden_size * 2, self.hidden_size),ELU(),Linear(self.hidden_size, 1)).to(device)edge_action_prob_generator.append(i_explainer)return edge_action_prob_generatordef forward(self, graph, state, train_flag=False):#整个图表示 graph_rep-->torch.Size([64, 32])graph_rep = self.model.get_graph_rep(graph.x, graph.edge_index, graph.edge_attr, graph.batch)#若不存在已使用的边,创建全0子图表示if len(torch.where(state==True)[0]) == 0:subgraph_rep = torch.zeros(graph_rep.size()).to(device)else:subgraph = relabel_graph(graph, state)#根据状态重新标记图subgraph_rep = self.model.get_graph_rep(subgraph.x, subgraph.edge_index, subgraph.edge_attr, subgraph.batch)# 可用边索引、属性 ava_edge_index = graph.edge_index.T[~state].T #torch.Size([2, 3666])ava_edge_attr = graph.edge_attr[~state]#torch.Size([3362, 3])#未使用边对应的节点表示->torch.Size([2153, 32])ava_node_reps = self.model.get_node_reps(graph.x, ava_edge_index, ava_edge_attr, graph.batch)# 学习每个候选动作表示if self.use_edge_attr:#使用边属性信息,将未使用边嵌入可用边表示ava_edge_reps = self.model.edge_emb(ava_edge_attr)ava_action_reps = torch.cat([ava_node_reps[ava_edge_index[0]],ava_node_reps[ava_edge_index[1]],ava_edge_reps], dim=1).to(device)else:ava_action_reps = torch.cat([ava_node_reps[ava_edge_index[0]],ava_node_reps[ava_edge_index[1]]], dim=1).to(device)#torch.Size([3824, 64])#边动作表示生成器ava_action_reps = self.edge_action_rep_generator(ava_action_reps)#torch.Size([3760, 32])#未使用边所属图ava_action_batch = graph.batch[ava_edge_index[0]]#[ 0,  0,  0,  ..., 63, 63, 63] torch.Size([4016])#图标签ava_y_batch = graph.y[ava_action_batch]#[0, 0, 0,  ..., 1, 1, 1] torch.Size([3794])# get the unique elements in batch, in cases where some batches are out of actions.unique_batch, ava_action_batch = torch.unique(ava_action_batch, return_inverse=True)#[64],[3760]#选择一个动作,预测未使用的边的动作概率ava_action_probs = self.predict_star(graph_rep, subgraph_rep, ava_action_reps, ava_y_batch, ava_action_batch)# print(ava_action_probs,ava_action_probs.size())# assert len(ava_action_probs) == sum(~state)#每个图中最大概率及动作added_action_probs, added_actions = scatter_max(ava_action_probs, ava_action_batch)if train_flag:#训练rand_action_probs = torch.rand(ava_action_probs.size()).to(device)# 生成一个与未使用的边的动作概率相同大小的随机概率张量#每个图中最大的随机概率动作_, rand_actions = scatter_max(rand_action_probs, ava_action_batch)return ava_action_probs, ava_action_probs[rand_actions], rand_actions, unique_batchreturn ava_action_probs, added_action_probs, added_actions, unique_batchdef predict_star(self, graph_rep, subgraph_rep, ava_action_reps, target_y, ava_action_batch):action_graph_reps = graph_rep - subgraph_rep#可用图表示action_graph_reps = action_graph_reps[ava_action_batch]#索引可用图表示#未使用边动作表示拼接动作图表示->完整的动作表示action_graph_reps = torch.cat([ava_action_reps, action_graph_reps], dim=1)action_probs = []for i_explainer in self.edge_action_prob_generator:#对于每个标签的动作解释器i_action_probs = i_explainer(action_graph_reps)#当前标签的动作解释器预测动作概率action_probs.append(i_action_probs)action_probs = torch.cat(action_probs, dim=1)#每个标签的动作概率连接,每一列->一个标签的动作概率#从预测的动作概率中索引标签对应的概率action_probs = action_probs.gather(1, target_y.view(-1,1))action_probs = action_probs.reshape(-1)#一维# action_probs = softmax(action_probs, ava_action_batch)# action_probs = F.sigmoid(action_probs)return action_probs

Policy Network

 论文相关代码实现:rc_explainer_pool.py  RC_Explainer_Batch_star()

ava_node_reps = self.model.get_node_reps(graph.x, ava_edge_index, ava_edge_attr, graph.batch)# 学习每个候选动作表示if self.use_edge_attr:#使用边属性信息,将未使用边嵌入可用边表示ava_edge_reps = self.model.edge_emb(ava_edge_attr)ava_action_reps = torch.cat([ava_node_reps[ava_edge_index[0]],ava_node_reps[ava_edge_index[1]],ava_edge_reps], dim=1).to(device)else:ava_action_reps = torch.cat([ava_node_reps[ava_edge_index[0]],ava_node_reps[ava_edge_index[1]]], dim=1).to(device)#torch.Size([3824, 64])#边动作表示生成器ava_action_reps = self.edge_action_rep_generator(ava_action_reps)#torch.Size([3760, 32])

论文相关代码实现:rc_explainer_pool.py 

def predict_star(self, graph_rep, subgraph_rep, ava_action_reps, target_y, ava_action_batch):action_graph_reps = graph_rep - subgraph_rep#可用图表示action_graph_reps = action_graph_reps[ava_action_batch]#索引可用图表示#未使用边动作表示拼接动作图表示->完整的动作表示action_graph_reps = torch.cat([ava_action_reps, action_graph_reps], dim=1)action_probs = []for i_explainer in self.edge_action_prob_generator:#对于每个标签的动作解释器i_action_probs = i_explainer(action_graph_reps)#当前标签的动作解释器预测动作概率action_probs.append(i_action_probs)action_probs = torch.cat(action_probs, dim=1)#每个标签的动作概率连接,每一列->一个标签的动作概率#从预测的动作概率中索引标签对应的概率action_probs = action_probs.gather(1, target_y.view(-1,1))action_probs = action_probs.reshape(-1)#一维# action_probs = softmax(action_probs, ava_action_batch)# action_probs = F.sigmoid(action_probs)return action_probs

 

 

Policy Gradient Training

 论文相关代码实现:train_test_pool_batch3.py  train_policy()

# 批次损失(RL REINFORCE策略梯度)batch_loss += torch.mean(- torch.log(beam_action_probs_list + EPS) * beam_reward_list)

Discussion

EXPERIMENTS

Evaluation Metrics

论文相关代码实现:一、ACC train_test_pool_batch3.py test_policy_all_with_gnd()

# 计算准确率并累加到对应的位置acc_count_list[check_idx] += sum(graph.y == subgraph_pred.argmax(dim=1))

Evaluation of Explanations

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • python基础语法 005 函数1-2 函数作用域
  • Linux - 基础开发工具(yum、vim、gcc、g++、make/Makefile、git)
  • 使用Go编写的持续下行测速脚本,快速消耗流量且不伤硬盘
  • 【排序 】
  • 多元输出表达(MOE)
  • 静态搜索iOS动态链接函数的调用位置
  • 神经网络识别数字图像案例
  • 昇思训练营打卡第二十四天(LSTM+CRF序列标注)
  • uniapp 小程序注册全局弹窗组件(无需引入,无需写标签)
  • 缓存与分布式锁
  • T113-i 倒车低概率性无反应,没有进入倒车视频界面
  • Spring-Cache 缓存
  • Zookeeper背景优缺点,以及应用场景
  • 头歌资源库(32)n皇后问题
  • 【坑】微信小程序开发wx.uploadFile和wx.request的返回值格式不同
  • 【159天】尚学堂高琪Java300集视频精华笔记(128)
  • Babel配置的不完全指南
  • C++类中的特殊成员函数
  • css布局,左右固定中间自适应实现
  • IOS评论框不贴底(ios12新bug)
  • Javascript 原型链
  • JAVA并发编程--1.基础概念
  • Python socket服务器端、客户端传送信息
  • Python十分钟制作属于你自己的个性logo
  • Rancher-k8s加速安装文档
  • Vue学习第二天
  • Zsh 开发指南(第十四篇 文件读写)
  • 诡异!React stopPropagation失灵
  • 马上搞懂 GeoJSON
  • 如何利用MongoDB打造TOP榜小程序
  • 网络应用优化——时延与带宽
  • 微服务框架lagom
  • 微信小程序开发问题汇总
  • 1.Ext JS 建立web开发工程
  • 完善智慧办公建设,小熊U租获京东数千万元A+轮融资 ...
  • (4)STL算法之比较
  • (二)fiber的基本认识
  • (六)Hibernate的二级缓存
  • (每日一问)基础知识:堆与栈的区别
  • (四)c52学习之旅-流水LED灯
  • (转)可以带来幸福的一本书
  • (自适应手机端)响应式新闻博客知识类pbootcms网站模板 自媒体运营博客网站源码下载
  • .bat批处理(五):遍历指定目录下资源文件并更新
  • .net core 外观者设计模式 实现,多种支付选择
  • .NET Core中的去虚
  • .NET 同步与异步 之 原子操作和自旋锁(Interlocked、SpinLock)(九)
  • .NET 中 GetProcess 相关方法的性能
  • .NetCore实践篇:分布式监控Zipkin持久化之殇
  • @SuppressWarnings注解
  • @Transactional 详解
  • [2018/11/18] Java数据结构(2) 简单排序 冒泡排序 选择排序 插入排序
  • [ACTF2020 新生赛]Upload 1
  • [Algorithm][综合训练][kotori和n皇后][取金币][矩阵转置]详细讲解
  • [CSS]文字旁边的竖线以及布局知识
  • [cvpr 2024 目标检测 前沿研究 热点] cpvr 2024中与目标检测主题有关的论文