当前位置: 首页 > news >正文

时间卷积网络(TCN):序列建模的强大工具(附Pytorch网络模型代码)

这里写目录标题

  • 1. 引言
  • 2. TCN的核心特性
    • 2.1 序列建模任务描述
    • 2.2 因果卷积
    • 2.3 扩张卷积
    • 2.4 残差连接
  • 3. TCN的网络结构
  • 4. TCN vs RNN
  • 5. TCN的应用
  • TCN的实现

1. 引言

引用自:Bai S, Kolter J Z, Koltun V. An empirical evaluation of generic convolutional and recurrent networks for sequence modeling. arXiv[J]. arXiv preprint arXiv:1803.01271, 2018, 10.

在这里插入图片描述

时间卷积网络(Temporal Convolutional Network,简称TCN)是一种专门用于处理序列数据的深度学习模型。它结合了卷积神经网络(CNN)的并行处理能力和循环神经网络(RNN)的长期依赖建模能力,成为序列建模任务中的强大工具。实验证明,对于某些任务下的长序LSTM和GRU等RNN架构,因此如果大家有多输入单输出(MISO)或多输入多输出(MIMO)序列建模任务,可以尝试使用TCN来作为创新点。
在这里插入图片描述

2. TCN的核心特性

在这里插入图片描述
图1所示。TCN中的架构元素。(a)一个扩张的因果卷积,其扩张因子d = 1,2,4,滤波器大小k = 3。接收野能够覆盖输入序列中的所有值。(b) TCN残余块。当剩余输入和输出具有不同的维数时,添加1x1卷积。© TCN中剩余连接的示例。蓝线是残差函数中的过滤器,绿线是恒等映射。

2.1 序列建模任务描述

在定义网络结构之前,我们先强调序列建模任务的核心特性。假设我们有输入序列 x 0 , … , x T x_0, \ldots, x_T x0,,xT,并希望在每个时间点预测对应的输出 y 0 , … , y T y_0, \ldots, y_T y0,,yT。关键约束在于,预测某个时间点 t t t 的输出 y t y_t yt 时,我们只能利用此前观察到的输入 x 0 , … , x t x_0, \ldots, x_t x0,,xt。形式上讲,序列建模网络是任何函数 f : X T + 1 → Y T + 1 f : X^{T+1} \rightarrow Y^{T+1} f:XT+1YT+1,它生成如下映射:

y ^ 0 , … , y ^ T = f ( x 0 , … , x T ) \hat{y}_0, \ldots, \hat{y}_T = f(x_0, \ldots, x_T) y^0,,y^T=f(x0,,xT)

若要满足因果性约束,即 y t y_t yt 只依赖于 x 0 , … , x t x_0, \ldots, x_t x0,,xt,而不依赖于任何“未来”的输入 x t + 1 , … , x T x_{t+1}, \ldots, x_T xt+1,,xT。在序列建模的学习目标中,是找到网络 f f f,使其最小化实际输出与预测值间的预期损失, L ( y 0 , … , y T , f ( x 0 , … , x T ) ) L(y_0, \ldots, y_T, f(x_0, \ldots, x_T)) L(y0,,yT,f(x0,,xT)),其中序列和输出根据某一概率分布抽取。

2.2 因果卷积

TCN使用因果卷积(Causal Convolution)来确保模型不会违反时间顺序。因果卷积即输出只依赖于当前时刻及其之前的输入,而不依赖于未来的输入(因为当前的你看不到未来的数据)。在标准的卷积操作中,每个输出值都基于其周围的输入值,包括未来的时间点。但在因果卷积中,权重仅应用于当前和过去的输入值,确保了信息流的方向性,避免了未来信息泄露到当前输出中。为了实现这一点,通常会在卷积核的右侧填充零(称为因果填充),这样只有当前和过去的信息被用于计算输出。

数学表示:

y ( t ) = ∑ i = 0 k − 1 f ( i ) ⋅ x ( t − i ) y(t) = \sum_{i=0}^{k-1} f(i) \cdot x(t-i) y(t)=i=0k1f(i)x(ti)

