Skip to content

Commit

Permalink
Change brute_force api to match ivf*/cagra (#536)
Browse files Browse the repository at this point in the history
This changes the brute_force knn api to match that of ivf-* and cagra , by adding a search_params and index_params structure to the relevant calls.

This allows us to use the dynamic batching code on brute_force knn, as well as provide a more standardized API for our users.

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #536
  • Loading branch information
benfred authored Jan 8, 2025
1 parent e324412 commit 2a10353
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 89 deletions.
139 changes: 105 additions & 34 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#pragma once

#include "common.hpp"
#include <cuvs/neighbors/common.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdarray.hpp>
Expand All @@ -28,6 +27,10 @@

namespace cuvs::neighbors::brute_force {

struct index_params : cuvs::neighbors::index_params {};

struct search_params : cuvs::neighbors::search_params {};

/**
* @defgroup bruteforce_cpp_index Bruteforce index
* @{
Expand All @@ -41,6 +44,11 @@ namespace cuvs::neighbors::brute_force {
*/
template <typename T, typename DistT = T>
struct index : cuvs::neighbors::index {
using index_params_type = brute_force::index_params;
using search_params_type = brute_force::search_params;
using index_type = int64_t;
using value_type = T;

public:
index(const index&) = delete;
index(index&&) = default;
Expand Down Expand Up @@ -181,83 +189,105 @@ struct index : cuvs::neighbors::index {
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* @endcode
*
* @param[in] handle
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed brute-force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float, float>;
const cuvs::neighbors::brute_force::index_params& index_params,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<float, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float, float>;
/**
* @brief Build the index from the dataset for efficient search.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* @endcode
*
* @param[in] handle
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed ivf-flat index
* @return the constructed brute force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<half, float>;
const cuvs::neighbors::brute_force::index_params& index_params,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<half, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<half, float>;

/**
* @brief Build the index from the dataset for efficient search.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* @endcode
*
* @param[in] handle
* @param[in] dataset a device pointer to a col-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
*
* @return the constructed bruteforce index
* @return the constructed brute force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float, float>;
const cuvs::neighbors::brute_force::index_params& index_params,
raft::device_matrix_view<const float, int64_t, raft::col_major> dataset)
-> cuvs::neighbors::brute_force::index<float, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float, float>;

/**
* @brief Build the index from the dataset for efficient search.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* brute_force::index_params index_params;
* auto index = brute_force::build(handle, index_params, dataset);
* @endcode
*
* @param[in] handle
* @param[in] dataset a device pointer to a col-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
*
* @return the constructed bruteforce index
* @return the constructed brute force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<half, float>;
const cuvs::neighbors::brute_force::index_params& index_params,
raft::device_matrix_view<const half, int64_t, raft::col_major> dataset)
-> cuvs::neighbors::brute_force::index<half, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<half, float>;
/**
* @}
*/
Expand Down Expand Up @@ -286,6 +316,7 @@ auto build(raft::resources const& handle,
* @endcode
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index brute-force constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand All @@ -296,13 +327,22 @@ auto build(raft::resources const& handle,
* `index->size()` bits to indicate whether queries[0] should compute the distance with dataset.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<float, float>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

[[deprecated]] void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float, float>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the constructed index.
*
Expand All @@ -323,6 +363,7 @@ void search(raft::resources const& handle,
* @endcode
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index ivf-flat constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand All @@ -332,18 +373,28 @@ void search(raft::resources const& handle,
* given
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<half, float>& index,
raft::device_matrix_view<const half, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

[[deprecated]] void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<half, float>& index,
raft::device_matrix_view<const half, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @brief Search ANN using the constructed index.
*
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand All @@ -353,18 +404,28 @@ void search(raft::resources const& handle,
* given query
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<float, float>& index,
raft::device_matrix_view<const float, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

[[deprecated]] void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float, float>& index,
raft::device_matrix_view<const float, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @brief Search ANN using the constructed index.
*
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand All @@ -374,12 +435,21 @@ void search(raft::resources const& handle,
* given query
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<half, float>& index,
raft::device_matrix_view<const half, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

[[deprecated]] void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<half, float>& index,
raft::device_matrix_view<const half, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});
/**
* @}
*/
Expand Down Expand Up @@ -472,6 +542,7 @@ struct sparse_search_params {
* @brief Search the sparse bruteforce index for nearest neighbors
*
* @param[in] handle
* @param[in] params parameters configuring the search
* @param[in] index Sparse brute-force constructed index
* @param[in] queries a sparse CSR matrix on the device to query
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
Expand Down
Loading

0 comments on commit 2a10353

Please sign in to comment.