Skip to content

Commit

Permalink
Improve filtering documentation (#568)
Browse files Browse the repository at this point in the history
This PR add a dedicated documentation page for filtering in the `Getting started` tab, and add the `cuvs::neighbors::filtering` namespace to the C++ documentation

Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #568
  • Loading branch information
lowener authored Jan 31, 2025
1 parent c778c88 commit 8eca524
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 23 deletions.
7 changes: 7 additions & 0 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,11 @@ inline constexpr bool is_vpq_dataset_v = is_vpq_dataset<DatasetT>::value;

namespace filtering {

/**
* @defgroup neighbors_filtering Filtering for ANN Types
* @{
*/

enum class FilterType { None, Bitmap, Bitset };

struct base_filter {
Expand Down Expand Up @@ -567,6 +572,8 @@ struct bitset_filter : public base_filter {
void to_csr(raft::resources const& handle, csr_matrix_t& csr);
};

/** @} */ // end group neighbors_filtering

/**
* If the filtering depends on the index of a sample, then the following
* filter template can be used:
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,7 @@ auto build(raft::resources const& handle,
"Unsupported data type");

std::cout << "using ivf_pq::index_params nrows " << (int)dataset.extent(0) << ", dim "
<< (int)dataset.extent(1) << ", n_lits " << (int)params.n_lists << ", pq_dim "
<< (int)dataset.extent(1) << ", n_lists " << (int)params.n_lists << ", pq_dim "
<< (int)params.pq_dim << std::endl;
RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset");
RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists");
Expand Down
1 change: 1 addition & 0 deletions docs/source/cpp_api/neighbors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Nearest Neighbors
neighbors_bruteforce.rst
neighbors_cagra.rst
neighbors_dynamic_batching.rst
neighbors_filter.rst
neighbors_hnsw.rst
neighbors_ivf_flat.rst
neighbors_ivf_pq.rst
Expand Down
2 changes: 1 addition & 1 deletion docs/source/cpp_api/neighbors_bruteforce.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The bruteforce method is running the KNN algorithm. It performs an extensive sea
:language: c++
:class: highlight

``#include <cuvs/neighbors/bruteforce.hpp>``
``#include <cuvs/neighbors/brute_force.hpp>``

namespace *cuvs::neighbors::bruteforce*

Expand Down
18 changes: 18 additions & 0 deletions docs/source/cpp_api/neighbors_filter.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Filtering
==========

All nearest neighbors search methods support filtering. Filtering is a method to reduce the number
of candidates that are considered for the nearest neighbors search.

.. role:: py(code)
:language: c++
:class: highlight

``#include <cuvs/neighbors/common.hpp>``

namespace *cuvs::neighbors*

.. doxygengroup:: neighbors_filtering
:project: cuvs
:members:
:content-only:
116 changes: 116 additions & 0 deletions docs/source/filtering.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
.. _filtering:

~~~~~~~~~~~~~~~~~~~~~~~~
Filtering vector indexes
~~~~~~~~~~~~~~~~~~~~~~~~

cuVS supports different type of filtering depending on the vector index being used. The main method used in all of the vector indexes
is pre-filtering, which is a technique that will into account the filtering of the vectors before computing it's closest neighbors, saving
some computation from calculating distances.

Bitset
======

A bitset is an array of bits where each bit can have two possible values: `0` and `1`, which signify in the context of filtering whether
a sample should be filtered or not. `0` means that the corresponding vector will be filtered, and will therefore not be present in the results of the search.
This mechanism is optimized to take as little memory space as possible, and is available through the RAFT library
(check out RAFT's `bitset API documentation <https://docs.rapids.ai/api/raft/stable/cpp_api/core_bitset/>`). When calling a search function of an ANN index, the
bitset length should match the number of vectors present in the database.

Bitmap
======

A bitmap is based on the same principle as a bitset, but in two dimensions. This allows users to provide a different bitset for each query
being searched. Check out RAFT's `bitmap API documentation <https://docs.rapids.ai/api/raft/stable/cpp_api/core_bitmap/>`.

Examples
=======

Using a Bitset filter on a CAGRA index
--------------------------------------

.. code-block:: c++

#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/core/bitset.hpp>

using namespace cuvs::neighbors;
cagra::index index;

// ... build index ...

cagra::search_params search_params;
raft::device_resources res;
raft::device_matrix_view<float> queries = load_queries();
raft::device_matrix_view<uint32_t> neighbors = make_device_matrix_view<uint32_t>(n_queries, k);
raft::device_matrix_view<float> distances = make_device_matrix_view<float>(n_queries, k);

// Load a list of all the samples that will get filtered
std::vector<uint32_t> removed_indices_host = get_invalid_indices();
auto removed_indices_device =
raft::make_device_vector<uint32_t, uint32_t>(res, removed_indices_host.size());
// Copy this list to device
raft::copy(removed_indices_device.data_handle(), removed_indices_host.data(),
removed_indices_host.size(), raft::resource::get_cuda_stream(res));

// Create a bitset with the list of samples to filter.
cuvs::core::bitset<uint32_t, uint32_t> removed_indices_bitset(
res, removed_indices_device.view(), index.size());
// Use a `bitset_filter` in the `cagra::search` function call.
auto bitset_filter =
cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view());
cagra::search(res,
search_params,
index,
queries,
neighbors,
distances,
bitset_filter);


Using a Bitmap filter on a Brute-force index
--------------------------------------------

.. code-block:: c++

#include <cuvs/neighbors/brute_force.hpp>
#include <cuvs/core/bitmap.hpp>

using namespace cuvs::neighbors;
using indexing_dtype = int64_t;

// ... build index ...
brute_force::index_params index_params;
brute_force::search_params search_params;
raft::device_resources res;
raft::device_matrix_view<float, indexing_dtype> dataset = load_dataset(n_vectors, dim);
raft::device_matrix_view<float, indexing_dtype> queries = load_queries(n_queries, dim);
auto index = brute_force::build(res, index_params, raft::make_const_mdspan(dataset.view()));

// Load a list of all the samples that will get filtered
std::vector<uint32_t> removed_indices_host = get_invalid_indices();
auto removed_indices_device =
raft::make_device_vector<uint32_t, uint32_t>(res, removed_indices_host.size());
// Copy this list to device
raft::copy(removed_indices_device.data_handle(), removed_indices_host.data(),
removed_indices_host.size(), raft::resource::get_cuda_stream(res));

// Create a bitmap with the list of samples to filter.
cuvs::core::bitset<uint32_t, indexing_dtype> removed_indices_bitset(
res, removed_indices_device.view(), n_queries * n_vectors);
cuvs::core::bitmap_view<const uint32_t, indexing_dtype> removed_indices_bitmap(
removed_indices_bitset.data(), n_queries, n_vectors);

// Use a `bitmap_filter` in the `brute_force::search` function call.
auto bitmap_filter =
cuvs::neighbors::filtering::bitmap_filter(removed_indices_bitmap);

auto neighbors = raft::make_device_matrix_view<uint32_t, indexing_dtype>(n_queries, k);
auto distances = raft::make_device_matrix_view<float, indexing_dtype>(n_queries, k);
brute_force::search(res,
search_params,
index,
raft::make_const_mdspan(queries.view()),
neighbors.view(),
distances.view(),
bitmap_filter);
2 changes: 2 additions & 0 deletions docs/source/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,5 @@ We always welcome patches for new features and bug fixes. Please read our `contr
indexes/indexes.rst
api_basics.rst
api_interoperability.rst
working_with_ann_indexes.rst
filtering.rst
6 changes: 3 additions & 3 deletions docs/source/indexes/bruteforce.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Brute-force can also be a good choice for heavily filtered queries where other a
when filtering out 90%-95% of the vectors from a search, the IVF methods could struggle to return anything at all with smaller number of probes and
graph-based algorithms with limited hash table memory could end up skipping over important unfiltered entries.

[ :doc:`C API <../c_api/neighbors_bruteforce_c>` | :doc:`C++ API <../cpp_api/neighbors_bruteforce>` | :doc:`Python API <../python_api/neighbors_bruteforce>` | :doc:`Rust API <../rust_api/index>` ]
[ :doc:`C API <../c_api/neighbors_bruteforce_c>` | :doc:`C++ API <../cpp_api/neighbors_bruteforce>` | :doc:`Python API <../python_api/neighbors_brute_force>` | :doc:`Rust API <../rust_api/index>` ]

Filtering considerations
------------------------
Expand Down Expand Up @@ -57,6 +57,6 @@ Memory footprint
Index footprint
~~~~~~~~~~~~~~~

Raw vectors: :math:`n_vectors * n_dimensions * precision`
Raw vectors: :math:`n\_vectors * n\_dimensions * precision`

Vector norms (for distances which require them): :math:`n_vectors * precision`
Vector norms (for distances which require them): :math:`n\_vectors * precision`
24 changes: 14 additions & 10 deletions docs/source/indexes/cagra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,22 @@ IVFPQ or NN-DESCENT can be used to build the graph (additions to the peak memory
Dataset on device (graph on host):
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Index memory footprint (device): :math:`n_index_vectors * n_dims * sizeof(T)`
Index memory footprint (device): :math:`n\_index\_vectors * n\_dims * sizeof(T)`

Index memory footprint (host): :math:`graph_degree * n_index_vectors * sizeof(T)``
Index memory footprint (host): :math:`graph\_degree * n\_index\_vectors * sizeof(T)``

Dataset on host (graph on host):
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Index memory footprint (host): :math:`n_index_vectors * n_dims * sizeof(T) + graph_degree * n_index_vectors * sizeof(T)`
Index memory footprint (host): :math:`n\_index\_vectors * n\_dims * sizeof(T) + graph\_degree * n\_index\_vectors * sizeof(T)`

Build peak memory usage:
~~~~~~~~~~~~~~~~~~~~~~~~

When built using NN-descent / IVF-PQ, the build process consists of two phases: (1) building an initial/(intermediate) graph and then (2) optimizing the graph. Key input parameters are n_vectors, intermediate_graph_degree, graph_degree.
The memory usage in the first phase (building) depends on the chosen method. The biggest allocation is the graph (n_vectors*intermediate_graph_degree), but it’s stored in the host memory.
Usually, the second phase (optimize) uses the most device memory. The peak memory usage is achieved during the pruning step (graph_core.cuh/optimize)
Optimize: formula for peak memory usage (device): :math:`n_vectors * (4 + (sizeof(IdxT) + 1) * intermediate_degree)``
Optimize: formula for peak memory usage (device): :math:`n\_vectors * (4 + (sizeof(IdxT) + 1) * intermediate_degree)``

Build with out-of-core IVF-PQ peak memory usage:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -134,14 +134,18 @@ IVF-PQ Build:

.. math::
n_vectors / train_set_ratio * dim * sizeof(float) // trainset, may be in managed mem
+ n_vectors / train_set_ratio * sizeof(uint32_t) // labels, may be in managed mem
+ n_clusters * n_dim * sizeof(float) // cluster centers
n\_vectors / train\_set\_ratio * dim * sizeof_{float} // trainset, may be in managed mem
+ n\_vectors / train\_set\_ratio * sizeof(uint32_t) // labels, may be in managed mem
+ n\_clusters * n\_dim * sizeof_{float} // cluster centers
IVF-PQ Search (max batch size 1024 vectors on device at a time):

.. math::
[n_vectors * (pq_dim * pq_bits / 8 + sizeof(int64_t)) + O(n_clusters)]
+ [batch_size * n_dim * sizeof(float)] + [batch_size * intermediate_degree * sizeof(uint32_t)] +
[batch_size * intermediate_degree * sizeof(float)]
[n\_vectors * (pq\_dim * pq\_bits / 8 + sizeof_{int64\_t}) + O(n\_clusters)]
+ [batch\_size * n\_dim * sizeof_{float}] + [batch\_size * intermediate\_degree * sizeof_{uint32\_t}]
+ [batch\_size * intermediate\_degree * sizeof_{float}]
4 changes: 2 additions & 2 deletions docs/source/indexes/ivfflat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Memory footprint
----------------

Each cluster is padded to at least 32 vectors (but potentially up to 1024). Assuming uniform random distribution of vectors/list, we would have
:math:`cluster\_overhead = (conservative\_memory\_allocation ? 16 : 512 ) * dim * sizeof_{float})`
:math:`cluster\_overhead = (conservative\_memory\_allocation ? 16 : 512 ) * dim * sizeof_{float}`

Note that each cluster is allocated as a separate allocation. If we use a `cuda_memory_resource`, that would grab memory in 1 MiB chunks, so on average we might have 0.5 MiB overhead per cluster. If we us 10s of thousands of clusters, it becomes essential to use pool allocator to avoid this overhead.

Expand All @@ -110,6 +110,6 @@ Index (device memory):
Peak device memory usage for index build:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

:math:`workspace = min(1GB, n\_queries * [(n\_lists + 1 + n\_probes * (k + 1)) * sizeof_{float}) + n\_probes * k * sizeof_{idx}])`
:math:`workspace = min(1GB, n\_queries * [(n\_lists + 1 + n\_probes * (k + 1)) * sizeof_{float} + n\_probes * k * sizeof_{idx}])`

:math:`index\_size + workspace`
12 changes: 6 additions & 6 deletions docs/source/indexes/ivfpq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,22 @@ Simple approximate formula: :math:`n\_vectors * (pq\_dim * \frac{pq\_bits}{8} +

The IVF lists end up being represented by a sparse data structure that stores the pointers to each list, an indices array that contains the indexes of each vector in each list, and an array with the encoded (and interleaved) data for each list.

IVF list pointers: :math:`n\_clusters * sizeof_{uint32_t}`
IVF list pointers: :math:`n\_clusters * sizeof_{uint32\_t}`

Indices: :math:`n\_vectors * sizeof_{idx}``
Indices: :math:`n\_vectors * sizeof_{idx}`

Encoded data (interleaved): :math:`n\_vectors * pq\_dim * \frac{pq\_bits}{8}`

Per subspace method: :math:`4 * pq\_dim * pq\_len * 2^pq\_bits`
Per subspace method: :math:`4 * pq\_dim * pq\_len * 2^{pq\_bits}`

Per cluster method: :math:`4 * n\_clusters * pq\_len * 2^pq\_bits`
Per cluster method: :math:`4 * n\_clusters * pq\_len * 2^{pq\_bits}`

Extras: :math:`n\_clusters * (20 + 8 * dim)`

Index (host memory):
~~~~~~~~~~~~~~~~~~~~

When refinement is used with the dataset on host, the original raw vectors are needed: :math:`n\_vectors * dims * sizeof_{Tloat}`
When refinement is used with the dataset on host, the original raw vectors are needed: :math:`n\_vectors * dims * sizeof_{float}`

Search peak memory usage (device);
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -128,7 +128,7 @@ Build peak memory usage (device):
\frac{n\_vectors}{trainset\_ratio * dims * sizeof_{float}}
+ \frac{n\_vectors}{trainset\_ratio * sizeof_{uint32_t}}
+ \frac{n\_vectors}{trainset\_ratio * sizeof_{uint32\_t}}
+ n\_clusters * dim * sizeof_{float}
Expand Down
4 changes: 4 additions & 0 deletions examples/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ set(BUILD_CUVS_C_LIBRARY OFF)
include(../cmake/thirdparty/get_cuvs.cmake)

# -------------- compile tasks ----------------- #
add_executable(BRUTE_FORCE_EXAMPLE src/brute_force_bitmap.cu)
add_executable(CAGRA_EXAMPLE src/cagra_example.cu)
add_executable(CAGRA_PERSISTENT_EXAMPLE src/cagra_persistent_example.cu)
add_executable(DYNAMIC_BATCHING_EXAMPLE src/dynamic_batching_example.cu)
Expand All @@ -48,6 +49,9 @@ add_executable(VAMANA_EXAMPLE src/vamana_example.cu)
add_library(rmm_logger OBJECT)
target_link_libraries(rmm_logger PRIVATE rmm::rmm_logger_impl)

target_link_libraries(
BRUTE_FORCE_EXAMPLE PRIVATE cuvs::cuvs $<TARGET_NAME_IF_EXISTS:conda_env> rmm_logger
)
target_link_libraries(
CAGRA_EXAMPLE PRIVATE cuvs::cuvs $<TARGET_NAME_IF_EXISTS:conda_env> rmm_logger
)
Expand Down
Loading

0 comments on commit 8eca524

Please sign in to comment.