Skip to content

Commit

Permalink
remove redundant input nodes of onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
MPolaris committed Oct 14, 2022
1 parent f106b0c commit d744f67
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion utils/onnx_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
from onnxsim import simplify

LOG = logging.getLogger("onnx_loader running:")
LOG.setLevel(logging.INFO)

def clean_model_input(model_proto):
inputs = model_proto.graph.input
name_to_input = {}
for input in inputs:
name_to_input[input.name] = input

names = []
for initializer in model_proto.graph.initializer:
if initializer.name in name_to_input:
inputs.remove(name_to_input[initializer.name])
names.append(initializer.name)
LOG.warning(f"[{len(names)}] redundant input nodes are removed.\n \
nodes name : {','.join(names)}")


def load_onnx_modelproto(onnx_model_path:str, need_simplify:bool=True):
if not os.path.exists(onnx_model_path):
Expand All @@ -19,10 +35,11 @@ def load_onnx_modelproto(onnx_model_path:str, need_simplify:bool=True):
if need_simplify:
success = False
try:
model_proto, success = simplify(model_proto, check_n=2, dynamic_input_shape=dynamic_input)
model_proto, success = simplify(model_proto, check_n=1, dynamic_input_shape=dynamic_input)
except:
success = False
if not success:
LOG.warning(f"onnxsim is failed, maybe make convert fails.")
model_proto = onnx.load(onnx_model_path)
clean_model_input(model_proto)
return model_proto

0 comments on commit d744f67

Please sign in to comment.