From 4b485fb2695971fcabaf6aa2cb9c54fe125eac8c Mon Sep 17 00:00:00 2001 From: Wenbing Li Date: Thu, 6 Mar 2025 01:43:54 +0000 Subject: [PATCH 1/3] apply_chat_template API design --- include/ortx_tokenizer.h | 22 +++++++++++++++++++++- shared/api/c_api_tokenizer.cc | 23 +++++++++++++++++++++++ test/pp_api_test/test_tokenizer.cc | 15 +++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/include/ortx_tokenizer.h b/include/ortx_tokenizer.h index c6b98a8df..248f1d2f7 100644 --- a/include/ortx_tokenizer.h +++ b/include/ortx_tokenizer.h @@ -5,6 +5,7 @@ #pragma once +#include #include "ortx_utils.h" #ifdef __cplusplus @@ -76,7 +77,6 @@ extError_t ORTX_API_CALL OrtxCreateTokenizerFromBlob(OrtxTokenizer** tokenizer, extError_t ORTX_API_CALL OrtxTokenize( const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size, OrtxTokenId2DArray** output); - /** * Converts a token to its corresponding ID. * @@ -171,6 +171,26 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* to extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem( const OrtxTokenId2DArray* token_id_2d_array, size_t index, const extTokenId_t** item, size_t* length); +/** + * @brief Applies a chat template to the given input. + * + * This function processes the specified template with the provided input using the + * tokenizer, and outputs the resulting string array. Optionally, it can include a + * generation prompt in the output. The chat template can be provided as a string or + * be retrieved from a loaded tokenizer json file which contains the chat template its json file. + * if both tokenizer and template_str are provided, the template_str will supersede the tokenizer. + * + * @param tokenizer Pointer to an OrtxTokenizer used for template processing + * @param template_str Null-terminated string representing the chat template, can be null if tokenizer.json has one. + * @param input Null-terminated string containing the input to be processed. + * @param output Double pointer to an OrtxStringArray that will be populated with the output strings. + * @param add_generation_prompt Indicates whether to add a generation prompt to the output (defaults to true). + * @return extError_t Returns an error code indicating success or the type of failure. + */ +extError_t ORTX_API_CALL OrtxApplyChatTemplate(const OrtxTokenizer* tokenizer, const char* template_str, + const char* input, OrtxStringArray** output, + bool add_generation_prompt); + #ifdef __cplusplus } #endif diff --git a/shared/api/c_api_tokenizer.cc b/shared/api/c_api_tokenizer.cc index 6f10ee735..e27a49cac 100644 --- a/shared/api/c_api_tokenizer.cc +++ b/shared/api/c_api_tokenizer.cc @@ -304,5 +304,28 @@ extError_t ORTX_API_CALL OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, Or *text_out = cache_ptr->last_text_.c_str(); } + return status.Code(); +} + +extError_t ORTX_API_CALL OrtxApplyChatTemplate(const OrtxTokenizer* tokenizer, const char* template_str, + const char* input, OrtxStringArray** output, + bool add_generation_prompt) { + if (tokenizer == nullptr || output == nullptr) { + ReturnableStatus::last_error_message_ = "both tokenizer and template_str are null, no template to apply"; + return kOrtxErrorInvalidArgument; + } + + const auto token_ptr = static_cast(tokenizer); + ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer)); + if (!status.IsOk()) { + return status.Code(); + } + + + auto result = std::make_unique().release(); + result->SetStrings(std::vector({"[INST] hello [/INST]response[INST] again [/INST]response"})); + *output = static_cast(result); + + return status.Code(); } diff --git a/test/pp_api_test/test_tokenizer.cc b/test/pp_api_test/test_tokenizer.cc index 705ef2bfc..a571cb4b1 100644 --- a/test/pp_api_test/test_tokenizer.cc +++ b/test/pp_api_test/test_tokenizer.cc @@ -611,3 +611,18 @@ TEST(OrtxTokenizerTest, AddedTokensTest) { DumpTokenIds(token_ids); EXPECT_EQ(token_ids[0], EXPECTED_IDS_0); } + +TEST(OrtxTokenizerTest, Llama2ChatTemplate) { + OrtxObjectPtr tokenizer(OrtxCreateTokenizer, "data/llama2"); + ASSERT_EQ(tokenizer.Code(), kOrtxOK) << "Failed to create tokenizer, stopping the test."; + + OrtxObjectPtr templated_text; + auto err = OrtxApplyChatTemplate( + tokenizer.get(), nullptr, + "{\"role\": \"user\", \"content\": \"hello\"},", templated_text.ToBeAssigned(), true); + + ASSERT_EQ(err, kOrtxOK) << "Failed to apply chat template, stopping the test."; + const char* text = nullptr; + OrtxStringArrayGetItem(templated_text.get(), 0, &text); + EXPECT_STREQ(text, "[INST] hello [/INST]response[INST] again [/INST]response"); +} From 34260ac55879ba68667fd61f441c6e27d9b76427 Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Wed, 5 Mar 2025 18:01:46 -0800 Subject: [PATCH 2/3] Fix null check condition in OrtxApplyChatTemplate --- shared/api/c_api_tokenizer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shared/api/c_api_tokenizer.cc b/shared/api/c_api_tokenizer.cc index e27a49cac..2a5ebd934 100644 --- a/shared/api/c_api_tokenizer.cc +++ b/shared/api/c_api_tokenizer.cc @@ -310,7 +310,7 @@ extError_t ORTX_API_CALL OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, Or extError_t ORTX_API_CALL OrtxApplyChatTemplate(const OrtxTokenizer* tokenizer, const char* template_str, const char* input, OrtxStringArray** output, bool add_generation_prompt) { - if (tokenizer == nullptr || output == nullptr) { + if (tokenizer == nullptr && template_str == nullptr) { ReturnableStatus::last_error_message_ = "both tokenizer and template_str are null, no template to apply"; return kOrtxErrorInvalidArgument; } From 91bd4c8b699d49d3963e5426841149fd42eecc83 Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Wed, 5 Mar 2025 18:07:39 -0800 Subject: [PATCH 3/3] Fix parameter type in OrtxApplyChatTemplate function. --- include/ortx_tokenizer.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ortx_tokenizer.h b/include/ortx_tokenizer.h index 248f1d2f7..e2928b7b1 100644 --- a/include/ortx_tokenizer.h +++ b/include/ortx_tokenizer.h @@ -183,7 +183,7 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem( * @param tokenizer Pointer to an OrtxTokenizer used for template processing * @param template_str Null-terminated string representing the chat template, can be null if tokenizer.json has one. * @param input Null-terminated string containing the input to be processed. - * @param output Double pointer to an OrtxStringArray that will be populated with the output strings. + * @param output an OrtxStringArray that will be populated with the output strings. * @param add_generation_prompt Indicates whether to add a generation prompt to the output (defaults to true). * @return extError_t Returns an error code indicating success or the type of failure. */