Skip to content

Commit

Permalink
Merge branch 'branch-25.02' into java_mh_refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisHegarty authored Feb 5, 2025
2 parents 62583b3 + ddf1c8f commit 2b3f19a
Show file tree
Hide file tree
Showing 15 changed files with 771 additions and 6 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ target_compile_definitions(cuvs::cuvs INTERFACE $<$<BOOL:${CUVS_NVTX}>:NVTX_ENAB
src/neighbors/ivf_pq_c.cpp
src/neighbors/cagra_c.cpp
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/hnsw_c.cpp>
src/neighbors/nn_descent_c.cpp
src/neighbors/refine/refine_c.cpp
src/preprocessing/quantize/scalar_c.cpp
src/distance/pairwise_distance_c.cpp
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ void dispatch_benchmark(std::string cmdline,
}
}
std::swap(more_indices, indices); // update the config in case algorithms need to access it
register_build<T>(dataset, more_indices, force_overwrite, no_lap_sync);
register_build<T>(dataset, indices, force_overwrite, no_lap_sync);
} else if (search_mode) {
if (file_exists(query_file)) {
log_info("Using the query file '%s'", query_file.c_str());
Expand Down
4 changes: 2 additions & 2 deletions cpp/cmake/patches/faiss_override.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{
"packages" : {
"faiss" : {
"version": "1.7.4",
"version": "1.10.0",
"git_url": "https://github.com/facebookresearch/faiss.git",
"git_tag": "main"
"git_tag": "v1.10.0"
}
}
}
181 changes: 181 additions & 0 deletions cpp/include/cuvs/neighbors/nn_descent.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuvs/core/c_api.h>
#include <cuvs/distance/distance.h>
#include <dlpack/dlpack.h>
#include <stdbool.h>
#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

/**
* @defgroup nn_descent_c_index_params The nn-descent algorithm parameters.
* @{
*/
/**
* @brief Parameters used to build an nn-descent index
*
* `metric`: The distance metric to use
* `metric_arg`: The argument used by distance metrics like Minkowskidistance
* `graph_degree`: For an input dataset of dimensions (N, D),
* determines the final dimensions of the all-neighbors knn graph
* which turns out to be of dimensions (N, graph_degree)
* `intermediate_graph_degree`: Internally, nn-descent builds an
* all-neighbors knn graph of dimensions (N, intermediate_graph_degree)
* before selecting the final `graph_degree` neighbors. It's recommended
* that `intermediate_graph_degree` >= 1.5 * graph_degree
* `max_iterations`: The number of iterations that nn-descent will refine
* the graph for. More iterations produce a better quality graph at cost of performance
* `termination_threshold`: The delta at which nn-descent will terminate its iterations
*/
struct cuvsNNDescentIndexParams {
cuvsDistanceType metric;
float metric_arg;
size_t graph_degree;
size_t intermediate_graph_degree;
size_t max_iterations;
float termination_threshold;
bool return_distances;
size_t n_clusters;
};

typedef struct cuvsNNDescentIndexParams* cuvsNNDescentIndexParams_t;

/**
* @brief Allocate NN-Descent Index params, and populate with default values
*
* @param[in] index_params cuvsNNDescentIndexParams_t to allocate
* @return cuvsError_t
*/
cuvsError_t cuvsNNDescentIndexParamsCreate(cuvsNNDescentIndexParams_t* index_params);

/**
* @brief De-allocate NN-Descent Index params
*
* @param[in] index_params
* @return cuvsError_t
*/
cuvsError_t cuvsNNDescentIndexParamsDestroy(cuvsNNDescentIndexParams_t index_params);
/**
* @}
*/

/**
* @defgroup nn_descent_c_index NN-Descent index
* @{
*/
/**
* @brief Struct to hold address of cuvs::neighbors::nn_descent::index and its active trained dtype
*
*/
typedef struct {
uintptr_t addr;
DLDataType dtype;
} cuvsNNDescentIndex;

typedef cuvsNNDescentIndex* cuvsNNDescentIndex_t;

/**
* @brief Allocate NN-Descent index
*
* @param[in] index cuvsNNDescentIndex_t to allocate
* @return cuvsError_t
*/
cuvsError_t cuvsNNDescentIndexCreate(cuvsNNDescentIndex_t* index);

/**
* @brief De-allocate NN-Descent index
*
* @param[in] index cuvsNNDescentIndex_t to de-allocate
*/
cuvsError_t cuvsNNDescentIndexDestroy(cuvsNNDescentIndex_t index);
/**
* @}
*/

/**
* @defgroup nn_descent_c_index_build NN-Descent index build
* @{
*/
/**
* @brief Build a NN-Descent index with a `DLManagedTensor` which has underlying
* `DLDeviceType` equal to `kDLCUDA`, `kDLCUDAHost`, `kDLCUDAManaged`,
* or `kDLCPU`. Also, acceptable underlying types are:
* 1. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* 2. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 16`
* 3. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* 4. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
*
* @code {.c}
* #include <cuvs/core/c_api.h>
* #include <cuvs/neighbors/nn_descent.h>
*
* // Create cuvsResources_t
* cuvsResources_t res;
* cuvsError_t res_create_status = cuvsResourcesCreate(&res);
*
* // Assume a populated `DLManagedTensor` type here
* DLManagedTensor dataset;
*
* // Create default index params
* cuvsNNDescentIndexParams_t index_params;
* cuvsError_t params_create_status = cuvsNNDescentIndexParamsCreate(&index_params);
*
* // Create NN-Descent index
* cuvsNNDescentIndex_t index;
* cuvsError_t index_create_status = cuvsNNDescentIndexCreate(&index);
*
* // Build the NN-Descent Index
* cuvsError_t build_status = cuvsNNDescentBuild(res, index_params, &dataset, index);
*
* // de-allocate `index_params`, `index` and `res`
* cuvsError_t params_destroy_status = cuvsNNDescentIndexParamsDestroy(index_params);
* cuvsError_t index_destroy_status = cuvsNNDescentIndexDestroy(index);
* cuvsError_t res_destroy_status = cuvsResourcesDestroy(res);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] index_params cuvsNNDescentIndexParams_t used to build NN-Descent index
* @param[in] dataset DLManagedTensor* training dataset on host or device memory
* @param[inout] graph Optional preallocated graph on host memory to store output
* @param[out] index cuvsNNDescentIndex_t Newly built NN-Descent index
* @return cuvsError_t
*/
cuvsError_t cuvsNNDescentBuild(cuvsResources_t res,
cuvsNNDescentIndexParams_t index_params,
DLManagedTensor* dataset,
DLManagedTensor* graph,
cuvsNNDescentIndex_t index);
/**
* @}
*/

/**
* @brief Get the KNN graph from a built NN-Descent index
*
* @param[in] index cuvsNNDescentIndex_t Built NN-Descent index
* @param[inout] graph Optional preallocated graph on host memory to store output
* @return cuvsError_t
*/
cuvsError_t cuvsNNDescentIndexGetGraph(cuvsNNDescentIndex_t index, DLManagedTensor* graph);
#ifdef __cplusplus
}
#endif
167 changes: 167 additions & 0 deletions cpp/src/neighbors/nn_descent_c.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cstdint>
#include <dlpack/dlpack.h>

#include <raft/core/copy.hpp>
#include <raft/core/error.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/core/serialize.hpp>

#include <cuvs/core/c_api.h>
#include <cuvs/core/exceptions.hpp>
#include <cuvs/core/interop.hpp>
#include <cuvs/neighbors/nn_descent.h>
#include <cuvs/neighbors/nn_descent.hpp>

#include <fstream>

namespace {

template <typename T, typename IdxT = uint32_t>
void* _build(cuvsResources_t res,
cuvsNNDescentIndexParams params,
DLManagedTensor* dataset_tensor,
DLManagedTensor* graph_tensor)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto dataset = dataset_tensor->dl_tensor;

auto build_params = cuvs::neighbors::nn_descent::index_params();
build_params.metric = static_cast<cuvs::distance::DistanceType>((int)params.metric),
build_params.metric_arg = params.metric_arg;
build_params.graph_degree = params.graph_degree;
build_params.intermediate_graph_degree = params.intermediate_graph_degree;
build_params.max_iterations = params.max_iterations;
build_params.termination_threshold = params.termination_threshold;
build_params.return_distances = params.return_distances;
build_params.n_clusters = params.n_clusters;

using graph_type = raft::host_matrix_view<IdxT, int64_t, raft::row_major>;
std::optional<graph_type> graph;
if (graph_tensor != NULL) { graph = cuvs::core::from_dlpack<graph_type>(graph_tensor); }

if (cuvs::core::is_dlpack_device_compatible(dataset)) {
using dataset_type = raft::device_matrix_view<T const, int64_t, raft::row_major>;
auto dataset = cuvs::core::from_dlpack<dataset_type>(dataset_tensor);
auto index = cuvs::neighbors::nn_descent::build(*res_ptr, build_params, dataset, graph);
return new cuvs::neighbors::nn_descent::index<IdxT>(std::move(index));
} else if (cuvs::core::is_dlpack_host_compatible(dataset)) {
using dataset_type = raft::host_matrix_view<T const, int64_t, raft::row_major>;
auto dataset = cuvs::core::from_dlpack<dataset_type>(dataset_tensor);
auto index = cuvs::neighbors::nn_descent::build(*res_ptr, build_params, dataset, graph);
return new cuvs::neighbors::nn_descent::index<IdxT>(std::move(index));
} else {
RAFT_FAIL("dataset must be accessible on host or device memory");
}
}
} // namespace

