当前位置: 首页 > news >正文

实时瞳孔分割算法-RITnet论文复现

源代码在这里RITnet-Github

        这个模型比较小众,我们实验室使用了官方提供的模型进行瞳孔中心位置提取,以实现视线追踪,效果很好


一、数据集准备

        RITnet也是那一届openEDS数据集挑战赛的冠军模型,openEDS数据集可以从Kaggle上下载,一共9G,大家不用去找openEDS论文里提供的数据集下载地址,那个需要Facebook账户去进行申请,很麻烦

        从Kaggle上下载下来的数据集长这样:

        只需要openEDS,它包括了以下内容:

        可以看到有很多文件夹,但代码中只用到了traintest文件夹中的文件,所以其它的先不用管

看看瞳孔图像的样子:

展示了上述内容后,相信大家对数据集这块有足够的了解了!最重要的一点:只需要把Semantic_Segmentation_Dataset拷贝进代码项目中(不拷贝也行,重定义一下路径就好了),文件摆放格式什么的都不用改,妥妥的保姆级


二、环境配置

        我所用的 IDE 是 Pycharm 。从 Github 上下载下来的项目文件中包含2个对项目所需环境的描述文件:requirements.txt、environment.yml,前者只是记录了项目所用的包的名字,并没有各个包的具体版本信息;后者是一种便捷的环境打包文件,记录的就是原作者运行代码时的全部环境,但我在浅浅尝试之后就放弃了这种方法,因为我不会而且也用不好

        所以我自己新创建了一个 conda 虚拟环境,对新手更友好,条理也非常清晰!

1.创建 python = 3.8 的 Anaconda 虚拟环境

        没有经验的同学可以参考这篇帖子:从零开始创建conda环境及pycharm配置项目环境

        打开 Anaconda Prompt,使用下面这条指令就能创建新的虚拟环境了

# Success 是自定义的环境名,python=3.8也是自定义的python版本
conda create -n Success python=3.8

        再使用下面这条指令看看环境有没有创建好,出现下图就是好消息

activate Success

这个环境的具体位置在你之前下载的 Anaconda 文件夹下,看看我的:

Anaconda2023 是我当时在安装 Anaconda 时新建的文件夹名,而 envs 中就存放着我们建立过的所有虚拟环境啦

创建好之后,把这个环境配置进 Pycharm 里。步骤为:

左上角 -> 设置

项目:RITNet -> Python解释器 -> 添加解释器 -> 添加本地解释器

Virtualenv环境 -> 上面提到的 Anaconda 所处位置 -> envs -> 虚拟环境文件夹 -> python.exe(目标文件),选定 python.exe 点“确定”就好了

右下角变成这样就大功告成了

2.在当前虚拟环境中安装所需包

        这块是最复杂最关键的部分,所用篇幅较长

a.Pytorch-GPU 安装

        在 RITnet 项目中用到了 GPU 来加速模型训练,需要安装 GPU 版本的 Pytorch,这里有很多坑,但经过我长时间的摸索已经总结出一套必杀技:

        大家应该经常在网上看到这2条指令 nvcc -V 、nvidia-smi 它们都是用来查看自己电脑上所安装的CUDA版本的,区别在于 nvidia-smi 所查看是自己电脑本身的 CUDA 版本,而 nvcc -V 指令只有你的电脑上有 Pytorch 时才能运行成功(不论GPU版本还是CPU版本),这2条语句的差别很大,而我们只需使用 nvidia-smi 来查看自己电脑的 CUDA 信息即可,如:

        我电脑的 CUDA 版本是 12.3,我们需要根据它来安装 CUDA ,进而安装对应版本的 Pytorch-GPU.

        CUDA 安装总体可以参考这篇帖子:全网最详细的安装pytorch GPU方法全网最详细之如何安装gpu版的pytorch,但别完全参考,有小坑

     而在安装 CUDA 时我遇到了 “NVIDIA 安装程序失败的问题”,很棘手,但解决了,解决方式如下:

        在选择组件(自定义安装)的时候,将 CUDA 中的 Nsight VSEVisual Studio Integration 取消勾选,后选择下一步,即可安装成功。此招式来自NVIDIA安装CUDA在安装阶段提示NVIDIA安装程序失败超级有用(给磕了)


        安装好 CUDA 后,不要使用 pip install 指令直接从 Pytorch 官网下载 GPU 版本的 Pytorch,我下了超多次,结果下下来的都是 CPU 版本的,参考这个全网最详细之如何安装gpu版的pytorch

        我将安装的 CUDA 和 torch 等包的对应关系放进下方的表格里,完全按照这个来就行

CUDA12.0.0
torch2.0.0
torchvision0.15.1
torchaudio2.0.1

       点击进入下载地址,我们需要根据上述表格下载以下3个文件(需要VPN,不然超慢)

torch 下载

torchvision 下载

torchaudio 下载

        将上述3个文件都下载好后,进入文件所在目录(切记),在上面的搜索栏里输入 cmd,然后回车即可进入 dos

        在 dos 里激活刚刚创建好的虚拟环境后,使用下面这3行指令将上述3个文件都安装进我们的虚拟环境

pip install "torch-2.0.0+cu118-cp38-cp38-win_amd64.whl"pip install "torchvision-0.15.1+cu118-cp38-cp38-win_amd64.whl"pip install "torchaudio-2.0.1+cu118-cp38-cp38-win_amd64.whl"

        结束了,一切都结束了!最终用下面这段测试代码美美验证一下 Pytorch-GPU 到底装好没有!

