当前位置: 首页 > news >正文

从ANN到SNN的转换:实现、原理及两种归一化方法【MINIST、实战】

从ANN到SNN的转换:实现、原理及两种归一化方法

引言

随着神经形态计算的迅猛发展,脉冲神经网络(Spiking Neural Networks, SNNs)作为一种仿生神经计算模型,逐渐展现出其在低功耗和事件驱动计算领域的巨大潜力。不同于传统的人工神经网络(Artificial Neural Networks, ANNs),SNN通过二值化的脉冲信号进行信息传递,从而更接近生物神经元的行为。其离散时间、事件触发的处理模式使得SNN在能效和计算效率上具有天然的优势,尤其在神经形态硬件上更为适合。

尽管SNN具备诸多优点,但由于脉冲神经元的异质性以及神经元发放模式的离散性,直接训练SNN模型存在较大挑战。为此,基于ANN到SNN的转换方法成为了当前热门的研究方向。通过将预先训练好的ANN转换为SNN,研究人员能够在保留ANN性能的前提下,充分利用SNN的能效优势。本文将介绍如何通过一套系统的方法实现ANN到SNN的转换,并深入探讨两种归一化方法:MaxNorm和RobustNorm,帮助我们更好地理解这一过程的细节。

1. ANN2SNN转换概述

1.1 ANN与SNN的核心差异

**人工神经网络(ANN)**中的神经元采用连续的激活函数,如ReLU、Sigmoid或Tanh等,激活值可以是任意实数。这种方式虽然能够实现复杂的非线性映射,但其计算能耗较高,且不具备生物神经元的事件驱动特性。

**脉冲神经网络(SNN)**的工作原理与ANN有显著区别。SNN的神经元使用脉冲(spike)作为信息载体,激活方式通过离散脉冲的形式表现。每个神经元的发放过程是基于输入电压的累积,当累积的电压达到某个阈值时,神经元会“发放”脉冲信号。SNN中的常见神经元模型包括:

  • 积分发放神经元(Integrate-and-Fire, IF Neuron):IF神经元通过累积输入电压,当电压超过阈值时,神经元发放脉冲,随后电压重置为初始值。
  • 泄露积分发放神经元(Leaky Integrate-and-Fire, LIF Neuron):在IF模型的基础上增加了泄露机制,使得神经元的电压在没有持续输入时会随时间衰减,更加接近生物神经元的动态特性。

由于ANN和SNN在信息传递机制上的本质差异,直接将ANN的权重应用于SNN是不可行的。因此,在实现ANN到SNN的转换时,需要对神经元的行为和模型的结构进行调整。具体来说,主要挑战在于如何将ANN中的连续激活值有效地映射到SNN中的脉冲发放行为上。

1.2 ANN2SNN的转换流程

ANN到SNN的转换是一个系统化的过程,核心步骤包括训练ANN、激活值归一化处理以及神经元替换。整个流程可概括为以下五个步骤:

  1. 训练ANN模型:首先使用标准的机器学习框架(如PyTorch、TensorFlow)训练一个高性能的ANN模型。通常采用卷积神经网络(CNN)架构,在任务(如图像分类)上进行训练。
  2. 激活值记录:在ANN的训练过程中,插入电压钩子(Voltage Hook)以记录每层网络的激活值。这一步的目的是获取每层神经元的激活范围,便于后续的归一化处理。
  3. 归一化处理:对每层神经元的激活值进行归一化,确保ANN中的权重在SNN中依然能产生合理的神经元发放行为。最常用的两种归一化方法是基于最大值的MaxNorm和基于分位数的RobustNorm。
  4. 替换为脉冲神经元:将ANN中的连续激活函数(如ReLU)替换为SNN中的脉冲神经元(如IF或LIF神经元),并应用归一化系数对输入电压进行缩放。
  5. SNN仿真与验证:在多步时间仿真下运行SNN模型,并在特定任务(如图像分类)上验证SNN的性能。

2. 两种归一化方法

归一化处理是ANN2SNN转换中的关键步骤。
在这里插入图片描述

可以发现,两者的曲线几乎一致。需要注意的是,脉冲频率不可能高于1,因此IF神经元无法拟合ANN中ReLU的输入大于1的情况。

