Skip to content

Commit

Permalink
add base llama 3 support
Browse files Browse the repository at this point in the history
  • Loading branch information
Sayan Shaw committed Mar 8, 2025
1 parent 3a93e4a commit 7e6e4b6
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
71 changes: 55 additions & 16 deletions shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ OrtxStatus TokenizerImpl::Phi3_5ChatTemplate(std::string* output, bool add_gener
}
}

// Add generation prompt or eos_token
// Add generation prompt or EOS token
if (add_generation_prompt) {
*output += "<|assistant|>\n";
} else {
Expand All @@ -211,20 +211,57 @@ OrtxStatus TokenizerImpl::Phi3_5ChatTemplate(std::string* output, bool add_gener
return OrtxStatus(kOrtxOK, "Created Phi-3.5 chat template.");
}

OrtxStatus TokenizerImpl::Llama3ChatTemplate(
std::string* output,
bool add_generation_prompt = true,
const std::string& eos_token = "<|eot_id|>",
const std::string& bos_token = "<|begin_of_text|>") {

// Clear the output string before starting
output->clear();

// Iterate over the messages to construct the template
for (size_t i = 0; i < messages.size(); ++i) {
const auto& message = messages[i];
std::string role = message.at("role");
std::string content = message.at("content");

// Build the message with header and content
std::string formatted_content = "<|start_header_id|>" + role + "<|end_header_id|>\n\n" + content + eos_token;

// Add BOS token only to the first message
if (i == 0) {
formatted_content = bos_token + formatted_content;
}

// Append the formatted message to the output
*output += formatted_content;
}

// Add generation prompt or eos_token at the end
if (add_generation_prompt) {
*output += "<|start_header_id|>assistant<|end_header_id|>\n\n";
} else {
*output += eos_token;
}

return OrtxStatus(kOrtxOK, "Created Llama 3 chat template.");
}

OrtxStatus TokenizerImpl::Llama3_2ChatTemplate(
std::string* output,
bool add_generation_prompt = true,
const std::string& eos_token = "<|eot_id|>",
const std::vector<std::string>& custom_tools = {},
bool tools_in_user_message = true,
const std::string& strftime_now = "",
const std::string& bos_token = "<|begin_of_text|>") { // Add bos_token as a parameter
const std::string& bos_token = "<|begin_of_text|>") {

// Clear the output string before starting
output->clear();

// Prepend BOS token at the start of the output
*output += bos_token; // BOS token goes first
*output += bos_token;

// Initialize date_string with default value
std::string date_string = "26 Jul 2024"; // Default date
Expand Down Expand Up @@ -298,27 +335,27 @@ OrtxStatus TokenizerImpl::Llama3_2ChatTemplate(
if (add_generation_prompt) {
*output += "<|start_header_id|>assistant<|end_header_id|>\n\n";
} else {
*output += eos_token; // Add the EOS token instead
*output += eos_token;
}

return OrtxStatus(kOrtxOK, "Created Llama3 chat template.");
return OrtxStatus(kOrtxOK, "Created Llama 3.2 chat template.");
}

OrtxStatus TokenizerImpl::Llama3_3ChatTemplate(
std::string* output,
bool add_generation_prompt = true,
const std::string& eos_token = "<|eot_id|>",
const std::vector<std::string>& custom_tools = {},
const std::vector<std::string>& builtin_tools = {}, // Added builtin_tools as parameter
const std::vector<std::string>& builtin_tools = {},
bool tools_in_user_message = true,
const std::string& date_string = "26 Jul 2024", // Default date string parameter
const std::string& bos_token = "<|begin_of_text|>") { // BOS token as a parameter
const std::string& date_string = "26 Jul 2024",
const std::string& bos_token = "<|begin_of_text|>") {

// Clear the output string before starting
output->clear();

// Prepend BOS token at the start of the output
*output += bos_token; // BOS token goes first
*output += bos_token;

// Loop through messages and process each one
for (const auto& message : messages) {
Expand Down Expand Up @@ -408,7 +445,7 @@ OrtxStatus TokenizerImpl::Llama3_3ChatTemplate(
if (!builtin_tools.empty()) {
*output += "<|eom_id|>";
} else {
*output += eos_token; // Replaced <|eot_id|> with eos_token
*output += eos_token;
}
}

Expand All @@ -418,25 +455,25 @@ OrtxStatus TokenizerImpl::Llama3_3ChatTemplate(
*output += "<|start_header_id|>" + role + "<|end_header_id|>\n\n";
}
*output += content;
*output += eos_token; // Replaced <|eot_id|> with eos_token
*output += eos_token;
}
}

// Add generation prompt or eos_token at the end
if (add_generation_prompt) {
*output += "<|start_header_id|>assistant<|end_header_id|>\n\n";
} else {
*output += eos_token; // Replaced <|eot_id|> with eos_token
*output += eos_token;
}

return OrtxStatus(kOrtxOK, "Created chat template.");
return OrtxStatus(kOrtxOK, "Created Llama 3.1/3.3 chat template."); // Llama 3.1 and 3.3 have the same chat template
}

OrtxStatus TokenizerImpl::DeepSeekChatTemplate(
std::string* output,
bool add_generation_prompt = false,
const std::string& eos_token = "<|end▁of▁sentence|>",
const std::string& bos_token = "<|begin▁of▁sentence|>") { // Add bos_token as a parameter
const std::string& bos_token = "<|begin▁of▁sentence|>") {

// Clear the output string before starting
output->clear();
Expand Down Expand Up @@ -466,7 +503,7 @@ OrtxStatus TokenizerImpl::DeepSeekChatTemplate(
// Process each message in the conversation
for (const auto& message : messages) {
std::string role = message.at("role");
std::string content = message.at("content"); // Now content is correctly defined here
std::string content = message.at("content");

// Handle user message
if (role == "user") {
Expand Down Expand Up @@ -541,7 +578,7 @@ OrtxStatus TokenizerImpl::DeepSeekChatTemplate(
if (add_generation_prompt && !is_tool) {
*output += "<|Assistant|><think>\n";
} else {
*output += eos_token; // Add the EOS token instead
*output += eos_token;
}

return OrtxStatus(kOrtxOK, "Created DeepSeek chat template.");
Expand All @@ -558,6 +595,8 @@ OrtxStatus TokenizerImpl::ApplyChatTemplate(std::vector<std::unordered_map<std::
return Phi4ChatTemplate(output, add_generation_prompt);
} else if (chat_template == PHI3_5_CHAT_TEMPLATE) {
return Phi3_5ChatTemplate(output, add_generation_prompt);
} else if (chat_template == LLAMA3_CHAT_TEMPLATE) {
return Llama3ChatTemplate(output, add_generation_prompt);
} else if (chat_template == LLAMA3_2_CHAT_TEMPLATE) {
return Llama3_2ChatTemplate(output, add_generation_prompt);
} else if (chat_template == LLAMA3_3_CHAT_TEMPLATE) {
Expand Down
2 changes: 2 additions & 0 deletions shared/api/tokenizer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class TokenizerImpl : public OrtxObjectImpl {

OrtxStatus Phi3_5ChatTemplate(std::string* output, bool add_generation_prompt, const std::string& eos_token);

OrtxStatus Llama3ChatTemplate(std::string* output, bool add_generation_prompt, const std::string& eos_token, const std::string& bos_token);

OrtxStatus Llama3_2ChatTemplate(std::string* output, bool add_generation_prompt, const std::string& eos_token, const std::vector<std::string>& custom_tools, bool tools_in_user_message, const std::string& strftime_now, const std::string& bos_token);

OrtxStatus Llama3_3ChatTemplate(std::string* output, bool add_generation_prompt, const std::string& eos_token, const std::vector<std::string>& custom_tools, const std::vector<std::string>& builtin_tools, bool tools_in_user_message, const std::string& date_string, const std::string& bos_token);
Expand Down

0 comments on commit 7e6e4b6

Please sign in to comment.