当前位置: 首页 > news >正文

图像分割实战-系列教程7:unet医学细胞分割实战5(医学数据集、图像分割、语义分割、unet网络、代码逐行解读)

🍁🍁🍁图像分割实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

unet医学细胞分割实战1
unet医学细胞分割实战2
unet医学细胞分割实战3
unet医学细胞分割实战4
unet医学细胞分割实战5
unet医学细胞分割实战6

9 模型架构类----archs.py解读

这部分内容主要解析本任务使用的网络,主要有两个网络可以选择,一个是Unet另一个是NestedUNet,实际上就是UNet++,这两个网络的都是主要调用了VGG块来进行网络的构建

9.1 VGGBlock

import torch
from torch import nn
__all__ = ['UNet', 'NestedUNet']
class VGGBlock(nn.Module):def __init__(self, in_channels, middle_channels, out_channels):super().__init__()self.relu = nn.ReLU(inplace=True)self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)self.bn1 = nn.BatchNorm2d(middle_channels)self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)return out

首先来看看一个VGG块,实际上就是数据经过几个卷积relu:

  1. 输入数据
  2. 经过一个3*3的卷积
  3. 经过一个batchNormalization
  4. 经过一个relu
  5. 再次经过一个3*3的卷积
  6. 再次经过一个batchNormalization
  7. 再次经过一个relu
  8. 得到输出

这就是一个VGG块的过程,其中每次进入的数据的长宽、输出通道都是在调用VGG块的时候进行定义的,每一个VGG块有三个参数需要指定,分别是输入通道数、中间通道数、输出通道数

9.2 Unet

class UNet(nn.Module):def __init__(self, num_classes, input_channels=3, **kwargs):super().__init__()nb_filter = [32, 64, 128, 256, 512]self.pool = nn.MaxPool2d(2, 2)self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)def forward(self, input):x0_0 = self.conv0_0(input)x1_0 = self.conv1_0(self.pool(x0_0))x2_0 = self.conv2_0(self.pool(x1_0))x3_0 = self.conv3_0(self.pool(x2_0))x4_0 = self.conv4_0(self.pool(x3_0))x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))output = self.final(x0_4)return output

Unet网络,主要都是调用VGG块来构建的:

  1. 首先输入数据
  2. 进入一个定义好的VGG块conv0_0 ,得到x0_0
  3. x1_0、x2_0、x3_0、x4_0都是先经过一个(2,2)的maxpooling后,再经过一个定义好的VGG块
  4. 而x3_1、x2_2、x1_3、x0_4都是先与其对应的数据进行拼接后再经过一个定义好的VGG块,具体原理可以参考这篇文章
  5. 最后把x0_4的输出经过一个二维卷积得到最终的输出

9.3 NestedUNet

9.3.1 构造函数

class NestedUNet(nn.Module):def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):super().__init__()nb_filter = [32, 64, 128, 256, 512]self.deep_supervision = deep_supervisionself.pool = nn.MaxPool2d(2, 2)self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])if self.deep_supervision:self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)else:self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)

9.3.2 前向传播

    def forward(self, input):x0_0 = self.conv0_0(input)x1_0 = self.conv1_0(self.pool(x0_0))x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))x2_0 = self.conv2_0(self.pool(x1_0))x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))x3_0 = self.conv3_0(self.pool(x2_0))x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))x4_0 = self.conv4_0(self.pool(x3_0))x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))if self.deep_supervision:output1 = self.final1(x0_1)output2 = self.final2(x0_2)output3 = self.final3(x0_3)output4 = self.final4(x0_4)return [output1, output2, output3, output4]else:output = self.final(x0_4)return output

NestedUNet即UNet++,与UNet大同小异,关于UNet++的解析在这里

  1. 首先输入数据
  2. 先经过一个VGG块得到x0_0
  3. x0_0 经过一个maxpooling后再经过一个VGG块得到x1_0
  4. 拼接x1_0 和上采样后的x0_0 后再经过一个VGG块得到x0_1
  5. x1_0 经过一个maxpooling后再经过一个VGG块得到x2_0
  6. 拼接x1_0 和上采样后的x2_0 后再经过一个VGG块得到x1_1
  7. 最终分别得到x0_1、x0_2、x0_3、x0_4,这4个都可以作为输出