由于SNN神经元的发放特性不同于ANN中的连续激活函数,为了保证模型在转换后的SNN中依旧具有良好的表现,需要对输入的电压或电流进行适当的缩放。本文讨论了两种归一化方法:MaxNorm和RobustNorm。

2.1 MaxNorm归一化

MaxNorm是最简单的归一化方式,适用于没有大量噪声或异常激活值的数据。该方法的核心思想是将每层神经元的输入电压缩放到其激活值的最大范围内,以确保神经元能够有效发放脉冲。

  1. 激活值的最大值收集:遍历训练数据集,记录每一层ReLU激活的最大值( s m a x s_{max} smax)。
  2. 转换为SNN:替换ReLU层为IF神经元,激活值通过一个比例缩放:
    输入 = 输入 s m a x 输出 = 输出 × s m a x \text{输入} = \frac{\text{输入}}{s_{max}} \quad \text{输出} = \text{输出} \times s_{max} 输入=smax输入输出=输出×smax
    即,输入电压缩放为 1 / s m a x 1/s_{max} 1/smax ,IF神经元发放脉冲后,再将输出电压放大回 s m a x s_{max} smax

这种归一化方法的优点在于简单高效,尤其是在输入数据比较规整、没有极端异常值的情况下,能够较好地保持ANN模型的性能。

代码示例:
model._modules[name] = nn.Sequential(VoltageScaler(1.0 / max_item),    # 缩放输入neuron.IFNode(v_threshold=1., v_reset=None),    # IF神经元VoltageScaler(max_item)    # 恢复输出
)

2.2 RobustNorm归一化

RobustNorm归一化是一种更加稳健的归一化策略,特别适用于数据中可能包含噪声或异常激活值的情况。与MaxNorm不同,RobustNorm不直接使用最大值进行归一化,而是使用激活值的某个高分位数(如99.9%)来确定归一化系数。
在这里插入图片描述

这种方法减少了极端激活值对归一化过程的影响,确保模型在数据分布复杂或含有噪声的情况下能够保持性能。

  1. 激活值的分位数收集:遍历训练数据集,记录每一层ReLU激活的某个高分位数(如99.9%)。
  2. 归一化权重和偏置:在替换神经元之前,对权重和偏置进行缩放,确保层与层之间的比例一致。
  3. 转换为SNN:类似MaxNorm,将激活值进行分位数缩放。

这种方法通过调整每一层的权重,进一步优化了层间的信息传递,减少了转换过程中精度的损失。

代码示例:
# 在替换神经元之前,调整权重
if self.prev_scale is not None:current_scale = max_itemprev_scale = self.prev_scalemodule.weight.data = module.weight.data * (prev_scale / current_scale)if hasattr(module, 'bias') and module.bias is not None:module.bias.data = module.bias.data * (prev_scale / current_scale)
self.prev_scale = max_item

3. 实现流程

在代码中,首先训练了一个具有较好精度的卷积神经网络(CNN)模型。随后使用VoltageHook来遍历训练数据,收集激活值的范围。根据收集到的最大激活值或分位数,进行归一化并替换成SNN中的IF神经元。

接下来详细解释代码中几个关键模块的功能,包括VoltageHookVoltageScalerConverter等。

3.1 VoltageHook

VoltageHook是一个自定义层,用于记录ANN中每一层的激活值。这个激活值在SNN中用于归一化(scaling)。在ANN的ReLU激活后,我们需要知道激活值的范围,以便后续归一化。

  • scale:保存激活层的尺度,用于后续的SNN模型归一化。
  • mode:决定使用最大值(MaxNorm)还是分位数(RobustNorm)来记录激活值。
class VoltageHook(nn.Module):def __init__(self, scale=1.0, mode='Max'):"""确定在ANN推理中激活的范围。"""super().__init__()self.register_buffer('scale', torch.tensor(scale))self.mode = modedef forward(self, x):if self.mode.lower() in ['max']:s_t = x.max().detach()  # 获取该层的最大激活值else:s_t = torch.tensor(np.percentile(x.detach().cpu(), float(self.mode[:-1])))  # 获取指定分位数的激活值self.scale = s_t  # 将激活值的最大值或分位数保存为该层的scalereturn x

