当前位置: 首页 > news >正文

Simple-STNDT使用Transformer进行Spike信号的表征学习(三)训练与评估

文章目录

    • 1. 评估指标
    • 2. 训练准备
    • 3. debug测试
    • 4. train-val函数

1. 评估指标

import numpy as np
from scipy.special import gammaln
import torchdef neg_log_likelihood(rates, spikes, zero_warning=True):"""Calculates Poisson negative log likelihood given rates and spikes.formula: -log(e^(-r) / n! * r^n)= r - n*log(r) + log(n!)Parameters----------rates : np.ndarraynumpy array containing rate predictionsspikes : np.ndarraynumpy array containing true spike countszero_warning : bool, optionalWhether to print out warning about 0 rate predictions or notReturns-------floatTotal negative log-likelihood of the data"""assert spikes.shape == rates.shape, \f"neg_log_likelihood: Rates and spikes should be of the same shape. spikes: {spikes.shape}, rates: {rates.shape}"if np.any(np.isnan(spikes)):mask = np.isnan(spikes)rates = rates[~mask]spikes = spikes[~mask]assert not np.any(np.isnan(rates)), \"neg_log_likelihood: NaN rate predictions found"assert np.all(rates >= 0), \"neg_log_likelihood: Negative rate predictions found"if (np.any(rates == 0)):rates[rates == 0] = 1e-9result = rates - spikes * np.log(rates) + gammaln(spikes + 1.0)return np.sum(result)def bits_per_spike(rates, spikes):"""Computes bits per spike of rate predictions given spikes.Bits per spike is equal to the difference between the log-likelihoods (in base 2)of the rate predictions and the null model (i.e. predicting mean firing rate of each neuron)divided by the total number of spikes.Parameters----------rates : np.ndarray3d numpy array containing rate predictionsspikes : np.ndarray3d numpy array containing true spike countsReturns-------floatBits per spike of rate predictions"""nll_model = neg_log_likelihood(rates, spikes)nll_null = neg_log_likelihood(np.tile(np.nanmean(spikes, axis=(0,1), keepdims=True), (spikes.shape[0], spikes.shape[1], 1)), spikes, zero_warning=False)return (nll_null - nll_model) / np.nansum(spikes) / np.log(2)

2. 训练准备

from torch.utils.data import DataLoader
from dataset import make_datasets, mask_batch
from model import SpatioTemporalNDT
from metric import bits_per_spike
import torch
from torch.optim import AdamW
from torch import nnbatch_size = 16
lr = 1e-3
train_dataset, val_dataset = make_datasets()
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size*2, shuffle=False)
trial_length = 160
neuron_num = 130
model = SpatioTemporalNDT(trial_length, neuron_num)
num_epochs = 50
optim = AdamW(model.parameters(), lr=lr)
log_interval = 20

3. debug测试

def param_num(model):return sum(param.numel() for param in model.parameters() if param.requires_grad)def debug_test():spikes, heldout_spikes, forward_spikes = next(iter(train_dataloader))print(spikes.shape)             # [16, 120, 98]print(heldout_spikes.shape)     # [16, 120, 32]print(forward_spikes.shape)     # [16, 40, 130]masked_spikes, labels = mask_batch(spikes, heldout_spikes, forward_spikes)print(masked_spikes.shape)      # [16, 160, 130]print(labels.shape)             # [16, 160, 130]print(param_num(model))         # 256886loss, decoder_rates = model.forward(masked_spikes, labels)print(loss)                     # tensor(1.2356, grad_fn=<MeanBackward0>)print(decoder_rates.shape)      # torch.Size([16, 160, 130])val_loss, val_score = valid(val_dataloader, model)print(val_loss)print(val_score)

4. train-val函数

