From 7e6e4b67f9c37bb7e573840c0d4742b0165051c7 Mon Sep 17 00:00:00 2001 From: Sayan Shaw Date: Fri, 7 Mar 2025 16:06:30 -0800 Subject: [PATCH] add base llama 3 support --- shared/api/tokenizer_impl.cc | 71 ++++++++++++++++++++++++++++-------- shared/api/tokenizer_impl.h | 2 + 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/shared/api/tokenizer_impl.cc b/shared/api/tokenizer_impl.cc index f8d6d506..fc815757 100644 --- a/shared/api/tokenizer_impl.cc +++ b/shared/api/tokenizer_impl.cc @@ -201,7 +201,7 @@ OrtxStatus TokenizerImpl::Phi3_5ChatTemplate(std::string* output, bool add_gener } } - // Add generation prompt or eos_token + // Add generation prompt or EOS token if (add_generation_prompt) { *output += "<|assistant|>\n"; } else { @@ -211,6 +211,43 @@ OrtxStatus TokenizerImpl::Phi3_5ChatTemplate(std::string* output, bool add_gener return OrtxStatus(kOrtxOK, "Created Phi-3.5 chat template."); } +OrtxStatus TokenizerImpl::Llama3ChatTemplate( + std::string* output, + bool add_generation_prompt = true, + const std::string& eos_token = "<|eot_id|>", + const std::string& bos_token = "<|begin_of_text|>") { + + // Clear the output string before starting + output->clear(); + + // Iterate over the messages to construct the template + for (size_t i = 0; i < messages.size(); ++i) { + const auto& message = messages[i]; + std::string role = message.at("role"); + std::string content = message.at("content"); + + // Build the message with header and content + std::string formatted_content = "<|start_header_id|>" + role + "<|end_header_id|>\n\n" + content + eos_token; + + // Add BOS token only to the first message + if (i == 0) { + formatted_content = bos_token + formatted_content; + } + + // Append the formatted message to the output + *output += formatted_content; + } + + // Add generation prompt or eos_token at the end + if (add_generation_prompt) { + *output += "<|start_header_id|>assistant<|end_header_id|>\n\n"; + } else { + *output += eos_token; + } + + return OrtxStatus(kOrtxOK, "Created Llama 3 chat template."); +} + OrtxStatus TokenizerImpl::Llama3_2ChatTemplate( std::string* output, bool add_generation_prompt = true, @@ -218,13 +255,13 @@ OrtxStatus TokenizerImpl::Llama3_2ChatTemplate( const std::vector& custom_tools = {}, bool tools_in_user_message = true, const std::string& strftime_now = "", - const std::string& bos_token = "<|begin_of_text|>") { // Add bos_token as a parameter + const std::string& bos_token = "<|begin_of_text|>") { // Clear the output string before starting output->clear(); // Prepend BOS token at the start of the output - *output += bos_token; // BOS token goes first + *output += bos_token; // Initialize date_string with default value std::string date_string = "26 Jul 2024"; // Default date @@ -298,10 +335,10 @@ OrtxStatus TokenizerImpl::Llama3_2ChatTemplate( if (add_generation_prompt) { *output += "<|start_header_id|>assistant<|end_header_id|>\n\n"; } else { - *output += eos_token; // Add the EOS token instead + *output += eos_token; } - return OrtxStatus(kOrtxOK, "Created Llama3 chat template."); + return OrtxStatus(kOrtxOK, "Created Llama 3.2 chat template."); } OrtxStatus TokenizerImpl::Llama3_3ChatTemplate( @@ -309,16 +346,16 @@ OrtxStatus TokenizerImpl::Llama3_3ChatTemplate( bool add_generation_prompt = true, const std::string& eos_token = "<|eot_id|>", const std::vector& custom_tools = {}, - const std::vector& builtin_tools = {}, // Added builtin_tools as parameter + const std::vector& builtin_tools = {}, bool tools_in_user_message = true, - const std::string& date_string = "26 Jul 2024", // Default date string parameter - const std::string& bos_token = "<|begin_of_text|>") { // BOS token as a parameter + const std::string& date_string = "26 Jul 2024", + const std::string& bos_token = "<|begin_of_text|>") { // Clear the output string before starting output->clear(); // Prepend BOS token at the start of the output - *output += bos_token; // BOS token goes first + *output += bos_token; // Loop through messages and process each one for (const auto& message : messages) { @@ -408,7 +445,7 @@ OrtxStatus TokenizerImpl::Llama3_3ChatTemplate( if (!builtin_tools.empty()) { *output += "<|eom_id|>"; } else { - *output += eos_token; // Replaced <|eot_id|> with eos_token + *output += eos_token; } } @@ -418,7 +455,7 @@ OrtxStatus TokenizerImpl::Llama3_3ChatTemplate( *output += "<|start_header_id|>" + role + "<|end_header_id|>\n\n"; } *output += content; - *output += eos_token; // Replaced <|eot_id|> with eos_token + *output += eos_token; } } @@ -426,17 +463,17 @@ OrtxStatus TokenizerImpl::Llama3_3ChatTemplate( if (add_generation_prompt) { *output += "<|start_header_id|>assistant<|end_header_id|>\n\n"; } else { - *output += eos_token; // Replaced <|eot_id|> with eos_token + *output += eos_token; } - return OrtxStatus(kOrtxOK, "Created chat template."); + return OrtxStatus(kOrtxOK, "Created Llama 3.1/3.3 chat template."); // Llama 3.1 and 3.3 have the same chat template } OrtxStatus TokenizerImpl::DeepSeekChatTemplate( std::string* output, bool add_generation_prompt = false, const std::string& eos_token = "<|end▁of▁sentence|>", - const std::string& bos_token = "<|begin▁of▁sentence|>") { // Add bos_token as a parameter + const std::string& bos_token = "<|begin▁of▁sentence|>") { // Clear the output string before starting output->clear(); @@ -466,7 +503,7 @@ OrtxStatus TokenizerImpl::DeepSeekChatTemplate( // Process each message in the conversation for (const auto& message : messages) { std::string role = message.at("role"); - std::string content = message.at("content"); // Now content is correctly defined here + std::string content = message.at("content"); // Handle user message if (role == "user") { @@ -541,7 +578,7 @@ OrtxStatus TokenizerImpl::DeepSeekChatTemplate( if (add_generation_prompt && !is_tool) { *output += "<|Assistant|>\n"; } else { - *output += eos_token; // Add the EOS token instead + *output += eos_token; } return OrtxStatus(kOrtxOK, "Created DeepSeek chat template."); @@ -558,6 +595,8 @@ OrtxStatus TokenizerImpl::ApplyChatTemplate(std::vector& custom_tools, bool tools_in_user_message, const std::string& strftime_now, const std::string& bos_token); OrtxStatus Llama3_3ChatTemplate(std::string* output, bool add_generation_prompt, const std::string& eos_token, const std::vector& custom_tools, const std::vector& builtin_tools, bool tools_in_user_message, const std::string& date_string, const std::string& bos_token);