Python onnx 模型打印显示所有节点及查看相互关系
最近使用onnx时,想把所有的节点的信息和权重参数显示出来,找了一下没找到类似的源码,官方介绍的pythonAPI都是些什么加载,保存,转换之类的,没有详细介绍怎么使用onnx分析模型的,只好自己写一个。
其实很简单,我只列些最基本的,具体分析还得看个人的需要,
import onnx
model_in_file = 'yolov5s-sim.onnx'
if __name__ == "__main__":
model = onnx.load(model_in_file)
nodes = model.graph.node
nodnum = len(nodes) # 205
for nid in range(nodnum):
if (nodes[nid].output[0] == 'stride_32'):
print('Found stride_32: index = ', nid)
else:
print(nodes[nid].output)
inits = model.graph.initializer
ininum = len(inits) #124
for iid in range(ininum):
el = inits[iid]
print('name:', el.name, ' dtype:', el.data_type, ' dim:', el.dims)
# el.raw_data for weights and biases
print(model.graph.output) # display all the output nodes
print('Done')
比如,我这显示出来的模型中的节点是这样的,
[input: "data"
input: "model.0.conv.weight"
input: "model.0.conv.bias"
output: "122"
name: "Conv_0"
op_type: "Conv"
attribute {
name: "dilations"
ints: 1
ints: 1
type: INTS
}
attribute {
name: "group"
i: 1
type: INT
}
attribute {
name: "kernel_shape"
ints: 6
ints: 6
type: INTS
}
attribute {
name: "pads"
ints: 2
ints: 2
ints: 2
ints: 2
type: INTS
}
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
,
。。。。。。
, input: "325"
input: "model.24.m.2.weight"
input: "model.24.m.2.bias"
output: "376"
name: "Conv_234"
op_type: "Conv"
attribute {
name: "dilations"
ints: 1
ints: 1
type: INTS
}
attribute {
name: "group"
i: 1
type: INT
}
attribute {
name: "kernel_shape"
ints: 1
ints: 1
type: INTS
}
attribute {
name: "pads"
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "strides"
ints: 1
ints: 1
type: INTS
}
, input: "376"
input: "398"
output: "399"
name: "Reshape_251"
op_type: "Reshape"
, input: "399"
output: "stride_32"
name: "Transpose_252"
op_type: "Transpose"
attribute {
name: "perm"
ints: 0
ints: 1
ints: 3
ints: 4
ints: 2
type: INTS
}
]
可以看出,在onnx模型中,结点之间用逗号隔开,输出和输出都分别列出,比如我这里最后一个节点的信息是
>>> nodes[204]
input: "399"
output: "stride_32"
name: "Transpose_252"
op_type: "Transpose"
attribute {
name: "perm"
ints: 0
ints: 1
ints: 3
ints: 4
ints: 2
type: INTS
}
>>> nodes[203]
input: "376"
input: "398"
output: "399"
name: "Reshape_251"
op_type: "Reshape"
其中node[204]input表示输入节点是399,也就是node[203]的输出;node[204]输出名称是stride_32,部署时就用这个名称来提取最终结果。像netron这样的工具,就是根据这些node之间的关系来绘制网络图的。
权重分析这里就不展开了。