Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ New Features

Improvements
---------------------
# GITHUB#15148: Add support uint8 distance and allow 8 bit scalar quantization (Trevor McCulloch)
* GITHUB#15148: Add support uint8 distance and allow 8 bit scalar quantization (Trevor McCulloch)

* GITHUB#15184: Refactoring internal HNSWGraphBuilder's APIs and avoid creating new scorer for each call (Patrick Zhai)

Optimizations
---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicWriter;

/**
Expand Down Expand Up @@ -586,7 +585,6 @@ private static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
private int lastDocID = -1;
private int node = 0;
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;
private UpdateableRandomVectorScorer scorer;

@SuppressWarnings("unchecked")
static FieldWriter<?> create(
Expand Down Expand Up @@ -642,7 +640,6 @@ static FieldWriter<?> create(
(List<float[]>) flatFieldVectorsWriter.getVectors(),
fieldInfo.getVectorDimension()));
};
this.scorer = scorerSupplier.scorer();
hnswGraphBuilder =
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream);
Expand All @@ -658,8 +655,7 @@ public void addValue(int docID, T vectorValue) throws IOException {
+ "\" appears more than once in this document (only one value is allowed per field)");
}
flatFieldVectorsWriter.addValue(docID, vectorValue);
scorer.setScoringOrdinal(node);
hnswGraphBuilder.addGraphNode(node, scorer);
hnswGraphBuilder.addGraphNode(node);
node++;
lastDocID = docID;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.lucene.util.hnsw;

import java.io.IOException;
import org.apache.lucene.internal.hppc.IntHashSet;
import org.apache.lucene.util.InfoStream;

/**
Expand All @@ -34,16 +35,21 @@ public interface HnswBuilder {
*/
OnHeapHnswGraph build(int maxOrd) throws IOException;

/** Inserts a doc with vector value to the graph */
/** Inserts a doc with a vector value to the graph */
void addGraphNode(int node) throws IOException;

/**
* Inserts a doc with a vector value to the graph, searching on level 0 with provided entry points
*/
void addGraphNode(int node, IntHashSet eps) throws IOException;

/** Set info-stream to output debugging information */
void setInfoStream(InfoStream infoStream);

OnHeapHnswGraph getGraph();

/**
* Once this method is called no further updates to the graph are accepted (addGraphNode will
* Once this method is called, no further updates to the graph are accepted (addGraphNode will
* throw IllegalStateException). Final modifications to the graph (eg patching up disconnected
* components, re-ordering node ids for better delta compression) may be triggered, so callers
* should expect this call to take some time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import org.apache.lucene.internal.hppc.IntHashSet;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.FixedBitSet;
Expand Down Expand Up @@ -98,6 +99,11 @@ public void addGraphNode(int node) throws IOException {
throw new UnsupportedOperationException("This builder is for merge only");
}

@Override
public void addGraphNode(int node, IntHashSet eps) throws IOException {
throw new UnsupportedOperationException("This builder is for merge only");
}

@Override
public void setInfoStream(InfoStream infoStream) {
this.infoStream = infoStream;
Expand Down Expand Up @@ -142,7 +148,6 @@ private static final class ConcurrentMergeWorker extends HnswGraphBuilder {

private final BitSet initializedNodes;
private int batchSize = DEFAULT_BATCH_SIZE;
private final UpdateableRandomVectorScorer scorer;

private ConcurrentMergeWorker(
RandomVectorScorerSupplier scorerSupplier,
Expand All @@ -163,7 +168,6 @@ private ConcurrentMergeWorker(
new NeighborQueue(beamWidth, true), hnswLock, new FixedBitSet(hnsw.maxNodeId() + 1)));
this.workProgress = workProgress;
this.initializedNodes = initializedNodes;
this.scorer = scorerSupplier.scorer();
}

/**
Expand Down Expand Up @@ -192,21 +196,12 @@ private int getStartPos(int maxOrd) {
}
}

@Override
public void addGraphNode(int node, UpdateableRandomVectorScorer scorer) throws IOException {
if (initializedNodes != null && initializedNodes.get(node)) {
return;
}
super.addGraphNode(node, scorer);
}

@Override
public void addGraphNode(int node) throws IOException {
if (initializedNodes != null && initializedNodes.get(node)) {
return;
}
scorer.setScoringOrdinal(node);
addGraphNode(node, scorer);
super.addGraphNode(node);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
/**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
* hyper-parameters.
*
* <p>Thread-safety: This class is NOT thread safe, it cannot be shared across threads, however, it
* IS safe for multiple HnswGraphBuilder to build the same graph, if the graph's size is known in
* the beginning (like when doing merge)
*/
public class HnswGraphBuilder implements HnswBuilder {

Expand All @@ -64,7 +68,7 @@ public class HnswGraphBuilder implements HnswBuilder {
private final double ml;

private final SplittableRandom random;
protected final RandomVectorScorerSupplier scorerSupplier;
private final UpdateableRandomVectorScorer scorer;
private final HnswGraphSearcher graphSearcher;
private final GraphBuilderKnnCollector entryCandidates; // for upper levels of graph search
private final GraphBuilderKnnCollector
Expand Down Expand Up @@ -144,8 +148,8 @@ protected HnswGraphBuilder(
throw new IllegalArgumentException("beamWidth must be positive");
}
this.M = hnsw.maxConn();
this.scorerSupplier =
Objects.requireNonNull(scorerSupplier, "scorer supplier must not be null");
this.scorer =
Objects.requireNonNull(scorerSupplier, "scorer supplier must not be null").scorer();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's change the error message to say "scorer must not be null"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually still asserting on scorerSupplier (which is passed in to ctor)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, thanks. Sorry it took me a bit to respond I was basking on the shore of a great lake and deliberately cut off ...

// normalization factor for level generation; currently not configurable
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed);
Expand Down Expand Up @@ -196,10 +200,8 @@ protected void addVectors(int minOrd, int maxOrd) throws IOException {
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "addVectors [" + minOrd + " " + maxOrd + ")");
}
UpdateableRandomVectorScorer scorer = scorerSupplier.scorer();
for (int node = minOrd; node < maxOrd; node++) {
scorer.setScoringOrdinal(node);
addGraphNode(node, scorer);
addGraphNode(node);
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
t = printGraphBuildStatus(node, start, t);
}
Expand All @@ -210,10 +212,24 @@ private void addVectors(int maxOrd) throws IOException {
addVectors(0, maxOrd);
}

