当前位置: 首页 > news >正文

NeRF学习——复现训练中的问题记录

代码复现的框架是基于:pengsida 的 Learning NeRF

希望各位可以通过学习 NeRF-Pytorch 的源码来自己复现一下试试看!

文章目录

  • 1 Windows bug
    • 1.1 DataLoader 的多进程 pickle
    • 1.2 imageio 输出图片
    • 1.3 I/O
  • 2 训练问题
    • 2.1 Evaluate 显存爆炸
    • 2.2 尝试一
    • 2.3 尝试二
    • 2.4 尝试三 (最终)
    • 2.5 其余小问题
      • 2.5.1 psnr 不上升
      • 2.5.2 加载再训练

1 Windows bug

源代码框架是基于 Linux 的,我在 Windows 上进行复现有以下的 bug,Windows 上 bug 修复的框架版本:Learning NeRF

1.1 DataLoader 的多进程 pickle

image-20240317200544474

在刚开始执行时就遇到了这个错误,开始以为是环境问题,结果鼓捣半天没用,仔细分析了一下错误原因,发现主要是这句话的原因

for iteration, batch in enumerate(data_loader):

也就是序列化的问题,后面还给出了多进程的报错,就应该是多进程的pickle问题,网上一搜,还真是,可能是Windows系统的原因导致的,在 configs 中的配置文件中修改 num_workers = 0,不使用多进程,就解决了报错(速度应该没啥影响)

1.2 imageio 输出图片

在执行到 evaluate 时,输出图片出了下面的错误:

envs\NeRFlearning\Lib\site-packages\PIL\Image.py", line 3102, in fromarray raise TypeError(msg) from e TypeError: Cannot handle this data type: (1, 1, 3), <f4

这是因为 imageio.imwrite 函数需要接收的图像数据类型为 uint8 ,而原始的 pred_rgbgt_rgb 可能是浮点数类型的数据。因此,我们需要将它们乘以255(将范围从0-1转换为0-255),然后使用 astype 函数将它们转换为 uint8 类型

# 将数据类型转换为 uint8
pred_rgb = (pred_rgb * 255).astype(np.uint8)
gt_rgb = (gt_rgb * 255).astype(np.uint8)
# 需要添加以上两行,否则报错
imageio.imwrite(save_path, img_utils.horizon_concate(gt_rgb, pred_rgb))

这里还有一个问题,我发现在已经生成过一张图片后,再次执行到这里,输出的新图片无法覆盖之前的旧图片,所以我加上了时间和当前周期作为扩展名

now = datetime.datetime.now()
now_str = now.strftime('%Y-%m-%d_%Hh%Mm%Ss')
base_name, ext = os.path.splitext(save_path)
save_path = f"{base_name}_{now_str}_epoch_{cfg.train.epoch}{ext}"

这样就会以这样的 res_2024-03-18_09h09m40s_epoch_10.jpg 格式正确输出每次的图像了

1.3 I/O

此项目是由 Linux 开发的,在Windows系统上,免不了出现各种麻烦。特别是,该项目的所有 I/O 都是 Linux 的 I/O 格式,所以要进行全面修改:

  • 使用 Python 的 os 模块中的 makedirs 函数来替换

    os.system('mkdir -p ' + model_dir)
    # |
    # V
    os.makedirs(model_dir, exist_ok=True)
    
  • shutil 模块中的 rmtree 函数来替换

    os.system('rm -rf {}'.format(model_dir))
    # |
    # V
    import shutil
    if os.path.exists(model_dir):shutil.rmtree(model_dir)
    
  • Python 的 os 模块中的remove函数来替换

    os.system('rm {}'.format(os.path.join(model_dir, '{}.pth'.format(min(pths)))))
    # |
    # V
    os.remove(os.path.join(model_dir, '{}.pth'.format(min(pths))))
    
  • 使用 exit(0) 来结束当前进程

    os.system('kill -9 {}'.format(os.getpid()))
    # |
    # V
    exit(0)
    

