diff --git a/CMakeLists.txt b/CMakeLists.txt index f67e31e49..09e70a8da 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -797,7 +797,7 @@ if(_BUILD_SHARED_LIBRARY) standardize_output_folder(extensions_shared) if(LINUX OR ANDROID) - set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver") + # set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver") # strip if not a debug build if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-s") diff --git a/build.sh b/build.sh index 3b7379c4c..4522e66a5 100755 --- a/build.sh +++ b/build.sh @@ -27,4 +27,4 @@ if [ -n "$cuda_arch" ]; then param="$@ -DCMAKE_CUDA_ARCHITECTURE=$cuda_arch ../../.." fi # it looks the parallel build on CI pipeline machine causes crashes. -cmake $param && cmake --build . --config $BUILD_FLAVOR --parallel "${CPU_NUMBER}" +cmake "$@" ../../.. "-DOCOS_USE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" && cmake --build . --config $BUILD_FLAVOR --parallel "${CPU_NUMBER}" diff --git a/cmake/ext_tests.cmake b/cmake/ext_tests.cmake index 95d257dcc..cf6a36ad8 100644 --- a/cmake/ext_tests.cmake +++ b/cmake/ext_tests.cmake @@ -165,7 +165,9 @@ else() LIBRARIES ${extensions_test_libraries} TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data) - target_include_directories(extensions_test PRIVATE ${spm_INCLUDE_DIRS}) + target_include_directories(extensions_test PRIVATE ${spm_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + + target_link_libraries(extensions_test PRIVATE ocos_operators) target_compile_definitions(extensions_test PUBLIC ${OCOS_COMPILE_DEFINITIONS}) if(use_extensions_shared_library) diff --git a/includes/custom_op_lite.h b/includes/custom_op_lite.h index 784e2b2bd..b0143af38 100644 --- a/includes/custom_op_lite.h +++ b/includes/custom_op_lite.h @@ -2,202 +2,107 @@ // Licensed under the MIT License. #pragma once -#include "onnxruntime_customop.hpp" -#include "onnxruntime_f16.h" + #include #include +#include "tensor_api.h" +#include "onnxruntime_cpp_api_legacy.hpp" namespace Ort { namespace Custom { -class TensorBase { - public: - TensorBase(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : api_(api), - ctx_(ctx), - indice_(indice), - is_input_(is_input) {} - - virtual ~TensorBase() = default; - operator bool() const { - return shape_.has_value(); - } - const std::vector& Shape() const { - if (shape_.has_value()) { - return *shape_; - } else { - ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION); - } - } - ONNXTensorElementDataType Type() const { - return type_; - } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); - } else { - ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION); - } - } - std::string Shape2Str() const { - if (shape_.has_value()) { - std::string shape_str; - for (const auto& dim : *shape_) { - shape_str.append(std::to_string(dim)); - shape_str.append(", "); +class OrtKernelArg { +public: + OrtKernelArg(const OrtW::CustomOpApi& api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : api_(api), ctx_(ctx), indice_(indice) { + if (is_input) { + const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice); + const OrtMemoryInfo* mem_info = {}; + api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info)); + if (mem_info) { + api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); } - return shape_str; - } else { - return "empty"; } } + bool IsCpuTensor() const { return strcmp("Cpu", mem_type_) == 0; } - virtual const void* DataRaw() const = 0; - virtual size_t SizeInBytes() const = 0; - protected: +protected: const OrtW::CustomOpApi& api_; OrtKernelContext& ctx_; size_t indice_; - bool is_input_; - std::optional> shape_; - ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; const char* mem_type_ = "Cpu"; }; -template -struct Span { - const T* data_ = {}; - size_t size_ = {}; - void Assign(const T* data, size_t size) { - data_ = data; - size_ = size; - } - size_t size() const { return size_; } - T operator[](size_t indice) const { - return data_[indice]; - } - const T* data() const { return data_; } -}; - -#if ORT_API_VERSION >= 16 - -template <> -struct Span { - const MFloat16* data_ = {}; - size_t size_ = {}; - void Assign(const MFloat16* data, size_t size) { - data_ = data; - size_ = size; - } - size_t size() const { return size_; } - MFloat16 operator[](size_t indice) const { - return data_[indice]; - } - const MFloat16* data() const { return data_; } -}; - -template <> -struct Span { - const BFloat16* data_ = {}; - size_t size_ = {}; - void Assign(const BFloat16* data, size_t size) { - data_ = data; - size_ = size; - } - size_t size() const { return size_; } - BFloat16 operator[](size_t indice) const { - return data_[indice]; - } - const BFloat16* data() const { return data_; } -}; - -#endif - -template -class Tensor : public TensorBase { - public: - using TT = typename std::remove_reference::type; - Tensor(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : TensorBase(api, - ctx, - indice, - is_input) { - if (is_input) { - auto input_count = api_.KernelContext_GetInputCount(&ctx_); +class OrtKernelContextStorage : public ITensorStorage { +public: + OrtKernelContextStorage(const OrtW::CustomOpApi& api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : api_(api), ctx_(ctx), indice_(indice) { + if (is_input){ + auto input_count = api.KernelContext_GetInputCount(&ctx); if (indice >= input_count) { ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); } - const_value_ = api_.KernelContext_GetInput(&ctx_, indice); - auto* info = api_.GetTensorTypeAndShape(const_value_); - shape_ = api_.GetTensorShape(info); - type_ = api_.GetTensorElementType(info); - api_.ReleaseTensorTypeAndShapeInfo(info); - const OrtMemoryInfo* mem_info = {}; - api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info)); - if (mem_info) { - api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); - } + const_value_ = api.KernelContext_GetInput(&ctx, indice); + auto* info = api.GetTensorTypeAndShape(const_value_); + shape_ = api.GetTensorShape(info); + api.ReleaseTensorTypeAndShapeInfo(info); } } - const TT* Data() const { - return api_.GetTensorData(const_value_); - } - const void* DataRaw() const override { - return reinterpret_cast(Data()); + const std::vector& Shape() const override { + if (!IsInitialized()) + ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); + return *shape_; } - size_t SizeInBytes() const override { - return NumberOfElement() * sizeof(TT); + virtual bool IsInitialized() const override { + return shape_.has_value(); } - TT* Allocate(const std::vector& shape) { - if (!data_) { - OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size()); - shape_ = shape; - data_ = api_.GetTensorMutableData(out); - } - return data_; - } - const Span& AsSpan() { - if (!shape_.has_value() || shape_->size() != 1) { - ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); - } - span_.Assign(Data(), (*shape_)[0]); - return span_; + const void* DataRaw() const override { + return api_.GetTensorRawData(const_value_); } - const T& AsScalar() { - if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) { - ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); + + void* Initialize(const std::vector& shape, size_t element_size) override { + if (!const_value_) { + const_value_ = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size()); } - return *Data(); + return api_.GetTensorMutableRawData(const_cast(const_value_)); } - private: +private: + const OrtW::CustomOpApi& api_; + OrtKernelContext& ctx_; + size_t indice_; const OrtValue* const_value_{}; // for input - TT* data_{}; // for output - Span span_; + std::optional> shape_; }; -template <> -class Tensor : public TensorBase { - public: - using strings = std::vector; +template +class OrtTensor : public OrtKernelArg, public Tensor { +public: + OrtTensor(const OrtW::CustomOpApi& api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : OrtKernelArg(api, ctx, indice, is_input), + Tensor(std::make_unique(api, ctx, indice, is_input)){ + } +}; - Tensor(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : TensorBase(api, - ctx, - indice, - is_input) { +class OrtStringTensorStorage : public IStringTensorStorage{ +public: + using strings = std::vector; + OrtStringTensorStorage(const OrtW::CustomOpApi& api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : api_(api), ctx_(ctx), indice_(indice){ if (is_input) { auto input_count = api_.KernelContext_GetInputCount(&ctx_); if (indice >= input_count) { @@ -207,14 +112,14 @@ class Tensor : public TensorBase { auto* const_value = api_.KernelContext_GetInput(&ctx_, indice); auto* info = api_.GetTensorTypeAndShape(const_value); shape_ = api_.GetTensorShape(info); - type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; api_.ReleaseTensorTypeAndShapeInfo(info); size_t num_chars; OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars)); std::vector chars(num_chars + 1, '\0'); - auto num_strings = NumberOfElement(); - std::vector offsets(NumberOfElement()); + assert((*shape_).size() == 1); + auto num_strings = (*shape_)[0]; + std::vector offsets((*shape_)[0]); OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value, (void*)chars.data(), num_chars, @@ -230,22 +135,25 @@ class Tensor : public TensorBase { } } } - const strings& Data() const { - return input_strings_; + + const std::vector& Shape() const override { + if (!IsInitialized()) + ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); + return *shape_; } - const void* DataRaw() const override { + + virtual const void* DataRaw() const override { if (input_strings_.size() != 1) { ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); } return reinterpret_cast(input_strings_[0].c_str()); } - size_t SizeInBytes() const override { - if (input_strings_.size() != 1) { - ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); - } - return input_strings_[0].size(); + + virtual bool IsInitialized() const override { + return shape_.has_value(); } - void SetStringOutput(const strings& ss, const std::vector& dims) { + + virtual void SetStringOutput(const strings& ss, const std::vector& dims) override { std::vector raw; for (const auto& s : ss) { raw.push_back(s.data()); @@ -253,38 +161,33 @@ class Tensor : public TensorBase { auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size()); OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size())); } - void SetStringOutput(const std::vector& ss, const std::vector& dims) { + + virtual void SetStringOutput(const std::vector& ss, const std::vector& dims) override { auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size()); OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size())); } - const Span& AsSpan() { - ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION); - } - const std::string& AsScalar() { - if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) { - ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); - } - return input_strings_[0]; + + const strings& Data() const override { + return input_strings_; } - private: - std::vector input_strings_; // for input +private: + const OrtW::CustomOpApi& api_; + OrtKernelContext& ctx_; + size_t indice_; + std::vector input_strings_; + std::optional> shape_; }; -template <> -class Tensor : public TensorBase { - public: - using strings = std::vector; - using string_views = std::vector; - Tensor(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : TensorBase(api, - ctx, - indice, - is_input) { - if (is_input_) { +class OrtStringViewTensorStorage : public IStringTensorStorage{ +public: + using strings = std::vector; + OrtStringViewTensorStorage(const OrtW::CustomOpApi& api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : api_(api), ctx_(ctx), indice_(indice){ + if (is_input) { auto input_count = api_.KernelContext_GetInputCount(&ctx_); if (indice >= input_count) { ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); @@ -292,14 +195,13 @@ class Tensor : public TensorBase { auto* const_value = api_.KernelContext_GetInput(&ctx_, indice); auto* info = api_.GetTensorTypeAndShape(const_value); shape_ = api_.GetTensorShape(info); - type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; api_.ReleaseTensorTypeAndShapeInfo(info); size_t num_chars; OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars)); chars_.resize(num_chars + 1, '\0'); - auto num_strings = static_cast(NumberOfElement()); + auto num_strings = static_cast((*shape_)[0]); if (num_strings) { std::vector offsets(num_strings); OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value, @@ -314,188 +216,82 @@ class Tensor : public TensorBase { } } } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); - } else { - return 0; - } - } - const string_views& Data() const { - return input_string_views_; + + const std::vector& Shape() const override { + if (!IsInitialized()) + ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); + return *shape_; } - const void* DataRaw() const override { + + virtual const void* DataRaw() const override { if (input_string_views_.size() != 1) { ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); } return reinterpret_cast(input_string_views_[0].data()); } - size_t SizeInBytes() const override { - if (input_string_views_.size() != 1) { - ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); - } - return input_string_views_[0].size(); + + virtual bool IsInitialized() const override { + return shape_.has_value(); } - const Span& AsSpan() { - ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION); + + virtual void SetStringOutput(const strings& ss, const std::vector& dims) override { + ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION); } - std::string_view AsScalar() { - if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) { - ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); - } - return input_string_views_[0]; + + virtual void SetStringOutput(const std::vector& ss, const std::vector& dims) override { + ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION); } - private: + const strings& Data() const override { + return input_string_views_; + } + +private: + const OrtW::CustomOpApi& api_; + OrtKernelContext& ctx_; + size_t indice_; std::vector chars_; // for input std::vector input_string_views_; // for input + std::optional> shape_; }; -#if ORT_API_VERSION >= 16 - +// to make the metaprogramming magic happy. template <> -struct Tensor : public TensorBase { - Tensor(const OrtW::CustomOpApi& api, +class OrtTensor : public OrtKernelArg, + public Tensor{ +public: + OrtTensor(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : TensorBase(api, - ctx, - indice, - is_input) { - type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - if (is_input_) { - auto input_count = api_.KernelContext_GetInputCount(&ctx_); - if (indice >= input_count) { - ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); - } - const_value_ = api_.KernelContext_GetInput(&ctx_, indice); - auto* info = api_.GetTensorTypeAndShape(const_value_); - shape_ = api_.GetTensorShape(info); - type_ = api_.GetTensorElementType(info); - api_.ReleaseTensorTypeAndShapeInfo(info); - const OrtMemoryInfo* mem_info = {}; - api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info)); - if (mem_info) { - api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); - } - } - } - - const MFloat16* Data() const { - return reinterpret_cast(api_.GetTensorData(const_value_)); - } - - MFloat16* Allocate(const std::vector& shape) { - if (!data_) { - OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size()); - shape_ = shape; - data_ = reinterpret_cast(api_.GetTensorMutableData(out)); - } - return data_; - } - - const Span& AsSpan() { - ORTX_CXX_API_THROW("AsSpan for MFloat16 not implemented", ORT_RUNTIME_EXCEPTION); - } - - const MFloat16& AsScalar() { - ORTX_CXX_API_THROW("AsScalar for MFloat16 not implemented", ORT_RUNTIME_EXCEPTION); - } - - const void* DataRaw() const override { - return reinterpret_cast(Data()); - } - - virtual size_t SizeInBytes() const override { - return NumberOfElement() * sizeof(uint16_t); - } - - private: - const OrtValue* const_value_{}; // for input - MFloat16* data_{}; // for output + bool is_input) : OrtKernelArg(api, ctx, indice, is_input), + Tensor(std::make_unique(api, ctx, indice, is_input)) {} }; template <> -struct Tensor : public TensorBase { - Tensor(const OrtW::CustomOpApi& api, +class OrtTensor : public OrtKernelArg, + public Tensor{ +public: + OrtTensor(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : TensorBase(api, - ctx, - indice, - is_input) { - type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - if (is_input_) { - auto input_count = api_.KernelContext_GetInputCount(&ctx_); - if (indice >= input_count) { - ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); - } - const_value_ = api_.KernelContext_GetInput(&ctx_, indice); - auto* info = api_.GetTensorTypeAndShape(const_value_); - shape_ = api_.GetTensorShape(info); - type_ = api_.GetTensorElementType(info); - api_.ReleaseTensorTypeAndShapeInfo(info); - const OrtMemoryInfo* mem_info = {}; - api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info)); - if (mem_info) { - api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); - } - } - } - - const BFloat16* Data() const { - return reinterpret_cast(api_.GetTensorData(const_value_)); - } - - BFloat16* Allocate(const std::vector& shape) { - if (!data_) { - OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size()); - shape_ = shape; - data_ = reinterpret_cast(api_.GetTensorMutableData(out)); - } - return data_; - } - - const Span& AsSpan() { - ORTX_CXX_API_THROW("AsSpan for BFloat16 not implemented", ORT_RUNTIME_EXCEPTION); - } - - const BFloat16& AsScalar() { - ORTX_CXX_API_THROW("AsScalar for BFloat16 not implemented", ORT_RUNTIME_EXCEPTION); - } - - const void* DataRaw() const override { - return reinterpret_cast(Data()); - } - - virtual size_t SizeInBytes() const override { - return NumberOfElement() * sizeof(uint16_t); - } - - private: - const OrtValue* const_value_{}; // for input - BFloat16* data_{}; // for output + bool is_input) : OrtKernelArg(api, ctx, indice, is_input), + Tensor(std::make_unique(api, ctx, indice, is_input)) {} }; -#endif - -using TensorPtr = std::unique_ptr; +using TensorPtr = std::unique_ptr; using TensorPtrs = std::vector; // Represent variadic input or output -struct Variadic : public TensorBase { +struct Variadic : public OrtKernelArg, public Arg { Variadic(const OrtW::CustomOpApi& api, OrtKernelContext& ctx, size_t indice, - bool is_input) : TensorBase(api, - ctx, - indice, - is_input) { + bool is_input) : OrtKernelArg(api, ctx, indice, is_input) { #if ORT_API_VERSION < 14 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION); #endif if (is_input) { - auto input_count = api_.KernelContext_GetInputCount(&ctx_); + auto input_count = api.KernelContext_GetInputCount(&ctx_); for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input); auto* info = api_.GetTensorTypeAndShape(const_value); @@ -504,40 +300,40 @@ struct Variadic : public TensorBase { TensorPtr tensor; switch (type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api, ctx, ith_input, true); break; default: ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); @@ -551,22 +347,22 @@ struct Variadic : public TensorBase { } template T* AllocateOutput(size_t ith_output, const std::vector& shape) { - auto tensor = std::make_unique>(api_, ctx_, ith_output, false); + auto tensor = std::make_unique>(api_, ctx_, ith_output, false); auto raw_output = tensor.get()->Allocate(shape); tensors_.emplace_back(tensor.release()); return raw_output; } Tensor& AllocateStringTensor(size_t ith_output) { - auto tensor = std::make_unique>(api_, ctx_, ith_output, false); + auto tensor = std::make_unique>(api_, ctx_, ith_output, false); Tensor& output = *tensor; tensors_.emplace_back(tensor.release()); return output; } - const void* DataRaw() const override { + const void* DataRaw() const { ORTX_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION); return nullptr; } - size_t SizeInBytes() const override { + size_t SizeInBytes() const { ORTX_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION); return 0; } @@ -581,6 +377,36 @@ struct Variadic : public TensorBase { TensorPtrs tensors_; }; +class OrtGraphKernelContext : public KernelContext { +public: + OrtGraphKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) { + OrtMemoryInfo* info; + OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info)); + OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &allocator_)); + api.ReleaseMemoryInfo(info); + } + + virtual ~OrtGraphKernelContext(){ + if (allocator_){ + api_.ReleaseAllocator(allocator_); + } + } + + void* AllocScratchBuffer(size_t size) override{ + return allocator_->Alloc(allocator_, size); + } + + void FreeScratchBuffer(void* p) override { + if (p){ + allocator_->Free(allocator_, p); + } + } + +private: + const OrtApi& api_; + OrtAllocator* allocator_; +}; + #ifdef USE_CUDA enum CudaResource { @@ -616,6 +442,89 @@ struct CudaContext { int device_id = 0; }; + +class OrtGraphCudaKernelContext : public CUDAKernelContext { +public: + static const int cuda_resource_ver = 1; + + OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) { + api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_); + if (!cuda_stream_) { + ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION); + } + api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_); + if (!cublas_) { + ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION); + } + void* resource = nullptr; + OrtStatusPtr result = api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource); + if (result) { + ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION); + } + memcpy(&device_id_, &resource, sizeof(int)); + + OrtMemoryInfo* info; + OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info)); + OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_)); + api.ReleaseMemoryInfo(info); + + OrtMemoryInfo* cuda_mem_info; + OrtW::ThrowOnError(api, api.CreateMemoryInfo("GPU", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info)); + OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_)); + api.ReleaseMemoryInfo(cuda_mem_info); + + } + + virtual ~OrtGraphCudaKernelContext(){ + if (cpu_allocator_){ + api_.ReleaseAllocator(cpu_allocator_); + } + if (cuda_allocator_){ + api_.ReleaseAllocator(cuda_allocator_); + } + } + + void* AllocScratchBuffer(size_t size) override{ + return cpu_allocator_->Alloc(cpu_allocator_, size); + } + + void FreeScratchBuffer(void* p) override { + if (p){ + cpu_allocator_->Free(cpu_allocator_, p); + } + } + + void* AllocCudaScratchBuffer(size_t size) override { + return cuda_allocator_->Alloc(cuda_allocator_, size); + } + + void FreeCudaScratchBuffer(void* p) override { + if (p){ + cuda_allocator_->Free(cuda_allocator_, p); + } + } + + void* GetCudaStream() const override { + return cuda_stream_; + } + + void* GetCublasHandle() const override { + return cublas_; + } + + int GetCudaDeviceId() const override { + return device_id_; + } + +private: + const OrtApi& api_; + OrtAllocator* cpu_allocator_; + OrtAllocator* cuda_allocator_; + void* cuda_stream_ = {}; + void* cublas_ = {}; + int device_id_ = 0; +}; + #endif // using mf16_t = uint16_t; @@ -648,6 +557,24 @@ struct OrtLiteCustomOp : public OrtCustomOp { } #endif + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + tensors.push_back(std::make_unique(api->GetOrtApi(), *context)); + std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; + auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + tensors.push_back(std::make_unique(api->GetOrtApi(), *context)); + std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; + auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + #if ORT_API_VERSION >= 14 template static typename std::enable_if::value, std::tuple>::type @@ -690,7 +617,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if*>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ @@ -698,7 +625,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if&>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ @@ -707,7 +634,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { static typename std::enable_if*>>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ if (ith_input < num_input) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ @@ -720,8 +647,8 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if*>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ - if (!tensors.back()->IsCpuTensor()) { \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + if (!reinterpret_cast*>(tensors.back().get())->IsCpuTensor()) { \ ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \ } \ std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ @@ -731,8 +658,8 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if&>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ - if (!tensors.back()->IsCpuTensor()) { \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + if (!reinterpret_cast*>(tensors.back().get())->IsCpuTensor()) { \ ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \ } \ std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ @@ -743,8 +670,8 @@ struct OrtLiteCustomOp : public OrtCustomOp { static typename std::enable_if*>>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ if (ith_input < num_input) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ - if (!tensors.back()->IsCpuTensor()) { \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + if (!reinterpret_cast*>(tensors.back().get())->IsCpuTensor()) { \ ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \ } \ std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ @@ -759,8 +686,8 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ - if (!tensors.back()->IsCpuTensor()) { \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + if (!reinterpret_cast*>(tensors.back().get())->IsCpuTensor()) { \ ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \ } \ std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ @@ -771,8 +698,8 @@ struct OrtLiteCustomOp : public OrtCustomOp { static typename std::enable_if>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ if (ith_input < num_input) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ - if (!tensors.back()->IsCpuTensor()) { \ + tensors.push_back(std::make_unique>(*api, *context, ith_input, true)); \ + if (!reinterpret_cast*>(tensors.back().get())->IsCpuTensor()) { \ ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \ } \ std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ @@ -788,7 +715,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if*>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_output, false)); \ + tensors.push_back(std::make_unique>(*api, *context, ith_output, false)); \ std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ @@ -796,7 +723,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { template \ static typename std::enable_if&>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_output, false)); \ + tensors.push_back(std::make_unique>(*api, *context, ith_output, false)); \ std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ @@ -805,7 +732,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { static typename std::enable_if*>>::value, std::tuple>::type \ CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ if (ith_output < num_output) { \ - tensors.push_back(std::make_unique>(*api, *context, ith_output, false)); \ + tensors.push_back(std::make_unique>(*api, *context, ith_output, false)); \ std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ auto next = CreateTuple(api, context, tensors, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ @@ -857,6 +784,18 @@ struct OrtLiteCustomOp : public OrtCustomOp { } #endif + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } + #if ORT_API_VERSION >= 14 template static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type @@ -1088,6 +1027,20 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { ComputeFn compute_fn_; }; +class OrtAttributeReader { +public: + OrtAttributeReader(const OrtApi& api, const OrtKernelInfo& info) : base_kernel_(api, info) { + } + + template + T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept { + return base_kernel_.TryToGetAttributeWithDefault(name, default_value); + } + +private: + BaseKernel base_kernel_; +}; + template struct OrtLiteCustomStruct : public OrtLiteCustomOp { template @@ -1124,7 +1077,13 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->custom_op_ = std::make_unique(*ort_api, *info); + + if constexpr (std::is_constructible::value){ + kernel->custom_op_ = std::make_unique(*ort_api, *info); + } + else { + kernel->custom_op_ = std::make_unique(OrtAttributeReader(*ort_api, *info)); + } auto self = static_cast(this_); kernel->ep_ = self->execution_provider_; kernel->api_ = std::make_unique(*ort_api); diff --git a/includes/kernel_context.h b/includes/kernel_context.h new file mode 100644 index 000000000..520056503 --- /dev/null +++ b/includes/kernel_context.h @@ -0,0 +1,34 @@ +#pragma once +#include +#include +#include + +namespace Ort { +namespace Custom { + +// this is for the ORT custom op template magic +class Arg { +}; + +class KernelContext : public Arg{ +public: + virtual void* AllocScratchBuffer(size_t size) = 0; + virtual void FreeScratchBuffer(void* p) = 0; + // TODO: threadpool? +}; + +#ifdef USE_CUDA +class CUDAKernelContext : public KernelContext { +public: + virtual void* AllocCudaScratchBuffer(size_t size) = 0; + virtual void FreeCudaScratchBuffer(void* p) = 0; + virtual void* GetCudaStream() const = 0; + virtual void* GetCublasHandle() const = 0; + virtual int GetCudaDeviceId() const = 0; +}; +#endif + +// TODO: helper func to create context from global ORT env. + +} +} \ No newline at end of file diff --git a/includes/onnxruntime_cpp_api_legacy.hpp b/includes/onnxruntime_cpp_api_legacy.hpp index f967b0b46..ddacb70d1 100644 --- a/includes/onnxruntime_cpp_api_legacy.hpp +++ b/includes/onnxruntime_cpp_api_legacy.hpp @@ -30,6 +30,9 @@ struct CustomOpApi { template const T* GetTensorData(_Inout_ const OrtValue* value) const; + void* GetTensorMutableRawData(_Inout_ OrtValue* value) const; + const void* GetTensorRawData(_Inout_ const OrtValue* value) const; + std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const; void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const; size_t KernelContext_GetInputCount(const OrtKernelContext* context) const; @@ -37,7 +40,7 @@ struct CustomOpApi { size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const; OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) const; - + void ThrowOnError(OrtStatus* status) const { OrtW::ThrowOnError(api_, status); } @@ -162,6 +165,16 @@ inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const return GetTensorMutableData(const_cast(value)); } +inline void* CustomOpApi::GetTensorMutableRawData(_Inout_ OrtValue* value) const { + void* data = nullptr; + ThrowOnError(api_.GetTensorMutableData(value, &data)); + return data; +} + +inline const void* CustomOpApi::GetTensorRawData(_Inout_ const OrtValue* value) const { + return GetTensorMutableRawData(const_cast(value)); +} + inline std::vector CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const { std::vector output(GetDimensionsCount(info)); GetDimensions(info, output.data(), output.size()); diff --git a/includes/tensor_api.h b/includes/tensor_api.h new file mode 100644 index 000000000..17f456189 --- /dev/null +++ b/includes/tensor_api.h @@ -0,0 +1,492 @@ +#pragma once +#include +#include +#include +#include "onnxruntime_customop.hpp" +#include "onnxruntime_f16.h" +#include "kernel_context.h" + +namespace Ort { +namespace Custom { + +template +struct Span { + const T* data_ = {}; + size_t size_ = {}; + void Assign(const T* data, size_t size) { + data_ = data; + size_ = size; + } + size_t size() const { return size_; } + T operator[](size_t indice) const { + return data_[indice]; + } + const T* data() const { return data_; } +}; + + +#if ORT_API_VERSION >= 16 + +template <> +struct Span { + const MFloat16* data_ = {}; + size_t size_ = {}; + void Assign(const MFloat16* data, size_t size) { + data_ = data; + size_ = size; + } + size_t size() const { return size_; } + MFloat16 operator[](size_t indice) const { + return data_[indice]; + } + const MFloat16* data() const { return data_; } +}; + +template <> +struct Span { + const BFloat16* data_ = {}; + size_t size_ = {}; + void Assign(const BFloat16* data, size_t size) { + data_ = data; + size_ = size; + } + size_t size() const { return size_; } + BFloat16 operator[](size_t indice) const { + return data_[indice]; + } + const BFloat16* data() const { return data_; } +}; + +#endif + +class ITensorStorage{ +public: + virtual const std::vector& Shape() const = 0; + virtual const void* DataRaw() const = 0; + virtual bool IsInitialized() const = 0; + virtual void* Initialize(const std::vector& shape, size_t element_size) = 0; +}; + + +class IAllocator { +public: + virtual void* Alloc(size_t size) = 0; + virtual void Free(void* p) = 0; +}; + +// TODO: remove this +class TestAllocator : public IAllocator { +public: + void* Alloc(size_t size) override { + return malloc(size); + } + + void Free(void* p) override { + if (p){ + free(p); + } + } +}; + +class OrtEagerTensorStorage : public ITensorStorage { +public: + OrtEagerTensorStorage(const std::vector& shape, + void* buffer) : buffer_(buffer), shape_(shape){ + + } + + OrtEagerTensorStorage(IAllocator* allocator) : allocator_(allocator){ + } + + virtual ~OrtEagerTensorStorage(){ + if (allocator_ && buffer_) + allocator_->Free(buffer_); + } + + const std::vector& Shape() const override { + if (!IsInitialized()) + ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); + return *shape_; + } + + virtual bool IsInitialized() const override { + return shape_.has_value(); + } + + const void* DataRaw() const override { + return buffer_; + } + + void* Initialize(const std::vector& shape, size_t element_size) override { + if (IsInitialized()) + return buffer_; + assert(allocator_); + shape_ = shape; + int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); + auto buffer_size = n_elem * element_size; + buffer_ = allocator_->Alloc(buffer_size); + return buffer_; + } + +private: + void* buffer_ {}; + std::optional> shape_; + // caller need to make sure the allocator is alive + IAllocator* allocator_; +}; + +template +class Tensor : public Arg { + public: + using TT = typename std::remove_reference::type; + Tensor(std::unique_ptr tensor_storage) : storage_(std::move(tensor_storage)){ + } + + Tensor(const std::vector& shape, void* buffer) : Tensor(std::make_unique(shape, buffer)) {} + + Tensor(IAllocator* allocator) : storage_(std::make_unique(allocator)){} + + virtual ~Tensor() = default; + + operator bool() const { + return storage_->IsInitialized(); + } + + const std::vector& Shape() const { + return storage_->Shape(); + } + + int64_t NumberOfElement() const { + auto& shape = storage_->Shape(); + return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); + } + + std::string Shape2Str() const { + if (storage_->IsInitialized()) { + std::string shape_str; + auto& shape = storage_->Shape(); + for (const auto& dim : shape) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + + const TT* Data() const { +#if ORT_API_VERSION >= 16 + if constexpr (std::is_same::value || std::is_same::value) + return reinterpret_cast(storage_->DataRaw()); + else +#endif + return static_cast(storage_->DataRaw()); + } + + const void* DataRaw() const { + return storage_->DataRaw(); + } + + size_t SizeInBytes() const { + return NumberOfElement() * sizeof(TT); + } + + TT* Allocate(const std::vector& shape) { + // it should be OK to allocate multiple times + void* buffer = storage_->Initialize(shape, sizeof(TT)); +#if ORT_API_VERSION >= 16 + if constexpr (std::is_same::value || std::is_same::value) + return reinterpret_cast(buffer); + else +#endif + return static_cast(buffer); + } + + const Span& AsSpan() { +#if ORT_API_VERSION >= 16 + if constexpr (std::is_same::value || std::is_same::value) { + ORTX_CXX_API_THROW("AsSpan for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION); + } + else{ +#endif + auto& shape = storage_->Shape(); + if (shape.size() != 1) { + ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); + } + span_.Assign(Data(), shape[0]); + return span_; +#if ORT_API_VERSION >= 16 + } +#endif + } + + const T& AsScalar() { +#if ORT_API_VERSION >= 16 + if constexpr (std::is_same::value || std::is_same::value) { + ORTX_CXX_API_THROW("AsScalar for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION); + } + else{ +#endif + auto& shape = storage_->Shape(); + if ((shape.size() == 1 && shape[0] != 1) || shape.size() > 1) { + ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); + } + return *Data(); +#if ORT_API_VERSION >= 16 + } +#endif + } + + private: + std::unique_ptr storage_; + Span span_; +}; + +template +class IStringTensorStorage{ +public: + using strings = std::vector; + virtual const std::vector& Shape() const = 0; + virtual const void* DataRaw() const = 0; + virtual const strings& Data() const = 0; + virtual bool IsInitialized() const = 0; + virtual void SetStringOutput(const strings& ss, const std::vector& dims) = 0; + virtual void SetStringOutput(const std::vector& ss, const std::vector& dims) = 0; +}; + +template +class EagerStringTensorStorage : public IStringTensorStorage{ +public: + using strings = std::vector; + EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector{static_cast(ss.size())}){} + + EagerStringTensorStorage() {} + + const std::vector& Shape() const override { + if (!IsInitialized()) + ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); + return *shape_; + } + + virtual const void* DataRaw() const override { + if (input_strings_.size() != 1) { + ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + if constexpr (std::is_same::value) + return reinterpret_cast(input_strings_[0].data()); + else + return reinterpret_cast(input_strings_[0].c_str()); + } + + virtual bool IsInitialized() const override { + return shape_.has_value(); + } + + virtual void SetStringOutput(const strings& ss, const std::vector& dims) override { + if constexpr (std::is_same::value) + ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION); + input_strings_.assign(ss.begin(), ss.end()); + shape_ = dims; + } + + const strings& Data() const override { + return input_strings_; + } + + virtual void SetStringOutput(const std::vector& ss, const std::vector& dims) override { + if constexpr (std::is_same::value) + ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION); + + for (const char* s : ss){ + input_strings_.push_back(s); + } + shape_ = dims; + } + +private: + std::vector input_strings_; + std::optional> shape_; +}; + +template <> +class Tensor : public Arg { + public: + using strings = std::vector; + + Tensor(std::unique_ptr> storage) : storage_(std::move(storage)) {} + + Tensor(const strings& ss) : storage_(std::make_unique>(ss)) {} + + Tensor() : storage_(std::make_unique>()) {} + + const strings& Data() const { + return storage_->Data(); + } + + const std::vector& Shape() const { + return storage_->Shape(); + } + + int64_t NumberOfElement() const { + auto& shape = storage_->Shape(); + return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); + } + + std::string Shape2Str() const { + if (storage_->IsInitialized()) { + std::string shape_str; + auto& shape = storage_->Shape(); + for (const auto& dim : shape) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + + const void* DataRaw() const { + return storage_->DataRaw(); + } + + size_t SizeInBytes() const { + auto& ss = storage_->Data(); + if (ss.size() != 1) { + ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return ss[0].size(); + } + + void SetStringOutput(const strings& ss, const std::vector& dims) { + storage_->SetStringOutput(ss, dims); + } + void SetStringOutput(const std::vector& ss, const std::vector& dims) { + storage_->SetStringOutput(ss, dims); + } + const Span& AsSpan() { + ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION); + } + const std::string& AsScalar() { + auto& ss = storage_->Data(); + if (ss.size() != 1) { + ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); + } + return ss[0]; + } + + private: + std::unique_ptr> storage_; +}; + + +template <> +class Tensor : public Arg { + public: + using strings = std::vector; + + Tensor(std::unique_ptr> storage) : storage_(std::move(storage)) {} + + Tensor(const strings& ss) : storage_(std::make_unique>(ss)) {} + + const strings& Data() const { + return storage_->Data(); + } + + const std::vector& Shape() const { + return storage_->Shape(); + } + + int64_t NumberOfElement() const { + auto& shape = storage_->Shape(); + return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); + } + + std::string Shape2Str() const { + if (storage_->IsInitialized()) { + std::string shape_str; + auto& shape = storage_->Shape(); + for (const auto& dim : shape) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + + const void* DataRaw() const { + return storage_->DataRaw(); + } + + size_t SizeInBytes() const { + auto& ss = storage_->Data(); + if (ss.size() != 1) { + ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return ss[0].size(); + } + + void SetStringOutput(const strings& ss, const std::vector& dims) { + storage_->SetStringOutput(ss, dims); + } + void SetStringOutput(const std::vector& ss, const std::vector& dims) { + storage_->SetStringOutput(ss, dims); + } + const Span& AsSpan() { + ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION); + } + const std::string_view& AsScalar() { + auto& ss = storage_->Data(); + if (ss.size() != 1) { + ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); + } + return ss[0]; + } + + private: + std::unique_ptr> storage_; +}; + + +template +class NamedArgumentDict{ +public: + using ValueTuple = std::tuple; + + NamedArgumentDict(const std::vector& keys, const std::tuple& args) : names_(keys), entries_(args) { + } + + template + T TryToGetAttributeWithDefault(const char* name, const T& default_value) const { + return TryToGetAttributeWithDefaultInternal<0>(name, default_value); + } + +private: + template + typename std::enable_if::type + TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const { + return default_value; + } + + template + typename std::enable_if::type + TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const { + if (names_[I] == name){ + if constexpr (std::is_same, T>::value) + return std::get(entries_); + else + throw std::runtime_error("name matched but type is not"); + } + return TryToGetAttributeWithDefaultInternal(name, default_value); + } + + std::vector names_; + std::tuple entries_; + +}; + +} +} diff --git a/operators/math/cuda/negpos_def.cc b/operators/math/cuda/negpos_def.cc index b1a78b8be..9d9c6e16c 100644 --- a/operators/math/cuda/negpos_def.cc +++ b/operators/math/cuda/negpos_def.cc @@ -4,7 +4,7 @@ #include #include -OrtStatusPtr neg_pos_cuda(const Ort::Custom::CudaContext& ctx, +OrtStatusPtr neg_pos_cuda(Ort::Custom::CUDAKernelContext& ctx, const ortc::Tensor& input, ortc::Tensor& out0_tensor, ortc::Tensor& out1_tensor) { @@ -13,6 +13,6 @@ OrtStatusPtr neg_pos_cuda(const Ort::Custom::CudaContext& ctx, float* out1 = out1_tensor.Allocate(input.Shape()); const float* X = input.Data(); - neg_pos_impl(reinterpret_cast(ctx.cuda_stream), X, out0, out1, size); + neg_pos_impl(reinterpret_cast(ctx.GetCudaStream()), X, out0, out1, size); return nullptr; } diff --git a/operators/math/cuda/negpos_def.h b/operators/math/cuda/negpos_def.h index 3ae0f4ef9..5479c7ada 100644 --- a/operators/math/cuda/negpos_def.h +++ b/operators/math/cuda/negpos_def.h @@ -4,7 +4,7 @@ #pragma once #include "ocos.h" -OrtStatusPtr neg_pos_cuda(const Ort::Custom::CudaContext& ctx, +OrtStatusPtr neg_pos_cuda(Ort::Custom::CUDAKernelContext& ctx, const ortc::Tensor& input, ortc::Tensor& out0_tensor, ortc::Tensor& out1_tensor); diff --git a/operators/math/negpos.hpp b/operators/math/negpos.hpp index 62dc4f34e..83b82534f 100644 --- a/operators/math/negpos.hpp +++ b/operators/math/negpos.hpp @@ -24,3 +24,4 @@ OrtStatusPtr neg_pos(const ortc::Tensor& input, return nullptr; } + diff --git a/operators/tokenizer/basic_tokenizer.cc b/operators/tokenizer/basic_tokenizer.cc index 8c9a11f8d..3a2a9e06c 100644 --- a/operators/tokenizer/basic_tokenizer.cc +++ b/operators/tokenizer/basic_tokenizer.cc @@ -81,16 +81,16 @@ std::vector BasicTokenizer::Tokenize(ustring text) { return result; } -KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) { - bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true); - bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true); - bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false); - bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false); - bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true); - - tokenizer_ = std::make_shared(do_lower_case, tokenize_chinese_chars, strip_accents, - tokenize_punctuation, remove_control_chars); -} +// KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) { +// bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true); +// bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true); +// bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false); +// bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false); +// bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true); + +// tokenizer_ = std::make_shared(do_lower_case, tokenize_chinese_chars, strip_accents, +// tokenize_punctuation, remove_control_chars); +// } void KernelBasicTokenizer::Compute(std::string_view input, ortc::Tensor& output) const { diff --git a/operators/tokenizer/basic_tokenizer.hpp b/operators/tokenizer/basic_tokenizer.hpp index 713bd956f..85c71fab8 100644 --- a/operators/tokenizer/basic_tokenizer.hpp +++ b/operators/tokenizer/basic_tokenizer.hpp @@ -21,8 +21,20 @@ class BasicTokenizer { bool remove_control_chars_; }; -struct KernelBasicTokenizer : BaseKernel { - KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info); +struct KernelBasicTokenizer { + + template + KernelBasicTokenizer(const T& dict) { + bool do_lower_case = dict.TryToGetAttributeWithDefault("do_lower_case", true); + bool tokenize_chinese_chars = dict.TryToGetAttributeWithDefault("tokenize_chinese_chars", true); + bool strip_accents = dict.TryToGetAttributeWithDefault("strip_accents", false); + bool tokenize_punctuation = dict.TryToGetAttributeWithDefault("tokenize_punctuation", false); + bool remove_control_chars = dict.TryToGetAttributeWithDefault("remove_control_chars", true); + + tokenizer_ = std::make_shared(do_lower_case, tokenize_chinese_chars, strip_accents, + tokenize_punctuation, remove_control_chars); + } + void Compute(std::string_view input, ortc::Tensor& output) const; diff --git a/operators/tokenizer/bert_tokenizer_decoder.cc b/operators/tokenizer/bert_tokenizer_decoder.cc index e8742c53c..d03131fd3 100644 --- a/operators/tokenizer/bert_tokenizer_decoder.cc +++ b/operators/tokenizer/bert_tokenizer_decoder.cc @@ -119,22 +119,24 @@ bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new return false; } -KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) { - std::string vocab = ort_.KernelInfoGetAttribute(&info, "vocab_file"); - std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]")); - std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]")); - std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]")); - std::string cls_token = TryToGetAttributeWithDefault("cls_token", std::string("[CLS]")); - std::string mask_token = TryToGetAttributeWithDefault("mask_token", std::string("[MASK]")); - std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##")); - - use_indices_ = TryToGetAttributeWithDefault("use_indices", false); - skip_special_tokens_ = TryToGetAttributeWithDefault("skip_special_tokens", false); - clean_up_tokenization_spaces_ = TryToGetAttributeWithDefault("clean_up_tokenization_spaces", true); - - decoder_ = std::make_shared(vocab, unk_token, sep_token, pad_token, - cls_token, mask_token, suffix_indicator); -} +// template +// KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const T& dict) { +// //std::string vocab = ort_.KernelInfoGetAttribute(&info, "vocab_file"); +// std::string vocab = dict.TryToGetAttributeWithDefault("vocab_file", std::string("")); +// std::string unk_token = dict.TryToGetAttributeWithDefault("unk_token", std::string("[UNK]")); +// std::string sep_token = dict.TryToGetAttributeWithDefault("sep_token", std::string("[SEP]")); +// std::string pad_token = dict.TryToGetAttributeWithDefault("pad_token", std::string("[PAD]")); +// std::string cls_token = dict.TryToGetAttributeWithDefault("cls_token", std::string("[CLS]")); +// std::string mask_token = dict.TryToGetAttributeWithDefault("mask_token", std::string("[MASK]")); +// std::string suffix_indicator = dict.TryToGetAttributeWithDefault("suffix_indicator", std::string("##")); + +// use_indices_ = dict.TryToGetAttributeWithDefault("use_indices", false); +// skip_special_tokens_ = dict.TryToGetAttributeWithDefault("skip_special_tokens", false); +// clean_up_tokenization_spaces_ = dict.TryToGetAttributeWithDefault("clean_up_tokenization_spaces", true); + +// decoder_ = std::make_shared(vocab, unk_token, sep_token, pad_token, +// cls_token, mask_token, suffix_indicator); +// } void KernelBertTokenizerDecoder::Compute(const ortc::Tensor& ids, const ortc::Tensor& positions, diff --git a/operators/tokenizer/bert_tokenizer_decoder.hpp b/operators/tokenizer/bert_tokenizer_decoder.hpp index 16441c484..e5aa34859 100644 --- a/operators/tokenizer/bert_tokenizer_decoder.hpp +++ b/operators/tokenizer/bert_tokenizer_decoder.hpp @@ -29,8 +29,27 @@ class BertTokenizerDecoder { bool RemoveTokenizeSpace(int64_t pre_token_id, int64_t new_token_id); }; -struct KernelBertTokenizerDecoder : BaseKernel { - KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo& info); +struct KernelBertTokenizerDecoder { + + template + KernelBertTokenizerDecoder(const T& dict) { + //std::string vocab = ort_.KernelInfoGetAttribute(&info, "vocab_file"); + std::string vocab = dict.TryToGetAttributeWithDefault("vocab_file", std::string("")); + std::string unk_token = dict.TryToGetAttributeWithDefault("unk_token", std::string("[UNK]")); + std::string sep_token = dict.TryToGetAttributeWithDefault("sep_token", std::string("[SEP]")); + std::string pad_token = dict.TryToGetAttributeWithDefault("pad_token", std::string("[PAD]")); + std::string cls_token = dict.TryToGetAttributeWithDefault("cls_token", std::string("[CLS]")); + std::string mask_token = dict.TryToGetAttributeWithDefault("mask_token", std::string("[MASK]")); + std::string suffix_indicator = dict.TryToGetAttributeWithDefault("suffix_indicator", std::string("##")); + + use_indices_ = dict.TryToGetAttributeWithDefault("use_indices", false); + skip_special_tokens_ = dict.TryToGetAttributeWithDefault("skip_special_tokens", false); + clean_up_tokenization_spaces_ = dict.TryToGetAttributeWithDefault("clean_up_tokenization_spaces", true); + + decoder_ = std::make_shared(vocab, unk_token, sep_token, pad_token, + cls_token, mask_token, suffix_indicator); + } + void Compute(const ortc::Tensor& ids, const ortc::Tensor& positions, ortc::Tensor& output) const; diff --git a/test/shared_test/test_ortops_cuda.cc b/test/shared_test/test_ortops_cuda.cc index dc9ae35b2..c62977990 100644 --- a/test/shared_test/test_ortops_cuda.cc +++ b/test/shared_test/test_ortops_cuda.cc @@ -6,8 +6,12 @@ #include "gtest/gtest.h" #include "ocos.h" #include "test_kernel.hpp" +#include "kernel_context.h" #ifdef USE_CUDA +#include "operators/math/cuda/negpos_def.h" +#include +#include TEST(CudaOp, test_fastgelu) { auto ort_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "Default"); @@ -35,4 +39,55 @@ TEST(CudaOp, test_fastgelu) { TestInference(*ort_env, model_path.c_str(), inputs, outputs); } +class MockCudaKernelContext : public Ort::Custom::CUDAKernelContext { +public: + MockCudaKernelContext() { cudaStreamCreate(&stream); } + ~MockCudaKernelContext() { cudaStreamDestroy(stream); } + void* AllocScratchBuffer(size_t size) override { return nullptr; } + void FreeScratchBuffer(void* p) override {} + void* AllocCudaScratchBuffer(size_t size) override { return nullptr; } + void FreeCudaScratchBuffer(void* p) override {} + void* GetCudaStream() const override { return static_cast(stream); } + void* GetCublasHandle() const override { return nullptr; } + int GetCudaDeviceId() const override { return 0; } + +private: + cudaStream_t stream; +}; + +class CudaAllocator : public Ort::Custom::IAllocator { +public: + void* Alloc(size_t size) override { + void* p = nullptr; + cudaMalloc((void**)&p, size); + return p; + } + void Free(void* p) override { cudaFree(p); } +}; + +TEST(CudaOp, test_eager_negpos) { + MockCudaKernelContext mock_cuda_kc; + std::vector input_data = {0.0f, 0.2f, -1.3f, 1.5f}; + std::unique_ptr cuda_alloc = std::make_unique(); + void* device_input = cuda_alloc->Alloc(sizeof(float) * input_data.size()); + cudaMemcpyAsync(device_input, input_data.data(), sizeof(float)*input_data.size(), cudaMemcpyHostToDevice, static_cast(mock_cuda_kc.GetCudaStream())); + + ortc::Tensor input(std::vector{2, 2}, device_input); + ortc::Tensor output1(cuda_alloc.get()); + ortc::Tensor output2(cuda_alloc.get()); + neg_pos_cuda(mock_cuda_kc, input, output1, output2); + + float* host_output1 = (float*)malloc(sizeof(float) * input_data.size()); + float* host_output2 = (float*)malloc(sizeof(float) * input_data.size()); + cudaMemcpyAsync(host_output1, output1.DataRaw(), sizeof(float)*input_data.size(), cudaMemcpyDeviceToHost, static_cast(mock_cuda_kc.GetCudaStream())); + cudaMemcpyAsync(host_output2, output2.DataRaw(), sizeof(float)*input_data.size(), cudaMemcpyDeviceToHost, static_cast(mock_cuda_kc.GetCudaStream())); + ASSERT_NEAR(host_output1[1], input_data[1], 0.01f); + ASSERT_NEAR(host_output2[2], input_data[2], 0.01f); + ASSERT_NEAR(host_output1[3], input_data[3], 0.01f); + + cuda_alloc->Free(device_input); + free(host_output1); + free(host_output2); +} + #endif \ No newline at end of file diff --git a/test/shared_test/test_ortops_math.cc b/test/shared_test/test_ortops_math.cc index 72da1f985..c8168083c 100644 --- a/test/shared_test/test_ortops_math.cc +++ b/test/shared_test/test_ortops_math.cc @@ -6,6 +6,21 @@ #include "ocos.h" #include "test_kernel.hpp" +#include "operators/math/negpos.hpp" + +TEST(math_operator, eager_poc){ + auto test_allocator = std::make_unique(); + std::vector input_data = {0.0f, 0.2f, -1.3f, 1.5f}; + + ortc::Tensor input(std::vector{2, 2}, input_data.data()); + + ortc::Tensor output1(test_allocator.get()); + ortc::Tensor output2(test_allocator.get()); + + auto result = neg_pos(input, output1, output2); + assert(!result); + assert(output1.Shape() == input.Shape() && output2.Shape() == input.Shape()); +} TEST(math_operator, segment_extraction) { auto ort_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "Default"); diff --git a/test/shared_test/test_ortops_tokenizer.cc b/test/shared_test/test_ortops_tokenizer.cc index 5933f358f..b87e3d557 100644 --- a/test/shared_test/test_ortops_tokenizer.cc +++ b/test/shared_test/test_ortops_tokenizer.cc @@ -7,6 +7,23 @@ #include "ocos.h" #include "test_kernel.hpp" +#include "operators/tokenizer/basic_tokenizer.hpp" + +TEST(basic_tokenizer, eager) { + std::string test_case = "I mean, you’ll need something to talk about next Sunday, right?"; + std::vector expect_result = {"I", "mean", ",", "you", "’", "ll", "need", "something", "to", "talk", "about", "next", "Sunday", ",", "right", "?"}; + + ortc::NamedArgumentDict dict({"do_lower_case", "tokenize_chinese_chars", "strip_accents", "tokenize_punctuation", "remove_control_chars"}, + std::make_tuple(false, true, true, true, true)); + + KernelBasicTokenizer tokenizer(dict); + + //ortc::Tensor input(std::vector{test_case}); + ortc::Tensor output; + tokenizer.Compute(test_case, output); + EXPECT_EQ(output.Data(), expect_result); +} + TEST(tokenizer_opertors, test_bert_tokenizer) { auto ort_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "Default");