diff --git a/operators/tokenizer/bpe_streaming.hpp b/operators/tokenizer/bpe_streaming.hpp index 33b1cdee..53a87c77 100644 --- a/operators/tokenizer/bpe_streaming.hpp +++ b/operators/tokenizer/bpe_streaming.hpp @@ -47,8 +47,6 @@ class BpeStreamingDecoder : public KernelBpeDecoder { return {}; } - - OrtxStatus Id2Token(extTokenId_t id, std::string& token, bool skip_special_tokens, @@ -95,17 +93,19 @@ class BpeStreamingDecoder : public KernelBpeDecoder { } OrtxStatus SpmId2Token(extTokenId_t id, std::string& token, bool& f_special_last) const { - - std::string piece = id < arr_vocab_.size() ? arr_vocab_[id] : ""; - bool f_special = false; - if (piece.empty() || all_special_ids_.count(id)) { - token = ""; - f_special = true; - } else if (IsSpmByteWord(piece)) { - char buf[3] = {piece[3], piece[4], 0}; // something like <0x20> - token = {static_cast(strtol(buf, NULL, 16))}; + bool f_special = all_special_ids_.count(id) ? true : false; + if (added_tokens_.count(id)) { + token = added_tokens_.at(id); } else { - token = ReplaceAll(piece, std::string(ort_extensions::spm_escaped_space), " "); + std::string piece = id < arr_vocab_.size() ? arr_vocab_[id] : ""; + if (piece.empty()) { + token = unk_token_; + } else if (IsSpmByteWord(piece)) { + char buf[3] = {piece[3], piece[4], 0}; // something like <0x20> + token = {static_cast(strtol(buf, NULL, 16))}; + } else { + token = ReplaceAll(piece, std::string(ort_extensions::spm_escaped_space), " "); + } } if (!token.empty() && token[0] == ' ' && f_special_last && add_dummy_prefix_) {