Skip to content

Commit

Permalink
Add a generic tokenizer operator to support Hugging Face Tokenizer wi…
Browse files Browse the repository at this point in the history
…th JSON data files. (#859)

* add the hf json file embedded tokenizer

* add a unit test

* fix the build break

* update the ort version

* build/test break fixes
  • Loading branch information
wenbingl authored Dec 12, 2024
1 parent 588f235 commit 1a21d45
Show file tree
Hide file tree
Showing 18 changed files with 2,227 additions and 2,119 deletions.
2 changes: 1 addition & 1 deletion cmake/ext_ortlib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ else()
if (OCOS_ONNXRUNTIME_VERSION)
set(ONNXRUNTIME_VER ${OCOS_ONNXRUNTIME_VERSION})
else()
set(ONNXRUNTIME_VER "1.17.1")
set(ONNXRUNTIME_VER "1.19.2") # need to check if android package of this version is available too.
endif()

if (ANDROID)
Expand Down
2 changes: 1 addition & 1 deletion include/ort_c_to_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ inline bool IsScalarOr1ElementVector(size_t num_dimensions, int64_t shape_size)
return false;
}

#define ORTX_RETURN_IF_ERROR(expr) \
#define ORTW_RETURN_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if (_status != nullptr) { \
Expand Down
4 changes: 2 additions & 2 deletions include/ortx_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ struct OrtxTokenizerBlob {
#ifdef __cplusplus
OrtxTokenizerBlob(const std::string_view& config_json_blob,
const std::string_view& vocab_json_blob,
const std::string_view& token_module_blob,
const std::string_view& raw_model_blob)
const std::string_view& token_module_blob = {},
const std::string_view& raw_model_blob = {})
: config_json_blob(config_json_blob.data()), vocab_json_blob(vocab_json_blob.data()),
token_module_blob(token_module_blob.data()), raw_model_blob(raw_model_blob.data()),
config_blob_len(config_json_blob.size()),
Expand Down
8 changes: 8 additions & 0 deletions include/ortx_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ typedef OrtxObject OrtxTensorResult;
// C, instead of C++ doesn't cast automatically,
// so we need to use a macro to cast the object to the correct type
#define ORTX_DISPOSE(obj) OrtxDispose((OrtxObject**)&obj)
#define ORTX_RETURN_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if (!_status.IsOk()) { \
return _status; \
} \
} while (0)


typedef uint32_t extTokenId_t;

Expand Down
16 changes: 8 additions & 8 deletions operators/tokenizer/bpe_decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,26 @@ struct KernelBpeDecoder {
}

std::string added_tokens;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_tokens", added_tokens));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_tokens", added_tokens));
if (!added_tokens.empty()) {
auto um = ParseId2String(added_tokens);
added_tokens_ = std::map<int64_t, std::string>(um.begin(), um.end());
}

std::string all_special_ids;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "all_special_ids", all_special_ids));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "all_special_ids", all_special_ids));
if (!all_special_ids.empty()) {
auto um = ParseId2String(all_special_ids);
std::transform(um.begin(), um.end(), std::inserter(all_special_ids_, all_special_ids_.end()),
[](const auto& p) { return p.first; });
}

ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "en_normalization", en_normalization_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "skip_special_tokens", skip_special_tokens_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "whitespace_token", whitespace_token_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "bos_token", bos_token_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "eos_token", eos_token_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "unk_token", unk_token_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "en_normalization", en_normalization_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "skip_special_tokens", skip_special_tokens_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "whitespace_token", whitespace_token_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "bos_token", bos_token_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "eos_token", eos_token_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "unk_token", unk_token_));

return status;
}
Expand Down
10 changes: 5 additions & 5 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,24 +124,24 @@ KernelBpeTokenizer::KernelBpeTokenizer(const BpeModelConf& conf)
OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
// note: if the attribute doesn't exist in op node, GetOpAttribute doesn't return a failed status;
std::string vocab;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", vocab));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", vocab));
if (vocab.empty()) {
return OrtW::CreateStatus("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
}

std::string merges;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "merges", merges));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "merges", merges));
if (merges.empty()) {
return OrtW::CreateStatus("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
}

ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "padding_length", padding_length_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "padding_length", padding_length_));
if (padding_length_ != -1 && padding_length_ <= 0) {
return OrtW::CreateStatus("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT);
}

std::string model_name;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "model_name", model_name));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "model_name", model_name));
if (!model_name.empty()) {
model_name_ = model_name;
}
Expand All @@ -159,7 +159,7 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne
}

std::string added_token;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_token", added_token));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_token", added_token));
status = bbpe_tokenizer_->LoadAddedTokens(added_token.c_str());
if (!status.IsOk()) {
return (OrtStatusPtr)status;
Expand Down
2 changes: 1 addition & 1 deletion operators/tokenizer/sentencepiece_decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
struct KernelSentencepieceDecoder {
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string model_blob;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "model", model_blob));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "model", model_blob));

sentencepiece::ModelProto model_proto;
model_proto.ParseFromArray(model_blob.data(), static_cast<int>(model_blob.size()));
Expand Down
2 changes: 1 addition & 1 deletion operators/tokenizer/sentencepiece_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

OrtStatusPtr KernelSentencepieceTokenizer::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string model_as_string;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "model", model_as_string));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "model", model_as_string));

sentencepiece::ModelProto model_proto;
std::vector<uint8_t> model_as_bytes;
Expand Down
32 changes: 32 additions & 0 deletions operators/tokenizer/tokenizer_jsconfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,30 @@

