Skip to content

Commit

Permalink
refactor ORT-Extension for the coming GroupQueryAttention work (#674)
Browse files Browse the repository at this point in the history
* refactor ORT-Extension for the coming GroupQueryAttention work

* fix typo and add #if ORT_API_VERSION >= 15 for GetOrtAllocator

* fix cuda build
  • Loading branch information
jslhcl authored Mar 20, 2024
1 parent 2321329 commit 2234001
Show file tree
Hide file tree
Showing 17 changed files with 326 additions and 134 deletions.
117 changes: 117 additions & 0 deletions base/ortx_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <locale>
#include <optional>
#include <string>
#include <sstream>
#include "string_utils.h"
#ifdef _WIN32
#include <Windows.h>
#endif

#define ORTX_RETURN_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if (_status != nullptr) { \
return _status; \
} \
} while (0)

template <typename T>
bool TryParseStringWithClassicLocale(std::string_view str, T& value) {
if constexpr (std::is_integral<T>::value && std::is_unsigned<T>::value) {
// if T is unsigned integral type, reject negative values which will wrap
if (!str.empty() && str[0] == '-') {
return false;
}
}

// don't allow leading whitespace
if (!str.empty() && std::isspace(str[0], std::locale::classic())) {
return false;
}

std::istringstream is{std::string{str}};
is.imbue(std::locale::classic());
T parsed_value{};

const bool parse_successful =
is >> parsed_value &&
is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters
if (!parse_successful) {
return false;
}

value = std::move(parsed_value);
return true;
}

inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) {
value = str;
return true;
}

inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) {
if (str == "0" || str == "False" || str == "false") {
value = false;
return true;
}

if (str == "1" || str == "True" || str == "true") {
value = true;
return true;
}

return false;
}

template <typename T>
std::optional<T> ParseEnvironmentVariable(const std::string& name) {
std::string buffer;
#ifdef _WIN32
constexpr size_t kBufferSize = 32767;

// Create buffer to hold the result
buffer.resize(kBufferSize, '\0');

// The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters.
// If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character.
// Therefore, If the function succeeds, kBufferSize should be larger than char_count.
auto char_count = GetEnvironmentVariableA(name.c_str(), buffer.data(), kBufferSize);

if (kBufferSize > char_count) {
buffer.resize(char_count);
} else {
// Else either the call was failed, or the buffer wasn't large enough.
// TODO: Understand the reason for failure by calling GetLastError().
// If it is due to the specified environment variable being found in the environment block,
// GetLastError() returns ERROR_ENVVAR_NOT_FOUND.
// For now, we assume that the environment variable is not found.
buffer.clear();
}
#else
char* val = getenv(name.c_str());
buffer = (val == nullptr) ? std::string() : std::string(val);
#endif
T parsed_value;
if (!TryParseStringWithClassicLocale(buffer, parsed_value)) {
OrtW::Exception(MakeString("Failed to parse environment variable - name: ", name, ", value: ", buffer), OrtErrorCode::ORT_FAIL);
}
return parsed_value;
}

template <typename T>
T ParseEnvironmentVariableWithDefault(const std::string& name, const T& default_value) {
const auto parsed = ParseEnvironmentVariable<T>(name);
if (parsed.has_value()) {
return *parsed;
}

return default_value;
}

inline bool IsScalarOr1ElementVector(size_t num_dimensions, int64_t shape_size) {
if (num_dimensions == 0 || (num_dimensions == 1 && shape_size == 1)) return true;
return false;
}
2 changes: 1 addition & 1 deletion base/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#pragma once
#include <sstream>
#include <vector>
#include "ocos.h"
#include "onnxruntime_cpp_api_legacy.hpp"

template <typename T>
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
Expand Down
2 changes: 1 addition & 1 deletion docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The package contains all custom operators and some Python scripts to manipulate
- no-opencv: disable operators based on OpenCV in build.
- cc-debug: Generate debug info for extensions binaries and disable C/C++ compiler optimization.

For example:`pip install --config-settings "ortx-user-option=use-cuda,cc-debug" `, This command builds CUDA kernels into the package and installs it, accompanied by the generation of debug information.
For example:`pip install . --config-settings "ortx-user-option=use-cuda,cc-debug" `, This command builds CUDA kernels into the package and installs it, accompanied by the generation of debug information.