extern "C" cuvsError_t cuvsNNDescentIndexCreate(cuvsNNDescentIndex_t* index)
{
return cuvs::core::translate_exceptions([=] { *index = new cuvsNNDescentIndex{}; });
}

extern "C" cuvsError_t cuvsNNDescentIndexDestroy(cuvsNNDescentIndex_t index_c_ptr)
{
return cuvs::core::translate_exceptions([=] {
auto index = *index_c_ptr;
if ((index.dtype.code == kDLUInt) && (index.dtype.bits == 32)) {
auto index_ptr = reinterpret_cast<cuvs::neighbors::nn_descent::index<uint32_t>*>(index.addr);
delete index_ptr;
} else {
RAFT_FAIL(
"Unsupported nn-descent index dtype: %d and bits: %d", index.dtype.code, index.dtype.bits);
}
delete index_c_ptr;
});
}

extern "C" cuvsError_t cuvsNNDescentBuild(cuvsResources_t res,
cuvsNNDescentIndexParams_t params,
DLManagedTensor* dataset_tensor,
DLManagedTensor* graph_tensor,
cuvsNNDescentIndex_t index)
{
return cuvs::core::translate_exceptions([=] {
index->dtype.code = kDLUInt;
index->dtype.bits = 32;

auto dtype = dataset_tensor->dl_tensor.dtype;

if ((dtype.code == kDLFloat) && (dtype.bits == 32)) {
index->addr = reinterpret_cast<uintptr_t>(
_build<float, uint32_t>(res, *params, dataset_tensor, graph_tensor));
} else if ((dtype.code == kDLFloat) && (dtype.bits == 16)) {
index->addr = reinterpret_cast<uintptr_t>(
_build<half, uint32_t>(res, *params, dataset_tensor, graph_tensor));
} else if ((dtype.code == kDLInt) && (dtype.bits == 8)) {
index->addr = reinterpret_cast<uintptr_t>(
_build<int8_t, uint32_t>(res, *params, dataset_tensor, graph_tensor));
} else if ((dtype.code == kDLUInt) && (dtype.bits == 8)) {
index->addr = reinterpret_cast<uintptr_t>(
_build<uint8_t, uint32_t>(res, *params, dataset_tensor, graph_tensor));
} else {
RAFT_FAIL("Unsupported nn-descent dataset dtype: %d and bits: %d", dtype.code, dtype.bits);
}
});
}

