Skip to content

Commit

Permalink
add json loading and unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Sayan Shaw committed Mar 5, 2025
1 parent db9f715 commit 9955f0a
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 12 deletions.
5 changes: 5 additions & 0 deletions operators/tokenizer/tokenizer_jsconfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class TokenJsonConfig final {
bos_token_ = "<s>";
eos_token_ = "</s>";
unk_token_ = "<unk>";
chat_template_ = ""; // can add default chat template
return {};
}

Expand All @@ -91,6 +92,8 @@ class TokenJsonConfig final {
parse_token(json_config, "eos_token", eos_token_);
parse_token(json_config, "unk_token", unk_token_);

parse_token(json_config, "chat_template", chat_template_);

auto pad_iter = json_config.find("pad_token");
if (pad_iter != json_config.end() && pad_iter->is_string()) {
pad_token_ = json_config.value("pad_token", "");
Expand Down Expand Up @@ -245,6 +248,8 @@ class TokenJsonConfig final {
std::string unk_token_;
std::string pad_token_;

std::string chat_template_;

AddedTokenMap added_tokens_;

static AddedToken ParseAddedToken(const json& token) {
Expand Down
28 changes: 17 additions & 11 deletions shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ OrtxStatus TokenizerImpl::Load(const std::string& tok_path) {
return status;
}

chat_template = tok_config_->chat_template_;

return LoadTokenizer();
}

Expand Down Expand Up @@ -143,7 +145,7 @@ std::vector<std::unordered_map<std::string, std::string>> messages;
std::string chat_template;

// Phi4ChatTemplate method to process messages and store result in output
OrtxStatus TokenizerImpl::Phi4ChatTemplate(std::string* output, bool add_generation_prompt = true, const std::string& eos_token = "<|eos|>") {
OrtxStatus TokenizerImpl::Phi4ChatTemplate(std::string* output, bool add_generation_prompt = true, const std::string& eos_token = "<|endoftext|>") {
// Clear the output string before starting
output->clear();

Expand All @@ -155,26 +157,26 @@ OrtxStatus TokenizerImpl::Phi4ChatTemplate(std::string* output, bool add_generat
// Check if "tools" is present in the message and is not empty for "system" role
if (role == "system" && message.find("tools") != message.end() && !message.at("tools").empty()) {
std::string tools = message.at("tools");
*output += "<|" + role + "|>\n";
*output += content + "<|tool|>" + tools + "<|/tool|>" + "<|end|>\n";
*output += "<|" + role + "|>";
*output += content + "<|tool|>" + tools + "<|/tool|>" + "<|end|>";
} else {
// For other messages, no tools
*output += "<|" + role + "|>\n";
*output += content + "<|end|>\n";
*output += "<|" + role + "|>";
*output += content + "<|end|>";
}
}

// Add generation prompt or eos_token
if (add_generation_prompt) {
*output += "<|assistant|>\n";
*output += "<|assistant|>";
} else {
*output += eos_token;
}

return OrtxStatus(kOrtxOK, "Created chat template.");
}

OrtxStatus TokenizerImpl::Phi3_5ChatTemplate(std::string* output, bool add_generation_prompt = true, const std::string& eos_token = "<|eos|>") {
OrtxStatus TokenizerImpl::Phi3_5ChatTemplate(std::string* output, bool add_generation_prompt = true, const std::string& eos_token = "<|endoftext|>") {
// Clear the output string before starting
output->clear();

Expand Down Expand Up @@ -299,14 +301,18 @@ OrtxStatus TokenizerImpl::Llama3ChatTemplate(
}

// ApplyChatTemplate method to choose the template logic based on chat_template
OrtxStatus TokenizerImpl::ApplyChatTemplate(std::string* output, bool add_generation_prompt = true, const std::string& eos_token = "<|eos|>") {
OrtxStatus TokenizerImpl::ApplyChatTemplate(std::vector<std::unordered_map<std::string, std::string>> message_list, std::string* output, bool add_generation_prompt = true) {

// Initialize messages
messages = message_list;

// Check if the chat_template matches any of the supported template strings and if so apply the corresponding template.
if (chat_template == PHI4_CHAT_TEMPLATE) {
return Phi4ChatTemplate(output, add_generation_prompt, eos_token);
return Phi4ChatTemplate(output, add_generation_prompt);
} else if (chat_template == PHI3_5_CHAT_TEMPLATE) {
return Phi3_5ChatTemplate(output, add_generation_prompt, eos_token);
return Phi3_5ChatTemplate(output, add_generation_prompt);
} else if (chat_template == LLAMA3_CHAT_TEMPLATE) {
return Llama3ChatTemplate(output, add_generation_prompt, eos_token);
return Llama3ChatTemplate(output, add_generation_prompt);
} else {
// Handle other templates or custom logic here
return OrtxStatus(kOrtxErrorNotImplemented, "The provided chat template is currently not supported. Custom template handling needed.");
Expand Down
9 changes: 8 additions & 1 deletion shared/api/tokenizer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,20 @@ class TokenizerImpl : public OrtxObjectImpl {

OrtxStatus BatchDecode(const std::vector<span<extTokenId_t const>>& t_ids, std::vector<std::string>& t_text) const;

const std::string PHI4_CHAT_TEMPLATE;
const std::string PHI3_5_CHAT_TEMPLATE;
const std::string LLAMA3_CHAT_TEMPLATE;

std::string chat_template;
std::vector<std::unordered_map<std::string, std::string>> messages;

OrtxStatus Phi4ChatTemplate(std::string* output, bool add_generation_prompt, const std::string& eos_token);

OrtxStatus Phi3_5ChatTemplate(std::string* output, bool add_generation_prompt, const std::string& eos_token);

OrtxStatus Llama3ChatTemplate(std::string* output, bool add_generation_prompt, const std::string& eos_token, const std::vector<std::string>& custom_tools, bool tools_in_user_message, const std::string& strftime_now, const std::string& bos_token);

OrtxStatus ApplyChatTemplate(std::string* output, bool add_generation_prompt, const std::string& eos_token);
OrtxStatus ApplyChatTemplate(std::vector<std::unordered_map<std::string, std::string>> messages, std::string* output, bool add_generation_prompt);

OrtxStatus Id2Token(extTokenId_t id, std::string& token, TokenizerDecodingState** state) const;

Expand Down
27 changes: 27 additions & 0 deletions test/pp_api_test/test_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -611,3 +611,30 @@ TEST(OrtxTokenizerTest, AddedTokensTest) {
DumpTokenIds(token_ids);
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
}

TEST(OrtxTokenizerTest, ChatTemplate) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();

// Since we do not have local test files for Phi4/Llama3/DeepSeek, we simply manually
// set the chat_template, but otherwise this will be loaded from the tokenizer config automatically.
tokenizer->chat_template = tokenizer->PHI4_CHAT_TEMPLATE;

std::vector<std::unordered_map<std::string, std::string>> messages = {
{{"role", "system"}, {"content", "You are a helpful assistant."}, {"tools", "Calculator"}},
{{"role", "user"}, {"content", "How do I add two numbers?"}},
{{"role", "assistant"}, {"content", "You can add numbers by using the '+' operator."}}
};

// From HuggingFace Python output for 'microsoft/Phi-4-multimodal-instruct'
std::string expected_output = "<|system|>You are a helpful assistant.<|tool|>Calculator<|/tool|><|end|><|user|>How do I add two numbers?<|end|><|assistant|>You can add numbers by using the '+' operator.<|end|><|assistant|>";

std::string output = "";

auto status = tokenizer->ApplyChatTemplate(messages, &output, true);

if (!status.IsOk()) {
std::cout << status.ToString() << std::endl;
}

ASSERT_EQ(output, expected_output);
}

0 comments on commit 9955f0a

Please sign in to comment.