3.2 VoltageScaler

VoltageScaler的作用是在SNN中对输入和输出进行电压的缩放。由于SNN神经元的行为与ANN不同,我们需要根据先前收集到的激活值对神经元输入和输出电压进行缩放。

  • scale:用于缩放输入电压或恢复输出电压。
  • forward:将输入乘以scale进行缩放。
class VoltageScaler(nn.Module):def __init__(self, scale=1.0):"""缩放SNN推理中电流"""super().__init__()self.register_buffer('scale', torch.tensor(scale))def forward(self, x):return x * self.scale  # 对输入电压进行缩放

3.3 Converter

Converter类负责从ANN到SNN的转换,并处理激活值归一化的过程。它包含三个主要功能:

  • 设置VoltageHook:遍历模型的每一层,并在ReLU激活层后插入VoltageHook,用于收集激活值。
  • 数据收集:通过训练数据集,计算每一层的激活值最大值或分位数。
  • 替换为IFNode:将ReLU层替换为SNN的IF神经元,并根据之前收集的scale进行电压的归一化。
class Converter(nn.Module):def __init__(self, dataloader, mode='Max'):super().__init__()self.mode = modeself.dataloader = dataloaderself.device = Noneself.prev_scale = None  # 添加一个变量,用于存储前一层的最大激活值def forward(self, origin_model):# 创建模型的副本relu_model = copy.deepcopy(origin_model)if self.device is None:self.device = next(relu_model.parameters()).devicerelu_model.eval()# 插入 VoltageHookmodel = self.set_voltagehook(relu_model, mode=self.mode).to(self.device)# 使用训练数据集遍历模型,收集激活值for _, (imgs, _) in enumerate(tqdm(self.dataloader)):model(imgs.to(self.device))# 替换为 IFNodemodel = self.replace_by_ifnode(model)return model@staticmethoddef set_voltagehook(model, mode='MaxNorm'):"""在每个ReLU层后插入 VoltageHook,用于收集该层的激活值"""for name, module in model._modules.items():if hasattr(module, "_modules"):model._modules[name] = Converter.set_voltagehook(module, mode=mode)if module.__class__.__name__ == 'ReLU':model._modules[name] = nn.Sequential(nn.ReLU(),VoltageHook(mode=mode)  # 插入 VoltageHook)return modeldef replace_by_ifnode(self, model):"""将每层 ReLU 层替换为 IFNode 神经元"""for name, module in model._modules.items():if hasattr(module, "_modules"):model._modules[name] = self.replace_by_ifnode(module)# 检查是否为ReLU层,并且有VoltageHookif module.__class__.__name__ == 'Sequential' and len(module) == 2 and \module[0].__class__.__name__ == 'ReLU' and \module[1].__class__.__name__ == 'VoltageHook':max_item = module[1].scale.item()  # 获取 VoltageHook 中记录的最大激活值# 替换为 SNN 层model._modules[name] = nn.Sequential(VoltageScaler(1.0 / max_item),  # 归一化输入neuron.IFNode(v_threshold=1., v_reset=None),  # 替换为 IFNodeVoltageScaler(max_item)  # 恢复输出电压)return model

3.4 核心流程概述

  1. 训练ANN模型:首先,训练一个卷积神经网络(CNN)在MNIST数据集上进行分类。
  2. 激活值收集:通过VoltageHook层,记录每一层ReLU的激活值,采用两种归一化方式:MaxNorm和RobustNorm。
  3. 模型转换:使用收集到的激活值,将ANN模型转换为SNN,将ReLU替换为IF神经元。
  4. SNN模拟与验证:通过多个时间步仿真SNN,并评估其精度。

通过上述流程,用户可以将ANN模型转换为SNN,并根据不同的归一化方式对其性能进行比较。

4. 实验结果

在实验中,采用了两种归一化方式进行SNN转换,分别为MaxNorm和RobustNorm。转换后的SNN通过50个时间步进行仿真,并比较了两种方法的精度随时间步长的变化。

  • MaxNorm:SNN模型使用最大值归一化,随着时间步长的增加,SNN的精度逐渐提高,最终在50个时间步的仿真下达到较高精度。
  • RobustNorm:基于99.9%分位数的归一化方式,SNN精度表现类似,但对异常激活值的敏感性较低。
    在这里插入图片描述

