Skip to content

Commit

Permalink
Add more tests for pre-processing C APIs (#793)
Browse files Browse the repository at this point in the history
* initial api for tokenizer

* More fixings and test data refinement

* add a simple wrapper for pre-processing APIs

* fix the test issues

* test if the tokenizer is spm based

* fix the failed test cases

* json pointer does not work
  • Loading branch information
wenbingl authored Aug 21, 2024
1 parent 85ffb94 commit 8f2c35f
Show file tree
Hide file tree
Showing 15 changed files with 278 additions and 221 deletions.
4 changes: 4 additions & 0 deletions .pyproject/cmdclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def initialize_options(self):
self.no_azure = None
self.no_opencv = None
self.cc_debug = None
self.pp_api = None
self.cuda_archs = None
self.ort_pkg_dir = None

Expand Down Expand Up @@ -210,6 +211,9 @@ def build_cmake(self, extension):
'-DOCOS_ENABLE_CV2=OFF',
'-DOCOS_ENABLE_VISION=OFF']

if self.pp_api:
cmake_args += ['-DOCOS_ENABLE_C_API=ON']

if self.no_azure is not None:
azure_flag = "OFF" if self.no_azure == 1 else "ON"
cmake_args += ['-DOCOS_ENABLE_AZURE=' + azure_flag]
Expand Down
30 changes: 17 additions & 13 deletions base/ustring.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ class ustring : public std::u32string {
public:
ustring() = default;

explicit ustring(const char* str) { assign(FromUTF8(str)); }
explicit ustring(const char* str) { assign(std::move(FromUTF8(str))); }

explicit ustring(const std::string& str) { assign(FromUTF8(str)); }
explicit ustring(const std::string& str) { assign(std::move(FromUTF8(str))); }

explicit ustring(const std::string_view& str) { assign(FromUTF8(str)); }
explicit ustring(const std::string_view& str) { assign(std::move(FromUTF8(str))); }

explicit ustring(const char32_t* str) : std::u32string(str) {}

Expand Down Expand Up @@ -76,11 +76,15 @@ class ustring : public std::u32string {
}
}

