Skip to content

Commit

Permalink
the new tokenizer API
Browse files Browse the repository at this point in the history
  • Loading branch information
wenbingl committed Jan 24, 2024
1 parent 44e494b commit 386515d
Show file tree
Hide file tree
Showing 25 changed files with 3,079 additions and 1 deletion.
13 changes: 12 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ include(CheckLanguage)

option(CC_OPTIMIZE "Allow compiler optimizations, Set to OFF to disable" ON)
option(OCOS_ENABLE_PYTHON "Enable Python component building, (deprecated)" OFF)
option(OCOS_ENABLE_CTEST "Enable C++ test" OFF)
option(OCOS_ENABLE_CTEST "Enable C++ test" ON)
option(OCOS_ENABLE_CPP_EXCEPTIONS "Enable C++ Exception" ON)
option(OCOS_ENABLE_TF_STRING "Enable String Operator Set" ON)
option(OCOS_ENABLE_RE2_REGEX "Enable StringRegexReplace and StringRegexSplit" ON)
Expand All @@ -64,6 +64,7 @@ option(OCOS_ENABLE_CV2 "Enable the operators in `operators/cv2`" ON)
option(OCOS_ENABLE_VISION "Enable the operators in `operators/vision`" ON)
option(OCOS_ENABLE_AUDIO "Enable the operators for audio processing" ON)
option(OCOS_ENABLE_AZURE "Enable the operators for azure execution provider" OFF)
option(OCOS_ENABLE_TOKENIZER_API "Enable building the tokenizer API" ON)

option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
option(OCOS_ENABLE_SELECTED_OPLIST "Enable including the selected_ops tool file" OFF)
Expand Down Expand Up @@ -712,6 +713,16 @@ else()
set(_BUILD_SHARED_LIBRARY TRUE)
endif()

if (OCOS_ENABLE_TOKENIZER_API)
message(STATUS "Build the tokenizer API")
include(simdjson)
add_subdirectory(tfmtok)
target_include_directories(tfmtok PUBLIC
${PROJECT_SOURCE_DIR}/includes/tfmtok
${simdjson_SOURCE_DIR}/singleheader
${spm_INCLUDE_DIRS})
endif()

if(OCOS_ENABLE_AZURE)
if (ANDROID)
# find_package calls were made immediately after `include(curl)` so we know CURL and OpenSSL are available.
Expand Down
8 changes: 8 additions & 0 deletions cmake/externals/simdjson.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FetchContent_Declare(
simdjson
URL https://github.com/simdjson/simdjson/archive/refs/tags/v3.6.3.zip
URL_HASH SHA1=2b063a2e81f74a5d1cb937fadf3d2fca0f1edb09
)

FetchContent_MakeAvailable(simdjson)

158 changes: 158 additions & 0 deletions includes/tfmtok/tfmtok.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "tfmtok_c.h"

#include <string>
#include <vector>
#include <string_view>
#include <memory>

class TfmStatus final {
public:
TfmStatus();
~TfmStatus();
TfmStatus(tfmError_t code, std::string_view error_message);
TfmStatus(const TfmStatus& s);
TfmStatus& operator=(const TfmStatus& s);
bool operator==(const TfmStatus& s) const;
bool operator!=(const TfmStatus& s) const;
[[nodiscard]] inline bool ok() const { return rep_ == nullptr; }

void SetErrorMessage(const char* str);
[[nodiscard]] const char* error_message() const;
[[nodiscard]] const char* message() const { return error_message(); }
[[nodiscard]] tfmError_t code() const;
[[nodiscard]] std::string ToString() const;

private:
struct Rep;
std::unique_ptr<Rep> rep_;

public:
static TfmStatus OK();
};

class TfmObjectImpl : public TfmObject {
public:
explicit TfmObjectImpl(tfmObjectKind_t kind = tfmObjectKind_t::kTfmKindUnknown) : TfmObject() {
tfm_kind_ = static_cast<int>(kind);
};
virtual ~TfmObjectImpl() = default;

[[nodiscard]] TfmStatus IsInstanceOf(tfmObjectKind_t kind) const;
[[nodiscard]] tfmObjectKind_t tfm_kind() const {
if (tfm_kind_ < static_cast<int>(tfmObjectKind_t::kTfmKindBegin) ||
tfm_kind_ >= static_cast<int>(tfmObjectKind_t::kTfmKindEnd)) {
return tfmObjectKind_t::kTfmKindUnknown;
}
return static_cast<tfmObjectKind_t>(tfm_kind_);
}
};