def train(model, dataloader, val_dataloader, num_epochs, optim):for epoch in range(num_epochs):print(f"--------- Epoch{epoch:2d} ----------")train_loss = []for i, (spikes, heldout_spikes, forward_spikes) in enumerate(dataloader):masked_spikes, labels = mask_batch(spikes, heldout_spikes, forward_spikes)loss, decoder_rates = model(masked_spikes, labels)optim.zero_grad()loss.backward()nn.utils.clip_grad_norm_(model.parameters(), 200.0)optim.step()with torch.no_grad():train_loss.append(loss.item())if i % log_interval == 0:print(f"Train loss: {sum(train_loss)/len(train_loss)}")val_loss, val_score = valid(val_dataloader, model)print(f"val loss: {float(val_loss)}")print(f"val score: {float(val_score)}")print()def valid(val_dataloader, model):model.eval()pred_rates = []heldout_spikes_full = []loss_list = []with torch.no_grad():for spikes, heldout_spikes, forward_spikes in val_dataloader:no_mask_labels = spikes.clone()no_mask_labels = torch.cat([no_mask_labels, torch.zeros_like(heldout_spikes)], -1)no_mask_labels = torch.cat([no_mask_labels, torch.zeros_like(forward_spikes)], 1)no_mask_labels[:, :, -heldout_spikes.size(-1):] = -100 # unmasked_labelno_mask_labels[:, -forward_spikes.size(1):,:] = -100 # unmasked_labelspikes = torch.cat([spikes, torch.zeros_like(heldout_spikes)], -1)spikes = torch.cat([spikes, torch.zeros_like(forward_spikes)], 1)loss, batch_rates = model(spikes, no_mask_labels)pred_rates.append(batch_rates)heldout_spikes_full.append(heldout_spikes)loss_list.append(loss)heldout_spikes = torch.cat(heldout_spikes_full, dim=0)pred_rates = torch.cat(pred_rates, dim=0)eval_rates_heldout = torch.exp(pred_rates.clone()[:, :heldout_spikes.size(1), -heldout_spikes.size(-1):]).numpy()[()].astype('float')eval_spikes_heldout = heldout_spikes.clone().numpy()[()].astype('float')# print(eval_rates_heldout.shape)     # (270, 120, 32)# print(eval_spikes_heldout.shape)    # (270, 120, 32)return sum(loss_list), float(bits_per_spike(eval_rates_heldout, eval_spikes_heldout))

最后,开始训练:

(stndt) D:\STNDT>python main.py
--------- Epoch 0 ----------
Train loss: 1.2486777305603027
Train loss: 0.5138219218878519
Train loss: 0.32351083744589876
val loss: 0.8636534214019775
val score: -0.39136893422272767--------- Epoch 1 ----------
Train loss: 0.09501783549785614
Train loss: 0.09383604036910194
Train loss: 0.09296295773692248
val loss: 0.8206770420074463
val score: -0.09666108663240561--------- Epoch 2 ----------
Train loss: 0.09622671455144882
Train loss: 0.09049306774423235
Train loss: 0.08994600358532696
val loss: 0.812911331653595
val score: -0.04202061410637105--------- Epoch 3 ----------
Train loss: 0.09225568175315857
Train loss: 0.09019481816462108
Train loss: 0.08970968806888999
val loss: 0.8099062442779541
val score: -0.019777008609723395--------- Epoch 4 ----------
Train loss: 0.08371596038341522
Train loss: 0.08918796905449458
Train loss: 0.0894875490083927
val loss: 0.8083348274230957
val score: -0.008896993842432857--------- Epoch 5 ----------
Train loss: 0.09019782394170761
Train loss: 0.08884035441137496
Train loss: 0.08963883395602064
val loss: 0.8072853088378906
val score: -0.0026569800293788507--------- Epoch 6 ----------
Train loss: 0.09667835384607315
Train loss: 0.09060979953833989
Train loss: 0.08956735653848183
val loss: 0.8064565658569336
val score: 0.0003163842262874261--------- Epoch 7 ----------
Train loss: 0.08744495362043381
Train loss: 0.08888665941499528
Train loss: 0.08930287855427439
val loss: 0.8058080077171326
val score: 0.005321093845270125--------- Epoch 8 ----------
Train loss: 0.10221674293279648
Train loss: 0.09078312771660942
Train loss: 0.08951869806865366
val loss: 0.8044026494026184
val score: 0.007113516568588765--------- Epoch 9 ----------
Train loss: 0.09160886704921722
Train loss: 0.08984803798652831
Train loss: 0.0897282888976539
val loss: 0.803226113319397
val score: 0.01217366049067505--------- Epoch10 ----------
Train loss: 0.09165512025356293
Train loss: 0.08854220310846965
Train loss: 0.08920388268988307
val loss: 0.8014105558395386
val score: 0.015657932109121083--------- Epoch11 ----------
Train loss: 0.07934647053480148
Train loss: 0.08873837547642845
Train loss: 0.08900632345821799
val loss: 0.7992606163024902
val score: 0.017361369978752348--------- Epoch12 ----------
Train loss: 0.08641393482685089
Train loss: 0.0893486404702777
Train loss: 0.08927923113834567
val loss: 0.7964036464691162
val score: 0.026846927269458674--------- Epoch13 ----------
Train loss: 0.08859497308731079
Train loss: 0.08794442635206949
Train loss: 0.08938420000599652
val loss: 0.7929846048355103
val score: 0.033583528051411037--------- Epoch14 ----------
Train loss: 0.08901184052228928
Train loss: 0.08875668652000882
Train loss: 0.08939630665430208
val loss: 0.7878748178482056
val score: 0.04465469491549107--------- Epoch15 ----------
Train loss: 0.09487541764974594
Train loss: 0.08885077848320916
Train loss: 0.08909488651083737
val loss: 0.7851467728614807
val score: 0.046395409621300066--------- Epoch16 ----------
Train loss: 0.0839885026216507
Train loss: 0.08959413000515529
Train loss: 0.08932711874566428
val loss: 0.7806612253189087
val score: 0.05012596379845563--------- Epoch17 ----------
Train loss: 0.09544813632965088
Train loss: 0.08826960552306402
Train loss: 0.0890249778948179
val loss: 0.7787002325057983
val score: 0.05084565441331739--------- Epoch18 ----------
Train loss: 0.09305278211832047
Train loss: 0.08740198683171045
Train loss: 0.08877205539767336
val loss: 0.7735776305198669
val score: 0.06808317309022775--------- Epoch19 ----------
Train loss: 0.08946727961301804
Train loss: 0.0880857486100424
Train loss: 0.08832225821367125
val loss: 0.7722467184066772
val score: 0.0741929715804975--------- Epoch20 ----------
Train loss: 0.09155283123254776
Train loss: 0.08762263329256148
Train loss: 0.08867140041618812
val loss: 0.774036705493927
val score: 0.06465988606612133--------- Epoch21 ----------
Train loss: 0.08425123244524002
Train loss: 0.08848933414334342
Train loss: 0.08806171540806933
val loss: 0.7706096768379211
val score: 0.06233272968330965--------- Epoch22 ----------
Train loss: 0.08672144263982773
Train loss: 0.08736556342669896
Train loss: 0.08800865782470238
val loss: 0.7690156698226929
val score: 0.07570956489538153--------- Epoch23 ----------
Train loss: 0.09086063504219055
Train loss: 0.0895571896717662
Train loss: 0.08793148053128545
val loss: 0.7725724577903748
val score: 0.045295719065139656--------- Epoch24 ----------
Train loss: 0.08895140141248703
Train loss: 0.08862598595165071
Train loss: 0.08853605389595032
val loss: 0.7674567103385925
val score: 0.07400126493414798--------- Epoch25 ----------
Train loss: 0.08059882372617722
Train loss: 0.08788907066697166
Train loss: 0.08830737322568893
val loss: 0.7654385566711426
val score: 0.0783971076192251--------- Epoch26 ----------
Train loss: 0.0904078260064125
Train loss: 0.08821353883970351
Train loss: 0.08813101125926506
val loss: 0.7648967504501343
val score: 0.06579874206738114--------- Epoch27 ----------
Train loss: 0.0888797715306282
Train loss: 0.08781595457167853
Train loss: 0.08853465282335514
val loss: 0.765023946762085
val score: 0.06403537205845905--------- Epoch28 ----------
Train loss: 0.0925334170460701
Train loss: 0.08814156835987455
Train loss: 0.08763645026015073
val loss: 0.7604566216468811
val score: 0.08386773786224676--------- Epoch29 ----------
Train loss: 0.09102518111467361
Train loss: 0.08881006035066787
Train loss: 0.08800200536483671
val loss: 0.7639309167861938
val score: 0.05987701272594979--------- Epoch30 ----------
Train loss: 0.08757702261209488
Train loss: 0.08790529945066997
Train loss: 0.08796896276677527
val loss: 0.7679344415664673
val score: 0.04645880716520806--------- Epoch31 ----------
Train loss: 0.09563669562339783
Train loss: 0.08776313385793141
Train loss: 0.08768014010132813
val loss: 0.7532508969306946
val score: 0.09419951931221196--------- Epoch32 ----------
Train loss: 0.08262639492750168
Train loss: 0.08920836945374806
Train loss: 0.08818964242208295
val loss: 0.7534663081169128
val score: 0.07980706821661744--------- Epoch33 ----------
Train loss: 0.09010934829711914
Train loss: 0.08798151392312277
Train loss: 0.08814984251086305
val loss: 0.7573298215866089
val score: 0.0587445179781999--------- Epoch34 ----------
Train loss: 0.09029105305671692
Train loss: 0.08793160106454577
Train loss: 0.087826013383342
val loss: 0.7541366219520569
val score: 0.04576204364697583--------- Epoch35 ----------
Train loss: 0.09183177351951599
Train loss: 0.08813220936627615
Train loss: 0.08824214902592868
val loss: 0.7545167803764343
val score: 0.043795136749962035--------- Epoch36 ----------
Train loss: 0.08738738298416138
Train loss: 0.08769806651842027
Train loss: 0.08801802520344897
val loss: 0.7475957870483398
val score: 0.07046052509968409--------- Epoch37 ----------
Train loss: 0.08695636689662933
Train loss: 0.08928513243084862
Train loss: 0.08794533206922252
val loss: 0.7405006885528564
val score: 0.08250606459379788--------- Epoch38 ----------
Train loss: 0.08741921186447144
Train loss: 0.08701477554582414
Train loss: 0.08772314776007722
val loss: 0.7421612739562988
val score: 0.07261544623998699--------- Epoch39 ----------
Train loss: 0.08897516131401062
Train loss: 0.08884722207273756
Train loss: 0.08827457195375024
val loss: 0.7383261919021606
val score: 0.05041364027920663--------- Epoch40 ----------
Train loss: 0.08877569437026978
Train loss: 0.08783218938679922
Train loss: 0.08838088319795888
val loss: 0.7311040759086609
val score: 0.05160266134263263--------- Epoch41 ----------
Train loss: 0.0751330778002739
Train loss: 0.0872439131850288
Train loss: 0.08815818952351082
val loss: 0.723595917224884
val score: 0.08080731948303856--------- Epoch42 ----------
Train loss: 0.09519665688276291
Train loss: 0.0866984451810519
Train loss: 0.08742059876279133
val loss: 0.7205336689949036
val score: 0.08327377202054256--------- Epoch43 ----------
Train loss: 0.08966871351003647
Train loss: 0.08703825693754923
Train loss: 0.08704596176380064
val loss: 0.7158994078636169
val score: 0.05753987849499046--------- Epoch44 ----------
Train loss: 0.08914705365896225
Train loss: 0.08722686128956932
Train loss: 0.08729714445951509
val loss: 0.7021420001983643
val score: 0.08133226152944593--------- Epoch45 ----------
Train loss: 0.08485537022352219
Train loss: 0.08770599854843956
Train loss: 0.08782925693000235
val loss: 0.705651044845581
val score: 0.07325790592903407--------- Epoch46 ----------
Train loss: 0.08972616493701935
Train loss: 0.088348921920572
Train loss: 0.08801035510330665
val loss: 0.6982176303863525
val score: 0.06009563284716213--------- Epoch47 ----------
Train loss: 0.08506552129983902
Train loss: 0.08846274834303629
Train loss: 0.08772453265946085
val loss: 0.684754490852356
val score: 0.10142577749520322--------- Epoch48 ----------
Train loss: 0.08494629710912704
Train loss: 0.08716638279812676
Train loss: 0.08738453831614518
val loss: 0.6825719475746155
val score: 0.087609587353269--------- Epoch49 ----------
Train loss: 0.08093467354774475
Train loss: 0.08778195899157297
Train loss: 0.08736045422350489
val loss: 0.6823106408119202
val score: 0.06519610685639747

