当前位置: 首页 > news >正文

block-recurrent-transformer-pytorch 学习笔记

目录

有依赖项1:

没有依赖项,没有使用例子

没有依赖项2:


有依赖项1:

GitHub - dashstander/block-recurrent-transformer: Pytorch implementation of "Block Recurrent Transformers" (Hutchins & Schlag et al., 2022)

没有依赖项,没有使用例子

GitHub - jskinn/pytorch-block-recurrent-transformer: Pytorch implementation of the Block-Recurrent Transformer. Official JAX implementation here: https://github.com/google-research/meliad

没有依赖项2:

GitHub - lucidrains/block-recurrent-transformer-pytorch: Implementation of Block Recurrent Transformer - Pytorch

import math
from random import random
from functools import wraps, partial
from itertools import zip_longest
from collections import namedtuple, defaultdict
from packaging import versionimport torch
import torch.nn.functional as F
from torch import nn, einsumfrom einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrangefrom beartype import beartype
from beartype.typing import Optional, List, Tuple# helpersdef exists(val):return val is not Nonedef default(val, d):return val if exists(val) else ddef is_empty(t: torch.Tensor):return t.numel() == 0def cast_tuple(t, length = 1):return t if isinstance(t, tuple) else ((t,) * length)def all_unique(arr):return len(arr) == len(set(arr))def eval_decorator(fn):def inner(self, *args, **kwargs):was_training = self.trainingself.eval()out = fn(self, *args, **kwargs)self.train(was_training)return outreturn innerdef once(fn):called = False@wraps(fn)def inner(x):nonlocal calledif called:returncalled = Truereturn fn(x)return innerprint_once = once(print)def compact(arr):return [*filter(exists, arr)]def and_reduce(arr: List[torch.Tensor]):if len(arr) == 0:return Nonehead, *rest = arrfor t in rest:head = head & treturn headdef safe_cat(*args, dim = 1):args = compact(args)if len(args) == 0:return Nonereturn torch.cat(args, dim = dim)def divisible_by(numer, denom):return (numer % denom) == 0def l2norm(t):return F.normalize(t, dim = -1)def pack_one(t, pattern):return pack([t], pattern)def unpack_one(t, ps, pattern):return unpack(t, ps, pattern)[0]def pad_at_dim(t, pad, dim = -1, value = 0.):dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)zeros = ((0, 0) * dims_from_right)return F.pad(t, (*zeros, *pad), value = value)# bias-less layernormclass LayerNorm(nn.Module):def __init__(self, dim):super().__init__()self.gamma = nn.Parameter(torch.ones(dim))self.register_buffer("beta", torch.zeros(dim))def forward(self, x):return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)# sampling helpersdef log(t, eps = 1e-20):return torch.log(t.clamp(min = eps))def gumbel_noise(t):noise = torch.zeros_like(t).uniform_(0, 1)return -log(-log(noise))def gumbel_sample(t, temperature = 1., dim = -1):return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)def top_k(logits, thres = 0.9):k = math.ceil((1 - thres) * logits.shape[-1])val, ind = torch.topk(logits, k)probs = torch.full_like(logits, float('-inf'))probs.scatter_(1, ind, val)return probs# rotary positional embedding w/ xpos
# https://arxiv.org/abs/2104.09864
# https://arxiv.org/abs/2212.10554v1class RotaryEmbedding(nn.Module):def __init__(self,dim,width,scale_base = 512,theta = 10000):super().__init__()self.width = widthinv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))self.register_buffer("inv_freq", inv_freq, persistent = False)self.scale_base = scale_basescale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)self.register_buffer('scale', scale, persistent = False)self.register_buffer('cached_freqs', None, persistent = False)self.register_buffer('cached_scales', None, persistent = False)@propertydef device(self):return next(self.buffers()).devicedef forward(self):device, seq_len = self.device, self.widthif exists(self.cached_freqs):cached_seq_len = self.cached_freqs.shape[-2]if cached_seq_len >= seq_len:return self.cached_freqs[:seq_len], self.cached_scales[:seq_len]t = torch.arange(seq_len, device = device).type_as(self.inv_freq)freqs = torch.einsum('i , j -> i j', t, self.inv_freq)freqs = torch.cat((freqs, freqs), dim = -1)power = (t - (seq_len // 2)) / self.scale_basescale = self.scale ** rearrange(power, 'n -> n 1')scale = torch.cat((scale, scale), dim = -1)self.register_buffer('cached_freqs', freqs, persistent = False)self.register_buffer('cached_scales', scale, persistent = False)return freqs, scaledef rotate_half(x):x1, x2 = x.chunk(2, dim=-1)return torch.cat((-x2, x1), dim=-1)def apply_rotary_pos_emb(t, pos, scale = 1.):scale = default(scale, 1.)seq_len = t.shape[-2]assert pos.shape[-2] >= seq_lenpos = pos[-seq_len:]if isinstance(scale, torch.Tensor):assert scale.shape[-2] >= seq_lenscale = scale[-seq_len:]return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)# memory managementclass MemoryManager(nn.Module):def __init__(self,dim,*,layers = 1,mem_lengths = 512,compress_factors = 1):super().__init__()mem_lengths = cast_tuple(mem_lengths)compress_factors = cast_tuple(compress_factors)assert all([mem_length > 0 for mem_length in mem_lengths])assert len(mem_lengths) == len(compress_factors)assert layers >= 1self.mem_lengths = mem_lengthsself.compress_factors = compress_factorsself.layers = nn.ModuleList([])for _ in range(layers):compress_fns = nn.ModuleList([])for compress_factor in compress_factors:compress_fn = nn.Identity()if compress_factor > 1:compress_fn = nn.Sequential(Rearrange('b n d -> b d n'),nn.Conv1d(dim * 2,dim * 2,compress_factor,stride = compress_factor,groups = 2),Rearrange('b d n -> b n d'),)compress_fns.append(compress_fn)self.layers.append(compress_fns)def forward(self,past_memories: List[torch.Tensor],new_memories: List[torch.Tensor]):next_memories = []for past_memory, new_memory, compress_fns in zip_longest(past_memories, new_memories, self.layers):# edge case if neither memories existif not (exists(past_memory) or exists(new_memory)):next_memories.append(None)continuenext_memory = Nonefor mem_length, compress_factor, compress_fn in zip(self.mem_lengths, self.compress_factors, compress_fns):# first get the memories for the given compression factor "current_memory"current_memory = Noneif exists(past_memory):past_memory, current_memory = past_memory[..., :-mem_length, :], past_memory[..., -mem_length:, :]# compress the new memories coming in, based on the compression factors set at initif (not is_empty(new_memory)) and compress_factor > 1:# make sure memory length is divisible by compression factornew_mem_length = new_memory.shape[-2]curtailed_length = (new_mem_length // compress_factor) * compress_factorcurtailed_slice = slice(-curtailed_length, None) if curtailed_length > 0 else slice(0, 0)new_memory = new_memory[..., curtailed_slice, :]# compress the memory pushed to the next stageif new_memory.shape[-2] > 0:new_memory = rearrange(new_memory, 'm b n d -> b n (m d)')new_memory = compress_fn(new_memory)new_memory = rearrange(new_memory, 'b n (m d) -> m b n d', m = 2)# fifo memory queue# add the new memory on the rightcurrent_memory = safe_cat(current_memory, new_memory, dim = -2)# "new" memory is new with respect to the next compressed segmentnew_memory, current_memory = current_memory[..., :-mem_length, :], current_memory[..., -mem_length:, :]# concat the new memory to the left into the pastnext_memory = safe_cat(current_memory, next_memory, dim = -2)next_memories.append(next_memory)return next_memories# maybe flash attention, if using pytorch 2.0# constantsConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])# state containerclass StateContainer(nn.Module):def __init__(self,dim,*,num_state_vectors,dim_head = 64,heads = 8,qk_rmsnorm = False,qk_rmsnorm_scale = 8,use_flash_attn = False):super().__init__()assert num_state_vectors > 0self.heads = headsinner_dim = dim_head * headsself.state_norm = LayerNorm(dim)self.q_to_state = nn.Linear(dim, inner_dim, bias = False)self.q_from_state = nn.Linear(dim, inner_dim, bias = False)self.state_to_q = nn.Linear(dim, inner_dim, bias = False)self.state_to_kv = nn.Linear(dim, dim_head * 2, bias = False)self.init_state = nn.Parameter(torch.randn(num_state_vectors, dim))self.state_pos_ids = nn.Parameter(torch.randn(num_state_vectors, dim))self.to_state_out = nn.Linear(inner_dim * 2, dim, bias = False)self.to_state_cross_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)self.state_self_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)self.from_state_cross_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)# gating related parameters - using the fixed simple configself.state_out_to_gate = nn.Linear(dim, dim)self.learned_ema_beta = nn.Parameter(torch.randn(dim))# since each read should be followed by a write, just store cache in the containerself.cache = Noneself.next_read_state = Nonedef set_next_read_state(self,states):if not exists(states):states = self.init_stateself.next_read_state = (states,)def read(self, x):assert exists(self.next_read_state), 'states to be read must be set with .set_next_read_state'states, = self.next_read_stateself.next_read_state = None# pre norm state for attentionnormed_states = self.state_norm(states)# add the positional ids, as stated in the paper critical for it to worknormed_states = normed_states + self.state_pos_ids# get queries for cross attention, which they do not share, although they share key / values. another intriguing detailq_to_state = self.q_to_state(x)q_to_state = rearrange(q_to_state, '... n (h d) -> ... h n d', h = self.heads)# self attention qkv for statesstate_k, state_v = self.state_to_kv(normed_states).chunk(2, dim = -1)# cross attend to the past states key valuesto_state_out = self.to_state_cross_attn(q_to_state, state_k, state_v)to_state_out = rearrange(to_state_out, 'b h n d -> b n (h d)')# cache for next writeself.cache = (states, normed_states, state_k, state_v)return to_state_outdef write(self,*,memories):assert exists(self.cache)k, v = memoriesbatch = k.shape[0]# get cached values from the previous readstates, normed_states, state_k, state_v = self.cacheself.cache = None# derive queriesq_from_state = self.q_from_state(normed_states)q_from_state = rearrange(q_from_state, '... n (h d) -> ... h n d', h = self.heads)state_q = self.state_to_q(normed_states)state_q_einsum = 'n (h d)' if state_q.ndim == 2 else 'b n (h d)'state_q = repeat(state_q, f'{state_q_einsum} -> b h n d', h = self.heads, b = batch)# states must also undergo self attentionif q_from_state.ndim == 3:q_from_state = repeat(q_from_state, '... -> b ...', b = batch)state_out = self.state_self_attn(state_q, state_k, state_v)from_state_out = self.from_state_cross_attn(q_from_state, k, v)state_out = torch.cat((state_out, from_state_out), dim = -1)state_out = rearrange(state_out, 'b h n d -> b n (h d)')state_out = self.to_state_out(state_out)# use the best performing configuration# fixed simple gate - nothing more than a learned EMA with some resemblance to highway networksz = self.state_out_to_gate(state_out)learned_ema_decay = self.learned_ema_beta.sigmoid()# set new state with the learned EMA gatingreturn learned_ema_decay * z + (1 - learned_ema_decay) * statesdef forward(self, x):raise NotImplementedError# main classclass Attend(nn.Module):def __init__(self,causal = False,use_flash_attn = False):super().__init__()self.causal = causalself.register_buffer("mask", None, persistent=False)self.use_flash_attn = use_flash_attnassert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'# determine efficient attention configs for cuda and cpuself.cpu_config = Config(True, True, True)self.cuda_config = Noneif not torch.cuda.is_available() or not use_flash_attn:returndevice_properties = torch.cuda.get_device_properties(torch.device('cuda'))if device_properties.major == 8 and device_properties.minor == 0:print_once('A100 GPU detected, using flash attention if input tensor is on cuda')self.cuda_config = Config(True, False, False)else:print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')self.cuda_config = Config(False, True, True)def get_mask(self, n, device):if exists(self.mask) and self.mask.shape[-1] >= n:return self.mask[:n, :n]mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)self.register_buffer("mask", mask, persistent=False)return maskdef flash_attn(self, q, k, v, mask = None):_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda# Recommended for multi-query single-key-value attention by Tri Dao# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])if k.ndim == 3:k = repeat(k, 'b ... -> b h ...', h = q.shape[1])if v.ndim == 3:v = repeat(v, 'b ... -> b h ...', h = q.shape[1])# Check if mask exists and expand to compatible shape# The mask is B L, so it would have to be expanded to B H N Lmasks = []if self.causal:i, j = q_len, k_lencausal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)masks.append(~causal_mask)if exists(mask):if mask.ndim != 2:mask = repeat(mask, 'w ... -> (b w) ...', b = q.shape[0] // mask.shape[0])masks.append(mask)attn_mask = and_reduce(masks)# Check if there is a compatible device for flash attentionconfig = self.cuda_config if is_cuda else self.cpu_config# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scalewith torch.backends.cuda.sdp_kernel(**config._asdict()):out = F.scaled_dot_product_attention(q, k, v,attn_mask = attn_mask)return outdef forward(self, q, k, v, mask = None, use_flash_attn = None):use_flash_attn = default(use_flash_attn, self.use_flash_attn)b, n, device = q.shape[0], q.shape[-2], q.deviceq, ps = pack_one(q, '* h n d')k, _ = pack_one(k, '* n d')v, _ = pack_one(v, '* n d')if use_flash_attn:out = self.flash_attn(q, k, v, mask = mask)return unpack_one(out, ps, '* h n d')scale = q.shape[-1] ** -0.5k_einsum = 'b j d' if k.ndim == 3 else 'b h j d'v_einsum = 'b j d' if v.ndim == 3 else 'b h j d'# similaritysim = einsum(f"b h i d, {k_einsum} -> b h i j", q, k) * scale# key padding maskif exists(mask):if mask.ndim != 2:mask = repeat(mask, 'w ... -> (b w) ...', b = b)sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)# causal maskif self.causal:i, j = sim.shape[-2:]causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)# attentionattn = sim.softmax(dim=-1)# aggregate valuesout = einsum(f"b h i j, {v_einsum} -> b h i d", attn, v)return unpack_one(out, ps, '* h n d')# geglu feedforwardclass GEGLU(nn.Module):def forward(self, x):x, gate = x.chunk(2, dim = -1)return F.gelu(gate) * xdef FeedForward(dim, mult = 4):inner_dim = int(dim * mult * 2 / 3)return nn.Sequential(LayerNorm(dim),nn.Linear(dim, inner_dim * 2, bias = False),GEGLU(),nn.Linear(inner_dim, dim, bias = False))# attentionclass Attention(nn.Module):def __init__(self,dim_head,causal = False,qk_rmsnorm = False,qk_rmsnorm_scale = 8,use_flash_attn = False):super().__init__()self.causal = causalself.qk_rmsnorm = qk_rmsnormself.qk_rmsnorm_scale = qk_rmsnorm_scaleself.attend = Attend(causal = causal, use_flash_attn = use_flash_attn)if qk_rmsnorm:self.q_scale = nn.Parameter(torch.ones(dim_head))self.k_scale = nn.Parameter(torch.ones(dim_head))def forward(self,q, k, v,mask = None,rotary_pos_emb = None,xpos_scale = None):scale = q.shape[-1] ** -0.5if self.qk_rmsnorm:q, k = map(l2norm, (q, k))scale = self.qk_rmsnorm_scaleif self.qk_rmsnorm:q = q * self.q_scalek = k * self.k_scale# rotary positional embedding with xpos for length extrapolationif exists(rotary_pos_emb):q = apply_rotary_pos_emb(q, rotary_pos_emb, xpos_scale)k = apply_rotary_pos_emb(k, rotary_pos_emb, xpos_scale ** -1)# attentionout = self.attend(q, k, v, mask = mask)return outclass AttentionBlock(nn.Module):def __init__(self,dim,block_width,dim_head = 64,heads = 8,qk_rmsnorm = False,qk_rmsnorm_scale = 8,use_flash_attn = False,num_state_vectors = 0,num_external_state_reads = 0,state_read_before_write = True  # this will be defaulted to on as in the paper, but will be turned off in the case the researcher wants to test out reading the state at a lower layer):super().__init__()inner_dim = dim_head * headsself.heads = headsself.norm = LayerNorm(dim)self.to_q = nn.Linear(dim, inner_dim, bias = False)self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)self.attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)self.block_width = block_widthself.is_recurrent_layer = num_state_vectors > 0# decide how many states this attention layer is going to read fromnum_state_reads = int(self.is_recurrent_layer and state_read_before_write) + num_external_state_readsself.to_out = nn.Linear(inner_dim * (1 + num_state_reads), dim, bias = False)if not self.is_recurrent_layer:returnself.state_read_before_write = state_read_before_writeself.state_container = StateContainer(dim,dim_head = dim_head,heads = heads,num_state_vectors = num_state_vectors,qk_rmsnorm = qk_rmsnorm,qk_rmsnorm_scale = qk_rmsnorm_scale,use_flash_attn = use_flash_attn)@propertydef device(self):return next(self.parameters()).devicedef forward(self,x,rotary_pos_emb = None,xpos_scale = None,attn_mask = None,xl_memories: Optional[torch.Tensor] = None,read_from_state_containers: List[StateContainer] = []):batch, seq_len, _, width, device = *x.shape, self.block_width, self.device# pre normalizationx = self.norm(x)# queries, keys, values and split out headsq, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))split_head = partial(rearrange, pattern = 'b n (h d) -> b h n d', h = self.heads)q = split_head(q)# save the last key / values as memories for recurrencememories = torch.stack((k, v))mem_len = 0if exists(xl_memories):# if past memories are passed in, concat as the first bucketmem_len = xl_memories.shape[-2]past_k, past_v = xl_memoriesk = torch.cat((past_k, k), dim = 1)v = torch.cat((past_v, v), dim = 1)# handle cropping of attention mask and positional embeddingsif exists(attn_mask):attn_mask = attn_mask[:seq_len, :seq_len]attn_mask = F.pad(attn_mask, (mem_len, 0), value = True)# attention, but of courseout = self.attn(q, k, v,rotary_pos_emb = rotary_pos_emb,xpos_scale = xpos_scale,mask = attn_mask)# merge headsout = rearrange(out, 'b h n d -> b n (h d)')# early return if not a recurrent layerif not self.is_recurrent_layer and len(read_from_state_containers) == 0:return self.to_out(out), memories, None# whether to read from own state container, default to on, but may pass in moreif self.is_recurrent_layer and self.state_read_before_write:read_from_state_containers = [self.state_container, *read_from_state_containers]for read_state_container in read_from_state_containers:# read from the states ...to_state_out = read_state_container.read(x)# and concat it to the output of self-attentionout = torch.cat((out, to_state_out), dim = -1)new_states = Noneif self.is_recurrent_layer:# then write to the states as well if need benew_states = self.state_container.write(memories = memories)return self.to_out(out), memories, new_states# classes@beartype
class BlockRecurrentTransformer(nn.Module):def __init__(self,*,num_tokens,dim,depth,dim_head = 64,heads = 8,all_layers_qk_rmsnorm = False,ff_mult = 4,max_seq_len = 1024,block_width = 512,recurrent_layers: Optional[Tuple[int, ...]] = None,read_recurrent_layers: Optional[Tuple[int, ...]] = None,num_state_vectors = None,ignore_index = -100,use_flash_attn = False,use_compressed_mem = False,compressed_mem_factor = 4):super().__init__()num_state_vectors = default(num_state_vectors, block_width)# set recurrent layersrecurrent_layers = default(recurrent_layers, (depth // 2,)) # default to one recurent layer at middle of the networkassert all([0 < layer <= depth for layer in recurrent_layers]), f'recurrent layers must range from 1 to the depth {depth}'assert all_unique(recurrent_layers), 'recurrent layers must be all unique. no duplicate layers'self.recurrent_layers = recurrent_layers# set read recurrent layersread_recurrent_layers = default(read_recurrent_layers, recurrent_layers)assert all([read_layer <= write_layer for read_layer, write_layer in zip(read_recurrent_layers, recurrent_layers)]), 'the recurrent read layer must be always less than or equal to the write layer'assert all([0 < layer <= depth for layer in read_recurrent_layers])assert len(read_recurrent_layers) == len(recurrent_layers)self.read_recurrent_layers = read_recurrent_layers# token embeddingself.token_emb = nn.Embedding(num_tokens, dim)self.rotary_pos_emb = RotaryEmbedding(dim = dim_head, width = (2 if not use_compressed_mem else 3) * block_width)self.layers = nn.ModuleList([])self.write_to_read_map = {write_layer: read_layer for write_layer, read_layer in zip(recurrent_layers, read_recurrent_layers)}self.read_state_router = defaultdict(list)for layer in range(1, depth + 1):is_recurrent_layer = layer in self.recurrent_layerslayer_num_state_vectors = num_state_vectors if is_recurrent_layer else 0num_external_state_reads = sum([int(layer == read_layer) for read_layer in read_recurrent_layers])# only layers with xl memories# or has recurrence in horizontal direction# use qk rmsnorm (in paper, they use cosine sim attention, but i think qk rmsnorm is more proven given Vit 22B paper)# one can also override to use all qk rmsnorm by setting all_layers_qk_rmsnorm = Trueqk_rmsnorm = all_layers_qk_rmsnorm or is_recurrent_layerattn_block = AttentionBlock(dim,block_width = block_width,dim_head = dim_head,heads = heads,qk_rmsnorm = qk_rmsnorm,num_state_vectors = layer_num_state_vectors,use_flash_attn = use_flash_attn,num_external_state_reads = num_external_state_reads,state_read_before_write = False,)ff_block = FeedForward(dim, mult = ff_mult)if is_recurrent_layer:read_layer = self.write_to_read_map[layer]self.read_state_router[read_layer].append(attn_block.state_container)self.layers.append(nn.ModuleList([attn_block,ff_block]))# (compressed) memory managementself.mem_manager = MemoryManager(dim = dim_head,layers = depth,mem_lengths = block_width if not use_compressed_mem else (block_width, block_width // 2),compress_factors = 1 if not use_compressed_mem else (1, compressed_mem_factor))# to logitsself.to_logits = nn.Sequential(LayerNorm(dim),nn.Linear(dim, num_tokens, bias = False))self.max_seq_len = max_seq_lenself.block_width = block_widthassert divisible_by(max_seq_len, block_width)self.ignore_index = ignore_indexself.register_buffer('cached_causal_attn_mask', None, persistent = False)@propertydef device(self):return next(self.parameters()).devicedef get_causal_attn_mask(self, width):if exists(self.cached_causal_attn_mask):cached_mask = self.cached_causal_attn_maskcached_width = cached_mask.shape[-2]padding = (width - cached_width) // 2j_slice = Ellipsis if padding == 0 else slice(padding, -padding)return cached_mask[:cached_width, j_slice]device = self.devicecausal_mask = torch.ones((width, width), device = device, dtype = torch.bool).triu(1)return ~causal_mask@torch.no_grad()@eval_decoratordef generate(self,prime,length = None,xl_memories: List[torch.Tensor] = [],states: List[torch.Tensor] = [],temperature = 1.,filter_thres = 0.9,return_memories_and_states = False):length = default(length, self.max_seq_len + 1)start_len = prime.shape[-1]assert start_len < self.max_seq_lenassert length <= (self.max_seq_len + 1)assert start_len < lengthoutput = primememories = []for ind in range(length - start_len):logits, next_memories, next_states = self.forward(output,xl_memories = xl_memories,states = states)logits = logits[:, -1]filtered_logits = top_k(logits, thres = filter_thres)sampled = gumbel_sample(filtered_logits, temperature = temperature)sampled = rearrange(sampled, 'b -> b 1')output = torch.cat((output, sampled), dim = -1)if divisible_by(output.shape[-1] - 1, self.max_seq_len): # on the sampling of the last token in the current window, set new memories and statesmemories = next_memoriesstates = next_statesoutput = output[:, start_len:]if return_memories_and_states:return output, memories, statesreturn outputdef forward(self,x,return_loss = False,xl_memories: List[torch.Tensor] = [],states: List[torch.Tensor] = [],return_memories_and_states = None  # can force to either return memory + state or not. by default will only return when number of tokens == max_seq_len):device = x.deviceif return_loss:x, labels = x[:, :-1], x[:, 1:]# get sequence length i and j for dynamic pos biasassert x.shape[-1] <= self.max_seq_lenw = self.block_width# token embeddingx = self.token_emb(x)# dynamic pos biasattn_mask = self.get_causal_attn_mask(w)rotary_pos_emb, xpos_scale = self.rotary_pos_emb()# only return memories and state if at the full block width, but can be overriddenreturn_memories_and_states = default(return_memories_and_states, self.max_seq_len == x.shape[-2])# ready output tensor, to be concatted to block by blockbatch, _, dim = x.shapeout = torch.empty(batch, 0, dim, dtype = x.dtype, device = self.device)# split input into blocks of width winput_blocks = x.split(w, dim = -2)# process each block at a timefor input_block in input_blocks:input_block_length = input_block.shape[-2]# ready xl memories and statesiter_xl_memories = iter(xl_memories)iter_states = iter(states)next_xl_memories = []next_states = []# set the states on the appropriate state containersfor attn, _ in self.layers:if not attn.is_recurrent_layer:continueattn.state_container.set_next_read_state(next(iter_states, None))# go through layersfor ind, (attn, ff) in enumerate(self.layers):# determine if the layer requires transformer xl memorieslayer = ind + 1# whether to pass in xl memoriesattn_kwargs = dict(rotary_pos_emb = rotary_pos_emb,xpos_scale = xpos_scale,attn_mask = attn_mask,xl_memories = next(iter_xl_memories, None),read_from_state_containers = self.read_state_router[layer])# attention layerresidual = input_blockattn_branch_out, layer_xl_memories, layer_next_states = attn(input_block, **attn_kwargs)if exists(layer_xl_memories):next_xl_memories.append(layer_xl_memories)if exists(layer_next_states):next_states.append(layer_next_states)input_block = attn_branch_out + residual# feedforward layerinput_block = ff(input_block) + input_block# concat to outputout = torch.cat((out, input_block), dim = -2)# set new xl memories and statesstates = next_statesif input_block_length == w:xl_memories = self.mem_manager(xl_memories, next_xl_memories)# project to logitslogits = self.to_logits(out)# detach the states and memoriesreturned_next_states = list(map(torch.detach, states)) if return_memories_and_states else Nonereturned_next_xl_memories = list(map(torch.detach, xl_memories)) if return_memories_and_states else None# whether to return logitsif not return_loss:return logits, returned_next_xl_memories, returned_next_states# cross entropy losslogits = rearrange(logits, 'b n c -> b c n')loss = F.cross_entropy(logits, labels, ignore_index = self.ignore_index)return loss, returned_next_xl_memories, returned_next_states# recurrent trainer wrapper@beartype
class RecurrentTrainerWrapper(nn.Module):def __init__(self,transformer: BlockRecurrentTransformer,xl_memories_dropout = 0.,state_dropout = 0.):super().__init__()self.transformer = transformerself.seq_len = transformer.max_seq_lenself.xl_memories_dropout = xl_memories_dropoutself.state_dropout = state_dropout@eval_decorator@torch.no_grad()def generate(self,prime,length,**kwargs):seq_len = self.seq_lenstart_len = prime.shape[-1]assert start_len < lengthoutput = primecurrent_len = start_lenmemories = []states = []# determine lengthshas_remainder = not divisible_by(length, seq_len)remainder_amount = length % seq_lentotal_segments = math.ceil(length / seq_len)if not has_remainder:lengths = (*((seq_len + 1,) * (total_segments - 1)), seq_len)elif remainder_amount == 1:lengths = (seq_len + 1,) * (total_segments - 1)else:lengths = (*((seq_len + 1,) * (total_segments - 1)), remainder_amount)# loop through lengthsfor next_length in lengths:segment_output, memories, states = self.transformer.generate(output[:, -current_len:],length = next_length,xl_memories = memories,states = states,return_memories_and_states = True,**kwargs)output = torch.cat((output, segment_output), dim = -1)current_len = 1return output[:, start_len:]def forward(self,x,return_memories_and_states = False):total_seq_len, seq_len = x.shape[1], self.seq_lenassert divisible_by(total_seq_len - 1, seq_len), f'length of sequence ({total_seq_len}) must be equal to a multiple of {seq_len} + 1 (one extra token) during training'segments = total_seq_len // seq_lentotal_loss = 0.memories = []states = []for ind in range(segments):start = ind * seq_lenend = start + seq_len + 1if self.training and random() < self.xl_memories_dropout:memories.clear()if self.training and random() < self.state_dropout:states.clear()loss, memories, states = self.transformer(x[:, start:end],xl_memories = memories,states = states,return_loss = True)total_loss = total_loss + (loss / segments)if return_memories_and_states:return total_loss, memories, statesreturn total_lossif __name__ == '__main__':model = BlockRecurrentTransformer(num_tokens=20000,  # vocab sizedim=512,  # model dimensionsdepth=6,  # depthdim_head=64,  # attention head dimensionsheads=8,  # number of attention headsmax_seq_len=1024,  # the total receptive field of the transformer, in the paper this was 2 * block sizeblock_width=512,# block size - total receptive field is max_seq_len, 2 * block size in paper. the block furthest forwards becomes the new cached xl memories, which is a block size of 1 (please open an issue if i am wrong)num_state_vectors=512,  # number of state vectors, i believe this was a single block size in the paper, but can be any amountrecurrent_layers=(4,),  # where to place the recurrent layer(s) for states with fixed simple gatinguse_compressed_mem=False,  # whether to use compressed memories of a single block width, from https://arxiv.org/abs/1911.05507compressed_mem_factor=4,  # compression factor of compressed memoriesuse_flash_attn=False  # use flash attention, if on pytorch 2.0)seq = torch.randint(0, 2000, (1, 512))out, mems1, states1 = model(seq)out, mems2, states2 = model(seq, xl_memories=mems1, states=states1)out, mems3, states3 = model(seq, xl_memories=mems2, states=states2)

