Skip to content

Commit

Permalink
new APIs for ORT-genai
Browse files Browse the repository at this point in the history
  • Loading branch information
jslhcl committed May 23, 2024
1 parent 474540d commit a43c782
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
9 changes: 8 additions & 1 deletion include/custom_op/custom_op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext {
public:
static const int cuda_resource_ver = 1;

OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) {
OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api), kernel_context_(ctx) {
api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_);
if (!cuda_stream_) {
ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
Expand Down Expand Up @@ -526,8 +526,15 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext {
return device_id_;
}

void* GetScratchBufferUnderMultiStream(const OrtMemoryInfo* mem_info, size_t count_or_bytes) override {
void* ret = nullptr;
api_.KernelContext_GetScratchBuffer(&kernel_context_, mem_info, count_or_bytes, &ret);
return ret;
}

private:
const OrtApi& api_;
const OrtKernelContext& kernel_context_;
OrtAllocator* cpu_allocator_;
OrtAllocator* cuda_allocator_;
void* cuda_stream_ = {};
Expand Down
2 changes: 2 additions & 0 deletions include/custom_op/kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <optional>
#include <numeric>
#include <type_traits>
#include "onnxruntime_c_api.h"

namespace Ort {
namespace Custom {
Expand All @@ -29,6 +30,7 @@ class CUDAKernelContext : public KernelContext {
virtual void* GetCudaStream() const = 0;
virtual void* GetCublasHandle() const = 0;
virtual int GetCudaDeviceId() const = 0;
virtual void* GetScratchBufferUnderMultiStream(const OrtMemoryInfo* , size_t ) { return nullptr; }
};
#endif

Expand Down
3 changes: 3 additions & 0 deletions include/ort_c_to_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class API {
return instance()->KernelContext_GetAllocator(context, mem_info, out);
}
#endif
static void ReleaseMemoryInfo(OrtMemoryInfo* mem_info) {
return instance()->ReleaseMemoryInfo(mem_info);
}
private:
const OrtApi* operator->() const {
return &api_;
Expand Down

0 comments on commit a43c782

Please sign in to comment.