精度结果展示了两种转换方式的优势和不足,MaxNorm简单直接,而RobustNorm更加稳健。

5. 结论

通过本文的分析和实验,我们展示了ANN到SNN转换的一般方法,以及两种不同的归一化策略。MaxNorm适合简单的场景,而RobustNorm在噪声较大的数据上具有更好的鲁棒性。SNN模型的转换不仅能提升计算效率,还能在硬件中实现低功耗的神经网络应用,为未来神经形态计算的发展提供了有效的路径。

附录:完整代码

from tqdm import tqdm
from spikingjelly.clock_driven import neuron
import copy
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
import numpy as npclass VoltageHook(nn.Module):def __init__(self, scale=1.0, mode='Max'):"""确定在ANN推理中激活的范围。"""super().__init__()self.register_buffer('scale', torch.tensor(scale))self.mode = modedef forward(self, x):if self.mode.lower() in ['max']:s_t = x.max().detach()else:s_t = torch.tensor(np.percentile(x.detach().cpu(), float(self.mode[:-1])))self.scale = s_treturn xclass VoltageScaler(nn.Module):def __init__(self, scale=1.0):"""缩放SNN推理中电流"""super().__init__()self.register_buffer('scale', torch.tensor(scale))def forward(self, x):return x * self.scale# def extra_repr(self):#     return '%f' % self.scale.item()class Converter(nn.Module):def __init__(self, dataloader, mode='Max'):super().__init__()self.mode = modeself.dataloader = dataloaderself.device = Noneself.prev_scale = None  # 添加一个变量,用于存储前一层的最大激活值def forward(self, origin_model):relu_model = copy.deepcopy(origin_model)if self.device is None:self.device = next(relu_model.parameters()).devicerelu_model.eval()model = self.set_voltagehook(relu_model, mode=self.mode).to(self.device)for _, (imgs, _) in enumerate(tqdm(self.dataloader)):model(imgs.to(self.device))model = self.replace_by_ifnode(model)return model@staticmethoddef set_voltagehook(model, mode='MaxNorm'):for name, module in model._modules.items():if hasattr(module, "_modules"):model._modules[name] = Converter.set_voltagehook(module, mode=mode)if module.__class__.__name__ == 'ReLU':model._modules[name] = nn.Sequential(nn.ReLU(),VoltageHook(mode=mode))return modeldef replace_by_ifnode(self, model):for name, module in model._modules.items():if hasattr(module, "_modules"):model._modules[name] = self.replace_by_ifnode(module)# 检查是否为ReLU层,并且有VoltageHookif module.__class__.__name__ == 'Sequential' and len(module) == 2 and \module[0].__class__.__name__ == 'ReLU' and \module[1].__class__.__name__ == 'VoltageHook':max_item = module[1].scale.item()# # 在替换神经元之前,调整权重# if self.prev_scale is not None:#     # 获取前一层的最大值 (𝜆^(𝑙−1)) 和当前层的最大值 (𝜆^𝑙)#     current_scale = max_item#     prev_scale = self.prev_scale##     # 按照 𝐖^𝑙 → 𝐖^𝑙 * (𝜆^(𝑙−1) / 𝜆^𝑙) 调整权重#     if hasattr(module, 'weight'):#         module.weight.data = module.weight.data * (prev_scale / current_scale)#     elif hasattr(module, 'bias') and module.bias is not None:#         module.bias.data = module.bias.data * (prev_scale / current_scale)## # 更新 prev_scale 为当前层的最大值# self.prev_scale = max_item# 替换为 SNN 层model._modules[name] = nn.Sequential(VoltageScaler(1.0 / max_item),neuron.IFNode(v_threshold=1., v_reset=None),VoltageScaler(max_item))return modelclass CNN(nn.Module):def __init__(self):super().__init__()self.network = nn.Sequential(nn.Conv2d(1, 32, 3, 1),  # 输入通道为1,输出通道为32,卷积核为3x3,步长为1nn.BatchNorm2d(32),  # 批归一化nn.ReLU(),  # ReLU激活nn.Conv2d(32, 32, 3, 1),  # 第二个卷积层,通道数不变nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(2, 2),  # 2x2最大池化nn.Conv2d(32, 64, 3, 1),  # 第三个卷积层,输出通道为64nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(64, 64, 3, 1),  # 第四个卷积层,通道数保持64nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2, 2),  # 2x2最大池化nn.Flatten(),  # 展平操作,将卷积层的输出展平为一维向量nn.Linear(64 * 4 * 4, 512),  # 全连接层,输入为卷积层输出的展平结果nn.ReLU(),  # ReLU激活nn.Linear(512, 10),  # 最后一层,全连接层,输出为10类nn.Softmax(dim=1)  # Softmax 输出)## def __init__(self):#     super().__init__()#     self.network = nn.Sequential(#         nn.Conv2d(1, 32, 3, 1),#         nn.BatchNorm2d(32),#         nn.ReLU(),#         nn.AvgPool2d(2, 2),##         nn.Conv2d(32, 32, 3, 1),#         nn.BatchNorm2d(32),#         nn.ReLU(),#         nn.AvgPool2d(2, 2),##         nn.Conv2d(32, 32, 3, 1),#         nn.BatchNorm2d(32),#         nn.ReLU(),#         nn.AvgPool2d(2, 2),##         nn.Flatten(),#         nn.Linear(32, 10)#     )def forward(self, x):x = self.network(x)return xdef val(net, device, data_loader, T=None):net.eval().to(device)correct = 0.0total = 0.0if T is not None:corrects = np.zeros(T)with torch.no_grad():for batch, (img, label) in enumerate(tqdm(data_loader)):img = img.to(device)if T is None:out = net(img)correct += (out.argmax(dim=1) == label.to(device)).float().sum().item()else:for m in net.modules():if hasattr(m, 'reset'):m.reset()for t in range(T):if t == 0:out = net(img)else:out += net(img)corrects[t] += (out.argmax(dim=1) == label.to(device)).float().sum().item()total += out.shape[0]return correct / total if T is None else corrects / totaldef main():torch.random.manual_seed(0)torch.cuda.manual_seed(0)device = 'cuda'dataset_dir = '../MNIST'batch_size = 100T = 50# 训练参数lr = 1e-3epochs = 10model = CNN().to(device)train_data_dataset = torchvision.datasets.MNIST(root=dataset_dir,train=True,transform=torchvision.transforms.ToTensor(),download=True)train_data_loader = torch.utils.data.DataLoader(dataset=train_data_dataset,batch_size=batch_size,shuffle=True,drop_last=False)test_data_dataset = torchvision.datasets.MNIST(root=dataset_dir,train=False,transform=torchvision.transforms.ToTensor(),download=True)test_data_loader = torch.utils.data.DataLoader(dataset=test_data_dataset,batch_size=50,shuffle=True,drop_last=False)## # 定义损失函数和优化器# loss_function = nn.CrossEntropyLoss()# optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)# # 开始训练模型# for epoch in range(epochs):#     model.train()#     running_loss = 0.0#     for (img, label) in train_data_loader:#         optimizer.zero_grad()#         out = model(img.to(device))#         loss = loss_function(out, label.to(device))#         loss.backward()#         optimizer.step()#         running_loss += loss.item()##     print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_data_loader):.4f}')##     # 保存模型#     torch.save(model.state_dict(), 'paper_mnist_cnn_model.pth')##     # 每个epoch后验证精度#     acc = val(model, device, test_data_loader)#     print(f'Validation Accuracy after epoch {epoch + 1}: {acc:.3f}')#     print()model.load_state_dict(torch.load('paper_mnist_cnn_model.pth'))acc = val(model, device, test_data_loader)print('ANN Validating Accuracy: %.4f' % (acc))# 使用转换后的模型print('---------------------------------------------')print('Converting using MaxNorm')model_converter = Converter(mode='max', dataloader=train_data_loader)snn_model = model_converter(model)print('Simulating...')mode_max_accs = val(snn_model, device, test_data_loader, T=T)print(f'SNN accuracy (simulation {T} time-steps): {mode_max_accs[-1]:.4f}')# 后续其他转换逻辑保持不变print('---------------------------------------------')print('Converting using RobustNorm')model_converter = Converter(mode='99.9%', dataloader=train_data_loader)snn_model = model_converter(model)print('Simulating...')mode_robust_accs = val(snn_model, device, test_data_loader, T=T)print(f'SNN accuracy (simulation {T} time-steps): {mode_robust_accs[-1]:.4f}')# 绘制不同转换方式下的精度随时间步长的变化fig = plt.figure()plt.plot(np.arange(0, T), mode_max_accs, label='mode: max')plt.plot(np.arange(0, T), mode_robust_accs, label='mode: 99.9%')plt.legend()plt.xlabel('t')plt.ylabel('Acc')plt.show()if __name__ == '__main__':main()