import torch    # 测试是否安装完成torch模块
import torchvision  # 测试是否安装完成torchvision模块
import osif __name__ == '__main__':print("安装torch版本为: ", torch.__version__)print("是否安装完成Pytorch-GPU : ", torch.cuda.is_available())arr = torch.zeros(5, 5)print("生成全零矩阵:\n", arr)

        可喜可贺,实在是可喜可贺!

b.其它包的安装

        全部使用 pip 进行安装(一定要关掉 VPN)

pip install scikit-learn
pip install numpy
pip install opencv-python  # 即cv2
pip install pillow
pip install matplotlib
pip install tqdm
pip install torchsummary
pip install argparse

三、代码调整

        想要运行 train.py,我们还需对其它 .py 文件进行调整.由于给原始瞳孔图像添加星爆图像这部分一直报错(维度问题),尝试解决了很多次,但都没能成功解决,于是把这部分图像预处理代码注释掉了,如果后续有哪位朋友解决了这个问题,还请和大家分享分享!

        1.将 dataset.py 中的 Starburst_augment 类全部注释,同时在 IrisDataset 类中的 __getitem__ 函数中注释掉调用 Starburst_augment 类的代码

        2.将 __init__中的一行代码转移到__getitem__ 中去.具体操作见下方:

class IrisDataset(Dataset):def __init__(self, filepath, split='train', transform=None, **args):self.transform = transformself.filepath = osp.join(filepath, split)self.split = splitlistall = []for file in os.listdir(osp.join(self.filepath, 'images')):if file.endswith(".png"):listall.append(file.strip(".png"))self.list_files = listallself.testrun = args.get('testrun')# 将下面这行代码换个位置,换到下方的__getitem__函数中去self.clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))def __len__(self):if self.testrun:return 10return len(self.list_files)def __getitem__(self, idx):# 从init移过来的self.clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8, 8))imagepath = osp.join(self.filepath, 'images', self.list_files[idx] + '.png')pilimg = Image.open(imagepath).convert("L")H, W = pilimg.width, pilimg.height

        整体复现流程就是这样,如果大家自己在复现过程中遇到解决不了的问题可以来私信我(复现论文真的好麻烦好累但别无它选

相关文章:

  • 【Python】超时请求或计算的处理
  • Dify快速接入微信
  • C++之map
  • OVF(Open Virtualization Format)或OVA(Open Virtual Appliance)格式有什么区别
  • OpenGL 和 DirectX 矩阵 乘法 左乘和右乘,glm和DirectXMath。 OpenGL用列矩阵 在 glsl中反转矩阵
  • C语言(数据存储)
  • 计算机网络学习实践:模拟PPP协议验证虚拟局域网(VLAN)
  • 【人工智能Ⅱ】实验8:生成对抗网络
  • 做外贸,怎么选国外服务器?
  • “神经网络之父”和“深度学习鼻祖”Geoffrey Hinton
  • 字节裁员!开启裁员新模式。。
  • 鸿蒙开发接口资源调度:【@ohos.backgroundTaskManager (后台任务管理)】
  • uni微信小程序editor富文本组件如何插入图片
  • Linux离线一键安装Docker及docker-compose环境
  • Double Add
  • ----------
  • 【译】理解JavaScript:new 关键字
  • 〔开发系列〕一次关于小程序开发的深度总结
  • 230. Kth Smallest Element in a BST
  • Android开源项目规范总结
  • el-input获取焦点 input输入框为空时高亮 el-input值非法时
  • ES学习笔记(12)--Symbol
  • JavaScript实现分页效果
  • js ES6 求数组的交集,并集,还有差集
  • js对象的深浅拷贝
  • JS基础之数据类型、对象、原型、原型链、继承
  • JS题目及答案整理
  • Linux Process Manage
  • Lsb图片隐写
  • miniui datagrid 的客户端分页解决方案 - CS结合
  • NLPIR语义挖掘平台推动行业大数据应用服务
  • PAT A1092
  • 聊聊spring cloud的LoadBalancerAutoConfiguration
  • 模型微调
  • 我的zsh配置, 2019最新方案
  • # 数仓建模:如何构建主题宽表模型?
  • #NOIP 2014# day.1 生活大爆炸版 石头剪刀布
  • #QT(TCP网络编程-服务端)
  • #鸿蒙生态创新中心#揭幕仪式在深圳湾科技生态园举行
  • (12)Hive调优——count distinct去重优化
  • (13)Hive调优——动态分区导致的小文件问题
  • (5)STL算法之复制
  • (8)Linux使用C语言读取proc/stat等cpu使用数据
  • (done) 声音信号处理基础知识(2) (重点知识:pitch)(Sound Waveforms)
  • (NO.00004)iOS实现打砖块游戏(十二):伸缩自如,我是如意金箍棒(上)!
  • (黑马点评)二、短信登录功能实现
  • (数据大屏)(Hadoop)基于SSM框架的学院校友管理系统的设计与实现+文档
  • (图文详解)小程序AppID申请以及在Hbuilderx中运行
  • .NET 8 跨平台高性能边缘采集网关
  • .net core webapi 大文件上传到wwwroot文件夹
  • .NET Core 和 .NET Framework 中的 MEF2
  • .NET 实现 NTFS 文件系统的硬链接 mklink /J(Junction)
  • ;号自动换行
  • ?.的用法
  • @RequestMapping用法详解