Pytorch加载模型不完全匹配 只加载部分参数权重 load
加载模型不完全匹配
model.load_state_dict(torch.load(weight_path), strict=False)
当权重中的key和网络中匹配就加载,不匹配就跳过
如果strict是True,那必须完全匹配,不然就报错
默认是True
但是注意,如果是像英文模型迁移到中文,改了class num的话,例如由26改为3600,这时模型不匹配用它是解决不了的,因为此时模型的key名字是对应的上的,只是权重的size不同 看只加载部分参数权重
如果发生上述情况的话,那就需要把加载到的模型的中,不匹配的那几项删掉,然后加载其他项
x = torch.load(self.weight) del x['char_recognizer.classifier.bias'] del x['char_recognizer.classifier.weight'] self.load_state_dict(x, strict=False)
或者
# Use when some parts of pretrained model are not needed # pretrained_dict = checkpoint['state_dict'] # model_dict = model.state_dict() # # 1. filter out unnecessary keys # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # # 2. overwrite entries in the existing state dict # model_dict.update(pretrained_dict) # # 3. load the new state dict # model.load_state_dict(model_dict)