Siamese Network(孪生网络/连体网络) (few-shot learning)
孪生网络很像对比学习,但它是few-shot learning中的内容
简单来说,Siamese network就是“连体的神经网络”,神经网络的“连体”是通过共享权值来实现的
大家可能会有疑问:共享权值是什么意思?左右两个神经网络的权重一模一样?
答:是的,在代码实现的时候,甚至可以是同一个网络,不用实现另外一个,因为权值都一样。对于siamese network,两边可以是lstm或者cnn,都可以。
孪生神经网络的用途是什么?
简单来说,衡量两个输入的相似程度。孪生神经网络有两个输入(Input1 and Input2),将两个输入feed进入两个神经网络(Network1 and Network2),这两个神经网络分别将输入映射到新的空间,形成输入在新的空间中的表示。通过Loss的计算,评价两个输入的相似度。
训练数据有正负标签
模型
训练的时候
可以看到,算的是同一个模型对两个输入的差值,同类输入就缩小插值,不同类输入就扩大差值
class Siamese(nn.Module): def __init__(self, input_shape, pretrained=False): super(Siamese, self).__init__() self.vgg = VGG16(pretrained, 3) del self.vgg.avgpool del self.vgg.classifier flat_shape = 512 * get_img_output_length(input_shape[1], input_shape[0]) self.fully_connect1 = torch.nn.Linear(flat_shape, 512) self.fully_connect2 = torch.nn.Linear(512, 1) def forward(self, x): x1, x2 = x #------------------------------------------# # 我们将两个输入传入到主干特征提取网络 #------------------------------------------# x1 = self.vgg.features(x1) x2 = self.vgg.features(x2) #-------------------------# # 相减取绝对值,取l1距离 #-------------------------# x1 = torch.flatten(x1, 1) x2 = torch.flatten(x2, 1) x = torch.abs(x1 - x2) #-------------------------# # 进行两次全连接 #-------------------------# x = self.fully_connect1(x) x = self.fully_connect2(x) return x
Few-Shot Learning (2/3): Siamese Network (孪生网络)_哔哩哔哩_bilibili