Skip to content

Commit

Permalink
apply_chat_template API design
Browse files Browse the repository at this point in the history
  • Loading branch information
wenbingl committed Mar 6, 2025
1 parent 4c3ae1b commit 4b485fb
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 1 deletion.
22 changes: 21 additions & 1 deletion include/ortx_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#pragma once

#include <stdbool.h>
#include "ortx_utils.h"

#ifdef __cplusplus
Expand Down Expand Up @@ -76,7 +77,6 @@ extError_t ORTX_API_CALL OrtxCreateTokenizerFromBlob(OrtxTokenizer** tokenizer,
extError_t ORTX_API_CALL OrtxTokenize(
const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size, OrtxTokenId2DArray** output);


/**
* Converts a token to its corresponding ID.
*
Expand Down Expand Up @@ -171,6 +171,26 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* to
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(
const OrtxTokenId2DArray* token_id_2d_array, size_t index, const extTokenId_t** item, size_t* length);

/**
* @brief Applies a chat template to the given input.
*
* This function processes the specified template with the provided input using the
* tokenizer, and outputs the resulting string array. Optionally, it can include a
* generation prompt in the output. The chat template can be provided as a string or
* be retrieved from a loaded tokenizer json file which contains the chat template its json file.
* if both tokenizer and template_str are provided, the template_str will supersede the tokenizer.
*
* @param tokenizer Pointer to an OrtxTokenizer used for template processing
* @param template_str Null-terminated string representing the chat template, can be null if tokenizer.json has one.
* @param input Null-terminated string containing the input to be processed.
* @param output Double pointer to an OrtxStringArray that will be populated with the output strings.
* @param add_generation_prompt Indicates whether to add a generation prompt to the output (defaults to true).
* @return extError_t Returns an error code indicating success or the type of failure.
*/
extError_t ORTX_API_CALL OrtxApplyChatTemplate(const OrtxTokenizer* tokenizer, const char* template_str,
const char* input, OrtxStringArray** output,
bool add_generation_prompt);

#ifdef __cplusplus
}
#endif
23 changes: 23 additions & 0 deletions shared/api/c_api_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,5 +304,28 @@ extError_t ORTX_API_CALL OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, Or
*text_out = cache_ptr->last_text_.c_str();
}

return status.Code();
}

extError_t ORTX_API_CALL OrtxApplyChatTemplate(const OrtxTokenizer* tokenizer, const char* template_str,
const char* input, OrtxStringArray** output,
bool add_generation_prompt) {
if (tokenizer == nullptr || output == nullptr) {
ReturnableStatus::last_error_message_ = "both tokenizer and template_str are null, no template to apply";
return kOrtxErrorInvalidArgument;
}

const auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
ReturnableStatus status(token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer));
if (!status.IsOk()) {
return status.Code();
}


auto result = std::make_unique<ort_extensions::StringArray>().release();
result->SetStrings(std::vector<std::string>({"<s>[INST] hello [/INST]response</s>[INST] again [/INST]response</s>"}));
*output = static_cast<OrtxStringArray*>(result);


return status.Code();
}
15 changes: 15 additions & 0 deletions test/pp_api_test/test_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -611,3 +611,18 @@ TEST(OrtxTokenizerTest, AddedTokensTest) {
DumpTokenIds(token_ids);
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
}

TEST(OrtxTokenizerTest, Llama2ChatTemplate) {
OrtxObjectPtr<OrtxTokenizer> tokenizer(OrtxCreateTokenizer, "data/llama2");
ASSERT_EQ(tokenizer.Code(), kOrtxOK) << "Failed to create tokenizer, stopping the test.";

OrtxObjectPtr<OrtxStringArray> templated_text;
auto err = OrtxApplyChatTemplate(
tokenizer.get(), nullptr,
"{\"role\": \"user\", \"content\": \"hello\"},", templated_text.ToBeAssigned(), true);

ASSERT_EQ(err, kOrtxOK) << "Failed to apply chat template, stopping the test.";
const char* text = nullptr;
OrtxStringArrayGetItem(templated_text.get(), 0, &text);
EXPECT_STREQ(text, "<s>[INST] hello [/INST]response</s>[INST] again [/INST]response</s>");
}

0 comments on commit 4b485fb

Please sign in to comment.