Skip to content

Commit

Permalink
ANN_BENCH enhanced dataset support (#624)
Browse files Browse the repository at this point in the history
Refactor the dataset module in the benchmark utility to add missing functionality:
  - [x] Data in managed memory (in addition to host/device/mmap/pinned)
  - [x] Basic filtering support: randomly generated bitset by setting 'filtering_rate' in the dataset config
  - [x] Partial support within CUVS algorithms (bitset only)
  - [ ] Support in all algorithms
  - [ ] Using files for bitset filtering
  - [x] Adjusting ground truth to account for the filtered data
  - [ ] Fine-grained control over where the bitset is located (like there is for the base set and query set)
  - [ ] Expose 2MB huge-pages support via config/cmd arguments
  - [ ] Add quantization as a dataset property

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #624
  • Loading branch information
achirkin authored Feb 1, 2025
1 parent 888a34f commit 88f0dfc
Show file tree
Hide file tree
Showing 18 changed files with 1,174 additions and 543 deletions.
17 changes: 16 additions & 1 deletion cpp/bench/ann/src/common/ann_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ enum class MemoryType {
kHostMmap,
kHostPinned,
kDevice,
kManaged,
};

/** Request 2MB huge pages support for an allocation */
enum class HugePages {
/** Don't use huge pages if possible. */
kDisable = 0,
/** Enable huge pages if possible, ignore otherwise. */
kAsk = 1,
/** Enable huge pages if possible, warn the user otherwise. */
kRequire = 2,
/** Force enable huge pages, throw an exception if not possible. */
kDemand = 3
};

enum class Metric {
Expand Down Expand Up @@ -65,6 +78,8 @@ inline auto parse_memory_type(const std::string& memory_type) -> MemoryType
return MemoryType::kHostPinned;
} else if (memory_type == "device") {
return MemoryType::kDevice;
} else if (memory_type == "managed") {
return MemoryType::kManaged;
} else {
throw std::runtime_error("invalid memory type: '" + memory_type + "'");
}
Expand Down Expand Up @@ -130,7 +145,7 @@ class algo : public algo_base {

virtual void build(const T* dataset, size_t nrow) = 0;

virtual void set_search_param(const search_param& param) = 0;
virtual void set_search_param(const search_param& param, const void* filter_bitset) = 0;
// TODO(snanditale): this assumes that an algorithm can always return k results.
// This is not always possible.
virtual void search(const T* queries,
Expand Down
59 changes: 45 additions & 14 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ void bench_search(::benchmark::State& state,
}
}
try {
a->set_search_param(*search_param);
a->set_search_param(*search_param,
dataset->filter_bitset(current_algo_props->dataset_memory_type));
} catch (const std::exception& ex) {
state.SkipWithError("An error occurred setting search parameters: " + std::string(ex.what()));
return;
Expand Down Expand Up @@ -359,13 +360,19 @@ void bench_search(::benchmark::State& state,
// Each thread calculates recall on their partition of queries.
// evaluate recall
if (dataset->max_k() >= k) {
const std::int32_t* gt = dataset->gt_set();
const std::int32_t* gt = dataset->gt_set();
const std::uint32_t* filter_bitset = dataset->filter_bitset(MemoryType::kHostMmap);
auto filter = [filter_bitset](std::int32_t i) -> bool {
if (filter_bitset == nullptr) { return true; }
auto word = filter_bitset[i >> 5];
return word & (1 << (i & 31));
};
const std::uint32_t max_k = dataset->max_k();
result_buf.transfer_data(MemoryType::kHost, current_algo_props->query_memory_type);
auto* neighbors_host = reinterpret_cast<index_type*>(result_buf.data(MemoryType::kHost));
std::size_t rows = std::min(queries_processed, query_set_size);
std::size_t match_count = 0;
std::size_t total_count = rows * static_cast<size_t>(k);
std::size_t total_count = 0;

// We go through the groundtruth with same stride as the benchmark loop.
size_t out_offset = 0;
Expand All @@ -375,22 +382,44 @@ void bench_search(::benchmark::State& state,
size_t i_orig_idx = batch_offset + i;
size_t i_out_idx = out_offset + i;
if (i_out_idx < rows) {
for (std::uint32_t j = 0; j < k; j++) {
auto act_idx = static_cast<std::int32_t>(neighbors_host[i_out_idx * k + j]);
for (std::uint32_t l = 0; l < k; l++) {
auto exp_idx = gt[i_orig_idx * max_k + l];
/* NOTE: recall correctness & filtering
In the loop below, we filter the ground truth values on-the-fly.
We need enough ground truth values to compute recall correctly though.
But the ground truth file only contains `max_k` values per row; if there are less valid
values than k among them, we overestimate the recall. Essentially, we compare the first
`filter_pass_count` values of the algorithm output, and this counter can be less than `k`.
In the extreme case of very high filtering rate, we may be bypassing entire rows of
results. However, this is still better than no recall estimate at all.
TODO: consider generating the filtered ground truth on-the-fly
*/
uint32_t filter_pass_count = 0;
for (std::uint32_t l = 0; l < max_k && filter_pass_count < k; l++) {
auto exp_idx = gt[i_orig_idx * max_k + l];
if (!filter(exp_idx)) { continue; }
filter_pass_count++;
for (std::uint32_t j = 0; j < k; j++) {
auto act_idx = static_cast<std::int32_t>(neighbors_host[i_out_idx * k + j]);
if (act_idx == exp_idx) {
match_count++;
break;
}
}
}
total_count += filter_pass_count;
}
}
out_offset += n_queries;
batch_offset = (batch_offset + queries_stride) % query_set_size;
}
double actual_recall = static_cast<double>(match_count) / static_cast<double>(total_count);
/* NOTE: recall in the throughput mode & filtering
When filtering is enabled, `total_count` may vary between individual threads, but we still take
the simple average across in-thread recalls. Strictly speaking, this is incorrect, but it's good
enough under assumption that the filtering is more-or-less uniform.
*/
state.counters.insert({"Recall", {actual_recall, benchmark::Counter::kAvgThreads}});
}
}
Expand Down Expand Up @@ -515,13 +544,15 @@ void dispatch_benchmark(std::string cmdline,
auto query_file = combine_path(data_prefix, dataset_conf.query_file);
auto gt_file = dataset_conf.groundtruth_neighbors_file;
if (gt_file.has_value()) { gt_file.emplace(combine_path(data_prefix, gt_file.value())); }
auto dataset = std::make_shared<bin_dataset<T>>(dataset_conf.name,
base_file,
dataset_conf.subset_first_row,
dataset_conf.subset_size,
query_file,
dataset_conf.distance,
gt_file);
auto dataset =
std::make_shared<bench::dataset<T>>(dataset_conf.name,
base_file,
dataset_conf.subset_first_row,
dataset_conf.subset_size,
query_file,
dataset_conf.distance,
gt_file,
search_mode ? dataset_conf.filtering_rate : std::nullopt);
::benchmark::AddCustomContext("dataset", dataset_conf.name);
::benchmark::AddCustomContext("distance", dataset_conf.distance);
std::vector<configuration::index> indices = conf.get_indices();
Expand Down
Loading

0 comments on commit 88f0dfc

Please sign in to comment.