From 66976c5e16731282ff4f2019596c304a60599138 Mon Sep 17 00:00:00 2001 From: MPolaris <540492239@qq.com> Date: Sun, 18 Sep 2022 10:29:44 +0800 Subject: [PATCH] fix cutoff bug --- utils/builder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/utils/builder.py b/utils/builder.py index 8375031..0e31239 100644 --- a/utils/builder.py +++ b/utils/builder.py @@ -94,7 +94,6 @@ def keras_builder(onnx_model, new_input_nodes:list=None, new_output_nodes:list=N outputs_node_names.append(node_outputs[0]) if new_output_nodes is not None and len(outputs_node_names) == len(new_output_nodes): break - input_nodes = [] if new_input_nodes is None: input_nodes = [tf_tensor[x.name] for x in model_graph.input] @@ -104,6 +103,9 @@ def keras_builder(onnx_model, new_input_nodes:list=None, new_output_nodes:list=N if new_output_nodes is None: outputs_nodes = [tf_tensor[x.name] for x in model_graph.output] else: + for node in model_graph.output: + if node.name in new_output_nodes: + outputs_node_names.append(node.name) outputs_nodes = [tf_tensor[x] for x in outputs_node_names] keras_model = keras.Model(inputs=input_nodes, outputs=outputs_nodes)