当前位置: 首页 > news >正文

samout 结构再优化 收敛速度再加快

代码

import torch
import numpy as npclass MaxState(torch.nn.Module):def __init__(self, hidden_dim, heads, win):super(MaxState, self).__init__()assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."self.head_size = hidden_dim // headsself.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.head_num = headsself.win = winself.hidden = hidden_dimself.mask = torch.triu(torch.ones([win, win])).to("cuda")self.layer_nor = torch.nn.LayerNorm(hidden_dim)def forward(self, input_data, state=None):# self.head.to("cuda")b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.winwindow = torch.ones([1, w]).to("cuda")out = self.head(input_data)out = out.unsqueeze(-1) @ windowout = out.permute([0, 2, 1, 3])one_list = []if state is None:state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")state = state.to("cuda")for i in range(0, s, w):state.reshape([state.shape[0], -1])j = w + ione = out[:, :, i:j]_, _, r, c = one.shapeif r != self.win:one = torch.where(self.mask[:r, :] == 1, one, torch.Tensor([-float('inf')]).to("cuda"))else:one = torch.where(self.mask == 1, one, torch.Tensor([-float('inf')]).to("cuda"))if i == 0:one = torch.concat([one, state @ window], axis=2)state, _ = torch.max(one, axis=2, keepdim=True)else:state1, _ = torch.max(one, axis=2, keepdim=True)# state = torch.sin(self.state(state1.reshape([state1.shape[0], -1]))*state.reshape([state.shape[0], -1]))state1 = self.state(state1.permute([0, 3, 1, 2]).reshape([state1.shape[0], -1, state1.shape[1]]))state = state1.permute([0, 2, 1]).unsqueeze(-2) + state# state = state.reshape(state1.shape)one = torch.concat([one, state], axis=2)state, _ = torch.max(one, axis=2, keepdim=True)one = state.reshape([b, k, h, w])state = state[..., -1:]if r != self.win:one = one[..., :r]one = one.permute([0, 3, 1, 2])one_list.append(one)out = torch.concat(one_list, 1)out = out.reshape([b, s, -1])return out, stateclass FeedForward(torch.nn.Module):def __init__(self, hidden_size):super(FeedForward, self).__init__()self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)self.relu = torch.nn.ReLU()def forward(self, x):x1 = self.ffn1(x)x2 = self.relu(self.gate(x))x = x1 * x2x = self.ffn2(x)return xclass DecoderLayer(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(DecoderLayer, self).__init__()# self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)self.self_attention = MaxState(hidden_size, num_heads, 8)self.ffn = FeedForward(hidden_size)self.layer_norm = torch.nn.LayerNorm(hidden_size)def forward(self, x, state=None, seq_len=None):x1, state = self.self_attention(x, state)x = self.layer_norm(self.ffn(x1) + x)  # Feed-Forward with residual connectionreturn x, stateclass SamOut(torch.nn.Module):def __init__(self, voc_size, hidden_size, num_heads, num_layers):super(SamOut, self).__init__()self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)self.pos = torch.nn.Embedding(1024, hidden_size)self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])self.head = torch.nn.Linear(hidden_size, voc_size)self.head_state = torch.nn.Linear(hidden_size, num_layers)self.layer_nor=torch.nn.LayerNorm(hidden_size)self.down=torch.nn.ModuleList([torch.nn.Linear(2*hidden_size,hidden_size) for _ in range(num_layers)])def forward(self, x, state=None, seq_len=None):x = self.em(x)if x.shape[1] >= 1024:pos = self.pos(torch.range(0, x.shape[1]-1).long() // 1024).unsqueeze(0)pos = self.pos(torch.range(0, x.shape[1]-1).long() % 1024).unsqueeze(0) + poselse:pos = self.pos(torch.range(0, x.shape[1]-1).long().to("cuda")).unsqueeze(0)if state is None:state = [None] * len(self.decoder_layers)i = 0for decoder_layer in self.decoder_layers:x1, state[i] = decoder_layer(self.down[i](torch.concat([torch.zeros([x.shape[0],1,1]).to("cuda")+pos , x],-1)), state[i])x = x1 + xi += 1state_data = self.head_state((torch.concat(state, -1).squeeze(-2)).permute([0, 2, 1]))return self.head(x), state, state_dataif __name__ == '__main__':net = SamOut(235, 256, 16, 4)net(torch.randint(0, 200, [2, 3000]))

解释

这段代码定义了一个基于PyTorch的神经网络模型,该模型包含自定义的解码器层和输出层,用于处理序列数据。下面是代码的逐行解析:

import torch
import numpy as np
  • 导入PyTorch库和NumPy库。