参考链接:
ANN转换SNN — spikingjelly alpha 文档
Frontiers | Conversion of Continuous-Valued Deep Networks to Efficient Event-Driven Networks for Image Classification

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【JVM】类加载
  • 上海亚商投顾:沪指探底回升 华为产业链午后爆发
  • Js中的pick函数
  • 关于STM32项目面试题01:电源
  • sqli-labs靶场自动化利用工具——第1关
  • 深入理解 C++ 中的 static_assert 编译期断言
  • Linux下的简单TCP客户端和服务器
  • 爬虫逆向学习(六):补环境过某数四代
  • 用Python创建一个键盘输入捕获程序
  • 【JavaScript】数据结构之树
  • C# 禁止程序重复启动
  • CSS3 过渡
  • Qt控制开发板的LED
  • 【文件包含】——日志文件注入
  • Unity-Transform类-旋转
  • 《深入 React 技术栈》
  • 【402天】跃迁之路——程序员高效学习方法论探索系列(实验阶段159-2018.03.14)...
  • 【译】React性能工程(下) -- 深入研究React性能调试
  • Facebook AccountKit 接入的坑点
  • gulp 教程
  • If…else
  • JavaScript异步流程控制的前世今生
  • Java多线程(4):使用线程池执行定时任务
  • Laravel 菜鸟晋级之路
  • Laravel深入学习6 - 应用体系结构:解耦事件处理器
  • Swoft 源码剖析 - 代码自动更新机制
  • 半理解系列--Promise的进化史
  • 从伪并行的 Python 多线程说起
  • 代理模式
  • 服务器从安装到部署全过程(二)
  • 回顾 Swift 多平台移植进度 #2
  • 基于Javascript, Springboot的管理系统报表查询页面代码设计
  • 每个JavaScript开发人员应阅读的书【1】 - JavaScript: The Good Parts
  • 目录与文件属性:编写ls
  • 如何优雅的使用vue+Dcloud(Hbuild)开发混合app
  • 智能合约Solidity教程-事件和日志(一)
  • 正则表达式-基础知识Review
  • ​渐进式Web应用PWA的未来
  • ​学习笔记——动态路由——IS-IS中间系统到中间系统(报文/TLV)​
  • #pragma once
  • #QT(QCharts绘制曲线)
  • (16)Reactor的测试——响应式Spring的道法术器
  • (1综述)从零开始的嵌入式图像图像处理(PI+QT+OpenCV)实战演练
  • (24)(24.1) FPV和仿真的机载OSD(三)
  • (c语言版)滑动窗口 给定一个字符串,只包含字母和数字,按要求找出字符串中的最长(连续)子串的长度
  • (LeetCode 49)Anagrams
  • (第8天)保姆级 PL/SQL Developer 安装与配置
  • (动态规划)5. 最长回文子串 java解决
  • (附源码)计算机毕业设计SSM保险客户管理系统
  • (全部习题答案)研究生英语读写教程基础级教师用书PDF|| 研究生英语读写教程提高级教师用书PDF
  • (三)uboot源码分析
  • (贪心) LeetCode 45. 跳跃游戏 II
  • (一)SvelteKit教程:hello world
  • .config、Kconfig、***_defconfig之间的关系和工作原理
  • .gitignore文件忽略的内容不生效问题解决