当前位置: 首页 > news >正文

【Karapathy大神build-nanogpt】Take Away Notes

B站翻译LINK

Personal Note

Andrej rebuild gpt2 in pytorch.

Take Away Points

  • Before entereing serious training, he use Shakespear’s work as a small debugging datset to see if a model can overfit. Overfitging is a should thing.
  • If we use TF32 or BF32, (by default is FP32 in pytorch it takes more memory), gpu can do hundred times faster. Becuase the core is computing really fast, most time (more than 40%) it is waiting for the memory allocation/transfer… Every computation is done in a manner like breaking down to 4x4 matrix multiplication.
  • when time a gpu programing, remember torch.cuda.synchronize()
  • watch gpu: watch -n 0.1 nvidia-smi
  • torch.set_float32_matmul_precision(‘high’) easily activate tf32 mode
    • default is highest -> float 32
    • High-> if avaible tensorflow32 (depends on GPU)
  • simply use torch.set_float32_matmul_precision(‘high’) theoraticlaly should make us have 8x speed. However we only achieve 3x. Becuse we are still memory bound, moving data around still cost a lot.
  • This can only be used in Amphere:
    use torch.autocast(device_type=device,dtype=torch.bfloat16) to wrap the forward process. In this wrap, some CUDA ops can autocast to BF16, many other stays in float32. Matrix Multiplication will be BF16.
  • One debug technique: import code; code.interact(local=locals())
  • torch.compile! Model = torch.compile(model)
  • Flash Attention. Flash Attention2. Online softmax.
    Use F.scale_dot_product_attention(q,k,v,is_causal=True) instead
  • Look for ugly numbers, make it to beautiful numbers. Any ugly numbers->increase it to have as much as 2 (Although flops will increase, time will decrease)
  • ALL ABOVE CHANGES MAKE PRAGRAM TRAINING 10x FASTER!!
  • Linear_warmup + cosine learning rate with minimum learning rate, see GPT3 paper for more details
  • First stage of the training, the model is not differing each tokens, they are just learning which tokens can show up which are not and driving them probability to zero. It is the reason that why in the early training stage, a small batchsize will be OK, as the gradients will not behave different if you use full batchsize.
  • parameters that should be weight decayed and should not be. WD: all weight tensors + embeddings (p.dim()>=2), NWD: all biaes, layernorms (p.dim()<2)
  • AdamW’s use_fused configuration (accelarate training process)
  • Model size up, lr down, batchsize up.
  • Grad accumulation: Remember to normalize: loss /= grad_accum_steps
  • when evaluation, use torch.Generator to create object used in torch.multinomial(xx,xx,generator=.), so that the generating process do not impact the global random number generator used for training.
  • However, torch.compile must be banned, so that you can sample in the training process.

CODES FOR DDP (SAMPLE)

# torchrun --stand_alone --nproc_per_node=<num_gpu_per_node> <your_training_script.py> <script_arguments> 
# Above only applies for single node training.# SETTINGS FOR EACH DIFFERENT RANK
ddp = int(os.environ.get('RANK',-1))!=-1
if ddp:assert torch.cuda.is_available()init_process_group(backend='nccl')ddp_rank = int(os.environ['RANK']) # It is a global rank, for each process it has a unique ddp_rankddp_local_rank = int(os.environ['LOCAL_RANK']) # It is a local rank in the local machine (node)ddp_world_size = int(os.environ['WORLD_SIZE']) # How many gpus (processes) in totaldevice = f'cuda:{ddp_local_rank}'torch.cuda.set_device(device)master_process = ddp_rank == 0
else:ddp_rank = 0ddp_local_rank = 0ddp_world_size = 1master_process = Truedevice = "cpu"if torhc.cuda.is_available():device = "cuda"elif hasattr(torch.backends,"mps") and torch.bakends.mps.is_available():device = "mps"print(f"using device:{device}")# IF YOU USE GRAD ACCUMULATION
total_batch_size = 524288 # batch size measured in token numbers
B = 16 # micro batch for each process
T = 1024 # sequence length
assert total_batch%(B * T * ddp_world_size) == 0
grad_accum_steps = total_batch_size // (B * T * ddp_world_size)# SET DATALOADER
Dataloader = DataLoader(*args, ddp_world_size, ddp_rank) # MUST! make each process deal with different part of datset# CREATE MODEL
model = createmodel()
model.to(device)
model = torch.compile(model)
if ddp:model = DDP(model,device_ids=[ddp_local_rank]) # this must be ddp_local_rank not ddp_rank
raw_model = model.module if ddp else model# FIX SEED
seed = 'YOUR LUCKY NUMBER'
torch.mannual_seed(seed)
if torch.cuda.is_available():torch.cuda.manual_seed(seed)# TRAIN
for step in range(max_steps):t0 = time.time()  model.train()optimizer.zero_grad()loss_accum = 0.0for micro_step in range(grad_accum_steps):x,y = Dataloader.next_batch()x,y = x.to(device),y.to(device)with torch.autocast(device_type=device,dtype=torch.bfloat16):logits, loss = model(x,y)loss = loss / grad_accum_stepsloss_accum += loss.detach()if ddp:model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1) # The ddp sync if applied to every micro step will be wasting time. So only the last backward in one accum cycle should be synchronized. See ddp.no_sync() contextmanager for official advice. Or use it in this way shown here.loss.backward() if ddp:torch.distributed.all_reduce(loss_accum,op=torch.distributed.ReduceOp.AVG)
norm = torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)if step%100 == 0:# start evaluationmodel.eval()with torch.no_grad():# SOME EVALUATION CODE
if ddp:destroy_process_group()

