diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index 8fc3105b9..607803bc4 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -9,6 +9,7 @@ #include "cuda/mul_sigmoid.h" #include "cuda/negxplus1.h" #include "cuda/replace_zero.h" +#include "cuda/rotary.h" #include "cuda/scatter_nd_of_shape.h" #include "cuda/transpose_cast.h" #endif @@ -36,6 +37,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero), + CustomCudaStructV2("Rotary", contrib::Rotary), CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape), #if ORT_API_VERSION >= 16 @@ -48,6 +50,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero), + CustomCudaStructV2("Rotary", contrib::Rotary), CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape), CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type), CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type) diff --git a/operators/cuda/rotary.h b/operators/cuda/rotary.h new file mode 100644 index 000000000..f143323ac --- /dev/null +++ b/operators/cuda/rotary.h @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "rotary_impl.cuh" +#include "ortx_common.h" + +namespace contrib { + +/** + * Y = Rotary(X) is equivalent to if side == LEFT: + * + * N = X.shape[-1] + * Y = X.copy() + * Y[...,:N/2] = X[...,N/2:] + * Y[...,N/2:] = -X[...,:N/2] + * + * And the opposite if side == RIGHT: + * + * N = X.shape[-1] + * Y = X.copy() + * Y[...,:N/2] = -X[...,N/2:] + * Y[...,N/2:] = X[...,:N/2] + */ +template +struct Rotary { + template + OrtxStatus OnModelAttach(const TDict& dict) { + std::string empty; + std::string side = dict.TryToGetAttributeWithDefault("side", empty); + if (side == "left") { + side_ = RotarySide::LEFT; + } else if (side == "right") { + side_ = RotarySide::RIGHT; + } else { + return {kOrtxErrorInvalidArgument, "side must be 'left' or 'right'."}; + } + + return {}; + } + + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& input, + const ortc::Tensor& split, ortc::Tensor& output) const { + const T* input_data = input.Data(); + auto input_shape = input.Shape(); + T* output_data = output.Allocate(input_shape); + auto input_length = input.NumberOfElement(); + if (0 == input_length) { + return {}; + } + + auto shape_split = split.Shape(); + if (shape_split.size() != 1 || shape_split[0] != 2) { + return {kOrtxErrorInvalidArgument, "Rotary only works when there are two sides."}; + } + const int64_t* split_data = split.Data(); + if (split_data[0] != split_data[1]) { + return {kOrtxErrorInvalidArgument, "Only equal split is allowed."}; + } + if (split_data[0] != split_data[1] != input_shape[input_shape.size() - 1]) { + return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."}; + } + + LaunchRotaryKernel(reinterpret_cast(ctx->GetCudaStream()), input_length, + static_cast(input_shape[input_shape.size() - 1]), input_data, split_data, output_data, + side_); + return {}; + } + + static OrtMemType GetInputMemoryType(size_t input_index) { + if (input_index == 1) // split + return OrtMemType::OrtMemTypeCPUInput; + return OrtMemType::OrtMemTypeDefault; + } + + private: + RotarySide side_; +}; + +} // namespace contrib \ No newline at end of file diff --git a/operators/cuda/rotary_impl.cu b/operators/cuda/rotary_impl.cu new file mode 100644 index 000000000..f177ebd87 --- /dev/null +++ b/operators/cuda/rotary_impl.cu @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "device_prop.cuh" +#include "utils.cuh" +#include "rotary_impl.cuh" +#include "cuda_type.h" + +#ifndef CUDA_LONG +#define CUDA_LONG int32_t +#endif + +using namespace Ort::Custom; + +template __device__ __inline__ T _neg(const T x) { return -x; } + +#if __CUDA_ARCH__ < 700 +template <> __device__ __inline__ half _neg(const half x) { + return __float2half(-__half2float(x)); +} +#endif + +template +__global__ void RotaryKernel(T *output_data, const T *input_data, CUDA_LONG half_N, CUDA_LONG half_stride) { + CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= half_N) + return; + CUDA_LONG last = id % half_stride; + id = (id - last) * 2 + last; + if (side == RotarySide::RIGHT) { + output_data[id + half_stride] = input_data[id]; + output_data[id] = _neg(input_data[id + half_stride]); + } else { + output_data[id + half_stride] = _neg(input_data[id]); + output_data[id] = input_data[id + half_stride]; + } +} + +template +cudaError_t _LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, + const T* input_data, const int64_t* /* split_data */, T* output_data, RotarySide side) { + if (input_length == 0) + return cudaGetLastError(); + using TT = typename contrib::CudaT::MappedType; + + CUDA_LONG N = static_cast(input_length); + CUDA_LONG stride = static_cast(last_dim); + + const int num_threads_per_block = 256; + const int num_elements_per_thread = + (N / 2 + num_threads_per_block - 1) / num_threads_per_block; + + switch (side) { + case RotarySide::LEFT: + RotaryKernel + <<>>(reinterpret_cast(output_data), + reinterpret_cast(input_data), + N / 2, stride / 2); + break; + case RotarySide::RIGHT: + RotaryKernel + <<>>(reinterpret_cast(output_data), + reinterpret_cast(input_data), + N / 2, stride / 2); + break; + } + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, + const float* input_data, const int64_t* split_data, float* output_data, RotarySide side) { + return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side); +} + +template <> +cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, + const ortc::MFloat16* input_data, const int64_t* split_data, + ortc::MFloat16* output_data, RotarySide side) { + return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side); +} diff --git a/operators/cuda/rotary_impl.cuh b/operators/cuda/rotary_impl.cuh new file mode 100644 index 000000000..11c7a116d --- /dev/null +++ b/operators/cuda/rotary_impl.cuh @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +enum class RotarySide : int { + LEFT = 1, + RIGHT = 2, +}; + +template +cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, + const T* input_data, const int64_t* split_data, T* output_data, RotarySide side); diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index 43233a26b..740593cef 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -596,6 +596,67 @@ def test_masked_scatternd_of_shape_standalone_cuda_big(self): self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, True) self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, True) + def _rotary_cuda(self, itype, side, input_shape=(3, 2, 3, 4)): + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + "Rotary", + ["X", "splits"], + ["Y"], + domain="ai.onnx.contrib", + side=side, + ) + ], + "nd", + [ + helper.make_tensor_value_info("X", itype, [None, None, None, None]), + helper.make_tensor_value_info("splits", TensorProto.INT64, [2]), + ], + [helper.make_tensor_value_info("Y", itype, [None, None, None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + x = (np.arange(np.prod(input_shape)) + 1).reshape(input_shape).astype(dtype) + splits = np.array([x.shape[-1] // 2, x.shape[-1] // 2], dtype=np.int64) + + expected = x.copy() + half = x.shape[-1] // 2 + if side == "left": + expected[:, :, :, :half] = x[:, :, :, half:] + expected[:, :, :, half:] = -x[:, :, :, :half] + else: + expected[:, :, :, :half] = -x[:, :, :, half:] + expected[:, :, :, half:] = x[:, :, :, :half] + + feeds = dict(X=x, splits=splits) + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds)[0] + assert_almost_equal(expected, got) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_rotary_cuda(self): + self._rotary_cuda(TensorProto.FLOAT, "left") + self._rotary_cuda(TensorProto.FLOAT, "right") + self._rotary_cuda(TensorProto.FLOAT16, "left") + self._rotary_cuda(TensorProto.FLOAT16, "right") + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_bigger_rotary_cuda(self): + sh = (2, 2, 1024, 8) + self._rotary_cuda(TensorProto.FLOAT, "left", input_shape=sh) + self._rotary_cuda(TensorProto.FLOAT, "right", input_shape=sh) + self._rotary_cuda(TensorProto.FLOAT16, "left", input_shape=sh) + self._rotary_cuda(TensorProto.FLOAT16, "right", input_shape=sh) + def _transpose_cast_cuda(self, itype): dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16