extern "C" cuvsError_t cuvsNNDescentIndexParamsCreate(cuvsNNDescentIndexParams_t* params)
{
return cuvs::core::translate_exceptions([=] {
// get defaults from cpp parameters struct
cuvs::neighbors::nn_descent::index_params cpp_params;

*params = new cuvsNNDescentIndexParams{
.metric = cpp_params.metric,
.metric_arg = cpp_params.metric_arg,
.graph_degree = cpp_params.graph_degree,
.intermediate_graph_degree = cpp_params.intermediate_graph_degree,
.max_iterations = cpp_params.max_iterations,
.termination_threshold = cpp_params.termination_threshold,
.return_distances = cpp_params.return_distances,
.n_clusters = cpp_params.n_clusters};
});
}

extern "C" cuvsError_t cuvsNNDescentIndexParamsDestroy(cuvsNNDescentIndexParams_t params)
{
return cuvs::core::translate_exceptions([=] { delete params; });
}

extern "C" cuvsError_t cuvsNNDescentIndexGetGraph(cuvsNNDescentIndex_t index,
DLManagedTensor* graph)
{
return cuvs::core::translate_exceptions([=] {
auto dtype = index->dtype;
if ((dtype.code == kDLUInt) && (dtype.bits == 32)) {
auto index_ptr = reinterpret_cast<cuvs::neighbors::nn_descent::index<uint32_t>*>(index->addr);
using output_mdspan_type = raft::host_matrix_view<uint32_t, int64_t, raft::row_major>;
auto dst = cuvs::core::from_dlpack<output_mdspan_type>(graph);
auto src = index_ptr->graph();

RAFT_EXPECTS(src.extent(0) == dst.extent(0), "Output graph has incorrect number of rows");
RAFT_EXPECTS(src.extent(1) == dst.extent(1), "Output graph has incorrect number of cols");
std::copy(src.data_handle(), src.data_handle() + dst.size(), dst.data_handle());
} else {
RAFT_FAIL("Unsupported nn-descent index dtype: %d and bits: %d", dtype.code, dtype.bits);
}
});
}
4 changes: 2 additions & 2 deletions python/cuvs/cuvs/distance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .distance import DISTANCE_TYPES, pairwise_distance
from .distance import DISTANCE_NAMES, DISTANCE_TYPES, pairwise_distance

__all__ = ["DISTANCE_TYPES", "pairwise_distance"]
__all__ = ["DISTANCE_NAMES", "DISTANCE_TYPES", "pairwise_distance"]
2 changes: 2 additions & 0 deletions python/cuvs/cuvs/distance/distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ DISTANCE_TYPES = {
"dice": cuvsDistanceType.DiceExpanded,
}

DISTANCE_NAMES = {v: k for k, v in DISTANCE_TYPES.items()}

SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product",
"chebyshev", "minkowski", "canberra", "kl_divergence",
"correlation", "russellrao", "hellinger", "lp",
Expand Down
Loading

0 comments on commit 2b3f19a

Please sign in to comment.