当前位置: 首页 > news >正文

CogView中网络结构的总体构建

入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删。

目录

一、构建图

二、代码解析

1、__init__

(1)参数设定

(2)Word embeddings (parallel)

(3)Transformer

2、forward

(1)Word embeddings (parallel)

(2)Transformer

(3)Parallel logits

(4)串行 or 并行输出


一、构建图


二、代码解析

这部分代码在model/gpt2_modeling.py中

1、__init__

(1)参数设定

  •  num_layers:transformerLayer的层数;
  • vocab_size:词典大小;
  • hidden_size:输入层大小;
  • num_attention_heads:attention head的数目;
  • embedding_dropout_prob:embedding的dropout概率;
  • attention_dropout_prob:self attention的dropout概率;
  • output_dropout_prob:输出的的dropout概率;
  • max_sequence_length:最大序列长度(每次读入的序列长度);
  • checkpoint_activations:是否启用检查点激活;
  • checkpoint_num_layers:checkpoint层数;
  • parallel_output:output是串行or并行;
  • query_window:稀疏处理的窗口大小;
  • key_window_times:用于调节稀疏处理中的窗口数量;
  • num_pivot:稀疏处理中的token总数;
class GPT2Model(torch.nn.Module):
    """GPT-2 Language model.

    The output of the forward method are the logits (parallel or
    serial depending on the `parallel_output` flag.
    """

    def __init__(self,
                 num_layers,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 embedding_dropout_prob,
                 attention_dropout_prob,
                 output_dropout_prob,
                 max_sequence_length,
                 max_memory_length,
                 checkpoint_activations,
                 checkpoint_num_layers=1,
                 parallel_output=True,
                 query_window=128,
                 key_window_times=6,
                 num_pivot=768
                 ):

        super(GPT2Model, self).__init__()

        self.parallel_output = parallel_output

        init_method = init_method_normal(std=0.02)#初始化方法为高斯分布(均值为0,方差为0.02)

(2)Word embeddings (parallel)

        # Word embeddings (parallel).
        self.word_embeddings = mpu.VocabParallelEmbedding(
            vocab_size, hidden_size, init_method=init_method)

详见CogView中的Word embeddings (parallel)_tt丫的博客-CSDN博客

(3)Transformer

        # Transformer
        self.transformer = mpu.GPT2ParallelTransformer(num_layers,
                                                       hidden_size,
                                                       num_attention_heads,
                                                       max_sequence_length,
                                                       max_memory_length,
                                                       embedding_dropout_prob,
                                                       attention_dropout_prob,
                                                       output_dropout_prob,
                                                       checkpoint_activations,
                                                       checkpoint_num_layers,
                                                       query_window=query_window,
                                                       key_window_times=key_window_times,
                                                       num_pivot=num_pivot
                                                       )

详见CogView中的Transformer_tt丫的博客-CSDN博客

2、forward

    def forward(self, input_ids, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse, *mems):

(1)Word embeddings (parallel)

shape为(b,s,h)

补:b——batch size;s——sequence length;h——hidden_size;

        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        embeddings = words_embeddings

(2)Transformer

        # Transformer.
        transformer_output = self.transformer(embeddings, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse, *mems)
        logits, *hidden_layers = transformer_output#logits为output;*hidden_layers为*mem

(3)Parallel logits

        # Parallel logits.
        logits_parallel = mpu.copy_to_model_parallel_region(
            logits)#传递到模型并行区域
        logits_parallel = F.linear(logits_parallel,
                                   self.word_embeddings.weight)#线性变化

最终shape为(b,s,h)*(v/p,h)^T=(b,s,v/p)

v——vocab_size;p——number of partitions;

(4)串行 or 并行输出

        if self.parallel_output:#并行
            return (logits_parallel, *hidden_layers)

        return (mpu.gather_from_model_parallel_region(logits_parallel), *hidden_layers)#串行

欢迎大家在评论区批评指正,谢谢~

相关文章:

  • 25k的自动化测试面试题,原来都是这样~
  • 猿创征文|我的焚膏继晷之路
  • Linux (Ubuntu)磁盘管理与文件压缩解压(入门必看)
  • CentOS上安装Docker
  • 一文搞定IDEA中SpringBoot项目环境的热部署
  • Java运算符
  • HIS -- 医院信息管理系统业务流程
  • 【精讲】后台项目 采用vue2框架 完整版内含详细注释 1
  • UVA 10271 佳佳的筷子 Chopsticks [DP的基本运用]
  • 【计算机视觉】尺度不变特征变换(SIFT)
  • 计算机网络基础概念
  • Scala系列一:变量和数据类型
  • ROS从入门到精通3-4:urdf集成Gazebo联合仿真
  • 2、操作系统基本原理
  • (二十五)admin-boot项目之集成消息队列Rabbitmq
  • 【108天】Java——《Head First Java》笔记(第1-4章)
  • Brief introduction of how to 'Call, Apply and Bind'
  • CSS实用技巧
  • JavaScript设计模式之工厂模式
  • mongo索引构建
  • MYSQL如何对数据进行自动化升级--以如果某数据表存在并且某字段不存在时则执行更新操作为例...
  • node-glob通配符
  • SAP云平台里Global Account和Sub Account的关系
  • sublime配置文件
  • Xmanager 远程桌面 CentOS 7
  • 不发不行!Netty集成文字图片聊天室外加TCP/IP软硬件通信
  • 初探 Vue 生命周期和钩子函数
  • 从0实现一个tiny react(三)生命周期
  • 分享一个自己写的基于canvas的原生js图片爆炸插件
  • 力扣(LeetCode)21
  • 浏览器缓存机制分析
  • 名企6年Java程序员的工作总结,写给在迷茫中的你!
  • 前端面试题总结
  • 学习JavaScript数据结构与算法 — 树
  • mysql面试题分组并合并列
  • 如何用纯 CSS 创作一个菱形 loader 动画
  • ​LeetCode解法汇总2696. 删除子串后的字符串最小长度
  • ​无人机石油管道巡检方案新亮点:灵活准确又高效
  • #pragma 指令
  • #多叉树深度遍历_结合深度学习的视频编码方法--帧内预测
  • #我与Java虚拟机的故事#连载13:有这本书就够了
  • #在 README.md 中生成项目目录结构
  • (C语言)输入一个序列,判断是否为奇偶交叉数
  • (JSP)EL——优化登录界面,获取对象,获取数据
  • (第二周)效能测试
  • (多级缓存)缓存同步
  • (二)hibernate配置管理
  • (附源码)ssm航空客运订票系统 毕业设计 141612
  • (四)Linux Shell编程——输入输出重定向
  • (转)【Hibernate总结系列】使用举例
  • (转)Linux整合apache和tomcat构建Web服务器
  • (转)一些感悟
  • ***通过什么方式***网吧
  • .NET 设计模式—简单工厂(Simple Factory Pattern)
  • .NetCore Flurl.Http 升级到4.0后 https 无法建立SSL连接