Test:

Expand Down
17 changes: 17 additions & 0 deletions includes/custom_op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,11 @@ struct Variadic : public TensorBase {

enum CudaResource {
cuda_handle_t = 10000,
cudnn_handle_t,
cublas_handle_t,
deferred_cpu_allocator_t,
// below are cuda ep options
device_id_t,
};

struct CudaContext {
Expand All @@ -595,8 +600,20 @@ struct CudaContext {
if (!cuda_stream) {
ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
}
ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas);
if (!cublas) {
ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION);
}
void* resource = nullptr;
OrtStatusPtr result = ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
if (result) {
ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION);
}
memcpy(&device_id, &resource, sizeof(int));
}
void* cuda_stream = {};
void* cublas = {};
int device_id = 0;
};

#endif
Expand Down
2 changes: 1 addition & 1 deletion includes/onnxruntime_cpp_api_legacy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#pragma once
#include <vector>
#include "onnxruntime_c_api.h"
#include "exceptions.h"

//
// DEPRECATED: All new custom OPs should not use any class/struct/functions from this file.
Expand Down
118 changes: 14 additions & 104 deletions includes/onnxruntime_customop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,118 +15,16 @@
#include <utility>
#include <type_traits>
#include <optional>
#include <functional>

#include "onnxruntime_c_api.h"
#include "exceptions.h"
#include "onnxruntime_no_customop.h"
#include "onnxruntime_cpp_api_legacy.hpp"
#include "onnxruntime_extensions.h"
#include "custom_op_lite.h"

#define MIN_ORT_VERSION_SUPPORTED 11

// namespace of ORT ABI Wrapper
namespace OrtW {

class API {
// To use ONNX C ABI in a way like OrtW::API::CreateStatus.
public:
static API& instance(const OrtApi* ort_api = nullptr) noexcept {
static API self(ort_api);
return self;
}

static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
return instance()->CreateStatus(code, msg);
}

static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
instance()->ReleaseStatus(ptr);
}

template <typename T>
static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;

static void ThrowOnError(OrtStatusPtr ptr) {
OrtW::ThrowOnError(instance().api_, ptr);
}

private:
const OrtApi* operator->() const {
return &api_;
}

API(const OrtApi* api) : api_(*api) {
if (api == nullptr) {
ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
}
}

const OrtApi& api_;
};

template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
}

template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
return instance()->KernelInfoGetAttribute_float(&info, name, &value);
}

template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
size_t size = 0;
std::string out;
// Feed nullptr for the data buffer to query the true size of the string attribute
OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
out.resize(size - 1); // remove the terminating character '\0'
}

if (status == nullptr) {
value = std::move(out);
}

return status;
}

template <class T>
inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
// Ideally, we should know which kind of error code can be ignored, but it is not available now.
// Just ignore all of them.
API::ReleaseStatus(status);
}

return nullptr;
}

inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
return API::CreateStatus(code, msg);
}

inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
return API::CreateStatus(code, msg.c_str());
}

inline void ReleaseStatus(OrtStatusPtr& status) {
API::ReleaseStatus(status);
status = nullptr;
}

} // namespace OrtW

#define ORTX_RETURN_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if (_status != nullptr) { \
return _status; \
} \
} while (0)

namespace Ort {
namespace Custom {

Expand Down Expand Up @@ -164,6 +62,12 @@ struct ComputeArgsList<OrtStatusPtr (C::*)(Args...) const> {
using MemberFunctionType = OrtStatusPtr (C::*)(Args...) const;
};

template <typename T, typename = void>
struct CustomOp_defined_getInputMemoryType : std::false_type {};

template <typename T>
struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {};

template <typename CustomOpKernel>
struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
using ComputeFunction = decltype(&CustomOpKernel::Compute);
Expand Down Expand Up @@ -236,6 +140,12 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
OrtCustomOp::CreateKernel = nullptr;
OrtCustomOp::KernelCompute = nullptr;

if constexpr (CustomOp_defined_getInputMemoryType<CustomOpKernel>::value) {
OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* /*this_*/, size_t index) -> OrtMemType {
return CustomOpKernel::GetInputMemoryType(index);
};
}

OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
if (api == nullptr) {
Expand Down
Loading

0 comments on commit 2234001

Please sign in to comment.