Skip to content

Commit b7e77c1

Browse files
committed
rename small_hash_size in single-CTA.
1 parent be6024d commit b7e77c1

File tree

4 files changed

+20
-20
lines changed

4 files changed

+20
-20
lines changed

cpp/src/neighbors/detail/cagra/search_single_cta.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ struct search : search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_T, Outp
258258
smem_size,
259259
hash_bitlen,
260260
hashmap.data(),
261-
small_hash_bitlen,
261+
(small_hash_bitlen > 0 ? 1u : 0u), // 转换为uint32_t
262262
small_hash_reset_interval,
263263
num_seeds,
264264
sample_filter,

cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search {
3838
uint32_t smem_size, \
3939
int64_t hash_bitlen, \
4040
IndexT* hashmap_ptr, \
41-
size_t small_hash_bitlen, \
41+
uint32_t use_small_hash, \
4242
size_t small_hash_reset_interval, \
4343
uint32_t num_seeds, \
4444
SampleFilterT sample_filter, \

cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ RAFT_DEVICE_INLINE_FUNCTION void hashmap_restore(INDEX_T* const hashmap_ptr,
540540
* @param max_iteration Maximum number of iterations.
541541
* @param num_executed_iterations Pointer to the number of executed iterations [num_queries].
542542
* @param hash_bitlen Bit length of the hash.
543-
* @param small_hash_bitlen Bit length of the small hash.
543+
* @param use_small_hash Whether to use local small hash table (1) or global hash table (0).
544544
* @param small_hash_reset_interval Interval for resetting the small hash.
545545
* @param query_id sequential id of the query in the batch
546546
*/
@@ -569,7 +569,7 @@ __device__ void search_core(
569569
const std::uint32_t max_iteration,
570570
std::uint32_t* const num_executed_iterations, // [num_queries]
571571
const std::uint32_t hash_bitlen,
572-
const std::uint32_t small_hash_bitlen,
572+
const std::uint32_t use_small_hash,
573573
const std::uint32_t small_hash_reset_interval,
574574
const std::uint32_t query_id,
575575
SAMPLE_FILTER_T sample_filter)
@@ -607,7 +607,7 @@ __device__ void search_core(
607607
// |<--- result_buffer_size --->|
608608
const auto result_buffer_size = internal_topk + (search_width * graph_degree);
609609
const auto result_buffer_size_32 = raft::round_up_safe<uint32_t>(result_buffer_size, 32);
610-
const auto small_hash_size = hashmap::get_size(small_hash_bitlen);
610+
const auto small_hash_size = use_small_hash ? hashmap::get_size(hash_bitlen) : 0;
611611

612612
// Set smem working buffer for the distance calculation
613613
dataset_desc = dataset_desc->setup_workspace(smem, queries_ptr, query_id);
@@ -634,7 +634,7 @@ __device__ void search_core(
634634

635635
// Init hashmap
636636
INDEX_T* local_visited_hashmap_ptr;
637-
if (small_hash_bitlen) {
637+
if (use_small_hash) {
638638
local_visited_hashmap_ptr = visited_hash_buffer;
639639
} else {
640640
local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * blockIdx.y);
@@ -986,7 +986,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
986986
const std::uint32_t max_iteration,
987987
std::uint32_t* const num_executed_iterations, // [num_queries]
988988
const std::uint32_t hash_bitlen,
989-
const std::uint32_t small_hash_bitlen,
989+
const std::uint32_t use_small_hash,
990990
const std::uint32_t small_hash_reset_interval,
991991
SAMPLE_FILTER_T sample_filter)
992992
{
@@ -1013,7 +1013,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
10131013
max_iteration,
10141014
num_executed_iterations,
10151015
hash_bitlen,
1016-
small_hash_bitlen,
1016+
use_small_hash,
10171017
small_hash_reset_interval,
10181018
query_id,
10191019
sample_filter);
@@ -1102,7 +1102,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel_p(
11021102
const std::uint32_t max_iteration,
11031103
std::uint32_t* const num_executed_iterations, // [num_queries]
11041104
const std::uint32_t hash_bitlen,
1105-
const std::uint32_t small_hash_bitlen,
1105+
const std::uint32_t use_small_hash,
11061106
const std::uint32_t small_hash_reset_interval,
11071107
SAMPLE_FILTER_T sample_filter)
11081108
{
@@ -1169,7 +1169,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel_p(
11691169
max_iteration,
11701170
num_executed_iterations,
11711171
hash_bitlen,
1172-
small_hash_bitlen,
1172+
use_small_hash,
11731173
small_hash_reset_interval,
11741174
query_id,
11751175
sample_filter);
@@ -1779,7 +1779,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b
17791779
uint32_t block_size, //
17801780
uint32_t smem_size,
17811781
int64_t hash_bitlen,
1782-
size_t small_hash_bitlen,
1782+
uint32_t use_small_hash,
17831783
size_t small_hash_reset_interval,
17841784
uint32_t num_random_samplings,
17851785
uint64_t rand_xor_mask,
@@ -1793,7 +1793,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b
17931793
float persistent_device_usage) -> uint64_t
17941794
{
17951795
return uint64_t(graph.data_handle()) ^ dataset_desc.get().team_size ^ num_itopk_candidates ^
1796-
block_size ^ smem_size ^ hash_bitlen ^ small_hash_reset_interval ^ num_random_samplings ^
1796+
block_size ^ smem_size ^ hash_bitlen ^ use_small_hash ^ small_hash_reset_interval ^ num_random_samplings ^
17971797
rand_xor_mask ^ num_seeds ^ itopk_size ^ search_width ^ min_iterations ^ max_iterations ^
17981798
uint64_t(persistent_lifetime * 1000) ^ uint64_t(persistent_device_usage * 1000);
17991799
}
@@ -1805,7 +1805,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b
18051805
uint32_t block_size, //
18061806
uint32_t smem_size,
18071807
int64_t hash_bitlen,
1808-
size_t small_hash_bitlen,
1808+
uint32_t use_small_hash,
18091809
size_t small_hash_reset_interval,
18101810
uint32_t num_random_samplings,
18111811
uint64_t rand_xor_mask,
@@ -1832,7 +1832,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b
18321832
block_size,
18331833
smem_size,
18341834
hash_bitlen,
1835-
small_hash_bitlen,
1835+
use_small_hash,
18361836
small_hash_reset_interval,
18371837
num_random_samplings,
18381838
rand_xor_mask,
@@ -1883,7 +1883,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b
18831883
}
18841884

18851885
index_type* hashmap_ptr = nullptr;
1886-
if (small_hash_bitlen == 0) {
1886+
if (!use_small_hash) {
18871887
hashmap.resize(gs.y * hashmap::get_size(hash_bitlen), stream);
18881888
hashmap_ptr = hashmap.data();
18891889
}
@@ -1912,7 +1912,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b
19121912
&max_iterations,
19131913
&num_executed_iterations,
19141914
&hash_bitlen,
1915-
&small_hash_bitlen,
1915+
&use_small_hash,
19161916
&small_hash_reset_interval,
19171917
&sample_filter};
19181918
cuda::atomic_thread_fence(cuda::memory_order_seq_cst, cuda::thread_scope_system);
@@ -2092,7 +2092,7 @@ void select_and_run(const dataset_descriptor_host<DataT, IndexT, DistanceT>& dat
20922092
uint32_t smem_size,
20932093
int64_t hash_bitlen,
20942094
IndexT* hashmap_ptr,
2095-
size_t small_hash_bitlen,
2095+
uint32_t use_small_hash,
20962096
size_t small_hash_reset_interval,
20972097
uint32_t num_seeds,
20982098
SampleFilterT sample_filter,
@@ -2112,7 +2112,7 @@ control is returned in this thread (in persistent_runner_t constructor), so we'r
21122112
block_size,
21132113
smem_size,
21142114
hash_bitlen,
2115-
small_hash_bitlen,
2115+
use_small_hash,
21162116
small_hash_reset_interval,
21172117
ps.num_random_samplings,
21182118
ps.rand_xor_mask,
@@ -2153,7 +2153,7 @@ control is returned in this thread (in persistent_runner_t constructor), so we'r
21532153
ps.max_iterations,
21542154
num_executed_iterations,
21552155
hash_bitlen,
2156-
small_hash_bitlen,
2156+
use_small_hash,
21572157
small_hash_reset_interval,
21582158
sample_filter);
21592159
RAFT_CUDA_TRY(cudaPeekAtLastError());

cpp/src/neighbors/detail/cagra/search_single_cta_kernel.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ void select_and_run(const dataset_descriptor_host<DataT, IndexT, DistanceT>& dat
3737
uint32_t smem_size,
3838
int64_t hash_bitlen,
3939
IndexT* hashmap_ptr,
40-
size_t small_hash_bitlen,
40+
uint32_t use_small_hash,
4141
size_t small_hash_reset_interval,
4242
uint32_t num_seeds,
4343
SampleFilterT sample_filter,

0 commit comments

Comments
 (0)