Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.Experimental;
import io.github.jbellis.jvector.annotations.VisibleForTesting;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.ImmutableGraphIndex.NodeAtLevel;
Expand Down Expand Up @@ -325,6 +326,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
this.parallelExecutor = parallelExecutor;

this.graph = new OnHeapGraphIndex(maxDegrees, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha));

this.searchers = ExplicitThreadLocal.withInitial(() -> {
var gs = new GraphSearcher(graph);
gs.usePruning(false);
Expand All @@ -338,6 +340,58 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
this.rng = new Random(0);
}

/**
* Create this builder from an existing {@link io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex}, this is useful when we just loaded a graph from disk
* copy it into {@link OnHeapGraphIndex} and then start mutating it with minimal overhead of recreating the mutable {@link OnHeapGraphIndex} used in the new GraphIndexBuilder object
*
* @param buildScoreProvider the provider responsible for calculating build scores.
* @param mutableGraphIndex a mutable graph index.
* @param beamWidth the width of the beam used during the graph building process.
* @param neighborOverflow the factor determining how many additional neighbors are allowed beyond the configured limit.
* @param alpha the weight factor for balancing score computations.
* @param addHierarchy whether to add hierarchical structures while building the graph.
* @param refineFinalGraph whether to perform a refinement step on the final graph structure.
* @param simdExecutor the ForkJoinPool executor used for SIMD tasks during graph building.
* @param parallelExecutor the ForkJoinPool executor used for general parallelization during graph building.
*
* @throws IOException if an I/O error occurs during the graph loading or conversion process.
*/
private GraphIndexBuilder(BuildScoreProvider buildScoreProvider, int dimension, MutableGraphIndex mutableGraphIndex, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) {
if (beamWidth <= 0) {
throw new IllegalArgumentException("beamWidth must be positive");
}
if (neighborOverflow < 1.0f) {
throw new IllegalArgumentException("neighborOverflow must be >= 1.0");
}
if (alpha <= 0) {
throw new IllegalArgumentException("alpha must be positive");
}

this.scoreProvider = buildScoreProvider;
this.neighborOverflow = neighborOverflow;
this.dimension = dimension;
this.alpha = alpha;
this.addHierarchy = addHierarchy;
this.refineFinalGraph = refineFinalGraph;
this.beamWidth = beamWidth;
this.simdExecutor = simdExecutor;
this.parallelExecutor = parallelExecutor;

this.graph = mutableGraphIndex;

this.searchers = ExplicitThreadLocal.withInitial(() -> {
var gs = new GraphSearcher(graph);
gs.usePruning(false);
return gs;
});

// in scratch, we store candidates in reverse order: worse candidates are first
this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1)));
this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1)));

this.rng = new Random(0);
}