class MaxState(torch.nn.Module):def __init__(self, hidden_dim, heads, win):super(MaxState, self).__init__()
  • 定义一个名为MaxState的类,继承自torch.nn.Module。这是自定义的一个模块,用于处理状态的最大化。
        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
  • 断言检查,确保隐藏层的维度可以被注意力头的数量整除。
        self.head_size = hidden_dim // headsself.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.head_num = headsself.win = winself.hidden = hidden_dimself.mask = torch.triu(torch.ones([win, win])).to("cuda")self.layer_nor = torch.nn.LayerNorm(hidden_dim)
  • 初始化类成员变量,包括线性层、注意力头数量、窗口大小、掩码和层归一化。
    def forward(self, input_data, state=None):# self.head.to("cuda")b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win
  • 前向传播方法,获取输入数据的形状参数。
        window = torch.ones([1, w]).to("cuda")
  • 创建一个窗口张量,并将其移动到CUDA设备。
        out = self.head(input_data)
  • 应用线性层到输入数据。
        out = out.unsqueeze(-1) @ window
  • 扩展维度并进行矩阵乘法。
        out = out.permute([0, 2, 1, 3])
  • 调整输出的维度顺序。
        one_list = []if state is None:state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")state = state.to("cuda")
  • 初始化状态张量,如果状态为None,则创建一个初始状态。
        for i in range(0, s, w):# ... (省略中间代码)
  • 循环处理每个窗口大小的数据块。
        return out, state
  • 返回处理后的输出和状态。
    接下来是FeedForwardDecoderLayerSamOut类的定义,这些类分别实现了前馈网络、解码器层和整个模型的输出部分。代码结构与MaxState类类似,包含了初始化和前向传播方法。
if __name__ == '__main__':net = SamOut(235, 256, 16, 4)net(torch.randint(0, 200, [2, 3000]))
  • 如果当前脚本作为主程序运行,创建一个SamOut模型实例,并使用随机整数张量作为输入进行测试。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【Unity】处理碰撞体(Collider)
  • 前端高頻面試題(一)
  • 数据结构第20节 快速排序以及优化
  • 【PostgreSQL】Spring boot + Mybatis-plus + PostgreSQL 处理json类型情况
  • 【项目实战】深入解析HTTP状态码:401 Unauthorized
  • 谷粒商城实战笔记-24-分布式组件-SpringCloud Alibaba-Nacos配置中心-命名空间与配置分组
  • vscode gitee问题
  • Proteus + Keil单片机仿真教程(五)多位LED数码管的静态显示
  • 机器学习(V)--无监督学习(三)EM算法
  • 抖音短视频矩阵管理系统搭建全攻略:功能详解与实战应用
  • Linux 渗透测试基础:使用Metasploit、Nmap等工具进行渗透测试
  • LeetCode LCR027.回文链表 C写法
  • MacOS 开发 — Packages 程序 macOS新版本 演示选项卡无法显示
  • mvvm模式
  • 华贝甄选干细胞科技,揭秘生命修复的奥秘
  • ES6语法详解(一)
  • golang中接口赋值与方法集
  • HTTP那些事
  • JS题目及答案整理
  • Linux链接文件
  • MySQL的数据类型
  • PermissionScope Swift4 兼容问题
  • React-redux的原理以及使用
  • 产品三维模型在线预览
  • 从零开始的webpack生活-0x009:FilesLoader装载文件
  • 理解IaaS, PaaS, SaaS等云模型 (Cloud Models)
  • 罗辑思维在全链路压测方面的实践和工作笔记
  • 前嗅ForeSpider中数据浏览界面介绍
  • 如何将自己的网站分享到QQ空间,微信,微博等等
  • 一道面试题引发的“血案”
  • 用mpvue开发微信小程序
  • - 转 Ext2.0 form使用实例
  • Python 之网络式编程
  • 智能情侣枕Pillow Talk,倾听彼此的心跳
  • ​如何使用QGIS制作三维建筑
  • #ubuntu# #git# repository git config --global --add safe.directory
  • #window11设置系统变量#
  • (145)光线追踪距离场柔和阴影
  • (C语言)二分查找 超详细
  • (LNMP) How To Install Linux, nginx, MySQL, PHP
  • (二) 初入MySQL 【数据库管理】
  • (非本人原创)史记·柴静列传(r4笔记第65天)
  • (分类)KNN算法- 参数调优
  • (附源码)springboot课程在线考试系统 毕业设计 655127
  • (附源码)基于SpringBoot和Vue的厨到家服务平台的设计与实现 毕业设计 063133
  • (三分钟了解debug)SLAM研究方向-Debug总结
  • (图文详解)小程序AppID申请以及在Hbuilderx中运行
  • (学习日记)2024.04.10:UCOSIII第三十八节:事件实验
  • (转) 深度模型优化性能 调参
  • .NET 8 跨平台高性能边缘采集网关
  • .NET delegate 委托 、 Event 事件
  • .net 使用$.ajax实现从前台调用后台方法(包含静态方法和非静态方法调用)
  • .net6 core Worker Service项目,使用Exchange Web Services (EWS) 分页获取电子邮件收件箱列表,邮件信息字段
  • .NET中使用Redis (二)
  • .sh 的运行