public void addGraphNode(int node, UpdateableRandomVectorScorer scorer) throws IOException {
addGraphNodeInternal(node, scorer, null);
}

/**
* Note: this implementation is thread safe when the graph size is fixed (e.g. when merging) The
* process of adding a node is roughly: 1. Add the node to all levels from top to the bottom, but
* do not connect it to any other node, nor try to promote itself to an entry node before the
* connection is done. (Unless the graph is empty and this is the first node, in that case we set
* the entry node and return) 2. Do the search from top to bottom, remember all the possible
* neighbours on each level the node is on. 3. Add the neighbor to the node from bottom to top
* level. When adding the neighbour, we always add all the outgoing links first before adding an
* incoming link such that when a search visits this node, it can always find a way out 4. If the
* node has a level that is less or equal to the graph's max level, then we're done here. If the
* node has a level larger than the graph's max level, then we need to promote the node as the
* entry node. If, while we add the node to the graph, the entry node has changed (which means the
* graph level has changed as well), we need to reinsert the node to the newly introduced levels
* (repeating step 2,3 for new levels) and again try to promote the node to entry node.
*
* @param eps0 If specified, we will use it as the entry points of search on level 0, is useful
* when you have some prior knowledge, e.g. in {@link MergingHnswGraphBuilder}
*/
private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer, IntHashSet eps0)
throws IOException {
if (frozen) {
Expand All @@ -224,7 +240,8 @@ private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer,
for (int level = nodeLevel; level >= 0; level--) {
hnsw.addNode(level, node);
}
// then promote itself as entry node if entry node is not set
// then promote itself as entry node if entry node is not set (this is the first ever node of
// the graph)
if (hnsw.trySetNewEntryNode(node, nodeLevel)) {
return;
}
Expand All @@ -235,8 +252,12 @@ private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer,
int curMaxLevel;
do {
curMaxLevel = hnsw.numLevels() - 1;
// NOTE: the entry node and max level may not be paired, but because we get the level first
// NOTE: the entry node and max level are not retrieved synchronously, which could lead to a
// situation where
// the entry node's level is different from the graph's max level, but because we get the
// level first,
// we ensure that the entry node we get later will always exist on the curMaxLevel
// e.g., curMaxLevel <= entryNode.level
int[] eps = new int[] {hnsw.entryNode()};

// we first do the search from top to bottom
Expand Down Expand Up @@ -271,15 +292,21 @@ private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer,
}
lowestUnsetLevel += scratchPerLevel.length;
assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1;
if (lowestUnsetLevel > nodeLevel) {
if (lowestUnsetLevel == nodeLevel + 1) {
// we have already set all the levels we need for this node
return;
}
assert lowestUnsetLevel == curMaxLevel + 1 && nodeLevel > curMaxLevel;
// The node's level is higher than the graph's max level, so we need to
// try to promote this node as the graph's entry node
if (hnsw.tryPromoteNewEntryNode(node, nodeLevel, curMaxLevel)) {
return;
}
// If we're not able to promote, it means the graph must have already changed
// and has a new max level and some other entry node
if (hnsw.numLevels() == curMaxLevel + 1) {
// This should never happen if all the calculations are correct
// This is an impossible situation, if happens, then something above is
// not hold
throw new IllegalStateException(
"We're not able to promote node "
+ node
Expand All @@ -294,31 +321,12 @@ private void addGraphNodeInternal(int node, UpdateableRandomVectorScorer scorer,

@Override
public void addGraphNode(int node) throws IOException {
/*
* Note: this implementation is thread safe when graph size is fixed (e.g. when merging)
* The process of adding a node is roughly:
* 1. Add the node to all level from top to the bottom, but do not connect it to any other node,
* nor try to promote itself to an entry node before the connection is done. (Unless the graph is empty
* and this is the first node, in that case we set the entry node and return)
* 2. Do the search from top to bottom, remember all the possible neighbours on each level the node
* is on.
* 3. Add the neighbor to the node from bottom to top level, when adding the neighbour,
* we always add all the outgoing links first before adding incoming link such that
* when a search visits this node, it can always find a way out
* 4. If the node has level that is less or equal to graph level, then we're done here.
* If the node has level larger than graph level, then we need to promote the node
* as the entry node. If, while we add the node to the graph, the entry node has changed
* (which means the graph level has changed as well), we need to reinsert the node
* to the newly introduced levels (repeating step 2,3 for new levels) and again try to
* promote the node to entry node.
*/
UpdateableRandomVectorScorer scorer = scorerSupplier.scorer();
scorer.setScoringOrdinal(node);
addGraphNodeInternal(node, scorer, null);
}

public void addGraphNodeWithEps(int node, IntHashSet eps0) throws IOException {
UpdateableRandomVectorScorer scorer = scorerSupplier.scorer();
@Override
public void addGraphNode(int node, IntHashSet eps0) throws IOException {
scorer.setScoringOrdinal(node);
addGraphNodeInternal(node, scorer, eps0);
}
Expand Down Expand Up @@ -486,7 +494,6 @@ private boolean connectComponents(int level) throws IOException {
// while linking
GraphBuilderKnnCollector beam = new GraphBuilderKnnCollector(2);
int[] eps = new int[1];
UpdateableRandomVectorScorer scorer = scorerSupplier.scorer();
for (Component c : components) {
if (c != c0) {
if (c.start() == NO_MORE_DOCS) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,6 @@ public InitializedHnswGraphBuilder(
this.initializedNodes = initializedNodes;
}

@Override
public void addGraphNode(int node, UpdateableRandomVectorScorer scorer) throws IOException {
if (initializedNodes.get(node)) {
return;
}
super.addGraphNode(node, scorer);
}

@Override
public void addGraphNode(int node) throws IOException {
if (initializedNodes.get(node)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ private MergingHnswGraphBuilder(
* @param ordMaps the ordinal maps for the graphs
* @param totalNumberOfVectors the total number of vectors in the new graph, this should include
* all vectors expected to be added to the graph in the future
* @param initializedNodes the nodes will be initialized through the merging
* @param initializedNodes the nodes will be initialized through the merging, if null, all nodes
* should be already initialized after {@link #updateGraph(HnswGraph, int[])} being called
* @return a new HnswGraphBuilder that is initialized with the provided HnswGraph
* @throws IOException when reading the graph fails
*/
Expand Down Expand Up @@ -172,7 +173,7 @@ private void updateGraph(HnswGraph gS, int[] ordMapS) throws IOException {
}
}
}
addGraphNodeWithEps(ordMapS[u], eps);
addGraphNode(ordMapS[u], eps);
}
}
}
Loading