Skip to content

Commit 24200b3

Browse files
committed
support precomputed distances
1 parent f09556d commit 24200b3

File tree

11 files changed

+726
-335
lines changed

11 files changed

+726
-335
lines changed

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,19 +127,23 @@ Computing $k$-MSTs using KDTrees can be expensive on some datasets. We provide a
127127
version of the algorithm based on Nearest Neighbor Descent for quicker
128128
approximations. We combined Boruvka's algorithm with NNDescent to find neighbors
129129
that are not already connected in the MST being build. This variant supports all
130-
distance metrics implemented in `pynndescent`. Combined with `fast_hdbscan`'s
131-
cluster selection, it can greatly speed up computing (approximate) clusters on
132-
high dimensional data-sets!
130+
distance metrics implemented in `pynndescent` and precomputed distances.
131+
Combined with `fast_hdbscan`'s cluster selection, it can greatly speed up
132+
computing (approximate) clusters on high dimensional data-sets!
133133

134134

135135
```python
136136
import matplotlib.pyplot as plt
137137
import matplotlib.collections as mc
138138
from sklearn.datasets import make_swiss_roll
139139
from multi_mst import KMSTDescent
140+
from scipy.spatial.distance import pdist, squareform
140141

141142
X, t = make_swiss_roll(n_samples=2000, noise=0.5, hole=True)
142-
model = KMSTDescent(num_neighbors=3, epsilon=2.0).fit(X)
143+
model = KMSTDescent(num_neighbors=3, epsilon=2.0).fit(X) # or
144+
model = KMSTDescent(
145+
num_neighbors=3, epsilon=2.0, metric="precomputed"
146+
).fit(squareform(pdist(X)))
143147
projector = model.umap(repulsion_strength=1.0)
144148

145149
# Draw the network

src/multi_mst/base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def fit(self, X, y=None, **fit_params):
9494
self: MultiMSTMixin
9595
The fitted estimator.
9696
"""
97-
X = check_array(X, ensure_all_finite=False)
97+
X = check_array(X, ensure_all_finite=self.metric == "precomputed")
9898
self._raw_data = X
9999

100100
self._all_finite = np.all(np.isfinite(X))
@@ -581,6 +581,10 @@ def tsne(
581581
"barnes_hut algorithm as it relies on "
582582
"quad-tree or oct-tree."
583583
)
584+
if init == "pca" and self.metric == "precomputed":
585+
raise ValueError(
586+
"PCA initialization not supported with precomputed metric."
587+
)
584588

585589
# Extract raw data
586590
X = self._raw_data
@@ -885,7 +889,7 @@ def hbcc(
885889
if hop_type == "metric":
886890
raise ValueError(
887891
'BoundaryClusterDetector requires "euclidean" '
888-
'metric with `hop_type="manifold".'
892+
'metric with `hop_type="metric".'
889893
)
890894
if not boundary_use_reachability:
891895
raise ValueError(
@@ -1063,6 +1067,8 @@ def branch_detector(
10631067
flare-sensitive clustering algorithm. PeerJ Computer Science 11:e2792
10641068
https://doi.org/10.7717/peerj-cs.2792.
10651069
"""
1070+
if self.metric == "precomputed":
1071+
raise ValueError("BranchDetector cannot be used with precomputed metric.")
10661072

10671073
return BranchDetector(
10681074
metric=self.metric,
@@ -1204,7 +1210,7 @@ def graphviz_layout(self, prog="sfdp", **kwargs):
12041210
The graphviz program to run.
12051211
**kwargs
12061212
Additional arguments to `networkx.nx_agraph.graphviz_layout`.
1207-
1213+
12081214
Returns
12091215
-------
12101216
coords : ndarray of shape (num_points, 2)

src/multi_mst/kmst_descent.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
from .base import MultiMSTMixin
9-
from .lib import multi_boruvka, DescentIndex, make_csr_graph
9+
from .lib import multi_boruvka, DescentIndex, PrecomputedIndex, make_csr_graph
1010
from .kmst import validate_parameters as validate_parameters_kmst
1111

1212

@@ -159,7 +159,9 @@ def kMSTDescent(
159159
it must be a numba njit compiled function. See the pynndescent docs for
160160
supported metrics. Metrics that take arguments (such as minkowski,
161161
mahalanobis etc.) can have arguments passed via the metric_kwds
162-
dictionary.
162+
dictionary. Precomputed distances can be passed to `data` as a 1D
163+
condensed or 2D square array, in which case the metric must be
164+
'precomputed'.
163165
164166
metric_kwds: dict (optional, default {})
165167
Arguments to pass on to the metric, such as the ``p`` value for
@@ -203,15 +205,24 @@ def kMSTDescent(
203205
data, epsilon = validate_parameters(
204206
data, num_neighbors, min_samples, epsilon, min_descent_neighbors
205207
)
206-
index = DescentIndex(
207-
data,
208-
metric,
209-
metric_kwds,
210-
num_neighbors,
211-
min_samples,
212-
min_descent_neighbors,
213-
nn_kwargs,
214-
)
208+
if metric == "precomputed":
209+
index = PrecomputedIndex(
210+
data,
211+
num_neighbors,
212+
min_samples,
213+
min_descent_neighbors,
214+
nn_kwargs,
215+
)
216+
else:
217+
index = DescentIndex(
218+
data,
219+
metric,
220+
metric_kwds,
221+
num_neighbors,
222+
min_samples,
223+
min_descent_neighbors,
224+
nn_kwargs,
225+
)
215226
mst_edges, k_edges, neighbors, distances = multi_boruvka(index, epsilon)
216227
graph = make_csr_graph(mst_edges, k_edges, neighbors.shape[0])
217228
return (graph, mst_edges, neighbors, distances)

src/multi_mst/lib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .kdtree import KDTreeIndex
33
from .kdtree_noisy import NoisyKDTreeIndex
44
from .descent import DescentIndex
5+
from .precomputed_descent import PrecomputedIndex
56
from .descent_recall import DescentRecallIndex
67
from .branches import BranchDetector
78
from .graph import (
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import numba
2+
import numpy as np
3+
from sklearn.utils import check_random_state
4+
from scipy.spatial.distance import squareform
5+
from pynndescent.pynndescent_ import INT32_MIN, INT32_MAX
6+
from pynndescent.utils import (
7+
tau_rand_int,
8+
make_heap,
9+
deheap_sort,
10+
simple_heap_push,
11+
checked_heap_push,
12+
)
13+
14+
from .descent import (
15+
_sample_in_candidates,
16+
_sample_out_candidates,
17+
_apply_graph_updates,
18+
_group_indices_per_component,
19+
)
20+
21+
22+
class PrecomputedIndex(object):
23+
def __init__(
24+
self,
25+
distances,
26+
num_neighbors=5,
27+
min_samples=1,
28+
min_descent_neighbors=12,
29+
nn_kwargs=None,
30+
):
31+
if nn_kwargs is None:
32+
nn_kwargs = {}
33+
34+
self.distances = distances
35+
self.num_points = distances.shape[0]
36+
self.num_neighbors = num_neighbors
37+
self.min_samples = min_samples
38+
self.n_threads = numba.get_num_threads()
39+
self.descent_neighbors = max(num_neighbors + 1, min_descent_neighbors)
40+
self.n_iters = nn_kwargs.get(
41+
"n_iters", max(5, int(round(np.log2(self.num_points))))
42+
)
43+
self.delta = nn_kwargs.get("delta", 0.001)
44+
self.rng_state = nn_kwargs.get(
45+
"rng_state",
46+
check_random_state(None).randint(INT32_MIN, INT32_MAX, 3).astype(np.int64),
47+
)
48+
49+
def neighbors(self):
50+
neighbors = np.argpartition(
51+
self.distances, np.arange(self.num_neighbors + 1), axis=1
52+
)[:, : self.num_neighbors + 1]
53+
distances = np.take_along_axis(self.distances, neighbors, axis=1)
54+
55+
self.in_graph = neighbors[:, 1:]
56+
self._neighbors = neighbors[:, 1:]
57+
self.core_distances = distances.T[self.min_samples]
58+
self._distances = np.maximum(
59+
distances[:, 1:],
60+
np.maximum(
61+
self.core_distances[:, None],
62+
self.core_distances[self._neighbors],
63+
),
64+
)
65+
66+
return (
67+
distances[:, : self.num_neighbors + 1],
68+
neighbors[:, : self.num_neighbors + 1],
69+
)
70+
71+
def query(self, point_components):
72+
heap_graph, remapped_components = initialize_out_graph(
73+
self.distances,
74+
self.core_distances,
75+
(self._neighbors, self._distances),
76+
point_components,
77+
self.rng_state,
78+
)
79+
precomputed_descent(
80+
self.distances,
81+
self.core_distances,
82+
self.in_graph,
83+
heap_graph,
84+
remapped_components,
85+
min(60, self.descent_neighbors),
86+
3 * self.n_iters,
87+
self.delta,
88+
self.rng_state,
89+
self.n_threads,
90+
)
91+
92+
self._neighbors, self._distances = deheap_sort(*heap_graph[:2])
93+
return self._distances, self._neighbors
94+
95+
def correction(self, distances):
96+
return distances
97+
98+
99+
@numba.njit(
100+
parallel=True,
101+
locals={
102+
"cnt": numba.int32,
103+
"idx": numba.int32,
104+
"size": numba.int32,
105+
"k": numba.int32,
106+
"d": numba.float32,
107+
},
108+
)
109+
def initialize_out_graph(distances, core_distances, graph, point_components, rng_state):
110+
"""Replaces neighbors in the same component with random points from other components."""
111+
# Create empty heap to size
112+
grouped_indices, remapped_components = _group_indices_per_component(
113+
point_components
114+
)
115+
descent_neighbors = graph[0].shape[1]
116+
new_graph = make_heap(distances.shape[0], descent_neighbors)
117+
118+
# Fill the new graph
119+
for i in numba.prange(distances.shape[0]):
120+
# Copy points from old graph that are not in the same component
121+
cnt = 0
122+
for j, d in zip(graph[0][i], graph[1][i]):
123+
if j < 0 or remapped_components[i] == remapped_components[j]:
124+
continue
125+
simple_heap_push(
126+
new_graph[1][i],
127+
new_graph[0][i],
128+
d,
129+
j,
130+
)
131+
cnt += 1
132+
133+
# Fill remaining slots with random points in other components
134+
tries = 0
135+
num_points_in_comp = len(grouped_indices[remapped_components[i]])
136+
while cnt < descent_neighbors and tries < 2 * descent_neighbors:
137+
tries += 1
138+
139+
# Sample random number in range [0, num-points-not-in-same-comp)
140+
idx = np.abs(tau_rand_int(rng_state)) % (
141+
distances.shape[0] - num_points_in_comp
142+
)
143+
144+
# Find the idx-th not-in-same-comp data point index
145+
for k, indices in enumerate(grouped_indices):
146+
if k == remapped_components[i]:
147+
continue
148+
size = np.int32(len(indices))
149+
if idx >= size:
150+
idx -= size
151+
else:
152+
idx = indices[idx]
153+
break
154+
155+
# Add idx to i's neighbors
156+
d = max(distances[idx, i], core_distances[idx], core_distances[i])
157+
cnt += checked_heap_push(new_graph[1][i], new_graph[0][i], d, idx)
158+
159+
# Set all flags to true
160+
new_graph[2][:] = np.uint8(1)
161+
return new_graph, remapped_components
162+
163+
164+
@numba.njit(cache=True)
165+
def precomputed_descent(
166+
distances,
167+
core_distances,
168+
in_graph,
169+
out_graph,
170+
point_components,
171+
max_candidates,
172+
n_iters,
173+
delta,
174+
rng_state,
175+
n_threads,
176+
):
177+
"""Runs NN Descent variant looking for nearest neighbors in other components.
178+
179+
Updates are more like the initially described algorithm than the local join
180+
algorithm. We keep track of two graphs:
181+
- the in-graph contains normal nearest neighbors and remains fixed.
182+
- the out-graph is updated to contain the nearest neighbors in other components.
183+
184+
The update step samples neighbors in the out-graph (both directions) compares their
185+
in-graph neighbors to find nearer neighbors in other components.
186+
"""
187+
for _ in range(n_iters):
188+
# Sample new (undirected) neighbors in the out-graph.
189+
out_neighbors = _sample_out_candidates(
190+
out_graph, max_candidates, rng_state, n_threads
191+
)
192+
# Direct neighbors + sampled reverse neighbors in the in-graph.
193+
in_neighbors = _sample_in_candidates(
194+
in_graph, max_candidates, rng_state, n_threads
195+
)
196+
# Find updates using the two sets of neighbors.
197+
updates = _generate_graph_updates(
198+
distances,
199+
core_distances,
200+
point_components,
201+
out_graph[1][:, 0],
202+
in_neighbors,
203+
out_neighbors,
204+
)
205+
# Update the out-graph.
206+
c = _apply_graph_updates(out_graph, updates, n_threads)
207+
# Early termination
208+
if c <= delta * in_graph.shape[1] * distances.shape[0]:
209+
break
210+
211+
212+
@numba.njit(parallel=True, cache=True)
213+
def _generate_graph_updates(
214+
distances,
215+
core_distances,
216+
point_components,
217+
dist_thresholds,
218+
in_neighbors,
219+
out_neighbors,
220+
):
221+
n_vertices = out_neighbors.shape[0]
222+
updates = [[(-1, np.inf) for _ in range(0)] for _ in range(n_vertices)]
223+
224+
# Iterate over vertices
225+
for current_idx in numba.prange(n_vertices):
226+
# Iterate over their out-graph sample
227+
for neighbor_idx in out_neighbors[current_idx]:
228+
if neighbor_idx < 0:
229+
continue
230+
# Iterate over their in-graph neighbors
231+
for candidate_idx in in_neighbors[neighbor_idx]:
232+
if (
233+
candidate_idx < 0
234+
or point_components[candidate_idx] == point_components[current_idx]
235+
):
236+
# Need to check components differ because Descent may run on
237+
# more neighbors than accepted by the MST! So the in-graph
238+
# may contain neighbors not yet connected!
239+
continue
240+
241+
d = max(
242+
distances[current_idx, candidate_idx],
243+
core_distances[candidate_idx],
244+
core_distances[current_idx],
245+
)
246+
if d <= max(
247+
dist_thresholds[current_idx],
248+
dist_thresholds[candidate_idx],
249+
):
250+
updates[current_idx].append((candidate_idx, d))
251+
252+
return updates

0 commit comments

Comments
 (0)