每天五分钟玩转深度学习PyTorch:获取神经网络模型的子网络模型
本文重点
本文主要为第二步(模型搭建)和第六步(优化器)服务,因为子网络是网络模型的一部分,我们如何获取自网络,需要了解网络的模型结构,然后再优化器部分我们需要获取指定网络模型部分的模型参数,所以本节课程很重要。
named_children
import torch
from torch import nn
class BasicNet(nn.Module):def __init__(self):super(BasicNet, self).__init__()self.net = nn.Linear(4, 3)def forward(self, x):#自定义的网络或者层一定要写forward方法return self.net(x)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.net = nn.Sequential(BasicNet(),nn.ReLU(),nn.Linear(3, 2)) def forward(self,x):return self.net(x)net = Net()
for name, t in net.named_children():print('parameters:', name, t)print('parameters1:', name, t[0])print('parameters1:', nam