Skip to content

Commit

Permalink
fix deepseek issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Sayan Shaw committed Mar 6, 2025
1 parent 3043d3c commit d08a617
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
24 changes: 12 additions & 12 deletions shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ OrtxStatus TokenizerImpl::Llama3ChatTemplate(
OrtxStatus TokenizerImpl::DeepSeekChatTemplate(
std::string* output,
bool add_generation_prompt = false,
const std::string& eos_token = "<|eot_id|>",
const std::string& bos_token = "<|begin_of_text|>") { // Add bos_token as a parameter
const std::string& eos_token = "<|end▁of▁sentence|>",
const std::string& bos_token = "<|begin▁of▁sentence|>") { // Add bos_token as a parameter

// Clear the output string before starting
output->clear();
Expand Down Expand Up @@ -341,7 +341,7 @@ OrtxStatus TokenizerImpl::DeepSeekChatTemplate(
// Handle user message
if (role == "user") {
is_tool = false;
*output += "<|User|>" + content;
*output += "<User>" + content;
}

// Handle assistant message with tool calls
Expand All @@ -367,46 +367,46 @@ OrtxStatus TokenizerImpl::DeepSeekChatTemplate(

// Handle the first tool call differently
if (is_first) {
*output += "<|Assistant|><|tool_calls_begin|><|tool_call_begin|>" + tool_calls_json[0]["type"].get<std::string>() + "<|tool_sep|>" + tool_calls_json[0]["function"]["name"].get<std::string>() + "\njson\n" + tool_calls_json[0]["function"]["arguments"].dump() + "\n<|tool_call_end|>";
*output += "<Assistant|><|tool_calls_begin|><|tool_call_begin>" + tool_calls_json[0]["type"].get<std::string>() + "<tool_sep>" + tool_calls_json[0]["function"]["name"].get<std::string>() + "\njson\n" + tool_calls_json[0]["function"]["arguments"].dump() + "\n<tool_call_end>";
is_first = false; // Mark as first tool call processed
} else {
// Subsequent tool calls
*output += "\n<|tool_call_begin|>" + tool_calls_json[0]["type"].get<std::string>() + "<|tool_sep|>" + tool_calls_json[0]["function"]["name"].get<std::string>() + "\njson\n" + tool_calls_json[0]["function"]["arguments"].dump() + "\n<|tool_call_end|>";
*output += "\n<tool_call_begin>" + tool_calls_json[0]["type"].get<std::string>() + "<tool_sep>" + tool_calls_json[0]["function"]["name"].get<std::string>() + "\njson\n" + tool_calls_json[0]["function"]["arguments"].dump() + "\n<tool_call_end>";
}

*output += "<|tool_calls_end|><|end_of_sentence|>";
*output += "<tool_calls_end|><|end▁of▁sentence|>";
}

// Handle assistant message without tool calls
if (role == "assistant" && !content.empty()) {
if (is_tool) {
*output += "<|tool_outputs_end|>" + content + "<|end_of_sentence|>";
*output += "<tool_outputs_end>" + content + "<|end▁of▁sentence|>";
is_tool = false;
} else {
*output += "<|Assistant|>" + content + "<|end_of_sentence|>";
*output += "<Assistant>" + content + "<|end▁of▁sentence|>";
}
}

// Handle tool messages
if (role == "tool") {
is_tool = true;
if (is_output_first) {
*output += "<|tool_outputs_begin|><|tool_output_begin|>" + content + "<|tool_output_end|>";
*output += "<tool_outputs_begin|><|tool_output_begin>" + content + "<tool_output_end>";
is_output_first = false;
} else {
*output += "\n<|tool_output_begin|>" + content + "<|tool_output_end|>";
*output += "\n<tool_output_begin>" + content + "<tool_output_end>";
}
}
}

// If still in a tool message, close it
if (is_tool) {
*output += "<|tool_outputs_end|>";
*output += "<tool_outputs_end>";
}

// Add generation prompt or eos_token at the end
if (add_generation_prompt && !is_tool) {
*output += "<|Assistant|><think>\n";
*output += "<Assistant><think>\n";
} else {
*output += eos_token; // Add the EOS token instead
}
Expand Down
1 change: 1 addition & 0 deletions shared/api/tokenizer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class TokenizerImpl : public OrtxObjectImpl {
const std::string PHI4_CHAT_TEMPLATE;
const std::string PHI3_5_CHAT_TEMPLATE;
const std::string LLAMA3_CHAT_TEMPLATE;
const std::string DEEPSEEK_CHAT_TEMPLATE;

std::string chat_template;
std::vector<std::unordered_map<std::string, std::string>> messages;
Expand Down

0 comments on commit d08a617

Please sign in to comment.