Skip to content

Commit

Permalink
add phi 3 small and medium support
Browse files Browse the repository at this point in the history
  • Loading branch information
Sayan Shaw committed Mar 8, 2025
1 parent 231093c commit 34b8ed7
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
57 changes: 55 additions & 2 deletions shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ OrtxStatus TokenizerImpl::BatchDecode(const std::vector<span<extTokenId_t const>
// Constant string variable to store predefined chat template strings for popular supported models
const std::string PHI_VISION_CHAT_TEMPLATE = R"({% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %})";
const std::string PHI3_CHAT_TEMPLATE = R"({% for message in messages %}{% if message['role'] == 'system' %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %})";
const std::string PHI3_SMALL_CHAT_TEMPLATE = R"({{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %})";
const std::string PHI3_MEDIUM_CHAT_TEMPLATE = R"({% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %})";
const std::string PHI3_5_CHAT_TEMPLATE = R"({% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %})";
const std::string PHI4_CHAT_TEMPLATE = R"({% for message in messages %}{% if message['role'] == 'system' and 'tools' in message and message['tools'] is not none %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|tool|>' + message['tools'] + '<|/tool|>' + '<|end|>' }}{% else %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|end|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %})";
const std::string LLAMA2_CHAT_TEMPLATE = R"({% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %})";
Expand Down Expand Up @@ -169,7 +171,7 @@ OrtxStatus TokenizerImpl::PhiVisionChatTemplate(std::string* output, bool add_ge
*output += "<|assistant|>\n";
}

return OrtxStatus(kOrtxOK, "Created Phi Vision chat template.");
return OrtxStatus(kOrtxOK, "Created Phi vision chat template.");
}

// Note Phi-3 and Phi-3.5 have slightly different chat template strings but share the same functionality so this method can be used for both.
Expand Down Expand Up @@ -202,7 +204,54 @@ OrtxStatus TokenizerImpl::Phi3ChatTemplate(std::string* output, bool add_generat
*output += eos_token;
}

return OrtxStatus(kOrtxOK, "Created Phi-3.5 chat template.");
return OrtxStatus(kOrtxOK, "Created Phi-3/3.5 chat template.");
}

OrtxStatus TokenizerImpl::Phi3SmallChatTemplate(std::string* output, bool add_generation_prompt = true, const std::string& eos_token = "<|endoftext|>", const std::string& bos_token = "<|startoftext|>") {

// Clear the output string before starting
output->clear();

// Add the beginning-of-sequence token
*output += bos_token;

// Iterate over the messages
for (const auto& message : messages) {
std::string role = message.at("role");
std::string content = message.at("content");

// Format the message according to the role
*output += "<|" + role + "|>\n" + content + "<|end|>\n";
}

// Add the generation prompt or eos_token
if (add_generation_prompt) {
*output += "<|assistant|>\n";
} else {
*output += eos_token;
}

return OrtxStatus(kOrtxOK, "Created Phi-3-small chat template.");
}

OrtxStatus TokenizerImpl::Phi3MediumChatTemplate(std::string* output) {
// Clear the output string before starting
output->clear();

// Process the messages
for (const auto& message : messages) {
std::string role = message.at("role");
std::string content = message.at("content");

// Format based on role (user/assistant)
if (role == "user") {
*output += "<|user|>\n" + content + "<|end|>\n<|assistant|>\n";
} else if (role == "assistant") {
*output += content + "<|end|>\n";
}
}

return OrtxStatus(kOrtxOK, "Created Phi-3-medium chat template.");
}

OrtxStatus TokenizerImpl::Phi4ChatTemplate(std::string* output, bool add_generation_prompt = true, const std::string& eos_token = "<|endoftext|>") {
Expand Down Expand Up @@ -671,6 +720,10 @@ OrtxStatus TokenizerImpl::ApplyChatTemplate(std::vector<std::unordered_map<std::
return Phi4ChatTemplate(output, add_generation_prompt);
} else if (chat_template == PHI3_CHAT_TEMPLATE || chat_template == PHI3_5_CHAT_TEMPLATE) {
return Phi3ChatTemplate(output, add_generation_prompt);
} else if (chat_template == PHI3_SMALL_CHAT_TEMPLATE) {
return Phi3SmallChatTemplate(output, add_generation_prompt);
} else if (chat_template == PHI3_MEDIUM_CHAT_TEMPLATE) {
return Phi3MediumChatTemplate(output);
} else if (chat_template == PHI_VISION_CHAT_TEMPLATE) {
return PhiVisionChatTemplate(output, add_generation_prompt);
} else if (chat_template == LLAMA2_CHAT_TEMPLATE) {
Expand Down
6 changes: 6 additions & 0 deletions shared/api/tokenizer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class TokenizerImpl : public OrtxObjectImpl {

const std::string PHI_VISION_CHAT_TEMPLATE;
const std::string PHI3_CHAT_TEMPLATE;
const std::string PHI3_SMALL_CHAT_TEMPLATE;
const std::string PHI3_MEDIUM_CHAT_TEMPLATE;
const std::string PHI3_5_CHAT_TEMPLATE;
const std::string PHI4_CHAT_TEMPLATE;
const std::string LLAMA2_CHAT_TEMPLATE;
Expand All @@ -70,6 +72,10 @@ class TokenizerImpl : public OrtxObjectImpl {
OrtxStatus PhiVisionChatTemplate(std::string* output, bool add_generation_prompt);

OrtxStatus Phi3ChatTemplate(std::string* output, bool add_generation_prompt, const std::string& eos_token);

OrtxStatus Phi3SmallChatTemplate(std::string* output, bool add_generation_prompt, const std::string& eos_token, const std::string& bos_token);

OrtxStatus Phi3MediumChatTemplate(std::string* output);

OrtxStatus Phi4ChatTemplate(std::string* output, bool add_generation_prompt, const std::string& eos_token);

Expand Down

0 comments on commit 34b8ed7

Please sign in to comment.