Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor ORT-Extension for the coming GroupQueryAttention work #674

Merged
merged 3 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions base/ortx_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <locale>
#include <optional>
#include <string>
#include <sstream>
#include "string_utils.h"
#ifdef _WIN32
#include <Windows.h>
#endif

#define ORTX_RETURN_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if (_status != nullptr) { \
return _status; \
} \
} while (0)

template <typename T>
bool TryParseStringWithClassicLocale(std::string_view str, T& value) {
if constexpr (std::is_integral<T>::value && std::is_unsigned<T>::value) {
// if T is unsigned integral type, reject negative values which will wrap
if (!str.empty() && str[0] == '-') {
return false;
}
}

// don't allow leading whitespace
if (!str.empty() && std::isspace(str[0], std::locale::classic())) {
return false;
}

std::istringstream is{std::string{str}};
is.imbue(std::locale::classic());
T parsed_value{};

const bool parse_successful =
is >> parsed_value &&
is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters
if (!parse_successful) {
return false;
}

value = std::move(parsed_value);
return true;
}

inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) {
value = str;
return true;
}

inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) {
if (str == "0" || str == "False" || str == "false") {
value = false;
return true;
}

if (str == "1" || str == "True" || str == "true") {
value = true;
return true;
}

return false;
}

template <typename T>
std::optional<T> ParseEnvironmentVariable(const std::string& name) {
std::string buffer;
#ifdef _WIN32
constexpr size_t kBufferSize = 32767;

// Create buffer to hold the result
buffer.resize(kBufferSize, '\0');

// The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters.
// If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character.
// Therefore, If the function succeeds, kBufferSize should be larger than char_count.
auto char_count = GetEnvironmentVariableA(name.c_str(), buffer.data(), kBufferSize);

if (kBufferSize > char_count) {
buffer.resize(char_count);
} else {
// Else either the call was failed, or the buffer wasn't large enough.
// TODO: Understand the reason for failure by calling GetLastError().
// If it is due to the specified environment variable being found in the environment block,
// GetLastError() returns ERROR_ENVVAR_NOT_FOUND.
// For now, we assume that the environment variable is not found.
buffer.clear();
}
#else
char* val = getenv(name.c_str());
buffer = (val == nullptr) ? std::string() : std::string(val);
#endif
T parsed_value;
if (!TryParseStringWithClassicLocale(buffer, parsed_value)) {
OrtW::Exception(MakeString("Failed to parse environment variable - name: ", name, ", value: ", buffer), OrtErrorCode::ORT_FAIL);
}
return parsed_value;
}

template <typename T>
T ParseEnvironmentVariableWithDefault(const std::string& name, const T& default_value) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are these functions used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be used in the coming GroupQueryAttention code

const auto parsed = ParseEnvironmentVariable<T>(name);
if (parsed.has_value()) {
return *parsed;
}

return default_value;
}

inline bool IsScalarOr1ElementVector(size_t num_dimensions, int64_t shape_size) {
if (num_dimensions == 0 || (num_dimensions == 1 && shape_size == 1)) return true;
return false;
}
2 changes: 1 addition & 1 deletion base/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#pragma once
#include <sstream>
#include <vector>
#include "ocos.h"
#include "onnxruntime_cpp_api_legacy.hpp"

template <typename T>
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
Expand Down
2 changes: 1 addition & 1 deletion docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The package contains all custom operators and some Python scripts to manipulate
- no-opencv: disable operators based on OpenCV in build.
- cc-debug: Generate debug info for extensions binaries and disable C/C++ compiler optimization.

For example:`pip install --config-settings "ortx-user-option=use-cuda,cc-debug" `, This command builds CUDA kernels into the package and installs it, accompanied by the generation of debug information.
For example:`pip install . --config-settings "ortx-user-option=use-cuda,cc-debug" `, This command builds CUDA kernels into the package and installs it, accompanied by the generation of debug information.

Test:

Expand Down
17 changes: 17 additions & 0 deletions includes/custom_op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,11 @@ struct Variadic : public TensorBase {

enum CudaResource {
cuda_handle_t = 10000,
cudnn_handle_t,
cublas_handle_t,
deferred_cpu_allocator_t,
// below are cuda ep options
device_id_t,
};

struct CudaContext {
Expand All @@ -595,8 +600,20 @@ struct CudaContext {
if (!cuda_stream) {
ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
}
ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas);
if (!cublas) {
ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION);
}
void* resource = nullptr;
OrtStatusPtr result = ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
if (result) {
ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION);
}
memcpy(&device_id, &resource, sizeof(int));
}
void* cuda_stream = {};
void* cublas = {};
int device_id = 0;
};

#endif
Expand Down
2 changes: 1 addition & 1 deletion includes/onnxruntime_cpp_api_legacy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#pragma once
#include <vector>
#include "onnxruntime_c_api.h"
#include "exceptions.h"

//
// DEPRECATED: All new custom OPs should not use any class/struct/functions from this file.
Expand Down
118 changes: 14 additions & 104 deletions includes/onnxruntime_customop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,118 +15,16 @@
#include <utility>
#include <type_traits>
#include <optional>
#include <functional>

#include "onnxruntime_c_api.h"
#include "exceptions.h"
#include "onnxruntime_no_customop.h"
#include "onnxruntime_cpp_api_legacy.hpp"
#include "onnxruntime_extensions.h"
#include "custom_op_lite.h"

#define MIN_ORT_VERSION_SUPPORTED 11

// namespace of ORT ABI Wrapper
namespace OrtW {

class API {
// To use ONNX C ABI in a way like OrtW::API::CreateStatus.
public:
static API& instance(const OrtApi* ort_api = nullptr) noexcept {
static API self(ort_api);
return self;
}

static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
return instance()->CreateStatus(code, msg);
}

static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
instance()->ReleaseStatus(ptr);
}

template <typename T>
static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;

static void ThrowOnError(OrtStatusPtr ptr) {
OrtW::ThrowOnError(instance().api_, ptr);
}

private:
const OrtApi* operator->() const {
return &api_;
}

API(const OrtApi* api) : api_(*api) {
if (api == nullptr) {
ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
}
}

const OrtApi& api_;
};

template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
}

template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
return instance()->KernelInfoGetAttribute_float(&info, name, &value);
}

template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
size_t size = 0;
std::string out;
// Feed nullptr for the data buffer to query the true size of the string attribute
OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
out.resize(size - 1); // remove the terminating character '\0'
}

if (status == nullptr) {
value = std::move(out);
}

return status;
}

template <class T>
inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
// Ideally, we should know which kind of error code can be ignored, but it is not available now.
// Just ignore all of them.
API::ReleaseStatus(status);
}

return nullptr;
}

inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
return API::CreateStatus(code, msg);
}

inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
return API::CreateStatus(code, msg.c_str());
}

inline void ReleaseStatus(OrtStatusPtr& status) {
API::ReleaseStatus(status);
status = nullptr;
}

} // namespace OrtW

#define ORTX_RETURN_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if (_status != nullptr) { \
return _status; \
} \
} while (0)

namespace Ort {
namespace Custom {

Expand Down Expand Up @@ -164,6 +62,12 @@ struct ComputeArgsList<OrtStatusPtr (C::*)(Args...) const> {
using MemberFunctionType = OrtStatusPtr (C::*)(Args...) const;
};

template <typename T, typename = void>
struct CustomOp_defined_getInputMemoryType : std::false_type {};

template <typename T>
struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {};

template <typename CustomOpKernel>
struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
using ComputeFunction = decltype(&CustomOpKernel::Compute);
Expand Down Expand Up @@ -236,6 +140,12 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
OrtCustomOp::CreateKernel = nullptr;
OrtCustomOp::KernelCompute = nullptr;

if constexpr (CustomOp_defined_getInputMemoryType<CustomOpKernel>::value) {
OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* /*this_*/, size_t index) -> OrtMemType {
return CustomOpKernel::GetInputMemoryType(index);
};
}

OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
if (api == nullptr) {
Expand Down
Loading
Loading