2 训练问题

训练从尝试一开始,记录了一步一步修改改进,到尝试三就基本完成复现!

2.1 Evaluate 显存爆炸

在运行 evaluate 模块时,我用模型对一张图片进行推理,发现显存爆炸增长,经过一番寻找之后,找到了问题:

  • 在运行 run.py 的 evaluate 模块时,evaluator.evaluate(output, batch)这句话没有在 torch.no_grad() 中导致的。修改只需要加一个缩进即可

更新:重构了evaluate的代码,之前是我理解出现了问题,在evaluate中再一次进行了模型推理,现在不用了,代码恢复为以前的

2.2 尝试一

训练中:在 Dateset 的 __getitem__ 函数中进行 shuffle,导致每次迭代都会 shuffle 整个数据一次,耗费大量时间

  • 在经过10000次迭代(20个epoch)后:

    • psnr 整体是呈现一个上升的趋势,在10000次迭代后在24左右

      image-20240430195650221
    • loss 从一开始整体下降,之后就一直在0.1225左右徘徊(感觉出现问题)没有明显的下降趋势

      image-20240430195519373
  • 分别在5000次和10000次迭代的时候 evaluate 了两次:

    • 第一次:loss = 0.1201,psnr = 22.94,mse = 0.005223
    • 第二次:loss = 0.1191,psnr = 23.8,mse = 0.004276

    推理得到的图像如下:

其实在训练时获取光线时本来就是打乱了的,传入 __getitem__index 就是随机的,但是发现在移除所有光线的 shuffle 后,训练出现问题(一直 psnr = infloss = 0),具体原因还未知、

2.3 尝试二

现在我将 shuffle 放到了 __init__ 中去,只在创建数据集时进行一次 shuffle,训练正常进行,时间相比以前有了提升!

但出过拟合问题!

  • 在2500次迭代(5个epoch)后:

    • loss 持续下降到 0.000292,停止后再次开始训练时断崖上升

      image-20240430203600663
    • psnr 持续上升到40,停止后再次开始训练时断崖下降

      image-20240430203714312
  • 使用 run.py 的 evaluate 模块进行测试:

    输出的图像与指标不符,loss = 0.055psnr = 15.4939mse = 0.0282;具体生成图像也是比较糟糕,只能看到大致雏形

    image-20240430212522124

问题:

发现了问题出现在数据集的 __getitem__ 上,每次传入的 index 是随机图片索引,只有只有 1~200,只能得到前面的 1~200*1024 的数据太狭隘了,导致了过拟合。

目前解决方案在 index 的基础上乘上图像的宽高:

index = index * self.H * self.W

2.4 尝试三 (最终)

经过上述的修改,并且参考 nerf-pytorch 的代码在每隔一段时间(我现在暂时设定self.N_rays * cfg.ep_iter 次,即 1024*500)就会重新打乱一遍所有的32000000条光线

我将代码放到了阿里云的服务器上进行训练,经过10.48小时的训练,总共迭代了71,000次(142个epoch)

  • 训练:

    • loss 持续下降,目前在0.003左右波动

      image-20240501085109274
    • psnr 持续上升,目前在32左右波动

      image-20240501085258470
  • evaluate:我设定了每隔2500次迭代(5个epoch)就进行一次 evaluate,每次为了节省时间只用10张图片进行测试

    loss 持续下降到0.006mse 持续下降到0.0018psnr 持续上升到27.6134

    image-20240501085755386

    得到的图片:

    009

    训练完后对整体(200张图片)进行了一次 evaluate:得到 psnr = 28.4860

  • 与尝试一进行对比:在10000次迭代左右,时间大大减少,且loss、psnr都更加优秀

    image-20240501090741066

2.5 其余小问题

2.5.1 psnr 不上升

有时开始训练时的psnr从9左右开始就会导致 psnr 不上升一直徘徊在9左右,loss 正常下降,重新开始训练就有可能回归正常!

image-20240430212144837

