From 1c9c4a4476ed59f8ec6ad786aff817aaa8338705 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 6 Jun 2024 10:04:42 +0000 Subject: [PATCH 1/5] draf --- operators/cuda/roatry_impl.cuh | 15 +++++++ operators/cuda/rotary.h | 77 +++++++++++++++++++++++++++++++++ operators/cuda/rotary_impl.cu | 79 ++++++++++++++++++++++++++++++++++ 3 files changed, 171 insertions(+) create mode 100644 operators/cuda/roatry_impl.cuh create mode 100644 operators/cuda/rotary.h create mode 100644 operators/cuda/rotary_impl.cu diff --git a/operators/cuda/roatry_impl.cuh b/operators/cuda/roatry_impl.cuh new file mode 100644 index 000000000..9d50b5313 --- /dev/null +++ b/operators/cuda/roatry_impl.cuh @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +enum class RotarySide : int { + LEFT = 1, + RIGHT = 2, +}; + +template +cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, + const T* input, const int64_t* split_data, T* output, RotarySide side); diff --git a/operators/cuda/rotary.h b/operators/cuda/rotary.h new file mode 100644 index 000000000..2ece1bf0e --- /dev/null +++ b/operators/cuda/rotary.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "rotary_impl.cuh" +#include "ortx_common.h" + +namespace contrib { + +template +struct Rotary { + template + OrtxStatus OnModelAttach(OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { + std::string side; + auto status = OrtW::GetOpAttribute(info, "side", side); + if (!status) { + return {kOrtxErrorInvalidArgument, "Missing or wrong argument side."}; + } + if (side == "left") { + side_ = RotarySide::LEFT; + } + else if (side == "right") { + side_ = RotarySide::RIGHT; + } + else { + return {kOrtxErrorInvalidArgument, "side must be 'left' or 'right'."}; + } + + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& input, + const ortc::Tensor& split, + ortc::Tensor& output) const { + const T* input_data = input.Data(); + auto input_shape = input.Shape(); + T* output_data = output.Allocate(input_shape); + auto input_length = input.NumberOfElement(); + if (0 == input_length) { + return {}; + } + + auto shape_split = split.Shape(); + if (shape_split.size() != 1 || shape_split[0] != 2) { + return {kOrtxErrorInvalidArgument, "Rotary only works when there are two sides."}; + } + if (shape_split[0] != shape_split[1]) { + return {kOrtxErrorInvalidArgument, "Only equal split are allowed."}; + } + if (shape_split[0] * 2 != input_shape[input_shape.size()-1]) { + return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."}; + } + + const int64_t* split_data = split.Data(); + + LaunchRotaryKernel(reinterpret_cast(ctx->GetCudaStream()), + input_length, + static_cast(input_shape[input_shape.size()-1]), + input_data, + split_data, + output_data, + side_); + return {}; + } + + static OrtMemType GetInputMemoryType(size_t input_index) { + if (input_index == 1) // split + return OrtMemType::OrtMemTypeCPUInput; + return OrtMemType::OrtMemTypeDefault; + } + + private: + RotarySide side_; +}; + +} // namespace contrib \ No newline at end of file diff --git a/operators/cuda/rotary_impl.cu b/operators/cuda/rotary_impl.cu new file mode 100644 index 000000000..d8928b2d4 --- /dev/null +++ b/operators/cuda/rotary_impl.cu @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "device_prop.cuh" +#include "utils.cuh" +#include "Rotary_impl.cuh" +#include "cuda_type.h" + +using namespace Ort::Custom; + +template __device__ __inline__ T _neg(const T x) { return -x; } + +#if __CUDA_ARCH__ < 700 +template <> __device__ __inline__ half _neg(const half x) { + return __float2half(-__half2float(x)); +} +#endif + +template +__global__ void RotaryKernel(T *output_data, const T *input_data, CUDA_LONG half_N, CUDA_LONG half_stride) { + CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= half_N) + return; + CUDA_LONG last = id % half_stride; + id = (id - last) * 2 + last; + if (side == RotarySide::RIGHT) { + output_data[id + half_stride] = input_data[id]; + output_data[id] = _neg(input_data[id + half_stride]); + } else { + output_data[id + half_stride] = _neg(input_data[id]); + output_data[id] = input_data[id + half_stride]; + } +} + +template +cudaError_t _LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, + const T* input, const int64_t* split_data, T* output, RotarySide side) { + constexpr int blockSize = 256; + const int gridSize = (input_length + blockSize - 1) / blockSize; + if (input_length == 0) + return; + using TT = typename contrib::CudaT::MappedType; + + CUDA_LONG N = static_cast(count); + CUDA_LONG stride = static_cast(last_dim); + + const int num_threads_per_block = GridDim::maxThreadsPerBlock; + const int num_elements_per_thread = + (N / 2 + num_threads_per_block - 1) / num_threads_per_block; + + switch (side) { + case RotarySide::LEFT: + RotaryKernel + <<>>(output_data, input_data, + N / 2, stride / 2); + break; + case RotarySide::RIGHT: + RotaryKernel + <<>>(output_data, input_data, + N / 2, stride / 2); + break; + } + + RotaryKernel<<>>(reinterpret_cast(output), reinterpret_cast(input), input_length); + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, + const float* input, const int64_t* split_data, float* output, RotarySide side) { + return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side); +} + +template <> +cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, + const ortc::MFloat16* input, const int64_t* split_data, + ortc::MFloat16* output, RotarySide side) { + return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side); +} From 52f351cb90d339fe2888387ad809932d6e592edc Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 6 Jun 2024 13:46:20 +0000 Subject: [PATCH 2/5] Fix implementation of Rotary --- operators/cuda/cuda_ops.cc | 5 +- operators/cuda/rotary.h | 16 ++--- operators/cuda/rotary_impl.cu | 38 +++++----- .../cuda/{roatry_impl.cuh => rotary_impl.cuh} | 2 +- test/cuda/test_cudaops.py | 69 +++++++++++++++++-- 5 files changed, 95 insertions(+), 35 deletions(-) rename operators/cuda/{roatry_impl.cuh => rotary_impl.cuh} (73%) diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index f8269302a..3b0d572eb 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -7,6 +7,7 @@ #include "cuda/add_mul.h" #include "cuda/fast_gelu.h" #include "cuda/negxplus1.h" +#include "cuda/rotary.h" #endif FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { @@ -28,13 +29,15 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("Rotary", contrib::Rotary), #if ORT_API_VERSION >= 16 CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type), - CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1) + CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("Rotary", contrib::Rotary) #endif #endif ); diff --git a/operators/cuda/rotary.h b/operators/cuda/rotary.h index 2ece1bf0e..a365c2470 100644 --- a/operators/cuda/rotary.h +++ b/operators/cuda/rotary.h @@ -11,12 +11,9 @@ namespace contrib { template struct Rotary { template - OrtxStatus OnModelAttach(OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { - std::string side; - auto status = OrtW::GetOpAttribute(info, "side", side); - if (!status) { - return {kOrtxErrorInvalidArgument, "Missing or wrong argument side."}; - } + OrtxStatus OnModelAttach(const TDict& dict) { + std::string empty; + std::string side = dict.TryToGetAttributeWithDefault("side", empty); if (side == "left") { side_ = RotarySide::LEFT; } @@ -45,15 +42,14 @@ struct Rotary { if (shape_split.size() != 1 || shape_split[0] != 2) { return {kOrtxErrorInvalidArgument, "Rotary only works when there are two sides."}; } - if (shape_split[0] != shape_split[1]) { + const int64_t* split_data = split.Data(); + if (split_data[0] != split_data[1]) { return {kOrtxErrorInvalidArgument, "Only equal split are allowed."}; } - if (shape_split[0] * 2 != input_shape[input_shape.size()-1]) { + if (split_data[0] * 2 != input_shape[input_shape.size()-1]) { return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."}; } - const int64_t* split_data = split.Data(); - LaunchRotaryKernel(reinterpret_cast(ctx->GetCudaStream()), input_length, static_cast(input_shape[input_shape.size()-1]), diff --git a/operators/cuda/rotary_impl.cu b/operators/cuda/rotary_impl.cu index d8928b2d4..f177ebd87 100644 --- a/operators/cuda/rotary_impl.cu +++ b/operators/cuda/rotary_impl.cu @@ -3,9 +3,13 @@ #include "device_prop.cuh" #include "utils.cuh" -#include "Rotary_impl.cuh" +#include "rotary_impl.cuh" #include "cuda_type.h" +#ifndef CUDA_LONG +#define CUDA_LONG int32_t +#endif + using namespace Ort::Custom; template __device__ __inline__ T _neg(const T x) { return -x; } @@ -34,46 +38,44 @@ __global__ void RotaryKernel(T *output_data, const T *input_data, CUDA_LONG half template cudaError_t _LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, - const T* input, const int64_t* split_data, T* output, RotarySide side) { - constexpr int blockSize = 256; - const int gridSize = (input_length + blockSize - 1) / blockSize; + const T* input_data, const int64_t* /* split_data */, T* output_data, RotarySide side) { if (input_length == 0) - return; + return cudaGetLastError(); using TT = typename contrib::CudaT::MappedType; - CUDA_LONG N = static_cast(count); + CUDA_LONG N = static_cast(input_length); CUDA_LONG stride = static_cast(last_dim); - const int num_threads_per_block = GridDim::maxThreadsPerBlock; + const int num_threads_per_block = 256; const int num_elements_per_thread = (N / 2 + num_threads_per_block - 1) / num_threads_per_block; switch (side) { case RotarySide::LEFT: - RotaryKernel - <<>>(output_data, input_data, + RotaryKernel + <<>>(reinterpret_cast(output_data), + reinterpret_cast(input_data), N / 2, stride / 2); break; case RotarySide::RIGHT: - RotaryKernel - <<>>(output_data, input_data, + RotaryKernel + <<>>(reinterpret_cast(output_data), + reinterpret_cast(input_data), N / 2, stride / 2); break; } - - RotaryKernel<<>>(reinterpret_cast(output), reinterpret_cast(input), input_length); return cudaGetLastError(); } template <> cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, - const float* input, const int64_t* split_data, float* output, RotarySide side) { - return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side); + const float* input_data, const int64_t* split_data, float* output_data, RotarySide side) { + return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side); } template <> cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, - const ortc::MFloat16* input, const int64_t* split_data, - ortc::MFloat16* output, RotarySide side) { - return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side); + const ortc::MFloat16* input_data, const int64_t* split_data, + ortc::MFloat16* output_data, RotarySide side) { + return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side); } diff --git a/operators/cuda/roatry_impl.cuh b/operators/cuda/rotary_impl.cuh similarity index 73% rename from operators/cuda/roatry_impl.cuh rename to operators/cuda/rotary_impl.cuh index 9d50b5313..11c7a116d 100644 --- a/operators/cuda/roatry_impl.cuh +++ b/operators/cuda/rotary_impl.cuh @@ -12,4 +12,4 @@ enum class RotarySide : int { template cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, - const T* input, const int64_t* split_data, T* output, RotarySide side); + const T* input_data, const int64_t* split_data, T* output_data, RotarySide side); diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index df1ab2c47..44e1f4d2d 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -151,8 +151,6 @@ def test_cuda_negxplus1(self): self._negxplus1_cuda(TensorProto.FLOAT16) def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, 2, 3), shapec=(3, 2, 3)): - from onnx_extended.ortops.optim.cuda import get_ort_ext_libs - model1 = helper.make_model( helper.make_graph( [ @@ -181,7 +179,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, f"{op_type}SharedInput", ["X", "Y", "Z"], ["XY", "XZ"], - domain="onnx_extended.ortops.optim.cuda", + domain="ai.onnx.contrib", ) ], "nd", @@ -197,7 +195,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, ), opset_imports=[ helper.make_opsetid("", 18), - helper.make_opsetid("onnx_extended.ortops.optim.cuda", 1), + helper.make_opsetid("ai.onnx.contrib", 1), ], ir_version=9, ) @@ -212,7 +210,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, expected = ref.run(None, feeds1) opts = _ort.SessionOptions() - opts.register_custom_ops_library(get_ort_ext_libs()[0]) + opts.register_custom_ops_library(_get_library_path()) sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) got = sess.run(None, feeds1) for i in range(2): @@ -262,6 +260,67 @@ def test_add_shared_input_cuda_broadcast2(self): shapec=(3, 2, 3), ) + def _rotary_cuda(self, itype, side, input_shape=(3, 2, 3, 4)): + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + "Rotary", + ["X", "splits"], + ["Y"], + domain="ai.onnx.contrib", + side=side, + ) + ], + "nd", + [ + helper.make_tensor_value_info("X", itype, [None, None, None, None]), + helper.make_tensor_value_info("splits", TensorProto.INT64, [2]), + ], + [helper.make_tensor_value_info("Y", itype, [None, None, None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + x = (np.arange(np.prod(input_shape)) + 1).reshape(input_shape).astype(dtype) + splits = np.array([x.shape[-1] // 2, x.shape[-1] // 2], dtype=np.int64) + + expected = x.copy() + half = x.shape[-1] // 2 + if side == "left": + expected[:, :, :, :half] = x[:, :, :, half:] + expected[:, :, :, half:] = -x[:, :, :, :half] + else: + expected[:, :, :, :half] = -x[:, :, :, half:] + expected[:, :, :, half:] = x[:, :, :, :half] + + feeds = dict(X=x, splits=splits) + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds)[0] + assert_almost_equal(expected, got) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_rotary_cuda(self): + self._rotary_cuda(TensorProto.FLOAT, "left") + self._rotary_cuda(TensorProto.FLOAT, "right") + self._rotary_cuda(TensorProto.FLOAT16, "left") + self._rotary_cuda(TensorProto.FLOAT16, "right") + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_bigger_rotary_cuda(self): + sh = (2, 2, 1024, 8) + self._rotary_cuda(TensorProto.FLOAT, "left", input_shape=sh) + self._rotary_cuda(TensorProto.FLOAT, "right", input_shape=sh) + self._rotary_cuda(TensorProto.FLOAT16, "left", input_shape=sh) + self._rotary_cuda(TensorProto.FLOAT16, "right", input_shape=sh) + if __name__ == "__main__": unittest.main() From ec3887a4dcee3984c009e59d67e4ea54e38f31f3 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 11 Jun 2024 10:05:58 +0000 Subject: [PATCH 3/5] comment --- operators/cuda/rotary.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/operators/cuda/rotary.h b/operators/cuda/rotary.h index a365c2470..b5e27b877 100644 --- a/operators/cuda/rotary.h +++ b/operators/cuda/rotary.h @@ -8,6 +8,16 @@ namespace contrib { +/** +* Y = Rotary(X) is equivalent to if side == LEFT: +* +* N = X.shape[-1] +* Y = X.copy() +* Y[...,:N/2] = X[...,N/2:] +* Y[...,N/2:] = -X[...,:N/2] +* +* And the opposite if side == RIGHT. +*/ template struct Rotary { template From c406f42baf61651f59e4b6b7ec3a27d2b1e645d2 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Wed, 12 Jun 2024 10:31:37 +0200 Subject: [PATCH 4/5] fix merge conflicts --- operators/cuda/cuda_ops.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index b42b65527..470d18153 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -8,11 +8,8 @@ #include "cuda/fast_gelu.h" #include "cuda/mul_sigmoid.h" #include "cuda/negxplus1.h" -<<<<<<< HEAD #include "cuda/rotary.h" -======= #include "cuda/scatter_nd_of_shape.h" ->>>>>>> f5055466d5376059c2ea74e3cea46e16a537bc0d #include "cuda/transpose_cast.h" #endif From dfeafa578b342589e8339db2f01fc00c58ea75b5 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Thu, 27 Jun 2024 10:07:07 +0200 Subject: [PATCH 5/5] minor fixes --- operators/cuda/rotary.h | 52 ++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/operators/cuda/rotary.h b/operators/cuda/rotary.h index b5e27b877..f143323ac 100644 --- a/operators/cuda/rotary.h +++ b/operators/cuda/rotary.h @@ -9,15 +9,20 @@ namespace contrib { /** -* Y = Rotary(X) is equivalent to if side == LEFT: -* -* N = X.shape[-1] -* Y = X.copy() -* Y[...,:N/2] = X[...,N/2:] -* Y[...,N/2:] = -X[...,:N/2] -* -* And the opposite if side == RIGHT. -*/ + * Y = Rotary(X) is equivalent to if side == LEFT: + * + * N = X.shape[-1] + * Y = X.copy() + * Y[...,:N/2] = X[...,N/2:] + * Y[...,N/2:] = -X[...,:N/2] + * + * And the opposite if side == RIGHT: + * + * N = X.shape[-1] + * Y = X.copy() + * Y[...,:N/2] = -X[...,N/2:] + * Y[...,N/2:] = X[...,:N/2] + */ template struct Rotary { template @@ -26,20 +31,17 @@ struct Rotary { std::string side = dict.TryToGetAttributeWithDefault("side", empty); if (side == "left") { side_ = RotarySide::LEFT; - } - else if (side == "right") { + } else if (side == "right") { side_ = RotarySide::RIGHT; - } - else { + } else { return {kOrtxErrorInvalidArgument, "side must be 'left' or 'right'."}; } return {}; } - OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, - const ortc::Tensor& input, - const ortc::Tensor& split, - ortc::Tensor& output) const { + + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& input, + const ortc::Tensor& split, ortc::Tensor& output) const { const T* input_data = input.Data(); auto input_shape = input.Shape(); T* output_data = output.Allocate(input_shape); @@ -54,19 +56,15 @@ struct Rotary { } const int64_t* split_data = split.Data(); if (split_data[0] != split_data[1]) { - return {kOrtxErrorInvalidArgument, "Only equal split are allowed."}; + return {kOrtxErrorInvalidArgument, "Only equal split is allowed."}; } - if (split_data[0] * 2 != input_shape[input_shape.size()-1]) { + if (split_data[0] != split_data[1] != input_shape[input_shape.size() - 1]) { return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."}; } - LaunchRotaryKernel(reinterpret_cast(ctx->GetCudaStream()), - input_length, - static_cast(input_shape[input_shape.size()-1]), - input_data, - split_data, - output_data, - side_); + LaunchRotaryKernel(reinterpret_cast(ctx->GetCudaStream()), input_length, + static_cast(input_shape[input_shape.size() - 1]), input_data, split_data, output_data, + side_); return {}; } @@ -76,7 +74,7 @@ struct Rotary { return OrtMemType::OrtMemTypeDefault; } - private: + private: RotarySide side_; };