namespace ort_extensions {

enum class TokenType {
kUnknown, kUnigram, kBPE
};

constexpr std::pair<const char*, TokenType> kTokenizerDict[] = {
{"PreTrainedTokenizerFast", TokenType::kBPE},
{"CLIPTokenizer", TokenType::kBPE},
{"WhisperTokenizer", TokenType::kBPE},
{"GemmaTokenizer", TokenType::kBPE},
{"LlamaTokenizer", TokenType::kBPE},
{"Phi3Tokenizer", TokenType::kBPE},
{"CodeLlamaTokenizer", TokenType::kBPE},
{"CodeGenTokenizer", TokenType::kBPE},
{"GPT2Tokenizer", TokenType::kBPE},
{"Qwen2Tokenizer", TokenType::kBPE},
{"BaichuanTokenizer", TokenType::kBPE},

{"", TokenType::kUnigram},
{"T5Tokenizer", TokenType::kUnigram},
{"ChatGLMTokenizer", TokenType::kUnigram},
{"XLMRobertaTokenizer", TokenType::kUnigram}
};


// TokenJsonConfig: Handles loading and parsing of JSON configuration files for tokenizers
class TokenJsonConfig final {
public:
Expand Down Expand Up @@ -230,6 +254,14 @@ class TokenJsonConfig final {
return added_token;
}

static TokenType GetTokenType(const std::string& tok) {
static const std::unordered_map<std::string, TokenType> dict {
std::begin(kTokenizerDict), std::end(kTokenizerDict) };

auto iter = dict.find(tok);
return iter == dict.end() ? TokenType::kUnknown : iter->second;
}

private:
void LoadAddedTokens(const json& tok_json) {
auto added_tokens = tok_json.find("added_tokens");
Expand Down
56 changes: 56 additions & 0 deletions operators/tokenizer/tokenizer_op_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <variant>

#include "bpe_kernels.h"
#include "ugm_kernels.hpp"

#include "ext_status.h"
#include "op_def_struct.h"
#include "ort_c_to_cpp.h"

namespace ort_extensions {

class JsonTokenizerOpKernel {
public:
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string config_json;
ORTW_RETURN_IF_ERROR(OrtW::API::GetOpAttributeString(api, info, "tokenizer_config", config_json));

std::string vocab_json;
ORTW_RETURN_IF_ERROR(OrtW::API::GetOpAttributeString(api, info, "tokenizer_vocab", vocab_json));

TokenJsonConfig cfg;
OrtxTokenizerBlob blob({config_json.c_str(), config_json.length()},
{vocab_json.c_str(), vocab_json.length()});

ORTX_RETURN_IF_ERROR(cfg.LoadFromBlob(blob));

auto type = TokenJsonConfig::GetTokenType(cfg.tokenizer_class_);
if (type == TokenType::kUnigram) {
tokenizer_ = std::make_unique<SpmUgmTokenizer>();
} else if (type == TokenType::kBPE) {
tokenizer_ = std::make_unique<JsonFastTokenizer>();
} else {
return OrtxStatus(kOrtxErrorCorruptData, "Unknown tokenizer type");
}

return std::visit([&](auto& ptr) { return ptr->Load(cfg); }, tokenizer_);
}

OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask = std::nullopt,
std::optional<ortc::Tensor<int64_t>*> offset_mapping = std::nullopt) const {

return std::visit([&](auto& ptr) {
return ptr->Compute(input, tokenize_output, attention_mask, offset_mapping);
}, tokenizer_);
}

private:
std::variant<std::unique_ptr<JsonFastTokenizer>, std::unique_ptr<SpmUgmTokenizer>> tokenizer_;
};

} // namespace ort_extensions
3 changes: 3 additions & 0 deletions operators/tokenizer/tokenizers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "bpe_kernels.h"
#include "bpe_tokenizer_model.hpp"
#include "bpe_decoder.hpp"
#include "tokenizer_op_impl.hpp"
using namespace ort_extensions;
#endif

#ifdef ENABLE_SPM_TOKENIZER
Expand Down Expand Up @@ -40,6 +42,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = []() -> CustomOpArray& {
CustomCpuStructV2("RobertaTokenizer", RobertaTokenizer),
CustomCpuStructV2("BpeDecoder", KernelBpeDecoder),
CustomCpuStructV2("SpmTokenizer", SpmTokenizer),
CustomCpuStructV2("HfJsonTokenizer", JsonTokenizerOpKernel),
#endif

#ifdef ENABLE_SPM_TOKENIZER
Expand Down
4 changes: 2 additions & 2 deletions operators/tokenizer/trie_tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct KernelTrieTokenizer {
public:
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string text_tokens;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", text_tokens));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", text_tokens));
tokenizer = std::make_shared<TrieTokenizer>(text_tokens);
return nullptr;
};
Expand Down Expand Up @@ -156,7 +156,7 @@ struct KernelTrieDetokenizer {
public:
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string text_tokens;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", text_tokens));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", text_tokens));
tokenizer = std::make_shared<TrieTokenizer>(text_tokens);
return nullptr;
};
Expand Down
8 changes: 7 additions & 1 deletion operators/tokenizer/ugm_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,13 @@ struct SpmUgmTokenizer {
return std::get<0>(iter->second);
}

OrtxStatus Compute(const ortc::Tensor<std::string>& input, ortc::Tensor<int64_t>& tokenize_output) const {
OrtxStatus Compute(const ortc::Tensor<std::string>& input, ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask = std::nullopt,
std::optional<ortc::Tensor<int64_t>*> offset_mapping = std::nullopt) const {
if (attention_mask.has_value() || offset_mapping.has_value()) {
return {kOrtxErrorInvalidArgument, "attention-mask or offset-mapping was supported in unigram tokenizer"};
}

if (input.Shape().size() != 1) {
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Input tensor must have rank 1.");
}
Expand Down
Loading

0 comments on commit 1a21d45

Please sign in to comment.