tensorflow 常用代码片段
tensorflow 常用代码片段
-
加载pb文件
def load_graph(frozen_graph_filename):
with tf.io.gfile.GFile(frozen_graph_filename, “rb”) as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
return graph -
写graph_def 到文件
with tf.gfile.Gfile('simplified.pb', 'wb') as fid:fid.write(graph_def.SerializeToString())
- 获取图的输入输出节点名字
def analyze_inputs_outputs(graph):ops = graph.get_operations()outputs_set = set(ops)inputs = []for op in ops:if len(op.inputs) == 0 and op.type != 'Const':inputs.append(op)else:for input_tensor in op.inputs:if input_tensor.op in outputs_set:outputs_set.remove(input_tensor.op)outputs = list(outputs_set)return (inputs, outputs)
- 简化模型,删除训练节点
graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, [out_name])
graph_def = tf.graph_util.remove_training_nodes(graph_def)