相关文章:

  • React@16.x(34)动画(中)
  • 有那么点道理。
  • 计算机硬件的组成与功能详解
  • 【数据结构】红黑树实现详解
  • 【面试】i++与++i的区别
  • SpringBoot配置第三方专业缓存技术jetcache远程缓存方案和本地缓存方案
  • 前端开发之计算机网络模型认识
  • C#基于SkiaSharp实现印章管理(1)
  • oracle12c dataguard搭建及切换
  • React组件卸载的几种情况
  • Nikto一键扫描Web服务器(KALI工具系列三十)
  • 【break】大头哥哥做题
  • vue登陆密码加密,java后端解密
  • Jenkins+K8s实现持续集成(二)
  • [数据集][目标检测]药片药丸检测数据集VOC+YOLO格式152张1类别
  • 【React系列】如何构建React应用程序
  • 【Under-the-hood-ReactJS-Part0】React源码解读
  • codis proxy处理流程
  • JavaScript 奇技淫巧
  • JS数组方法汇总
  • leetcode386. Lexicographical Numbers
  • Terraform入门 - 1. 安装Terraform
  • TiDB 源码阅读系列文章(十)Chunk 和执行框架简介
  • Twitter赢在开放,三年创造奇迹
  • 从零开始的无人驾驶 1
  • 关于springcloud Gateway中的限流
  • 简单数学运算程序(不定期更新)
  • 我建了一个叫Hello World的项目
  • 无服务器化是企业 IT 架构的未来吗?
  • 一份游戏开发学习路线
  • 一些基于React、Vue、Node.js、MongoDB技术栈的实践项目
  • 用jquery写贪吃蛇
  • 在Unity中实现一个简单的消息管理器
  • 正则表达式
  • Hibernate主键生成策略及选择
  • ​sqlite3 --- SQLite 数据库 DB-API 2.0 接口模块​
  • ​人工智能之父图灵诞辰纪念日,一起来看最受读者欢迎的AI技术好书
  • #define 用法
  • #QT项目实战(天气预报)
  • (2020)Java后端开发----(面试题和笔试题)
  • (Oracle)SQL优化技巧(一):分页查询
  • (Redis使用系列) Springboot 使用redis的List数据结构实现简单的排队功能场景 九
  • (办公)springboot配置aop处理请求.
  • (动态规划)5. 最长回文子串 java解决
  • (二)WCF的Binding模型
  • (分布式缓存)Redis分片集群
  • (四)库存超卖案例实战——优化redis分布式锁
  • (太强大了) - Linux 性能监控、测试、优化工具
  • (转)大型网站架构演变和知识体系
  • (转载)虚函数剖析
  • .desktop 桌面快捷_Linux桌面环境那么多,这几款优秀的任你选
  • .NET 除了用 Task 之外,如何自己写一个可以 await 的对象?
  • .net 调用php,php 调用.net com组件 --
  • .NET开发不可不知、不可不用的辅助类(一)
  • @JSONField或@JsonProperty注解使用