diff --git a/operators/cuda/add_mul.h b/operators/cuda/add_mul.h index a4f1fabaf..badac483e 100644 --- a/operators/cuda/add_mul.h +++ b/operators/cuda/add_mul.h @@ -8,6 +8,36 @@ namespace contrib { +inline void _FillOutputShape3Op(std::vector& dimsA, + std::vector& dimsB, + std::vector& dimsC, + std::vector& output_dims) { + auto max_rank = std::max(dimsA.size(), std::max(dimsB.size(), dimsC.size())); + while (dimsA.size() < max_rank) + dimsA.insert(dimsA.begin(), 1); + while (dimsB.size() < max_rank) + dimsB.insert(dimsB.begin(), 1); + while (dimsC.size() < max_rank) + dimsC.insert(dimsC.begin(), 1); + + output_dims.resize(dimsA.size()); + for (size_t i = 0; i < dimsA.size(); ++i) { + output_dims[i] = std::max(std::max(dimsA[i], dimsB[i]), dimsC[i]); + if (output_dims[i] == 0) { + ORTX_CXX_API_THROW("One of the input dimensions is null.", ORT_RUNTIME_EXCEPTION); + } + } +} + +/** +* AddOrMulSharedInput(A, B, C) = A + B, A + C ifaddition is true +* AddOrMulSharedInput(A, B, C) = A * B, A * C ifaddition is false +* +* The operator supports broadcast on first dimensions. +* A[1, J] + B[I, J] is supported, +* A[1, J, 1] + B[I, J, K] is not supported, +* In all other cases, all tensors must have the same shape. +*/ template struct AddOrMulSharedInput { template @@ -20,22 +50,19 @@ struct AddOrMulSharedInput { const ortc::Tensor& tensor_c, ortc::Tensor& output_ab, ortc::Tensor& output_ac) const { - const T* input_data_a = tensor_a.Data(); - const T* input_data_b = tensor_b.Data(); - const T* input_data_c = tensor_c.Data(); - auto length_a = tensor_a.NumberOfElement(); auto length_b = tensor_b.NumberOfElement(); auto length_c = tensor_c.NumberOfElement(); + if (0 == length_a || 0 == length_b || 0 == length_c) { + return {}; + } + T* output_data_ab = output_ab.Allocate(length_a <= length_b ? tensor_b.Shape() : tensor_a.Shape()); T* output_data_ac = output_ac.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape()); - if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) { - return {}; - } LaunchAddOrMulSharedInputKernel(reinterpret_cast(ctx->GetCudaStream()), - input_data_a, input_data_b, input_data_c, + tensor_a.Data(), tensor_b.Data(), tensor_c.Data(), output_data_ab, output_data_ac, length_a, length_b, length_c, addition); @@ -43,4 +70,166 @@ struct AddOrMulSharedInput { } }; +/** +* AddOrMulTwice(A, B, C) = A + B + C ifaddition is true +* AddOrMulTwice(A, B, C) = A * B * C ifaddition is false +* +* The operator supports broadcast on first dimensions. +* A[1, J] + B[I, J] is supported, +* A[1, J, 1] + B[I, J, K] is not supported, +*/ +template +struct AddOrMulTwice { + template + OrtxStatus OnModelAttach(const TDict& /*dict*/) { + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& tensor_a, + const ortc::Tensor& tensor_b, + const ortc::Tensor& tensor_c, + ortc::Tensor& output) const { + auto length_a = tensor_a.NumberOfElement(); + auto length_b = tensor_b.NumberOfElement(); + auto length_c = tensor_c.NumberOfElement(); + + if (0 == length_a || 0 == length_b || 0 == length_c) { + return {}; + } + + std::vector dimsA = tensor_a.Shape(); + std::vector dimsB = tensor_b.Shape(); + std::vector dimsC = tensor_c.Shape(); + std::vector output_dims; + _FillOutputShape3Op(dimsA, dimsB, dimsC, output_dims); + + T* output_data = output.Allocate(output_dims); + + LaunchAddOrMulTwiceKernel(reinterpret_cast(ctx->GetCudaStream()), + tensor_a.Data(), tensor_b.Data(), tensor_c.Data(), + output_data, + length_a, length_b, length_c, + addition); + return {}; + } +}; + +/** +* AddAndMul(A, B, C) = (A + B) * C if addition_first is true +* AddAndMul(A, B, C) = A * B + C if addition_first is false +* +* The operator supports broadcast on first dimensions. +* A[1, J] + B[I, J] is supported, +* A[1, J, 1] + B[I, J, K] is not supported, +* +* If switchMiddleAxis is true, then the output is transposed, then +* AddAndMul(A, B, C, switchMiddleAxis=1) = Transpose((A + B) * C, perm=[0, 2, 1, 3]) +*/ +template +struct AddAndMul { + template + OrtxStatus OnModelAttach(const TDict& dict) { + int64_t default_value = 0; + switchMiddelAxis_ = dict.TryToGetAttributeWithDefault("switchMiddleAxis", default_value) == 1; + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& tensor_a, + const ortc::Tensor& tensor_b, + const ortc::Tensor& tensor_c, + ortc::Tensor& output) const { + auto length_a = tensor_a.NumberOfElement(); + auto length_b = tensor_b.NumberOfElement(); + auto length_c = tensor_c.NumberOfElement(); + + if (0 == length_a || 0 == length_b || 0 == length_c) { + return {}; + } + + std::vector dimsA = tensor_a.Shape(); + std::vector dimsB = tensor_b.Shape(); + std::vector dimsC = tensor_c.Shape(); + std::vector output_dims; + _FillOutputShape3Op(dimsA, dimsB, dimsC, output_dims); + + if (switchMiddelAxis_) { + if (output_dims.size() != 4) { + ORTX_CXX_API_THROW("switchMiddleAxes only works with 4D tensors", ORT_RUNTIME_EXCEPTION); + } + int64_t d4 = output_dims[output_dims.size() - 1]; + int64_t d3 = output_dims[output_dims.size() - 2]; + int64_t d2 = output_dims[output_dims.size() - 3]; + output_dims[1] = d3; + output_dims[2] = d2; + T* output_data = output.Allocate(output_dims); + LaunchAddAndMulSwitchMiddleAxesKernel(reinterpret_cast(ctx->GetCudaStream()), + tensor_a.Data(), tensor_b.Data(), tensor_c.Data(), + output_data, + length_a, length_b, length_c, + addition_first, d2, d3, d4); + } else { + T* output_data = output.Allocate(output_dims); + LaunchAddAndMulKernel(reinterpret_cast(ctx->GetCudaStream()), + tensor_a.Data(), tensor_b.Data(), tensor_c.Data(), + output_data, + length_a, length_b, length_c, + addition_first); + } + return {}; + } + + private: + bool switchMiddelAxis_; +}; + +/** +* SubAndMul(A, B, C) = (A - B) * C if subtract_first is true +* SubAndMul(A, B, C) = A * B - C if subtract_first is false +* +* The operator supports broadcast on first dimensions. +* A[1, J] + B[I, J] is supported, +* A[1, J, 1] + B[I, J, K] is not supported, +* +* If negative is true, then +* SubAndMul(A, B, C, negative=1) = (B - A) * C +*/ +template +struct SubAndMul { + template + OrtxStatus OnModelAttach(const TDict& dict) { + int64_t default_value = 0; + negative_ = dict.TryToGetAttributeWithDefault("negative", default_value) == 1; + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& tensor_a, + const ortc::Tensor& tensor_b, + const ortc::Tensor& tensor_c, + ortc::Tensor& output) const { + auto length_a = tensor_a.NumberOfElement(); + auto length_b = tensor_b.NumberOfElement(); + auto length_c = tensor_c.NumberOfElement(); + if (0 == length_a || 0 == length_b || 0 == length_c) { + return {}; + } + + std::vector dimsA = tensor_a.Shape(); + std::vector dimsB = tensor_b.Shape(); + std::vector dimsC = tensor_c.Shape(); + std::vector output_dims; + _FillOutputShape3Op(dimsA, dimsB, dimsC, output_dims); + T* output_data = output.Allocate(output_dims); + + LaunchSubAndMulKernel(reinterpret_cast(ctx->GetCudaStream()), + tensor_a.Data(), tensor_b.Data(), tensor_c.Data(), + output_data, + length_a, length_b, length_c, + subtract_first, negative_); + return {}; + } + + private: + bool negative_; +}; + } // namespace contrib \ No newline at end of file diff --git a/operators/cuda/add_mul_impl.cu b/operators/cuda/add_mul_impl.cu index 85f55bc77..20919c467 100644 --- a/operators/cuda/add_mul_impl.cu +++ b/operators/cuda/add_mul_impl.cu @@ -12,12 +12,12 @@ using namespace Ort::Custom; -__device__ __forceinline__ void _add3_op(float* ab, float* ac, const float a, const float b, const float c) { +__device__ __forceinline__ void _add3_2_op(float* ab, float* ac, const float a, const float b, const float c) { *ab = a + b; *ac = a + c; } -__device__ __forceinline__ void _add3_op(half* ab, half* ac, const half a, const half b, const half c) { +__device__ __forceinline__ void _add3_2_op(half* ab, half* ac, const half a, const half b, const half c) { #if __CUDA_ARCH__ < 700 *ab = __float2half(__half2float(a) + __half2float(b)); *ac = __float2half(__half2float(a) + __half2float(c)); @@ -27,12 +27,12 @@ __device__ __forceinline__ void _add3_op(half* ab, half* ac, const half a, const #endif } -__device__ __forceinline__ void _mul3_op(float* ab, float* ac, const float a, const float b, const float c) { +__device__ __forceinline__ void _mul3_2_op(float* ab, float* ac, const float a, const float b, const float c) { *ab = a * b; *ac = a * c; } -__device__ __forceinline__ void _mul3_op(half* ab, half* ac, const half a, const half b, const half c) { +__device__ __forceinline__ void _mul3_2_op(half* ab, half* ac, const half a, const half b, const half c) { #if __CUDA_ARCH__ < 700 *ab = __float2half(__half2float(a) * __half2float(b)); *ac = __float2half(__half2float(a) * __half2float(c)); @@ -45,21 +45,21 @@ __device__ __forceinline__ void _mul3_op(half* ab, half* ac, const half a, const template struct Mul3SharedOp { __device__ __forceinline__ void operator()(T* ab, T* ac, const T a, const T b, const T c) const { - _mul3_op(ab, ac, a, b, c); + _mul3_2_op(ab, ac, a, b, c); } }; template struct Add3SharedOp { __device__ __forceinline__ void operator()(T* ab, T* ac, const T a, const T b, const T c) const { - _add3_op(ab, ac, a, b, c); + _add3_2_op(ab, ac, a, b, c); } }; template -__global__ void AddMulKernel(T* output_ab, T* output_ac, const T* pA, const T* pB, - const T* pC, CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, - CUDA_LONG N, const TFunc func) { +__global__ void AddMulSharedInputKernel(T* output_ab, T* output_ac, const T* pA, const T* pB, + const T* pC, CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, + CUDA_LONG N, const TFunc func) { CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; CUDA_LONG id = start; #pragma unroll @@ -89,14 +89,14 @@ cudaError_t _LaunchAddOrMulSharedInputKernel(cudaStream_t stream, using TT = typename contrib::CudaT::MappedType; if (addition) { - AddMulKernel, num_threads_per_block, num_elements_per_thread> + AddMulSharedInputKernel, num_threads_per_block, num_elements_per_thread> <<>>( reinterpret_cast(output_ab), reinterpret_cast(output_ac), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), static_cast(countA), static_cast(countB), static_cast(countC), static_cast(max_count), Add3SharedOp()); } else { - AddMulKernel, num_threads_per_block, num_elements_per_thread> + AddMulSharedInputKernel, num_threads_per_block, num_elements_per_thread> <<>>( reinterpret_cast(output_ab), reinterpret_cast(output_ac), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), static_cast(countA), @@ -107,15 +107,515 @@ cudaError_t _LaunchAddOrMulSharedInputKernel(cudaStream_t stream, } template <> -cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c, +cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, + const float* input_a, const float* input_b, const float* input_c, float* output_ab, float* output_ac, int64_t length_a, int64_t length_b, int64_t length_c, bool addition) { - return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, output_ab, output_ac, length_a, length_b, length_c, addition); + return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, + output_ab, output_ac, + length_a, length_b, length_c, addition); } template <> -cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c, +cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, + const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c, ortc::MFloat16* output_ab, ortc::MFloat16* output_ac, int64_t length_a, int64_t length_b, int64_t length_c, bool addition) { - return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, output_ab, output_ac, length_a, length_b, length_c, addition); + return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, + output_ab, output_ac, + length_a, length_b, length_c, addition); +} + +__device__ __forceinline__ void _add3_op(float* address, const float a, const float b, + const float c) { + *address = a + b + c; +} + +__device__ __forceinline__ void _add3_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half(__half2float(a) + __half2float(b) + __half2float(c)); +#else + *address = a + b + c; +#endif +} + +__device__ __forceinline__ void _mul3_op(float* address, const float a, const float b, + const float c) { + *address = a * b * c; +} + +__device__ __forceinline__ void _mul3_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half(__half2float(a) * __half2float(b) * __half2float(c)); +#else + *address = a * b * c; +#endif +} + +template +struct Mul3Op { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _mul3_op(address, a, b, c); + } +}; + +template +struct Add3Op { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _add3_op(address, a, b, c); + } +}; + +template +__global__ void AddMulTwiceKernel(T* output_data, const T* pA, const T* pB, const T* pC, + CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, CUDA_LONG N, + const TFunc func) { + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + CUDA_LONG id = start; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + func(output_data + id, pA[id % nA], pB[id % nB], pC[id % nC]); + id += NumThreadsPerBlock; + } + } +} + +template +cudaError_t _LaunchAddOrMulTwiceKernel(cudaStream_t stream, + const T* pA, const T* pB, const T* pC, + T* output, + int64_t countA, int64_t countB, int64_t countC, bool addition) { + int64_t max_count = std::max(std::max(countA, countB), countC); + if (max_count == 0) // special case where there's a dim value of 0 in the output shape + return cudaGetLastError(); + + const int num_elements_per_thread = 4; + const int num_threads_per_block = 256; + const int num_el_th = num_threads_per_block * num_elements_per_thread; + + int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th; + + using TT = typename contrib::CudaT::MappedType; + + if (addition) { + AddMulTwiceKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + reinterpret_cast(output), + reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), + static_cast(countA), static_cast(countB), static_cast(countC), + static_cast(max_count), Add3Op()); + } else { + AddMulTwiceKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + reinterpret_cast(output), + reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), static_cast(countA), + static_cast(countB), static_cast(countC), + static_cast(max_count), Mul3Op()); + } + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchAddOrMulTwiceKernel(cudaStream_t stream, + const float* input_a, const float* input_b, const float* input_c, + float* output, + int64_t length_a, int64_t length_b, int64_t length_c, + bool addition) { + return _LaunchAddOrMulTwiceKernel(stream, input_a, input_b, input_c, + output, + length_a, length_b, length_c, addition); +} + +template <> +cudaError_t LaunchAddOrMulTwiceKernel(cudaStream_t stream, + const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, + const ortc::MFloat16* input_c, + ortc::MFloat16* output, + int64_t length_a, int64_t length_b, int64_t length_c, + bool addition) { + return _LaunchAddOrMulTwiceKernel(stream, input_a, input_b, input_c, + output, + length_a, length_b, length_c, addition); +} + +__device__ __forceinline__ void _addmul_op(float* address, const float a, const float b, + const float c) { + *address = (a + b) * c; +} + +__device__ __forceinline__ void _addmul_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half((__half2float(a) + __half2float(b)) * __half2float(c)); +#else + *address = (a + b) * c; +#endif +} + +__device__ __forceinline__ void _muladd_op(float* address, const float a, const float b, + const float c) { + *address = a * b + c; +} + +__device__ __forceinline__ void _muladd_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half(__half2float(a) * __half2float(b) + __half2float(c)); +#else + *address = a * b + c; +#endif +} + +template +struct AddMul { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _addmul_op(address, a, b, c); + } +}; + +template +struct MulAdd { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _muladd_op(address, a, b, c); + } +}; + +template +__global__ void AddAndMulKernel(T* output_data, const T* pA, const T* pB, const T* pC, + CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, CUDA_LONG N, + const TFunc func) { + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + CUDA_LONG id = start; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + func(output_data + id, pA[id % nA], pB[id % nB], pC[id % nC]); + id += NumThreadsPerBlock; + } + } +} + +template +__global__ void AddAndMulSwitchMiddleAxesKernel(T* output_data, const T* pA, const T* pB, + const T* pC, CUDA_LONG nA, CUDA_LONG nB, + CUDA_LONG nC, CUDA_LONG N, + const TFunc func, CUDA_LONG d2, + CUDA_LONG d3, CUDA_LONG d4) { + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + CUDA_LONG id = start; + CUDA_LONG k, j, ido; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + k = (id / d4) % d3; + j = (id / (d4 * d3)) % d2; + ido = id + d4 * ((k * d2 + j) - (j * d3 + k)); + func(output_data + ido, pA[id % nA], pB[id % nB], pC[id % nC]); + id += NumThreadsPerBlock; + } + } +} + +template +cudaError_t _LaunchAddAndMulKernel(cudaStream_t stream, + const T* pA, const T* pB, const T* pC, + T* output, + int64_t countA, int64_t countB, int64_t countC, + bool addition_first) { + int64_t max_count = std::max(std::max(countA, countB), countC); + if (max_count == 0) // special case where there's a dim value of 0 in the output shape + return cudaGetLastError(); + + const int num_elements_per_thread = 4; + const int num_threads_per_block = 256; + const int num_el_th = num_threads_per_block * num_elements_per_thread; + + int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th; + + using TT = typename contrib::CudaT::MappedType; + + if (addition_first) { + AddAndMulKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + AddMul()); + } else { + AddAndMulKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + MulAdd()); + } + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchAddAndMulKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c, + float* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition_first) { + return _LaunchAddAndMulKernel(stream, input_a, input_b, input_c, output, length_a, length_b, length_c, addition_first); +} + +template <> +cudaError_t LaunchAddAndMulKernel(cudaStream_t stream, + const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, + const ortc::MFloat16* input_c, + ortc::MFloat16* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition_first) { + return _LaunchAddAndMulKernel(stream, input_a, input_b, input_c, output, length_a, length_b, length_c, addition_first); +} + +template +cudaError_t _LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, + const T* pA, const T* pB, const T* pC, + T* output, + int64_t countA, int64_t countB, int64_t countC, + bool addition_first, int64_t d2, int64_t d3, int64_t d4) { + int64_t max_count = std::max(std::max(countA, countB), countC); + if (max_count == 0) // special case where there's a dim value of 0 in the output shape + return cudaGetLastError(); + + const int num_elements_per_thread = 4; + const int num_threads_per_block = 256; + const int num_el_th = num_threads_per_block * num_elements_per_thread; + + int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th; + + using TT = typename contrib::CudaT::MappedType; + + if (addition_first) { + AddAndMulSwitchMiddleAxesKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + AddMul(), d2, d3, d4); + } else { + AddAndMulSwitchMiddleAxesKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + MulAdd(), d2, d3, d4); + } + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c, + float* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition_first, + int64_t d2, int64_t d3, int64_t d4) { + return _LaunchAddAndMulSwitchMiddleAxesKernel(stream, input_a, input_b, input_c, output, length_a, length_b, length_c, + addition_first, d2, d3, d4); +} + +template <> +cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const ortc::MFloat16* input_a, + const ortc::MFloat16* input_b, const ortc::MFloat16* input_c, + ortc::MFloat16* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition_first, + int64_t d2, int64_t d3, int64_t d4) { + return _LaunchAddAndMulSwitchMiddleAxesKernel(stream, input_a, input_b, input_c, output, length_a, length_b, length_c, + addition_first, d2, d3, d4); +} + +__device__ __forceinline__ void _submul_op(float* address, const float a, const float b, + const float c) { + *address = (a - b) * c; +} + +__device__ __forceinline__ void _submul_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half((__half2float(a) - __half2float(b)) * __half2float(c)); +#else + *address = (a - b) * c; +#endif +} + +__device__ __forceinline__ void _submul_neg_op(float* address, const float a, const float b, + const float c) { + *address = (b - a) * c; +} + +__device__ __forceinline__ void _submul_neg_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half((__half2float(b) - __half2float(a)) * __half2float(c)); +#else + *address = (b - a) * c; +#endif +} + +__device__ __forceinline__ void _mulsub_op(float* address, const float a, const float b, + const float c) { + *address = a * b - c; +} + +__device__ __forceinline__ void _mulsub_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half(__half2float(a) * __half2float(b) - __half2float(c)); +#else + *address = a * b - c; +#endif +} + +__device__ __forceinline__ void _mulsub_neg_op(float* address, const float a, const float b, + const float c) { + *address = c - a * b; +} + +__device__ __forceinline__ void _mulsub_neg_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half(__half2float(c) - __half2float(a) * __half2float(b)); +#else + *address = c - a * b; +#endif +} + +template +struct SubMul { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _submul_op(address, a, b, c); + } +}; + +template +struct MulSub { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _mulsub_op(address, a, b, c); + } +}; + +template +struct SubMulNeg { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _submul_neg_op(address, a, b, c); + } +}; + +template +struct MulSubNeg { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _mulsub_neg_op(address, a, b, c); + } +}; + +template +cudaError_t _LaunchSubAndMulKernel(cudaStream_t stream, + const T* pA, const T* pB, const T* pC, + T* output, + int64_t countA, int64_t countB, int64_t countC, + bool addition_first, bool negative) { + int64_t max_count = std::max(std::max(countA, countB), countC); + if (max_count == 0) // special case where there's a dim value of 0 in the output shape + return cudaGetLastError(); + + const int num_elements_per_thread = 4; + const int num_threads_per_block = 256; + const int num_el_th = num_threads_per_block * num_elements_per_thread; + + int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th; + + using TT = typename contrib::CudaT::MappedType; + + if (addition_first) { + if (negative) { + AddAndMulKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + SubMulNeg()); + } else { + AddAndMulKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + SubMul()); + } + } else { + if (negative) { + AddAndMulKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + MulSubNeg()); + } else { + AddAndMulKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + MulSub()); + } + } + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchSubAndMulKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c, + float* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool subtract_first, bool negative) { + return _LaunchSubAndMulKernel(stream, input_a, input_b, input_c, output, length_a, length_b, length_c, subtract_first, negative); +} + +template <> +cudaError_t LaunchSubAndMulKernel(cudaStream_t stream, + const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, + const ortc::MFloat16* input_c, + ortc::MFloat16* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool subtract_first, bool negative) { + return _LaunchSubAndMulKernel(stream, input_a, input_b, input_c, output, length_a, length_b, length_c, subtract_first, negative); } diff --git a/operators/cuda/add_mul_impl.cuh b/operators/cuda/add_mul_impl.cuh index 9bf3ad853..2c5b000ec 100644 --- a/operators/cuda/add_mul_impl.cuh +++ b/operators/cuda/add_mul_impl.cuh @@ -8,4 +8,24 @@ template cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, T* output_ab, T* output_ac, - int64_t length_a, int64_t length_b, int64_t length_c, bool addition); \ No newline at end of file + int64_t length_a, int64_t length_b, int64_t length_c, bool addition); + +template +cudaError_t LaunchAddOrMulTwiceKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, + T* output, int64_t length_a, int64_t length_b, int64_t length_c, bool addition); + +template +cudaError_t LaunchAddAndMulKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, + T* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition); + +template +cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, + T* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition, + int64_t d2, int64_t d3, int64_t d4); + +template +cudaError_t LaunchSubAndMulKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, + T* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition, bool negative); diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index 8fc3105b9..fd4bb9f90 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -17,9 +17,28 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { using AddSharedInputFloat32Type = typename contrib::AddOrMulSharedInput; using MulSharedInputFloat32Type = typename contrib::AddOrMulSharedInput; + using AddTwiceFloat32Type = typename contrib::AddOrMulTwice; + using MulTwiceFloat32Type = typename contrib::AddOrMulTwice; + + using AddAndMulFloat32Type = typename contrib::AddAndMul; + using MulAndAddFloat32Type = typename contrib::AddAndMul; + + using SubAndMulFloat32Type = typename contrib::SubAndMul; + using MulAndSubFloat32Type = typename contrib::SubAndMul; + #if ORT_API_VERSION >= 16 using AddSharedInputFloat16Type = typename contrib::AddOrMulSharedInput; using MulSharedInputFloat16Type = typename contrib::AddOrMulSharedInput; + + using AddTwiceFloat16Type = typename contrib::AddOrMulTwice; + using MulTwiceFloat16Type = typename contrib::AddOrMulTwice; + + using AddAndMulFloat16Type = typename contrib::AddAndMul; + using MulAndAddFloat16Type = typename contrib::AddAndMul; + + using SubAndMulFloat16Type = typename contrib::SubAndMul; + using MulAndSubFloat16Type = typename contrib::SubAndMul; + using Transpose2DCastFloat32ToFloat16Type = typename contrib::Transpose2DCast; using Transpose2DCastFloat16ToFloat32Type = typename contrib::Transpose2DCast; #endif @@ -28,27 +47,39 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { []() { return nullptr; } #ifdef USE_CUDA , + CustomCudaStructV2("AddAdd", AddTwiceFloat32Type), + CustomCudaStructV2("AddMul", AddAndMulFloat32Type), CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape), - CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type), + CustomCudaStructV2("MulAdd", MulAndAddFloat32Type), + CustomCudaStructV2("MulMul", MulTwiceFloat32Type), CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid), + CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type), CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid), + CustomCudaStructV2("MulSub", MulAndSubFloat32Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero), CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape), + CustomCudaStructV2("SubMul", SubAndMulFloat32Type), #if ORT_API_VERSION >= 16 + CustomCudaStructV2("AddAdd", AddTwiceFloat16Type), + CustomCudaStructV2("AddMul", AddAndMulFloat16Type), CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape), - CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type), + CustomCudaStructV2("MulAdd", MulAndAddFloat16Type), + CustomCudaStructV2("MulMul", MulTwiceFloat16Type), CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid), + CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type), CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid), + CustomCudaStructV2("MulSub", MulAndSubFloat16Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero), CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape), + CustomCudaStructV2("SubMul", SubAndMulFloat16Type), CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type), CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type) #endif diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index 43233a26b..fabc4e349 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -46,6 +46,83 @@ def _run(self, X): class TestCudaOps(unittest.TestCase): + def _addaddmulmul_cuda(self, itype, op_type, broad=False): + model1 = helper.make_model( + helper.make_graph( + [ + helper.make_node(op_type, ["X", "Y"], ["xy"]), + helper.make_node(op_type, ["xy", "Z"], ["final"]), + ], + "nd", + [ + helper.make_tensor_value_info("X", itype, [None, None, None]), + helper.make_tensor_value_info("Y", itype, [None, None, None]), + helper.make_tensor_value_info("Z", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("final", itype, [None, None, None])], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + f"{op_type}{op_type}", + ["X", "Y", "Z"], + ["final"], + domain="ai.onnx.contrib", + ) + ], + "nd", + [ + helper.make_tensor_value_info("X", itype, [None, None, None]), + helper.make_tensor_value_info("Y", itype, [None, None, None]), + helper.make_tensor_value_info("Z", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("final", itype, [None, None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + shapex = (1, 2, 3) if broad else (3, 2, 3) + shapey = (3, 2, 3) + shapez = (1, 2, 3) if broad else (3, 2, 3) + x = (np.arange(np.prod(shapex)) + 1).reshape(shapex).astype(dtype) + y = (np.arange(np.prod(shapey)) + 10).reshape(shapey).astype(dtype) + z = (np.arange(np.prod(shapez)) + 100).reshape(shapez).astype(dtype) + + feeds1 = dict(X=x, Y=y, Z=z) + ref = ReferenceEvaluator(model1) + expected = ref.run(None, feeds1)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds1)[0] + assert_almost_equal(expected, got) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_mulmul_cuda(self): + self._addaddmulmul_cuda(TensorProto.FLOAT, "Mul") + self._addaddmulmul_cuda(TensorProto.FLOAT16, "Mul") + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_mulmul_cuda_broadcast(self): + self._addaddmulmul_cuda(TensorProto.FLOAT, "Mul", True) + self._addaddmulmul_cuda(TensorProto.FLOAT16, "Mul", True) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_addadd_cuda(self): + self._addaddmulmul_cuda(TensorProto.FLOAT, "Add") + self._addaddmulmul_cuda(TensorProto.FLOAT16, "Add") + @staticmethod def _create_negpos_test_model(domain="ai.onnx.contrib"): nodes = [ @@ -647,6 +724,110 @@ def _transpose_cast_cuda(self, itype): got = sess.run(None, feeds1)[0] assert_almost_equal(expected, got, decimal=5) + def _addmul_cuda(self, itype, op_type1, op_type2, broad=False, negative=False): + model1 = helper.make_model( + helper.make_graph( + [ + helper.make_node(op_type1, ["Y", "X"] if negative else ["X", "Y"], ["xy"]), + helper.make_node(op_type2, ["Z", "xy"] if negative else ["xy", "Z"], ["final"]), + ], + "nd", + [ + helper.make_tensor_value_info("X", itype, [None, None, None]), + helper.make_tensor_value_info("Y", itype, [None, None, None]), + helper.make_tensor_value_info("Z", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("final", itype, [None, None, None])], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + kwargs = {"negative": 1} if negative else {} + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + f"{op_type1}{op_type2}", + ["X", "Y", "Z"], + ["final"], + domain="ai.onnx.contrib", + **kwargs, + ) + ], + "nd", + [ + helper.make_tensor_value_info("X", itype, [None, None, None]), + helper.make_tensor_value_info("Y", itype, [None, None, None]), + helper.make_tensor_value_info("Z", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("final", itype, [None, None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + shapex = (1, 2, 3) if broad else (3, 2, 3) + shapey = (3, 2, 3) + shapez = (1, 2, 3) if broad else (3, 2, 3) + x = (np.arange(np.prod(shapex)) + 1).reshape(shapex).astype(dtype) + y = (np.arange(np.prod(shapey)) + 1).reshape(shapey).astype(dtype) + z = (np.arange(np.prod(shapez)) + 1).reshape(shapez).astype(dtype) + + feeds1 = dict(X=x, Y=y, Z=z) + ref = ReferenceEvaluator(model1, verbose=0) + expected = ref.run(None, feeds1)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds1)[0] + assert_almost_equal(expected, got) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_addmul_cuda(self): + self._addmul_cuda(TensorProto.FLOAT, "Add", "Mul") + self._addmul_cuda(TensorProto.FLOAT16, "Add", "Mul") + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_addmul_cuda_broadcast(self): + self._addmul_cuda(TensorProto.FLOAT, "Add", "Mul", True) + self._addmul_cuda(TensorProto.FLOAT16, "Add", "Mul", True) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_muladd_cuda(self): + self._addmul_cuda(TensorProto.FLOAT, "Mul", "Add") + self._addmul_cuda(TensorProto.FLOAT16, "Mul", "Add") + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_submul_cuda(self): + self._addmul_cuda(TensorProto.FLOAT, "Sub", "Mul") + self._addmul_cuda(TensorProto.FLOAT16, "Sub", "Mul") + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_submul_cuda_negative(self): + self._addmul_cuda(TensorProto.FLOAT, "Sub", "Mul", negative=True) + self._addmul_cuda(TensorProto.FLOAT16, "Sub", "Mul", negative=True) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_submul_cuda_broadcast(self): + self._addmul_cuda(TensorProto.FLOAT, "Sub", "Mul", True) + self._addmul_cuda(TensorProto.FLOAT16, "Sub", "Mul", True) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_mulsub_cuda(self): + self._addmul_cuda(TensorProto.FLOAT, "Mul", "Sub") + self._addmul_cuda(TensorProto.FLOAT16, "Mul", "Sub") + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_mulsub_cuda_negative(self): + self._addmul_cuda(TensorProto.FLOAT, "Mul", "Sub", negative=True) + self._addmul_cuda(TensorProto.FLOAT16, "Mul", "Sub", negative=True) + @unittest.skipIf(not has_cuda(), reason="cuda not available") def test_transpose_cast_cuda(self): self._transpose_cast_cuda(TensorProto.FLOAT)