Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial tiktoken and Phi3SmallTokenizer support #729

Merged
merged 25 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
44e4d7d
add initial tiktoken support
May 24, 2024
4618c5b
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
May 24, 2024
4df908a
add vector hash and equal for bpe ranks map
May 24, 2024
a06e25d
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
May 28, 2024
0addca4
change lambda comparator
May 28, 2024
1abc5b2
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
May 29, 2024
9f9af88
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
May 30, 2024
0a499ca
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
Jun 3, 2024
ef88a21
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
Jun 3, 2024
888055b
move phi-3-small files
Jun 4, 2024
3720245
final changes
Jun 4, 2024
6456499
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
Jun 6, 2024
63c8445
move tiktoken files from data2 to data
Jun 6, 2024
4d0b35e
add unit test
Jun 6, 2024
625c274
Merge branch 'main' into sayanshaw/tiktoken
wenbingl Jun 13, 2024
27196c7
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
Jun 26, 2024
469ad3a
add tokenizer module
Jun 26, 2024
ecd134f
Merge branch 'sayanshaw/tiktoken' of https://github.com/microsoft/onn…
Jun 26, 2024
910743b
merge json and tiktoken impl
Jun 27, 2024
6f95616
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
Aug 1, 2024
3581ac2
fix tiktoken encoding problem
Aug 1, 2024
9bd937a
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
Aug 1, 2024
b284c92
address comments
Aug 1, 2024
9054900
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
Aug 2, 2024
520a4e2
remove dummy tokens
Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion operators/tokenizer/bpe_json.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,23 @@ class TokenJsonConfig final {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open a json file: " + file_path.string());
}

vocab_path_ = (path(json_path) / "tokenizer.json").string();
auto vocab_file_path = path(json_path) / "tokenizer.json";
vocab_path_ = vocab_file_path.string();
std::ifstream vocab_fs = vocab_file_path.open();
if (!vocab_fs.is_open()) {
// No tokenizer.json file present; search for tokenizer module file
auto module_file_path = path(json_path) / "tokenizer_module.json";
module_path_ = module_file_path.string();
std::ifstream tok_module_ifs = module_file_path.open();
if (!tok_module_ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "No tokenizer.json or tokenizer_module.json file found.");
} else {
nlohmann::json tok_module_json_config = nlohmann::json::parse(tok_module_ifs);
auto tiktoken_path = tok_module_json_config.value("tiktoken_file", "");
vocab_file_path = path(json_path) / tiktoken_path.c_str();
vocab_path_ = vocab_file_path.string();
}
}
nlohmann::json json_config = nlohmann::json::parse(ifs);
add_bos_token_ = json_config.value("add_bos_token", false);
add_eos_token_ = json_config.value("add_eos_token", false);
Expand Down Expand Up @@ -66,6 +82,10 @@ class TokenJsonConfig final {

const std::string& GetVocabDataFile() const { return vocab_path_; }

const std::string& GetTikTokenModuleFile() const {
return module_path_;
}

public:
bool add_bos_token_{};
bool add_eos_token_{};
Expand All @@ -80,6 +100,7 @@ class TokenJsonConfig final {

private:
std::string vocab_path_;
std::string module_path_;
};

} // namespace ort_extensions::bpe
162 changes: 150 additions & 12 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "bpe_json.hpp"
#include "bpe_tokenizer.hpp"

#include "base64.h"

#include <optional>
#include <limits>

Expand Down Expand Up @@ -552,13 +554,140 @@ SpmTokenizer::SpmTokenizer()

JsonFastTokenizer::JsonFastTokenizer() : KernelBpeTokenizer(kGPT2Configuration) {}

/*
Read more here: https://github.com/huggingface/transformers/blob/60bb571e993b7d73257fb64044726b569fef9403/src/transformers/convert_slow_tokenizer.py#L1454

Note: this is similar to the BPE CreateByteEncoder, however for decoding the .tiktoken bytes
we need to store the strings rather than their IDs, and thereby need a separate map.
*/
void JsonFastTokenizer::CreateUnicodeByteEncoder() {
char32_t index = 256;
for (char32_t i = 0; i < 256; ++i) {
if ((i >= 0 && i < 33) || (i >= 127 && i < 161) || (i == 173)) {
unicode_byte_encoder_[i] = ustring::EncodeUTF8Char(index++);
} else {
unicode_byte_encoder_[i] = ustring::EncodeUTF8Char(i);
}
}
}