相关文章:

  • Mybatis之动态SQL
  • 基于SSM的医院交互系统的设计与实现
  • Oracle官网 账号及密码 -- 笔记
  • 测试新手百科:Postman简介、安装、入门使用方法详细攻略!
  • Hadoop3.x完全分布式环境搭建Zookeeper和Hbase
  • 工作中常用的RabbitMQ实践
  • GPT-4 变懒了?官方回复
  • Linux 网络协议
  • 秋招春招,我没有拿到一个offer怎么办?
  • 关于IDEA中maven的作用以及如何配置MAVEN
  • springboot(ssm滁州市特产销售系统 特产商城系统Java系统
  • SQLMap介绍
  • 低多边形3D建模石头材质纹理贴图
  • 【微服务】springboot整合quartz使用详解
  • 无人零售店,凭借黑科技引领,它的前景如何?
  • 【vuex入门系列02】mutation接收单个参数和多个参数
  • CentOS从零开始部署Nodejs项目
  • flask接收请求并推入栈
  • GitUp, 你不可错过的秀外慧中的git工具
  • IIS 10 PHP CGI 设置 PHP_INI_SCAN_DIR
  • 算法-图和图算法
  • 小程序测试方案初探
  • 新版博客前端前瞻
  • 学习Vue.js的五个小例子
  • 一文看透浏览器架构
  • JavaScript 新语法详解:Class 的私有属性与私有方法 ...
  • ​​​​​​​开发面试“八股文”:助力还是阻力?
  • # Swust 12th acm 邀请赛# [ K ] 三角形判定 [题解]
  • #HarmonyOS:软件安装window和mac预览Hello World
  • $().each和$.each的区别
  • (06)金属布线——为半导体注入生命的连接
  • (4)(4.6) Triducer
  • (4)logging(日志模块)
  • (delphi11最新学习资料) Object Pascal 学习笔记---第14章泛型第2节(泛型类的类构造函数)
  • (附源码)php投票系统 毕业设计 121500
  • (黑客游戏)HackTheGame1.21 过关攻略
  • (一)Dubbo快速入门、介绍、使用
  • (转)eclipse内存溢出设置 -Xms212m -Xmx804m -XX:PermSize=250M -XX:MaxPermSize=356m
  • (最简单,详细,直接上手)uniapp/vue中英文多语言切换
  • .net 托管代码与非托管代码
  • .NET 药厂业务系统 CPU爆高分析
  • .Net6支持的操作系统版本(.net8已来,你还在用.netframework4.5吗)
  • .NET分布式缓存Memcached从入门到实战
  • .NET和.COM和.CN域名区别
  • @Builder注释导致@RequestBody的前端json反序列化失败,HTTP400
  • @Not - Empty-Null-Blank
  • [ 网络基础篇 ] MAP 迈普交换机常用命令详解
  • [C++]——继承 深继承
  • [GN] DP学习笔记板子
  • [HTML]Web前端开发技术18(HTML5、CSS3、JavaScript )HTML5 基础与CSS3 应用——喵喵画网页
  • [Linux基础开发工具---vim]关于vim的介绍、vim如何配置及vim的基本操作方法
  • [mit6.s081] 笔记 Lab2:system calls
  • [Mvc]在ASP.NET MVC中使用Repeater
  • [MySQL]视图索引以及连接查询案列
  • [M双指针] lc209. 长度最小的子数组(双指针+好题)