其中, f f f是卷积核, k k k是卷积核大小, x x x是输入序列。

2.3 扩张卷积

为了增加感受野而不增加参数数量,TCN采用扩张卷积(Dilated Convolution)。扩张卷积,也被称为空洞卷积,是一种在卷积核之间插入空隙(即跳过某些输入单元)的卷积形式。这种技术允许模型在不增加参数数量的情况下捕获更大的感受野,从而更好地理解输入数据中的上下文信息。扩张因子(dilation factor)决定了卷积核中元素之间的间距,例如,如果扩张因子为2,则卷积核中的元素会间隔一个输入单元。

扩张卷积的数学表示:

y ( t ) = ∑ i = 0 k − 1 f ( i ) ⋅ x ( t − d ⋅ i ) y(t) = \sum_{i=0}^{k-1} f(i) \cdot x(t-d \cdot i) y(t)=i=0k1f(i)x(tdi)

其中, d d d是扩张率。

一个扩张的因果卷积如下图所示:
在这里插入图片描述

2.4 残差连接

TCN使用残差连接来缓解梯度消失问题并促进更深层网络的训练。残差连接是残差网络(ResNets)的关键组成部分,由何凯明等人提出。它的主要目的是解决深层神经网络训练中的梯度消失/爆炸问题,以及提高网络的训练效率和性能。在残差连接中,网络的某一层的输出直接加到几层之后的另一层上,形成所谓的“跳跃连接”。具体来说,假设有一个输入 x x x,经过几层后得到 F ( x ) F(x) F(x),那么最终的输出不是 F ( x ) F(x) F(x)而是 x + F ( x ) x+F(x) x+F(x),也就是输入+输出。这种结构允许梯度在反向传播时可以直接流回更早的层,减少了梯度消失的问题,并且使得网络能够有效地训练更深的架构。残差块的输出可以表示为:

o u t p u t = a c t i v a t i o n ( i n p u t + F ( i n p u t ) ) output = activation(input + F(input)) output=activation(input+F(input))

其中, F F F是卷积层和激活函数的组合,残差连接如下图所示:
在这里插入图片描述

3. TCN的网络结构

TCN的基本结构包括多个残差块,每个残差块包含:

  1. 一维因果卷积层
  2. 层归一化
  3. ReLU激活函数
  4. Dropout层

TCN的整体结构可以表示为:
在这里插入图片描述

4. TCN vs RNN

相比于RNN,TCN有以下优势:

  1. 并行计算:卷积操作可以并行执行,提高计算效率。
  2. 固定感受野:可以精确控制输出对过去输入的依赖范围。
  3. 灵活的感受野大小:通过调整网络深度和扩张率,可以轻松处理不同长度的序列。
  4. 稳定梯度:避免了RNN中的梯度消失/爆炸问题。

5. TCN的应用

TCN在多个领域表现出色,包括:

  • 时间序列预测
  • 语音合成
  • 机器翻译
  • 动作识别
  • 音频生成

本篇文章不靠卖代码赚取收益,麻烦给个点赞和关注,后续还会有开源的免费优化算法及其代码,栓Q!同时如果大家有想要的算法可以在评论区打出,如果有空的话我可以帮忙复现

TCN的实现

以下是使用PyTorch实现TCN核心组件的示例代码(可以直接调用):

