Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply_chat_template C API design in tokenizer #906

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion include/ortx_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#pragma once

#include <stdbool.h>
#include "ortx_utils.h"

#ifdef __cplusplus
Expand Down Expand Up @@ -76,7 +77,6 @@ extError_t ORTX_API_CALL OrtxCreateTokenizerFromBlob(OrtxTokenizer** tokenizer,
extError_t ORTX_API_CALL OrtxTokenize(
const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size, OrtxTokenId2DArray** output);


/**
* Converts a token to its corresponding ID.
*
Expand Down Expand Up @@ -171,6 +171,26 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* to
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(
const OrtxTokenId2DArray* token_id_2d_array, size_t index, const extTokenId_t** item, size_t* length);

/**
* @brief Applies a chat template to the given input.
*
* This function processes the specified template with the provided input using the
* tokenizer, and outputs the resulting string array. Optionally, it can include a
* generation prompt in the output. The chat template can be provided as a string or
* be retrieved from a loaded tokenizer json file which contains the chat template its json file.
* if both tokenizer and template_str are provided, the template_str will supersede the tokenizer.
*
* @param tokenizer Pointer to an OrtxTokenizer used for template processing
* @param template_str Null-terminated string representing the chat template, can be null if tokenizer.json has one.
* @param input Null-terminated string containing the input to be processed.
* @param output Double pointer to an OrtxStringArray that will be populated with the output strings.
* @param add_generation_prompt Indicates whether to add a generation prompt to the output (defaults to true).
* @return extError_t Returns an error code indicating success or the type of failure.
*/
extError_t ORTX_API_CALL OrtxApplyChatTemplate(const OrtxTokenizer* tokenizer, const char* template_str,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a tokenizer object is needed because some bos_token and eos_token, and chat-template stored in json file only be only retrieved from the tokenizer object.

const char* input, OrtxStringArray** output,
bool add_generation_prompt);

#ifdef __cplusplus
}
#endif
23 changes: 23 additions & 0 deletions shared/api/c_api_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,5 +304,28 @@ extError_t ORTX_API_CALL OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, Or
*text_out = cache_ptr->last_text_.c_str();
}

return status.Code();
}

extError_t ORTX_API_CALL OrtxApplyChatTemplate(const OrtxTokenizer* tokenizer, const char* template_str,
const char* input, OrtxStringArray** output,
bool add_generation_prompt) {
if (tokenizer == nullptr || output == nullptr) {
ReturnableStatus::last_error_message_ = "both tokenizer and template_str are null, no template to apply";
return kOrtxErrorInvalidArgument;
}

const auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer));
if (!status.IsOk()) {
return status.Code();
}


auto result = std::make_unique<ort_extensions::StringArray>().release();
result->SetStrings(std::vector<std::string>({"<s>[INST] hello [/INST]response</s>[INST] again [/INST]response</s>"}));
*output = static_cast<OrtxStringArray*>(result);


return status.Code();
}
15 changes: 15 additions & 0 deletions test/pp_api_test/test_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -611,3 +611,18 @@ TEST(OrtxTokenizerTest, AddedTokensTest) {
DumpTokenIds(token_ids);
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
}

TEST(OrtxTokenizerTest, Llama2ChatTemplate) {
OrtxObjectPtr<OrtxTokenizer> tokenizer(OrtxCreateTokenizer, "data/llama2");
ASSERT_EQ(tokenizer.Code(), kOrtxOK) << "Failed to create tokenizer, stopping the test.";

OrtxObjectPtr<OrtxStringArray> templated_text;
auto err = OrtxApplyChatTemplate(
tokenizer.get(), nullptr,
"{\"role\": \"user\", \"content\": \"hello\"},", templated_text.ToBeAssigned(), true);

ASSERT_EQ(err, kOrtxOK) << "Failed to apply chat template, stopping the test.";
const char* text = nullptr;
OrtxStringArrayGetItem(templated_text.get(), 0, &text);
EXPECT_STREQ(text, "<s>[INST] hello [/INST]response</s>[INST] again [/INST]response</s>");
}
Loading