CogView中网络结构的总体构建
入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删。
目录
一、构建图
二、代码解析
1、__init__
(1)参数设定
(2)Word embeddings (parallel)
(3)Transformer
2、forward
(1)Word embeddings (parallel)
(2)Transformer
(3)Parallel logits
(4)串行 or 并行输出
一、构建图
二、代码解析
这部分代码在model/gpt2_modeling.py中
1、__init__
(1)参数设定
- num_layers:transformerLayer的层数;
- vocab_size:词典大小;
- hidden_size:输入层大小;
- num_attention_heads:attention head的数目;
- embedding_dropout_prob:embedding的dropout概率;
- attention_dropout_prob:self attention的dropout概率;
- output_dropout_prob:输出的的dropout概率;
- max_sequence_length:最大序列长度(每次读入的序列长度);
- checkpoint_activations:是否启用检查点激活;
- checkpoint_num_layers:checkpoint层数;
- parallel_output:output是串行or并行;
- query_window:稀疏处理的窗口大小;
- key_window_times:用于调节稀疏处理中的窗口数量;
- num_pivot:稀疏处理中的token总数;
class GPT2Model(torch.nn.Module):
"""GPT-2 Language model.
The output of the forward method are the logits (parallel or
serial depending on the `parallel_output` flag.
"""
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
max_memory_length,
checkpoint_activations,
checkpoint_num_layers=1,
parallel_output=True,
query_window=128,
key_window_times=6,
num_pivot=768
):
super(GPT2Model, self).__init__()
self.parallel_output = parallel_output
init_method = init_method_normal(std=0.02)#初始化方法为高斯分布(均值为0,方差为0.02)
(2)Word embeddings (parallel)
# Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding(
vocab_size, hidden_size, init_method=init_method)
详见CogView中的Word embeddings (parallel)_tt丫的博客-CSDN博客
(3)Transformer
# Transformer
self.transformer = mpu.GPT2ParallelTransformer(num_layers,
hidden_size,
num_attention_heads,
max_sequence_length,
max_memory_length,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
checkpoint_activations,
checkpoint_num_layers,
query_window=query_window,
key_window_times=key_window_times,
num_pivot=num_pivot
)
详见CogView中的Transformer_tt丫的博客-CSDN博客
2、forward
def forward(self, input_ids, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse, *mems):
(1)Word embeddings (parallel)
shape为(b,s,h)
补:b——batch size;s——sequence length;h——hidden_size;
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
embeddings = words_embeddings
(2)Transformer
# Transformer.
transformer_output = self.transformer(embeddings, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse, *mems)
logits, *hidden_layers = transformer_output#logits为output;*hidden_layers为*mem
(3)Parallel logits
# Parallel logits.
logits_parallel = mpu.copy_to_model_parallel_region(
logits)#传递到模型并行区域
logits_parallel = F.linear(logits_parallel,
self.word_embeddings.weight)#线性变化
最终shape为(b,s,h)*(v/p,h)^T=(b,s,v/p)
v——vocab_size;p——number of partitions;
(4)串行 or 并行输出
if self.parallel_output:#并行
return (logits_parallel, *hidden_layers)
return (mpu.gather_from_model_parallel_region(logits_parallel), *hidden_layers)#串行
欢迎大家在评论区批评指正,谢谢~