@@ -540,7 +540,7 @@ RAFT_DEVICE_INLINE_FUNCTION void hashmap_restore(INDEX_T* const hashmap_ptr,
540540 * @param max_iteration Maximum number of iterations.
541541 * @param num_executed_iterations Pointer to the number of executed iterations [num_queries].
542542 * @param hash_bitlen Bit length of the hash.
543- * @param small_hash_bitlen Bit length of the small hash.
543+ * @param use_small_hash Whether to use local small hash table (1) or global hash table (0) .
544544 * @param small_hash_reset_interval Interval for resetting the small hash.
545545 * @param query_id sequential id of the query in the batch
546546 */
@@ -569,7 +569,7 @@ __device__ void search_core(
569569 const std::uint32_t max_iteration,
570570 std::uint32_t * const num_executed_iterations, // [num_queries]
571571 const std::uint32_t hash_bitlen,
572- const std::uint32_t small_hash_bitlen ,
572+ const std::uint32_t use_small_hash ,
573573 const std::uint32_t small_hash_reset_interval,
574574 const std::uint32_t query_id,
575575 SAMPLE_FILTER_T sample_filter)
@@ -607,7 +607,7 @@ __device__ void search_core(
607607 // |<--- result_buffer_size --->|
608608 const auto result_buffer_size = internal_topk + (search_width * graph_degree);
609609 const auto result_buffer_size_32 = raft::round_up_safe<uint32_t >(result_buffer_size, 32 );
610- const auto small_hash_size = hashmap::get_size (small_hash_bitlen) ;
610+ const auto small_hash_size = use_small_hash ? hashmap::get_size (hash_bitlen) : 0 ;
611611
612612 // Set smem working buffer for the distance calculation
613613 dataset_desc = dataset_desc->setup_workspace (smem, queries_ptr, query_id);
@@ -634,7 +634,7 @@ __device__ void search_core(
634634
635635 // Init hashmap
636636 INDEX_T* local_visited_hashmap_ptr;
637- if (small_hash_bitlen ) {
637+ if (use_small_hash ) {
638638 local_visited_hashmap_ptr = visited_hash_buffer;
639639 } else {
640640 local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size (hash_bitlen) * blockIdx .y );
@@ -986,7 +986,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
986986 const std::uint32_t max_iteration,
987987 std::uint32_t * const num_executed_iterations, // [num_queries]
988988 const std::uint32_t hash_bitlen,
989- const std::uint32_t small_hash_bitlen ,
989+ const std::uint32_t use_small_hash ,
990990 const std::uint32_t small_hash_reset_interval,
991991 SAMPLE_FILTER_T sample_filter)
992992{
@@ -1013,7 +1013,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
10131013 max_iteration,
10141014 num_executed_iterations,
10151015 hash_bitlen,
1016- small_hash_bitlen ,
1016+ use_small_hash ,
10171017 small_hash_reset_interval,
10181018 query_id,
10191019 sample_filter);
@@ -1102,7 +1102,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel_p(
11021102 const std::uint32_t max_iteration,
11031103 std::uint32_t * const num_executed_iterations, // [num_queries]
11041104 const std::uint32_t hash_bitlen,
1105- const std::uint32_t small_hash_bitlen ,
1105+ const std::uint32_t use_small_hash ,
11061106 const std::uint32_t small_hash_reset_interval,
11071107 SAMPLE_FILTER_T sample_filter)
11081108{
@@ -1169,7 +1169,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel_p(
11691169 max_iteration,
11701170 num_executed_iterations,
11711171 hash_bitlen,
1172- small_hash_bitlen ,
1172+ use_small_hash ,
11731173 small_hash_reset_interval,
11741174 query_id,
11751175 sample_filter);
@@ -1779,7 +1779,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b
17791779 uint32_t block_size, //
17801780 uint32_t smem_size,
17811781 int64_t hash_bitlen,
1782- size_t small_hash_bitlen ,
1782+ uint32_t use_small_hash ,
17831783 size_t small_hash_reset_interval,
17841784 uint32_t num_random_samplings,
17851785 uint64_t rand_xor_mask,
@@ -1793,7 +1793,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b
17931793 float persistent_device_usage) -> uint64_t
17941794 {
17951795 return uint64_t (graph.data_handle ()) ^ dataset_desc.get ().team_size ^ num_itopk_candidates ^
1796- block_size ^ smem_size ^ hash_bitlen ^ small_hash_reset_interval ^ num_random_samplings ^
1796+ block_size ^ smem_size ^ hash_bitlen ^ use_small_hash ^ small_hash_reset_interval ^ num_random_samplings ^
17971797 rand_xor_mask ^ num_seeds ^ itopk_size ^ search_width ^ min_iterations ^ max_iterations ^
17981798 uint64_t (persistent_lifetime * 1000 ) ^ uint64_t (persistent_device_usage * 1000 );
17991799 }
@@ -1805,7 +1805,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b
18051805 uint32_t block_size, //
18061806 uint32_t smem_size,
18071807 int64_t hash_bitlen,
1808- size_t small_hash_bitlen ,
1808+ uint32_t use_small_hash ,
18091809 size_t small_hash_reset_interval,
18101810 uint32_t num_random_samplings,
18111811 uint64_t rand_xor_mask,
@@ -1832,7 +1832,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b
18321832 block_size,
18331833 smem_size,
18341834 hash_bitlen,
1835- small_hash_bitlen ,
1835+ use_small_hash ,
18361836 small_hash_reset_interval,
18371837 num_random_samplings,
18381838 rand_xor_mask,
@@ -1883,7 +1883,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b
18831883 }
18841884
18851885 index_type* hashmap_ptr = nullptr ;
1886- if (small_hash_bitlen == 0 ) {
1886+ if (!use_small_hash ) {
18871887 hashmap.resize (gs.y * hashmap::get_size (hash_bitlen), stream);
18881888 hashmap_ptr = hashmap.data ();
18891889 }
@@ -1912,7 +1912,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b
19121912 &max_iterations,
19131913 &num_executed_iterations,
19141914 &hash_bitlen,
1915- &small_hash_bitlen ,
1915+ &use_small_hash ,
19161916 &small_hash_reset_interval,
19171917 &sample_filter};
19181918 cuda::atomic_thread_fence (cuda::memory_order_seq_cst, cuda::thread_scope_system);
@@ -2092,7 +2092,7 @@ void select_and_run(const dataset_descriptor_host<DataT, IndexT, DistanceT>& dat
20922092 uint32_t smem_size,
20932093 int64_t hash_bitlen,
20942094 IndexT* hashmap_ptr,
2095- size_t small_hash_bitlen ,
2095+ uint32_t use_small_hash ,
20962096 size_t small_hash_reset_interval,
20972097 uint32_t num_seeds,
20982098 SampleFilterT sample_filter,
@@ -2112,7 +2112,7 @@ control is returned in this thread (in persistent_runner_t constructor), so we'r
21122112 block_size,
21132113 smem_size,
21142114 hash_bitlen,
2115- small_hash_bitlen ,
2115+ use_small_hash ,
21162116 small_hash_reset_interval,
21172117 ps.num_random_samplings ,
21182118 ps.rand_xor_mask ,
@@ -2153,7 +2153,7 @@ control is returned in this thread (in persistent_runner_t constructor), so we'r
21532153 ps.max_iterations ,
21542154 num_executed_iterations,
21552155 hash_bitlen,
2156- small_hash_bitlen ,
2156+ use_small_hash ,
21572157 small_hash_reset_interval,
21582158 sample_filter);
21592159 RAFT_CUDA_TRY (cudaPeekAtLastError ());
0 commit comments