static bool ValidateUTF8(const std::string& data) {
// return a negative value for the first invalid utf8 char position,
// otherwise the position of the terminating null character, which is the end of the string.
static ptrdiff_t ValidateUTF8(const std::string& data) {
const unsigned char* s = reinterpret_cast<const unsigned char*>(data.c_str());
const unsigned char* s_begin = s;
const unsigned char* s_end = s + data.size();

if (*s_end != '\0')
return false;
return 0;

while (*s) {
if (*s < 0x80)
Expand All @@ -89,45 +93,45 @@ class ustring : public std::u32string {
else if ((s[0] & 0xe0) == 0xc0) {
/* 110XXXXx 10xxxxxx */
if (s + 1 >= s_end) {
return false;
return s_begin - s;
}
if ((s[1] & 0xc0) != 0x80 ||
(s[0] & 0xfe) == 0xc0) /* overlong? */
return false;
return s_begin - s;
else
s += 2;
} else if ((s[0] & 0xf0) == 0xe0) {
/* 1110XXXX 10Xxxxxx 10xxxxxx */
if (s + 2 >= s_end) {
return false;
return s_begin - s;
}
if ((s[1] & 0xc0) != 0x80 ||
(s[2] & 0xc0) != 0x80 ||
(s[0] == 0xe0 && (s[1] & 0xe0) == 0x80) || /* overlong? */
(s[0] == 0xed && (s[1] & 0xe0) == 0xa0) || /* surrogate? */
(s[0] == 0xef && s[1] == 0xbf &&
(s[2] & 0xfe) == 0xbe)) /* U+FFFE or U+FFFF? */
return false;
return s_begin - s;
else
s += 3;
} else if ((s[0] & 0xf8) == 0xf0) {
/* 11110XXX 10XXxxxx 10xxxxxx 10xxxxxx */
if (s + 3 >= s_end) {
return false;
return s_begin - s;
}
if ((s[1] & 0xc0) != 0x80 ||
(s[2] & 0xc0) != 0x80 ||
(s[3] & 0xc0) != 0x80 ||
(s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) || /* overlong? */
(s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) /* > U+10FFFF? */
return false;
return s_begin - s;
else
s += 4;
} else
return false;
return s_begin - s;
}

return true;
return s - s_begin;
}

private:
Expand Down
11 changes: 10 additions & 1 deletion docs/c_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,13 @@ Most APIs accept raw data inputs such as audio, image compressed binary formats,

**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extraction.cc#L16).

NB: If onnxruntime-extensions is to build as a shared library, which requires the OCOS_ENABLE_AUDIO OCOS_ENABLE_CV2 OCOS_ENABLE_OPENCV_CODECS OCOS_ENABLE_GPT2_TOKENIZER build flags are ON to have a full function of binary. Only onnxruntime-extensions static library can be used for a minimal build with the selected operators, so in that case, the shared library build can be switched off by `-DOCOS_BUILD_SHARED_LIB=OFF`.
**NB:** If onnxruntime-extensions is to build as a shared library, which requires the OCOS_ENABLE_AUDIO OCOS_ENABLE_CV2 OCOS_ENABLE_OPENCV_CODECS OCOS_ENABLE_GPT2_TOKENIZER build flags are ON to have a full function of binary. Only onnxruntime-extensions static library can be used for a minimal build with the selected operators, so in that case, the shared library build can be switched off by `-DOCOS_BUILD_SHARED_LIB=OFF`.

There is a simple Python wrapper on these C API in [pp_api](../onnxruntime_extensions/pp_api.py), which can have a easy access these APIs in Python code like

```Python
from onnxruntime_extensions.pp_api import Tokenizer
# the name can be the same one used by Huggingface transformers.AutoTokenizer
pp_tok = Tokenizer('google/gemma-2-2b')
print(pp_tok.tokenize("what are you? \n 给 weiss ich, über was los ist \n"))
```
9 changes: 0 additions & 9 deletions docs/custom_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -531,15 +531,6 @@ expect(node, inputs=[inputs],
</details>


### BlingFireSentenceBreaker

TODO

### BpeTokenizer

TODO


## String operators

### StringEqual
Expand Down
1 change: 1 addition & 0 deletions docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The package contains all custom operators and some Python scripts to manipulate
- no-azure: disable AzureOp kernel build in Python package.
- no-opencv: disable operators based on OpenCV in build.
- cc-debug: generate debug info for extensions binaries and disable C/C++ compiler optimization.
- pp_api: enable pre-processing C ABI Python wrapper, `from onnxruntime_extensions.pp_api import *`
- cuda-archs: specify the CUDA architectures(like 70, 85, etc.), and the multiple values can be combined with semicolon. The default value is nvidia-smi util output of GPU-0
- ort\_pkg\_dir: specify ONNXRuntime package directory the extension project is depending on. This is helpful if you want to use some ONNXRuntime latest function which has not been involved in the official build

Expand Down
62 changes: 60 additions & 2 deletions onnxruntime_extensions/pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,69 @@
# license information.
###############################################################################

import os
from . import _extensions_pydll as _C
if not hasattr(_C, "create_processor"):
raise ImportError("onnxruntime_extensions is not built with pre-processing API")
if not hasattr(_C, "delete_object"):
raise ImportError(
"onnxruntime_extensions is not built with pre-processing C API"
"To enable it, please build the package with --ortx-user-option=pp_api")

create_processor = _C.create_processor
load_images = _C.load_images
image_pre_process = _C.image_pre_process
tensor_result_get_at = _C.tensor_result_get_at

create_tokenizer = _C.create_tokenizer
batch_tokenize = _C.batch_tokenize
batch_detokenize = _C.batch_detokenize

delete_object = _C.delete_object


class Tokenizer:
def __init__(self, tokenizer_dir):
if os.path.isdir(tokenizer_dir):
self.tokenizer = create_tokenizer(tokenizer_dir)
else:
try:
from transformers.utils import cached_file
resolved_full_file = cached_file(
tokenizer_dir, "tokenizer.json")
resolved_config_file = cached_file(
tokenizer_dir, "tokenizer_config.json")
except ImportError:
raise ValueError(
f"Directory '{tokenizer_dir}' not found and transformers is not available")
if not os.path.exists(resolved_full_file):
raise FileNotFoundError(
f"Downloaded HF file '{resolved_full_file}' cannot be found")
if (os.path.dirname(resolved_full_file) != os.path.dirname(resolved_config_file)):
raise FileNotFoundError(
f"Downloaded HF files '{resolved_full_file}' and '{resolved_config_file}' are not in the same directory")

tokenizer_dir = os.path.dirname(resolved_full_file)
self.tokenizer = create_tokenizer(tokenizer_dir)

def tokenize(self, text):
return batch_tokenize(self.tokenizer, [text])[0]

def detokenize(self, tokens):
return batch_detokenize(self.tokenizer, [tokens])[0]

def __del__(self):
if delete_object and self.tokenizer:
delete_object(self.tokenizer)
self.tokenizer = None


class ImageProcessor:
def __init__(self, processor_json):
self.processor = create_processor(processor_json)

def pre_process(self, images):
return image_pre_process(self.processor, images)

def __del__(self):
if delete_object and self.processor:
delete_object(self.processor)
self.processor = None
34 changes: 24 additions & 10 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ static bool IsBosEosRequired(const std::string& model_name) {
return model_name != kModel_GPT2 && model_name != kModel_CodeGen;
}

static bool IsSpmModel(const std::string& model_name) {
return model_name == kModel_Llama ||
model_name == kModel_Gemma;
}

std::string BpeModelConf::GetSpecialTokens() const {
std::string special_tokens = unk_token_; // unk_token_ is required
auto add_token = [](std::string& sp, const char* tok) {
Expand Down Expand Up @@ -145,7 +140,7 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne
merges_stream,
bpe_conf_.get().unk_token_,
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
bpe_conf_.get().spm_model_);
if (!status.IsOk()) {
return (OrtStatusPtr)status;
}
Expand Down Expand Up @@ -454,7 +449,7 @@ OrtxStatus KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
}

auto tok_fun = &KernelBpeTokenizer::Tokenize;
if (IsSpmModel(ModelName())) {
if (bpe_conf_.get().spm_model_) {
tok_fun = &KernelBpeTokenizer::SpmTokenize;
}

Expand Down Expand Up @@ -556,7 +551,8 @@ static const auto kSpmConfiguration = BpeModelConf{
"<unk>", // unk_token
"<s>", // bos_token
"</s>", // eos_token
""}; // pad_token
"", // pad_token
true};

SpmTokenizer::SpmTokenizer()
: KernelBpeTokenizer(kSpmConfiguration) {}
Expand Down Expand Up @@ -718,15 +714,33 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
module_ifs >> tok_json;
} else {
ifs >> tok_json;
// auto decoders_node = tok_json.find("/decoder/decoders"_json_pointer);
auto decoders_node = tok_json.find("decoder");
if (decoders_node != tok_json.end()) {
decoders_node = decoders_node->find("decoders");
}

if (decoders_node->is_array()) {
for(auto step = decoders_node->begin(); step != decoders_node->end(); ++step) {
std::string type = step->value("type", "");
if (type == "Replace") {
std::string target = step->value("/pattern/String"_json_pointer, "");
if (target == "\xe2\x96\x81") {
json_conf_.spm_model_ = true;
break;
}
}
}
}
auto model_node = tok_json.find("model");
if (model_node == tok_json.end()) {
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
}

bbpe_tokenizer_ = std::make_unique<BpeModel>();
status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
bpe_conf_.get().GetSpecialTokens().c_str(),
bpe_conf_.get().spm_model_);
}


Expand Down
27 changes: 5 additions & 22 deletions operators/tokenizer/bpe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct BpeModelConf {
const char* eos_token_{"<|endoftext|>"};
const char* pad_token_{nullptr};

bool spm_model_{};
std::string GetSpecialTokens() const;
};

Expand Down Expand Up @@ -108,10 +109,6 @@ struct SpmTokenizer : KernelBpeTokenizer {
class JsonFastTokenizer : public KernelBpeTokenizer {
public:
JsonFastTokenizer();
bool tiktoken_ = false;
std::string unicode_byte_encoder_[256] = {};
void CreateUnicodeByteEncoder();
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
Expand All @@ -121,28 +118,14 @@ class JsonFastTokenizer : public KernelBpeTokenizer {
public:
const auto& GetAddedTokens() const { return added_tokens_; }
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }
bool IsSpmModel() const { return json_conf_.spm_model_; }
bool tiktoken_ = false;

private:
BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
};

class TikTokenizer : KernelBpeTokenizer {
public:
TikTokenizer();
void CreateUnicodeByteEncoder();
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;

public:
const auto& GetAddedTokens() const { return added_tokens_; }
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }

private:
std::unique_ptr<ort_extensions::BpeModel>bbpe_tokenizer_;
BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
std::string unicode_byte_encoder_[256] = {};
};
Loading

0 comments on commit 8f2c35f

Please sign in to comment.