这就是整个的模型架构,如果需要进行深入的掌握,建议把每一个前向传播的过程的数据维度打印出来

unet医学细胞分割实战1
unet医学细胞分割实战2
unet医学细胞分割实战3
unet医学细胞分割实战4
unet医学细胞分割实战5
unet医学细胞分割实战6

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • LeetCode 1758. 生成交替二进制字符串的最少操作数【字符串,模拟】1353
  • 2024年01月数据库流行度最新排名
  • 开源掌机是什么?
  • FA模板制作
  • R_handbook_统计分析
  • 数据结构:队列(链表和数组模拟实现)
  • 大数据机器学习GAN:生成对抗网络GAN全维度介绍与实战
  • 系统学习Python——装饰器:函数装饰器-[对方法进行装饰:基础知识]
  • 基础算法-归并排序
  • 20231228在Firefly的AIO-3399J开发板的Android11使用Firefly的DTS配置单前后摄像头ov13850
  • Pandas的apply方法的应用练习
  • 2023-12-12LeetCode每日一题(下一个更大元素 IV)
  • SDG大数据平台简介
  • [ 云计算 | AWS ] 对比分析:Amazon SNS 与 SQS 消息服务的异同与选择
  • Java——功能开发思路
  • 【159天】尚学堂高琪Java300集视频精华笔记(128)
  • 【Under-the-hood-ReactJS-Part0】React源码解读
  • C++类的相互关联
  • Dubbo 整合 Pinpoint 做分布式服务请求跟踪
  • FastReport在线报表设计器工作原理
  • Python打包系统简单入门
  • Redis在Web项目中的应用与实践
  • Spring Cloud Feign的两种使用姿势
  • 驱动程序原理
  • 通过git安装npm私有模块
  • 微服务入门【系列视频课程】
  • const的用法,特别是用在函数前面与后面的区别
  • # windows 安装 mysql 显示 no packages found 解决方法
  • #!/usr/bin/python与#!/usr/bin/env python的区别
  • #、%和$符号在OGNL表达式中经常出现
  • #周末课堂# 【Linux + JVM + Mysql高级性能优化班】(火热报名中~~~)
  • (02)Unity使用在线AI大模型(调用Python)
  • (2)(2.4) TerraRanger Tower/Tower EVO(360度)
  • (day 12)JavaScript学习笔记(数组3)
  • (动手学习深度学习)第13章 计算机视觉---微调
  • (二)Eureka服务搭建,服务注册,服务发现
  • (二)JAVA使用POI操作excel
  • (二十六)Java 数据结构
  • (附源码)ssm基于jsp高校选课系统 毕业设计 291627
  • (机器学习-深度学习快速入门)第三章机器学习-第二节:机器学习模型之线性回归
  • (三维重建学习)已有位姿放入colmap和3D Gaussian Splatting训练
  • (四)汇编语言——简单程序
  • (原創) X61用戶,小心你的上蓋!! (NB) (ThinkPad) (X61)
  • (转)如何上传第三方jar包至Maven私服让maven项目可以使用第三方jar包
  • (转)原始图像数据和PDF中的图像数据
  • .NET Conf 2023 回顾 – 庆祝社区、创新和 .NET 8 的发布
  • .NET Core 控制台程序读 appsettings.json 、注依赖、配日志、设 IOptions
  • .NET的微型Web框架 Nancy
  • .pings勒索病毒的威胁:如何应对.pings勒索病毒的突袭?
  • @Transient注解
  • [ Linux ] Linux信号概述 信号的产生
  • [ Python ]使用Charles对Python程序发出的Get与Post请求抓包-解决Python程序报错问题
  • [AIGC] Redis基础命令集详细介绍
  • [C]编译和预处理详解
  • [C++]模板与STL简介