Skip to content

Commit

Permalink
update custom op v2 struct to be able to invoke from eager mode (#700)
Browse files Browse the repository at this point in the history
Co-authored-by: Cheng Tang <[email protected]>
Co-authored-by: Wenbing Li <[email protected]>
  • Loading branch information
3 people authored Apr 30, 2024
1 parent 0175f90 commit 3b889fc
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 21 deletions.
39 changes: 37 additions & 2 deletions include/op_def_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,35 @@ struct ComputeArgsList<RType (C::*)(Args...) const> {
using ResultType = RType;
};

template<typename, typename T>
struct HasOnModelAttach {
static_assert(
std::integral_constant<T, false>::value,
"Second template parameter needs to be of function type.");
};

// specialization that does the checking

template<typename C, typename Ret, typename... Args>
struct HasOnModelAttach<C, Ret(Args...)> {
private:
template<typename T>
static constexpr auto check(T*)
-> typename
std::is_same<
decltype( std::declval<T>().OnModelAttach( std::declval<Args>()... ) ),
Ret
>::type; // attempt to call it and see if the return type is correct

template<typename>
static constexpr std::false_type check(...);

typedef decltype(check<C>(0)) type;

public:
static constexpr bool value = type::value;
};

template <typename T, typename = void>
struct CustomOp_defined_getInputMemoryType : std::false_type {};

Expand All @@ -96,8 +125,14 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
template <typename T>
static OrtStatusPtr InitKernel(KernelEx& kernel,
const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, T t) {
auto status = kernel.OnModelAttach(api, info);
return ToApiStatus(status);
if constexpr (HasOnModelAttach<KernelEx, OrtStatusPtr(const OrtApi&, const OrtKernelInfo&)>::value){
auto status = kernel.OnModelAttach(api, info);
return ToApiStatus(status);
}
else {
auto status = kernel.OnModelAttach(OrtAttributeReader(api, info));
return ToApiStatus(status);
}
}

static OrtStatusPtr InitKernel(
Expand Down
4 changes: 2 additions & 2 deletions operators/contrib/cuda/fast_gelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ namespace contrib {

template <typename T>
struct FastGelu {
OrtStatusPtr OnModelAttach(const OrtApi& /*api*/,
const OrtKernelInfo& /*info*/) {
template<typename TDict>
OrtStatusPtr OnModelAttach(const TDict& /*dict*/) {
return nullptr;
}
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
Expand Down
71 changes: 54 additions & 17 deletions test/static_test/test_cuda_eager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,50 @@

#ifdef USE_CUDA
#include "math/cuda/negpos_def.h"
#include "contrib/cuda/fast_gelu.h"
#include <cuda.h>
#include <cuda_runtime.h>


class CudaAllocator : public Ort::Custom::IAllocator {
public:
void* Alloc(size_t size) override {
void* p = nullptr;
cudaMalloc((void**)&p, size);
return p;
}
void Free(void* p) override { cudaFree(p); }
};

class MockCudaKernelContext : public Ort::Custom::CUDAKernelContext {
public:
MockCudaKernelContext() { cudaStreamCreate(&stream); }
~MockCudaKernelContext() { cudaStreamDestroy(stream); }
void* AllocScratchBuffer(size_t size) override { return nullptr; }
void FreeScratchBuffer(void* p) override {}
void* AllocCudaScratchBuffer(size_t size) override { return nullptr; }
void FreeCudaScratchBuffer(void* p) override {}
void* AllocScratchBuffer(size_t size) override { return malloc(size); }
void FreeScratchBuffer(void* p) override { return free(p);}
void* AllocCudaScratchBuffer(size_t size) override { return cuda_alloc.Alloc(size); }
void FreeCudaScratchBuffer(void* p) override { return cuda_alloc.Free(p); }
void* GetCudaStream() const override { return static_cast<void*>(stream); }
void* GetCublasHandle() const override { return nullptr; }
int GetCudaDeviceId() const override { return 0; }

Ort::Custom::IAllocator* GetCudaAllocator() { return &cuda_alloc;};

private:
cudaStream_t stream;
};

class CudaAllocator : public Ort::Custom::IAllocator {
public:
void* Alloc(size_t size) override {
void* p = nullptr;
cudaMalloc((void**)&p, size);
return p;
}
void Free(void* p) override { cudaFree(p); }
CudaAllocator cuda_alloc;
};

TEST(CudaOp, test_eager_negpos) {
MockCudaKernelContext mock_cuda_kc;
std::vector<float> input_data = {0.0f, 0.2f, -1.3f, 1.5f};
std::unique_ptr<CudaAllocator> cuda_alloc = std::make_unique<CudaAllocator>();
auto cuda_alloc = mock_cuda_kc.GetCudaAllocator();
void* device_input = cuda_alloc->Alloc(sizeof(float) * input_data.size());
cudaMemcpyAsync(device_input, input_data.data(), sizeof(float)*input_data.size(), cudaMemcpyHostToDevice, static_cast<cudaStream_t>(mock_cuda_kc.GetCudaStream()));

ortc::Tensor<float> input(std::vector<int64_t>{2, 2}, device_input);
ortc::Tensor<float> output1(cuda_alloc.get());
ortc::Tensor<float> output2(cuda_alloc.get());
ortc::Tensor<float> output1(cuda_alloc);
ortc::Tensor<float> output2(cuda_alloc);
neg_pos_cuda(&mock_cuda_kc, input, output1, output2);

float* host_output1 = (float*)malloc(sizeof(float) * input_data.size());
Expand All @@ -63,4 +68,36 @@ TEST(CudaOp, test_eager_negpos) {
free(host_output2);
}

TEST(CudaOp, test_fastgelu_eager) {

MockCudaKernelContext mock_cuda_kc;
// inputs
std::vector<float> x_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
auto cuda_alloc = mock_cuda_kc.GetCudaAllocator();
void* x_gpu_input = cuda_alloc->Alloc(sizeof(float) * x_data.size());
cudaMemcpyAsync(x_gpu_input, x_data.data(), sizeof(float)*x_data.size(), cudaMemcpyHostToDevice, static_cast<cudaStream_t>(mock_cuda_kc.GetCudaStream()));

std::vector<float> bias_data = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f};
void* bias_gpu_input = cuda_alloc->Alloc(sizeof(float) * bias_data.size());
cudaMemcpyAsync(bias_gpu_input, bias_data.data(), sizeof(float)*bias_data.size(), cudaMemcpyHostToDevice, static_cast<cudaStream_t>(mock_cuda_kc.GetCudaStream()));

ortc::NamedArgumentDict dict({"use_half_2_"},
std::make_tuple(false));
contrib::FastGelu<float> fast_gelu;
fast_gelu.OnModelAttach(dict);

ortc::Tensor<float> x(std::vector<int64_t>{6, }, x_gpu_input);
ortc::Tensor<float> bias(std::vector<int64_t>{6, }, bias_gpu_input);
ortc::Tensor<float> output(cuda_alloc);
fast_gelu.Compute(&mock_cuda_kc, x, &bias, output);

float* host_output = (float*)malloc(sizeof(float) * x_data.size());
cudaMemcpyAsync(host_output, output.DataRaw(), sizeof(float)*x_data.size(), cudaMemcpyDeviceToHost, static_cast<cudaStream_t>(mock_cuda_kc.GetCudaStream()));
ASSERT_NEAR(host_output[1], 0.9505811, 0.01f);
ASSERT_NEAR(host_output[2], 2.1696784, 0.01f);
ASSERT_NEAR(host_output[3], 3.298689, 0.01f);
ASSERT_NEAR(host_output[4], 4.399991, 0.01f);
ASSERT_NEAR(host_output[5], 5.5, 0.01f);
}

#endif

0 comments on commit 3b889fc

Please sign in to comment.