namespace tfm {

class TokenConfig;

template <typename T>
class span {
public:
using value_type = std::remove_cv_t<T>;

span(T* d, size_t s) : data_(d), size_(s) {}
span(std::vector<value_type>& v) {
data_ = v.data();
size_ = v.size();
}

T* data() const { return data_; }
[[nodiscard]] size_t size() const { return size_; }
T* begin() const { return data_; }
T* end() const { return data_ + size_; }

private:
T* data_;
size_t size_;
};

/**
* @brief The Tokenizer class is responsible for tokenizing and detokenizing text.
*/
class Tokenizer : public TfmObjectImpl {
public:
/**
* @brief Loads the token configuration data and tokenizer directory.
*
* @param token_cfg A unique pointer to the token configuration.
* @param tokenizer_dir The directory path of the tokenizer.
* @return The status of the load operation.
*/
TfmStatus LoadData(std::unique_ptr<TokenConfig> token_cfg, const std::string& tokenizer_dir);

/**
* @brief Tokenizes the input text.
*
* @param input The vector of input strings to be tokenized.
* @param t_ids The vector of token IDs for each input string.
* @return The result of the tokenization operation.
*/
TfmStatus Tokenize(const std::vector<std::string_view>& input, std::vector<std::vector<tfmTokenId_t>>& t_ids) const;

/**
* @brief Detokenizes the token IDs.
*
* @param t_ids The vector of token IDs to be detokenized.
* @param t_text The vector of detokenized text.
* @return The result of the detokenization operation.
*/
TfmStatus Detokenize(const std::vector<span<tfmTokenId_t const>>& t_ids, std::vector<std::string>& t_text) const;

// the override function for all derived classes.
protected:
/**
* @brief Default constructor for the Tokenizer class.
*/
Tokenizer();

/**
* @brief Callback function called during loading.
*
* @return The status of the onload operation.
*/
virtual TfmStatus Onload() = 0;

/**
* @brief Batch encodes the input text.
*
* @param input The vector of input strings to be encoded.
* @param t_ids The vector of token IDs for each input string.
* @return The status of the batch encoding operation.
*/
virtual TfmStatus BatchEncode(const std::vector<std::string_view>& input, std::vector<std::vector<tfmTokenId_t>>& t_ids) const = 0;

/**
* @brief Batch decodes the token IDs.
*
* @param t_ids The vector of token IDs to be decoded.
* @param t_text The vector of decoded text.
* @return The status of the batch decoding operation.
*/
virtual TfmStatus BatchDecode(const std::vector<span<tfmTokenId_t const>>& t_ids, std::vector<std::string>& t_text) const = 0;
};

/**
* @brief This function creates a Tokenizer object based on the specified tokenizer path and type.
*
* @param tokenizer_path The path to the tokenizer.
* @param tokenizer_type The type of the tokenizer, if empty, the type will be inferred from the tokenizer path.
* @param status A pointer to a TfmStatus object to store the status of the tokenizer creation.
* @return A unique pointer to a Tokenizer object.
*/
std::unique_ptr<Tokenizer>
CreateTokenizer(const std::string& tokenizer_path,
const std::string& tokenizer_type = "",
TfmStatus* status = nullptr);

} // namespace tfm
187 changes: 187 additions & 0 deletions includes/tfmtok/tfmtok_c.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// C ABI header file for the tfmtok library

#pragma once

#include <stdint.h>
#include <stddef.h>

#if defined(__CYGWIN__) || defined(__MINGW32__)
#define TFM_API_CALL __stdcall
#elif defined(_WIN32)
#define TFM_API_CALL _stdcall
#define TFM_MUST_USE_RESULT
#elif __APPLE__
#define TFM_API_CALL
// To make symbols visible on macOS/iOS
#define TFM_MUST_USE_RESULT __attribute__((warn_unused_result))
#else
#define TFM_API_CALL
#define TFM_MUST_USE_RESULT
#endif

typedef enum {
kTfmOK = 0,
kTfmErrorInvalidArgument = 1,
kTfmErrorOutOfMemory = 2,
kTfmErrorInvalidFile = 3,
kTfmErrorNotFound = 4,
kTfmErrorAlreadyExists = 5,
kTfmErrorOutOfRange = 6,
kTfmErrorUnimplemented = 7,
kTfmErrorInternal = 8,
kTfmErrorUnknown = 1000
} tfmError_t;

typedef enum {
kTfmKindUnknown = 0,

kTfmKindBegin = 0x7788, // starting from a number to help validate the object
kTfmKindTokenizer = kTfmKindBegin,
kTfmKindStringArray = 0x7789,
kTfmKindTokenId2DArray = 0x778A,
kTfmKindDetokenizerCache = 0x778B,
kTfmKindEnd = 0x7999
} tfmObjectKind_t;


// all object managed by the library should be 'derived' from this struct
// which eventually will be released by TfmDispose if C++, or TFM_DISPOSE if C
typedef struct {
int tfm_kind_;
}TfmObject;

const int API_VERSION = 1;

// typedefs to create/dispose function flood, and to make the API more C++ friendly with less casting
typedef TfmObject TfmTokenizer;
typedef TfmObject TfmStringArray;
typedef TfmObject TfmTokenId2DArray;
typedef TfmObject TfmDetokenizerCache;

// C, instead of C++ doesn't cast automatically,
// so we need to use a macro to cast the object to the correct type
#define TFM_DISPOSE(obj) TfmDispose((TfmObject**)&obj)

