Skip to content

Commit

Permalink
experiment changes, working
Browse files Browse the repository at this point in the history
  • Loading branch information
jslhcl committed Feb 7, 2024
1 parent 44e494b commit f726034
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 3 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ endmacro()

if(OCOS_USE_CUDA)
include(ext_cuda)
include(cutlass)
endif()

#######################################################################################################################
Expand Down Expand Up @@ -557,6 +558,10 @@ target_include_directories(ocos_operators PUBLIC
${PROJECT_SOURCE_DIR}/base
${PROJECT_SOURCE_DIR}/operators)

if (OCOS_USE_CUDA)
target_include_directories(ocos_operators PUBLIC ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
endif()

set(ocos_libraries)
set(OCOS_COMPILE_DEFINITIONS)

Expand Down
11 changes: 11 additions & 0 deletions cmake/externals/cutlass.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
include(FetchContent)
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG v3.1.0
)

FetchContent_GetProperties(cutlass)
if(NOT cutlass_POPULATED)
FetchContent_Populate(cutlass)
endif()
18 changes: 17 additions & 1 deletion includes/custom_op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,18 +585,30 @@ struct Variadic : public TensorBase {

enum CudaResource {
cuda_handle_t = 10000,
cudnn_handle_t,
cublas_handle_t,
deferred_cpu_allocator_t,
// below are cuda ep options
device_id_t, // 10004
};

struct CudaContext {
static const int cuda_resource_ver = 1;
void Init(const OrtW::CustomOpApi& api, const OrtKernelContext& ctx) {
const auto& ort_api = api.GetOrtApi();
ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream);
auto hr = ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream);
if (hr) return;
if (!cuda_stream) {
ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
}

ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas);
if (!cublas) {
ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION);
}
}
void* cuda_stream = {};
void* cublas = {};
};

#endif
Expand All @@ -623,6 +635,10 @@ struct OrtLiteCustomOp : public OrtCustomOp {
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
//OrtMemoryInfo* memory_info = nullptr;
//(*api).GetOrtApi().CreateMemoryInfo("Cuda", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info);
//OrtAllocator* allocator = nullptr;
//(*api).GetOrtApi().KernelContext_GetAllocator(context, memory_info, &allocator);
thread_local CudaContext cuda_context;
cuda_context.Init(*api, *context);
std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
Expand Down
18 changes: 18 additions & 0 deletions includes/onnxruntime_customop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ class API {
instance()->ReleaseStatus(ptr);
}

static OrtStatusPtr GetInputCount(const OrtKernelContext* context, size_t* out) noexcept {
return instance()->KernelContext_GetInputCount(context, out);
}

static OrtMemoryInfo* CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type) noexcept {
OrtMemoryInfo* ret = nullptr;
OrtStatusPtr tmp = instance()->CreateMemoryInfo(name, type, id, mem_type, &ret);
if (tmp) return nullptr;
return ret;
}

static OrtAllocator* GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info) {
OrtAllocator* ret = nullptr;
OrtStatusPtr tmp = instance()->KernelContext_GetAllocator(context, mem_info, &ret);
if (tmp) return nullptr;
return ret;
}

template <typename T>
static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;

Expand Down
34 changes: 32 additions & 2 deletions operators/contrib/cuda/fast_gelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
#pragma once
#include "ocos.h"
#include "fast_gelu_impl.cuh"
#include "cuda_fp16.h"
#include <cuda_fp16.h>
#include <cublas_v2.h>
//#include "41_fused_multi_head_attention/kernel_forward.h"
#include "cute/arch/copy_sm90_desc.hpp"

namespace contrib {

Expand All @@ -18,16 +21,43 @@ struct CudaT<ortc::MFloat16> {
using MappedType = half;
};

template <typename T>
using IAllocatorUniquePtr = std::unique_ptr<T, std::function<void(T*)>>;

template <typename T>
inline IAllocatorUniquePtr<T> GetScrachBuffer(void* p, OrtAllocator* allocator) {
return IAllocatorUniquePtr<T>{static_cast<T*>(p), [allocator = std::move(allocator)](T* p) {
allocator->Free(allocator, p);
}};
}

template <typename T>
struct FastGelu {
OrtStatusPtr OnModelAttach(const OrtApi& /*api*/,
const OrtKernelInfo& /*info*/) {
return nullptr;
}
OrtStatusPtr Compute(const Ort::Custom::CudaContext& ctx,
OrtStatusPtr Compute(OrtKernelContext* kernel_context,
const Ort::Custom::CudaContext& ctx,
const ortc::Tensor<T>& input,
std::optional<const ortc::Tensor<T>*> bias,
ortc::Tensor<T>& output) const {
if (kernel_context == nullptr) return nullptr;
size_t input_count = 0;
auto hr = OrtW::API::GetInputCount(kernel_context, &input_count);
if (hr || input_count == 0) return nullptr;
cublasHandle_t cublas = reinterpret_cast<cublasHandle_t>(ctx.cublas);
if (!cublas) return nullptr;
//bool supportsDropout = AttentionKernel::kSupportsDropout;
auto value = cute::TMA::SmemSwizzleBits::B32;
OrtMemoryInfo* mem_info = OrtW::API::CreateOrtMemoryInfo("Cuda", OrtDeviceAllocator, 0, OrtMemTypeDefault);
OrtAllocator* allocator = OrtW::API::GetOrtAllocator(kernel_context, mem_info);
void* p_raw = allocator->Alloc(allocator, 3);
if (!p_raw) return nullptr;
{
IAllocatorUniquePtr<int> p = GetScrachBuffer<int>(p_raw, allocator);
}

const T* input_data = input.Data();
T* output_data = output.Allocate(input.Shape());
auto input_length = input.NumberOfElement();
Expand Down
1 change: 1 addition & 0 deletions test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _create_negpos_test_model(domain='ai.onnx.contrib'):

def test_cuda_negpos(self):
so = _ort.SessionOptions()
print('lib_path:'+_get_library_path())
so.register_custom_ops_library(_get_library_path())
onnx_model = self._create_negpos_test_model()
self.assertIn('op_type: "NegPos"', str(onnx_model))
Expand Down

0 comments on commit f726034

Please sign in to comment.