From e6431530944d032d6d45a632011aa639d65988bc Mon Sep 17 00:00:00 2001 From: Sebastian Gallese <140911+sgallese@users.noreply.github.com> Date: Thu, 1 Aug 2024 08:17:03 -0700 Subject: [PATCH 1/2] Fix arguments for transformers_and_bert --- .../tools/add_pre_post_processing_to_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime_extensions/tools/add_pre_post_processing_to_model.py b/onnxruntime_extensions/tools/add_pre_post_processing_to_model.py index 66cd3b01e..9dba21b79 100644 --- a/onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +++ b/onnxruntime_extensions/tools/add_pre_post_processing_to_model.py @@ -522,7 +522,7 @@ def main(): else: if args.vocab_file is None or args.nlp_task_type is None or args.tokenizer_type is None: parser.error("Please provide vocab file/nlp_task_type/tokenizer_type.") - transformers_and_bert(model_path, new_model_path, args.tokenizer_type, args.vocab_file, args.nlp_task_type) + transformers_and_bert(model_path, new_model_path, args.vocab_file, args.tokenizer_type, args.nlp_task_type) if __name__ == "__main__": From afa9a37d3e0af1729f4e7e06c2b71ad294ec9881 Mon Sep 17 00:00:00 2001 From: Sebastian Gallese <140911+sgallese@users.noreply.github.com> Date: Thu, 1 Aug 2024 08:54:52 -0700 Subject: [PATCH 2/2] Add test --- test/test_tools_add_pre_post_processing_to_model.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_tools_add_pre_post_processing_to_model.py b/test/test_tools_add_pre_post_processing_to_model.py index bec043b17..809c829c2 100644 --- a/test/test_tools_add_pre_post_processing_to_model.py +++ b/test/test_tools_add_pre_post_processing_to_model.py @@ -393,6 +393,13 @@ def test_qatask_with_tokenizer(self): self.assertEqual(result[0][0], ref_output[0][0]) + def test_transformers_and_bert(self): + input_model = os.path.join(test_data_dir, "../bert_qa_decoder_base.onnx") + output_model = (self.temp4onnx / "bert_qa.updated.onnx").resolve() + vocab_file = os.path.join(test_data_dir, "../bert.vocab") + + add_ppp.transformers_and_bert(Path(input_model), Path(output_model), Path(vocab_file), "BertTokenizer", "QuestionAnswering") + # Corner Case def test_debug_step(self): import onnx