相关文章:

  • Android 12系统源码_存储(二)StorageManager类介绍
  • MySQL数据库慢查询日志、SQL分析、数据库诊断
  • 1.厦门面试
  • 学习Python的IDE功能--(一)入门导览
  • uniapp小程序上传pdf文件
  • 好玩的调度技术-场景编辑器
  • Linux内核编译安装 - Deepin,Debian系
  • llama-index,uncharted and llama2:7b run locally to generate Index
  • C语言从头学35——struct结构
  • Web开发:卡片翻转效果(HTML、CSS)
  • 文心一言:如何备考软考架构师
  • MySQL第一次作业
  • 数据库day2
  • 伪元素::before :: after的用法?
  • 微软GraphRAG +本地模型+Gradio 简单测试笔记
  • 【EOS】Cleos基础
  • 【面试系列】之二:关于js原型
  • css选择器
  • Cumulo 的 ClojureScript 模块已经成型
  • GraphQL学习过程应该是这样的
  • Mac转Windows的拯救指南
  • Redis的resp协议
  • vue从创建到完整的饿了么(11)组件的使用(svg图标及watch的简单使用)
  • XML已死 ?
  • 构建二叉树进行数值数组的去重及优化
  • 开源中国专访:Chameleon原理首发,其它跨多端统一框架都是假的?
  • 漂亮刷新控件-iOS
  • 前端之Sass/Scss实战笔记
  • 前嗅ForeSpider中数据浏览界面介绍
  • 微信小程序填坑清单
  • 学习ES6 变量的解构赋值
  • 以太坊客户端Geth命令参数详解
  • 06-01 点餐小程序前台界面搭建
  • 国内唯一,阿里云入选全球区块链云服务报告,领先AWS、Google ...
  • 如何通过报表单元格右键控制报表跳转到不同链接地址 ...
  • ​浅谈 Linux 中的 core dump 分析方法
  • (2)MFC+openGL单文档框架glFrame
  • (cljs/run-at (JSVM. :browser) 搭建刚好可用的开发环境!)
  • (delphi11最新学习资料) Object Pascal 学习笔记---第8章第5节(封闭类和Final方法)
  • (十六)串口UART
  • (实战篇)如何缓存数据
  • (四)c52学习之旅-流水LED灯
  • (续)使用Django搭建一个完整的项目(Centos7+Nginx)
  • (原創) 未来三学期想要修的课 (日記)
  • (转)使用VMware vSphere标准交换机设置网络连接
  • .NET CF命令行调试器MDbg入门(一)
  • .NET 常见的偏门问题
  • .NET6实现破解Modbus poll点表配置文件
  • .NET设计模式(7):创建型模式专题总结(Creational Pattern)
  • [ CTF ]【天格】战队WriteUp- 2022年第三届“网鼎杯”网络安全大赛(青龙组)
  • [ vulhub漏洞复现篇 ] Apache Flink目录遍历(CVE-2020-17519)
  • [20170705]lsnrctl status LISTENER_SCAN1
  • [20170713] 无法访问SQL Server
  • [AI]文心一言出圈的同时,NLP处理下的ChatGPT-4.5最新资讯
  • [Android] 240204批量生成联系人,短信,通话记录的APK