diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index fd8a6aaa9..2482bfa83 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -128,7 +128,7 @@ jobs: build_type: pull-request node_type: "gpu-v100-latest-1" arch: "amd64" - container_image: "rapidsai/ci-conda:cuda12.8.0-ubuntu24.04-py3.12" + container_image: "rapidsai/ci-conda:latest" run_script: "ci/build_docs.sh" rust-build: needs: conda-cpp-build diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 3ed3227e0..65b1471f5 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -208,6 +208,9 @@ if(BUILD_SHARED_LIBS) src/neighbors/cagra_search_int8.cu src/neighbors/cagra_search_uint8.cu src/neighbors/detail/cagra/compute_distance.cu + src/neighbors/detail/cagra/compute_distance_standard_BitwiseHamming_u8_uint32_dim128_t8.cu + src/neighbors/detail/cagra/compute_distance_standard_BitwiseHamming_u8_uint32_dim256_t16.cu + src/neighbors/detail/cagra/compute_distance_standard_BitwiseHamming_u8_uint32_dim512_t32.cu src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim128_t8.cu src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim256_t16.cu src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim512_t32.cu @@ -469,6 +472,7 @@ if(BUILD_SHARED_LIBS) src/neighbors/vamana_serialize_uint8.cu src/neighbors/vamana_serialize_int8.cu src/preprocessing/quantize/scalar.cu + src/preprocessing/quantize/binary.cu src/selection/select_k_float_int64_t.cu src/selection/select_k_float_int32_t.cu src/selection/select_k_float_uint32_t.cu diff --git a/cpp/include/cuvs/distance/distance.h b/cpp/include/cuvs/distance/distance.h index 550221e8e..8d76d5f6d 100644 --- a/cpp/include/cuvs/distance/distance.h +++ b/cpp/include/cuvs/distance/distance.h @@ -62,6 +62,8 @@ typedef enum { RusselRaoExpanded = 18, /** Dice-Sorensen distance **/ DiceExpanded = 19, + /** Bitstring Hamming distance **/ + BitwiseHamming = 20, /** Precomputed (special value) **/ Precomputed = 100 } cuvsDistanceType; diff --git a/cpp/include/cuvs/preprocessing/quantize/binary.hpp b/cpp/include/cuvs/preprocessing/quantize/binary.hpp new file mode 100644 index 000000000..4f7d36048 --- /dev/null +++ b/cpp/include/cuvs/preprocessing/quantize/binary.hpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace cuvs::preprocessing::quantize::binary { + +/** + * @defgroup binary Binary quantizer utilities + * @{ + */ + +/** + * @brief Applies binary quantization transform to given dataset. If a dataset element is positive, + * set the corresponding bit to 1. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * auto quantized_dataset = raft::make_device_matrix(handle, samples, + * features); cuvs::preprocessing::quantize::binary::transform(handle, dataset, + * quantized_dataset.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] dataset a row-major matrix view on device + * @param[out] out a row-major matrix view on device + * + */ +void transform(raft::resources const& res, + raft::device_matrix_view dataset, + raft::device_matrix_view out); + +/** + * @brief Applies binary quantization transform to given dataset. If a dataset element is positive, + * set the corresponding bit to 1. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * auto quantized_dataset = raft::make_host_matrix(handle, samples, + * features); cuvs::preprocessing::quantize::binary::transform(handle, dataset, + * quantized_dataset.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] dataset a row-major matrix view on host + * @param[out] out a row-major matrix view on host + * + */ +void transform(raft::resources const& res, + raft::host_matrix_view dataset, + raft::host_matrix_view out); + +/** + * @brief Applies binary quantization transform to given dataset. If a dataset element is positive, + * set the corresponding bit to 1. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * raft::device_matrix dataset = read_dataset(filename); + * int64_t quantized_dim = raft::div_rounding_up_safe(dataset.extent(1), sizeof(uint8_t) * 8); + * auto quantized_dataset = raft::make_device_matrix( + * handle, dataset.extent(0), quantized_dim); + * cuvs::preprocessing::quantize::binary::transform(handle, dataset, quantized_dataset.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] dataset a row-major matrix view on device + * @param[out] out a row-major matrix view on device + * + */ +void transform(raft::resources const& res, + raft::device_matrix_view dataset, + raft::device_matrix_view out); + +/** + * @brief Applies binary quantization transform to given dataset. If a dataset element is positive, + * set the corresponding bit to 1. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * raft::host_matrix dataset = read_dataset(filename); + * int64_t quantized_dim = raft::div_rounding_up_safe(dataset.extent(1), sizeof(uint8_t) * 8); + * auto quantized_dataset = raft::make_host_matrix( + * handle, dataset.extent(0), quantized_dim); + * cuvs::preprocessing::quantize::binary::transform(handle, dataset, quantized_dataset.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] dataset a row-major matrix view on host + * @param[out] out a row-major matrix view on host + * + */ +void transform(raft::resources const& res, + raft::host_matrix_view dataset, + raft::host_matrix_view out); + +/** + * @brief Applies binary quantization transform to given dataset. If a dataset element is positive, + * set the corresponding bit to 1. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * raft::device_matrix dataset = read_dataset(filename); + * int64_t quantized_dim = raft::div_rounding_up_safe(dataset.extent(1), sizeof(uint8_t) * 8); + * auto quantized_dataset = raft::make_device_matrix( + * handle, dataset.extent(0), quantized_dim); + * cuvs::preprocessing::quantize::binary::transform(handle, dataset, quantized_dataset.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] dataset a row-major matrix view on device + * @param[out] out a row-major matrix view on device + * + */ +void transform(raft::resources const& res, + raft::device_matrix_view dataset, + raft::device_matrix_view out); + +/** + * @brief Applies binary quantization transform to given dataset. If a dataset element is positive, + * set the corresponding bit to 1. + * + * Usage example: + * @code{.cpp} + * raft::handle_t handle; + * raft::host_matrix dataset = read_dataset(filename); + * int64_t quantized_dim = raft::div_rounding_up_safe(dataset.extent(1), sizeof(uint8_t) * 8); + * auto quantized_dataset = raft::make_host_matrix( + * handle, dataset.extent(0), quantized_dim); + * cuvs::preprocessing::quantize::binary::transform(handle, dataset, quantized_dataset.view()); + * @endcode + * + * @param[in] res raft resource + * @param[in] dataset a row-major matrix view on host + * @param[out] out a row-major matrix view on host + * + */ +void transform(raft::resources const& res, + raft::host_matrix_view dataset, + raft::host_matrix_view out); + +/** @} */ // end of group binary + +} // namespace cuvs::preprocessing::quantize::binary diff --git a/cpp/internal/cuvs_internal/neighbors/naive_knn.cuh b/cpp/internal/cuvs_internal/neighbors/naive_knn.cuh index 1bde550ab..5c9c01f54 100644 --- a/cpp/internal/cuvs_internal/neighbors/naive_knn.cuh +++ b/cpp/internal/cuvs_internal/neighbors/naive_knn.cuh @@ -45,19 +45,27 @@ RAFT_KERNEL naive_distance_kernel(EvalT* dist, for (IdxT i = 0; i < k; ++i) { IdxT xidx = i + midx * k; IdxT yidx = i + nidx * k; - auto xv = EvalT(x[xidx]); - auto yv = EvalT(y[yidx]); + auto xv = x[xidx]; + auto yv = y[yidx]; switch (metric) { case cuvs::distance::DistanceType::InnerProduct: { - acc += xv * yv; + acc += static_cast(xv) * static_cast(yv); + } break; + case cuvs::distance::DistanceType::CosineExpanded: { + acc += static_cast(xv) * static_cast(yv); } break; case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2SqrtUnexpanded: case cuvs::distance::DistanceType::L2Expanded: case cuvs::distance::DistanceType::L2Unexpanded: { - auto diff = xv - yv; + auto diff = static_cast(xv) - static_cast(yv); acc += diff * diff; } break; + case cuvs::distance::DistanceType::BitwiseHamming: { + if constexpr (std::is_same_v) { + acc += __popc(static_cast(xv ^ yv) & 0xff); + } + } break; default: break; } } diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index ee4b1444f..4d559a662 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -602,6 +602,11 @@ index build( knn_build_params = cagra::graph_build_params::ivf_pq_params(dataset.extents(), params.metric); } } + RAFT_EXPECTS( + params.metric != BitwiseHamming || + std::holds_alternative(knn_build_params), + "IVF_PQ and NN_DESCENT for CAGRA graph build do not support BitwiseHamming as a metric. Please " + "use the iterative CAGRA search build."); auto cagra_graph = raft::make_host_matrix(0, 0); diff --git a/cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh b/cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh index df447d196..faca96080 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh @@ -392,6 +392,24 @@ extern template struct vpq_descriptor_spec; +extern template struct standard_descriptor_spec; +extern template struct standard_descriptor_spec; +extern template struct standard_descriptor_spec; extern template struct instance_selector< standard_descriptor_spec, @@ -441,7 +459,10 @@ extern template struct instance_selector< standard_descriptor_spec, standard_descriptor_spec, vpq_descriptor_spec, - vpq_descriptor_spec>; + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec>; using descriptor_instances = instance_selector< standard_descriptor_spec, @@ -491,7 +512,10 @@ using descriptor_instances = instance_selector< standard_descriptor_spec, standard_descriptor_spec, vpq_descriptor_spec, - vpq_descriptor_spec>; + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec>; template auto dataset_descriptor_init(const cagra::search_params& params, diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.cu b/cpp/src/neighbors/detail/cagra/compute_distance.cu index 45316e59b..415886346 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.cu +++ b/cpp/src/neighbors/detail/cagra/compute_distance.cu @@ -77,6 +77,9 @@ template struct instance_selector< standard_descriptor_spec, standard_descriptor_spec, vpq_descriptor_spec, - vpq_descriptor_spec>; + vpq_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec, + standard_descriptor_spec>; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py b/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py index aef31d161..1c813f1a5 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py @@ -113,6 +113,24 @@ f.write(template.format(includes=includes, content=content)) cmake_list.append(f" src/neighbors/detail/cagra/{path}") +# CAGRA (Binary Hamming distance) +for (mxdim, team) in mxdim_team: + metric = 'BitwiseHamming' + type_path = 'u8_uint32' + idx_t = 'uint32_t' + distance_t = 'float' + data_t = 'uint8_t' + + path = f"compute_distance_standard_{metric}_{type_path}_dim{mxdim}_t{team}.cu" + includes = '#include "compute_distance_standard-impl.cuh"' + params = f"{metric_prefix}{metric}, {team}, {mxdim}, {data_t}, {idx_t}, {distance_t}" + spec = f"standard_descriptor_spec<{params}>" + content = f"""template struct {spec};""" + specs.append(spec) + with open(path, "w") as f: + f.write(template.format(includes=includes, content=content)) + cmake_list.append(f" src/neighbors/detail/cagra/{path}") + with open("compute_distance-ext.cuh", "w") as f: includes = ''' #pragma once diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh index 877d83fff..fdc873100 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh @@ -25,19 +25,30 @@ namespace cuvs::neighbors::cagra::detail { namespace { -template -RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(T a, T b) - -> std::enable_if_t +template +RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(DATA_T a, DATA_T b) + -> std::enable_if_t { - T diff = a - b; + DISTANCE_T diff = a - b; return diff * diff; } -template -RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(T a, T b) - -> std::enable_if_t +template +RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(DATA_T a, DATA_T b) + -> std::enable_if_t { - return -a * b; + return -static_cast(a) * static_cast(b); +} + +template +RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(DATA_T a, DATA_T b) + -> std::enable_if_t, + DISTANCE_T> +{ + // mask the result of xor for the integer promotion + const auto v = (a ^ b) & 0xffu; + return __popc(v); } } // namespace @@ -49,7 +60,8 @@ template struct standard_dataset_descriptor_t : public dataset_descriptor_base_t { using base_type = dataset_descriptor_base_t; - using QUERY_T = float; + using QUERY_T = typename std:: + conditional_t; using base_type::args; using base_type::smem_ws_size_in_bytes; using typename base_type::args_t; @@ -150,7 +162,7 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_standard( if (i < dim) { buf[j] = cuvs::spatial::knn::detail::utils::mapping{}(queries_ptr[i]); } else { - buf[j] = 0.0; + buf[j] = 0; } } @@ -194,13 +206,13 @@ RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_standard_worker( // because: // - Above the last element (dataset_dim-1), the query array is filled with zeros. // - The data buffer has to be also padded with zeros. - DISTANCE_T d; + QUERY_T d; device::lds( d, query_smem_ptr + sizeof(QUERY_T) * device::swizzling(k + v)); - r += dist_op( - d, cuvs::spatial::knn::detail::utils::mapping{}(data[e][v])); + r += dist_op( + d, cuvs::spatial::knn::detail::utils::mapping{}(data[e][v])); } } } diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_BitwiseHamming_u8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_BitwiseHamming_u8_uint32_dim128_t8.cu new file mode 100644 index 000000000..477ecc664 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_BitwiseHamming_u8_uint32_dim128_t8.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by compute_distance_00_generate.py + * + * Make changes there and run in this directory: + * + * > python compute_distance_00_generate.py + * + */ + +#include "compute_distance_standard-impl.cuh" + +namespace cuvs::neighbors::cagra::detail { + +using namespace cuvs::distance; +template struct standard_descriptor_spec; + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_BitwiseHamming_u8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_BitwiseHamming_u8_uint32_dim256_t16.cu new file mode 100644 index 000000000..3dfc2538f --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_BitwiseHamming_u8_uint32_dim256_t16.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by compute_distance_00_generate.py + * + * Make changes there and run in this directory: + * + * > python compute_distance_00_generate.py + * + */ + +#include "compute_distance_standard-impl.cuh" + +namespace cuvs::neighbors::cagra::detail { + +using namespace cuvs::distance; +template struct standard_descriptor_spec; + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard_BitwiseHamming_u8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/compute_distance_standard_BitwiseHamming_u8_uint32_dim512_t32.cu new file mode 100644 index 000000000..2c9f48db7 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard_BitwiseHamming_u8_uint32_dim512_t32.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by compute_distance_00_generate.py + * + * Make changes there and run in this directory: + * + * > python compute_distance_00_generate.py + * + */ + +#include "compute_distance_standard-impl.cuh" + +namespace cuvs::neighbors::cagra::detail { + +using namespace cuvs::distance; +template struct standard_descriptor_spec; + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/device_common.hpp b/cpp/src/neighbors/detail/cagra/device_common.hpp index e5886582d..882928add 100644 --- a/cpp/src/neighbors/detail/cagra/device_common.hpp +++ b/cpp/src/neighbors/detail/cagra/device_common.hpp @@ -281,6 +281,13 @@ RAFT_DEVICE_INLINE_FUNCTION void lds(half (&x)[4], uint32_t addr) : "r"(addr)); } +RAFT_DEVICE_INLINE_FUNCTION void lds(uint8_t& x, uint32_t addr) +{ + uint32_t res; + asm volatile("ld.shared.u8 {%0}, [%1];" : "=r"(res) : "r"(addr)); + x = static_cast(res); +} + RAFT_DEVICE_INLINE_FUNCTION void lds(uint32_t& x, uint32_t addr) { asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(x) : "r"(addr)); diff --git a/cpp/src/neighbors/ivf_common.cuh b/cpp/src/neighbors/ivf_common.cuh index fb73fb8a9..0c6f7c5ba 100644 --- a/cpp/src/neighbors/ivf_common.cuh +++ b/cpp/src/neighbors/ivf_common.cuh @@ -268,6 +268,7 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk] raft::linalg::unaryOp(out, in, len, raft::cast_op{}, stream); } } break; + case distance::DistanceType::BitwiseHamming: break; default: RAFT_FAIL("Unexpected metric."); } } diff --git a/cpp/src/preprocessing/quantize/binary.cu b/cpp/src/preprocessing/quantize/binary.cu new file mode 100644 index 000000000..fd04c7fe4 --- /dev/null +++ b/cpp/src/preprocessing/quantize/binary.cu @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "./detail/binary.cuh" + +#include + +namespace cuvs::preprocessing::quantize::binary { + +#define CUVS_INST_QUANTIZATION(T, QuantI) \ + void transform(raft::resources const& res, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view out) \ + { \ + detail::transform(res, dataset, out); \ + } \ + void transform(raft::resources const& res, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view out) \ + { \ + detail::transform(res, dataset, out); \ + } + +CUVS_INST_QUANTIZATION(double, uint8_t); +CUVS_INST_QUANTIZATION(float, uint8_t); +CUVS_INST_QUANTIZATION(half, uint8_t); + +#undef CUVS_INST_QUANTIZATION + +} // namespace cuvs::preprocessing::quantize::binary diff --git a/cpp/src/preprocessing/quantize/detail/binary.cuh b/cpp/src/preprocessing/quantize/detail/binary.cuh new file mode 100644 index 000000000..8ddc33c94 --- /dev/null +++ b/cpp/src/preprocessing/quantize/detail/binary.cuh @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuvs::preprocessing::quantize::detail { + +template +_RAFT_HOST_DEVICE bool is_positive(const T& a) +{ + return a > 0; +} + +template <> +_RAFT_HOST_DEVICE bool is_positive(const half& a) +{ + return is_positive(static_cast(a)); +} + +template +RAFT_KERNEL binary_quantization_kernel(const T* const in_ptr, + const uint32_t ldi, + const size_t dataset_size, + const uint32_t dataset_dim, + pack_t* const out_ptr, + const uint32_t ldo) +{ + constexpr uint32_t warp_size = 32; + const uint32_t bits_per_pack = sizeof(pack_t) * 8; + const auto output_dim = raft::div_rounding_up_safe(dataset_dim, bits_per_pack); + + const auto vector_id = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; + if (vector_id >= dataset_size) { return; } + + __shared__ pack_t smem[block_size]; + auto local_smem = smem + (threadIdx.x / warp_size) * warp_size; + + const auto lane_id = threadIdx.x % warp_size; + for (uint32_t j_offset = 0; j_offset < dataset_dim; j_offset += warp_size * bits_per_pack) { + // Load dataset vector elements with coalesce access. The mapping of the vector element position + // and the `pack` register is as follows: + // + // lane_id | LSB (pack(u8)) MSB + // 0 | 0, 32, 64, ..., 224 + // 1 | 1, 33, 65, ..., 225 + // ... + // 31 | 31, 63, 95, ..., 255 + pack_t pack = 0; + for (uint32_t bit_offset = 0; bit_offset < bits_per_pack; bit_offset++) { + const auto j = j_offset + lane_id + bit_offset * warp_size; + if (j < dataset_dim) { + const auto v = in_ptr[vector_id * ldi + j]; + if (is_positive(v)) { pack |= (1u << bit_offset); } + } + } + + // Store the local result in smem so that the other threads in the same warp can read + local_smem[lane_id] = pack; + + // Store the result with (a kind of) transposition so that the the coalesce access can be used. + // The mapping of the result `pack` register bit position and (smem_index, bit_position) is as + // follows: + // + // lane_id | LSB (pack(u8)) MSB + // 0 | ( 0,0), ( 1,0), ( 2,0), ..., ( 7,0) + // 1 | ( 8,0), ( 9,0), (10,0), ..., (15,0) + // ... + // 4 | ( 0,1), ( 1,1), ( 2,1), ..., ( 7,1) + // ... + // 31 | (24,7), (25,7), (26,7), ..., (31,7) + // + // The `bit_position`-th bit of `local_smem[smem_index]` is copied to the corresponding `pack` + // bit. By this mapping, the quantization result of 8*i-th ~ (8*(i+1)-1)-th vector elements is + // stored by the lane_id=i thread. + pack = 0; + const auto smem_start_i = (lane_id % (warp_size / bits_per_pack)) * bits_per_pack; + const auto mask = 1u << (lane_id / (warp_size / bits_per_pack)); + for (uint32_t j = 0; j < bits_per_pack; j++) { + if (local_smem[smem_start_i + j] & mask) { pack |= (1u << j); } + } + + const auto out_j = j_offset / bits_per_pack + lane_id; + if (out_j < output_dim) { out_ptr[vector_id * ldo + out_j] = pack; } + } +} + +template +void transform(raft::resources const& res, + raft::device_matrix_view dataset, + raft::device_matrix_view out) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(res); + const uint32_t bits_per_pack = sizeof(QuantI) * 8; + const uint32_t dataset_dim = dataset.extent(1); + const uint32_t out_dim = out.extent(1); + const size_t dataset_size = dataset.extent(0); + const size_t out_dataset_size = out.extent(0); + const uint32_t minimul_out_dim = raft::div_rounding_up_safe(dataset_dim, bits_per_pack); + RAFT_EXPECTS(out_dim >= minimul_out_dim, + "The quantized dataset dimension must be larger or equal to " + "%u but is %u passed", + minimul_out_dim, + out_dim); + RAFT_EXPECTS(out_dataset_size >= dataset_size, + "The quantized dataset size must be larger or equal to " + "the input dataset size (%lu) but is %lu passed", + dataset_size, + out_dataset_size); + + constexpr uint32_t warp_size = 32; + constexpr uint32_t block_size = 256; + constexpr uint32_t vecs_per_cta = block_size / warp_size; + const auto grid_size = + raft::div_rounding_up_safe(dataset_size, static_cast(vecs_per_cta)); + + binary_quantization_kernel + <<>>(dataset.data_handle(), + dataset.stride(0), + dataset_size, + dataset_dim, + out.data_handle(), + out.stride(0)); +} + +template +void transform(raft::resources const& res, + raft::host_matrix_view dataset, + raft::host_matrix_view out) +{ + const uint32_t bits_per_pack = sizeof(QuantI) * 8; + const uint32_t dataset_dim = dataset.extent(1); + const uint32_t out_dim = out.extent(1); + const size_t dataset_size = dataset.extent(0); + const size_t out_dataset_size = out.extent(0); + const uint32_t minimul_out_dim = raft::div_rounding_up_safe(dataset_dim, bits_per_pack); + RAFT_EXPECTS(out_dim >= minimul_out_dim, + "The quantized dataset dimension must be larger or equal to " + "%u but is %u passed", + minimul_out_dim, + out_dim); + RAFT_EXPECTS(out_dataset_size >= dataset_size, + "The quantized dataset size must be larger or equal to " + "the input dataset size (%lu) but is %lu passed", + dataset_size, + out_dataset_size); + +#pragma omp parallel for collapse(2) + for (size_t i = 0; i < dataset_size; ++i) { + for (uint32_t out_j = 0; out_j < minimul_out_dim; ++out_j) { + QuantI pack = 0; + for (uint32_t pack_j = 0; pack_j < bits_per_pack; ++pack_j) { + const uint32_t in_j = out_j * bits_per_pack + pack_j; + if (in_j < dataset_dim) { + if (is_positive(dataset(i, in_j))) { pack |= (1u << pack_j); } + } + } + out(i, out_j) = pack; + } + } +} +} // namespace cuvs::preprocessing::quantize::detail diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 3ed37175d..cea4e0ce6 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -248,7 +248,8 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME PREPROCESSING_TEST PATH preprocessing/scalar_quantization.cu GPUS 1 PERCENT 100 + NAME PREPROCESSING_TEST PATH preprocessing/scalar_quantization.cu + preprocessing/binary_quantization.cu GPUS 1 PERCENT 100 ) ConfigureTest( diff --git a/cpp/tests/neighbors/ann_cagra.cuh b/cpp/tests/neighbors/ann_cagra.cuh index 99d1fd5cc..aedb11543 100644 --- a/cpp/tests/neighbors/ann_cagra.cuh +++ b/cpp/tests/neighbors/ann_cagra.cuh @@ -290,6 +290,7 @@ inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) switch (dist) { case InnerProduct: return "InnerProduct"; case L2Expanded: return "L2"; + case BitwiseHamming: return "BitwiseHamming"; default: break; } return "Unknown"; @@ -328,6 +329,18 @@ class AnnCagraTest : public ::testing::TestWithParam { protected: void testCagra() { + // IVF_PQ and NN_DESCENT graph builds do not support BitwiseHamming + if (ps.metric == cuvs::distance::DistanceType::BitwiseHamming && + ((!std::is_same_v) || + (ps.build_algo != graph_build_algo::ITERATIVE_CAGRA_SEARCH))) + GTEST_SKIP(); + // If the dataset dimension is small and the dataset size is large, there can be a lot of + // dataset vectors that have the same distance to the query, especially in the binary Hamming + // distance, making it impossible to make a top-k ground truth. + if (ps.metric == cuvs::distance::DistanceType::BitwiseHamming && + (ps.k * ps.dim * 8 / 5 /*(=magic number)*/ < ps.n_rows)) + GTEST_SKIP(); + size_t queries_size = ps.n_queries * ps.k; std::vector indices_Cagra(queries_size); std::vector indices_naive(queries_size); @@ -507,6 +520,17 @@ class AnnCagraAddNodesTest : public ::testing::TestWithParam { // issue: https://github.com/rapidsai/raft/issues/2276 if (ps.metric == InnerProduct && ps.build_algo == graph_build_algo::NN_DESCENT) GTEST_SKIP(); if (ps.compression != std::nullopt) GTEST_SKIP(); + // IVF_PQ and NN_DESCENT graph builds do not support BitwiseHamming + if (ps.metric == cuvs::distance::DistanceType::BitwiseHamming && + ((!std::is_same_v) || + (ps.build_algo != graph_build_algo::ITERATIVE_CAGRA_SEARCH))) + GTEST_SKIP(); + // If the dataset dimension is small and the dataset size is large, there can be a lot of + // dataset vectors that have the same distance to the query, especially in the binary Hamming + // distance, making it impossible to make a top-k ground truth. + if (ps.metric == cuvs::distance::DistanceType::BitwiseHamming && + (ps.k * ps.dim * 8 / 5 /*(=magic number)*/ < ps.n_rows)) + GTEST_SKIP(); size_t queries_size = ps.n_queries * ps.k; std::vector indices_Cagra(queries_size); @@ -706,6 +730,17 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { if (ps.metric == cuvs::distance::DistanceType::InnerProduct && ps.build_algo == graph_build_algo::NN_DESCENT) GTEST_SKIP(); + // IVF_PQ and NN_DESCENT graph builds do not support BitwiseHamming + if (ps.metric == cuvs::distance::DistanceType::BitwiseHamming && + ((!std::is_same_v) || + (ps.build_algo != graph_build_algo::ITERATIVE_CAGRA_SEARCH))) + GTEST_SKIP(); + // If the dataset dimension is small and the dataset size is large, there can be a lot of + // dataset vectors that have the same distance to the query, especially in the binary Hamming + // distance, making it impossible to make a top-k ground truth. + if (ps.metric == cuvs::distance::DistanceType::BitwiseHamming && + (ps.k * ps.dim * 8 / 5 /*(=magic number)*/ < ps.n_rows)) + GTEST_SKIP(); size_t queries_size = ps.n_queries * ps.k; std::vector indices_Cagra(queries_size); @@ -896,7 +931,9 @@ inline std::vector generate_inputs() {1000}, {1, 8, 17}, {16}, // k - {graph_build_algo::NN_DESCENT, graph_build_algo::ITERATIVE_CAGRA_SEARCH}, + {graph_build_algo::NN_DESCENT, + graph_build_algo::ITERATIVE_CAGRA_SEARCH}, // build algo. ITERATIVE_CAGRA_SEARCH is needed to + // test BitwiseHamming {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, {0, 10}, // query size {0}, @@ -939,7 +976,9 @@ inline std::vector generate_inputs() {0}, {64}, {1}, - {cuvs::distance::DistanceType::L2Expanded, cuvs::distance::DistanceType::InnerProduct}, + {cuvs::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType::InnerProduct, + cuvs::distance::DistanceType::BitwiseHamming}, {false}, {true}, {0.995}); @@ -1053,21 +1092,21 @@ inline std::vector generate_inputs() inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); // Varying dim, adding non_owning_memory_buffer_flag - inputs2 = raft::util::itertools::product( - {100}, - {1000}, - {1, 5, 8, 64, 137, 256, 619, 1024}, // dim - {10}, - {graph_build_algo::IVF_PQ}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {64}, - {1}, - {cuvs::distance::DistanceType::L2Expanded, cuvs::distance::DistanceType::InnerProduct}, - {false}, - {false}, - {0.995}); + inputs2 = + raft::util::itertools::product({100}, + {1000}, + {1, 5, 8, 64, 137, 256, 619, 1024}, // dim + {10}, + {graph_build_algo::IVF_PQ}, + {search_algo::AUTO}, + {10}, + {0}, // team_size + {64}, + {1}, + {cuvs::distance::DistanceType::L2Expanded}, + {false}, + {false}, + {0.995}); for (auto input : inputs2) { input.non_owning_memory_buffer_flag = true; inputs.push_back(input); @@ -1079,21 +1118,23 @@ inline std::vector generate_inputs() inline std::vector generate_addnode_inputs() { // changing dim - std::vector inputs = raft::util::itertools::product( - {100}, - {1000}, - {1, 8, 17, 64, 128, 137, 512, 1024}, // dim - {16}, // k - {graph_build_algo::NN_DESCENT}, - {search_algo::AUTO}, - {10}, - {0}, - {64}, - {1}, - {cuvs::distance::DistanceType::L2Expanded, cuvs::distance::DistanceType::InnerProduct}, - {false}, - {true}, - {0.995}); + std::vector inputs = + raft::util::itertools::product({100}, + {1000}, + {1, 8, 17, 64, 128, 137, 512, 1024}, // dim + {16}, // k + {graph_build_algo::ITERATIVE_CAGRA_SEARCH}, + {search_algo::AUTO}, + {10}, + {0}, + {64}, + {1}, + {cuvs::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType::InnerProduct, + cuvs::distance::DistanceType::BitwiseHamming}, + {false}, + {true}, + {0.995}); // testing host and device datasets auto inputs2 = raft::util::itertools::product( @@ -1150,7 +1191,7 @@ inline std::vector generate_filtering_inputs() std::vector inputs = raft::util::itertools::product( {100}, {1000}, - {1, 8, 17}, + {1, 8, 17, 102}, {16}, // k {graph_build_algo::NN_DESCENT}, {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, @@ -1158,7 +1199,9 @@ inline std::vector generate_filtering_inputs() {0}, {256}, {1}, - {cuvs::distance::DistanceType::L2Expanded, cuvs::distance::DistanceType::InnerProduct}, + {cuvs::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType::InnerProduct, + cuvs::distance::DistanceType::BitwiseHamming}, {false}, {true}, {0.995}); diff --git a/cpp/tests/neighbors/naive_knn.cuh b/cpp/tests/neighbors/naive_knn.cuh index 553e667aa..484022c79 100644 --- a/cpp/tests/neighbors/naive_knn.cuh +++ b/cpp/tests/neighbors/naive_knn.cuh @@ -47,24 +47,29 @@ RAFT_KERNEL naive_distance_kernel(EvalT* dist, for (IdxT i = 0; i < k; ++i) { IdxT xidx = i + midx * k; IdxT yidx = i + nidx * k; - auto xv = EvalT(x[xidx]); - auto yv = EvalT(y[yidx]); + auto xv = x[xidx]; + auto yv = y[yidx]; switch (metric) { case cuvs::distance::DistanceType::InnerProduct: { - acc += xv * yv; + acc += static_cast(xv) * static_cast(yv); } break; case cuvs::distance::DistanceType::CosineExpanded: { - acc += xv * yv; - normX += xv * xv; - normY += yv * yv; + acc += static_cast(xv) * static_cast(yv); + normX += static_cast(xv) * static_cast(xv); + normY += static_cast(yv) * static_cast(yv); } break; case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2SqrtUnexpanded: case cuvs::distance::DistanceType::L2Expanded: case cuvs::distance::DistanceType::L2Unexpanded: { - auto diff = xv - yv; + auto diff = static_cast(xv) - static_cast(yv); acc += diff * diff; } break; + case cuvs::distance::DistanceType::BitwiseHamming: { + if constexpr (std::is_same_v) { + acc += __popc(static_cast(xv ^ yv) & 0xff); + } + } break; default: break; } } diff --git a/cpp/tests/preprocessing/binary_quantization.cu b/cpp/tests/preprocessing/binary_quantization.cu new file mode 100644 index 000000000..af1cc07e8 --- /dev/null +++ b/cpp/tests/preprocessing/binary_quantization.cu @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include +#include +#include +#include +#include +#include + +namespace cuvs::preprocessing::quantize::binary { + +template +struct BinaryQuantizationInputs { + int rows; + int cols; +}; + +template +std::ostream& operator<<(std::ostream& os, const BinaryQuantizationInputs& inputs) +{ + return os << "> rows:" << inputs.rows << " cols:" << inputs.cols; +} + +template +class BinaryQuantizationTest : public ::testing::TestWithParam> { + public: + BinaryQuantizationTest() + : params_(::testing::TestWithParam>::GetParam()), + stream(raft::resource::get_cuda_stream(handle)), + input_(0, stream) + { + } + + protected: + void testBinaryQuantization() + { + // dataset identical on host / device + auto dataset = raft::make_device_matrix_view( + (const T*)(input_.data()), rows_, cols_); + auto dataset_h = raft::make_host_matrix_view( + (const T*)(host_input_.data()), rows_, cols_); + + { + static_assert(std::is_same_v); + + const auto col_quantized = raft::div_rounding_up_safe(cols_, 8); + auto quantized_input_h = raft::make_host_matrix(rows_, cols_); + auto quantized_input_d = raft::make_device_matrix(handle, rows_, cols_); + cuvs::preprocessing::quantize::binary::transform(handle, dataset, quantized_input_d.view()); + cuvs::preprocessing::quantize::binary::transform(handle, dataset_h, quantized_input_h.view()); + + ASSERT_TRUE(devArrMatchHost(quantized_input_h.data_handle(), + quantized_input_d.data_handle(), + input_.size(), + cuvs::Compare(), + stream)); + } + } + + void SetUp() override + { + rows_ = params_.rows; + cols_ = params_.cols; + + int n_elements = rows_ * cols_; + input_.resize(n_elements, stream); + host_input_.resize(n_elements); + + // random input + unsigned long long int seed = 1234ULL; + raft::random::RngState r(seed); + uniform(handle, r, input_.data(), input_.size(), static_cast(-1), static_cast(1)); + + raft::update_host(host_input_.data(), input_.data(), input_.size(), stream); + + raft::resource::sync_stream(handle, stream); + } + + private: + raft::resources handle; + cudaStream_t stream; + + BinaryQuantizationInputs params_; + int rows_; + int cols_; + rmm::device_uvector input_; + std::vector host_input_; +}; + +template +const std::vector> inputs = { + {5, 5}, + {100, 7}, + {100, 128}, + {100, 1999}, + {1000, 1999}, +}; + +typedef BinaryQuantizationTest QuantizationTest_float_uint8t; +TEST_P(QuantizationTest_float_uint8t, BinaryQuantizationTest) { this->testBinaryQuantization(); } + +typedef BinaryQuantizationTest QuantizationTest_double_uint8t; +TEST_P(QuantizationTest_double_uint8t, BinaryQuantizationTest) { this->testBinaryQuantization(); } + +typedef BinaryQuantizationTest QuantizationTest_half_uint8t; +TEST_P(QuantizationTest_half_uint8t, BinaryQuantizationTest) { this->testBinaryQuantization(); } + +INSTANTIATE_TEST_CASE_P(BinaryQuantizationTest, + QuantizationTest_float_uint8t, + ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(BinaryQuantizationTest, + QuantizationTest_double_uint8t, + ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(BinaryQuantizationTest, + QuantizationTest_half_uint8t, + ::testing::ValuesIn(inputs)); + +} // namespace cuvs::preprocessing::quantize::binary