diff --git a/utils/builder.py b/utils/builder.py index 8375031..0e31239 100644 --- a/utils/builder.py +++ b/utils/builder.py @@ -94,7 +94,6 @@ def keras_builder(onnx_model, new_input_nodes:list=None, new_output_nodes:list=N outputs_node_names.append(node_outputs[0]) if new_output_nodes is not None and len(outputs_node_names) == len(new_output_nodes): break - input_nodes = [] if new_input_nodes is None: input_nodes = [tf_tensor[x.name] for x in model_graph.input] @@ -104,6 +103,9 @@ def keras_builder(onnx_model, new_input_nodes:list=None, new_output_nodes:list=N if new_output_nodes is None: outputs_nodes = [tf_tensor[x.name] for x in model_graph.output] else: + for node in model_graph.output: + if node.name in new_output_nodes: + outputs_node_names.append(node.name) outputs_nodes = [tf_tensor[x] for x in outputs_node_names] keras_model = keras.Model(inputs=input_nodes, outputs=outputs_nodes)