typedef uint32_t tfmTokenId_t;

#ifdef __cplusplus
extern "C" {
#endif

/** \brief Get the current C ABI version of this library
*
* \snippet{doc} snippets.dox int Return Value
*/
int TFM_API_CALL TfmGetAPIVersion();

/** \brief Get the last error message generated by the library
*
* \param message Pointer to store the last error message
* \return Pointer to the last error message
*/
const char* TFM_API_CALL TfmGetLastErrorMessage();

/** \brief Create a new object of the specified kind
*
* \param kind The kind of object to create
* \param object Pointer to store the created object
* \param ... Additional arguments based on the kind of object
* \return Error code indicating the success or failure of the operation
*/
tfmError_t TFM_API_CALL TfmCreate(tfmObjectKind_t kind, TfmObject** object, ...);

/** \brief Dispose the specified object
*
* \param object Pointer to the object to dispose
* \return Error code indicating the success or failure of the operation
*/
tfmError_t TFM_API_CALL TfmDispose(TfmObject** object);

/** \brief Create a tokenizer object with the specified tokenizer path
*
* \param tokenizer Pointer to store the created tokenizer object
* \param tokenizer_path The path to the tokenizer
* \return Error code indicating the success or failure of the operation
*/
tfmError_t TFM_API_CALL TfmCreateTokenizer(TfmTokenizer** tokenizer, const char* tokenizer_path);

/** \brief Tokenize the input using the specified tokenizer
*
* \param tokenizer Pointer to the tokenizer object
* \param input Array of input strings
* \param batch_size Number of input strings in the batch
* \param output Pointer to store the tokenized result
* \return Error code indicating the success or failure of the operation
*/
tfmError_t TFM_API_CALL TfmTokenize(const TfmTokenizer* tokenizer, const char* input[], size_t batch_size, TfmTokenId2DArray** output);

/** \brief Detokenize the input using the specified tokenizer
*
* \param tokenizer Pointer to the tokenizer object
* \param input Pointer to the input token IDs
* \param output Pointer to store the detokenized result
* \return Error code indicating the success or failure of the operation
*/
tfmError_t TFM_API_CALL TfmDetokenize(const TfmTokenizer* tokenizer, const TfmTokenId2DArray* input, TfmStringArray** output);

/** \brief Detokenize the input using the specified tokenizer (1D version)
*
* \param tokenizer Pointer to the tokenizer object
* \param input Pointer to the input token IDs
* \param len Length of the input token IDs array
* \param output Pointer to store the detokenized result
* \return Error code indicating the success or failure of the operation
*/
tfmError_t TFM_API_CALL TfmDetokenize1D(const TfmTokenizer* tokenizer, const tfmTokenId_t* input, size_t len, TfmStringArray** output);

/** \brief Detokenize the input using the specified tokenizer with caching
*
* \param tokenizer Pointer to the tokenizer object
* \param cache Pointer to the detokenizer cache
* \param next_id Next token ID to detokenize
* \param text_out Pointer to store the detokenized text
* \return Error code indicating the success or failure of the operation
*/
tfmError_t TFM_API_CALL TfmDetokenizeCached(const TfmTokenizer* tokenizer, TfmDetokenizerCache* cache, tfmTokenId_t next_id, const char** text_out);

/** \brief Get the length of the string array
*
* \param string_array Pointer to the string array
* \param length Pointer to store the length of the string array
* \return Error code indicating the success or failure of the operation
*/
tfmError_t TFM_API_CALL TfmStringArrayGetBatch(const TfmStringArray* string_array, size_t* length);

/** \brief Get the item at the specified index from the string array
*
* \param string_array Pointer to the string array
* \param index Index of the item to retrieve
* \param item Pointer to store the retrieved item
* \return Error code indicating the success or failure of the operation
*/
tfmError_t TFM_API_CALL TfmStringArrayGetItem(const TfmStringArray* string_array, size_t index, const char** item);

/** \brief Get the batch size of the token ID 2D array
*
* \param token_id_2d_array Pointer to the token ID 2D array
* \param length Pointer to store the batch size
* \return Error code indicating the success or failure of the operation
*/
tfmError_t TFM_API_CALL TfmTokenId2DArrayGetBatch(const TfmTokenId2DArray* token_id_2d_array, size_t* length);

/** \brief Get the item at the specified index from the token ID 2D array
*
* \param token_id_2d_array Pointer to the token ID 2D array
* \param index Index of the item to retrieve
* \param item Pointer to store the retrieved item
* \param length Pointer to store the length of the item
* \return Error code indicating the success or failure of the operation
*/
tfmError_t TFM_API_CALL TfmTokenId2DArrayGetItem(const TfmTokenId2DArray* token_id_2d_array, size_t index, const tfmTokenId_t** item, size_t* length);

#ifdef __cplusplus
}
#endif
Loading

0 comments on commit 386515d

Please sign in to comment.