Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor improvements for arrow_filter_policy #654

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/cuco/bloom_filter_policies.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace cuco {
* By default, cuco::xxhash_64 hasher will be used.
*
*/
template <class Key, class XXHash64 = cuco::xxhash_64<Key>>
template <typename Key, template <typename> class XXHash64 = cuco::xxhash_64>
using arrow_filter_policy = detail::arrow_filter_policy<Key, XXHash64>;

/**
Expand Down
23 changes: 9 additions & 14 deletions include/cuco/detail/bloom_filter/arrow_filter_policy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ namespace cuco::detail {
* void bulk_insert_and_eval_arrow_policy_bloom_filter(device_vector<KeyType> const& positive_keys,
* device_vector<KeyType> const& negative_keys)
* {
* using xxhash_64 = cuco::xxhash_64<KeyType>;
* using policy_type = cuco::arrow_filter_policy<KeyType, xxhash_64>;
* using policy_type = cuco::arrow_filter_policy<KeyType, cuco::xxhash_64>;
*
* // Warn or throw if the number of filter blocks is greater than maximum used by Arrow policy.
* static_assert(NUM_FILTER_BLOCKS <= policy_type::max_filter_blocks, "NUM_FILTER_BLOCKS must be
Expand Down Expand Up @@ -81,14 +80,13 @@ namespace cuco::detail {
* @tparam Key The type of the values to generate a fingerprint for.
* @tparam XXHash64 64-bit XXHash hasher implementation for fingerprint generation.
*/
template <class Key, class XXHash64>
template <class Key, template <typename> class XXHash64>
class arrow_filter_policy {
public:
using hasher = XXHash64; ///< 64-bit XXHash hasher for Arrow bloom filter policy
using word_type = std::uint32_t; ///< uint32_t for Arrow bloom filter policy
using hash_argument_type = typename hasher::argument_type; ///< Hash function input type
using hash_result_type = decltype(std::declval<hasher>()(
std::declval<hash_argument_type>())); ///< hash function output type
using hasher = XXHash64<Key>; ///< 64-bit XXHash hasher for Arrow bloom filter policy
using word_type = std::uint32_t; ///< uint32_t for Arrow bloom filter policy
using key_type = Key; ///< Hash function input type
using hash_value_type = std::uint64_t; ///< hash function output type

static constexpr uint32_t bits_set_per_block = 8; ///< hardcoded bits set per Arrow filter block
static constexpr uint32_t words_per_block = 8; ///< hardcoded words per Arrow filter block
Expand Down Expand Up @@ -135,10 +133,7 @@ class arrow_filter_policy {
*
* @return The hash value of the key
*/
__device__ constexpr hash_result_type hash(hash_argument_type const& key) const
{
return hash_(key);
}
__device__ constexpr hash_value_type hash(key_type const& key) const { return hash_(key); }

/**
* @brief Determines the filter block a key is added into.
Expand All @@ -155,7 +150,7 @@ class arrow_filter_policy {
* @return The block index for the given key's hash value
*/
template <class Extent>
__device__ constexpr auto block_index(hash_result_type hash, Extent num_blocks) const
__device__ constexpr auto block_index(hash_value_type hash, Extent num_blocks) const
{
constexpr auto hash_bits = cuda::std::numeric_limits<word_type>::digits;
// TODO: assert if num_blocks > max_filter_blocks
Expand All @@ -173,7 +168,7 @@ class arrow_filter_policy {
*
* @return The bit pattern for the word/segment in the filter block
*/
__device__ constexpr word_type word_pattern(hash_result_type hash, std::uint32_t word_index) const
__device__ constexpr word_type word_pattern(hash_value_type hash, std::uint32_t word_index) const
{
// SALT array to calculate bit indexes for the current word
auto constexpr salt = SALT();
Expand Down
Loading