import torch
import torch.nn as nn
from torch.nn.utils import weight_normclass Chomp1d(nn.Module):def __init__(self, chomp_size):super(Chomp1d, self).__init__()self.chomp_size = chomp_sizedef forward(self, x):return x[:, :, :-self.chomp_size].contiguous()class TemporalBlock(nn.Module):def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):super(TemporalBlock, self).__init__()self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,stride=stride, padding=padding, dilation=dilation))self.chomp1 = Chomp1d(padding)self.relu1 = nn.ReLU()self.dropout1 = nn.Dropout(dropout)self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,stride=stride, padding=padding, dilation=dilation))self.chomp2 = Chomp1d(padding)self.relu2 = nn.ReLU()self.dropout2 = nn.Dropout(dropout)self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,self.conv2, self.chomp2, self.relu2, self.dropout2)self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else Noneself.relu = nn.ReLU()self.init_weights()def init_weights(self):self.conv1.weight.data.normal_(0, 0.01)self.conv2.weight.data.normal_(0, 0.01)if self.downsample is not None:self.downsample.weight.data.normal_(0, 0.01)def forward(self, x):out = self.net(x)res = x if self.downsample is None else self.downsample(x)return self.relu(out + res)class TemporalConvNet(nn.Module):def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):super(TemporalConvNet, self).__init__()layers = []num_levels = len(num_channels)for i in range(num_levels):dilation_size = 2 ** iin_channels = num_inputs if i == 0 else num_channels[i-1]out_channels = num_channels[i]layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,padding=(kernel_size-1) * dilation_size, dropout=dropout)]self.network = nn.Sequential(*layers)def forward(self, x):return self.network(x)

相关文章:

  • 入门 git
  • MySQL:SELECT 语句
  • Android 11 HAL层集成FFMPEG
  • Flink源码学习资料
  • 机体坐标系和导航坐标系
  • 【中项】系统集成项目管理工程师-第2章 信息技术发展-2.1信息技术及其发展-2.1.1计算机软硬件与2.1.2计算机网络
  • springboot防止重复提交的方案有哪些
  • [2019红帽杯]Snake
  • 纯前端导出xlsx表格
  • 深入理解并使用 MySQL 的 SUBSTRING_INDEX 函数
  • STM32中PC13引脚可以当做普通引脚使用吗?如何配置STM32的TAMPER?
  • docker搭建普罗米修斯监控gpu
  • 基于 Three.js 的 3D 模型加载优化
  • Python实现人脸识别
  • 【IEEE出版,会议历史良好、论文录用检索快】第四届计算机科学与区块链国际学术会议 (CCSB 2024,9月6-8)
  • Google 是如何开发 Web 框架的
  • 【跃迁之路】【585天】程序员高效学习方法论探索系列(实验阶段342-2018.09.13)...
  • JavaScript 奇技淫巧
  • MD5加密原理解析及OC版原理实现
  • niucms就是以城市为分割单位,在上面 小区/乡村/同城论坛+58+团购
  • Python - 闭包Closure
  • Sublime Text 2/3 绑定Eclipse快捷键
  • Vue 重置组件到初始状态
  • web标准化(下)
  • 阿里云Kubernetes容器服务上体验Knative
  • 纯 javascript 半自动式下滑一定高度,导航栏固定
  • 从 Android Sample ApiDemos 中学习 android.animation API 的用法
  • 精益 React 学习指南 (Lean React)- 1.5 React 与 DOM
  • 前端面试之CSS3新特性
  • 小程序 setData 学问多
  • 移动互联网+智能运营体系搭建=你家有金矿啊!
  • 用Python写一份独特的元宵节祝福
  • 源码之下无秘密 ── 做最好的 Netty 源码分析教程
  • Hibernate主键生成策略及选择
  • 正则表达式-基础知识Review
  • ​Python 3 新特性:类型注解
  • ‌分布式计算技术与复杂算法优化:‌现代数据处理的基石
  • #pragma once与条件编译
  • #window11设置系统变量#
  • #大学#套接字
  • (ZT) 理解系统底层的概念是多么重要(by趋势科技邹飞)
  • (ZT)一个美国文科博士的YardLife
  • (算法设计与分析)第一章算法概述-习题
  • (一)Dubbo快速入门、介绍、使用
  • (转)GCC在C语言中内嵌汇编 asm __volatile__
  • (转)jQuery 基础
  • .cfg\.dat\.mak(持续补充)
  • .NET Core 2.1路线图
  • .Net Core缓存组件(MemoryCache)源码解析
  • .Net MVC4 上传大文件,并保存表单
  • .net redis定时_一场由fork引发的超时,让我们重新探讨了Redis的抖动问题
  • .NET/C# 使用反射注册事件
  • .NET技术成长路线架构图
  • .NET牛人应该知道些什么(2):中级.NET开发人员
  • .NET正则基础之——正则委托