std::string JsonFastTokenizer::TokenBytesToString(std::vector<uint8_t>& bytes) {
std::string result;
for (auto c : bytes) {
result += unicode_byte_encoder_[static_cast<unsigned char>(c)];
}
return result;
}

// Custom hash function for the vector key
struct VectorHash {
size_t operator()(const std::vector<uint8_t>& v) const {
std::hash<uint8_t> hasher;
size_t seed = 0;
for (uint8_t i : v) {
seed ^= hasher(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};

// Custom equality function for the vector key
struct VectorEqual {
bool operator()(const std::vector<uint8_t>& a, const std::vector<uint8_t>& b) const {
return a == b;
}
};

OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& config) {
std::string voc_file = config.GetVocabDataFile();
std::ifstream ifs = path(voc_file).open();
if (!ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open json file: " + voc_file);
}

// consider to use SAX parser for large json file
nlohmann::json tok_json;
std::ifstream module_ifs;

// Following vocab and merges only used for tiktoken case but accessed outside scope below
std::unordered_map<std::string, uint32_t> vocab;
std::vector<std::pair<std::string, std::string>> merges;

if (tiktoken_){
std::string module_file = config.GetTikTokenModuleFile();

module_ifs = path(module_file).open();
if (!module_ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open module file: " + module_file);
}

std::unordered_map<std::vector<uint8_t>, uint32_t, VectorHash, VectorEqual> bpe_ranks;

std::string line;
while (std::getline(ifs, line)) {
if (!line.empty()) {
std::istringstream lineStream(line);
std::string token;
uint32_t rank;
while (lineStream >> token >> rank) {
// Decode base64 token and convert rank to int
std::vector<uint8_t> decoded_token;
base64_decode(token, decoded_token);
// Store bpe token and rank
bpe_ranks[decoded_token] = rank;
}
}
}

std::vector<std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>> byte_merges;

bbpe_tokenizer_ = std::make_unique<BpeModel>();
JsonFastTokenizer::CreateUnicodeByteEncoder();

for (const auto& item : bpe_ranks) {
std::vector<uint8_t> token = item.first;
uint32_t rank = item.second;
vocab[JsonFastTokenizer::TokenBytesToString(token)] = rank;

if (token.size() == 1) {
continue;
}

std::vector<std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>> local;
for (size_t index = 1; index < token.size(); index++) {
std::vector<uint8_t> piece_l(token.begin(), token.begin() + index);
std::vector<uint8_t> piece_r(token.begin() + index, token.end());
if (bpe_ranks.count(piece_l) && bpe_ranks.count(piece_r)) {
local.emplace_back(piece_l, piece_r, rank);
}
}

auto compare_bpe_tuples = [&](const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& a,
const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& b) {
// Compare comparator based on the ranks in bpe_ranks
return bpe_ranks[std::get<0>(a)] < bpe_ranks[std::get<0>(b)] ||
(bpe_ranks[std::get<0>(a)] == bpe_ranks[std::get<0>(b)] && bpe_ranks[std::get<1>(a)] < bpe_ranks[std::get<1>(b)]);
};

std::sort(local.begin(), local.end(), compare_bpe_tuples);

byte_merges.insert(byte_merges.end(), local.begin(), local.end());
}

// Custom comparator that compares the third element of the tuples
auto compare_merge_tuples = [&](const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& a,
const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& b) {
return std::get<2>(a) < std::get<2>(b);
};

std::sort(byte_merges.begin(), byte_merges.end(), compare_merge_tuples);

// Populate merges
for (auto& val : byte_merges) {
merges.push_back({JsonFastTokenizer::TokenBytesToString(std::get<0>(val)), JsonFastTokenizer::TokenBytesToString(std::get<1>(val))});
}
}

const char token_sub[] = "Tokenizer";
model_name_ = config.tokenizer_class_.substr(0, config.tokenizer_class_.find(token_sub));
json_conf_.name_ = model_name_.c_str();
Expand All @@ -570,18 +699,27 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
// re-bind the configuration object
bpe_conf_ = json_conf_;

// consider to use SAX parser for large json file
nlohmann::json tok_json;
ifs >> tok_json;
auto model_node = tok_json.find("model");
if (model_node == tok_json.end()) {
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
}
OrtxStatus status;
if (tiktoken_){
status = bbpe_tokenizer_->Load(vocab,
merges,
bpe_conf_.get().GetSpecialTokens().c_str(),
false);

bbpe_tokenizer_ = std::make_unique<BpeModel>();
auto status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
module_ifs >> tok_json;
} else {
ifs >> tok_json;
auto model_node = tok_json.find("model");
if (model_node == tok_json.end()) {
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
}

bbpe_tokenizer_ = std::make_unique<BpeModel>();
status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
}


auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
Expand Down Expand Up @@ -640,4 +778,4 @@ OrtxStatus JsonFastTokenizer::Compute(const ortc::Tensor<std::string>& input,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
}
}
24 changes: 24 additions & 0 deletions operators/tokenizer/bpe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ struct SpmTokenizer : KernelBpeTokenizer {
class JsonFastTokenizer : KernelBpeTokenizer {
public:
JsonFastTokenizer();
bool tiktoken_ = false;
std::string unicode_byte_encoder_[256] = {};
void CreateUnicodeByteEncoder();
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
Expand All @@ -121,3 +125,23 @@ class JsonFastTokenizer : KernelBpeTokenizer {
BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
};

class TikTokenizer : KernelBpeTokenizer {
public:
TikTokenizer();
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;

public:
const auto& GetAddedTokens() const { return added_tokens_; }
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }

private:
std::unique_ptr<ort_extensions::BpeModel>bbpe_tokenizer_;
BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
};
41 changes: 41 additions & 0 deletions operators/tokenizer/bpe_tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,47 @@ class BpeModel {
return {};
}

OrtxStatus Load(std::unordered_map<std::string, uint32_t>& vocab,
std::vector<std::pair<std::string, std::string>>& merges,
const char* /* special_tokens */,
bool spm_converted) {
vocab_map_ = vocab;

if (spm_converted) {
UpdateSpmByteToken(vocab_map_);
} else {
CreateByteEncoder();
}

uint32_t index = 0;
for (auto& tuple : merges){
std::string w1 = tuple.first;
std::string w2 = tuple.second;
int token_length = ort_extensions::narrow<int>(w1.length() + w2.length());
if (w2.find("</w>") != std::string::npos || w1.find("</w>") != std::string::npos) {
token_length -= 4;
}
auto iw1 = GetTokenId(w1);
auto iw2 = GetTokenId(w2);
auto iww = GetTokenId(w1 + w2);
BpeNode value{iww, index++, token_length};
bpe_rank_[GetRankKey(iw1, iw2)] = value;
}

id2token_map_.resize(vocab_map_.size());
for (const auto& [t, i] : vocab_map_) {
if (i > static_cast<uint32_t>((std::numeric_limits<int32_t>::max)())) {
continue; // safe purpose.
}
if (i > id2token_map_.size()) {
id2token_map_.resize(static_cast<size_t>(i) + 1);
}
id2token_map_[i] = t;
}

return {};
}

OrtxStatus LoadAddedTokens(const char* added_tokens) {
int id = bpe::kInvalidTokenId;
std::istringstream strm_tokens(added_tokens);
Expand Down
24 changes: 18 additions & 6 deletions shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,24 @@ OrtxStatus TokenizerImpl::Load(const std::string& dir) {
return status;
}

auto vocab_file_path = path(dir) / "tokenizer.json";
std::ifstream vocab_fs = vocab_file_path.open();

tokenizer_ = std::make_unique<JsonFastTokenizer>();
// load the tokenizer from a config
status = tokenizer_->Load(*tok_config_);
if (status.IsOk()) {
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
status = detokenizer_->Load(tok_config_, *tokenizer_);
if (!vocab_fs.is_open()) {
// No tokenizer.json file present; use TikToken tokenizer
tokenizer_->tiktoken_ = true;

// load the tokenizer from a config
status = tokenizer_->Load(*tok_config_);
} else {
// load the tokenizer from a config
status = tokenizer_->Load(*tok_config_);

if (status.IsOk()) {
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
status = detokenizer_->Load(tok_config_, *tokenizer_);
}
}

return status;
Expand All @@ -34,7 +46,7 @@ OrtxStatus TokenizerImpl::BatchEncode(const std::vector<std::string_view>& input
for (const auto& s : input) {
ortc::Tensor<int64_t> ts_output(&CppAllocator::Instance());
ortc::Tensor<std::string> ts_input = ortc::Tensor<std::string>(std::vector<std::string>{std::string(s)});
auto status = tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);
OrtxStatus status = tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);

if (!status.IsOk()) {
return status;
Expand Down
1 change: 1 addition & 0 deletions shared/api/tokenizer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class TokenizerImpl : public OrtxObjectImpl {
std::vector<std::vector<extTokenId_t>>& t_ids) const;

private:
bool tiktoken = false;
std::string tokenizer_dir_;
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig> tok_config_;
std::unique_ptr<JsonFastTokenizer> tokenizer_;
Expand Down
Loading
Loading