尝试了5000次的迭代(10个epoch),测试出来图片如下:

lQLPKHHImplXjiHNAlTNBKqwQyL9R_cVfYgGGp3Cg3rNAA_1194_596

目前原因可能是初始化导致的😭(小问题,重新开始训练就行)

2.5.2 加载再训练

当我将保存的模型加载接着训练的时候,我发现了 loss 相较于之前突然变大了再缓慢下降,psnr 也是相似的,一开始相较之前的要小再缓慢上升

应该是存储和读取模型时的问题

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【全国大学生电子设计竞赛】2022年E题
  • TCP Analysis Flags 之 TCP Window Full
  • 解决 Vue 页面中地址栏参数变更不刷新的问题
  • react防抖和节流hooks封装
  • Hystrix 线程池策略时使用ThreadLocal
  • 【LeetCode】219.存在重复元素II
  • STM32卡死、跑飞如何调试确定问题
  • CMD运行指令
  • 鸿蒙系统开发【ASN.1密文转换】安全
  • 线程池工具类 Executors源代码详解
  • 基于Redis实现全局唯一id
  • 小试牛刀-Telebot区块链游戏机器人(TS升级)
  • 【Python】数据类型之详讲字符串(下)
  • 全球轻型汽车安全气囊面料市场规划预测:未来六年CAGR为4.3%
  • 1. 什么是操作系统
  • 【Leetcode】101. 对称二叉树
  • [分享]iOS开发-关于在xcode中引用文件夹右边出现问号的解决办法
  • 5、React组件事件详解
  • Java深入 - 深入理解Java集合
  • React-redux的原理以及使用
  • SAP云平台运行环境Cloud Foundry和Neo的区别
  • windows-nginx-https-本地配置
  • 测试开发系类之接口自动化测试
  • 第2章 网络文档
  • 关于 Linux 进程的 UID、EUID、GID 和 EGID
  • 解析 Webpack中import、require、按需加载的执行过程
  • 主流的CSS水平和垂直居中技术大全
  • 3月7日云栖精选夜读 | RSA 2019安全大会:企业资产管理成行业新风向标,云上安全占绝对优势 ...
  • LevelDB 入门 —— 全面了解 LevelDB 的功能特性
  • shell使用lftp连接ftp和sftp,并可以指定私钥
  • Spring第一个helloWorld
  • 东超科技获得千万级Pre-A轮融资,投资方为中科创星 ...
  • 没有任何编程基础可以直接学习python语言吗?学会后能够做什么? ...
  • ​如何在iOS手机上查看应用日志
  • ​软考-高级-系统架构设计师教程(清华第2版)【第15章 面向服务架构设计理论与实践(P527~554)-思维导图】​
  • # 消息中间件 RocketMQ 高级功能和源码分析(七)
  • ######## golang各章节终篇索引 ########
  • #C++ 智能指针 std::unique_ptr 、std::shared_ptr 和 std::weak_ptr
  • $(selector).each()和$.each()的区别
  • $jQuery 重写Alert样式方法
  • (C#)Windows Shell 外壳编程系列9 - QueryInfo 扩展提示
  • (zt)基于Facebook和Flash平台的应用架构解析
  • (八)Docker网络跨主机通讯vxlan和vlan
  • (编译到47%失败)to be deleted
  • (多级缓存)缓存同步
  • (附源码)springboot掌上博客系统 毕业设计063131
  • (六)vue-router+UI组件库
  • (七)glDrawArry绘制
  • (七)Java对象在Hibernate持久化层的状态
  • (提供数据集下载)基于大语言模型LangChain与ChatGLM3-6B本地知识库调优:数据集优化、参数调整、Prompt提示词优化实战
  • (一)Linux+Windows下安装ffmpeg
  • (一)WLAN定义和基本架构转
  • (转)C#调用WebService 基础
  • (转)shell调试方法
  • (转载)虚幻引擎3--【UnrealScript教程】章节一:20.location和rotation