From 81270043fffa425a016f6a186a7d3f08e2c14fa2 Mon Sep 17 00:00:00 2001 From: PaulusParssinen Date: Wed, 22 Jan 2025 02:49:58 +0200 Subject: [PATCH 1/2] Initial BITOP refactor --- .../TensorPrimitives.IBinaryOperator.cs | 1108 +++++++++++++++++ libs/server/Resp/Bitmap/BitmapManagerBitOp.cs | 696 +++-------- test/Garnet.test/GarnetBitmapTests.cs | 104 +- 3 files changed, 1314 insertions(+), 594 deletions(-) create mode 100644 libs/common/Numerics/TensorPrimitives.IBinaryOperator.cs diff --git a/libs/common/Numerics/TensorPrimitives.IBinaryOperator.cs b/libs/common/Numerics/TensorPrimitives.IBinaryOperator.cs new file mode 100644 index 0000000000..6091e0ab30 --- /dev/null +++ b/libs/common/Numerics/TensorPrimitives.IBinaryOperator.cs @@ -0,0 +1,1108 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System.Diagnostics; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; + +namespace Garnet.common.Numerics +{ + public static unsafe partial class TensorPrimitives + { + /// x & y + public readonly struct BitwiseAndOperator : IBinaryOperator where T : IBitwiseOperators + { + public static T Invoke(T x, T y) => x & y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x & y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x & y; + public static Vector512 Invoke(Vector512 x, Vector512 y) => x & y; + } + + /// x | y + public readonly struct BitwiseOrOperator : IBinaryOperator where T : IBitwiseOperators + { + public static T Invoke(T x, T y) => x | y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x | y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x | y; + public static Vector512 Invoke(Vector512 x, Vector512 y) => x | y; + } + + /// x ^ y + public readonly struct BitwiseXorOperator : IBinaryOperator where T : IBitwiseOperators + { + public static T Invoke(T x, T y) => x ^ y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x ^ y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x ^ y; + public static Vector512 Invoke(Vector512 x, Vector512 y) => x ^ y; + } + + /// Operator that takes two input values and returns a single value. + public interface IBinaryOperator + { + static abstract T Invoke(T x, T y); + static abstract Vector128 Invoke(Vector128 x, Vector128 y); + static abstract Vector256 Invoke(Vector256 x, Vector256 y); + static abstract Vector512 Invoke(Vector512 x, Vector512 y); + } + + // TODO: Remove, no attempt to use yet in this PR + public static void UnsafeInvokeOperator( + T* xPtr, T* yPtr, T* dPtr, int length) + where T : unmanaged + where TBinaryOperator : struct, IBinaryOperator + { + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + nuint remainder = (uint)length; + + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported) + { + if (remainder >= (uint)Vector512.Count) + { + Vectorized512(ref xPtr, ref yPtr, ref dPtr, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xPtr, ref yPtr, ref dPtr, remainder); + } + + return; + } + + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported) + { + if (remainder >= (uint)Vector256.Count) + { + Vectorized256(ref xPtr, ref yPtr, ref dPtr, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xPtr, ref yPtr, ref dPtr, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported) + { + if (remainder >= (uint)Vector128.Count) + { + Vectorized128(ref xPtr, ref yPtr, ref dPtr, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xPtr, ref yPtr, ref dPtr, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(xPtr, yPtr, dPtr, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(T* xPtr, T* yPtr, T* dPtr, nuint length) + { + for (nuint i = 0; i < length; i++) + { + *(dPtr + i) = TBinaryOperator.Invoke(*(xPtr + i), *(yPtr + i)); + } + } + + static void Vectorized128(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) + { + ref T* dPtrBeg = ref dPtr; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); + Vector128 end = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)Vector128.Count), + Vector128.Load(yPtr + remainder - (uint)Vector128.Count)); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)sizeof(Vector128) - ((nuint)dPtr % (uint)sizeof(Vector128))) / (uint)sizeof(T); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector128)) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)Vector128.Count) + { + case 8: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 8)), + Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 8))); + vector.Store(dPtr + remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 7))); + vector.Store(dPtr + remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 6))); + vector.Store(dPtr + remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 5))); + vector.Store(dPtr + remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 4))); + vector.Store(dPtr + remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 3))); + vector.Store(dPtr + remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 2))); + vector.Store(dPtr + remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.Store(dPtr + endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.Store(dPtrBeg); + break; + } + } + } + + static void Vectorized256(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) + { + ref T* dPtrBeg = ref dPtr; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); + Vector256 end = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)Vector256.Count), + Vector256.Load(yPtr + remainder - (uint)Vector256.Count)); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)sizeof(Vector256) - ((nuint)dPtr % (uint)sizeof(Vector256))) / (uint)sizeof(T); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector256)) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)Vector256.Count) + { + case 8: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 8)), + Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 8))); + vector.Store(dPtr + remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 7))); + vector.Store(dPtr + remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 6))); + vector.Store(dPtr + remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 5))); + vector.Store(dPtr + remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 4))); + vector.Store(dPtr + remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 3))); + vector.Store(dPtr + remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 2))); + vector.Store(dPtr + remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.Store(dPtr + endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.Store(dPtrBeg); + break; + } + } + } + + static void Vectorized512(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) + { + ref T* dPtrBeg = ref dPtr; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TBinaryOperator.Invoke(Vector512.Load(xPtr), Vector512.Load(yPtr)); + Vector512 end = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)Vector512.Count), + Vector512.Load(yPtr + remainder - (uint)Vector512.Count)); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)sizeof(Vector512) - ((nuint)dPtr % (uint)sizeof(Vector512))) / (uint)sizeof(T); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector512)) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)Vector512.Count) + { + case 8: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 8)), + Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 8))); + vector.Store(dPtr + remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 7))); + vector.Store(dPtr + remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 6))); + vector.Store(dPtr + remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 5))); + vector.Store(dPtr + remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 4))); + vector.Store(dPtr + remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 3))); + vector.Store(dPtr + remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 2))); + vector.Store(dPtr + remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.Store(dPtr + endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.Store(dPtrBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) + { + if (sizeof(T) == 1) + { + VectorizedSmall1(ref xPtr, ref yPtr, ref dPtr, remainder); + } + else if (sizeof(T) == 2) + { + VectorizedSmall2(ref xPtr, ref yPtr, ref dPtr, remainder); + } + else if (sizeof(T) == 4) + { + VectorizedSmall4(ref xPtr, ref yPtr, ref dPtr, remainder); + } + else + { + Debug.Assert(sizeof(T) == 8); + VectorizedSmall8(ref xPtr, ref yPtr, ref dPtr, remainder); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall1(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) + { + Debug.Assert(sizeof(T) == 1); + + switch (remainder) + { + // Two Vector256's worth of data, with at least one element overlapping. + case 63: + case 62: + case 61: + case 60: + case 59: + case 58: + case 57: + case 56: + case 55: + case 54: + case 53: + case 52: + case 51: + case 50: + case 49: + case 48: + case 47: + case 46: + case 45: + case 44: + case 43: + case 42: + case 41: + case 40: + case 39: + case 38: + case 37: + case 36: + case 35: + case 34: + case 33: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); + Vector256 end = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)Vector256.Count), + Vector256.Load(yPtr + remainder - (uint)Vector256.Count)); + + beg.Store(dPtr); + end.Store(dPtr + remainder - (uint)Vector256.Count); + + break; + } + + // One Vector256's worth of data. + case 32: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); + beg.Store(dPtr); + + break; + } + + // Two Vector128's worth of data, with at least one element overlapping. + case 31: + case 30: + case 29: + case 28: + case 27: + case 26: + case 25: + case 24: + case 23: + case 22: + case 21: + case 20: + case 19: + case 18: + case 17: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); + Vector128 end = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)Vector128.Count), + Vector128.Load(yPtr + remainder - (uint)Vector128.Count)); + + beg.Store(dPtr); + end.Store(dPtr + remainder - (uint)Vector128.Count); + + break; + } + + // One Vector128's worth of data. + case 16: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); + beg.Store(dPtr); + + break; + } + + // Cases that are smaller than a single vector. No SIMD; just jump to the length and fall through each + // case to unroll the whole processing. + case 15: + *(dPtr + 14) = TBinaryOperator.Invoke(*(xPtr + 14), *(yPtr + 14)); + goto case 14; + + case 14: + *(dPtr + 13) = TBinaryOperator.Invoke(*(xPtr + 13), *(yPtr + 13)); + goto case 13; + + case 13: + *(dPtr + 12) = TBinaryOperator.Invoke(*(xPtr + 12), *(yPtr + 12)); + goto case 12; + + case 12: + *(dPtr + 11) = TBinaryOperator.Invoke(*(xPtr + 11), *(yPtr + 11)); + goto case 11; + + case 11: + *(dPtr + 10) = TBinaryOperator.Invoke(*(xPtr + 10), *(yPtr + 10)); + goto case 10; + + case 10: + *(dPtr + 9) = TBinaryOperator.Invoke(*(xPtr + 9), *(yPtr + 9)); + goto case 9; + + case 9: + *(dPtr + 8) = TBinaryOperator.Invoke(*(xPtr + 8), *(yPtr + 8)); + goto case 8; + + case 8: + *(dPtr + 7) = TBinaryOperator.Invoke(*(xPtr + 7), *(yPtr + 7)); + goto case 7; + + case 7: + *(dPtr + 6) = TBinaryOperator.Invoke(*(xPtr + 6), *(yPtr + 6)); + goto case 6; + + case 6: + *(dPtr + 5) = TBinaryOperator.Invoke(*(xPtr + 5), *(yPtr + 5)); + goto case 5; + + case 5: + *(dPtr + 4) = TBinaryOperator.Invoke(*(xPtr + 4), *(yPtr + 4)); + goto case 4; + + case 4: + *(dPtr + 3) = TBinaryOperator.Invoke(*(xPtr + 3), *(yPtr + 3)); + goto case 3; + + case 3: + *(dPtr + 2) = TBinaryOperator.Invoke(*(xPtr + 2), *(yPtr + 2)); + goto case 2; + + case 2: + *(dPtr + 1) = TBinaryOperator.Invoke(*(xPtr + 1), *(yPtr + 1)); + goto case 1; + + case 1: + *dPtr = TBinaryOperator.Invoke(*xPtr, *yPtr); + goto case 0; + + case 0: + break; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall2(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) + { + Debug.Assert(sizeof(T) == 2); + + switch (remainder) + { + // Two Vector256's worth of data, with at least one element overlapping. + case 31: + case 30: + case 29: + case 28: + case 27: + case 26: + case 25: + case 24: + case 23: + case 22: + case 21: + case 20: + case 19: + case 18: + case 17: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); + Vector256 end = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)Vector256.Count), + Vector256.Load(yPtr + remainder - (uint)Vector256.Count)); + + beg.Store(dPtr); + end.Store(dPtr + remainder - (uint)Vector256.Count); + + break; + } + + // One Vector256's worth of data. + case 16: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); + beg.Store(dPtr); + + break; + } + + // Two Vector128's worth of data, with at least one element overlapping. + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); + Vector128 end = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)Vector128.Count), + Vector128.Load(yPtr + remainder - (uint)Vector128.Count)); + + beg.Store(dPtr); + end.Store(dPtr + remainder - (uint)Vector128.Count); + + break; + } + + // One Vector128's worth of data. + case 8: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); + beg.Store(dPtr); + + break; + } + + // Cases that are smaller than a single vector. No SIMD; just jump to the length and fall through each + // case to unroll the whole processing. + case 7: + *(dPtr + 6) = TBinaryOperator.Invoke(*(xPtr + 6), *(yPtr + 6)); + goto case 6; + + case 6: + *(dPtr + 5) = TBinaryOperator.Invoke(*(xPtr + 5), *(yPtr + 5)); + goto case 5; + + case 5: + *(dPtr + 4) = TBinaryOperator.Invoke(*(xPtr + 4), *(yPtr + 4)); + goto case 4; + + case 4: + *(dPtr + 3) = TBinaryOperator.Invoke(*(xPtr + 3), *(yPtr + 3)); + goto case 3; + + case 3: + *(dPtr + 2) = TBinaryOperator.Invoke(*(xPtr + 2), *(yPtr + 2)); + goto case 2; + + case 2: + *(dPtr + 1) = TBinaryOperator.Invoke(*(xPtr + 1), *(yPtr + 1)); + goto case 1; + + case 1: + *dPtr = TBinaryOperator.Invoke(*xPtr, *yPtr); + goto case 0; + + case 0: + break; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall4(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) + { + Debug.Assert(sizeof(T) == 4); + + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); + Vector256 end = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)Vector256.Count), + Vector256.Load(yPtr + remainder - (uint)Vector256.Count)); + + beg.Store(dPtr); + end.Store(dPtr + remainder - (uint)Vector256.Count); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); + beg.Store(dPtr); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); + Vector128 end = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)Vector128.Count), + Vector128.Load(yPtr + remainder - (uint)Vector128.Count)); + + beg.Store(dPtr); + end.Store(dPtr + remainder - (uint)Vector128.Count); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); + beg.Store(dPtr); + + break; + } + + case 3: + { + *(dPtr + 2) = TBinaryOperator.Invoke(*(xPtr + 2), *(yPtr + 2)); + goto case 2; + } + + case 2: + { + *(dPtr + 1) = TBinaryOperator.Invoke(*(xPtr + 1), *(yPtr + 1)); + goto case 1; + } + + case 1: + { + *dPtr = TBinaryOperator.Invoke(*xPtr, *yPtr); + goto case 0; + } + + case 0: + { + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall8(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) + { + Debug.Assert(sizeof(T) == 8); + + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); + Vector256 end = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)Vector256.Count), + Vector256.Load(yPtr + remainder - (uint)Vector256.Count)); + + beg.Store(dPtr); + end.Store(dPtr + remainder - (uint)Vector256.Count); + + break; + } + + case 4: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); + beg.Store(dPtr); + + break; + } + + case 3: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); + Vector128 end = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)Vector128.Count), + Vector128.Load(yPtr + remainder - (uint)Vector128.Count)); + + beg.Store(dPtr); + end.Store(dPtr + remainder - (uint)Vector128.Count); + + break; + } + + case 2: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); + beg.Store(dPtr); + + break; + } + + case 1: + { + *dPtr = TBinaryOperator.Invoke(*xPtr, *yPtr); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Bitmap/BitmapManagerBitOp.cs b/libs/server/Resp/Bitmap/BitmapManagerBitOp.cs index 6917e25993..4f6a3d8b33 100644 --- a/libs/server/Resp/Bitmap/BitmapManagerBitOp.cs +++ b/libs/server/Resp/Bitmap/BitmapManagerBitOp.cs @@ -2,12 +2,16 @@ // Licensed under the MIT license. using System.Runtime.Intrinsics; -using System.Runtime.Intrinsics.X86; using Garnet.common; - +using static Garnet.common.Numerics.TensorPrimitives; namespace Garnet.server { + // TODO: Guard Vector### logic behind IsSupported & IsHardwareAccelerated + // TODO: Add Vector512 & Vector128 paths atleast + // TODO: Get rid of "IBinaryOperator" scalar logic? + // FÒLLOW-UP: Non-temporal stores after sizes larger than 256KB (like in TensorPrimitives) + // FÒLLOW-UP: Investigate alignment -> overlapping & jump-table (like in TensorPrimitives) public unsafe partial class BitmapManager { /// @@ -26,16 +30,16 @@ public static bool BitOpMainUnsafeMultiKey(byte* dstPtr, int dstLen, byte** srcS switch (bitop) { case (byte)BitmapOperation.NOT: - __bitop_multikey_simdX256_not(dstPtr, dstLen, srcStartPtrs[0], srcEndPtrs[0] - srcStartPtrs[0]); + InvokeSingleKeyBitwiseNot(dstPtr, dstLen, srcStartPtrs[0], srcEndPtrs[0] - srcStartPtrs[0]); break; case (byte)BitmapOperation.AND: - __bitop_multikey_simdX256_and(dstPtr, dstLen, srcStartPtrs, srcEndPtrs, srcKeyCount, minSize); + InvokeMultiKeyBitwise, BitwiseAndOperator>(dstPtr, dstLen, srcStartPtrs, srcEndPtrs, srcKeyCount, minSize); break; case (byte)BitmapOperation.OR: - __bitop_multikey_simdX256_or(dstPtr, dstLen, srcStartPtrs, srcEndPtrs, srcKeyCount, minSize); + InvokeMultiKeyBitwise, BitwiseOrOperator>(dstPtr, dstLen, srcStartPtrs, srcEndPtrs, srcKeyCount, minSize); break; case (byte)BitmapOperation.XOR: - __bitop_multikey_simdX256_xor(dstPtr, dstLen, srcStartPtrs, srcEndPtrs, srcKeyCount, minSize); + InvokeMultiKeyBitwise, BitwiseXorOperator>(dstPtr, dstLen, srcStartPtrs, srcEndPtrs, srcKeyCount, minSize); break; default: throw new GarnetException("Unsupported BitOp command"); @@ -44,285 +48,87 @@ public static bool BitOpMainUnsafeMultiKey(byte* dstPtr, int dstLen, byte** srcS } /// - /// Negation bitop implementation using 256-wide SIMD registers. + /// Invokes unary bitwise-NOT operation for single source key using hardware accelerated SIMD intrinsics when possible. /// /// Output buffer to write BitOp result /// Output buffer length. /// Pointer to source bitmap. /// Source bitmap length. - private static void __bitop_multikey_simdX256_not(byte* dstPtr, long dstLen, byte* srcBitmap, long srcLen) + private static void InvokeSingleKeyBitwiseNot(byte* dstPtr, long dstLen, byte* srcBitmap, long srcLen) { - int batchSize = 8 * 32; long slen = srcLen; - long stail = slen & (batchSize - 1); + long remainder = slen & ((Vector256.Count * 8) - 1); //iterate using srcBitmap because always dstLen >= srcLen byte* srcCurr = srcBitmap; - byte* srcEnd = srcCurr + (slen - stail); + byte* srcEnd = srcCurr + (slen - remainder); byte* dstCurr = dstPtr; - #region 8x32 - while (srcCurr < srcEnd) - { - Vector256 d00 = Avx.LoadVector256(srcCurr); - Vector256 d01 = Avx.LoadVector256(srcCurr + 32); - Vector256 d02 = Avx.LoadVector256(srcCurr + 64); - Vector256 d03 = Avx.LoadVector256(srcCurr + 96); - Vector256 d04 = Avx.LoadVector256(srcCurr + 128); - Vector256 d05 = Avx.LoadVector256(srcCurr + 160); - Vector256 d06 = Avx.LoadVector256(srcCurr + 192); - Vector256 d07 = Avx.LoadVector256(srcCurr + 224); - - Avx.Store(dstCurr, Avx2.Xor(d00, Vector256.AllBitsSet)); - Avx.Store(dstCurr + 32, Avx2.Xor(d01, Vector256.AllBitsSet)); - Avx.Store(dstCurr + 64, Avx2.Xor(d02, Vector256.AllBitsSet)); - Avx.Store(dstCurr + 96, Avx2.Xor(d03, Vector256.AllBitsSet)); - Avx.Store(dstCurr + 128, Avx2.Xor(d04, Vector256.AllBitsSet)); - Avx.Store(dstCurr + 160, Avx2.Xor(d05, Vector256.AllBitsSet)); - Avx.Store(dstCurr + 192, Avx2.Xor(d06, Vector256.AllBitsSet)); - Avx.Store(dstCurr + 224, Avx2.Xor(d07, Vector256.AllBitsSet)); - - srcCurr += batchSize; - dstCurr += batchSize; - } - if (stail == 0) return; - #endregion - - #region 1x32 - slen = stail; - batchSize = 1 * 32; - stail = slen & (batchSize - 1); - srcEnd = srcCurr + (slen - stail); while (srcCurr < srcEnd) { - Vector256 d00 = Avx.LoadVector256(srcCurr); - Avx.Store(dstCurr, Avx2.Xor(d00, Vector256.AllBitsSet)); - srcCurr += batchSize; - dstCurr += batchSize; + var d00 = Vector256.Load(srcCurr); + var d01 = Vector256.Load(srcCurr + Vector256.Count); + var d02 = Vector256.Load(srcCurr + (Vector256.Count * 2)); + var d03 = Vector256.Load(srcCurr + (Vector256.Count * 3)); + var d04 = Vector256.Load(srcCurr + (Vector256.Count * 4)); + var d05 = Vector256.Load(srcCurr + (Vector256.Count * 5)); + var d06 = Vector256.Load(srcCurr + (Vector256.Count * 6)); + var d07 = Vector256.Load(srcCurr + (Vector256.Count * 7)); + + Vector256.Store(~d00, dstCurr); + Vector256.Store(~d01, dstCurr + Vector256.Count); + Vector256.Store(~d02, dstCurr + Vector256.Count * 2); + Vector256.Store(~d03, dstCurr + Vector256.Count * 3); + Vector256.Store(~d04, dstCurr + Vector256.Count * 4); + Vector256.Store(~d05, dstCurr + Vector256.Count * 5); + Vector256.Store(~d06, dstCurr + Vector256.Count * 6); + Vector256.Store(~d07, dstCurr + Vector256.Count * 7); + + srcCurr += Vector256.Count * 8; + dstCurr += Vector256.Count * 8; } - if (stail == 0) return; - #endregion - - #region 4x8 - slen = stail; - batchSize = 4 * 8; - stail = slen & (batchSize - 1); - srcEnd = srcCurr + (slen - stail); + if (remainder == 0) return; + + slen = remainder; + remainder = slen & (Vector256.Count - 1); + srcEnd = srcCurr + (slen - remainder); while (srcCurr < srcEnd) { - long d00 = *(long*)(srcCurr); - long d01 = *(long*)(srcCurr + 8); - long d02 = *(long*)(srcCurr + 16); - long d03 = *(long*)(srcCurr + 24); - - *(long*)dstCurr = ~d00; - *(long*)(dstCurr + 8) = ~d01; - *(long*)(dstCurr + 16) = ~d02; - *(long*)(dstCurr + 24) = ~d03; - - srcCurr += batchSize; - dstCurr += batchSize; + Vector256.Store(~Vector256.Load(srcCurr), dstCurr); + + srcCurr += Vector256.Count; + dstCurr += Vector256.Count; } - if (stail == 0) return; - #endregion - - #region 1x8 - slen = stail; - batchSize = 8; - stail = slen & (batchSize - 1); - srcEnd = srcCurr + (slen - stail); + if (remainder == 0) return; + + slen = remainder; + remainder = slen & (sizeof(ulong) - 1); + srcEnd = srcCurr + (slen - remainder); while (srcCurr < srcEnd) { - long d00 = *(long*)(srcCurr); + *(ulong*)dstCurr = ~*(ulong*)srcCurr; - *(long*)dstCurr = ~d00; - - srcCurr += batchSize; - dstCurr += batchSize; + srcCurr += sizeof(ulong); + dstCurr += sizeof(ulong); } - if (stail == 0) return; - #endregion - - if (stail >= 7) dstCurr[6] = (byte)(~srcCurr[6]); - if (stail >= 6) dstCurr[5] = (byte)(~srcCurr[5]); - if (stail >= 5) dstCurr[4] = (byte)(~srcCurr[4]); - if (stail >= 4) dstCurr[3] = (byte)(~srcCurr[3]); - if (stail >= 3) dstCurr[2] = (byte)(~srcCurr[2]); - if (stail >= 2) dstCurr[1] = (byte)(~srcCurr[1]); - if (stail >= 1) dstCurr[0] = (byte)(~srcCurr[0]); + if (remainder == 0) return; + + if (remainder >= 7) dstCurr[6] = (byte)~srcCurr[6]; + if (remainder >= 6) dstCurr[5] = (byte)~srcCurr[5]; + if (remainder >= 5) dstCurr[4] = (byte)~srcCurr[4]; + if (remainder >= 4) dstCurr[3] = (byte)~srcCurr[3]; + if (remainder >= 3) dstCurr[2] = (byte)~srcCurr[2]; + if (remainder >= 2) dstCurr[1] = (byte)~srcCurr[1]; + if (remainder >= 1) dstCurr[0] = (byte)~srcCurr[0]; } - /// - /// AND bitop implementation using 256-wide SIMD registers. - /// - /// Output buffer to write BitOp result - /// Output buffer length. - /// Pointer to start of bitmap sources. - /// Pointer to end of bitmap sources - /// Number of source keys. - /// Minimum size of source bitmaps. - private static void __bitop_multikey_simdX256_and(byte* dstPtr, int dstLen, byte** srcStartPtrs, byte** srcEndPtrs, int srcKeyCount, int minSize) + public static void GenericCodeGenDebugAid(int dstLen, int srcKeyCount, int minSize) { - int batchSize = 8 * 32; - long slen = minSize; - long stail = slen & (batchSize - 1); - - byte* dstCurr = dstPtr; - byte* dstEnd = dstCurr + (slen - stail); - - #region 8x32 - while (dstCurr < dstEnd) - { - Vector256 d00 = Avx.LoadVector256(srcStartPtrs[0]); - Vector256 d01 = Avx.LoadVector256(srcStartPtrs[0] + 32); - Vector256 d02 = Avx.LoadVector256(srcStartPtrs[0] + 64); - Vector256 d03 = Avx.LoadVector256(srcStartPtrs[0] + 96); - Vector256 d04 = Avx.LoadVector256(srcStartPtrs[0] + 128); - Vector256 d05 = Avx.LoadVector256(srcStartPtrs[0] + 160); - Vector256 d06 = Avx.LoadVector256(srcStartPtrs[0] + 192); - Vector256 d07 = Avx.LoadVector256(srcStartPtrs[0] + 224); - srcStartPtrs[0] += batchSize; - for (int i = 1; i < srcKeyCount; i++) - { - Vector256 s00 = Avx.LoadVector256(srcStartPtrs[i]); - Vector256 s01 = Avx.LoadVector256(srcStartPtrs[i] + 32); - Vector256 s02 = Avx.LoadVector256(srcStartPtrs[i] + 64); - Vector256 s03 = Avx.LoadVector256(srcStartPtrs[i] + 96); - Vector256 s04 = Avx.LoadVector256(srcStartPtrs[i] + 128); - Vector256 s05 = Avx.LoadVector256(srcStartPtrs[i] + 160); - Vector256 s06 = Avx.LoadVector256(srcStartPtrs[i] + 192); - Vector256 s07 = Avx.LoadVector256(srcStartPtrs[i] + 224); - - d00 = Avx2.And(d00, s00); - d01 = Avx2.And(d01, s01); - d02 = Avx2.And(d02, s02); - d03 = Avx2.And(d03, s03); - d04 = Avx2.And(d04, s04); - d05 = Avx2.And(d05, s05); - d06 = Avx2.And(d06, s06); - d07 = Avx2.And(d07, s07); - srcStartPtrs[i] += batchSize; - } - - Avx.Store(dstCurr, d00); - Avx.Store(dstCurr + 32, d01); - Avx.Store(dstCurr + 64, d02); - Avx.Store(dstCurr + 96, d03); - Avx.Store(dstCurr + 128, d04); - Avx.Store(dstCurr + 160, d05); - Avx.Store(dstCurr + 192, d06); - Avx.Store(dstCurr + 224, d07); - - dstCurr += batchSize; - } - if (stail == 0) goto fillTail; - #endregion - - #region 1x32 - slen = stail; - batchSize = 1 * 32; - stail = slen & (batchSize - 1); - dstEnd = dstCurr + (slen - stail); - - while (dstCurr < dstEnd) - { - Vector256 d00 = Avx.LoadVector256(srcStartPtrs[0]); - srcStartPtrs[0] += batchSize; - for (int i = 1; i < srcKeyCount; i++) - { - Vector256 s00 = Avx.LoadVector256(srcStartPtrs[i]); - d00 = Avx2.And(d00, s00); - srcStartPtrs[i] += batchSize; - } - Avx.Store(dstCurr, d00); - dstCurr += batchSize; - } - if (stail == 0) goto fillTail; - #endregion - - #region scalar_4x8 - slen = stail; - batchSize = 4 * 8; - stail = slen & (batchSize - 1); - dstEnd = dstCurr + (slen - stail); - while (dstCurr < dstEnd) - { - long d00 = *(long*)(srcStartPtrs[0]); - long d01 = *(long*)(srcStartPtrs[0] + 8); - long d02 = *(long*)(srcStartPtrs[0] + 16); - long d03 = *(long*)(srcStartPtrs[0] + 24); - srcStartPtrs[0] += batchSize; - for (int i = 1; i < srcKeyCount; i++) - { - d00 &= *(long*)(srcStartPtrs[i]); - d01 &= *(long*)(srcStartPtrs[i] + 8); - d02 &= *(long*)(srcStartPtrs[i] + 16); - d03 &= *(long*)(srcStartPtrs[i] + 24); - srcStartPtrs[i] += batchSize; - } - - *(long*)dstCurr = d00; - *(long*)(dstCurr + 8) = d01; - *(long*)(dstCurr + 16) = d02; - *(long*)(dstCurr + 24) = d03; - dstCurr += batchSize; - } - if (stail == 0) goto fillTail; - #endregion - - #region scalar_1x8 - slen = stail; - batchSize = 8; - stail = slen & (batchSize - 1); - dstEnd = dstCurr + (slen - stail); - while (dstCurr < dstEnd) - { - long d00 = *(long*)(srcStartPtrs[0]); - srcStartPtrs[0] += batchSize; - for (int i = 1; i < srcKeyCount; i++) - { - d00 &= *(long*)(srcStartPtrs[i]); - srcStartPtrs[i] += batchSize; - } - *(long*)dstCurr = d00; - dstCurr += batchSize; - } - #endregion - - fillTail: - #region scalar_1x1 - byte* dstMaxEnd = dstPtr + dstLen; - int offset = 0; - while (dstCurr < dstMaxEnd) - { - byte d00; - if (srcStartPtrs[0] + offset < srcEndPtrs[0]) - d00 = srcStartPtrs[0][offset]; - else - { - d00 = 0; - goto writeBack; - } - - for (int i = 1; i < srcKeyCount; i++) - { - if (srcStartPtrs[i] + offset < srcEndPtrs[i]) - d00 &= srcStartPtrs[i][offset]; - else - { - d00 = 0; - goto writeBack; - } - } - writeBack: - *dstCurr++ = d00; - offset++; - } - #endregion + InvokeMultiKeyBitwise, BitwiseAndOperator>((byte*)0, dstLen, (byte**)0, (byte**)0, srcKeyCount, minSize); } /// - /// OR bitop implementation using 256-wide SIMD registers. + /// Invokes bitwise bit-operation for multiple keys using hardware accelerated SIMD intrinsics when possible. /// /// Output buffer to write BitOp result /// Output buffer length. @@ -330,321 +136,149 @@ private static void __bitop_multikey_simdX256_and(byte* dstPtr, int dstLen, byte /// Pointer to end of bitmap sources /// Number of source keys. /// Minimum size of source bitmaps. - private static void __bitop_multikey_simdX256_or(byte* dstPtr, int dstLen, byte** srcStartPtrs, byte** srcEndPtrs, int srcKeyCount, int minSize) + private static void InvokeMultiKeyBitwise(byte* dstPtr, int dstLen, byte** srcStartPtrs, byte** srcEndPtrs, int srcKeyCount, int minSize) + where TBinaryOperator : struct, IBinaryOperator + where TBinaryOperator2 : struct, IBinaryOperator { - int batchSize = 8 * 32; long slen = minSize; - long stail = slen & (batchSize - 1); + var remainder = slen & ((Vector256.Count * 8) - 1); - byte* dstCurr = dstPtr; - byte* dstEnd = dstCurr + (slen - stail); + var dstEndPtr = dstPtr + dstLen; + var dstBatchEndPtr = dstPtr + (slen - remainder); - #region 8x32 - while (dstCurr < dstEnd) + ref var firstKeyPtr = ref srcStartPtrs[0]; + + while (dstPtr < dstBatchEndPtr) { - Vector256 d00 = Avx.LoadVector256(srcStartPtrs[0]); - Vector256 d01 = Avx.LoadVector256(srcStartPtrs[0] + 32); - Vector256 d02 = Avx.LoadVector256(srcStartPtrs[0] + 64); - Vector256 d03 = Avx.LoadVector256(srcStartPtrs[0] + 96); - Vector256 d04 = Avx.LoadVector256(srcStartPtrs[0] + 128); - Vector256 d05 = Avx.LoadVector256(srcStartPtrs[0] + 160); - Vector256 d06 = Avx.LoadVector256(srcStartPtrs[0] + 192); - Vector256 d07 = Avx.LoadVector256(srcStartPtrs[0] + 224); - srcStartPtrs[0] += batchSize; - for (int i = 1; i < srcKeyCount; i++) + var d00 = Vector256.Load(firstKeyPtr); + var d01 = Vector256.Load(firstKeyPtr + Vector256.Count); + var d02 = Vector256.Load(firstKeyPtr + (Vector256.Count * 2)); + var d03 = Vector256.Load(firstKeyPtr + (Vector256.Count * 3)); + var d04 = Vector256.Load(firstKeyPtr + (Vector256.Count * 4)); + var d05 = Vector256.Load(firstKeyPtr + (Vector256.Count * 5)); + var d06 = Vector256.Load(firstKeyPtr + (Vector256.Count * 6)); + var d07 = Vector256.Load(firstKeyPtr + (Vector256.Count * 7)); + + firstKeyPtr += Vector256.Count * 8; + + for (var i = 1; i < srcKeyCount; i++) { - Vector256 s00 = Avx.LoadVector256(srcStartPtrs[i]); - Vector256 s01 = Avx.LoadVector256(srcStartPtrs[i] + 32); - Vector256 s02 = Avx.LoadVector256(srcStartPtrs[i] + 64); - Vector256 s03 = Avx.LoadVector256(srcStartPtrs[i] + 96); - Vector256 s04 = Avx.LoadVector256(srcStartPtrs[i] + 128); - Vector256 s05 = Avx.LoadVector256(srcStartPtrs[i] + 160); - Vector256 s06 = Avx.LoadVector256(srcStartPtrs[i] + 192); - Vector256 s07 = Avx.LoadVector256(srcStartPtrs[i] + 224); - - d00 = Avx2.Or(d00, s00); - d01 = Avx2.Or(d01, s01); - d02 = Avx2.Or(d02, s02); - d03 = Avx2.Or(d03, s03); - d04 = Avx2.Or(d04, s04); - d05 = Avx2.Or(d05, s05); - d06 = Avx2.Or(d06, s06); - d07 = Avx2.Or(d07, s07); - srcStartPtrs[i] += batchSize; + ref var keyStartPtr = ref srcStartPtrs[i]; + + var s00 = Vector256.Load(keyStartPtr); + var s01 = Vector256.Load(keyStartPtr + Vector256.Count); + var s02 = Vector256.Load(keyStartPtr + (Vector256.Count * 2)); + var s03 = Vector256.Load(keyStartPtr + (Vector256.Count * 3)); + var s04 = Vector256.Load(keyStartPtr + (Vector256.Count * 4)); + var s05 = Vector256.Load(keyStartPtr + (Vector256.Count * 5)); + var s06 = Vector256.Load(keyStartPtr + (Vector256.Count * 6)); + var s07 = Vector256.Load(keyStartPtr + (Vector256.Count * 7)); + + d00 = TBinaryOperator.Invoke(d00, s00); + d01 = TBinaryOperator.Invoke(d01, s01); + d02 = TBinaryOperator.Invoke(d02, s02); + d03 = TBinaryOperator.Invoke(d03, s03); + d04 = TBinaryOperator.Invoke(d04, s04); + d05 = TBinaryOperator.Invoke(d05, s05); + d06 = TBinaryOperator.Invoke(d06, s06); + d07 = TBinaryOperator.Invoke(d07, s07); + + keyStartPtr += Vector256.Count * 8; } - Avx.Store(dstCurr, d00); - Avx.Store(dstCurr + 32, d01); - Avx.Store(dstCurr + 64, d02); - Avx.Store(dstCurr + 96, d03); - Avx.Store(dstCurr + 128, d04); - Avx.Store(dstCurr + 160, d05); - Avx.Store(dstCurr + 192, d06); - Avx.Store(dstCurr + 224, d07); + Vector256.Store(d00, dstPtr); + Vector256.Store(d01, dstPtr + Vector256.Count); + Vector256.Store(d02, dstPtr + Vector256.Count * 2); + Vector256.Store(d03, dstPtr + Vector256.Count * 3); + Vector256.Store(d04, dstPtr + Vector256.Count * 4); + Vector256.Store(d05, dstPtr + Vector256.Count * 5); + Vector256.Store(d06, dstPtr + Vector256.Count * 6); + Vector256.Store(d07, dstPtr + Vector256.Count * 7); - dstCurr += batchSize; + dstPtr += Vector256.Count * 8; } - if (stail == 0) goto fillTail; - #endregion + if (remainder == 0) goto fillTail; - #region 1x32 - slen = stail; - batchSize = 1 * 32; - stail = slen & (batchSize - 1); - dstEnd = dstCurr + (slen - stail); + slen = remainder; + remainder = slen & (Vector256.Count - 1); + dstBatchEndPtr = dstPtr + (slen - remainder); - while (dstCurr < dstEnd) + while (dstPtr < dstBatchEndPtr) { - Vector256 d00 = Avx.LoadVector256(srcStartPtrs[0]); - srcStartPtrs[0] += batchSize; - for (int i = 1; i < srcKeyCount; i++) - { - Vector256 s00 = Avx.LoadVector256(srcStartPtrs[i]); - d00 = Avx2.Or(d00, s00); - srcStartPtrs[i] += batchSize; - } - Avx.Store(dstCurr, d00); - dstCurr += batchSize; - } - if (stail == 0) goto fillTail; - #endregion - - #region scalar_4x8 - slen = stail; - batchSize = 4 * 8; - stail = slen & (batchSize - 1); - dstEnd = dstCurr + (slen - stail); - while (dstCurr < dstEnd) - { - long d00 = *(long*)(srcStartPtrs[0]); - long d01 = *(long*)(srcStartPtrs[0] + 8); - long d02 = *(long*)(srcStartPtrs[0] + 16); - long d03 = *(long*)(srcStartPtrs[0] + 24); - srcStartPtrs[0] += batchSize; - for (int i = 1; i < srcKeyCount; i++) - { - d00 |= *(long*)(srcStartPtrs[i]); - d01 |= *(long*)(srcStartPtrs[i] + 8); - d02 |= *(long*)(srcStartPtrs[i] + 16); - d03 |= *(long*)(srcStartPtrs[i] + 24); - srcStartPtrs[i] += batchSize; - } + var d00 = Vector256.Load(firstKeyPtr); + firstKeyPtr += Vector256.Count; - *(long*)dstCurr = d00; - *(long*)(dstCurr + 8) = d01; - *(long*)(dstCurr + 16) = d02; - *(long*)(dstCurr + 24) = d03; - dstCurr += batchSize; - } - if (stail == 0) goto fillTail; - #endregion - - #region scalar_1x8 - slen = stail; - batchSize = 8; - stail = slen & (batchSize - 1); - dstEnd = dstCurr + (slen - stail); - while (dstCurr < dstEnd) - { - long d00 = *(long*)(srcStartPtrs[0]); - srcStartPtrs[0] += batchSize; - for (int i = 1; i < srcKeyCount; i++) + for (var i = 1; i < srcKeyCount; i++) { - d00 |= *(long*)(srcStartPtrs[i]); - srcStartPtrs[i] += batchSize; - } - *(long*)dstCurr = d00; - dstCurr += batchSize; - } - #endregion - - fillTail: - #region scalar_1x1 - byte* dstMaxEnd = dstPtr + dstLen; - int offset = 0; - while (dstCurr < dstMaxEnd) - { - byte d00 = 0; - if (srcStartPtrs[0] + offset < srcEndPtrs[0]) - { - d00 = srcStartPtrs[0][offset]; - if (d00 == 0xff) goto writeBack; - } - - for (int i = 1; i < srcKeyCount; i++) - { - if (srcStartPtrs[i] + offset < srcEndPtrs[i]) - { - d00 |= srcStartPtrs[i][offset]; - if (d00 == 0xff) goto writeBack; - } - } - writeBack: - *dstCurr++ = d00; - offset++; - } - #endregion - } - - /// - /// XOR bitop implementation using 256-wide SIMD registers. - /// - /// Output buffer to write BitOp result - /// Output buffer length. - /// Pointer to start of bitmap sources. - /// Pointer to end of bitmap sources - /// Number of source keys. - /// Minimum size of source bitmaps. - private static void __bitop_multikey_simdX256_xor(byte* dstPtr, int dstLen, byte** srcStartPtrs, byte** srcEndPtrs, int srcKeyCount, int minSize) - { - int batchSize = 8 * 32; - long slen = minSize; - long stail = slen & (batchSize - 1); - - byte* dstCurr = dstPtr; - byte* dstEnd = dstCurr + (slen - stail); + var s00 = Vector256.Load(srcStartPtrs[i]); + d00 = TBinaryOperator.Invoke(d00, s00); - #region 8x32 - while (dstCurr < dstEnd) - { - Vector256 d00 = Avx.LoadVector256(srcStartPtrs[0]); - Vector256 d01 = Avx.LoadVector256(srcStartPtrs[0] + 32); - Vector256 d02 = Avx.LoadVector256(srcStartPtrs[0] + 64); - Vector256 d03 = Avx.LoadVector256(srcStartPtrs[0] + 96); - Vector256 d04 = Avx.LoadVector256(srcStartPtrs[0] + 128); - Vector256 d05 = Avx.LoadVector256(srcStartPtrs[0] + 160); - Vector256 d06 = Avx.LoadVector256(srcStartPtrs[0] + 192); - Vector256 d07 = Avx.LoadVector256(srcStartPtrs[0] + 224); - srcStartPtrs[0] += batchSize; - for (int i = 1; i < srcKeyCount; i++) - { - Vector256 s00 = Avx.LoadVector256(srcStartPtrs[i]); - Vector256 s01 = Avx.LoadVector256(srcStartPtrs[i] + 32); - Vector256 s02 = Avx.LoadVector256(srcStartPtrs[i] + 64); - Vector256 s03 = Avx.LoadVector256(srcStartPtrs[i] + 96); - Vector256 s04 = Avx.LoadVector256(srcStartPtrs[i] + 128); - Vector256 s05 = Avx.LoadVector256(srcStartPtrs[i] + 160); - Vector256 s06 = Avx.LoadVector256(srcStartPtrs[i] + 192); - Vector256 s07 = Avx.LoadVector256(srcStartPtrs[i] + 224); - - d00 = Avx2.Xor(d00, s00); - d01 = Avx2.Xor(d01, s01); - d02 = Avx2.Xor(d02, s02); - d03 = Avx2.Xor(d03, s03); - d04 = Avx2.Xor(d04, s04); - d05 = Avx2.Xor(d05, s05); - d06 = Avx2.Xor(d06, s06); - d07 = Avx2.Xor(d07, s07); - srcStartPtrs[i] += batchSize; + srcStartPtrs[i] += Vector256.Count; } - Avx.Store(dstCurr, d00); - Avx.Store(dstCurr + 32, d01); - Avx.Store(dstCurr + 64, d02); - Avx.Store(dstCurr + 96, d03); - Avx.Store(dstCurr + 128, d04); - Avx.Store(dstCurr + 160, d05); - Avx.Store(dstCurr + 192, d06); - Avx.Store(dstCurr + 224, d07); + Vector256.Store(d00, dstPtr); - dstCurr += batchSize; + dstPtr += Vector256.Count; } - #endregion + if (remainder == 0) goto fillTail; - #region 1x32 - slen = stail; - batchSize = 1 * 32; - stail = slen & (batchSize - 1); - dstEnd = dstCurr + (slen - stail); + slen = remainder; + remainder = slen & (sizeof(ulong) - 1); + dstBatchEndPtr = dstPtr + (slen - remainder); - while (dstCurr < dstEnd) + while (dstPtr < dstBatchEndPtr) { - Vector256 d00 = Avx.LoadVector256(srcStartPtrs[0]); - srcStartPtrs[0] += batchSize; - for (int i = 1; i < srcKeyCount; i++) - { - Vector256 s00 = Avx.LoadVector256(srcStartPtrs[i]); - d00 = Avx2.Xor(d00, s00); - srcStartPtrs[i] += batchSize; - } - Avx.Store(dstCurr, d00); - dstCurr += batchSize; - } - #endregion - - #region scalar_4x8 - slen = stail; - batchSize = 4 * 8; - stail = slen & (batchSize - 1); - dstEnd = dstCurr + (slen - stail); - while (dstCurr < dstEnd) - { - long d00 = *(long*)(srcStartPtrs[0]); - long d01 = *(long*)(srcStartPtrs[0] + 8); - long d02 = *(long*)(srcStartPtrs[0] + 16); - long d03 = *(long*)(srcStartPtrs[0] + 24); - srcStartPtrs[0] += batchSize; - for (int i = 1; i < srcKeyCount; i++) - { - d00 ^= *(long*)(srcStartPtrs[i]); - d01 ^= *(long*)(srcStartPtrs[i] + 8); - d02 ^= *(long*)(srcStartPtrs[i] + 16); - d03 ^= *(long*)(srcStartPtrs[i] + 24); - srcStartPtrs[i] += batchSize; - } + ulong d00 = *(ulong*)firstKeyPtr; + firstKeyPtr += sizeof(ulong); - *(long*)dstCurr = d00; - *(long*)(dstCurr + 8) = d01; - *(long*)(dstCurr + 16) = d02; - *(long*)(dstCurr + 24) = d03; - dstCurr += batchSize; - } - if (stail == 0) goto fillTail; - #endregion - - #region scalar_1x8 - slen = stail; - batchSize = 8; - stail = slen & (batchSize - 1); - dstEnd = dstCurr + (slen - stail); - while (dstCurr < dstEnd) - { - long d00 = *(long*)(srcStartPtrs[0]); - srcStartPtrs[0] += batchSize; - for (int i = 1; i < srcKeyCount; i++) + for (var i = 1; i < srcKeyCount; i++) { - d00 ^= *(long*)(srcStartPtrs[i]); - srcStartPtrs[i] += batchSize; + d00 = TBinaryOperator2.Invoke(d00, *(ulong*)srcStartPtrs[i]); + srcStartPtrs[i] += sizeof(ulong); } - *(long*)dstCurr = d00; - dstCurr += batchSize; + + *(ulong*)dstPtr = d00; + dstPtr += sizeof(ulong); } - #endregion fillTail: - #region scalar_1x1 - byte* dstMaxEnd = dstPtr + dstLen; - while (dstCurr < dstMaxEnd) + while (dstPtr < dstEndPtr) { byte d00 = 0; - if (srcStartPtrs[0] < srcEndPtrs[0]) + + if (firstKeyPtr < srcEndPtrs[0]) { - d00 = *srcStartPtrs[0]; - srcStartPtrs[0]++; + d00 = *firstKeyPtr; + firstKeyPtr++; } - for (int i = 1; i < srcKeyCount; i++) + for (var i = 1; i < srcKeyCount; i++) { if (srcStartPtrs[i] < srcEndPtrs[i]) { - d00 ^= *srcStartPtrs[i]; + d00 = TBinaryOperator.Invoke(d00, *srcStartPtrs[i]); srcStartPtrs[i]++; } + else + { + if (typeof(TBinaryOperator) == typeof(BitwiseAndOperator)) + { + d00 = 0; + } + else if (typeof(TBinaryOperator) == typeof(BitwiseOrOperator)) + { + // nop + } + else if (typeof(TBinaryOperator) == typeof(BitwiseXorOperator)) + { + // TODO: I _think_ there's a error in this logic and we should have here: + // d00 ^= 0; + } + } } - *dstCurr++ = d00; + + *dstPtr++ = d00; } - #endregion } - } } \ No newline at end of file diff --git a/test/Garnet.test/GarnetBitmapTests.cs b/test/Garnet.test/GarnetBitmapTests.cs index 627a4984d4..11a2327417 100644 --- a/test/Garnet.test/GarnetBitmapTests.cs +++ b/test/Garnet.test/GarnetBitmapTests.cs @@ -32,14 +32,9 @@ public void TearDown() TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } - private long LongRandom() => ((long)this.r.Next() << 32) | (long)this.r.Next(); + private long LongRandom() => r.NextInt64(); - private ulong ULongRandom() - { - ulong lsb = (ulong)(this.r.Next()); - ulong msb = (ulong)(this.r.Next()) << 32; - return (msb | lsb); - } + private ulong ULongRandom() => (ulong)r.NextInt64(long.MinValue, long.MaxValue); private unsafe long ResponseToLong(byte[] response, int offset) { @@ -879,16 +874,18 @@ public void BitmapSimpleBitOpTests() } } - private static void InitBitmap(ref byte[] dst, byte[] srcA, bool invert = false) + private static byte[] CopyBitmap(byte[] sourceBitmap, bool invert = false) { - dst = new byte[srcA.Length]; + var dst = new byte[sourceBitmap.Length]; if (invert) - for (int i = 0; i < srcA.Length; i++) dst[i] = (byte)~srcA[i]; + for (int i = 0; i < sourceBitmap.Length; i++) dst[i] = (byte)~sourceBitmap[i]; else - for (int i = 0; i < srcA.Length; i++) dst[i] = srcA[i]; + sourceBitmap.AsSpan().CopyTo(dst); + + return dst; } - private static void ApplyBitop(ref byte[] dst, byte[] srcA, Func f8) + private static void ApplyBitop(ref byte[] dst, byte[] srcA, Func op) { if (dst.Length < srcA.Length) { @@ -899,12 +896,12 @@ private static void ApplyBitop(ref byte[] dst, byte[] srcA, Func f8 = null; - switch (bitwiseOps[j]) - { - case Bitwise.And: - f8 = (a, b) => (byte)(a & b); - break; - case Bitwise.Or: - f8 = (a, b) => (byte)(a | b); - break; - case Bitwise.Xor: - f8 = (a, b) => (byte)(a ^ b); - break; - } + Func op = bitwiseOps[j] switch + { + Bitwise.And => static (a, b) => (byte)(a & b), + Bitwise.Or => static (a, b) => (byte)(a | b), + Bitwise.Xor => static (a, b) => (byte)(a ^ b) + }; - dataX = null; - InitBitmap(ref dataX, dataA); - ApplyBitop(ref dataX, dataB, f8); - ApplyBitop(ref dataX, dataC, f8); - ApplyBitop(ref dataX, dataD, f8); + byte[] dataX = CopyBitmap(dataA); + ApplyBitop(ref dataX, dataB, op); + ApplyBitop(ref dataX, dataC, op); + ApplyBitop(ref dataX, dataD, op); long size = db.StringBitOperation(bitwiseOps[j], x, keys); ClassicAssert.AreEqual(size, dataX.Length); @@ -1032,7 +1020,7 @@ public void BitmapSimpleBitOpVarLenGrowingSizeTests() string x = "x"; byte[] dataA, dataB, dataC, dataD; - byte[] dataX; + int minSize = 512; Bitwise[] bitwiseOps = [Bitwise.And, Bitwise.Or, Bitwise.Xor, Bitwise.And, Bitwise.Or, Bitwise.Xor]; RedisKey[] keys = [a, b, c, d]; @@ -1042,15 +1030,14 @@ public void BitmapSimpleBitOpVarLenGrowingSizeTests() { dataA = new byte[r.Next(minSize, minSize + 32)]; r.NextBytes(dataA); - db.StringSet(a, dataA); + byte[] expectedX = CopyBitmap(dataA, invert: true); - dataX = null; - InitBitmap(ref dataX, dataA, true); + db.StringSet(a, dataA); long size = db.StringBitOperation(Bitwise.Not, x, a); - ClassicAssert.AreEqual(size, dataX.Length); + ClassicAssert.AreEqual(expectedX.Length, size); - byte[] expectedX = db.StringGet(x); - ClassicAssert.AreEqual(dataX, expectedX); + byte[] actualX = db.StringGet(x); + ClassicAssert.AreEqual(expectedX, actualX); } //Test AND, OR, XOR @@ -1062,8 +1049,7 @@ public void BitmapSimpleBitOpVarLenGrowingSizeTests() dataB = new byte[r.Next(minSize, minSize + 16)]; minSize = dataB.Length; dataC = new byte[r.Next(minSize, minSize + 16)]; minSize = dataC.Length; dataD = new byte[r.Next(minSize, minSize + 16)]; minSize = dataD.Length; - minSize = 17; - + r.NextBytes(dataA); r.NextBytes(dataB); r.NextBytes(dataC); @@ -1074,32 +1060,24 @@ public void BitmapSimpleBitOpVarLenGrowingSizeTests() db.StringSet(c, dataC); db.StringSet(d, dataD); - Func f8 = null; - switch (bitwiseOps[j]) + Func op = bitwiseOps[j] switch { - case Bitwise.And: - f8 = (a, b) => (byte)(a & b); - break; - case Bitwise.Or: - f8 = (a, b) => (byte)(a | b); - break; - case Bitwise.Xor: - f8 = (a, b) => (byte)(a ^ b); - break; - } + Bitwise.And => static (a, b) => (byte)(a & b), + Bitwise.Or => static (a, b) => (byte)(a | b), + Bitwise.Xor => static (a, b) => (byte)(a ^ b) + }; - dataX = null; - InitBitmap(ref dataX, dataA); - ApplyBitop(ref dataX, dataB, f8); - ApplyBitop(ref dataX, dataC, f8); - ApplyBitop(ref dataX, dataD, f8); + byte[] expectedX = CopyBitmap(dataA); + ApplyBitop(ref expectedX, dataB, op); + ApplyBitop(ref expectedX, dataC, op); + ApplyBitop(ref expectedX, dataD, op); long size = db.StringBitOperation(bitwiseOps[j], x, keys); - ClassicAssert.AreEqual(size, dataX.Length); - byte[] expectedX = db.StringGet(x); + ClassicAssert.AreEqual(expectedX.Length, size); + byte[] dataX = db.StringGet(x); - ClassicAssert.AreEqual(expectedX.Length, dataX.Length); - ClassicAssert.AreEqual(dataX, expectedX); + ClassicAssert.AreEqual(expectedX.Length, expectedX.Length); + ClassicAssert.AreEqual(expectedX, dataX); } } } From e691692a0b2a74a6a54c48b6bf08f1d3d2ab9824 Mon Sep 17 00:00:00 2001 From: PaulusParssinen Date: Wed, 22 Jan 2025 03:08:15 +0200 Subject: [PATCH 2/2] Remove TensorPrimitives invoke logic for now --- .../TensorPrimitives.IBinaryOperator.cs | 1059 ----------------- 1 file changed, 1059 deletions(-) diff --git a/libs/common/Numerics/TensorPrimitives.IBinaryOperator.cs b/libs/common/Numerics/TensorPrimitives.IBinaryOperator.cs index 6091e0ab30..3c94b67b4a 100644 --- a/libs/common/Numerics/TensorPrimitives.IBinaryOperator.cs +++ b/libs/common/Numerics/TensorPrimitives.IBinaryOperator.cs @@ -45,1064 +45,5 @@ public interface IBinaryOperator static abstract Vector256 Invoke(Vector256 x, Vector256 y); static abstract Vector512 Invoke(Vector512 x, Vector512 y); } - - // TODO: Remove, no attempt to use yet in this PR - public static void UnsafeInvokeOperator( - T* xPtr, T* yPtr, T* dPtr, int length) - where T : unmanaged - where TBinaryOperator : struct, IBinaryOperator - { - // Since every branch has a cost and since that cost is - // essentially lost for larger inputs, we do branches - // in a way that allows us to have the minimum possible - // for small sizes - - nuint remainder = (uint)length; - - if (Vector512.IsHardwareAccelerated && Vector512.IsSupported) - { - if (remainder >= (uint)Vector512.Count) - { - Vectorized512(ref xPtr, ref yPtr, ref dPtr, remainder); - } - else - { - // We have less than a vector and so we can only handle this as scalar. To do this - // efficiently, we simply have a small jump table and fallthrough. So we get a simple - // length check, single jump, and then linear execution. - - VectorizedSmall(ref xPtr, ref yPtr, ref dPtr, remainder); - } - - return; - } - - if (Vector256.IsHardwareAccelerated && Vector256.IsSupported) - { - if (remainder >= (uint)Vector256.Count) - { - Vectorized256(ref xPtr, ref yPtr, ref dPtr, remainder); - } - else - { - // We have less than a vector and so we can only handle this as scalar. To do this - // efficiently, we simply have a small jump table and fallthrough. So we get a simple - // length check, single jump, and then linear execution. - - VectorizedSmall(ref xPtr, ref yPtr, ref dPtr, remainder); - } - - return; - } - - if (Vector128.IsHardwareAccelerated && Vector128.IsSupported) - { - if (remainder >= (uint)Vector128.Count) - { - Vectorized128(ref xPtr, ref yPtr, ref dPtr, remainder); - } - else - { - // We have less than a vector and so we can only handle this as scalar. To do this - // efficiently, we simply have a small jump table and fallthrough. So we get a simple - // length check, single jump, and then linear execution. - - VectorizedSmall(ref xPtr, ref yPtr, ref dPtr, remainder); - } - - return; - } - - // This is the software fallback when no acceleration is available - // It requires no branches to hit - - SoftwareFallback(xPtr, yPtr, dPtr, remainder); - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void SoftwareFallback(T* xPtr, T* yPtr, T* dPtr, nuint length) - { - for (nuint i = 0; i < length; i++) - { - *(dPtr + i) = TBinaryOperator.Invoke(*(xPtr + i), *(yPtr + i)); - } - } - - static void Vectorized128(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) - { - ref T* dPtrBeg = ref dPtr; - - // Preload the beginning and end so that overlapping accesses don't negatively impact the data - - Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); - Vector128 end = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)Vector128.Count), - Vector128.Load(yPtr + remainder - (uint)Vector128.Count)); - - if (remainder > (uint)(Vector128.Count * 8)) - { - // We need to the ensure the underlying data can be aligned and only align - // it if it can. It is possible we have an unaligned ref, in which case we - // can never achieve the required SIMD alignment. - - bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; - - if (canAlign) - { - // Compute by how many elements we're misaligned and adjust the pointers accordingly - // - // Noting that we are only actually aligning dPtr. This is because unaligned stores - // are more expensive than unaligned loads and aligning both is significantly more - // complex. - - nuint misalignment = ((uint)sizeof(Vector128) - ((nuint)dPtr % (uint)sizeof(Vector128))) / (uint)sizeof(T); - - xPtr += misalignment; - yPtr += misalignment; - dPtr += misalignment; - - Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector128)) == 0); - - remainder -= misalignment; - } - - Vector128 vector1; - Vector128 vector2; - Vector128 vector3; - Vector128 vector4; - - while (remainder >= (uint)(Vector128.Count * 8)) - { - // We load, process, and store the first four vectors - - vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); - vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); - vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); - vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); - - vector1.Store(dPtr + (uint)(Vector128.Count * 0)); - vector2.Store(dPtr + (uint)(Vector128.Count * 1)); - vector3.Store(dPtr + (uint)(Vector128.Count * 2)); - vector4.Store(dPtr + (uint)(Vector128.Count * 3)); - - // We load, process, and store the next four vectors - - vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); - vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); - vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); - vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), - Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); - - vector1.Store(dPtr + (uint)(Vector128.Count * 4)); - vector2.Store(dPtr + (uint)(Vector128.Count * 5)); - vector3.Store(dPtr + (uint)(Vector128.Count * 6)); - vector4.Store(dPtr + (uint)(Vector128.Count * 7)); - - // We adjust the source and destination references, then update - // the count of remaining elements to process. - - xPtr += (uint)(Vector128.Count * 8); - yPtr += (uint)(Vector128.Count * 8); - dPtr += (uint)(Vector128.Count * 8); - - remainder -= (uint)(Vector128.Count * 8); - } - } - - // Process the remaining [Count, Count * 8] elements via a jump table - // - // Unless the original length was an exact multiple of Count, then we'll - // end up reprocessing a couple elements in case 1 for end. We'll also - // potentially reprocess a few elements in case 0 for beg, to handle any - // data before the first aligned address. - - nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); - - switch (remainder / (uint)Vector128.Count) - { - case 8: - { - Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 8)), - Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 8))); - vector.Store(dPtr + remainder - (uint)(Vector128.Count * 8)); - goto case 7; - } - - case 7: - { - Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 7)), - Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 7))); - vector.Store(dPtr + remainder - (uint)(Vector128.Count * 7)); - goto case 6; - } - - case 6: - { - Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 6)), - Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 6))); - vector.Store(dPtr + remainder - (uint)(Vector128.Count * 6)); - goto case 5; - } - - case 5: - { - Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 5)), - Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 5))); - vector.Store(dPtr + remainder - (uint)(Vector128.Count * 5)); - goto case 4; - } - - case 4: - { - Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 4)), - Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 4))); - vector.Store(dPtr + remainder - (uint)(Vector128.Count * 4)); - goto case 3; - } - - case 3: - { - Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 3)), - Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 3))); - vector.Store(dPtr + remainder - (uint)(Vector128.Count * 3)); - goto case 2; - } - - case 2: - { - Vector128 vector = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)(Vector128.Count * 2)), - Vector128.Load(yPtr + remainder - (uint)(Vector128.Count * 2))); - vector.Store(dPtr + remainder - (uint)(Vector128.Count * 2)); - goto case 1; - } - - case 1: - { - // Store the last block, which includes any elements that wouldn't fill a full vector - end.Store(dPtr + endIndex - (uint)Vector128.Count); - goto case 0; - } - - case 0: - { - // Store the first block, which includes any elements preceding the first aligned block - beg.Store(dPtrBeg); - break; - } - } - } - - static void Vectorized256(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) - { - ref T* dPtrBeg = ref dPtr; - - // Preload the beginning and end so that overlapping accesses don't negatively impact the data - - Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); - Vector256 end = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)Vector256.Count), - Vector256.Load(yPtr + remainder - (uint)Vector256.Count)); - - if (remainder > (uint)(Vector256.Count * 8)) - { - // We need to the ensure the underlying data can be aligned and only align - // it if it can. It is possible we have an unaligned ref, in which case we - // can never achieve the required SIMD alignment. - - bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; - - if (canAlign) - { - // Compute by how many elements we're misaligned and adjust the pointers accordingly - // - // Noting that we are only actually aligning dPtr. This is because unaligned stores - // are more expensive than unaligned loads and aligning both is significantly more - // complex. - - nuint misalignment = ((uint)sizeof(Vector256) - ((nuint)dPtr % (uint)sizeof(Vector256))) / (uint)sizeof(T); - - xPtr += misalignment; - yPtr += misalignment; - dPtr += misalignment; - - Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector256)) == 0); - - remainder -= misalignment; - } - - Vector256 vector1; - Vector256 vector2; - Vector256 vector3; - Vector256 vector4; - - while (remainder >= (uint)(Vector256.Count * 8)) - { - // We load, process, and store the first four vectors - - vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); - vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); - vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); - vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); - - vector1.Store(dPtr + (uint)(Vector256.Count * 0)); - vector2.Store(dPtr + (uint)(Vector256.Count * 1)); - vector3.Store(dPtr + (uint)(Vector256.Count * 2)); - vector4.Store(dPtr + (uint)(Vector256.Count * 3)); - - // We load, process, and store the next four vectors - - vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); - vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); - vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); - vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), - Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); - - vector1.Store(dPtr + (uint)(Vector256.Count * 4)); - vector2.Store(dPtr + (uint)(Vector256.Count * 5)); - vector3.Store(dPtr + (uint)(Vector256.Count * 6)); - vector4.Store(dPtr + (uint)(Vector256.Count * 7)); - - // We adjust the source and destination references, then update - // the count of remaining elements to process. - - xPtr += (uint)(Vector256.Count * 8); - yPtr += (uint)(Vector256.Count * 8); - dPtr += (uint)(Vector256.Count * 8); - - remainder -= (uint)(Vector256.Count * 8); - } - } - - // Process the remaining [Count, Count * 8] elements via a jump table - // - // Unless the original length was an exact multiple of Count, then we'll - // end up reprocessing a couple elements in case 1 for end. We'll also - // potentially reprocess a few elements in case 0 for beg, to handle any - // data before the first aligned address. - - nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); - - switch (remainder / (uint)Vector256.Count) - { - case 8: - { - Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 8)), - Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 8))); - vector.Store(dPtr + remainder - (uint)(Vector256.Count * 8)); - goto case 7; - } - - case 7: - { - Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 7)), - Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 7))); - vector.Store(dPtr + remainder - (uint)(Vector256.Count * 7)); - goto case 6; - } - - case 6: - { - Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 6)), - Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 6))); - vector.Store(dPtr + remainder - (uint)(Vector256.Count * 6)); - goto case 5; - } - - case 5: - { - Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 5)), - Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 5))); - vector.Store(dPtr + remainder - (uint)(Vector256.Count * 5)); - goto case 4; - } - - case 4: - { - Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 4)), - Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 4))); - vector.Store(dPtr + remainder - (uint)(Vector256.Count * 4)); - goto case 3; - } - - case 3: - { - Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 3)), - Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 3))); - vector.Store(dPtr + remainder - (uint)(Vector256.Count * 3)); - goto case 2; - } - - case 2: - { - Vector256 vector = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)(Vector256.Count * 2)), - Vector256.Load(yPtr + remainder - (uint)(Vector256.Count * 2))); - vector.Store(dPtr + remainder - (uint)(Vector256.Count * 2)); - goto case 1; - } - - case 1: - { - // Store the last block, which includes any elements that wouldn't fill a full vector - end.Store(dPtr + endIndex - (uint)Vector256.Count); - goto case 0; - } - - case 0: - { - // Store the first block, which includes any elements preceding the first aligned block - beg.Store(dPtrBeg); - break; - } - } - } - - static void Vectorized512(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) - { - ref T* dPtrBeg = ref dPtr; - - // Preload the beginning and end so that overlapping accesses don't negatively impact the data - - Vector512 beg = TBinaryOperator.Invoke(Vector512.Load(xPtr), Vector512.Load(yPtr)); - Vector512 end = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)Vector512.Count), - Vector512.Load(yPtr + remainder - (uint)Vector512.Count)); - - if (remainder > (uint)(Vector512.Count * 8)) - { - // We need to the ensure the underlying data can be aligned and only align - // it if it can. It is possible we have an unaligned ref, in which case we - // can never achieve the required SIMD alignment. - - bool canAlign = ((nuint)dPtr % (nuint)sizeof(T)) == 0; - - if (canAlign) - { - // Compute by how many elements we're misaligned and adjust the pointers accordingly - // - // Noting that we are only actually aligning dPtr. This is because unaligned stores - // are more expensive than unaligned loads and aligning both is significantly more - // complex. - - nuint misalignment = ((uint)sizeof(Vector512) - ((nuint)dPtr % (uint)sizeof(Vector512))) / (uint)sizeof(T); - - xPtr += misalignment; - yPtr += misalignment; - dPtr += misalignment; - - Debug.Assert(((nuint)dPtr % (uint)sizeof(Vector512)) == 0); - - remainder -= misalignment; - } - - Vector512 vector1; - Vector512 vector2; - Vector512 vector3; - Vector512 vector4; - - while (remainder >= (uint)(Vector512.Count * 8)) - { - // We load, process, and store the first four vectors - - vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); - vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); - vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); - vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); - - vector1.Store(dPtr + (uint)(Vector512.Count * 0)); - vector2.Store(dPtr + (uint)(Vector512.Count * 1)); - vector3.Store(dPtr + (uint)(Vector512.Count * 2)); - vector4.Store(dPtr + (uint)(Vector512.Count * 3)); - - // We load, process, and store the next four vectors - - vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); - vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); - vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); - vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), - Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); - - vector1.Store(dPtr + (uint)(Vector512.Count * 4)); - vector2.Store(dPtr + (uint)(Vector512.Count * 5)); - vector3.Store(dPtr + (uint)(Vector512.Count * 6)); - vector4.Store(dPtr + (uint)(Vector512.Count * 7)); - - // We adjust the source and destination references, then update - // the count of remaining elements to process. - - xPtr += (uint)(Vector512.Count * 8); - yPtr += (uint)(Vector512.Count * 8); - dPtr += (uint)(Vector512.Count * 8); - - remainder -= (uint)(Vector512.Count * 8); - } - } - - // Process the remaining [Count, Count * 8] elements via a jump table - // - // Unless the original length was an exact multiple of Count, then we'll - // end up reprocessing a couple elements in case 1 for end. We'll also - // potentially reprocess a few elements in case 0 for beg, to handle any - // data before the first aligned address. - - nuint endIndex = remainder; - remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); - - switch (remainder / (uint)Vector512.Count) - { - case 8: - { - Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 8)), - Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 8))); - vector.Store(dPtr + remainder - (uint)(Vector512.Count * 8)); - goto case 7; - } - - case 7: - { - Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 7)), - Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 7))); - vector.Store(dPtr + remainder - (uint)(Vector512.Count * 7)); - goto case 6; - } - - case 6: - { - Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 6)), - Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 6))); - vector.Store(dPtr + remainder - (uint)(Vector512.Count * 6)); - goto case 5; - } - - case 5: - { - Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 5)), - Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 5))); - vector.Store(dPtr + remainder - (uint)(Vector512.Count * 5)); - goto case 4; - } - - case 4: - { - Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 4)), - Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 4))); - vector.Store(dPtr + remainder - (uint)(Vector512.Count * 4)); - goto case 3; - } - - case 3: - { - Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 3)), - Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 3))); - vector.Store(dPtr + remainder - (uint)(Vector512.Count * 3)); - goto case 2; - } - - case 2: - { - Vector512 vector = TBinaryOperator.Invoke(Vector512.Load(xPtr + remainder - (uint)(Vector512.Count * 2)), - Vector512.Load(yPtr + remainder - (uint)(Vector512.Count * 2))); - vector.Store(dPtr + remainder - (uint)(Vector512.Count * 2)); - goto case 1; - } - - case 1: - { - // Store the last block, which includes any elements that wouldn't fill a full vector - end.Store(dPtr + endIndex - (uint)Vector512.Count); - goto case 0; - } - - case 0: - { - // Store the first block, which includes any elements preceding the first aligned block - beg.Store(dPtrBeg); - break; - } - } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void VectorizedSmall(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) - { - if (sizeof(T) == 1) - { - VectorizedSmall1(ref xPtr, ref yPtr, ref dPtr, remainder); - } - else if (sizeof(T) == 2) - { - VectorizedSmall2(ref xPtr, ref yPtr, ref dPtr, remainder); - } - else if (sizeof(T) == 4) - { - VectorizedSmall4(ref xPtr, ref yPtr, ref dPtr, remainder); - } - else - { - Debug.Assert(sizeof(T) == 8); - VectorizedSmall8(ref xPtr, ref yPtr, ref dPtr, remainder); - } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void VectorizedSmall1(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) - { - Debug.Assert(sizeof(T) == 1); - - switch (remainder) - { - // Two Vector256's worth of data, with at least one element overlapping. - case 63: - case 62: - case 61: - case 60: - case 59: - case 58: - case 57: - case 56: - case 55: - case 54: - case 53: - case 52: - case 51: - case 50: - case 49: - case 48: - case 47: - case 46: - case 45: - case 44: - case 43: - case 42: - case 41: - case 40: - case 39: - case 38: - case 37: - case 36: - case 35: - case 34: - case 33: - { - Debug.Assert(Vector256.IsHardwareAccelerated); - - Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); - Vector256 end = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)Vector256.Count), - Vector256.Load(yPtr + remainder - (uint)Vector256.Count)); - - beg.Store(dPtr); - end.Store(dPtr + remainder - (uint)Vector256.Count); - - break; - } - - // One Vector256's worth of data. - case 32: - { - Debug.Assert(Vector256.IsHardwareAccelerated); - - Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); - beg.Store(dPtr); - - break; - } - - // Two Vector128's worth of data, with at least one element overlapping. - case 31: - case 30: - case 29: - case 28: - case 27: - case 26: - case 25: - case 24: - case 23: - case 22: - case 21: - case 20: - case 19: - case 18: - case 17: - { - Debug.Assert(Vector128.IsHardwareAccelerated); - - Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); - Vector128 end = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)Vector128.Count), - Vector128.Load(yPtr + remainder - (uint)Vector128.Count)); - - beg.Store(dPtr); - end.Store(dPtr + remainder - (uint)Vector128.Count); - - break; - } - - // One Vector128's worth of data. - case 16: - { - Debug.Assert(Vector128.IsHardwareAccelerated); - - Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); - beg.Store(dPtr); - - break; - } - - // Cases that are smaller than a single vector. No SIMD; just jump to the length and fall through each - // case to unroll the whole processing. - case 15: - *(dPtr + 14) = TBinaryOperator.Invoke(*(xPtr + 14), *(yPtr + 14)); - goto case 14; - - case 14: - *(dPtr + 13) = TBinaryOperator.Invoke(*(xPtr + 13), *(yPtr + 13)); - goto case 13; - - case 13: - *(dPtr + 12) = TBinaryOperator.Invoke(*(xPtr + 12), *(yPtr + 12)); - goto case 12; - - case 12: - *(dPtr + 11) = TBinaryOperator.Invoke(*(xPtr + 11), *(yPtr + 11)); - goto case 11; - - case 11: - *(dPtr + 10) = TBinaryOperator.Invoke(*(xPtr + 10), *(yPtr + 10)); - goto case 10; - - case 10: - *(dPtr + 9) = TBinaryOperator.Invoke(*(xPtr + 9), *(yPtr + 9)); - goto case 9; - - case 9: - *(dPtr + 8) = TBinaryOperator.Invoke(*(xPtr + 8), *(yPtr + 8)); - goto case 8; - - case 8: - *(dPtr + 7) = TBinaryOperator.Invoke(*(xPtr + 7), *(yPtr + 7)); - goto case 7; - - case 7: - *(dPtr + 6) = TBinaryOperator.Invoke(*(xPtr + 6), *(yPtr + 6)); - goto case 6; - - case 6: - *(dPtr + 5) = TBinaryOperator.Invoke(*(xPtr + 5), *(yPtr + 5)); - goto case 5; - - case 5: - *(dPtr + 4) = TBinaryOperator.Invoke(*(xPtr + 4), *(yPtr + 4)); - goto case 4; - - case 4: - *(dPtr + 3) = TBinaryOperator.Invoke(*(xPtr + 3), *(yPtr + 3)); - goto case 3; - - case 3: - *(dPtr + 2) = TBinaryOperator.Invoke(*(xPtr + 2), *(yPtr + 2)); - goto case 2; - - case 2: - *(dPtr + 1) = TBinaryOperator.Invoke(*(xPtr + 1), *(yPtr + 1)); - goto case 1; - - case 1: - *dPtr = TBinaryOperator.Invoke(*xPtr, *yPtr); - goto case 0; - - case 0: - break; - } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void VectorizedSmall2(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) - { - Debug.Assert(sizeof(T) == 2); - - switch (remainder) - { - // Two Vector256's worth of data, with at least one element overlapping. - case 31: - case 30: - case 29: - case 28: - case 27: - case 26: - case 25: - case 24: - case 23: - case 22: - case 21: - case 20: - case 19: - case 18: - case 17: - { - Debug.Assert(Vector256.IsHardwareAccelerated); - - Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); - Vector256 end = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)Vector256.Count), - Vector256.Load(yPtr + remainder - (uint)Vector256.Count)); - - beg.Store(dPtr); - end.Store(dPtr + remainder - (uint)Vector256.Count); - - break; - } - - // One Vector256's worth of data. - case 16: - { - Debug.Assert(Vector256.IsHardwareAccelerated); - - Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); - beg.Store(dPtr); - - break; - } - - // Two Vector128's worth of data, with at least one element overlapping. - case 15: - case 14: - case 13: - case 12: - case 11: - case 10: - case 9: - { - Debug.Assert(Vector128.IsHardwareAccelerated); - - Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); - Vector128 end = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)Vector128.Count), - Vector128.Load(yPtr + remainder - (uint)Vector128.Count)); - - beg.Store(dPtr); - end.Store(dPtr + remainder - (uint)Vector128.Count); - - break; - } - - // One Vector128's worth of data. - case 8: - { - Debug.Assert(Vector128.IsHardwareAccelerated); - - Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); - beg.Store(dPtr); - - break; - } - - // Cases that are smaller than a single vector. No SIMD; just jump to the length and fall through each - // case to unroll the whole processing. - case 7: - *(dPtr + 6) = TBinaryOperator.Invoke(*(xPtr + 6), *(yPtr + 6)); - goto case 6; - - case 6: - *(dPtr + 5) = TBinaryOperator.Invoke(*(xPtr + 5), *(yPtr + 5)); - goto case 5; - - case 5: - *(dPtr + 4) = TBinaryOperator.Invoke(*(xPtr + 4), *(yPtr + 4)); - goto case 4; - - case 4: - *(dPtr + 3) = TBinaryOperator.Invoke(*(xPtr + 3), *(yPtr + 3)); - goto case 3; - - case 3: - *(dPtr + 2) = TBinaryOperator.Invoke(*(xPtr + 2), *(yPtr + 2)); - goto case 2; - - case 2: - *(dPtr + 1) = TBinaryOperator.Invoke(*(xPtr + 1), *(yPtr + 1)); - goto case 1; - - case 1: - *dPtr = TBinaryOperator.Invoke(*xPtr, *yPtr); - goto case 0; - - case 0: - break; - } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void VectorizedSmall4(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) - { - Debug.Assert(sizeof(T) == 4); - - switch (remainder) - { - case 15: - case 14: - case 13: - case 12: - case 11: - case 10: - case 9: - { - Debug.Assert(Vector256.IsHardwareAccelerated); - - Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); - Vector256 end = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)Vector256.Count), - Vector256.Load(yPtr + remainder - (uint)Vector256.Count)); - - beg.Store(dPtr); - end.Store(dPtr + remainder - (uint)Vector256.Count); - - break; - } - - case 8: - { - Debug.Assert(Vector256.IsHardwareAccelerated); - - Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); - beg.Store(dPtr); - - break; - } - - case 7: - case 6: - case 5: - { - Debug.Assert(Vector128.IsHardwareAccelerated); - - Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); - Vector128 end = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)Vector128.Count), - Vector128.Load(yPtr + remainder - (uint)Vector128.Count)); - - beg.Store(dPtr); - end.Store(dPtr + remainder - (uint)Vector128.Count); - - break; - } - - case 4: - { - Debug.Assert(Vector128.IsHardwareAccelerated); - - Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); - beg.Store(dPtr); - - break; - } - - case 3: - { - *(dPtr + 2) = TBinaryOperator.Invoke(*(xPtr + 2), *(yPtr + 2)); - goto case 2; - } - - case 2: - { - *(dPtr + 1) = TBinaryOperator.Invoke(*(xPtr + 1), *(yPtr + 1)); - goto case 1; - } - - case 1: - { - *dPtr = TBinaryOperator.Invoke(*xPtr, *yPtr); - goto case 0; - } - - case 0: - { - break; - } - } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void VectorizedSmall8(ref T* xPtr, ref T* yPtr, ref T* dPtr, nuint remainder) - { - Debug.Assert(sizeof(T) == 8); - - switch (remainder) - { - case 7: - case 6: - case 5: - { - Debug.Assert(Vector256.IsHardwareAccelerated); - - Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); - Vector256 end = TBinaryOperator.Invoke(Vector256.Load(xPtr + remainder - (uint)Vector256.Count), - Vector256.Load(yPtr + remainder - (uint)Vector256.Count)); - - beg.Store(dPtr); - end.Store(dPtr + remainder - (uint)Vector256.Count); - - break; - } - - case 4: - { - Debug.Assert(Vector256.IsHardwareAccelerated); - - Vector256 beg = TBinaryOperator.Invoke(Vector256.Load(xPtr), Vector256.Load(yPtr)); - beg.Store(dPtr); - - break; - } - - case 3: - { - Debug.Assert(Vector128.IsHardwareAccelerated); - - Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); - Vector128 end = TBinaryOperator.Invoke(Vector128.Load(xPtr + remainder - (uint)Vector128.Count), - Vector128.Load(yPtr + remainder - (uint)Vector128.Count)); - - beg.Store(dPtr); - end.Store(dPtr + remainder - (uint)Vector128.Count); - - break; - } - - case 2: - { - Debug.Assert(Vector128.IsHardwareAccelerated); - - Vector128 beg = TBinaryOperator.Invoke(Vector128.Load(xPtr), Vector128.Load(yPtr)); - beg.Store(dPtr); - - break; - } - - case 1: - { - *dPtr = TBinaryOperator.Invoke(*xPtr, *yPtr); - goto case 0; - } - - case 0: - { - break; - } - } - } - } } } \ No newline at end of file