当前位置: 首页 > news >正文

Siamese Network(孪生网络/连体网络) (few-shot learning)

孪生网络很像对比学习,但它是few-shot learning中的内容

简单来说,Siamese network就是“连体的神经网络”,神经网络的“连体”是通过共享权值来实现的

 

大家可能会有疑问:共享权值是什么意思?左右两个神经网络的权重一模一样?

答:是的,在代码实现的时候,甚至可以是同一个网络,不用实现另外一个,因为权值都一样。对于siamese network,两边可以是lstm或者cnn,都可以。

孪生神经网络的用途是什么?

简单来说,衡量两个输入的相似程度。孪生神经网络有两个输入(Input1 and Input2),将两个输入feed进入两个神经网络(Network1 and Network2),这两个神经网络分别将输入映射到新的空间形成输入在新的空间中的表示。通过Loss的计算,评价两个输入的相似度。

训练数据有正负标签

模型

训练的时候

可以看到,算的是同一个模型对两个输入的差值,同类输入就缩小插值,不同类输入就扩大差值

class Siamese(nn.Module):
    def __init__(self, input_shape, pretrained=False):
        super(Siamese, self).__init__()
        self.vgg = VGG16(pretrained, 3)
        del self.vgg.avgpool
        del self.vgg.classifier
        
        flat_shape = 512 * get_img_output_length(input_shape[1], input_shape[0])
        self.fully_connect1 = torch.nn.Linear(flat_shape, 512)
        self.fully_connect2 = torch.nn.Linear(512, 1)

    def forward(self, x):
        x1, x2 = x
        #------------------------------------------#
        #   我们将两个输入传入到主干特征提取网络
        #------------------------------------------#
        x1 = self.vgg.features(x1)
        x2 = self.vgg.features(x2)   
        #-------------------------#
        #   相减取绝对值,取l1距离
        #-------------------------#     
        x1 = torch.flatten(x1, 1)
        x2 = torch.flatten(x2, 1)
        x = torch.abs(x1 - x2)
        #-------------------------#
        #   进行两次全连接
        #-------------------------#
        x = self.fully_connect1(x)
        x = self.fully_connect2(x)
        return x

Few-Shot Learning (2/3): Siamese Network (孪生网络)_哔哩哔哩_bilibili

相关文章:

  • inplace=True (原地操作)
  • 服务器内存泄漏
  • linux快速目录切换(cd - ,dirs, pushd, popd)
  • Python同时输出到屏幕和文件(Logger)
  • 图像通道转换——tensor从[h, w, c]转为[c, h, w] (permutetranspose和view的区别)(reshape和view)
  • linux服务器精确kill掉占用某几张卡的显存的程序
  • onnx 跨框架的模型中间表达框架(onnx.js)
  • Linux下删除文件后变成.nfsxxxxxx
  • linux lsof命令(查看哪个进程在占用文件)
  • TensorRT(GIE)
  • tensor与PIL.Image转换
  • numpy array与PIL.Image的转换
  • PyTorch Lightning (pl)
  • torch.jit (Python JIT) (Just-In-Time 即时编译器) (动态图转为静态图)
  • TorchScript (将动态图转为静态图)(模型部署)(jit)(torch.jit.trace)
  • 《深入 React 技术栈》
  • docker容器内的网络抓包
  • Git 使用集
  • javascript 总结(常用工具类的封装)
  • Transformer-XL: Unleashing the Potential of Attention Models
  • Vue实战(四)登录/注册页的实现
  • Webpack4 学习笔记 - 01:webpack的安装和简单配置
  • 彻底搞懂浏览器Event-loop
  • 成为一名优秀的Developer的书单
  • 番外篇1:在Windows环境下安装JDK
  • 高程读书笔记 第六章 面向对象程序设计
  • 如何在GitHub上创建个人博客
  • 探索 JS 中的模块化
  • 学习笔记TF060:图像语音结合,看图说话
  • #前后端分离# 头条发布系统
  • (22)C#传智:复习,多态虚方法抽象类接口,静态类,String与StringBuilder,集合泛型List与Dictionary,文件类,结构与类的区别
  • (C语言)输入一个序列,判断是否为奇偶交叉数
  • (k8s中)docker netty OOM问题记录
  • (十八)用JAVA编写MP3解码器——迷你播放器
  • (一一四)第九章编程练习
  • (原創) 如何刪除Windows Live Writer留在本機的文章? (Web) (Windows Live Writer)
  • (转)Linux整合apache和tomcat构建Web服务器
  • (转)菜鸟学数据库(三)——存储过程
  • *** 2003
  • .chm格式文件如何阅读
  • .cn根服务器被攻击之后
  • .NET委托:一个关于C#的睡前故事
  • @Bean注解详解
  • @property @synthesize @dynamic 及相关属性作用探究
  • [2010-8-30]
  • [ActionScript][AS3]小小笔记
  • [BZOJ3757] 苹果树
  • [C++] Boost智能指针——boost::scoped_ptr(使用及原理分析)
  • [C++参考]拷贝构造函数的参数必须是引用类型
  • [CISCN2019 华北赛区 Day1 Web5]CyberPunk --不会编程的崽
  • [codevs 1515]跳 【解题报告】
  • [Google Guava] 1.1-使用和避免null
  • [HOW TO]怎么在iPhone程序中实现可多选可搜索按字母排序的联系人选择器
  • [ISITDTU 2019]EasyPHP
  • [JS] node.js 入门