// used by Cassandra when it fine-tunes the PQ codebook
public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) {
var newBuilder = new GraphIndexBuilder(newProvider,
Expand Down Expand Up @@ -450,13 +504,13 @@ public void cleanup() {
// clean up overflowed neighbor lists
parallelExecutor.submit(() -> {
IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(id -> {
for (int layer = 0; layer <= graph.getMaxLevel(); layer++) {
for (int level = 0; level <= graph.getMaxLevel(); level++) {
graph.enforceDegree(id);
}
});
}).join();

graph.allMutationsCompleted();
graph.setAllMutationsCompleted();
}

private void improveConnections(int node) {
Expand Down Expand Up @@ -825,6 +879,9 @@ public void load(RandomAccessReader in) throws IOException {
loadV3(in, size);
} else {
version = in.readInt();
if (version != 4) {
throw new IOException("Unsupported version: " + version);
}
loadV4(in);
}
}
Expand All @@ -836,15 +893,18 @@ private void loadV4(RandomAccessReader in) throws IOException {
}

int layerCount = in.readInt();
int entryNode = in.readInt();
var layerDegrees = new ArrayList<Integer>(layerCount);
for (int level = 0; level < layerCount; level++) {
layerDegrees.add(in.readInt());
}

int entryNode = in.readInt();

Map<Integer, Integer> nodeLevelMap = new HashMap<>();

// Read layer info
for (int level = 0; level < layerCount; level++) {
int layerSize = in.readInt();
layerDegrees.add(in.readInt());
for (int i = 0; i < layerSize; i++) {
int nodeId = in.readInt();
int nNeighbors = in.readInt();
Expand All @@ -860,6 +920,7 @@ private void loadV4(RandomAccessReader in) throws IOException {
var ca = new NodeArray(nNeighbors);
for (int j = 0; j < nNeighbors; j++) {
int neighbor = in.readInt();
float score = in.readFloat();
ca.addInOrder(neighbor, sf.similarityTo(neighbor));
}
graph.connectNode(level, nodeId, ca);
Expand Down Expand Up @@ -909,4 +970,61 @@ private void loadV3(RandomAccessReader in, int size) throws IOException {
graph.updateEntryNode(new NodeAtLevel(0, entryNode));
graph.setDegrees(List.of(maxDegree));
}

/**
* Convenience method to build a new graph from an existing one, with the addition of new nodes.
* This is useful when we want to merge a new set of vectors into an existing graph that is already on disk.
*
* @param in a reader from which to read the on-heap graph.
* @param newVectors a super set RAVV containing the new vectors to be added to the graph as well as the old ones that are already in the graph
* @param buildScoreProvider the provider responsible for calculating build scores.
* @param startingNodeOffset the offset in the newVectors RAVV where the new vectors start
* @param graphToRavvOrdMap a mapping from the old graph's node ids to the newVectors RAVV node ids
* @param beamWidth the width of the beam used during the graph building process.
* @param overflowRatio the ratio of extra neighbors to allow temporarily when inserting a node.
* @param alpha the weight factor for balancing score computations.
* @param addHierarchy whether to add hierarchical structures while building the graph.
*
* @return the in-memory representation of the graph index.
* @throws IOException if an I/O error occurs during the graph loading or conversion process.
*/
@Experimental
public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in,
RandomAccessVectorValues newVectors,
BuildScoreProvider buildScoreProvider,
int startingNodeOffset,
int[] graphToRavvOrdMap,
int beamWidth,
float overflowRatio,
float alpha,
boolean addHierarchy) throws IOException {

var diversityProvider = new VamanaDiversityProvider(buildScoreProvider, alpha);

try (MutableGraphIndex graph = OnHeapGraphIndex.load(in, overflowRatio, diversityProvider);) {

GraphIndexBuilder builder = new GraphIndexBuilder(
buildScoreProvider,
newVectors.dimension(),
graph,
beamWidth,
overflowRatio,
alpha,
addHierarchy,
true,
PhysicalCoreExecutor.pool(),
ForkJoinPool.commonPool()
);

var vv = newVectors.threadLocalSupplier();

// parallel graph construction from the merge documents Ids
PhysicalCoreExecutor.pool().submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> {
builder.addGraphNode(ord, vv.get().getVector(graphToRavvOrdMap[ord]));
})).join();

builder.cleanup();
return builder.getGraph();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,10 @@ interface MutableGraphIndex extends ImmutableGraphIndex {
* Signals that all mutations have been completed and the graph will not be mutated any further.
* Should be called by the builder after all mutations are completed (during cleanup).
*/
void allMutationsCompleted();
void setAllMutationsCompleted();

/**
* Returns true if all mutations have been completed. This is signaled by calling setAllMutationsCompleted.
*/
boolean allMutationsCompleted();
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.Experimental;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors;
import io.github.jbellis.jvector.graph.diversity.DiversityProvider;
import io.github.jbellis.jvector.util.Accountable;
Expand All @@ -37,9 +39,10 @@

import java.io.DataOutput;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicIntegerArray;
Expand Down Expand Up @@ -367,10 +370,14 @@ public void setDegrees(List<Integer> layerDegrees) {
}

@Override
public void allMutationsCompleted() {
public void setAllMutationsCompleted() {
allMutationsCompleted = true;
}

@Override
public boolean allMutationsCompleted() {
return allMutationsCompleted;
}

/**
* A concurrent View of the graph that is safe to search concurrently with updates and with other
Expand Down Expand Up @@ -490,44 +497,101 @@ public String toString() {
/**
* Saves the graph to the given DataOutput for reloading into memory later
*/
@Experimental
@Deprecated
public void save(DataOutput out) {
if (deletedNodes.cardinality() > 0) {
throw new IllegalStateException("Cannot save a graph that has deleted nodes. Call cleanup() first");
}

try (var view = getView()) {
out.writeInt(OnHeapGraphIndex.MAGIC); // the magic number
out.writeInt(4); // The version

// Write graph-level properties.
out.writeInt(layers.size());
assert view.entryNode().level == getMaxLevel();
out.writeInt(view.entryNode().node);

for (int level = 0; level < layers.size(); level++) {
out.writeInt(size(level));
out.writeInt(getDegree(level));

// Save neighbors from the layer.
var baseLayer = layers.get(level);
baseLayer.forEach((nodeId, neighbors) -> {
try {
NodesIterator iterator = neighbors.iterator();
out.writeInt(nodeId);
out.writeInt(iterator.size());
for (int n = 0; n < iterator.size(); n++) {
out.writeInt(iterator.nextInt());
}
assert !iterator.hasNext();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
});
public void save(DataOutput out) throws IOException {
if (!allMutationsCompleted()) {
throw new IllegalStateException("Cannot save a graph with pending mutations. Call cleanup() first");
}

out.writeInt(OnHeapGraphIndex.MAGIC); // the magic number
out.writeInt(4); // The version

// Write graph-level properties.
out.writeInt(layers.size());
for (int level = 0; level < layers.size(); level++) {
out.writeInt(getDegree(level));
}

var entryNode = entryPoint.get();
assert entryNode.level == getMaxLevel();
out.writeInt(entryNode.node);

for (int level = 0; level < layers.size(); level++) {
out.writeInt(size(level));

// Save neighbors from the layer.
var it = nodeStream(level).iterator();
while (it.hasNext()) {
int nodeId = it.nextInt();
var neighbors = layers.get(level).get(nodeId);
out.writeInt(nodeId);
out.writeInt(neighbors.size());

for (int n = 0; n < neighbors.size(); n++) {
out.writeInt(neighbors.getNode(n));
out.writeFloat(neighbors.getScore(n));
}
}
}
}

/**
* Saves the graph to the given DataOutput for reloading into memory later
*/
@Experimental
@Deprecated
public static OnHeapGraphIndex load(RandomAccessReader in, double overflowRatio, DiversityProvider diversityProvider) throws IOException {
int magic = in.readInt(); // the magic number
if (magic != OnHeapGraphIndex.MAGIC) {
throw new IOException("Unsupported magic number: " + magic);
}

int version = in.readInt(); // The version
if (version != 4) {
throw new IOException("Unsupported version: " + version);
}

// Write graph-level properties.
int layerCount = in.readInt();
var layerDegrees = new ArrayList<Integer>(layerCount);
for (int level = 0; level < layerCount; level++) {
layerDegrees.add(in.readInt());
}

int entryNode = in.readInt();

var graph = new OnHeapGraphIndex(layerDegrees, overflowRatio, diversityProvider);

Map<Integer, Integer> nodeLevelMap = new HashMap<>();

for (int level = 0; level < layerCount; level++) {
int layerSize = in.readInt();

for (int i = 0; i < layerSize; i++) {
int nodeId = in.readInt();
int nNeighbors = in.readInt();

var ca = new NodeArray(nNeighbors);
for (int j = 0; j < nNeighbors; j++) {
int neighbor = in.readInt();
float score = in.readFloat();
ca.addInOrder(neighbor, score);
}
graph.connectNode(level, nodeId, ca);
nodeLevelMap.put(nodeId, level);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}

for (var k : nodeLevelMap.keySet()) {
NodeAtLevel nal = new NodeAtLevel(nodeLevelMap.get(k), k);
graph.markComplete(nal);
}

graph.setDegrees(layerDegrees);
graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode));

return graph;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,6 @@ public Set<FeatureId> getFeatureSet() {
return features.keySet();
}

public int getDimension() {
Comment on lines 228 to -230
Copy link
Member

@michaeljmarshall michaeljmarshall Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using this in C*, why did this get removed?

Copy link
Member

@michaeljmarshall michaeljmarshall Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I submitted a PR proposing we re-add it. #550

return dimension;
}

@Override
public int size(int level) {
return layerInfo.get(level).size;
Expand Down
Loading