diff --git a/benchmarks-jmh/pom.xml b/benchmarks-jmh/pom.xml index c82ee2707..78654edcc 100644 --- a/benchmarks-jmh/pom.xml +++ b/benchmarks-jmh/pom.xml @@ -94,6 +94,21 @@ + + org.apache.maven.plugins + maven-javadoc-plugin + + + --add-modules=jdk.incubator.vector + + 22 + false + true + + io.github.jbellis:* + + + \ No newline at end of file diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithRandomSetBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithRandomSetBenchmark.java index ed0402e6d..4558c0559 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithRandomSetBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithRandomSetBenchmark.java @@ -37,7 +37,23 @@ import io.github.jbellis.jvector.vector.types.VectorTypeSupport; - +/** + * JMH benchmark for measuring graph index construction performance using randomly generated vectors. + * This benchmark evaluates the time required to build a graph index with configurable parameters + * including vector dimensionality, dataset size, and optional Product Quantization (PQ) compression. + * + *

The benchmark tests various configurations to assess how different factors affect index + * construction time, including the impact of using PQ compression during the build process.

+ * + *

Key parameters:

+ * + */ @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.MILLISECONDS) @State(Scope.Thread) @@ -48,17 +64,45 @@ public class IndexConstructionWithRandomSetBenchmark { private static final Logger log = LoggerFactory.getLogger(IndexConstructionWithRandomSetBenchmark.class); private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); + + /** The vector values to be indexed, initialized during setup. */ private RandomAccessVectorValues ravv; + + /** The score provider used during graph construction, either exact or PQ-based. */ private BuildScoreProvider buildScoreProvider; - private int M = 32; // graph degree + + /** The maximum degree of the graph (number of neighbors per node). */ + private int M = 32; + + /** The beam width used during graph construction searches. */ private int beamWidth = 100; + + /** The dimensionality of vectors being indexed. */ @Param({"768", "1536"}) private int originalDimension; + + /** The number of vectors in the dataset to be indexed. */ @Param({/*"10000",*/ "100000"/*, "1000000"*/}) int numBaseVectors; + + /** The number of PQ subspaces to use, or 0 for no compression. */ @Param({"0", "16"}) private int numberOfPQSubspaces; + /** + * Constructs a new benchmark instance. JMH will instantiate this class + * and populate the @Param fields before calling setup methods. + */ + public IndexConstructionWithRandomSetBenchmark() { + // JMH-managed lifecycle + } + + /** + * Initializes the benchmark state by generating random vectors and configuring + * the appropriate score provider based on whether PQ compression is enabled. + * + * @throws IOException if an error occurs during setup + */ @Setup(Level.Trial) public void setup() throws IOException { @@ -86,11 +130,25 @@ public void setup() throws IOException { } + /** + * Tears down resources after each benchmark invocation. + * Currently performs no operations but is included for future resource cleanup needs. + * + * @throws IOException if an error occurs during teardown + */ @TearDown(Level.Invocation) public void tearDown() throws IOException { } + /** + * The main benchmark method that measures the time to build a graph index. + * Constructs a complete graph index from the configured vectors using the + * specified parameters and score provider. + * + * @param blackhole JMH blackhole to prevent dead code elimination + * @throws IOException if an error occurs during index construction + */ @Benchmark public void buildIndexBenchmark(Blackhole blackhole) throws IOException { // score provider using the raw, in-memory vectors @@ -100,6 +158,13 @@ public void buildIndexBenchmark(Blackhole blackhole) throws IOException { } } + /** + * Creates a random vector with the specified dimensionality. + * Each component is randomly generated using {@link Math#random()}. + * + * @param dimension the number of dimensions in the vector + * @return a newly created random vector + */ private VectorFloat createRandomVector(int dimension) { VectorFloat vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension); for (int i = 0; i < dimension; i++) { diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithStaticSetBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithStaticSetBenchmark.java index 5e643a986..4d616652a 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithStaticSetBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithStaticSetBenchmark.java @@ -31,6 +31,23 @@ import java.util.List; import java.util.concurrent.TimeUnit; +/** + * JMH benchmark for measuring graph index construction performance using the SIFT dataset. + * This benchmark evaluates index construction time with a fixed, real-world dataset, + * testing various combinations of graph degree (M) and beam width parameters. + * + *

Unlike {@link IndexConstructionWithRandomSetBenchmark}, this benchmark uses the + * actual SIFT dataset loaded from disk, providing more realistic performance measurements + * that account for real data characteristics.

+ * + *

Key parameters:

+ * + */ @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.MILLISECONDS) @State(Scope.Thread) @@ -40,17 +57,47 @@ @Threads(1) public class IndexConstructionWithStaticSetBenchmark { private static final Logger log = LoggerFactory.getLogger(IndexConstructionWithStaticSetBenchmark.class); + + /** The vector values to be indexed, loaded from the SIFT dataset. */ private RandomAccessVectorValues ravv; + + /** The base vectors from the SIFT dataset. */ private List> baseVectors; + + /** The query vectors from the SIFT dataset (loaded but not used in this benchmark). */ private List> queryVectors; + + /** The ground truth nearest neighbors (loaded but not used in this benchmark). */ private List> groundTruth; + + /** The score provider used during graph construction. */ private BuildScoreProvider bsp; + + /** The maximum degree of the graph (number of neighbors per node). */ @Param({"16", "32", "64"}) - private int M; // graph degree + private int M; + + /** The beam width used during graph construction searches. */ @Param({"10", "100"}) private int beamWidth; + + /** The dimensionality of vectors in the dataset. */ int originalDimension; + /** + * Constructs a new benchmark instance. JMH will instantiate this class + * and populate the @Param fields before calling setup methods. + */ + public IndexConstructionWithStaticSetBenchmark() { + // JMH-managed lifecycle + } + + /** + * Initializes the benchmark state by loading the SIFT dataset from disk + * and configuring the score provider. + * + * @throws IOException if an error occurs loading the dataset files + */ @Setup public void setup() throws IOException { var siftPath = "siftsmall"; @@ -67,6 +114,11 @@ public void setup() throws IOException { bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.EUCLIDEAN); } + /** + * Cleans up resources after the benchmark completes by clearing all vector collections. + * + * @throws IOException if an error occurs during teardown + */ @TearDown public void tearDown() throws IOException { baseVectors.clear(); @@ -74,6 +126,13 @@ public void tearDown() throws IOException { groundTruth.clear(); } + /** + * The main benchmark method that measures the time to build a graph index + * from the loaded SIFT dataset using the configured parameters. + * + * @param blackhole JMH blackhole to prevent dead code elimination + * @throws IOException if an error occurs during index construction + */ @Benchmark public void buildIndexBenchmark(Blackhole blackhole) throws IOException { // score provider using the raw, in-memory vectors diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationBenchmark.java index 59342e41a..40048d62a 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationBenchmark.java @@ -50,28 +50,76 @@ @Threads(1) public class PQDistanceCalculationBenchmark { private static final Logger log = LoggerFactory.getLogger(PQDistanceCalculationBenchmark.class); + + /** + * Creates a new benchmark instance. + *

+ * This constructor is invoked by JMH and should not be called directly. + */ + public PQDistanceCalculationBenchmark() { + } private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); private final VectorSimilarityFunction vsf = VectorSimilarityFunction.EUCLIDEAN; + /** The base vectors used for distance calculations. */ private List> vectors; + + /** Product-quantized versions of the base vectors, or null if M=0. */ private PQVectors pqVectors; + + /** Query vectors used to test distance calculations. */ private List> queryVectors; + + /** The Product Quantization model, or null if M=0. */ private ProductQuantization pq; + + /** Score provider configured for either full precision or PQ-based scoring. */ private BuildScoreProvider buildScoreProvider; - + + /** + * The dimensionality of the vectors. + *

+ * Default value: 1536 (typical for modern embedding models). + */ @Param({"1536"}) private int dimension; - + + /** + * The number of base vectors to create for the dataset. + *

+ * Default value: 10000 + */ @Param({"10000"}) private int vectorCount; - + + /** + * The number of query vectors to test against the dataset. + *

+ * Default value: 100 + */ @Param({"100"}) private int queryCount; - + + /** + * The number of subspaces for Product Quantization. + *

+ * When M=0, uses full precision vectors without quantization. + * When M>0, splits each vector into M subspaces for compression. + * Values: 0 (no PQ), 16, 64, 192 + */ @Param({"0", "16", "64", "192"}) - private int M; // Number of subspaces for PQ + private int M; + /** + * Sets up the benchmark by creating random vectors and configuring score providers. + *

+ * This method creates the specified number of base vectors and query vectors with random + * values. If M>0, it also computes Product Quantization and creates PQ-encoded vectors. + * The appropriate score provider is then configured based on whether PQ is used. + * + * @throws IOException if there is an error during setup + */ @Setup public void setup() throws IOException { log.info("Creating dataset with dimension: {}, vector count: {}, query count: {}", dimension, vectorCount, queryCount); @@ -100,6 +148,16 @@ public void setup() throws IOException { log.info("Created dataset with dimension: {}, vector count: {}, query count: {}", dimension, vectorCount, queryCount); } + /** + * Benchmarks distance calculation using cached search score providers. + *

+ * This benchmark measures the performance of calculating distances between query vectors + * and all base vectors using a search score provider that caches precomputed values for + * the query vector. This represents the typical search scenario where a query is compared + * against many candidates. + * + * @param blackhole JMH blackhole to prevent dead code elimination + */ @Benchmark public void cachedDistanceCalculation(Blackhole blackhole) { float totalSimilarity = 0; @@ -115,6 +173,16 @@ public void cachedDistanceCalculation(Blackhole blackhole) { blackhole.consume(totalSimilarity); } + /** + * Benchmarks distance calculation for diversity scoring. + *

+ * This benchmark measures the performance of calculating distances between base vectors + * using diversity score providers. This represents the scenario where vectors in the + * dataset are compared against each other to assess diversity, such as during graph + * construction or result reranking. + * + * @param blackhole JMH blackhole to prevent dead code elimination + */ @Benchmark public void diversityCalculation(Blackhole blackhole) { float totalSimilarity = 0; @@ -130,6 +198,15 @@ public void diversityCalculation(Blackhole blackhole) { blackhole.consume(totalSimilarity); } + /** + * Creates a random vector with the specified dimension. + *

+ * Each component of the vector is assigned a random floating-point value + * between 0.0 (inclusive) and 1.0 (exclusive). + * + * @param dimension the number of dimensions for the vector + * @return a new random vector + */ private VectorFloat createRandomVector(int dimension) { VectorFloat vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension); for (int i = 0; i < dimension; i++) { diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQTrainingWithRandomVectorsBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQTrainingWithRandomVectorsBenchmark.java index 8e1fa403b..ad350c438 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQTrainingWithRandomVectorsBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQTrainingWithRandomVectorsBenchmark.java @@ -34,6 +34,21 @@ import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED; +/** + * Benchmark for measuring the performance of Product Quantization training on randomly generated vectors. + *

+ * This benchmark evaluates the time required to compute Product Quantization (PQ) codebooks from + * a dataset of random vectors. PQ training involves clustering vectors in each subspace using k-means, + * which is a computationally intensive operation. The benchmark tests various configurations of + * subspace counts (M) to understand the trade-off between compression ratio and training time. + *

+ * Key aspects measured: + *

    + *
  • K-means clustering performance across multiple subspaces
  • + *
  • Impact of increasing M (number of subspaces) on training time
  • + *
  • Scalability with dataset size
  • + *
+ */ @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.MILLISECONDS) @State(Scope.Thread) @@ -43,15 +58,54 @@ @Threads(1) public class PQTrainingWithRandomVectorsBenchmark { private static final Logger log = LoggerFactory.getLogger(PQTrainingWithRandomVectorsBenchmark.class); + + /** + * Creates a new benchmark instance. + *

+ * This constructor is invoked by JMH and should not be called directly. + */ + public PQTrainingWithRandomVectorsBenchmark() { + } private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); + + /** Random access wrapper for the pre-created vector dataset. */ private RandomAccessVectorValues ravv; + + /** + * The number of subspaces for Product Quantization. + *

+ * Higher values of M provide more accurate quantization but increase training time + * and memory usage. Values: 16, 32, 64 + */ @Param({"16", "32", "64"}) - private int M; // Number of subspaces + private int M; + + /** + * The dimensionality of the vectors. + *

+ * Default value: 768 (common for many embedding models). + */ @Param({"768"}) int originalDimension; + + /** + * The number of vectors in the training dataset. + *

+ * Default value: 100000 + */ @Param({"100000"}) int vectorCount; + /** + * Sets up the benchmark by pre-creating a dataset of random vectors. + *

+ * This method generates the specified number of random vectors with the configured + * dimensionality. The vectors are wrapped in a RandomAccessVectorValues instance + * for use during PQ training. Pre-creating all vectors ensures the benchmark + * measures only the PQ training time, not vector generation. + * + * @throws IOException if there is an error during setup + */ @Setup public void setup() throws IOException { log.info("Pre-creating vector dataset with original dimension: {}, vector count: {}", originalDimension, vectorCount); @@ -69,11 +123,35 @@ public void setup() throws IOException { log.info("Pre-created vector dataset with original dimension: {}, vector count: {}", originalDimension, vectorCount); } + /** + * Tears down the benchmark state. + *

+ * This method is a placeholder for any cleanup operations that may be needed + * in future implementations. + * + * @throws IOException if there is an error during teardown + * @throws InterruptedException if the thread is interrupted during teardown + */ @TearDown public void tearDown() throws IOException, InterruptedException { } + /** + * Benchmarks the computation of Product Quantization codebooks. + *

+ * This benchmark measures the time required to train a Product Quantization model + * on the pre-created vector dataset. The training process involves: + *

    + *
  • Splitting each vector into M subspaces
  • + *
  • Running k-means clustering (256 centroids) in each subspace
  • + *
  • Centering the dataset to improve quantization accuracy
  • + *
+ * The resulting PQ model provides a compression ratio based on M and the original dimension. + * + * @param blackhole JMH blackhole to prevent dead code elimination + * @throws IOException if there is an error during PQ computation + */ @Benchmark public void productQuantizationComputeBenchmark(Blackhole blackhole) throws IOException { // Compress the original vectors using PQ. this represents a compression ratio of 128 * 4 / 16 = 32x diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQTrainingWithSiftBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQTrainingWithSiftBenchmark.java index 14eea932c..c99185ec2 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQTrainingWithSiftBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQTrainingWithSiftBenchmark.java @@ -29,6 +29,28 @@ import java.util.List; import java.util.concurrent.TimeUnit; +/** + * Benchmark for measuring the performance of Product Quantization training on the SIFT dataset. + *

+ * This benchmark evaluates the time required to compute Product Quantization (PQ) codebooks from + * the SIFT Small dataset, which consists of real-world image feature vectors. Unlike random vectors, + * SIFT vectors have realistic distributions and correlations, making this benchmark more representative + * of actual production workloads. + *

+ * The SIFT Small dataset contains: + *

    + *
  • 10,000 base vectors (128-dimensional)
  • + *
  • 100 query vectors
  • + *
  • Ground truth nearest neighbors for evaluation
  • + *
+ *

+ * Key aspects measured: + *

    + *
  • PQ training performance on real-world data with natural clustering
  • + *
  • Impact of different M values on training time with realistic vectors
  • + *
  • Comparison with random vector training to understand data distribution effects
  • + *
+ */ @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.MILLISECONDS) @State(Scope.Thread) @@ -38,14 +60,48 @@ @Threads(1) public class PQTrainingWithSiftBenchmark { private static final Logger log = LoggerFactory.getLogger(PQTrainingWithSiftBenchmark.class); + + /** + * Creates a new benchmark instance. + *

+ * This constructor is invoked by JMH and should not be called directly. + */ + public PQTrainingWithSiftBenchmark() { + } + + /** Random access wrapper for the SIFT base vectors. */ private RandomAccessVectorValues ravv; + + /** The SIFT base vectors used for training. */ private List> baseVectors; + + /** The SIFT query vectors (loaded but not used in this benchmark). */ private List> queryVectors; + + /** Ground truth nearest neighbors (loaded but not used in this benchmark). */ private List> groundTruth; + + /** + * The number of subspaces for Product Quantization. + *

+ * Higher values of M provide more accurate quantization but increase training time. + * Values: 16, 32, 64 + */ @Param({"16", "32", "64"}) - private int M; // Number of subspaces + private int M; + + /** The dimensionality of the SIFT vectors (128). */ int originalDimension; + /** + * Sets up the benchmark by loading the SIFT Small dataset. + *

+ * This method loads the SIFT base vectors, query vectors, and ground truth from the + * local filesystem. The base vectors are wrapped in a RandomAccessVectorValues instance + * for use during PQ training. + * + * @throws IOException if there is an error loading the SIFT dataset files + */ @Setup public void setup() throws IOException { var siftPath = "siftsmall"; @@ -59,6 +115,13 @@ public void setup() throws IOException { ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension); } + /** + * Tears down the benchmark by clearing the loaded vectors. + *

+ * This method releases memory by clearing all loaded vectors and ground truth data. + * + * @throws IOException if there is an error during teardown + */ @TearDown public void tearDown() throws IOException { baseVectors.clear(); @@ -66,6 +129,21 @@ public void tearDown() throws IOException { groundTruth.clear(); } + /** + * Benchmarks the computation of Product Quantization codebooks on SIFT vectors. + *

+ * This benchmark measures the time required to train a Product Quantization model + * on the SIFT Small dataset. The training process involves: + *

    + *
  • Splitting each 128-dimensional SIFT vector into M subspaces
  • + *
  • Running k-means clustering (256 centroids) in each subspace
  • + *
  • Centering the dataset to improve quantization accuracy
  • + *
+ * The resulting PQ model provides a compression ratio of 128 * 4 / M bytes per vector. + * + * @param blackhole JMH blackhole to prevent dead code elimination + * @throws IOException if there is an error during PQ computation + */ @Benchmark public void productQuantizationComputeBenchmark(Blackhole blackhole) throws IOException { // Compress the original vectors using PQ. this represents a compression ratio of 128 * 4 / 16 = 32x diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java index b71591f33..0813fe2c0 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java @@ -38,6 +38,26 @@ import java.util.Set; import java.util.concurrent.TimeUnit; +/** + * Benchmark for measuring graph search recall with and without Product Quantization on random vectors. + *

+ * This benchmark evaluates the quality and performance of approximate nearest neighbor (ANN) search + * using a hierarchical navigable small world (HNSW) graph index. It measures recall by comparing + * search results against exact nearest neighbors computed via brute force. The benchmark tests both + * full-precision vectors and Product Quantized (PQ) vectors to understand the accuracy-speed trade-off. + *

+ * Key metrics tracked via auxiliary counters: + *

    + *
  • avgRecall: The fraction of true nearest neighbors found in the search results
  • + *
  • avgReRankedCount: Number of candidates re-ranked with exact distances
  • + *
  • avgVisitedCount: Number of graph nodes visited during search
  • + *
  • avgExpandedCount: Number of graph nodes expanded (neighbors examined)
  • + *
  • avgExpandedCountBaseLayer: Number of nodes expanded in the base layer
  • + *
+ *

+ * The benchmark builds a graph index once during setup and then performs searches with multiple + * query vectors, measuring both search time and recall quality. + */ @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.MILLISECONDS) @State(Scope.Thread) @@ -47,30 +67,112 @@ @Threads(1) public class RecallWithRandomVectorsBenchmark { private static final Logger log = LoggerFactory.getLogger(RecallWithRandomVectorsBenchmark.class); + + /** + * Creates a new benchmark instance. + *

+ * This constructor is invoked by JMH and should not be called directly. + */ + public RecallWithRandomVectorsBenchmark() { + } private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); + + /** Random access wrapper for the base vectors. */ private RandomAccessVectorValues ravv; + + /** The base vectors in the searchable dataset. */ private ArrayList> baseVectors; + + /** Query vectors used to test search recall. */ private ArrayList> queryVectors; + + /** Builder used to construct the graph index. */ private GraphIndexBuilder graphIndexBuilder; + + /** The constructed graph index used for ANN search. */ private ImmutableGraphIndex graphIndex; + + /** Product-quantized versions of the base vectors, or null if numberOfPQSubspaces=0. */ private PQVectors pqVectors; - // Add ground truth storage + /** Ground truth nearest neighbors for each query, computed via brute force. */ private ArrayList groundTruth; + /** + * The dimensionality of the vectors. + *

+ * Default value: 1536 (typical for modern embedding models). + */ @Param({"1536"}) int originalDimension; + + /** + * The number of base vectors in the dataset. + *

+ * Default value: 100000 + */ @Param({"100000"}) int numBaseVectors; + + /** + * The number of query vectors to test. + *

+ * Default value: 10 + */ @Param({"10"}) int numQueryVectors; + + /** + * The number of subspaces for Product Quantization. + *

+ * When numberOfPQSubspaces=0, uses full precision vectors without quantization. + * When numberOfPQSubspaces>0, applies PQ compression with approximate scoring. + * Values: 0 (no PQ), 16, 32, 64, 96, 192 + */ @Param({"0", "16", "32", "64", "96", "192"}) int numberOfPQSubspaces; - @Param({/*"10",*/ "50"}) // Add different k values for recall calculation + + /** + * The number of nearest neighbors to retrieve (k). + *

+ * Default value: 50 + */ + @Param({/*"10",*/ "50"}) int k; + + /** + * The over-query factor for PQ searches. + *

+ * When using PQ, the search retrieves k * overQueryFactor candidates using approximate + * distances, then re-ranks them with exact distances to select the final k results. + * Only applies when numberOfPQSubspaces > 0. + *

+ * Default value: 5 + */ @Param({"5"}) int overQueryFactor; + /** + * Sets up the benchmark by creating random vectors, building the graph index, and computing ground truth. + *

+ * This method performs the following steps: + *

    + *
  1. Generates random base vectors and query vectors
  2. + *
  3. Optionally computes Product Quantization if numberOfPQSubspaces > 0
  4. + *
  5. Builds an HNSW graph index for ANN search
  6. + *
  7. Computes exact nearest neighbors via brute force for recall measurement
  8. + *
+ * The graph is configured with: + *
    + *
  • Degree: 16 (max edges per node)
  • + *
  • Construction depth: 100 (beam width during construction)
  • + *
  • Alpha: 1.2 (degree overflow allowance)
  • + *
  • Diversity alpha: 1.2 (neighbor diversity requirement)
  • + *
  • Hierarchy: enabled
  • + *
+ * + * @throws IOException if there is an error during setup + */ @Setup public void setup() throws IOException { baseVectors = new ArrayList<>(numBaseVectors); @@ -112,6 +214,15 @@ public void setup() throws IOException { calculateGroundTruth(); } + /** + * Creates a random vector with the specified dimension. + *

+ * Each component of the vector is assigned a random floating-point value + * between 0.0 (inclusive) and 1.0 (exclusive). + * + * @param dimension the number of dimensions for the vector + * @return a new random vector + */ private VectorFloat createRandomVector(int dimension) { VectorFloat vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension); for (int i = 0; i < dimension; i++) { @@ -120,6 +231,14 @@ private VectorFloat createRandomVector(int dimension) { return vector; } + /** + * Tears down the benchmark by releasing resources. + *

+ * This method clears all vectors and closes the graph index builder to release + * any associated resources. + * + * @throws IOException if there is an error during teardown + */ @TearDown public void tearDown() throws IOException { baseVectors.clear(); @@ -127,21 +246,68 @@ public void tearDown() throws IOException { graphIndexBuilder.close(); } + /** + * Auxiliary counters for tracking recall and search statistics across benchmark iterations. + *

+ * These counters accumulate metrics from each benchmark iteration and compute running averages. + * JMH reports these values as additional benchmark results alongside timing measurements. + */ @AuxCounters(AuxCounters.Type.EVENTS) @State(Scope.Thread) public static class RecallCounters { + /** + * Creates a new counter instance. + *

+ * This constructor is invoked by JMH and should not be called directly. + */ + public RecallCounters() { + } + + /** The average recall across all iterations. */ public double avgRecall = 0; + + /** The average number of candidates re-ranked per query across all iterations. */ public double avgReRankedCount = 0; + + /** The average number of graph nodes visited per query across all iterations. */ public double avgVisitedCount = 0; + + /** The average number of graph nodes expanded per query across all iterations. */ public double avgExpandedCount = 0; + + /** The average number of base layer nodes expanded per query across all iterations. */ public double avgExpandedCountBaseLayer = 0; + + /** The number of benchmark iterations completed. */ private int iterations = 0; + + /** The cumulative recall across all iterations. */ private double totalRecall = 0; + + /** The cumulative re-ranked count across all iterations. */ private double totalReRankedCount = 0; + + /** The cumulative visited count across all iterations. */ private double totalVisitedCount = 0; + + /** The cumulative expanded count across all iterations. */ private double totalExpandedCount = 0; + + /** The cumulative base layer expanded count across all iterations. */ private double totalExpandedCountBaseLayer = 0; + /** + * Adds results from a single benchmark iteration and updates running averages. + *

+ * This method is called after each benchmark iteration to accumulate statistics + * and compute new average values. + * + * @param avgIterationRecall the average recall for this iteration + * @param avgIterationReRankedCount the average re-ranked count for this iteration + * @param avgIterationVisitedCount the average visited count for this iteration + * @param avgIterationExpandedCount the average expanded count for this iteration + * @param avgIterationExpandedCountBaseLayer the average base layer expanded count for this iteration + */ public void addResults(double avgIterationRecall, double avgIterationReRankedCount, double avgIterationVisitedCount, double avgIterationExpandedCount, double avgIterationExpandedCountBaseLayer) { log.info("adding results avgIterationRecall: {}, avgIterationReRankedCount: {}, avgIterationVisitedCount: {}, avgIterationExpandedCount: {}, avgIterationExpandedCountBaseLayer: {}", avgIterationRecall, avgIterationReRankedCount, avgIterationVisitedCount, avgIterationExpandedCount, avgIterationExpandedCountBaseLayer); totalRecall += avgIterationRecall; @@ -159,6 +325,29 @@ public void addResults(double avgIterationRecall, double avgIterationReRankedCou } + /** + * Benchmarks ANN search with recall measurement on random vectors. + *

+ * This benchmark performs graph searches for all query vectors and measures: + *

    + *
  • Search time (via JMH timing)
  • + *
  • Recall quality (fraction of true nearest neighbors found)
  • + *
  • Search statistics (nodes visited, expanded, re-ranked)
  • + *
+ *

+ * The search behavior depends on the numberOfPQSubspaces parameter: + *

    + *
  • When numberOfPQSubspaces=0: Uses exact distance calculations throughout
  • + *
  • When numberOfPQSubspaces>0: Uses PQ approximate distances for initial search, + * then re-ranks top candidates with exact distances
  • + *
+ *

+ * Recall is computed by comparing search results against ground truth exact nearest neighbors. + * + * @param blackhole JMH blackhole to prevent dead code elimination + * @param counters auxiliary counters for accumulating recall and search statistics + * @throws IOException if there is an error during search + */ @Benchmark public void testOnHeapRandomVectorsWithRecall(Blackhole blackhole, RecallCounters counters) throws IOException { double totalRecall = 0.0; @@ -218,6 +407,14 @@ public void testOnHeapRandomVectorsWithRecall(Blackhole blackhole, RecallCounter } + /** + * Calculates exact nearest neighbors for all query vectors via brute force. + *

+ * This method computes ground truth by performing exhaustive distance calculations + * between each query vector and all base vectors. The top-k nearest neighbors for + * each query are stored for later recall computation. This is computationally expensive + * but provides the true nearest neighbors needed to measure search quality. + */ private void calculateGroundTruth() { groundTruth = new ArrayList<>(queryVectors.size()); @@ -242,6 +439,18 @@ private void calculateGroundTruth() { } } + /** + * Calculates recall by comparing predicted results against ground truth. + *

+ * Recall is the fraction of true nearest neighbors that appear in the search results. + * This method compares the node IDs from the search results against the ground truth + * nearest neighbors and counts how many matches are found. + * + * @param predicted the set of node IDs returned by the search + * @param groundTruth the array of true nearest neighbor node IDs + * @param k the number of neighbors to consider + * @return the recall value between 0.0 and 1.0 + */ private double calculateRecall(Set predicted, int[] groundTruth, int k) { int hits = 0; int actualK = Math.min(k, Math.min(predicted.size(), groundTruth.length)); diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java index a3651aabc..3dc903b9e 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java @@ -31,6 +31,26 @@ import java.util.List; import java.util.concurrent.TimeUnit; +/** + * Benchmark for measuring graph search performance on the SIFT dataset. + *

+ * This benchmark evaluates the performance of approximate nearest neighbor (ANN) search + * using a hierarchical navigable small world (HNSW) graph index on the SIFT Small dataset. + * Unlike the random vector benchmarks, this uses real-world image feature vectors with + * realistic distributions and correlations. + *

+ * The benchmark builds a graph index once during setup using the SIFT base vectors, + * then measures search time using random query vectors. This focuses purely on search + * performance without recall measurement. + *

+ * Key characteristics: + *

    + *
  • Uses SIFT Small dataset (10,000 base vectors, 128-dimensional)
  • + *
  • Full precision vectors (no quantization)
  • + *
  • Random query vectors generated at search time
  • + *
  • Measures pure search throughput
  • + *
+ */ @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.MILLISECONDS) @State(Scope.Thread) @@ -40,14 +60,57 @@ @Threads(1) public class StaticSetVectorsBenchmark { private static final Logger log = LoggerFactory.getLogger(StaticSetVectorsBenchmark.class); + + /** + * Creates a new benchmark instance. + *

+ * This constructor is invoked by JMH and should not be called directly. + */ + public StaticSetVectorsBenchmark() { + } + + /** Random access wrapper for the SIFT base vectors. */ private RandomAccessVectorValues ravv; + + /** The SIFT base vectors in the searchable dataset. */ private List> baseVectors; + + /** The SIFT query vectors (loaded but not used in this benchmark). */ private List> queryVectors; + + /** Ground truth nearest neighbors (loaded but not used in this benchmark). */ private List> groundTruth; + + /** Builder used to construct the graph index. */ private GraphIndexBuilder graphIndexBuilder; + + /** The constructed graph index used for ANN search. */ private ImmutableGraphIndex graphIndex; + + /** The dimensionality of the SIFT vectors (128). */ int originalDimension; + /** + * Sets up the benchmark by loading the SIFT dataset and building the graph index. + *

+ * This method performs the following steps: + *

    + *
  1. Loads SIFT base vectors, query vectors, and ground truth from the filesystem
  2. + *
  3. Wraps the base vectors in a RandomAccessVectorValues instance
  4. + *
  5. Creates a BuildScoreProvider for exact distance calculations
  6. + *
  7. Builds an HNSW graph index with the following configuration: + *
      + *
    • Degree: 16 (max edges per node)
    • + *
    • Construction depth: 100 (beam width during construction)
    • + *
    • Alpha: 1.2 (degree overflow allowance)
    • + *
    • Diversity alpha: 1.2 (neighbor diversity requirement)
    • + *
    • Hierarchy: enabled
    • + *
    + *
  8. + *
+ * + * @throws IOException if there is an error loading the SIFT dataset or building the index + */ @Setup public void setup() throws IOException { var siftPath = "siftsmall"; @@ -73,6 +136,14 @@ public void setup() throws IOException { graphIndex = graphIndexBuilder.build(ravv); } + /** + * Tears down the benchmark by releasing resources. + *

+ * This method clears all loaded vectors and closes the graph index builder to release + * any associated resources. + * + * @throws IOException if there is an error during teardown + */ @TearDown public void tearDown() throws IOException { baseVectors.clear(); @@ -81,6 +152,19 @@ public void tearDown() throws IOException { graphIndexBuilder.close(); } + /** + * Benchmarks graph search performance using random query vectors. + *

+ * This benchmark measures the time to perform a single ANN search using a randomly + * generated query vector. The search uses exact distance calculations (no quantization) + * and returns the 10 nearest neighbors from the SIFT base vectors. + *

+ * Each benchmark iteration generates a new random query vector with the same dimensionality + * as the SIFT vectors (128), ensuring that the search operates on fresh data each time. + * + * @param blackhole JMH blackhole to prevent dead code elimination + * @throws IOException if there is an error during search + */ @Benchmark public void testOnHeapWithRandomQueryVectors(Blackhole blackhole) throws IOException { var queryVector = SiftSmall.randomVector(originalDimension); diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/package-info.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/package-info.java new file mode 100644 index 000000000..fd502d832 --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/package-info.java @@ -0,0 +1,122 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * JMH benchmarks for measuring the performance of JVector's core components. + *

+ * This package contains Java Microbenchmark Harness (JMH) benchmarks that evaluate + * various aspects of JVector's vector search functionality, including: + *

    + *
  • Product Quantization (PQ) training and distance calculations
  • + *
  • Graph construction and search performance
  • + *
  • Recall quality measurement
  • + *
  • Performance comparisons between full-precision and quantized vectors
  • + *
+ * + *

Benchmark Categories

+ * + *

Product Quantization Benchmarks

+ *
    + *
  • {@link io.github.jbellis.jvector.bench.PQDistanceCalculationBenchmark} - Compares distance + * calculation performance between full-precision and PQ-compressed vectors
  • + *
  • {@link io.github.jbellis.jvector.bench.PQTrainingWithRandomVectorsBenchmark} - Measures + * PQ training time on randomly generated vectors
  • + *
  • {@link io.github.jbellis.jvector.bench.PQTrainingWithSiftBenchmark} - Measures PQ training + * time on the SIFT Small dataset with real-world vectors
  • + *
+ * + *

Graph Search Benchmarks

+ *
    + *
  • {@link io.github.jbellis.jvector.bench.RecallWithRandomVectorsBenchmark} - Evaluates + * search performance and recall quality with and without PQ compression on random vectors
  • + *
  • {@link io.github.jbellis.jvector.bench.StaticSetVectorsBenchmark} - Measures pure search + * throughput on the SIFT Small dataset
  • + *
+ * + *

Running Benchmarks

+ * + *

+ * These benchmarks are packaged as a standalone executable JAR using Maven Shade Plugin. + * To build and run the benchmarks: + *

+ * + *
+ * # Build the shaded JAR
+ * mvn clean package -pl benchmarks-jmh
+ *
+ * # Run all benchmarks
+ * java -jar benchmarks-jmh/target/benchmarks-jmh-*.jar
+ *
+ * # Run a specific benchmark
+ * java -jar benchmarks-jmh/target/benchmarks-jmh-*.jar PQDistanceCalculationBenchmark
+ *
+ * # Run with custom JMH options
+ * java -jar benchmarks-jmh/target/benchmarks-jmh-*.jar -h  # Show help
+ * 
+ * + *

Benchmark Configuration

+ * + *

+ * Most benchmarks use JMH annotations to configure: + *

+ *
    + *
  • Mode: Typically AverageTime (measures average execution time per operation)
  • + *
  • Time Units: Typically MILLISECONDS for reporting results
  • + *
  • Warmup: Multiple iterations to allow JVM warmup and JIT compilation
  • + *
  • Measurement: Multiple iterations for statistically significant results
  • + *
  • Fork: Usually 1 fork for faster execution, increase for production-grade results
  • + *
  • Parameters: {@code @Param} annotations define benchmark variants (dimensions, vector counts, etc.)
  • + *
+ * + *

Data Sources

+ * + *

+ * The benchmarks use two types of vector datasets: + *

+ *
    + *
  • Random Vectors: Generated programmatically with uniform random values, + * useful for controlled testing and scalability evaluation
  • + *
  • SIFT Small Dataset: Real-world image feature vectors (10,000 base vectors, + * 100 queries, 128 dimensions), available at + * http://corpus-texmex.irisa.fr/
  • + *
+ * + *

Understanding Results

+ * + *

+ * JMH produces detailed output including: + *

+ *
    + *
  • Score: Average time per operation (or ops/sec depending on mode)
  • + *
  • Error: 99.9% confidence interval (lower is better)
  • + *
  • Auxiliary Counters: Some benchmarks report additional metrics like recall, + * visited nodes, etc.
  • + *
+ * + *

Best Practices

+ * + *
    + *
  • Run benchmarks on a quiet system with minimal background processes
  • + *
  • Use multiple forks ({@code -f 3}) for production-grade measurements
  • + *
  • Increase warmup and measurement iterations for stable results
  • + *
  • Be aware that vector incubator module performance may vary by platform
  • + *
  • Consider using JMH profilers ({@code -prof}) to understand hotspots
  • + *
+ * + * @see JMH Documentation + * @see SIFT Dataset + */ +package io.github.jbellis.jvector.bench; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/annotations/package-info.java b/jvector-base/src/main/java/io/github/jbellis/jvector/annotations/package-info.java new file mode 100644 index 000000000..d975c5985 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/annotations/package-info.java @@ -0,0 +1,63 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides annotation types for documenting API stability and visibility constraints. + *

+ * This package contains marker annotations used throughout JVector to communicate + * API stability guarantees and visibility intentions to library users. + * + *

Available Annotations

+ *
    + *
  • {@link io.github.jbellis.jvector.annotations.Experimental} - Marks APIs that are + * experimental and may change or be removed in future releases without prior notice. + * Users should avoid depending on experimental APIs in production code.
  • + *
  • {@link io.github.jbellis.jvector.annotations.VisibleForTesting} - Marks classes, + * methods, or fields that are made visible (typically package-private or public) + * solely for testing purposes. These elements are internal implementation details + * and may change without warning despite their visibility level.
  • + *
+ * + *

Usage Guidelines

+ *

+ * When using JVector as a library: + *

    + *
  • Avoid using APIs marked with {@code @Experimental} in production code, as they + * may be modified or removed in any release.
  • + *
  • Do not rely on APIs marked with {@code @VisibleForTesting}, even if they are + * technically accessible. These are implementation details that may change without + * following semantic versioning rules.
  • + *
+ * + *

Example Usage

+ *
{@code
+ * @Experimental
+ * public class NewFeature {
+ *     // This feature is experimental and may change
+ * }
+ *
+ * public class GraphIndexBuilder {
+ *     @VisibleForTesting
+ *     public void setEntryPoint(int level, int node) {
+ *         // Made public for testing but intended as internal API
+ *     }
+ * }
+ * }
+ * + * @see io.github.jbellis.jvector.annotations.Experimental + * @see io.github.jbellis.jvector.annotations.VisibleForTesting + */ +package io.github.jbellis.jvector.annotations; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/BufferedRandomAccessWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/BufferedRandomAccessWriter.java index 8e13df0e7..9511944f4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/BufferedRandomAccessWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/BufferedRandomAccessWriter.java @@ -35,14 +35,27 @@ public class BufferedRandomAccessWriter implements RandomAccessWriter { private final RandomAccessFile raf; private final DataOutputStream stream; + /** + * Creates a BufferedRandomAccessWriter for the specified file path. + * @param path the path to the file to write to + * @throws FileNotFoundException if the file cannot be created or opened for writing + */ public BufferedRandomAccessWriter(Path path) throws FileNotFoundException { raf = new RandomAccessFile(path.toFile(), "rw"); stream = new DataOutputStream(new BufferedOutputStream(new RandomAccessOutputStream(raf))); } + /** + * An OutputStream wrapper around a RandomAccessFile. + * This allows the RandomAccessFile to be buffered using standard Java I/O classes. + */ private static class RandomAccessOutputStream extends OutputStream { private final RandomAccessFile raf; + /** + * Creates a RandomAccessOutputStream that writes to the given RandomAccessFile. + * @param raf the RandomAccessFile to write to + */ public RandomAccessOutputStream(RandomAccessFile raf) { this.raf = raf; } @@ -88,10 +101,8 @@ public void flush() throws IOException { } /** - * return the CRC32 checksum for the range [startOffset .. endOffset) - *

- * the file pointer will be left at endOffset. - *

+ * Computes and returns the CRC32 checksum for the range [startOffset .. endOffset). + * The file pointer will be left at endOffset. */ @Override public long checksum(long startOffset, long endOffset) throws IOException { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferReader.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferReader.java index f74ec76eb..a8b11e15e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferReader.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferReader.java @@ -23,8 +23,13 @@ * RandomAccessReader that reads from a ByteBuffer */ public class ByteBufferReader implements RandomAccessReader { + /** The underlying ByteBuffer for reading data. */ protected final ByteBuffer bb; + /** + * Creates a ByteBufferReader that reads from the given ByteBuffer. + * @param sourceBB the ByteBuffer to read from + */ public ByteBufferReader(ByteBuffer sourceBB) { bb = sourceBB; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/IndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/IndexWriter.java index 9a214425b..d1adcde4d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/IndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/IndexWriter.java @@ -20,9 +20,18 @@ import java.io.DataOutput; import java.io.IOException; +/** + * An interface for writing index data that combines DataOutput and Closeable capabilities + * with position tracking. + *

+ * This interface is used by index writers to provide sequential write access with + * the ability to query the current write position. + */ public interface IndexWriter extends DataOutput, Closeable { /** + * Returns the current position in the output stream. * @return the current position in the output + * @throws IOException if an I/O error occurs */ long position() throws IOException; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessReader.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessReader.java index a09081a28..2e69cc604 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessReader.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessReader.java @@ -32,32 +32,99 @@ * uses the ReaderSupplier API to create a RandomAccessReader per thread, as needed. */ public interface RandomAccessReader extends AutoCloseable { + /** + * Seeks to the specified position in the reader. + * @param offset the offset position to seek to + * @throws IOException if an I/O error occurs + */ void seek(long offset) throws IOException; + /** + * Returns the current position in the reader. + * @return the current position in the reader + * @throws IOException if an I/O error occurs + */ long getPosition() throws IOException; + /** + * Reads and returns a 32-bit integer. + * @return the integer value read + * @throws IOException if an I/O error occurs + */ int readInt() throws IOException; + /** + * Reads and returns a 32-bit float. + * @return the float value read + * @throws IOException if an I/O error occurs + */ float readFloat() throws IOException; + /** + * Reads and returns a 64-bit long. + * @return the long value read + * @throws IOException if an I/O error occurs + */ long readLong() throws IOException; + /** + * Reads bytes to completely fill the specified byte array. + * @param bytes the byte array to fill + * @throws IOException if an I/O error occurs + */ void readFully(byte[] bytes) throws IOException; + /** + * Reads bytes to completely fill the specified ByteBuffer. + * @param buffer the ByteBuffer to fill + * @throws IOException if an I/O error occurs + */ void readFully(ByteBuffer buffer) throws IOException; + /** + * Reads floats to completely fill the specified float array. + * @param floats the float array to fill + * @throws IOException if an I/O error occurs + */ default void readFully(float[] floats) throws IOException { read(floats, 0, floats.length); } + /** + * Reads longs to completely fill the specified long array. + * @param vector the long array to fill + * @throws IOException if an I/O error occurs + */ void readFully(long[] vector) throws IOException; + /** + * Reads a specified number of integers into an array starting at the given offset. + * @param ints the array to read integers into + * @param offset the starting position in the array + * @param count the number of integers to read + * @throws IOException if an I/O error occurs + */ void read(int[] ints, int offset, int count) throws IOException; + /** + * Reads a specified number of floats into an array starting at the given offset. + * @param floats the array to read floats into + * @param offset the starting position in the array + * @param count the number of floats to read + * @throws IOException if an I/O error occurs + */ void read(float[] floats, int offset, int count) throws IOException; + /** + * Closes this reader and releases any system resources associated with it. + * @throws IOException if an I/O error occurs + */ void close() throws IOException; - // Length of the reader slice + /** + * Returns the length of the reader slice. + * @return the length of the reader slice + * @throws IOException if an I/O error occurs + */ long length() throws IOException; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessWriter.java index ef7202894..3c1e57a82 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessWriter.java @@ -21,12 +21,31 @@ import java.io.IOException; /** - * A DataOutput that adds methods for random access writes + * A DataOutput that adds methods for random access writes. + *

+ * This interface extends IndexWriter to provide seek capability and checksum computation, + * enabling efficient random access write patterns and data integrity verification. */ public interface RandomAccessWriter extends IndexWriter { + /** + * Seeks to the specified position in the output. + * @param position the position to seek to + * @throws IOException if an I/O error occurs + */ void seek(long position) throws IOException; + /** + * Flushes any buffered data to the underlying storage. + * @throws IOException if an I/O error occurs + */ void flush() throws IOException; + /** + * Computes and returns a CRC32 checksum for the specified byte range. + * @param startOffset the starting offset of the range (inclusive) + * @param endOffset the ending offset of the range (exclusive) + * @return the CRC32 checksum value + * @throws IOException if an I/O error occurs + */ long checksum(long startOffset, long endOffset) throws IOException; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ReaderSupplier.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ReaderSupplier.java index 8f8d2ae2c..f87e13474 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ReaderSupplier.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ReaderSupplier.java @@ -23,11 +23,20 @@ */ public interface ReaderSupplier extends AutoCloseable { /** - * @return a new reader. It is up to the caller to re-use these readers or close them, + * Returns a new reader. It is up to the caller to re-use these readers or close them, * the ReaderSupplier is not responsible for caching them. + * + * @return a new RandomAccessReader instance + * @throws IOException if an I/O error occurs */ RandomAccessReader get() throws IOException; + /** + * Closes this ReaderSupplier and releases any resources. + * The default implementation does nothing. + * + * @throws IOException if an I/O error occurs + */ default void close() throws IOException { } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ReaderSupplierFactory.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ReaderSupplierFactory.java index dd0a22659..f298c76af 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ReaderSupplierFactory.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ReaderSupplierFactory.java @@ -22,11 +22,31 @@ import java.util.logging.Level; import java.util.logging.Logger; +/** + * Factory for creating ReaderSupplier instances with automatic fallback based on available implementations. + * Tries in order: MemorySegmentReader (JDK 20+), MMapReader (requires native library), MappedChunkReader (fallback). + */ public class ReaderSupplierFactory { private static final Logger LOG = Logger.getLogger(ReaderSupplierFactory.class.getName()); private static final String MEMORY_SEGMENT_READER_CLASSNAME = "io.github.jbellis.jvector.disk.MemorySegmentReader$Supplier"; private static final String MMAP_READER_CLASSNAME = "io.github.jbellis.jvector.example.util.MMapReader$Supplier"; + /** + * Private constructor to prevent instantiation of this utility class. + */ + private ReaderSupplierFactory() { + throw new AssertionError("ReaderSupplierFactory should not be instantiated"); + } + + /** + * Opens a ReaderSupplier for the given path, using the best available implementation. + * Attempts to use MemorySegmentReader first (JDK 20+), then MMapReader (native library), + * and finally falls back to MappedChunkReader. + * + * @param path the path to the file to read + * @return a ReaderSupplier for accessing the file + * @throws IOException if the file cannot be opened + */ public static ReaderSupplier open(Path path) throws IOException { try { // prefer MemorySegmentReader (available under JDK 20+) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReader.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReader.java index f55688e94..3559ca6c8 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReader.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReader.java @@ -45,7 +45,11 @@ private static Unsafe getUnsafe() { } } - + /** + * Constructs a SimpleMappedReader wrapping the specified memory-mapped buffer. + * + * @param mbb the memory-mapped byte buffer to read from + */ SimpleMappedReader(MappedByteBuffer mbb) { super(mbb); } @@ -55,10 +59,24 @@ public void close() { // Individual readers don't close anything } + /** + * Supplier that creates SimpleMappedReader instances from a memory-mapped file. + * The file is mapped into memory once during construction and shared across all readers. + */ public static class Supplier implements ReaderSupplier { + /** The shared memory-mapped buffer for this file. */ private final MappedByteBuffer buffer; + /** Unsafe instance for invoking the buffer cleaner when closing. */ private static final Unsafe unsafe = getUnsafe(); + /** + * Constructs a Supplier that memory-maps the file at the specified path. + * The entire file is loaded into memory. Files larger than 2GB are not supported. + * + * @param path the path to the file to map + * @throws IOException if an I/O error occurs + * @throws RuntimeException if the file is larger than 2GB + */ public Supplier(Path path) throws IOException { try (var raf = new RandomAccessFile(path.toString(), "r")) { if (raf.length() > Integer.MAX_VALUE) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleReader.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleReader.java index 07d9488a1..c22da637d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleReader.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleReader.java @@ -26,9 +26,20 @@ // TODO what are the low-hanging optimization options here? // The requirement that we need to read from a file that is potentially changing in length // limits our options. +/** + * Simple implementation of RandomAccessReader using RandomAccessFile. + * Suitable for files that may be changing in length during access. + */ public class SimpleReader implements RandomAccessReader { + /** The underlying random access file. */ RandomAccessFile raf; + /** + * Constructs a SimpleReader for the file at the specified path. + * + * @param path the path to the file to read + * @throws FileNotFoundException if the file does not exist + */ public SimpleReader(Path path) throws FileNotFoundException { raf = new RandomAccessFile(path.toFile(), "r"); } @@ -105,9 +116,19 @@ public long length() throws IOException { return raf.length(); } + /** + * Supplier that creates SimpleReader instances for a given file path. + * Each call to {@link #get()} creates a new SimpleReader with its own file handle. + */ public static class Supplier implements ReaderSupplier { + /** The path to the file to read. */ private final Path path; + /** + * Constructs a Supplier for the file at the specified path. + * + * @param path the path to the file to read + */ public Supplier(Path path) { this.path = path; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleWriter.java index e6462a373..6fe7c5349 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleWriter.java @@ -30,6 +30,11 @@ public class SimpleWriter implements IndexWriter { private final FileOutputStream fos; private final DataOutputStream dos; + /** + * Creates a new SimpleWriter that writes to the specified path. + * @param path the path to write to + * @throws IOException if an I/O error occurs opening the file + */ public SimpleWriter(Path path) throws IOException { fos = new FileOutputStream(path.toFile()); dos = new DataOutputStream(fos); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/package-info.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/package-info.java new file mode 100644 index 000000000..dd28660b6 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/package-info.java @@ -0,0 +1,111 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides low-level I/O abstractions for reading and writing binary data. + *

+ * This package contains interfaces and implementations for efficient random access I/O operations + * used throughout JVector. These abstractions support both memory-mapped and traditional buffered + * I/O strategies, enabling optimal performance across different use cases and storage backends. + * + *

Core Abstractions

+ * + *

Reader Interfaces

+ *
    + *
  • {@link io.github.jbellis.jvector.disk.RandomAccessReader} - Interface for reading data + * with seek capability. Supports reading primitive types (int, long, float) and bulk reads + * into arrays and buffers. Designed for sequential reads after seeking to a position.
  • + *
  • {@link io.github.jbellis.jvector.disk.ReaderSupplier} - Factory interface for creating + * {@code RandomAccessReader} instances. Used to provide thread-local readers since + * {@code RandomAccessReader} implementations are stateful and not thread-safe.
  • + *
  • {@link io.github.jbellis.jvector.disk.ReaderSupplierFactory} - Factory for creating + * {@code ReaderSupplier} instances from files. Provides the recommended entry point for + * opening files for reading.
  • + *
+ * + *

Writer Interfaces

+ *
    + *
  • {@link io.github.jbellis.jvector.disk.IndexWriter} - Base interface for sequential data + * writing with position tracking.
  • + *
  • {@link io.github.jbellis.jvector.disk.RandomAccessWriter} - Extends {@code IndexWriter} + * with seek capability for random access writes and checksum computation.
  • + *
+ * + *

Implementations

+ * + *

Readers

+ *
    + *
  • {@link io.github.jbellis.jvector.disk.SimpleReader} - Buffered file reader using + * {@code FileChannel}
  • + *
  • {@link io.github.jbellis.jvector.disk.SimpleMappedReader} - Memory-mapped file reader + * for optimal performance with large files
  • + *
  • {@link io.github.jbellis.jvector.disk.ByteBufferReader} - Reader backed by a + * {@code ByteBuffer}, useful for in-memory data
  • + *
  • {@link io.github.jbellis.jvector.disk.MappedChunkReader} - Chunked memory-mapped reader + * for handling files larger than the maximum mapping size
  • + *
+ * + *

Writers

+ *
    + *
  • {@link io.github.jbellis.jvector.disk.SimpleWriter} - Basic file writer using + * {@code FileChannel}
  • + *
  • {@link io.github.jbellis.jvector.disk.BufferedRandomAccessWriter} - Buffered writer with + * random access and checksum support, recommended for most writing scenarios
  • + *
+ * + *

Usage Pattern

+ *

+ * The recommended usage pattern for reading is: + *

{@code
+ * // Open a file with a ReaderSupplierFactory
+ * try (ReaderSupplier readerSupplier = ReaderSupplierFactory.open(path)) {
+ *     // Get a thread-local reader
+ *     try (RandomAccessReader reader = readerSupplier.get()) {
+ *         reader.seek(offset);
+ *         int value = reader.readInt();
+ *         float[] vector = new float[dimension];
+ *         reader.readFully(vector);
+ *     }
+ * }
+ * }
+ * + *

+ * For writing: + *

{@code
+ * try (RandomAccessWriter writer = new BufferedRandomAccessWriter(path)) {
+ *     writer.writeInt(42);
+ *     writer.writeFloat(3.14f);
+ *     long position = writer.getPosition();
+ *     writer.seek(0);  // Go back and update header
+ *     writer.writeLong(position);
+ * }
+ * }
+ * + *

Thread Safety

+ *
    + *
  • {@code RandomAccessReader} implementations are not thread-safe. Use + * {@code ReaderSupplier} to create separate instances per thread.
  • + *
  • {@code RandomAccessWriter} implementations are not thread-safe. Coordinate + * access externally if needed.
  • + *
  • {@code ReaderSupplier} implementations are typically thread-safe and can be shared + * across threads to create per-thread readers.
  • + *
+ * + * @see io.github.jbellis.jvector.disk.RandomAccessReader + * @see io.github.jbellis.jvector.disk.ReaderSupplierFactory + * @see io.github.jbellis.jvector.disk.RandomAccessWriter + */ +package io.github.jbellis.jvector.disk; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/exceptions/ThreadInterruptedException.java b/jvector-base/src/main/java/io/github/jbellis/jvector/exceptions/ThreadInterruptedException.java index 0445252a1..a629cc2a8 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/exceptions/ThreadInterruptedException.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/exceptions/ThreadInterruptedException.java @@ -24,7 +24,16 @@ package io.github.jbellis.jvector.exceptions; +/** + * Wraps an {@link InterruptedException} as an unchecked exception. + * This exception is thrown when a thread is interrupted during an operation + * that does not declare {@code InterruptedException} in its signature. + */ public final class ThreadInterruptedException extends RuntimeException { + /** + * Constructs a new ThreadInterruptedException wrapping the given InterruptedException. + * @param ie the InterruptedException that caused this exception + */ public ThreadInterruptedException(InterruptedException ie) { super(ie); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/exceptions/package-info.java b/jvector-base/src/main/java/io/github/jbellis/jvector/exceptions/package-info.java new file mode 100644 index 000000000..cb81431a6 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/exceptions/package-info.java @@ -0,0 +1,61 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides custom exception types used throughout JVector. + *

+ * This package contains specialized exception classes that represent error conditions + * specific to JVector operations. These exceptions extend standard Java exception types + * to provide more specific error handling capabilities. + * + *

Exception Types

+ *
    + *
  • {@link io.github.jbellis.jvector.exceptions.ThreadInterruptedException} - An unchecked + * exception that wraps {@link InterruptedException}. This is used in contexts where + * methods cannot declare checked exceptions but need to propagate thread interruption + * signals. The wrapped {@code InterruptedException} is preserved as the cause.
  • + *
+ * + *

Usage Guidelines

+ *

+ * {@code ThreadInterruptedException} is typically thrown by JVector when: + *

    + *
  • A thread is interrupted during graph construction or search operations
  • + *
  • The operation is running in a context that does not allow checked exceptions + * (such as lambda expressions or stream operations)
  • + *
+ * + *

Exception Handling Example

+ *
{@code
+ * try {
+ *     GraphIndexBuilder builder = new GraphIndexBuilder(...);
+ *     builder.build(vectors);
+ * } catch (ThreadInterruptedException e) {
+ *     // Thread was interrupted during graph construction
+ *     Thread.currentThread().interrupt(); // Restore interrupt status
+ *     logger.warn("Graph construction was interrupted", e);
+ * }
+ * }
+ * + *

+ * When catching {@code ThreadInterruptedException}, it is generally recommended to restore + * the thread's interrupt status by calling {@code Thread.currentThread().interrupt()} unless + * you are certain the interruption has been properly handled. + * + * @see io.github.jbellis.jvector.exceptions.ThreadInterruptedException + * @see java.lang.InterruptedException + */ +package io.github.jbellis.jvector.exceptions; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java index 891fda756..6330ace9c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java @@ -41,10 +41,26 @@ public class ConcurrentNeighborMap { /** the maximum number of neighbors a node can have temporarily during construction */ public final int maxOverflowDegree; + /** + * Constructs a new ConcurrentNeighborMap with default initial capacity. + * + * @param diversityProvider the provider for diversity calculations + * @param maxDegree the maximum number of neighbors desired per node + * @param maxOverflowDegree the maximum number of neighbors a node can have temporarily during construction + */ public ConcurrentNeighborMap(DiversityProvider diversityProvider, int maxDegree, int maxOverflowDegree) { this(new DenseIntMap<>(1024), diversityProvider, maxDegree, maxOverflowDegree); } + /** + * Constructs a new ConcurrentNeighborMap with the given neighbors map. + * + * @param the type parameter (unused, for compatibility) + * @param neighbors the neighbors map to use + * @param diversityProvider the provider for diversity calculations + * @param maxDegree the maximum number of neighbors desired per node + * @param maxOverflowDegree the maximum number of neighbors a node can have temporarily during construction + */ public ConcurrentNeighborMap(IntMap neighbors, DiversityProvider diversityProvider, int maxDegree, int maxOverflowDegree) { assert maxDegree <= maxOverflowDegree : String.format("maxDegree %d exceeds maxOverflowDegree %d", maxDegree, maxOverflowDegree); this.neighbors = neighbors; @@ -53,6 +69,15 @@ public ConcurrentNeighborMap(IntMap neighbors, DiversityProvider this.maxOverflowDegree = maxOverflowDegree; } + /** + * Inserts an edge from one node to another with the given score and overflow factor. + * This method is thread-safe and uses compare-and-swap to update the neighbor set. + * + * @param fromId the source node id + * @param toId the target node id + * @param score the similarity score for this edge + * @param overflow the factor by which to allow exceeding maxDegree temporarily + */ public void insertEdge(int fromId, int toId, float score, float overflow) { while (true) { var old = neighbors.get(fromId); @@ -63,6 +88,14 @@ public void insertEdge(int fromId, int toId, float score, float overflow) { } } + /** + * Inserts an edge without enforcing diversity constraints. This should only be called + * during cleanup operations after initial graph construction. + * + * @param fromId the source node id + * @param toId the target node id + * @param score the similarity score for this edge + */ public void insertEdgeNotDiverse(int fromId, int toId, float score) { while (true) { var old = neighbors.get(fromId); @@ -74,6 +107,10 @@ public void insertEdgeNotDiverse(int fromId, int toId, float score) { } /** + * Enforces the maximum degree constraint on a node by pruning to maxDegree neighbors + * using diversity selection. + * + * @param nodeId the node id to enforce degree on * @return the fraction of short edges, i.e., neighbors within alpha=1.0 */ public double enforceDegree(int nodeId) { @@ -91,6 +128,15 @@ public double enforceDegree(int nodeId) { } } + /** + * Replaces deleted neighbors with new candidates from the provided NodeArray. + * Filters out deleted nodes and merges remaining neighbors with candidates, selecting + * diverse neighbors up to maxDegree. + * + * @param nodeId the node id whose neighbors to update + * @param toDelete a BitSet indicating which nodes have been deleted + * @param candidates candidate nodes to consider as replacement neighbors + */ public void replaceDeletedNeighbors(int nodeId, BitSet toDelete, NodeArray candidates) { while (true) { var old = neighbors.get(nodeId); @@ -101,6 +147,14 @@ public void replaceDeletedNeighbors(int nodeId, BitSet toDelete, NodeArray candi } } + /** + * Inserts diverse neighbors for a node from the provided candidates. Merges the candidates + * with existing neighbors and selects up to maxDegree diverse neighbors. + * + * @param nodeId the node id to update + * @param candidates candidate nodes to consider as neighbors + * @return the updated Neighbors object + */ public Neighbors insertDiverse(int nodeId, NodeArray candidates) { while (true) { var old = neighbors.get(nodeId); @@ -112,10 +166,21 @@ public Neighbors insertDiverse(int nodeId, NodeArray candidates) { } } + /** + * Returns the Neighbors object for the given node. + * + * @param node the node id to retrieve neighbors for + * @return the Neighbors object, or null if the node does not exist + */ public Neighbors get(int node) { return neighbors.get(node); } + /** + * Returns the number of nodes in this neighbor map. + * + * @return the number of nodes + */ public int size() { return neighbors.size(); } @@ -130,18 +195,40 @@ void addNode(int nodeId, NodeArray nodes) { } } + /** + * Adds a new node with no initial neighbors to this map. + * + * @param nodeId the node id to add + */ public void addNode(int nodeId) { addNode(nodeId, new NodeArray(0)); } + /** + * Removes a node from this map and returns its neighbors. + * + * @param node the node id to remove + * @return the Neighbors object that was removed, or null if the node did not exist + */ public Neighbors remove(int node) { return neighbors.remove(node); } + /** + * Checks if a node exists in this map. + * + * @param nodeId the node id to check + * @return true if the node exists, false otherwise + */ public boolean contains(int nodeId) { return neighbors.containsKey(nodeId); } + /** + * Iterates over all nodes and their neighbors in this map. + * + * @param consumer the consumer to apply to each node id and its Neighbors + */ public void forEach(DenseIntMap.IntBiConsumer consumer) { neighbors.forEach(consumer); } @@ -152,8 +239,12 @@ int nodeArrayLength() { } /** - * Add a link from every node in the NodeArray to the target toId. - * If overflow is > 1.0, allow the number of neighbors to exceed maxConnections temporarily. + * Adds a link from every node in the NodeArray to the target node. + * If overflow is greater than 1.0, allows the number of neighbors to exceed maxDegree temporarily. + * + * @param nodes the nodes to backlink from + * @param toId the target node id to link to + * @param overflow the factor by which to allow exceeding maxDegree temporarily */ public void backlink(NodeArray nodes, int toId, float overflow) { for (int i = 0; i < nodes.size(); i++) { @@ -192,6 +283,11 @@ private Neighbors(int nodeId, NodeArray nodeArray) { this.diverseBefore = size(); } + /** + * Returns an iterator over the neighbor node ids. + * + * @return a NodesIterator for iterating over neighbors + */ public NodesIterator iterator() { return new NeighborIterator(this); } @@ -322,6 +418,12 @@ private Neighbors insert(int neighborId, float score, float overflow, Concurrent return next; } + /** + * Estimates the RAM bytes used by a Neighbors object with the given capacity. + * + * @param count the capacity of the neighbors array + * @return the estimated RAM usage in bytes + */ public static long ramBytesUsed(int count) { return NodeArray.ramBytesUsed(count) // includes our object header + Integer.BYTES // nodeId diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 2cff2de4a..e4cbde521 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -90,6 +90,7 @@ public class GraphIndexBuilder implements Closeable { * * @param vectorValues the vectors whose relations are represented by the graph - must provide a * different view over those vectors than the one used to add via addGraphNode. + * @param similarityFunction the vector similarity function to use for comparing vectors * @param M – the maximum number of connections a node can have * @param beamWidth the size of the beam search to use when finding nearest neighbors. * @param neighborOverflow the ratio of extra neighbors to allow temporarily when inserting a @@ -123,6 +124,7 @@ public GraphIndexBuilder(RandomAccessVectorValues vectorValues, * * @param vectorValues the vectors whose relations are represented by the graph - must provide a * different view over those vectors than the one used to add via addGraphNode. + * @param similarityFunction the vector similarity function to use for comparing vectors * @param M – the maximum number of connections a node can have * @param beamWidth the size of the beam search to use when finding nearest neighbors. * @param neighborOverflow the ratio of extra neighbors to allow temporarily when inserting a @@ -159,6 +161,7 @@ public GraphIndexBuilder(RandomAccessVectorValues vectorValues, * By default, refineFinalGraph = true. * * @param scoreProvider describes how to determine the similarities between vectors + * @param dimension the dimensionality of the vectors in the graph * @param M the maximum number of connections a node can have * @param beamWidth the size of the beam search to use when finding nearest neighbors. * @param neighborOverflow the ratio of extra neighbors to allow temporarily when inserting a @@ -185,6 +188,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, * Default executor pools are used. * * @param scoreProvider describes how to determine the similarities between vectors + * @param dimension the dimensionality of the vectors in the graph * @param M the maximum number of connections a node can have * @param beamWidth the size of the beam search to use when finding nearest neighbors. * @param neighborOverflow the ratio of extra neighbors to allow temporarily when inserting a @@ -212,6 +216,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, * ordinals, using the given hyperparameter settings, and returns the resulting graph. * * @param scoreProvider describes how to determine the similarities between vectors + * @param dimension the dimensionality of the vectors in the graph * @param M the maximum number of connections a node can have * @param beamWidth the size of the beam search to use when finding nearest neighbors. * @param neighborOverflow the ratio of extra neighbors to allow temporarily when inserting a @@ -245,6 +250,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, * Default executor pools are used. * * @param scoreProvider describes how to determine the similarities between vectors + * @param dimension the dimensionality of the vectors in the graph * @param maxDegrees the maximum number of connections a node can have in each layer; if fewer entries * * are specified than the number of layers, the last entry is used for all remaining layers. * @param beamWidth the size of the beam search to use when finding nearest neighbors. @@ -273,6 +279,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, * ordinals, using the given hyperparameter settings, and returns the resulting graph. * * @param scoreProvider describes how to determine the similarities between vectors + * @param dimension the dimensionality of the vectors in the graph * @param maxDegrees the maximum number of connections a node can have in each layer; if fewer entries * are specified than the number of layers, the last entry is used for all remaining layers. * @param beamWidth the size of the beam search to use when finding nearest neighbors. @@ -338,7 +345,14 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, this.rng = new Random(0); } - // used by Cassandra when it fine-tunes the PQ codebook + /** + * Creates a new GraphIndexBuilder by rescoring an existing graph with a different score provider. + * Used by Cassandra when it fine-tunes the PQ codebook. + * + * @param other the existing GraphIndexBuilder to rescore + * @param newProvider the new score provider to use for rescoring + * @return a new GraphIndexBuilder with the same graph structure but new scores + */ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) { var newBuilder = new GraphIndexBuilder(newProvider, other.dimension, @@ -384,6 +398,12 @@ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvi return newBuilder; } + /** + * Builds a complete graph index from the given vector values. + * + * @param ravv the random access vector values to build the index from + * @return an immutable graph index containing all vectors + */ public ImmutableGraphIndex build(RandomAccessVectorValues ravv) { var vv = ravv.threadLocalSupplier(); int size = ravv.size(); @@ -489,6 +509,11 @@ private void improveConnections(int node) { } } + /** + * Returns the current state of the graph being built. + * + * @return the mutable graph index + */ public ImmutableGraphIndex getGraph() { return graph; } @@ -497,11 +522,21 @@ public ImmutableGraphIndex getGraph() { * Number of inserts in progress, across all threads. Useful as a sanity check * when calling non-threadsafe methods like cleanup(). (Do not use it to try to * _prevent_ races, only to detect them.) + * + * @return the number of insertions currently in progress */ public int insertsInProgress() { return insertionsInProgress.size(); } + /** + * Adds a node to the graph by retrieving its vector from the given vector values. + * + * @param node the node identifier to add + * @param ravv the vector values to retrieve the vector from + * @return an estimate of the number of extra bytes used by the graph after adding the node + * @deprecated Use {@link #addGraphNode(int, VectorFloat)} directly + */ @Deprecated public long addGraphNode(int node, RandomAccessVectorValues ravv) { return addGraphNode(node, ravv.getVector(node)); @@ -622,11 +657,23 @@ private void updateNeighborsOneLayer(int level, int node, NodeScore[] neighbors, updateNeighbors(level, node, natural, concurrent); } + /** + * Sets the entry point of the graph to a specific node at a given level. + * This method is visible for testing purposes only. + * + * @param level the level of the entry point + * @param node the node identifier to use as the entry point + */ @VisibleForTesting public void setEntryPoint(int level, int node) { graph.updateEntryNode(new NodeAtLevel(level, node)); } + /** + * Marks a node as deleted. The node will be removed from the graph during the next cleanup. + * + * @param node the node identifier to mark as deleted + */ public void markNodeDeleted(int node) { graph.markDeleted(node); } @@ -797,19 +844,42 @@ public void close() throws IOException { } } + /** + * A Bits implementation that excludes a single specified index. + * Used during graph construction to exclude the node being inserted from neighbor candidates. + */ private static class ExcludingBits implements Bits { private final int excluded; + /** + * Creates a new ExcludingBits that excludes the specified index. + * + * @param excluded the index to exclude + */ public ExcludingBits(int excluded) { this.excluded = excluded; } + /** + * Returns true if the index is not the excluded index. + * + * @param index the index to check + * @return true if the index is not excluded, false otherwise + */ @Override public boolean get(int index) { return index != excluded; } } + /** + * Loads a graph from the given input stream. + * + * @param in the input stream to read from + * @throws IOException if an I/O error occurs + * @throws IllegalStateException if the graph is not empty + * @deprecated This method is deprecated and will be removed in a future version + */ @Deprecated public void load(RandomAccessReader in) throws IOException { if (graph.size(0) != 0) { @@ -829,6 +899,14 @@ public void load(RandomAccessReader in) throws IOException { } } + /** + * Loads a version 4 format graph from the given input stream. + * + * @param in the input stream to read from + * @throws IOException if an I/O error occurs + * @throws IllegalStateException if the graph is not empty + * @deprecated This method is deprecated and will be removed in a future version + */ @Deprecated private void loadV4(RandomAccessReader in) throws IOException { if (graph.size(0) != 0) { @@ -876,6 +954,15 @@ private void loadV4(RandomAccessReader in) throws IOException { graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode)); } + /** + * Loads a version 3 format graph from the given input stream. + * + * @param in the input stream to read from + * @param size the number of nodes in the graph + * @throws IOException if an I/O error occurs + * @throws IllegalStateException if the graph is not empty + * @deprecated This method is deprecated and will be removed in a future version + */ @Deprecated private void loadV3(RandomAccessReader in, int size) throws IOException { if (graph.size() != 0) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java index 60a91c9ae..98d241d27 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java @@ -69,14 +69,18 @@ public class GraphSearcher implements Closeable { private int expandedCountBaseLayer; /** - * Creates a new graph searcher from the given GraphIndex + * Creates a new graph searcher from the given GraphIndex. + * + * @param graph the graph index to search */ public GraphSearcher(ImmutableGraphIndex graph) { this(graph.getView()); } /** - * Creates a new graph searcher from the given GraphIndex.View + * Creates a new graph searcher from the given GraphIndex.View. + * + * @param view the view of the graph index to search */ protected GraphSearcher(ImmutableGraphIndex.View view) { this.view = view; @@ -90,14 +94,29 @@ protected GraphSearcher(ImmutableGraphIndex.View view) { this.scoreTrackerFactory = new ScoreTracker.ScoreTrackerFactory(); } + /** + * Returns the number of nodes visited during the most recent search. + * + * @return the count of visited nodes + */ protected int getVisitedCount() { return visitedCount; } + /** + * Returns the number of nodes expanded during the most recent search. + * + * @return the count of expanded nodes + */ protected int getExpandedCount() { return expandedCount; } + /** + * Returns the number of nodes expanded in the base layer during the most recent search. + * + * @return the count of expanded nodes in the base layer + */ protected int getExpandedCountBaseLayer() { return expandedCountBaseLayer; } @@ -112,6 +131,11 @@ private void initializeScoreProvider(SearchScoreProvider scoreProvider) { cachingReranker = new CachingReranker(scoreProvider); } + /** + * Returns the current view of the graph being searched. + * + * @return the graph index view + */ public ImmutableGraphIndex.View getView() { return view; } @@ -128,6 +152,14 @@ public void usePruning(boolean usage) { /** * Convenience function for simple one-off searches. It is caller's responsibility to make sure that it * is the unique owner of the vectors instance passed in here. + * + * @param queryVector the query vector to search for + * @param topK the number of nearest neighbors to return + * @param vectors the vector values to search + * @param similarityFunction the similarity function to use + * @param graph the graph index to search + * @param acceptOrds a Bits instance indicating which nodes are acceptable results + * @return a SearchResult containing the topK results and search statistics */ public static SearchResult search(VectorFloat queryVector, int topK, RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, ImmutableGraphIndex graph, Bits acceptOrds) { try (var searcher = new GraphSearcher(graph)) { @@ -141,6 +173,15 @@ public static SearchResult search(VectorFloat queryVector, int topK, RandomAc /** * Convenience function for simple one-off searches. It is caller's responsibility to make sure that it * is the unique owner of the vectors instance passed in here. + * + * @param queryVector the query vector to search for + * @param topK the number of nearest neighbors to return + * @param rerankK the number of candidates to rerank before returning topK results + * @param vectors the vector values to search + * @param similarityFunction the similarity function to use + * @param graph the graph index to search + * @param acceptOrds a Bits instance indicating which nodes are acceptable results + * @return a SearchResult containing the topK results and search statistics */ public static SearchResult search(VectorFloat queryVector, int topK, int rerankK, RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, ImmutableGraphIndex graph, Bits acceptOrds) { try (var searcher = new GraphSearcher(graph)) { @@ -165,26 +206,46 @@ public void setView(ImmutableGraphIndex.View view) { } /** - * Call GraphSearcher constructor instead + * Call GraphSearcher constructor instead. + * + * @deprecated Use {@link GraphSearcher#GraphSearcher(ImmutableGraphIndex.View)} directly */ @Deprecated public static class Builder { private final ImmutableGraphIndex.View view; + /** + * Creates a new Builder for constructing a GraphSearcher. + * + * @param view the view of the graph index to search + */ public Builder(ImmutableGraphIndex.View view) { this.view = view; } + /** + * Configures the builder for concurrent updates. This method is deprecated and has no effect. + * + * @return this builder instance + * @deprecated This method has no effect and will be removed + */ public Builder withConcurrentUpdates() { return this; } + /** + * Builds and returns a new GraphSearcher instance. + * + * @return a new GraphSearcher + */ public GraphSearcher build() { return new GraphSearcher(view); } } /** + * Performs a comprehensive search of the graph with full control over reranking and thresholding. + * * @param scoreProvider provides functions to return the similarity of a given node to the query vector * @param topK the number of results to look for. With threshold=0, the search will continue until at least * `topK` results have been found, or until the entire graph has been searched. @@ -269,6 +330,8 @@ protected void internalSearch(SearchScoreProvider scoreProvider, } /** + * Performs a search of the graph with threshold-based filtering. + * * @param scoreProvider provides functions to return the similarity of a given node to the query vector * @param topK the number of results to look for. With threshold=0, the search will continue until at least * `topK` results have been found, or until the entire graph has been searched. @@ -291,6 +354,8 @@ public SearchResult search(SearchScoreProvider scoreProvider, /** + * Performs a basic search of the graph for the top K nearest neighbors. + * * @param scoreProvider provides functions to return the similarity of a given node to the query vector * @param topK the number of results to look for. With threshold=0, the search will continue until at least * `topK` results have been found, or until the entire graph has been searched. @@ -307,6 +372,10 @@ public SearchResult search(SearchScoreProvider scoreProvider, return search(scoreProvider, topK, 0.0f, acceptOrds); } + /** + * Updates entry points for the next layer by pushing current candidates back onto the queue. + * This allows reusing computed similarities and potentially connecting to more distant clusters. + */ void setEntryPointsFromPreviousLayer() { // push the candidates seen so far back onto the queue for the next layer // at worst we save recomputing the similarity; at best we might connect to a more distant cluster @@ -316,6 +385,13 @@ void setEntryPointsFromPreviousLayer() { approximateResults.clear(); } + /** + * Initializes the internal state for a new search operation. + * + * @param scoreProvider provides functions to compute similarity scores + * @param entry the entry point node and level to start the search + * @param rawAcceptOrds a Bits instance indicating which nodes are acceptable results + */ void initializeInternal(SearchScoreProvider scoreProvider, NodeAtLevel entry, Bits rawAcceptOrds) { // save search parameters for potential later resume initializeScoreProvider(scoreProvider); @@ -462,6 +538,14 @@ private void searchLayer0(int topK, int rerankK, float threshold) { searchOneLayer(scoreProvider, rerankK, threshold, 0, acceptOrds); } + /** + * Performs reranking of the approximate search results to produce the final topK results. + * + * @param topK the number of final results to return + * @param rerankK the number of approximate results to consider for reranking + * @param rerankFloor the minimum approximate score threshold for reranking + * @return a SearchResult containing the topK results after reranking + */ SearchResult reranking(int topK, int rerankK, float rerankFloor) { // rerank results assert approximateResults.size() <= rerankK; @@ -500,6 +584,15 @@ SearchResult reranking(int topK, int rerankK, float rerankFloor) { return new SearchResult(nodes, visitedCount, expandedCount, expandedCountBaseLayer, reranked, worstApproximateInTopK); } + /** + * Resumes a previous search to find additional results. + * + * @param topK the number of final results to return + * @param rerankK the number of approximate results to consider for reranking + * @param threshold the minimum similarity threshold for accepting results + * @param rerankFloor the minimum approximate score threshold for reranking + * @return a SearchResult containing the additional topK results + */ SearchResult resume(int topK, int rerankK, float threshold, float rerankFloor) { searchLayer0(topK, rerankK, threshold); return reranking(topK, rerankK, rerankFloor); @@ -531,6 +624,10 @@ private void addTopCandidate(int topCandidateNode, float topCandidateScore, int * `search`, but `resume` may be called as many times as desired once the search is initialized. *

* SearchResult.visitedCount resets with each call to `search` or `resume`. + * + * @param additionalK the number of additional results to find + * @param rerankK the number of approximate results to consider for reranking + * @return a SearchResult containing the additional results and search statistics */ @Experimental public SearchResult resume(int additionalK, int rerankK) { @@ -545,6 +642,10 @@ public void close() throws IOException { view.close(); } + /** + * A caching wrapper for exact score functions that memoizes computed similarities. + * This cache persists across resume() calls to avoid recomputing the same scores. + */ private static class CachingReranker implements ScoreFunction.ExactScoreFunction { // this cache never gets cleared out (until a new search reinitializes it), // but we expect resume() to be called at most a few times so it's fine @@ -552,12 +653,23 @@ private static class CachingReranker implements ScoreFunction.ExactScoreFunction private final SearchScoreProvider scoreProvider; private int rerankCalls; + /** + * Creates a new CachingReranker that wraps the reranker from the given score provider. + * + * @param scoreProvider the score provider whose reranker will be cached + */ public CachingReranker(SearchScoreProvider scoreProvider) { this.scoreProvider = scoreProvider; cachedScores = new Int2ObjectHashMap<>(); rerankCalls = 0; } + /** + * Returns the exact similarity to the given node, using the cached value if available. + * + * @param node2 the node to compute similarity to + * @return the exact similarity score + */ @Override public float similarityTo(int node2) { if (cachedScores.containsKey(node2)) { @@ -569,6 +681,11 @@ public float similarityTo(int node2) { return score; } + /** + * Returns the number of times the underlying reranker was called (cache misses). + * + * @return the count of rerank calls + */ public int getRerankCalls() { return rerankCalls; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java index 088f9a1af..0341e6737 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java @@ -47,7 +47,12 @@ * in a View that should be created per accessing thread. */ public interface ImmutableGraphIndex extends AutoCloseable, Accountable { - /** Returns the number of nodes in the graph */ + /** + * Returns the number of nodes in the graph. + * + * @return the number of nodes in the graph + * @deprecated Use {@link #size(int)} with level 0 instead + */ @Deprecated default int size() { return size(0); @@ -57,6 +62,7 @@ default int size() { * Get all node ordinals included in the graph. The nodes are NOT guaranteed to be * presented in any particular order. * + * @param level the level of the graph to get nodes from * @return an iterator over nodes where {@code nextInt} returns the next node. */ NodesIterator getNodes(int level); @@ -70,17 +76,31 @@ default int size() { * concurrently modified. Thus, it is good (and encouraged) to re-use Views for * on-disk, read-only graphs, but for in-memory graphs, it is better to create a new * View per search. + * + * @return a View for navigating the graph */ View getView(); /** - * @return the maximum number of edges per node across any layer + * Returns the maximum number of edges per node across any layer. + * + * @return the maximum degree */ int maxDegree(); + /** + * Returns a list of maximum degrees for each layer of the graph. + * If fewer entries are specified than the number of layers, the last entry applies to all remaining layers. + * + * @return a list of maximum degrees per layer + */ List maxDegrees(); /** + * Returns the first ordinal greater than all node ids in the graph. Equal to size() in simple cases. + * May be different from size() if nodes are being added concurrently, or if nodes have been + * deleted (and cleaned up). + * * @return the first ordinal greater than all node ids in the graph. Equal to size() in simple cases; * May be different from size() if nodes are being added concurrently, or if nodes have been * deleted (and cleaned up). @@ -90,16 +110,26 @@ default int getIdUpperBound() { } /** + * Returns true if and only if the graph contains the node with the given ordinal id. + * + * @param nodeId the node identifier to check * @return true iff the graph contains the node with the given ordinal id */ default boolean containsNode(int nodeId) { return nodeId >= 0 && nodeId < size(); } + /** + * Closes this graph index and releases any resources. + * + * @throws IOException if an I/O error occurs + */ @Override void close() throws IOException; /** + * Returns the maximum (coarser) level that contains a vector in the graph. + * * @return The maximum (coarser) level that contains a vector in the graph. */ int getMaxLevel(); @@ -134,18 +164,26 @@ interface View extends Closeable { /** * Iterator over the neighbors of a given node. Only the most recently instantiated iterator * is guaranteed to be valid. + * + * @param level the level of the graph + * @param node the node whose neighbors to iterate + * @return an iterator over the neighbors of the node */ NodesIterator getNeighborsIterator(int level, int node); /** * This method is deprecated as most View usages should not need size. * Where they do, they could access the graph. + * * @return the number of nodes in the graph + * @deprecated Use the graph's size() method instead */ @Deprecated int size(); /** + * Returns the node of the graph to start searches at. + * * @return the node of the graph to start searches at */ NodeAtLevel entryNode(); @@ -153,10 +191,14 @@ interface View extends Closeable { /** * Return a Bits instance indicating which nodes are live. The result is undefined for * ordinals that do not correspond to nodes in the graph. + * + * @return a Bits instance indicating which nodes are live */ Bits liveNodes(); /** + * Returns the largest ordinal id in the graph. May be different from size() if nodes have been deleted. + * * @return the largest ordinal id in the graph. May be different from size() if nodes have been deleted. */ default int getIdUpperBound() { @@ -165,6 +207,10 @@ default int getIdUpperBound() { /** * Whether the given node is present in the given layer of the graph. + * + * @param level the level to check + * @param node the node to check + * @return true if the node is present in the layer, false otherwise */ boolean contains(int level, int node); } @@ -174,10 +220,31 @@ default int getIdUpperBound() { * except for OnHeapGraphIndex.ConcurrentGraphIndexView.) */ interface ScoringView extends View { + /** + * Returns an exact score function for reranking results. + * + * @param queryVector the query vector to compute scores against + * @param vsf the vector similarity function to use + * @return an exact score function + */ ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf); + + /** + * Returns an approximate score function for initial candidate scoring. + * + * @param queryVector the query vector to compute scores against + * @param vsf the vector similarity function to use + * @return an approximate score function + */ ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat queryVector, VectorSimilarityFunction vsf); } + /** + * Returns a human-readable string representation of the graph structure showing all nodes and their neighbors. + * + * @param graph the graph index to format + * @return a formatted string representation of the graph + */ static String prettyPrint(ImmutableGraphIndex graph) { StringBuilder sb = new StringBuilder(); sb.append(graph); @@ -204,11 +271,22 @@ static String prettyPrint(ImmutableGraphIndex graph) { return sb.toString(); } - // Comparable b/c it gets used in ConcurrentSkipListMap + /** + * Represents a node at a specific level in the hierarchical graph structure. + * Comparable to support use in ConcurrentSkipListMap. + */ final class NodeAtLevel implements Comparable { + /** The level in the hierarchy where this node exists */ public final int level; + /** The node identifier */ public final int node; + /** + * Creates a new NodeAtLevel instance. + * + * @param level the level in the hierarchy (must be non-negative) + * @param node the node identifier (must be non-negative) + */ public NodeAtLevel(int level, int node) { assert level >= 0 : level; assert node >= 0 : node; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MapRandomAccessVectorValues.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MapRandomAccessVectorValues.java index b95456903..feceb7f60 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MapRandomAccessVectorValues.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MapRandomAccessVectorValues.java @@ -33,6 +33,12 @@ public class MapRandomAccessVectorValues implements RandomAccessVectorValues { private final Map> map; private final int dimension; + /** + * Constructs a MapRandomAccessVectorValues with the given map and dimension. + * + * @param map the map from node IDs to vectors + * @param dimension the dimension of the vectors + */ public MapRandomAccessVectorValues(Map> map, int dimension) { this.map = map; this.dimension = dimension; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java index 9650cece6..4991f53eb 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java @@ -41,25 +41,43 @@ * i.e. the most-similar nodes are first. */ public class NodeArray { + /** An empty NodeArray singleton. */ public static final NodeArray EMPTY = new NodeArray(0); private int size; private float[] scores; private int[] nodes; + /** + * Constructs a NodeArray with the specified initial capacity. + * + * @param initialSize the initial capacity for node and score arrays + */ public NodeArray(int initialSize) { nodes = new int[initialSize]; scores = new float[initialSize]; } - // this idiosyncratic constructor exists for the benefit of subclass ConcurrentNeighborMap + /** + * Constructs a NodeArray that shares the internal arrays from another NodeArray. + * This idiosyncratic constructor exists for the benefit of subclass ConcurrentNeighborMap. + * + * @param nodeArray the NodeArray to share internal arrays from + */ protected NodeArray(NodeArray nodeArray) { this.size = nodeArray.size(); this.nodes = nodeArray.nodes; this.scores = nodeArray.scores; } - /** always creates a new NodeArray to return, even when a1 or a2 is empty */ + /** + * Merges two NodeArrays into a new NodeArray, removing duplicate nodes. + * Always creates a new NodeArray to return, even when a1 or a2 is empty. + * + * @param a1 the first NodeArray to merge + * @param a2 the second NodeArray to merge + * @return a new NodeArray containing all unique nodes from both arrays, sorted by score + */ static NodeArray merge(NodeArray a1, NodeArray a2) { NodeArray merged = new NodeArray(a1.size() + a2.size()); int i = 0, j = 0; @@ -143,8 +161,11 @@ static NodeArray merge(NodeArray a1, NodeArray a2) { } /** - * Add a new node to the NodeArray. The new node must be worse than all previously stored - * nodes. + * Adds a new node to the NodeArray. The new node must be worse than all previously stored + * nodes (i.e., have a score less than or equal to the last node's score). + * + * @param newNode the node id to add + * @param newScore the score of the node */ public void addInOrder(int newNode, float newScore) { if (size == nodes.length) { @@ -166,6 +187,10 @@ public void addInOrder(int newNode, float newScore) { /** * Returns the index at which the given node should be inserted to maintain sorted order, * or -1 if the node already exists in the array (with the same score). + * + * @param newNode the node ID to find insertion point for + * @param newScore the score of the node + * @return the insertion index, or -1 if a duplicate exists */ int insertionPoint(int newNode, float newScore) { int insertionPoint = descSortFindRightMostInsertionPoint(newScore); @@ -173,9 +198,11 @@ int insertionPoint(int newNode, float newScore) { } /** - * Add a new node to the NodeArray into a correct sort position according to its score. + * Adds a new node to the NodeArray into a correct sort position according to its score. * Duplicate node + score pairs are ignored. * + * @param newNode the node id to insert + * @param newScore the score of the node * @return the insertion point of the new node, or -1 if it already existed */ public int insertSorted(int newNode, float newScore) { @@ -191,7 +218,11 @@ public int insertSorted(int newNode, float newScore) { } /** - * Add a new node to the NodeArray into the specified insertion point. + * Adds a new node to the NodeArray at the specified insertion point. + * + * @param insertionPoint the index at which to insert the node + * @param newNode the node ID to insert + * @param newScore the score of the node */ void insertAt(int insertionPoint, int newNode, float newScore) { if (size == nodes.length) { @@ -200,6 +231,14 @@ void insertAt(int insertionPoint, int newNode, float newScore) { insertInternal(insertionPoint, newNode, newScore); } + /** + * Performs the actual insertion of a node at the given index. + * + * @param insertionPoint the index at which to insert + * @param newNode the node ID + * @param newScore the score + * @return the insertion point + */ private int insertInternal(int insertionPoint, int newNode, float newScore) { System.arraycopy(nodes, insertionPoint, nodes, insertionPoint + 1, size - insertionPoint); System.arraycopy(scores, insertionPoint, scores, insertionPoint + 1, size - insertionPoint); @@ -209,6 +248,14 @@ private int insertInternal(int insertionPoint, int newNode, float newScore) { return insertionPoint; } + /** + * Checks if a duplicate node with the same score exists near the insertion point. + * + * @param insertionPoint the index to check around + * @param newNode the node ID to check for + * @param newScore the score to check for + * @return true if a duplicate exists, false otherwise + */ private boolean duplicateExistsNear(int insertionPoint, int newNode, float newScore) { // Check to the left for (int i = insertionPoint - 1; i >= 0 && scores[i] == newScore; i--) { @@ -230,9 +277,8 @@ private boolean duplicateExistsNear(int insertionPoint, int newNode, float newSc /** * Retains only the elements in the current NodeArray whose corresponding index * is set in the given BitSet. - *

+ * * This modifies the array in place, preserving the relative order of the elements retained. - *

* * @param selected A BitSet where the bit at index i is set if the i-th element should be retained. * (Thus, the elements of selected represent positions in the NodeArray, NOT node ids.) @@ -255,10 +301,21 @@ public void retain(Bits selected) { size = writeIdx; } + /** + * Creates a copy of this NodeArray with the same capacity. + * + * @return a new NodeArray with the same size and contents + */ public NodeArray copy() { return copy(size); } + /** + * Creates a copy of this NodeArray with the specified capacity. + * + * @param newSize the capacity of the new array (must be greater than or equal to current size) + * @return a new NodeArray with the specified capacity + */ public NodeArray copy(int newSize) { if (size > newSize) { throw new IllegalArgumentException(String.format("Cannot copy %d nodes to a smaller size %d", size, newSize)); @@ -271,23 +328,42 @@ public NodeArray copy(int newSize) { return copy; } + /** + * Grows the internal arrays to accommodate more nodes. + */ protected final void growArrays() { nodes = ArrayUtil.grow(nodes); scores = ArrayUtil.growExact(scores, nodes.length); } + /** + * Returns the number of nodes currently stored in this array. + * + * @return the size of this array + */ public int size() { return size; } + /** + * Removes all nodes from this array. + */ public void clear() { size = 0; } + /** + * Removes the last node from this array. + */ public void removeLast() { size--; } + /** + * Removes the node at the specified index. + * + * @param idx the index of the node to remove + */ public void removeIndex(int idx) { System.arraycopy(nodes, idx + 1, nodes, idx, size - idx - 1); System.arraycopy(scores, idx + 1, scores, idx, size - idx - 1); @@ -305,6 +381,12 @@ public String toString() { return sb.toString(); } + /** + * Finds the rightmost insertion point for a score in the descending-sorted array. + * + * @param newScore the score to find insertion point for + * @return the index where the score should be inserted + */ protected final int descSortFindRightMostInsertionPoint(float newScore) { int start = 0; int end = size - 1; @@ -316,6 +398,12 @@ protected final int descSortFindRightMostInsertionPoint(float newScore) { return start; } + /** + * Estimates the RAM usage in bytes for a NodeArray of the given size. + * + * @param size the number of nodes + * @return estimated RAM usage in bytes + */ public static long ramBytesUsed(int size) { int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF; int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; @@ -329,7 +417,11 @@ public static long ramBytesUsed(int size) { } /** + * Checks if the given node is present in this array. * Caution! This performs a linear scan. + * + * @param node the node ID to search for + * @return true if the node is present, false otherwise */ @VisibleForTesting boolean contains(int node) { @@ -341,33 +433,64 @@ boolean contains(int node) { return false; } + /** + * Returns a dense copy of the nodes array containing only the active elements. + * + * @return a copy of the nodes array with length equal to size + */ @VisibleForTesting int[] copyDenseNodes() { return Arrays.copyOf(nodes, size); } + /** + * Returns a dense copy of the scores array containing only the active elements. + * + * @return a copy of the scores array with length equal to size + */ @VisibleForTesting float[] copyDenseScores() { return Arrays.copyOf(scores, size); } /** - * Insert a new node, without growing the array. If the array is full, drop the worst existing node to make room. + * Inserts a new node without growing the array. If the array is full, drops the worst existing node to make room. * (Even if the worst existing one is better than newNode!) + * + * @param newNode the node id to insert + * @param newScore the score of the node + * @return the insertion point of the new node, or -1 if it already existed */ protected int insertOrReplaceWorst(int newNode, float newScore) { size = min(size, nodes.length - 1); return insertSorted(newNode, newScore); } + /** + * Returns the score at the specified index. + * + * @param i the index + * @return the score at index i + */ public float getScore(int i) { return scores[i]; } + /** + * Returns the node ID at the specified index. + * + * @param i the index + * @return the node ID at index i + */ public int getNode(int i) { return nodes[i]; } + /** + * Returns the capacity of the internal arrays. + * + * @return the length of the internal arrays + */ protected int getArrayLength() { return nodes.length; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java index f35761bf9..41afa457e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java @@ -40,6 +40,9 @@ * or unbounded operations, depending on the implementation subclasses, and either maxheap or minheap behavior. */ public class NodeQueue { + /** + * Ordering for the heap: MIN_HEAP keeps smallest values at the top, MAX_HEAP keeps largest values at the top. + */ public enum Order { /** Smallest values at the top of the heap */ MIN_HEAP { @@ -64,12 +67,20 @@ long apply(long v) { private final AbstractLongHeap heap; private final Order order; + /** + * Constructs a NodeQueue with the specified heap and ordering. + * + * @param heap the underlying heap to store encoded node/score pairs + * @param order the heap ordering (MIN_HEAP or MAX_HEAP) + */ public NodeQueue(AbstractLongHeap heap, Order order) { this.heap = heap; this.order = order; } /** + * Returns the number of elements in the heap. + * * @return the number of elements in the heap */ public int size() { @@ -128,20 +139,40 @@ private long encode(int node, float score) { (((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node)); } + /** + * Decodes the score from the encoded heap value. + * + * @param heapValue the encoded long value from the heap + * @return the decoded score + */ private float decodeScore(long heapValue) { return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32)); } + /** + * Decodes the node ID from the encoded heap value. + * + * @param heapValue the encoded long value from the heap + * @return the decoded node ID + */ private int decodeNodeId(long heapValue) { return (int) ~(order.apply(heapValue)); } - /** Removes the top element and returns its node id. */ + /** + * Removes the top element and returns its node id. + * + * @return the node ID of the top element + */ public int pop() { return decodeNodeId(heap.pop()); } - /** Returns a copy of the internal nodes array. Not sorted by score! */ + /** + * Returns a copy of the internal nodes array. Not sorted by score! + * + * @return an array of node IDs in heap order (not score order) + */ public int[] nodesCopy() { int size = size(); int[] nodes = new int[size]; @@ -152,10 +183,17 @@ public int[] nodesCopy() { } /** - * Rerank results and return the worst approximate score that made it into the topK. - * The topK results will be placed into `reranked`, and the remainder into `unused`. + * Reranks results and returns the worst approximate score that made it into the topK. + * The topK results will be placed into {@code reranked}, and the remainder into {@code unused}. *

- * Only the best result or results whose approximate score is at least `rerankFloor` will be reranked. + * Only the best result or results whose approximate score is at least {@code rerankFloor} will be reranked. + * + * @param topK the number of top results to rerank + * @param reranker the exact score function to use for reranking + * @param rerankFloor the minimum approximate score threshold for reranking + * @param reranked the queue to receive the reranked top results + * @param unused the collection to receive nodes that were not included in the top results + * @return the worst approximate score among the topK results */ public float rerank(int topK, ScoreFunction.ExactScoreFunction reranker, float rerankFloor, NodeQueue reranked, NodesUnsorted unused) { // Rescore the nodes whose approximate score meets the floor. Nodes that do not will be marked as -1 @@ -229,7 +267,11 @@ public float rerank(int topK, ScoreFunction.ExactScoreFunction reranker, float r return worstApproximateInTopK; } - /** Returns the top element's node id. */ + /** + * Returns the top element's node id. + * + * @return the node ID of the top element + */ public int topNode() { return decodeNodeId(heap.top()); } @@ -237,17 +279,25 @@ public int topNode() { /** * Returns the top element's node score. For the min heap this is the minimum score. For the max * heap this is the maximum score. + * + * @return the score of the top element */ public float topScore() { return decodeScore(heap.top()); } + /** + * Removes all elements from this queue. + */ public void clear() { heap.clear(); } /** - * Set the max size of the underlying heap. Only valid when NodeQueue was created with BoundedLongHeap. + * Sets the maximum size of the underlying heap. Only valid when NodeQueue was created with BoundedLongHeap. + * + * @param maxSize the new maximum size for the heap + * @throws ClassCastException if the underlying heap is not a BoundedLongHeap */ public void setMaxSize(int maxSize) { ((BoundedLongHeap) heap).setMaxSize(maxSize); @@ -258,6 +308,12 @@ public String toString() { return "Nodes[" + heap.size() + "]"; } + /** + * Applies the given consumer to each node/score pair in this queue. + * The order of iteration is not guaranteed to be sorted by score. + * + * @param consumer the consumer to apply to each node/score pair + */ public void foreach(NodeConsumer consumer) { for (int i = 0; i < heap.size(); i++) { long heapValue = heap.get(i + 1); @@ -265,27 +321,51 @@ public void foreach(NodeConsumer consumer) { } } + /** + * A consumer that accepts node ID and score pairs. + */ @FunctionalInterface public interface NodeConsumer { + /** + * Accepts a node ID and its associated score. + * + * @param node the node ID + * @param score the score associated with the node + */ void accept(int node, float score); } - /** Iterator over node and score pairs. */ + /** + * Iterator over node and score pairs. + */ public interface NodeScoreIterator { - /** @return true if there are more elements */ + /** + * Checks if there are more elements to iterate over. + * + * @return true if there are more elements, false otherwise + */ boolean hasNext(); - /** @return the next node id and advance the iterator */ + /** + * Returns the next node ID and advances the iterator. + * + * @return the next node ID + */ int pop(); - /** @return the next node score */ + /** + * Returns the score of the next node without advancing the iterator. + * + * @return the next node score + */ float topScore(); } /** * Copies the other NodeQueue to this one. If its order (MIN_HEAP or MAX_HEAP) is the same as this, - * it is copied verbatim. If it differs, every lement is re-inserted into this. - * @param other the other node queue. + * it is copied verbatim. If it differs, every element is re-inserted into this. + * + * @param other the other node queue to copy from */ public void copyFrom(NodeQueue other) { if (this.order == other.order) { @@ -304,6 +384,12 @@ private static class NodeScoreIteratorConverter implements PrimitiveIterator.OfL private final NodeScoreIterator it; private final NodeQueue queue; + /** + * Constructs a converter that wraps the given iterator. + * + * @param it the node score iterator to wrap + * @param queue the node queue used for encoding + */ public NodeScoreIteratorConverter(NodeScoreIterator it, NodeQueue queue) { this.it = it; this.queue = queue; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesIterator.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesIterator.java index 83981ee91..546cbbeb2 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesIterator.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesIterator.java @@ -34,10 +34,19 @@ */ public interface NodesIterator extends PrimitiveIterator.OfInt { /** - * The number of elements in this iterator * + * Returns the number of elements in this iterator. + * + * @return the size of this iterator */ int size(); + /** + * Creates a NodesIterator from a primitive iterator and size. + * + * @param iterator the primitive iterator to wrap + * @param size the number of elements + * @return a NodesIterator wrapping the given iterator + */ static NodesIterator fromPrimitiveIterator(PrimitiveIterator.OfInt iterator, int size) { return new NodesIterator() { @Override @@ -57,12 +66,20 @@ public boolean hasNext() { }; } + /** + * An iterator over an array of node IDs. + */ class ArrayNodesIterator implements NodesIterator { private final int[] nodes; private int cur = 0; private final int size; - /** Constructor for iterator based on integer array representing nodes */ + /** + * Constructs an iterator based on an integer array representing nodes. + * + * @param nodes the array of node IDs + * @param size the number of valid elements in the array + */ public ArrayNodesIterator(int[] nodes, int size) { assert nodes != null; assert size <= nodes.length; @@ -75,6 +92,11 @@ public int size() { return size; } + /** + * Constructs an iterator for the entire array. + * + * @param nodes the array of node IDs + */ public ArrayNodesIterator(int[] nodes) { this(nodes, nodes.length); } @@ -97,9 +119,18 @@ public boolean hasNext() { } } + /** + * A singleton empty node iterator. + */ EmptyNodeIterator EMPTY_NODE_ITERATOR = new EmptyNodeIterator(); + /** + * An empty node iterator implementation. + */ class EmptyNodeIterator implements NodesIterator { + /** Package-private constructor. */ + EmptyNodeIterator() { + } @Override public int size() { return 0; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesUnsorted.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesUnsorted.java index 4597618a8..e3dc37312 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesUnsorted.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesUnsorted.java @@ -31,18 +31,26 @@ * NodesUnsorted contains scored node ids in insertion order. */ public class NodesUnsorted { + /** The number of nodes currently stored. */ protected int size; float[] score; int[] node; + /** + * Constructs a NodesUnsorted with the specified initial capacity. + * + * @param initialSize the initial capacity for node and score arrays + */ public NodesUnsorted(int initialSize) { node = new int[initialSize]; score = new float[initialSize]; } /** - * Add a new node to the NodeArray. The new node must be worse than all previously stored - * nodes. + * Adds a new node to this collection in insertion order (unsorted). + * + * @param newNode the node ID to add + * @param newScore the score of the node */ public void add(int newNode, float newScore) { if (size == node.length) { @@ -53,19 +61,35 @@ public void add(int newNode, float newScore) { ++size; } + /** + * Grows the internal arrays to accommodate more nodes. + */ protected final void growArrays() { node = ArrayUtil.grow(node); score = ArrayUtil.growExact(score, node.length); } + /** + * Returns the number of nodes currently stored. + * + * @return the size + */ public int size() { return size; } + /** + * Removes all nodes from this collection. + */ public void clear() { size = 0; } + /** + * Applies the given consumer to each node/score pair in this collection. + * + * @param consumer the consumer to apply to each node/score pair + */ public void foreach(NodeConsumer consumer) { for (int i = 0; i < size; i++) { consumer.accept(node[i], score[i]); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 7ddbf7897..e2d555e05 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -55,8 +55,8 @@ * For searching, use a view obtained from {@link #getView()} which supports level–aware operations. */ public class OnHeapGraphIndex implements MutableGraphIndex { - // Used for saving and loading OnHeapGraphIndex - public static final int MAGIC = 0x75EC4012; // JVECTOR, with some imagination + /** Magic number used for saving and loading OnHeapGraphIndex (JVECTOR with some imagination). */ + public static final int MAGIC = 0x75EC4012; // The current entry node for searches private final AtomicReference entryPoint; @@ -76,6 +76,13 @@ public class OnHeapGraphIndex implements MutableGraphIndex { private volatile boolean allMutationsCompleted = false; + /** + * Constructs an OnHeapGraphIndex with the specified parameters. + * + * @param maxDegrees the maximum degree for each layer + * @param overflowRatio the multiplicative ratio for temporary overflow during construction + * @param diversityProvider provider for diversity-based neighbor selection + */ OnHeapGraphIndex(List maxDegrees, double overflowRatio, DiversityProvider diversityProvider) { this.overflowRatio = overflowRatio; this.maxDegrees = new IntArrayList(); @@ -140,10 +147,21 @@ public int size(int level) { return layers.get(level).size(); } + /** + * Adds a node at the specified level to the graph. + * + * @param nodeLevel the node and level to add + */ public void addNode(NodeAtLevel nodeLevel) { addNode(nodeLevel.level, nodeLevel.node); } + /** + * Adds a node at the specified level to the graph. + * + * @param level the layer to add the node to + * @param node the node ID to add + */ public void addNode(int level, int node) { ensureLayersExist(level); @@ -377,9 +395,16 @@ public void allMutationsCompleted() { * searches. The View provides a limited kind of snapshot isolation: only nodes completely added * to the graph at the time the View was created will be visible (but the connections between them * are allowed to change, so you could potentially get different top K results from the same query - * if concurrent updates are in progress.) + * if concurrent updates are in progress). */ public class ConcurrentGraphIndexView extends FrozenView { + /** + * Constructs a ConcurrentGraphIndexView with snapshot isolation based on the current completion timestamp. + */ + public ConcurrentGraphIndexView() { + super(); + } + // It is tempting, but incorrect, to try to provide "adequate" isolation by // (1) keeping a bitset of complete nodes and giving that to the searcher as nodes to // accept -- but we need to keep incomplete nodes out of the search path entirely, @@ -441,7 +466,18 @@ public boolean hasNext() { } } + /** + * A frozen view of the graph that provides read-only access without snapshot isolation. + * This view is used when all mutations have been completed and the graph structure is stable. + */ private class FrozenView implements View { + /** + * Constructs a FrozenView for this graph. + */ + FrozenView() { + // Default constructor + } + @Override public NodesIterator getNeighborsIterator(int level, int node) { return OnHeapGraphIndex.this.getNeighborsIterator(level, node); @@ -488,7 +524,10 @@ public String toString() { } /** - * Saves the graph to the given DataOutput for reloading into memory later + * Saves the graph to the given DataOutput for reloading into memory later. + * + * @param out the DataOutput to write the graph to + * @deprecated This method is deprecated and may be removed in a future version */ @Deprecated public void save(DataOutput out) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java index eb8f6df24..78934c9e3 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java @@ -37,39 +37,55 @@ * implementations of KNN search. */ public interface RandomAccessVectorValues { + /** Logger for RandomAccessVectorValues operations. */ Logger LOG = Logger.getLogger(RandomAccessVectorValues.class.getName()); /** - * Return the number of vector values. + * Returns the number of vector values. *

* All copies of a given RAVV should have the same size. Typically this is achieved by either - * (1) implementing a threadsafe, un-shared RAVV, where `copy` returns `this`, or + * (1) implementing a threadsafe, un-shared RAVV, where {@code copy} returns {@code this}, or * (2) implementing a fixed-size RAVV. + * + * @return the number of vectors */ int size(); - /** Return the dimension of the returned vector values */ + /** + * Returns the dimension of the returned vector values. + * + * @return the vector dimension + */ int dimension(); /** - * Return the vector value indexed at the given ordinal. + * Returns the vector value indexed at the given ordinal. * *

For performance, implementations are free to re-use the same object across invocations. * That is, you will get back the same VectorFloat<?> * reference (for instance) for every requested ordinal. If you want to use those values across * calls, you should make a copy. * - * @param nodeId a valid ordinal, ≥ 0 and < {@link #size()}. + * @param nodeId a valid ordinal, ≥ 0 and < {@link #size()} + * @return the vector at the given ordinal */ VectorFloat getVector(int nodeId); + /** + * Returns the vector value indexed at the given ordinal. + * + * @deprecated Use {@link #getVector(int)} instead + * @param targetOrd a valid ordinal + * @return the vector at the given ordinal + */ @Deprecated default VectorFloat vectorValue(int targetOrd) { return getVector(targetOrd); } /** - * Retrieve the vector associated with a given node, and store it in the destination vector at the given offset. + * Retrieves the vector associated with a given node and stores it in the destination vector at the given offset. + * * @param node the node to retrieve * @param destinationVector the vector to store the result in * @param offset the offset in the destination vector to store the result @@ -79,7 +95,9 @@ default void getVectorInto(int node, VectorFloat destinationVector, int offse } /** - * @return true iff the vector returned by `getVector` is shared. A shared vector will + * Checks if the vector returned by {@code getVector} is shared. + * + * @return true iff the vector returned by {@code getVector} is shared. A shared vector will * only be valid until the next call to getVector overwrites it. */ boolean isValueShared(); @@ -89,12 +107,16 @@ default void getVectorInto(int node, VectorFloat destinationVector, int offse * access different values at once, to avoid overwriting the underlying float vector returned by * a shared {@link RandomAccessVectorValues#getVector}. *

- * Un-shared implementations may simply return `this`. + * Un-shared implementations may simply return {@code this}. + * + * @return a copy of this RandomAccessVectorValues */ RandomAccessVectorValues copy(); /** * Returns a supplier of thread-local copies of the RAVV. + * + * @return a supplier that provides thread-local copies */ default Supplier threadLocalSupplier() { if (!isValueShared()) { @@ -110,6 +132,10 @@ default Supplier threadLocalSupplier() { /** * Convenience method to create an ExactScoreFunction for reranking. The resulting function is NOT thread-safe. + * + * @param queryVector the query vector + * @param vsf the vector similarity function to use + * @return an ExactScoreFunction for reranking */ default ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf) { return new ScoreFunction.ExactScoreFunction() { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java index 21be9e3b0..873eadbd5 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java @@ -23,7 +23,15 @@ import static io.github.jbellis.jvector.util.NumericUtils.floatToSortableInt; import static io.github.jbellis.jvector.util.NumericUtils.sortableIntToFloat; +/** + * Interface for tracking similarity scores during graph search to determine when to stop early. + * Implementations can track score distributions to predict when continuing the search is unlikely to improve results. + */ public interface ScoreTracker { + /** + * Factory for creating and reusing ScoreTracker instances. + * Maintains instances of each tracker type to avoid repeated allocation. + */ class ScoreTrackerFactory { private TwoPhaseTracker twoPhaseTracker; private RelaxedMonotonicityTracker relaxedMonotonicityTracker; @@ -35,6 +43,14 @@ class ScoreTrackerFactory { noOpTracker = null; } + /** + * Returns an appropriate ScoreTracker based on the search parameters. + * + * @param pruneSearch whether to prune the search early + * @param rerankK the number of candidates to rerank + * @param threshold the minimum score threshold for threshold queries + * @return a ScoreTracker instance appropriate for the given parameters + */ public ScoreTracker getScoreTracker(boolean pruneSearch, int rerankK, float threshold) { // track scores to predict when we are done with threshold queries final ScoreTracker scoreTracker; @@ -65,13 +81,34 @@ public ScoreTracker getScoreTracker(boolean pruneSearch, int rerankK, float thre } } + /** A no-op tracker instance that never triggers early termination. */ ScoreTracker NO_OP = new NoOpTracker(); + /** + * Records a similarity score observed during search. + * + * @param score the similarity score to track + */ void track(float score); + /** + * Returns whether the search should stop early based on tracked scores. + * + * @return true if the search should stop, false otherwise + */ boolean shouldStop(); + /** + * A no-op tracker that never triggers early termination. + */ class NoOpTracker implements ScoreTracker { + /** + * Constructs a NoOpTracker. + */ + NoOpTracker() { + // Default constructor + } + @Override public void track(float score) { } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/SearchResult.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/SearchResult.java index e1b9e5506..f82431006 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/SearchResult.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/SearchResult.java @@ -23,13 +23,29 @@ * Container class for results of an ANN search, along with associated metrics about the behavior of the search. */ public final class SearchResult { + /** The closest neighbors discovered by the search, sorted best-first. */ private final NodeScore[] nodes; + /** The total number of graph nodes visited while performing the search. */ private final int visitedCount; + /** The total number of graph nodes expanded while performing the search. */ private final int expandedCount; + /** The number of graph nodes expanded while performing the search in the base layer. */ private final int expandedCountL0; + /** The number of nodes that were reranked during the search. */ private final int rerankedCount; + /** The worst approximate score of the top K nodes in the search result. */ private final float worstApproximateScoreInTopK; + /** + * Constructs a SearchResult with the specified search results and metrics. + * + * @param nodes the closest neighbors discovered by the search, sorted best-first + * @param visitedCount the total number of graph nodes visited while performing the search + * @param expandedCount the total number of graph nodes expanded while performing the search + * @param expandedCountL0 the number of graph nodes expanded in the base layer + * @param rerankedCount the number of nodes that were reranked during the search + * @param worstApproximateScoreInTopK the worst approximate score in the top K results, or Float.POSITIVE_INFINITY if no reranking occurred + */ public SearchResult(NodeScore[] nodes, int visitedCount, int expandedCount, int expandedCountL0, int rerankedCount, float worstApproximateScoreInTopK) { this.nodes = nodes; this.visitedCount = visitedCount; @@ -40,6 +56,8 @@ public SearchResult(NodeScore[] nodes, int visitedCount, int expandedCount, int } /** + * Returns the closest neighbors discovered by the search. + * * @return the closest neighbors discovered by the search, sorted best-first */ public NodeScore[] getNodes() { @@ -47,6 +65,8 @@ public NodeScore[] getNodes() { } /** + * Returns the total number of graph nodes visited during the search. + * * @return the total number of graph nodes visited while performing the search */ public int getVisitedCount() { @@ -54,6 +74,8 @@ public int getVisitedCount() { } /** + * Returns the total number of graph nodes expanded during the search. + * * @return the total number of graph nodes expanded while performing the search */ public int getExpandedCount() { @@ -61,6 +83,8 @@ public int getExpandedCount() { } /** + * Returns the number of graph nodes expanded in the base layer during the search. + * * @return the number of graph nodes expanded while performing the search in the base layer */ public int getExpandedCountBaseLayer() { @@ -68,6 +92,8 @@ public int getExpandedCountBaseLayer() { } /** + * Returns the number of nodes that were reranked during the search. + * * @return the number of nodes that were reranked during the search */ public int getRerankedCount() { @@ -75,6 +101,9 @@ public int getRerankedCount() { } /** + * Returns the worst approximate score of the top K nodes in the search result. + * Useful for passing to rerankFloor during search across multiple indexes. + * * @return the worst approximate score of the top K nodes in the search result. Useful * for passing to rerankFloor during search across multiple indexes. Will be * Float.POSITIVE_INFINITY if no reranking was performed or no results were found. @@ -83,10 +112,21 @@ public float getWorstApproximateScoreInTopK() { return worstApproximateScoreInTopK; } + /** + * Represents a node and its associated similarity score in a search result. + */ public static final class NodeScore implements Comparable { + /** The node identifier. */ public final int node; + /** The similarity score for this node. */ public final float score; + /** + * Constructs a NodeScore with the specified node ID and score. + * + * @param node the node identifier + * @param score the similarity score for this node + */ public NodeScore(int node, float score) { this.node = node; this.score = score; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java index 761024ff8..84d0df5c0 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java @@ -29,23 +29,90 @@ import java.util.function.IntFunction; import java.util.stream.Collectors; +/** + * Abstract base class for writing graph indexes to disk in various formats. + *

+ * This writer handles the serialization of graph structure, including nodes, edges, and associated + * features (such as vectors) to a persistent storage format. It supports both inline features + * (written alongside graph nodes) and separated features (written in a dedicated section). + *

+ * The on-disk format consists of: + *

    + *
  • A header containing metadata and feature information
  • + *
  • Graph nodes with inline features and edge lists
  • + *
  • Sparse levels for hierarchical graphs (if applicable)
  • + *
  • Separated feature data (if any)
  • + *
  • A footer containing the header offset and magic number
  • + *
+ *

+ * Subclasses must implement the specific writing strategy (e.g., sequential or random access). + *

+ * Thread safety: This class uses synchronized methods where necessary but is not designed + * for concurrent writes from multiple threads. The {@code maxOrdinalWritten} field is volatile + * to support visibility across threads. + * + * @param the type of {@link IndexWriter} used for output operations + */ public abstract class AbstractGraphIndexWriter implements GraphIndexWriter { + /** Magic number written at the end of the index file to identify valid JVector graph files. */ public static final int FOOTER_MAGIC = 0x4a564244; // "EOF magic" - public static final int FOOTER_OFFSET_SIZE = Long.BYTES; // The size of the offset in the footer - public static final int FOOTER_MAGIC_SIZE = Integer.BYTES; // The size of the magic number in the footer - public static final int FOOTER_SIZE = FOOTER_MAGIC_SIZE + FOOTER_OFFSET_SIZE; // The total size of the footer + + /** Size in bytes of the header offset field in the footer. */ + public static final int FOOTER_OFFSET_SIZE = Long.BYTES; + + /** Size in bytes of the magic number field in the footer. */ + public static final int FOOTER_MAGIC_SIZE = Integer.BYTES; + + /** Total size in bytes of the footer (magic number plus offset). */ + public static final int FOOTER_SIZE = FOOTER_MAGIC_SIZE + FOOTER_OFFSET_SIZE; + + /** The format version number for this graph index. */ final int version; + + /** The immutable graph structure to be written to disk. */ final ImmutableGraphIndex graph; + + /** Maps between original graph ordinals and the ordinals written to disk. */ final OrdinalMapper ordinalMapper; + + /** The dimensionality of the vectors stored in this index. */ final int dimension; - // we don't use Map features but EnumMap is the best way to make sure we don't - // accidentally introduce an ordering bug in the future + + /** + * Map of features to be written with this index. + *

+ * Uses {@code EnumMap} to ensure consistent ordering and avoid ordering bugs, + * even though map-specific features are not utilized. + */ final EnumMap featureMap; - final T out; /* output for graph nodes and inline features */ + + /** Output writer for graph nodes and inline features. */ + final T out; + + /** The size in bytes of the index header. */ final int headerSize; + + /** + * The maximum ordinal that has been written so far, or -1 if no ordinals have been written yet. + *

+ * This field is volatile to ensure visibility across threads. + */ volatile int maxOrdinalWritten = -1; + + /** List of features that are written inline with graph nodes (not separated). */ final List inlineFeatures; + /** + * Constructs an abstract graph index writer with the specified configuration. + * + * @param out the output writer for graph nodes and inline features + * @param version the format version number for this graph index + * @param graph the immutable graph structure to be written to disk + * @param oldToNewOrdinals maps original graph ordinals to new ordinals for writing + * @param dimension the dimensionality of the vectors stored in this index + * @param features map of features to be written with this index + * @throws IllegalArgumentException if attempting to write a multilayer graph with version less than 4 + */ AbstractGraphIndexWriter(T out, int version, ImmutableGraphIndex graph, @@ -72,16 +139,33 @@ public abstract class AbstractGraphIndexWriter implements } /** + * Returns the maximum ordinal that has been written so far. + * * @return the maximum ordinal written so far, or -1 if no ordinals have been written yet */ public int getMaxOrdinal() { return maxOrdinalWritten; } + /** + * Returns the set of feature IDs that will be written with this index. + * + * @return an unmodifiable set of {@link FeatureId} values configured for this writer + */ public Set getFeatureSet() { return featureMap.keySet(); } + /** + * Calculates the byte offset where inline features for a given ordinal begin in the output stream. + *

+ * This calculation accounts for the header size, all previous nodes' data (ordinals, inline features, + * and edges), and the ordinal field of the target node. + * + * @param startOffset the starting offset in the output stream where the graph data begins + * @param ordinal the node ordinal for which to calculate the feature offset + * @return the absolute byte offset where the inline features for the specified ordinal are located + */ long featureOffsetForOrdinal(long startOffset, int ordinal) { int edgeSize = Integer.BYTES * (1 + graph.getDegree(0)); long inlineBytes = ordinal * (long) (Integer.BYTES + inlineFeatures.stream().mapToInt(Feature::featureSize).sum() + edgeSize); @@ -91,15 +175,30 @@ long featureOffsetForOrdinal(long startOffset, int ordinal) { + Integer.BYTES; // the ordinal of the node whose features we're about to write } + /** + * Checks whether a feature should be written separately from the main graph data. + * + * @param feature the feature to check + * @return {@code true} if the feature is a {@link SeparatedFeature}, {@code false} otherwise + */ boolean isSeparated(Feature feature) { return feature instanceof SeparatedFeature; } /** - * @return a Map of old to new graph ordinals where the new ordinals are sequential starting at 0, - * while preserving the original relative ordering in `graph`. That is, for all node ids i and j, - * if i < j in `graph` then map[i] < map[j] in the returned map. "Holes" left by + * Creates a sequential renumbering map that eliminates gaps in ordinal numbering. + *

+ * Returns a map of old to new graph ordinals where the new ordinals are sequential starting at 0, + * while preserving the original relative ordering in the graph. That is, for all node ids i and j, + * if i < j in the graph then map[i] < map[j] in the returned map. "Holes" left by * deleted nodes are filled in by shifting down the new ordinals. + *

+ * This is useful for creating compact on-disk representations where deleted nodes do not + * leave unused space in the ordinal range. + * + * @param graph the immutable graph to renumber + * @return a map from original ordinals to sequential new ordinals + * @throws RuntimeException if an exception occurs while accessing the graph view */ public static Map sequentialRenumbering(ImmutableGraphIndex graph) { try (var view = graph.getView()) { @@ -117,17 +216,22 @@ public static Map sequentialRenumbering(ImmutableGraphIndex gr } /** - * Write the {@link Header} as a footer for the graph index. + * Writes the graph header as a footer at the end of the index file. + *

+ * The footer format enables efficient index reading by storing the header at the end, + * allowing readers to locate and parse metadata without scanning the entire file. *

- * To read the graph later, we will perform the following steps: + * To read the graph later, the following steps are performed: *

    - *
  1. Find the magic number at the end of the slice - *
  2. Read the header offset from the end of the slice - *
  3. Read the header - *
  4. Read the neighbors offsets and graph metadata + *
  5. Find the magic number at the end of the file
  6. + *
  7. Read the header offset from the end of the file
  8. + *
  9. Seek to the header offset and read the header
  10. + *
  11. Parse the graph metadata and feature information from the header
  12. *
- * @param headerOffset the offset of the header in the slice - * @throws IOException IOException + * + * @param view the graph view containing the entry node and other metadata + * @param headerOffset the byte offset where the header begins in the output stream + * @throws IOException if an I/O error occurs while writing the footer */ void writeFooter(ImmutableGraphIndex.View view, long headerOffset) throws IOException { var layerInfo = CommonHeader.LayerInfo.fromGraph(graph, ordinalMapper); @@ -145,11 +249,19 @@ void writeFooter(ImmutableGraphIndex.View view, long headerOffset) throws IOExce } /** - * Writes the index header, including the graph size, so that OnDiskGraphIndex can open it. - * The output IS flushed. + * Writes the index header at the beginning of the output stream. *

- * Public so that you can write the index size (and thus usefully open an OnDiskGraphIndex against the index) - * to read Features from it before writing the edges. + * The header includes graph metadata such as version, dimension, entry node, layer information, + * and feature configuration. After writing, the output is flushed to ensure the header is + * persisted to disk. + *

+ * This method is public to allow writing the header early in the process, enabling + * {@code OnDiskGraphIndex} to open the index and read features before the edge data + * is fully written. This is useful for incremental or staged writing scenarios. + * + * @param view the graph view containing the entry node and other metadata + * @param startOffset the byte offset in the output stream where the header should begin + * @throws IOException if an I/O error occurs while writing the header */ public synchronized void writeHeader(ImmutableGraphIndex.View view, long startOffset) throws IOException { // graph-level properties @@ -164,6 +276,22 @@ public synchronized void writeHeader(ImmutableGraphIndex.View view, long startOf assert out.position() == startOffset + headerSize : String.format("%d != %d", out.position(), startOffset + headerSize); } + /** + * Writes the sparse levels of a hierarchical graph to the output stream. + *

+ * For graphs with multiple layers (levels > 0), this method writes each sparse level + * sequentially. Each level contains only a subset of nodes that participate in that level + * of the hierarchy. For each node in a level, the method writes: + *

    + *
  • The remapped node ordinal
  • + *
  • The number of neighbors at this level
  • + *
  • The remapped neighbor ordinals, padded to the level's degree with -1 values
  • + *
+ * + * @param view the graph view providing access to nodes and neighbors at each level + * @throws IOException if an I/O error occurs while writing the sparse levels + * @throws IllegalStateException if the number of nodes written does not match the expected layer size + */ void writeSparseLevels(ImmutableGraphIndex.View view) throws IOException { // write sparse levels for (int level = 1; level <= graph.getMaxLevel(); level++) { @@ -195,6 +323,26 @@ void writeSparseLevels(ImmutableGraphIndex.View view) throws IOException { } } + /** + * Writes separated features to dedicated sections in the output stream. + *

+ * Separated features are stored apart from the main graph node data, which can improve + * cache locality and enable more efficient access patterns for certain use cases. + * This method iterates through all features marked as {@link SeparatedFeature} and writes + * their data sequentially for each node ordinal. + *

+ * For each separated feature: + *

    + *
  • Records the current output position as the feature's offset
  • + *
  • Writes the feature data for each node in ordinal order
  • + *
  • Writes zero-padding for ordinals that have been omitted (deleted nodes)
  • + *
+ * + * @param featureStateSuppliers a map from feature IDs to functions that provide feature state + * for a given node ordinal; must contain suppliers for all separated features + * @throws IOException if an I/O error occurs while writing feature data + * @throws IllegalStateException if a supplier is missing for a separated feature + */ void writeSeparatedFeatures(Map> featureStateSuppliers) throws IOException { for (var featureEntry : featureMap.entrySet()) { if (isSeparated(featureEntry.getValue())) { @@ -225,20 +373,55 @@ void writeSeparatedFeatures(Map> featureSt } /** - * Builder for {@link AbstractGraphIndexWriter}, with optional features. + * Builder for constructing {@link AbstractGraphIndexWriter} instances with configurable features. + *

+ * This builder provides a fluent API for configuring graph index writers. It allows specifying: + *

    + *
  • Format version
  • + *
  • Features to include (vectors, compression, etc.)
  • + *
  • Ordinal mapping strategy
  • + *
*

- * Subclasses should implement `reallyBuild` to return the appropriate type. + * The builder performs validation to ensure the requested configuration is compatible with + * the specified format version. For example, version 3 and earlier only support inline vectors, + * while version 4 and later are required for multilayer graphs. *

- * K - the type of the writer to build - * T - the type of the output stream + * Subclasses must implement {@link #reallyBuild(int)} to construct the appropriate writer type. + *

+ * Example usage: + *

{@code
+     * var writer = new MyGraphIndexWriter.Builder(graph, output)
+     *     .withVersion(4)
+     *     .with(new InlineVectors(dimension))
+     *     .withMapper(ordinalMapper)
+     *     .build();
+     * }
+ * + * @param the concrete type of {@link AbstractGraphIndexWriter} to build + * @param the type of {@link IndexWriter} used for output operations */ public abstract static class Builder, T extends IndexWriter> { + /** The immutable graph index to be written. */ final ImmutableGraphIndex graphIndex; + + /** Map of features to include in the index, keyed by feature ID. */ final EnumMap features; + + /** The output writer for graph data. */ final T out; + + /** Optional ordinal mapper for renumbering nodes; defaults to sequential renumbering if not set. */ OrdinalMapper ordinalMapper; + + /** The format version to use; defaults to {@link OnDiskGraphIndex#CURRENT_VERSION}. */ int version; + /** + * Constructs a new builder for the specified graph and output writer. + * + * @param graphIndex the immutable graph to write to disk + * @param out the output writer for graph data + */ public Builder(ImmutableGraphIndex graphIndex, T out) { this.graphIndex = graphIndex; this.out = out; @@ -246,6 +429,20 @@ public Builder(ImmutableGraphIndex graphIndex, T out) { this.version = OnDiskGraphIndex.CURRENT_VERSION; } + /** + * Sets the format version for the index. + *

+ * Different versions support different features: + *

    + *
  • Version 1-2: Basic graph structure with inline vectors only
  • + *
  • Version 3: Support for multiple feature types
  • + *
  • Version 4+: Required for multilayer graphs
  • + *
+ * + * @param version the format version to use + * @return this builder for method chaining + * @throws IllegalArgumentException if the version is greater than {@link OnDiskGraphIndex#CURRENT_VERSION} + */ public Builder withVersion(int version) { if (version > OnDiskGraphIndex.CURRENT_VERSION) { throw new IllegalArgumentException("Unsupported version: " + version); @@ -255,16 +452,56 @@ public Builder withVersion(int version) { return this; } + /** + * Adds a feature to be written with the index. + *

+ * Features include vector storage (inline or separated), compression schemes, + * and other node-associated data. Each feature is identified by its {@link FeatureId}. + * If a feature with the same ID is already registered, it will be replaced. + * + * @param feature the feature to add + * @return this builder for method chaining + */ public Builder with(Feature feature) { features.put(feature.id(), feature); return this; } + /** + * Sets the ordinal mapper for renumbering nodes during writing. + *

+ * The ordinal mapper controls how node ordinals in the source graph are mapped to + * ordinals in the written index. This is useful for eliminating gaps from deleted nodes + * or for mapping to external identifiers (e.g., database row IDs). + *

+ * If no mapper is specified, {@link #build()} will use {@link #sequentialRenumbering(ImmutableGraphIndex)} + * to create a mapper that eliminates gaps. + * + * @param ordinalMapper the ordinal mapper to use + * @return this builder for method chaining + */ public Builder withMapper(OrdinalMapper ordinalMapper) { this.ordinalMapper = ordinalMapper; return this; } + /** + * Builds and returns the configured graph index writer. + *

+ * This method performs validation to ensure the configuration is valid: + *

    + *
  • Version 3 and earlier must use only {@code INLINE_VECTORS}
  • + *
  • At least one vector feature must be present (inline, separated, or compressed)
  • + *
  • If no ordinal mapper is set, sequential renumbering is applied
  • + *
+ *

+ * The vector dimension is extracted from whichever vector feature is configured. + * + * @return the configured graph index writer + * @throws IOException if an I/O error occurs during writer initialization + * @throws IllegalArgumentException if the configuration is invalid for the specified version, + * or if no vector feature is provided + */ public K build() throws IOException { if (version < 3 && (!features.containsKey(FeatureId.INLINE_VECTORS) || features.size() > 1)) { throw new IllegalArgumentException("Only INLINE_VECTORS is supported until version 3"); @@ -289,12 +526,37 @@ public K build() throws IOException { return reallyBuild(dimension); } + /** + * Constructs the concrete writer instance with the specified dimension. + *

+ * This method is called by {@link #build()} after validation and dimension extraction + * are complete. Subclasses must implement this to instantiate their specific writer type. + * + * @param dimension the vector dimensionality extracted from the configured features + * @return the concrete graph index writer instance + * @throws IOException if an I/O error occurs during writer construction + */ protected abstract K reallyBuild(int dimension) throws IOException; + /** + * Sets the ordinal mapper using a map from old to new ordinals. + *

+ * This is a convenience method equivalent to calling + * {@code withMapper(new OrdinalMapper.MapMapper(oldToNewOrdinals))}. + * + * @param oldToNewOrdinals a map from original graph ordinals to new ordinals for writing + * @return this builder for method chaining + */ public Builder withMap(Map oldToNewOrdinals) { return withMapper(new OrdinalMapper.MapMapper(oldToNewOrdinals)); } + /** + * Returns the feature associated with the specified feature ID. + * + * @param featureId the ID of the feature to retrieve + * @return the feature with the specified ID, or {@code null} if no such feature is configured + */ public Feature getFeature(FeatureId featureId) { return features.get(featureId); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CommonHeader.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CommonHeader.java index 5d0a1aecb..d5d29db44 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CommonHeader.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CommonHeader.java @@ -61,10 +61,19 @@ public class CommonHeader { private static final int V4_MAX_LAYERS = 32; + /** The graph index format version */ public final int version; + + /** The vector dimension */ public final int dimension; + + /** The entry node id for graph traversal */ public final int entryNode; + + /** Information about each layer in the graph */ public final List layerInfo; + + /** The upper bound of node IDs (maximum node ID + 1) */ public final int idUpperBound; CommonHeader(int version, int dimension, int entryNode, List layerInfo, int idUpperBound) { @@ -162,16 +171,36 @@ int size() { return size * Integer.BYTES; } + /** + * Information about a single layer in a multi-layer graph. + */ @VisibleForTesting public static class LayerInfo { + /** The number of nodes in this layer */ public final int size; + + /** The maximum degree (number of neighbors) for nodes in this layer */ public final int degree; + /** + * Constructs layer information with the given size and degree. + * + * @param size the number of nodes in this layer + * @param degree the maximum degree for nodes in this layer + */ public LayerInfo(int size, int degree) { this.size = size; this.degree = degree; } + /** + * Creates a list of LayerInfo from a graph, extracting size and degree information + * for each layer. + * + * @param graph the graph to extract layer information from + * @param mapper the ordinal mapper (currently unused but kept for API compatibility) + * @return a list of LayerInfo objects, one per layer + */ public static List fromGraph(ImmutableGraphIndex graph, OrdinalMapper mapper) { return IntStream.rangeClosed(0, graph.getMaxLevel()) .mapToObj(i -> new LayerInfo(graph.size(i), graph.getDegree(i))) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphIndexWriter.java index ac67900fe..64155f942 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphIndexWriter.java @@ -36,6 +36,7 @@ public interface GraphIndexWriter extends Closeable { * Each supplier takes a node ordinal and returns a FeatureState suitable for Feature.writeInline. * * @param featureStateSuppliers a map of FeatureId to a function that returns a Feature.State + * @throws IOException if an I/O error occurs during writing */ void write(Map> featureStateSuppliers) throws IOException; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index 30eb0a6ba..8b97ce009 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -58,14 +58,14 @@ /** * A class representing a graph index stored on disk. The base graph contains only graph structure. - *

* The base graph - + *

* This graph may be extended with additional features, which are stored inline in the graph and in headers. * At runtime, this class may choose the best way to use these features. */ public class OnDiskGraphIndex implements ImmutableGraphIndex, AutoCloseable, Accountable { private static final Logger logger = LoggerFactory.getLogger(OnDiskGraphIndex.class); + /** The current serialization version for on-disk graph indices. */ public static final int CURRENT_VERSION = 5; static final int MAGIC = 0xFFFF0D61; // FFFF to distinguish from old graphs, which should never start with a negative size "ODGI" static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); @@ -83,6 +83,14 @@ public class OnDiskGraphIndex implements ImmutableGraphIndex, AutoCloseable, Acc /** For layers > 0, store adjacency fully in memory. */ private final AtomicReference>> inMemoryNeighbors; + /** + * Constructs an OnDiskGraphIndex from a reader supplier, header, and neighbors offset. + * This constructor is package-private and used internally by the load methods. + * + * @param readerSupplier supplies readers for accessing the graph data + * @param header the parsed header containing graph metadata and features + * @param neighborsOffset the file offset where layer 0 adjacency data begins + */ OnDiskGraphIndex(ReaderSupplier readerSupplier, Header header, long neighborsOffset) { this.readerSupplier = readerSupplier; @@ -106,6 +114,14 @@ public class OnDiskGraphIndex implements ImmutableGraphIndex, AutoCloseable, Acc inMemoryNeighbors = new AtomicReference<>(null); } + /** + * Returns the in-memory representation of higher-layer adjacency data (layers 1+). + * Loads the data on first access and caches it for subsequent calls. + * + * @param in the reader to use for loading data if not already cached + * @return a list of maps from node ID to neighbor arrays, one per layer (with null at index 0) + * @throws IOException if an I/O error occurs during loading + */ private List> getInMemoryLayers(RandomAccessReader in) throws IOException { return inMemoryNeighbors.updateAndGet(current -> { if (current != null) { @@ -119,6 +135,14 @@ private List> getInMemoryLayers(RandomAccessReader in) }); } + /** + * Loads the higher-layer (1+) adjacency data into memory from disk. + * Layer 0 is not loaded since it remains on disk for efficient random access. + * + * @param in the reader to use for loading the layer data + * @return a list of maps from node ID to neighbor arrays, one per layer (with null at index 0) + * @throws IOException if an I/O error occurs during loading + */ private List> loadInMemoryLayers(RandomAccessReader in) throws IOException { var imn = new ArrayList>(layerInfo.size()); // For levels > 0, we load adjacency into memory @@ -153,11 +177,12 @@ private List> loadInMemoryLayers(RandomAccessReader in) } /** - * Load an index from the given reader supplier where header and graph are located on the same file, - * where the index starts at `offset`. + * Loads an index from the given reader supplier where header and graph are located on the same file, + * where the index starts at the specified offset. * - * @param readerSupplier the reader supplier to use to read the graph and index. - * @param offset the offset in bytes from the start of the file where the index starts. + * @param readerSupplier the reader supplier to use to read the graph and index + * @param offset the offset in bytes from the start of the file where the index starts + * @return the loaded OnDiskGraphIndex instance */ public static OnDiskGraphIndex load(ReaderSupplier readerSupplier, long offset) { try (var reader = readerSupplier.get()) { @@ -180,20 +205,23 @@ public static OnDiskGraphIndex load(ReaderSupplier readerSupplier, long offset) } /** - * Load an index from the given reader supplier where header and graph are located on the same file at offset 0. + * Loads an index from the given reader supplier where header and graph are located on the same file at offset 0. * - * @param readerSupplier the reader supplier to use to read the graph index. + * @param readerSupplier the reader supplier to use to read the graph index + * @return the loaded OnDiskGraphIndex instance */ public static OnDiskGraphIndex load(ReaderSupplier readerSupplier) { return load(readerSupplier, 0); } /** - * Load an index from the given reader supplier where we will use the footer of the file to find the header. - * In this implementation we will assume that the {@link ReaderSupplier} must vend slices of IndexOutput that contain the graph index and nothing else. - * @param readerSupplier the reader supplier to use to read the graph index. - * This reader supplier must vend slices of IndexOutput that contain the graph index and nothing else. - * @return the loaded index. + * Loads an index from the given reader supplier using the footer of the file to locate the header. + * This method assumes the reader supplier vends slices that contain only the graph index data. + * + * @param readerSupplier the reader supplier to use to read the graph index + * @param neighborsOffset the offset where layer 0 adjacency data begins + * @return the loaded OnDiskGraphIndex instance + * @throws RuntimeException if the footer is invalid or an I/O error occurs */ private static OnDiskGraphIndex loadFromFooter(ReaderSupplier readerSupplier, long neighborsOffset) { try (var in = readerSupplier.get()) { @@ -223,10 +251,20 @@ private static OnDiskGraphIndex loadFromFooter(ReaderSupplier readerSupplier, lo } } + /** + * Returns the set of feature IDs present in this index. + * + * @return a set of feature IDs + */ public Set getFeatureSet() { return features.keySet(); } + /** + * Returns the dimensionality of vectors stored in this index. + * + * @return the vector dimension + */ public int getDimension() { return dimension; } @@ -297,12 +335,23 @@ public NodesIterator getNodes(int level) { } } + /** + * Returns the approximate memory usage in bytes of this index. + * This includes overhead for index structures but not the full graph data on disk. + * + * @return the memory usage in bytes + */ @Override public long ramBytesUsed() { return Long.BYTES + 6 * Integer.BYTES + RamUsageEstimator.NUM_BYTES_OBJECT_REF + (long) 2 * RamUsageEstimator.NUM_BYTES_OBJECT_REF * FeatureId.values().length; } + /** + * Closes this index. Note that the caller is responsible for closing the ReaderSupplier. + * + * @throws IOException if an I/O error occurs + */ public void close() throws IOException { // caller is responsible for closing ReaderSupplier } @@ -313,11 +362,21 @@ public String toString() { features.keySet().stream().map(Enum::name).collect(Collectors.joining(","))); } + /** + * Returns the maximum level (highest layer number) in this graph. + * + * @return the maximum level + */ @Override public int getMaxLevel() { return entryNode.level; } + /** + * Returns the maximum degree across all layers in this graph. + * + * @return the maximum degree + */ @Override public int maxDegree() { return layerInfo.stream().mapToInt(li -> li.degree).max().orElseThrow(); @@ -345,10 +404,22 @@ public double getAverageDegree(int level) { return (double) sum / it.size(); } + /** + * A view for accessing graph data with a dedicated reader. + * This class implements multiple interfaces to provide vector access, scoring, and feature reading. + * Each view maintains its own reader for thread-safe access to the graph data. + */ public class View implements FeatureSource, ScoringView, RandomAccessVectorValues { + /** The reader for accessing graph data from disk. */ protected final RandomAccessReader reader; + /** Reusable array for reading neighbor lists from disk. */ private final int[] neighbors; + /** + * Constructs a View with the given reader. + * + * @param reader the reader to use for accessing graph data + */ public View(RandomAccessReader reader) { this.reader = reader; this.neighbors = new int[layerInfo.stream().mapToInt(li -> li.degree).max().orElse(0)]; @@ -370,6 +441,13 @@ public RandomAccessVectorValues copy() { throw new UnsupportedOperationException(); // need to copy reader } + /** + * Computes the file offset for accessing a specific feature of a given node. + * + * @param node the node ID + * @param featureId the feature to access + * @return the file offset in bytes + */ private long offsetFor(int node, FeatureId featureId) { Feature feature = features.get(featureId); @@ -385,6 +463,13 @@ private long offsetFor(int node, FeatureId featureId) { return baseNodeOffsetFor(node) + skipInNode; } + /** + * Computes the file offset for accessing the neighbors of a node at layer 0. + * + * @param level must be 0 (higher layers are in memory) + * @param node the node ID + * @return the file offset in bytes + */ private long neighborsOffsetFor(int level, int node) { assert level == 0; // higher layers are in memory @@ -393,6 +478,12 @@ private long neighborsOffsetFor(int level, int node) { return baseNodeOffsetFor(node) + skipInline; } + /** + * Computes the base file offset for a node's data block at layer 0. + * + * @param node the node ID + * @return the file offset in bytes + */ private long baseNodeOffsetFor(int node) { int degree = layerInfo.get(0).degree; @@ -438,6 +529,14 @@ public void getVectorInto(int node, VectorFloat vector, int offset) { } } + /** + * Returns an iterator over the neighbors of a node at the specified level. + * For layer 0, neighbors are read from disk. For higher layers, they are read from memory. + * + * @param level the layer number + * @param node the node ID + * @return an iterator over the node's neighbors + */ public NodesIterator getNeighborsIterator(int level, int node) { try { if (level == 0) { @@ -460,22 +559,46 @@ public NodesIterator getNeighborsIterator(int level, int node) { } } + /** + * Returns the number of nodes at layer 0. + * + * @return the size of layer 0 + */ @Override public int size() { // For vector operations we only care about layer 0 return OnDiskGraphIndex.this.size(0); } + /** + * Returns the entry node for graph traversal. + * + * @return the entry node with its level + */ @Override public NodeAtLevel entryNode() { return entryNode; } + /** + * Returns the upper bound on node IDs in this graph. + * + * @return the ID upper bound + */ @Override public int getIdUpperBound() { return idUpperBound; } + /** + * Checks whether a node exists at the specified level. + * For layer 0, checks if the node ID is within bounds. + * For higher layers, checks the in-memory layer data. + * + * @param level the layer number + * @param node the node ID + * @return true if the node exists at the specified level + */ @Override public boolean contains(int level, int node) { try { @@ -491,11 +614,22 @@ public boolean contains(int level, int node) { } } + /** + * Returns a Bits instance indicating which nodes are live (not deleted). + * For on-disk graphs, all nodes are considered live. + * + * @return Bits.ALL indicating all nodes are live + */ @Override public Bits liveNodes() { return Bits.ALL; } + /** + * Closes this view and its associated reader. + * + * @throws IOException if an I/O error occurs + */ @Override public void close() throws IOException { reader.close(); @@ -522,12 +656,27 @@ public ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(Vector } } - /** Convenience function for writing a vanilla DiskANN-style index with no extra Features. */ + /** + * Convenience function for writing a vanilla DiskANN-style index with no extra Features. + * + * @param graph the graph to write + * @param vectors the vectors to include in the index + * @param path the output file path + * @throws IOException if an I/O error occurs during writing + */ public static void write(ImmutableGraphIndex graph, RandomAccessVectorValues vectors, Path path) throws IOException { write(graph, vectors, OnDiskGraphIndexWriter.sequentialRenumbering(graph), path); } - /** Convenience function for writing a vanilla DiskANN-style index with no extra Features. */ + /** + * Convenience function for writing a vanilla DiskANN-style index with no extra Features. + * + * @param graph the graph to write + * @param vectors the vectors to include in the index + * @param oldToNewOrdinals mapping from original ordinals to sequential ordinals + * @param path the output file path + * @throws IOException if an I/O error occurs during writing + */ public static void write(ImmutableGraphIndex graph, RandomAccessVectorValues vectors, Map oldToNewOrdinals, diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java index a8515c191..9bfe480dd 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java @@ -67,6 +67,18 @@ public class OnDiskGraphIndexWriter extends AbstractGraphIndexWriter { private final long startOffset; + /** + * Constructs an OnDiskGraphIndexWriter with the specified parameters. + * This constructor is package-private and used by the Builder. + * + * @param randomAccessWriter the writer to output the graph to + * @param version the serialization version to use + * @param startOffset the offset in the file where the graph starts + * @param graph the graph to serialize + * @param oldToNewOrdinals mapper for converting between original and on-disk ordinals + * @param dimension the dimensionality of vectors in the graph + * @param features the features to include in the serialized graph + */ OnDiskGraphIndexWriter(RandomAccessWriter randomAccessWriter, int version, long startOffset, @@ -89,22 +101,28 @@ public synchronized void close() throws IOException { } /** + * Returns the underlying RandomAccessWriter for direct access. * Caller should synchronize on this OnDiskGraphIndexWriter instance if mixing usage of the * output with calls to any of the synchronized methods in this class. - *

- * Provided for callers (like Cassandra) that want to add their own header/footer to the output. + * This method is provided for callers (like Cassandra) that want to add their own header/footer to the output. + * + * @return the underlying RandomAccessWriter */ public RandomAccessWriter getOutput() { return out; } /** - * Write the inline features of the given ordinal to the output at the correct offset. - * Nothing else is written (no headers, no edges). The output IS NOT flushed. + * Writes the inline features of the given ordinal to the output at the correct offset. + * Nothing else is written (no headers, no edges). The output IS NOT flushed. *

* Note: the ordinal given is implicitly a "new" ordinal in the sense of the OrdinalMapper, * but since no nodes or edges are involved (we just write the given State to the index file), * the mapper is not invoked. + * + * @param ordinal the ordinal to write features for + * @param stateMap map of feature IDs to their state objects + * @throws IOException if an I/O error occurs during writing */ public synchronized void writeInline(int ordinal, Map stateMap) throws IOException { @@ -128,10 +146,25 @@ public synchronized void writeInline(int ordinal, Map maxOrdinalWritten = Math.max(maxOrdinalWritten, ordinal); } + /** + * Computes the file offset where inline features for a given ordinal should be written. + * + * @param ordinal the node ordinal + * @return the file offset in bytes + */ private long featureOffsetForOrdinal(int ordinal) { return super.featureOffsetForOrdinal(startOffset, ordinal); } + /** + * Writes the entire graph index to disk, including headers, features, and adjacency data. + * This method writes layer 0 data, higher layer data, separated features, and the footer. + * + * @param featureStateSuppliers functions that provide feature state for each node ordinal + * @throws IOException if an I/O error occurs during writing + * @throws IllegalArgumentException if the graph contains deleted nodes or if a feature is not configured + * @throws IllegalStateException if the ordinal mapper doesn't cover all nodes or if nodes/neighbors are invalid + */ public synchronized void write(Map> featureStateSuppliers) throws IOException { if (graph instanceof OnHeapGraphIndex) { @@ -231,9 +264,11 @@ public synchronized void write(Map> featur } /** - * Write the index header and completed edge lists to the given output. - * Unlike the super method, this method flushes the output and also assumes it's using a RandomAccessWriter that can - * seek to the startOffset and re-write the header. + * Writes the index header to the output at the start offset. + * Unlike the super method, this method flushes the output and uses the RandomAccessWriter's + * ability to seek to the startOffset and re-write the header. + * + * @param view the graph view to write header information from * @throws IOException if there is an error writing the header */ public synchronized void writeHeader(ImmutableGraphIndex.View view) throws IOException { @@ -243,7 +278,12 @@ public synchronized void writeHeader(ImmutableGraphIndex.View view) throws IOExc out.flush(); } - /** CRC32 checksum of bytes written since the starting offset */ + /** + * Computes the CRC32 checksum of bytes written since the starting offset. + * + * @return the CRC32 checksum + * @throws IOException if an I/O error occurs + */ public synchronized long checksum() throws IOException { long endOffset = out.position(); return out.checksum(startOffset, endOffset); @@ -255,23 +295,46 @@ public synchronized long checksum() throws IOException { public static class Builder extends AbstractGraphIndexWriter.Builder { private long startOffset = 0L; + /** + * Constructs a Builder for writing a graph index to a file. + * + * @param graphIndex the graph to write + * @param outPath the output file path + * @throws FileNotFoundException if the output file cannot be created + */ public Builder(ImmutableGraphIndex graphIndex, Path outPath) throws FileNotFoundException { this(graphIndex, new BufferedRandomAccessWriter(outPath)); } + /** + * Constructs a Builder for writing a graph index using a custom writer. + * + * @param graphIndex the graph to write + * @param out the output writer to use + */ public Builder(ImmutableGraphIndex graphIndex, RandomAccessWriter out) { super(graphIndex, out); } /** - * Set the starting offset for the graph index in the output file. This is useful if you want to - * append the index to an existing file. + * Sets the starting offset for the graph index in the output file. + * This is useful when appending the index to an existing file. + * + * @param startOffset the byte offset where the graph index should start + * @return this Builder instance for method chaining */ public Builder withStartOffset(long startOffset) { this.startOffset = startOffset; return this; } + /** + * Creates the OnDiskGraphIndexWriter instance with the configured parameters. + * + * @param dimension the dimensionality of vectors in the graph + * @return a new OnDiskGraphIndexWriter instance + * @throws IOException if an I/O error occurs during initialization + */ @Override protected OnDiskGraphIndexWriter reallyBuild(int dimension) throws IOException { return new OnDiskGraphIndexWriter(out, version, startOffset, graphIndex, ordinalMapper, dimension, features); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskSequentialGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskSequentialGraphIndexWriter.java index e7dd69476..a983e5d62 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskSequentialGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskSequentialGraphIndexWriter.java @@ -55,6 +55,16 @@ */ public class OnDiskSequentialGraphIndexWriter extends AbstractGraphIndexWriter { + /** + * Constructs an OnDiskSequentialGraphIndexWriter with the specified parameters. + * + * @param out the output writer + * @param version the serialization version + * @param graph the graph to serialize + * @param oldToNewOrdinals mapper for converting between original and on-disk ordinals + * @param dimension the dimensionality of vectors in the graph + * @param features the features to include in the serialized graph + */ OnDiskSequentialGraphIndexWriter(IndexWriter out, int version, ImmutableGraphIndex graph, @@ -71,12 +81,18 @@ public synchronized void close() throws IOException { } /** + * Writes the entire graph index to disk sequentially, including headers, features, and adjacency data. * Note: There are several limitations you should be aware of when using: *

    - *
  • This method doesn't persist (e.g. flush) the output streams. The caller is responsible for doing so. - *
  • This method does not support writing to "holes" in the ordinal space. If your ordinal mapper + *
  • This method doesn't persist (e.g. flush) the output streams. The caller is responsible for doing so. + *
  • This method does not support writing to "holes" in the ordinal space. If your ordinal mapper * maps a new ordinal to an old ordinal that does not exist in the graph, an exception will be thrown. *
+ * + * @param featureStateSuppliers functions that provide feature state for each node ordinal + * @throws IOException if an I/O error occurs during writing + * @throws IllegalArgumentException if the graph contains deleted nodes or if a feature is not configured + * @throws IllegalStateException if the ordinal mapper doesn't cover all nodes, maps to holes, or if nodes/neighbors are invalid */ @Override public synchronized void write(Map> featureStateSuppliers) throws IOException @@ -167,6 +183,12 @@ public synchronized void write(Map> featur * Builder for {@link OnDiskSequentialGraphIndexWriter}, with optional features. */ public static class Builder extends AbstractGraphIndexWriter.Builder { + /** + * Constructs a Builder for writing a graph index using a sequential writer. + * + * @param graphIndex the graph to write + * @param out the output writer to use + */ public Builder(ImmutableGraphIndex graphIndex, IndexWriter out) { super(graphIndex, out); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OrdinalMapper.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OrdinalMapper.java index 526241eff..222fe25c3 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OrdinalMapper.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OrdinalMapper.java @@ -34,42 +34,73 @@ public interface OrdinalMapper { int OMITTED = Integer.MIN_VALUE; /** - * OnDiskGraphIndexWriter will iterate from 0..maxOrdinal(), inclusive. + * Returns the maximum ordinal value (inclusive) that OnDiskGraphIndexWriter will iterate over. + * + * @return the maximum ordinal value */ int maxOrdinal(); /** - * Map old ordinals (in the graph as constructed) to new ordinals (written to disk). + * Maps old ordinals (in the graph as constructed) to new ordinals (written to disk). * Should always return a valid ordinal (between 0 and maxOrdinal). + * + * @param oldOrdinal the original ordinal in the graph + * @return the new ordinal to use when writing to disk */ int oldToNew(int oldOrdinal); /** - * Map new ordinals (written to disk) to old ordinals (in the graph as constructed). + * Maps new ordinals (written to disk) to old ordinals (in the graph as constructed). * May return OMITTED if there is a "hole" at the new ordinal. + * + * @param newOrdinal the new ordinal written to disk + * @return the original ordinal in the graph, or OMITTED if there is a hole */ int newToOld(int newOrdinal); /** * A mapper that leaves the original ordinals unchanged. + * This is the simplest implementation where old and new ordinals are identical. */ class IdentityMapper implements OrdinalMapper { private final int maxOrdinal; + /** + * Constructs an IdentityMapper with the specified maximum ordinal. + * + * @param maxOrdinal the maximum ordinal value (inclusive) + */ public IdentityMapper(int maxOrdinal) { this.maxOrdinal = maxOrdinal; } + /** + * Returns the maximum ordinal value. + * + * @return the maximum ordinal + */ @Override public int maxOrdinal() { return maxOrdinal; } + /** + * Maps an old ordinal to a new ordinal. For IdentityMapper, returns the same value. + * + * @param oldOrdinal the original ordinal + * @return the same ordinal unchanged + */ @Override public int oldToNew(int oldOrdinal) { return oldOrdinal; } + /** + * Maps a new ordinal to an old ordinal. For IdentityMapper, returns the same value. + * + * @param newOrdinal the new ordinal + * @return the same ordinal unchanged + */ @Override public int newToOld(int newOrdinal) { return newOrdinal; @@ -78,12 +109,19 @@ public int newToOld(int newOrdinal) { /** * Converts a Map of old to new ordinals into an OrdinalMapper. + * This implementation allows for arbitrary remapping and supports gaps (omitted ordinals). */ class MapMapper implements OrdinalMapper { private final int maxOrdinal; private final Map oldToNew; private final Int2IntHashMap newToOld; + /** + * Constructs a MapMapper from a map of old to new ordinals. + * The mapper builds a reverse mapping and determines the maximum new ordinal. + * + * @param oldToNew a map from original ordinals to new ordinals + */ public MapMapper(Map oldToNew) { this.oldToNew = oldToNew; this.newToOld = new Int2IntHashMap(oldToNew.size(), 0.65f, OMITTED); @@ -91,16 +129,34 @@ public MapMapper(Map oldToNew) { this.maxOrdinal = oldToNew.values().stream().mapToInt(i -> i).max().orElse(-1); } + /** + * Returns the maximum new ordinal value. + * + * @return the maximum ordinal + */ @Override public int maxOrdinal() { return maxOrdinal; } + /** + * Maps an old ordinal to its corresponding new ordinal. + * + * @param oldOrdinal the original ordinal + * @return the new ordinal corresponding to the old ordinal + */ @Override public int oldToNew(int oldOrdinal) { return oldToNew.get(oldOrdinal); } + /** + * Maps a new ordinal back to its original ordinal. + * Returns OMITTED if there is no mapping for the new ordinal. + * + * @param newOrdinal the new ordinal + * @return the original ordinal, or OMITTED if there is a gap + */ @Override public int newToOld(int newOrdinal) { return newToOld.get(newOrdinal); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java index a72ff10b6..33b15dbec 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java @@ -25,28 +25,74 @@ * A feature of an on-disk graph index. Information to use a feature is stored in the header on-disk. */ public interface Feature { + /** + * Returns the unique identifier for this feature. + * + * @return the FeatureId of this feature + */ FeatureId id(); + /** + * Returns the size in bytes of this feature's header data. + * + * @return the header size in bytes + */ int headerSize(); + /** + * Returns the size in bytes of this feature's per-node data. + * + * @return the feature size in bytes per node + */ int featureSize(); + /** + * Writes this feature's header data to the output stream. + * + * @param out the output stream to write to + * @throws IOException if an I/O error occurs + */ void writeHeader(DataOutput out) throws IOException; + /** + * Writes inline feature data for a node to the output stream. + * Default implementation is a no-op for features that don't support inline storage. + * + * @param out the output stream to write to + * @param state the state containing the data to write + * @throws IOException if an I/O error occurs + */ default void writeInline(DataOutput out, State state) throws IOException { // default no-op } - // Feature implementations should implement a State as well for use with writeInline/writeSeparately + /** + * Marker interface for feature-specific state used during writing. + * Feature implementations should implement this interface for their specific state. + */ interface State { } + /** + * Creates a single-entry map associating a FeatureId with a state factory function. + * + * @param id the feature identifier + * @param stateFactory the factory function to create state instances + * @return an EnumMap containing the single mapping + */ static EnumMap> singleStateFactory(FeatureId id, IntFunction stateFactory) { EnumMap> map = new EnumMap<>(FeatureId.class); map.put(id, stateFactory); return map; } + /** + * Creates a single-entry map associating a FeatureId with a state instance. + * + * @param id the feature identifier + * @param state the state instance + * @return an EnumMap containing the single mapping + */ static EnumMap singleState(FeatureId id, State state) { EnumMap map = new EnumMap<>(FeatureId.class); map.put(id, state); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureId.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureId.java index dd0857834..32f7aced5 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureId.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureId.java @@ -31,12 +31,22 @@ * These are typically mapped to a Feature. */ public enum FeatureId { + /** Vectors stored inline with the graph structure */ INLINE_VECTORS(InlineVectors::load), + + /** Fused asymmetric distance computation for efficient similarity search */ FUSED_ADC(FusedADC::load), + + /** Vectors compressed using Neighborhood Vector Quantization */ NVQ_VECTORS(NVQ::load), + + /** Vectors stored separately from the graph structure */ SEPARATED_VECTORS(SeparatedVectors::load), + + /** NVQ-compressed vectors stored separately from the graph structure */ SEPARATED_NVQ(SeparatedNVQ::load); + /** A set containing all available feature IDs */ public static final Set ALL = Collections.unmodifiableSet(EnumSet.allOf(FeatureId.class)); private final BiFunction loader; @@ -45,10 +55,23 @@ public enum FeatureId { this.loader = loader; } + /** + * Loads the Feature implementation associated with this FeatureId from disk. + * + * @param header the common header containing graph metadata + * @param reader the reader for accessing the on-disk data + * @return the loaded Feature instance + */ public Feature load(CommonHeader header, RandomAccessReader reader) { return loader.apply(header, reader); } + /** + * Deserializes a set of FeatureIds from a bitfield representation. + * + * @param bitflags the bitfield where each bit represents the presence of a feature + * @return an EnumSet containing the features indicated by the bitfield + */ public static EnumSet deserialize(int bitflags) { EnumSet set = EnumSet.noneOf(FeatureId.class); for (int n = 0; n < values().length; n++) { @@ -58,6 +81,12 @@ public static EnumSet deserialize(int bitflags) { return set; } + /** + * Serializes a set of FeatureIds into a bitfield representation. + * + * @param flags the set of features to serialize + * @return a bitfield where each bit represents the presence of a feature + */ public static int serialize(EnumSet flags) { int i = 0; for (FeatureId flag : flags) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureSource.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureSource.java index b8b24a949..64c3c929c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureSource.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureSource.java @@ -21,6 +21,17 @@ import java.io.Closeable; import java.io.IOException; +/** + * A source for reading feature data associated with graph nodes. + */ public interface FeatureSource extends Closeable { + /** + * Returns a reader for accessing the feature data of a specific node. + * + * @param node the node id to read feature data for + * @param featureId the type of feature to read + * @return a RandomAccessReader positioned to read the node's feature data + * @throws IOException if an I/O error occurs + */ RandomAccessReader featureReaderForNode(int node, FeatureId featureId) throws IOException; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java index 59ca11564..f523a433e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java @@ -46,6 +46,13 @@ public class FusedADC implements Feature { private final ExplicitThreadLocal> reusableNeighbors; private ByteSequence compressedNeighbors = null; + /** + * Creates a new FusedADC feature with the given parameters. + * + * @param maxDegree the maximum degree of the graph (must be 32) + * @param pq the product quantization to use for compression (must have 256 clusters) + * @throws IllegalArgumentException if maxDegree is not 32 or if PQ cluster count is not 256 + */ public FusedADC(int maxDegree, ProductQuantization pq) { if (maxDegree != 32) { throw new IllegalArgumentException("maxDegree must be 32 for FusedADC. This limitation may be removed in future releases"); @@ -74,6 +81,14 @@ public int featureSize() { return pq.compressedVectorSize() * maxDegree; } + /** + * Loads a FusedADC feature from the given reader. + * + * @param header the common header containing graph metadata + * @param reader the reader to load from + * @return a new FusedADC instance + * @throws UncheckedIOException if an I/O error occurs + */ static FusedADC load(CommonHeader header, RandomAccessReader reader) { // TODO doesn't work with different degrees try { @@ -83,6 +98,15 @@ static FusedADC load(CommonHeader header, RandomAccessReader reader) { } } + /** + * Creates an approximate score function for the given query vector. + * + * @param queryVector the query vector to compute scores against + * @param vsf the vector similarity function to use + * @param view the view of the on-disk graph index + * @param esf the exact score function for fallback computations + * @return an approximate score function that uses fused ADC + */ public ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat queryVector, VectorSimilarityFunction vsf, OnDiskGraphIndex.View view, ScoreFunction.ExactScoreFunction esf) { var neighbors = new PackedNeighbors(view); return FusedADCPQDecoder.newDecoder(neighbors, pq, queryVector, reusableResults.get(), vsf, esf); @@ -117,11 +141,24 @@ public void writeInline(DataOutput out, Feature.State state_) throws IOException vectorTypeSupport.writeByteSequence(out, compressedNeighbors); } + /** + * Encapsulates the state required for writing FusedADC feature data for a node. + */ public static class State implements Feature.State { + /** The view of the graph index */ public final ImmutableGraphIndex.View view; + /** The product quantized vectors */ public final PQVectors pqVectors; + /** The node identifier */ public final int nodeId; + /** + * Creates a new State instance. + * + * @param view the view of the graph index + * @param pqVectors the product quantized vectors + * @param nodeId the node identifier + */ public State(ImmutableGraphIndex.View view, PQVectors pqVectors, int nodeId) { this.view = view; this.pqVectors = pqVectors; @@ -129,13 +166,28 @@ public State(ImmutableGraphIndex.View view, PQVectors pqVectors, int nodeId) { } } + /** + * Provides access to packed neighbors for efficient bulk similarity computations. + */ public class PackedNeighbors { private final OnDiskGraphIndex.View view; + /** + * Creates a new PackedNeighbors instance. + * + * @param view the view of the on-disk graph index + */ public PackedNeighbors(OnDiskGraphIndex.View view) { this.view = view; } + /** + * Returns the packed neighbors for the given node as a byte sequence. + * + * @param node the node identifier + * @return the packed neighbors as a byte sequence + * @throws RuntimeException if an I/O error occurs + */ public ByteSequence getPackedNeighbors(int node) { try { var reader = view.featureReaderForNode(node, FeatureId.FUSED_ADC); @@ -147,6 +199,11 @@ public ByteSequence getPackedNeighbors(int node) { } } + /** + * Returns the maximum degree of the graph. + * + * @return the maximum degree + */ public int maxDegree() { return maxDegree; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java index 59e2b359c..2caba51d8 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java @@ -32,6 +32,11 @@ public class InlineVectors implements Feature { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); private final int dimension; + /** + * Constructs an InlineVectors feature with the specified dimension. + * + * @param dimension the vector dimension + */ public InlineVectors(int dimension) { this.dimension = dimension; } @@ -46,14 +51,31 @@ public int headerSize() { return 0; } + /** + * Returns the size in bytes of each inline vector. + * + * @return the feature size in bytes + */ public int featureSize() { return dimension * Float.BYTES; } + /** + * Returns the dimension of the stored vectors. + * + * @return the vector dimension + */ public int dimension() { return dimension; } + /** + * Loads an InlineVectors feature from the reader. + * + * @param header the common header containing dimension information + * @param reader the reader (not used, dimension comes from header) + * @return a new InlineVectors instance + */ static InlineVectors load(CommonHeader header, RandomAccessReader reader) { return new InlineVectors(header.dimension); } @@ -68,9 +90,18 @@ public void writeInline(DataOutput out, Feature.State state) throws IOException vectorTypeSupport.writeFloatVector(out, ((InlineVectors.State) state).vector); } + /** + * State holder for an inline vector being written. + */ public static class State implements Feature.State { + /** The vector to be written inline. */ public final VectorFloat vector; + /** + * Constructs a State with the given vector. + * + * @param vector the vector to store + */ public State(VectorFloat vector) { this.vector = vector; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java index 2489ada21..3873f67ac 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java @@ -38,6 +38,11 @@ public class NVQ implements Feature { private final NVQScorer scorer; private final ThreadLocal reusableQuantizedVector; + /** + * Constructs an NVQ feature with the given NVQuantization compressor. + * + * @param nvq the NVQuantization instance to use for encoding/decoding + */ public NVQ(NVQuantization nvq) { this.nvq = nvq; scorer = new NVQScorer(this.nvq); @@ -54,13 +59,30 @@ public int headerSize() { return nvq.compressorSize(); } + /** + * Returns the size in bytes of a single NVQ-quantized vector. + * + * @return the feature size in bytes + */ @Override public int featureSize() { return nvq.compressedVectorSize();} + /** + * Returns the dimensionality of the original uncompressed vectors. + * + * @return the vector dimension + */ public int dimension() { return nvq.globalMean.length(); } + /** + * Loads an NVQ feature from a reader. + * + * @param header the common header (unused but required by signature) + * @param reader the reader to load from + * @return the loaded NVQ feature + */ static NVQ load(CommonHeader header, RandomAccessReader reader) { try { return new NVQ(NVQuantization.load(reader)); @@ -80,14 +102,33 @@ public void writeInline(DataOutput out, Feature.State state_) throws IOException state.vector.write(out); } + /** + * Represents the state of an NVQ-quantized vector for a single node. + */ public static class State implements Feature.State { + /** + * The quantized vector. + */ public final QuantizedVector vector; + /** + * Constructs a State with the given quantized vector. + * + * @param vector the quantized vector + */ public State(QuantizedVector vector) { this.vector = vector; } } + /** + * Creates a reranking score function that loads NVQ vectors from disk and computes exact scores. + * + * @param queryVector the query vector + * @param vsf the vector similarity function to use + * @param source the source to read NVQ vectors from + * @return an exact score function for reranking + */ public ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf, FeatureSource source) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedFeature.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedFeature.java index d90aee603..81cd134ad 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedFeature.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedFeature.java @@ -19,9 +19,34 @@ import java.io.DataOutput; import java.io.IOException; +/** + * A feature whose data is stored separately from the main graph structure. + * Separated features write their data to a separate location on disk, with only + * the offset information stored in the graph header. This is useful for large + * features that would make inline storage inefficient. + */ public interface SeparatedFeature extends Feature { + /** + * Sets the file offset where this feature's data begins. + * + * @param offset the file offset in bytes + */ void setOffset(long offset); + + /** + * Returns the file offset where this feature's data begins. + * + * @return the file offset in bytes + */ long getOffset(); + /** + * Writes this feature's data to the specified output, separate from the graph structure. + * This method is called during graph serialization to write feature data to its dedicated location. + * + * @param out the output to write to + * @param state the feature state containing the data to write + * @throws IOException if an I/O error occurs + */ void writeSeparately(DataOutput out, State state) throws IOException; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java index b5d4cc476..29d97e73e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java @@ -29,12 +29,26 @@ import java.io.IOException; import java.io.UncheckedIOException; +/** + * A separated feature implementation for Neural Vector Quantization (NVQ) compressed vectors. + * Stores quantized vector data separately from the graph structure for efficient storage and access. + */ public class SeparatedNVQ implements SeparatedFeature { + /** The NVQ quantization scheme used for compressing vectors. */ private final NVQuantization nvq; + /** Scorer for computing similarities between quantized vectors. */ private final NVQScorer scorer; + /** Thread-local storage for reusable quantized vector instances to avoid repeated allocation. */ private final ThreadLocal reusableQuantizedVector; + /** The file offset where the separated NVQ data begins. */ private long offset; + /** + * Constructs a SeparatedNVQ feature with the specified quantization and offset. + * + * @param nvq the NVQ quantization scheme + * @param offset the file offset where the NVQ data begins + */ public SeparatedNVQ(NVQuantization nvq, long offset) { this.nvq = nvq; this.offset = offset; @@ -86,6 +100,14 @@ public void writeSeparately(DataOutput out, State state_) throws IOException { // Using NVQ.State + /** + * Loads a SeparatedNVQ feature from the specified reader. + * + * @param header the common header (unused but kept for API consistency) + * @param reader the reader to load from + * @return the loaded SeparatedNVQ instance + * @throws UncheckedIOException if an I/O error occurs during loading + */ static SeparatedNVQ load(CommonHeader header, RandomAccessReader reader) { try { var nvq = NVQuantization.load(reader); @@ -96,10 +118,25 @@ static SeparatedNVQ load(CommonHeader header, RandomAccessReader reader) { } } + /** + * Returns the dimensionality of the vectors stored by this feature. + * + * @return the vector dimension + */ public int dimension() { return nvq.globalMean.length(); } + /** + * Creates an exact score function for reranking using the quantized vectors. + * The returned function reads quantized vectors from the feature source and computes + * exact similarities to the query vector. + * + * @param queryVector the query vector to compare against + * @param vsf the vector similarity function to use for scoring + * @param source the feature source for reading node data + * @return an exact score function for reranking + */ ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf, FeatureSource source) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java index f6bff8472..a23e0e5d1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java @@ -25,11 +25,24 @@ import java.io.IOException; import java.io.UncheckedIOException; +/** + * A separated feature implementation for full-resolution float vectors. + * Stores vector data separately from the graph structure, with only offset information in the header. + */ public class SeparatedVectors implements SeparatedFeature { + /** Vectorization support for reading and writing vectors efficiently. */ private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + /** The dimensionality of the vectors stored by this feature. */ private final int dimension; + /** The file offset where the separated vector data begins. */ private long offset; + /** + * Constructs a SeparatedVectors feature with the specified dimension and offset. + * + * @param dimension the dimensionality of the vectors + * @param offset the file offset where the vector data begins + */ public SeparatedVectors(int dimension, long offset) { this.dimension = dimension; this.offset = offset; @@ -80,6 +93,15 @@ public void writeSeparately(DataOutput out, State state_) throws IOException { // Using InlineVectors.State + /** + * Loads a SeparatedVectors feature from the specified reader. + * The dimension is taken from the common header. + * + * @param header the common header containing dimension information + * @param reader the reader to load from + * @return the loaded SeparatedVectors instance + * @throws UncheckedIOException if an I/O error occurs during loading + */ static SeparatedVectors load(CommonHeader header, RandomAccessReader reader) { try { long offset = reader.readLong(); @@ -89,6 +111,11 @@ static SeparatedVectors load(CommonHeader header, RandomAccessReader reader) { } } + /** + * Returns the dimensionality of the vectors stored by this feature. + * + * @return the vector dimension + */ public int dimension() { return dimension; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/package-info.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/package-info.java new file mode 100644 index 000000000..6bf06346d --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/package-info.java @@ -0,0 +1,188 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides feature types for storing additional data with on-disk graph indexes. + *

+ * This package contains implementations of features that can be stored alongside graph nodes + * in persistent indexes. Features represent additional per-node data such as vectors, compressed + * vectors, or other metadata. The feature system supports both inline storage (data stored with + * each node) and separated storage (data stored in a dedicated section for better cache locality). + * + *

Feature Architecture

+ *

+ * Features are identified by {@link io.github.jbellis.jvector.graph.disk.feature.FeatureId} + * and implement the {@link io.github.jbellis.jvector.graph.disk.feature.Feature} interface. + * During graph writing, features are serialized to the index file. During reading, features + * are loaded from the header and provide access to per-node data. + * + *

Core Abstractions

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.disk.feature.Feature} - Base interface for all + * features. Defines methods for: + *
      + *
    • Writing header metadata (dimensions, compression parameters, etc.)
    • + *
    • Writing per-node inline data
    • + *
    • Querying feature size and storage layout
    • + *
    + *
  • + *
  • {@link io.github.jbellis.jvector.graph.disk.feature.FeatureId} - Enum identifying + * available feature types. New features should be added to the end to maintain + * serialization compatibility.
  • + *
  • {@link io.github.jbellis.jvector.graph.disk.feature.FeatureSource} - Marker interface + * for features that provide data during graph writing.
  • + *
+ * + *

Available Features

+ * + *

Vector Storage

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.disk.feature.InlineVectors} - Stores full-precision + * vectors inline with each graph node. Best for: + *
      + *
    • Small to medium dimensional vectors (< 512 dimensions)
    • + *
    • When exact similarity computation is always required
    • + *
    + *
  • + *
  • {@link io.github.jbellis.jvector.graph.disk.feature.SeparatedVectors} - Stores vectors + * in a dedicated section separate from the graph structure. Benefits: + *
      + *
    • Better cache locality during graph traversal (when vectors aren't needed)
    • + *
    • More efficient when using approximate scoring during search
    • + *
    + *
  • + *
+ * + *

Compressed Vector Storage

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.disk.feature.NVQ} - Stores vectors compressed + * using Neighborhood Vector Quantization (NVQ). Inline storage variant.
  • + *
  • {@link io.github.jbellis.jvector.graph.disk.feature.SeparatedNVQ} - Separated storage + * variant of NVQ compression. Recommended for most use cases combining compression with + * cache-friendly layout.
  • + *
+ * + *

Specialized Features

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.disk.feature.FusedADC} - Combines compressed + * vectors with precomputed query-dependent data for faster similarity computation. Used + * with Product Quantization (PQ) for asymmetric distance computation.
  • + *
+ * + *

Storage Strategies

+ * + *

Inline Storage

+ *

+ * Inline features store data directly with each graph node. This provides: + *

    + *
  • Advantages: Single random access to get both graph structure and feature data
  • + *
  • Disadvantages: Larger per-node size reduces cache efficiency during graph traversal
  • + *
+ * + *

Separated Storage

+ *

+ * Separated features ({@link io.github.jbellis.jvector.graph.disk.feature.SeparatedFeature}) + * store data in a dedicated section. This provides: + *

    + *
  • Advantages: + *
      + *
    • Smaller per-node size improves cache utilization during traversal
    • + *
    • Feature data only accessed when needed (e.g., for reranking)
    • + *
    • Better suited for approximate + exact scoring workflows
    • + *
    + *
  • + *
  • Disadvantages: Requires additional seek for feature access
  • + *
+ * + *

Feature Selection Guidelines

+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
Feature Selection by Use Case
Use CaseRecommended FeatureRationale
High-dimensional vectors (> 512d)SeparatedVectors or SeparatedNVQReduces per-node size for better cache efficiency
Memory-constrained environmentsSeparatedNVQ or FusedADCCompression reduces memory footprint
Low-dimensional vectors (< 128d)InlineVectorsMinimal overhead, single access pattern
Approximate + exact rerankingSeparatedNVQ + SeparatedVectorsUse compressed for search, exact for reranking
+ * + *

Usage Example

+ *
{@code
+ * // Writing a graph with separated vectors
+ * try (var writer = new OnDiskGraphIndexWriter.Builder(graph, output)
+ *         .with(new SeparatedVectors(dimension))
+ *         .build()) {
+ *     // Create feature state for each node
+ *     var features = Feature.singleStateFactory(
+ *         FeatureId.SEPARATED_VECTORS,
+ *         nodeId -> new SeparatedVectors.State(vectors.getVector(nodeId))
+ *     );
+ *     writer.write(features);
+ * }
+ *
+ * // Reading and accessing feature data
+ * var reader = OnDiskGraphIndex.load(...);
+ * try (var view = reader.getView()) {
+ *     VectorFloat vector = view.getVector(nodeId);
+ * }
+ * }
+ * + *

Adding New Features

+ *

+ * To add a new feature type: + *

    + *
  1. Add a new entry to {@code FeatureId} enum (at the end to maintain compatibility)
  2. + *
  3. Implement the {@code Feature} interface with: + *
      + *
    • Header serialization (metadata like dimensions, compression parameters)
    • + *
    • Per-node data serialization (inline or separated)
    • + *
    • Loading logic to reconstruct from disk
    • + *
    + *
  4. + *
  5. Update graph writers to support the new feature
  6. + *
  7. Update graph readers to provide access to the feature data
  8. + *
+ * + *

Thread Safety

+ *
    + *
  • Feature instances are typically immutable after construction and thread-safe
  • + *
  • Feature.State instances are per-write operation and not thread-safe
  • + *
  • Feature data access through graph views is thread-safe
  • + *
+ * + * @see io.github.jbellis.jvector.graph.disk.feature.Feature + * @see io.github.jbellis.jvector.graph.disk.feature.FeatureId + * @see io.github.jbellis.jvector.graph.disk + */ +package io.github.jbellis.jvector.graph.disk.feature; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/package-info.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/package-info.java new file mode 100644 index 000000000..c96e151e7 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/package-info.java @@ -0,0 +1,113 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides classes for reading and writing graph indexes to persistent storage. + *

+ * This package contains the core infrastructure for serializing and deserializing vector search + * graph indexes. It supports both sequential and random-access writing strategies, multiple + * format versions, and flexible feature storage (inline or separated). + * + *

Key Components

+ * + *

Writers

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.disk.AbstractGraphIndexWriter} - Abstract base class + * for all graph index writers, providing common functionality for header/footer writing, + * feature handling, and ordinal mapping
  • + *
  • {@link io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter} - Random-access writer + * that can write nodes in any order
  • + *
  • {@link io.github.jbellis.jvector.graph.disk.OnDiskSequentialGraphIndexWriter} - Sequential + * writer optimized for writing nodes in ordinal order
  • + *
  • {@link io.github.jbellis.jvector.graph.disk.GraphIndexWriter} - Interface defining the + * contract for writing graph indexes
  • + *
+ * + *

Reader

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex} - Memory-mapped reader for + * accessing on-disk graph indexes efficiently
  • + *
+ * + *

Utilities

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.disk.OrdinalMapper} - Maps between original graph + * ordinals and on-disk ordinals, useful for compacting deleted nodes or mapping to external IDs
  • + *
  • {@link io.github.jbellis.jvector.graph.disk.Header} - Encapsulates the index header format
  • + *
  • {@link io.github.jbellis.jvector.graph.disk.CommonHeader} - Common header information + * shared across format versions
  • + *
+ * + *

On-Disk Format

+ *

+ * The on-disk format consists of the following sections: + *

    + *
  1. Header - Contains metadata about the graph (version, dimension, entry node, layer info) + * and feature configuration
  2. + *
  3. Dense Level (Level 0) - All graph nodes with their inline features and neighbor lists
  4. + *
  5. Sparse Levels - For hierarchical graphs, additional levels containing only nodes + * participating in those levels
  6. + *
  7. Separated Features - Optional section containing feature data that is stored + * separately from nodes for better cache locality
  8. + *
  9. Footer - Contains the header offset (allowing the header to be located) and a + * magic number for file validation
  10. + *
+ * + *

Format Versions

+ *
    + *
  • Version 1-2: Basic format with inline vectors only
  • + *
  • Version 3: Support for multiple feature types
  • + *
  • Version 4+: Support for multilayer (hierarchical) graphs
  • + *
+ * + *

Features

+ *

+ * Features represent additional data stored with graph nodes (e.g., vectors, compressed vectors). + * Features can be stored inline (with each node) or separated (in a dedicated section). + * See the {@link io.github.jbellis.jvector.graph.disk.feature} package for available feature types. + * + *

Usage Example

+ *
{@code
+ * // Writing a graph index
+ * try (var output = new BufferedRandomAccessWriter(...)) {
+ *     var writer = new OnDiskGraphIndexWriter.Builder(graph, output)
+ *         .withVersion(4)
+ *         .with(new InlineVectors(dimension))
+ *         .build();
+ *
+ *     writer.write(featureSuppliers);
+ * }
+ *
+ * // Reading a graph index
+ * var reader = OnDiskGraphIndex.load(...);
+ * try (var view = reader.getView()) {
+ *     // Access nodes and neighbors
+ *     var neighbors = view.getNeighborsIterator(level, ordinal);
+ *     // Read features
+ *     var vector = view.getVector(ordinal);
+ * }
+ * }
+ * + *

Thread Safety

+ *

+ * Writers are not thread-safe for concurrent writes to the same instance. + * Readers ({@code OnDiskGraphIndex}) are thread-safe and support concurrent read access + * through separate views. + * + * @see io.github.jbellis.jvector.graph.disk.feature + * @see io.github.jbellis.jvector.graph.ImmutableGraphIndex + */ +package io.github.jbellis.jvector.graph.disk; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/DiversityProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/DiversityProvider.java index 7551aec71..a18e8b0ff 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/DiversityProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/DiversityProvider.java @@ -24,9 +24,18 @@ import static java.lang.Math.min; +/** + * Provides diversity selection functionality for graph neighbors. + * Implementations determine which neighbors to retain to maintain graph quality. + */ public interface DiversityProvider { /** - * update `selected` with the diverse members of `neighbors`. `neighbors` is not modified + * Updates {@code selected} with the diverse members of {@code neighbors}. The {@code neighbors} array is not modified. + * + * @param neighbors the candidate neighbors to select from + * @param maxDegree the maximum number of neighbors to retain + * @param diverseBefore the index before which neighbors are already diverse and don't need re-checking + * @param selected a BitSet to update with the indices of selected diverse neighbors * @return the fraction of short edges (neighbors within alpha=1.0) */ double retainDiverse(NodeArray neighbors, int maxDegree, int diverseBefore, BitSet selected); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/VamanaDiversityProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/VamanaDiversityProvider.java index 0bdc6415f..34fe95430 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/VamanaDiversityProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/VamanaDiversityProvider.java @@ -24,6 +24,12 @@ import static java.lang.Math.min; +/** + * Provides diversity selection using the Vamana algorithm's diversity heuristic. + * The alpha parameter controls the diversity threshold, where higher values + * encourage more diverse neighbor selection at the cost of potentially longer + * paths in the graph. + */ public class VamanaDiversityProvider implements DiversityProvider { /** the diversity threshold; 1.0 is equivalent to HNSW; Vamana uses 1.2 or more */ public final float alpha; @@ -31,7 +37,11 @@ public class VamanaDiversityProvider implements DiversityProvider { /** used to compute diversity */ public final BuildScoreProvider scoreProvider; - /** Create a new diversity provider */ + /** + * Creates a new Vamana diversity provider. + * @param scoreProvider the score provider used to compute diversity + * @param alpha the diversity threshold (1.0 is equivalent to HNSW; Vamana uses 1.2 or more) + */ public VamanaDiversityProvider(BuildScoreProvider scoreProvider, float alpha) { this.scoreProvider = scoreProvider; this.alpha = alpha; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/package-info.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/package-info.java new file mode 100644 index 000000000..3a5020093 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/diversity/package-info.java @@ -0,0 +1,103 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides neighbor diversity selection strategies for graph construction. + *

+ * This package contains implementations of diversity providers that determine which neighbors + * to retain during graph construction. Diversity selection is critical for building high-quality + * proximity graphs that balance local connectivity with long-range edges. + * + *

Diversity and Graph Quality

+ *

+ * In graph-based vector search, simply connecting each node to its k nearest neighbors can lead + * to poor search performance. Diversity selection addresses two key issues: + *

    + *
  • Clustering: Without diversity, nodes in dense regions may connect only to their + * immediate cluster, making it difficult to reach distant regions of the vector space.
  • + *
  • Graph traversability: A diverse neighbor set includes both short edges (for local + * precision) and longer edges (for efficient navigation across the space).
  • + *
+ * + *

Core Abstractions

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.diversity.DiversityProvider} - Interface for + * diversity selection algorithms. Implementations select which neighbors to retain from + * a candidate set while maintaining graph quality constraints.
  • + *
  • {@link io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider} - Implementation + * based on the DiskANN/Vamana Robust Prune algorithm. Uses an alpha parameter to control + * the trade-off between short edges (high recall) and longer edges (better graph + * navigability).
  • + *
+ * + *

Vamana Diversity Algorithm

+ *

+ * The Vamana diversity provider implements the Robust Prune algorithm from the DiskANN paper: + *

    + *
  1. Start with a candidate set of potential neighbors
  2. + *
  3. For each candidate, check if adding it would create a "shortcut" - that is, if there's + * already a neighbor that's closer to both the target node and the candidate
  4. + *
  5. The alpha parameter controls how strict this test is: + *
      + *
    • alpha = 1.0: Only keep edges where no existing neighbor is closer (strictest)
    • + *
    • alpha > 1.0: Allow longer edges even when shortcuts exist (recommended)
    • + *
    • Higher alpha values create more diverse graphs with better long-range connectivity
    • + *
    + *
  6. + *
  7. Select up to maxDegree diverse neighbors
  8. + *
+ * + *

Usage in Graph Construction

+ *

+ * Diversity providers are used by {@link io.github.jbellis.jvector.graph.GraphIndexBuilder} + * during graph construction: + *

{@code
+ * // Create a diversity provider with alpha=1.2 for balanced diversity
+ * BuildScoreProvider scoreProvider = ...;
+ * DiversityProvider diversityProvider = new VamanaDiversityProvider(scoreProvider, 1.2f);
+ *
+ * // The GraphIndexBuilder uses the diversity provider internally
+ * GraphIndexBuilder builder = new GraphIndexBuilder(
+ *     scoreProvider,
+ *     dimension,
+ *     maxDegree,
+ *     beamWidth,
+ *     neighborOverflow,
+ *     1.2f,  // alpha passed to diversity provider
+ *     addHierarchy
+ * );
+ * }
+ * + *

Alpha Parameter Guidelines

+ *
    + *
  • alpha = 1.0: Creates an HNSW-like graph at the base layer (not recommended for + * JVector's Vamana-based approach)
  • + *
  • alpha = 1.2 (default): Good balance between recall and build efficiency
  • + *
  • alpha > 1.5: More diverse graphs with better long-range connectivity but + * potentially lower recall for small beam widths
  • + *
+ * + *

Thread Safety

+ *

+ * {@code DiversityProvider} implementations are typically stateless (beyond immutable + * configuration) and thread-safe. The same instance can be shared across multiple threads + * during concurrent graph construction. + * + * @see io.github.jbellis.jvector.graph.diversity.DiversityProvider + * @see io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider + * @see io.github.jbellis.jvector.graph.GraphIndexBuilder + */ +package io.github.jbellis.jvector.graph.diversity; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/package-info.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/package-info.java new file mode 100644 index 000000000..9ce0b6138 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/package-info.java @@ -0,0 +1,167 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides core graph-based approximate nearest neighbor (ANN) search implementations. + *

+ * This package contains the primary graph data structures and algorithms for building and + * searching vector similarity indexes. JVector implements a hybrid approach combining + * DiskANN-inspired graph construction with optional HNSW-style hierarchical layers. + * + *

Core Concepts

+ * + *

Graph Index

+ *

+ * The graph index is a proximity graph where nodes represent vectors and edges connect + * similar vectors. JVector uses a Vamana-based construction algorithm that builds a + * high-quality base layer, with optional hierarchical layers for faster entry point selection. + * + *

Key Interfaces and Classes

+ * + *

Graph Representations

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.ImmutableGraphIndex} - Immutable view of a graph + * index. All graph implementations provide this interface for thread-safe read access.
  • + *
  • {@link io.github.jbellis.jvector.graph.MutableGraphIndex} - Mutable graph index interface + * that supports adding nodes and edges.
  • + *
  • {@link io.github.jbellis.jvector.graph.OnHeapGraphIndex} - In-memory graph index + * implementation supporting concurrent construction and search.
  • + *
  • {@link io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex} - Memory-mapped graph + * index loaded from persistent storage (see {@link io.github.jbellis.jvector.graph.disk} + * package).
  • + *
+ * + *

Graph Construction

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.GraphIndexBuilder} - Builder for constructing + * graph indexes. Supports concurrent graph construction with configurable parameters: + *
      + *
    • M - maximum edges per node (degree)
    • + *
    • beamWidth - search beam width during construction
    • + *
    • neighborOverflow - temporary overflow ratio during insertion
    • + *
    • alpha - diversity pruning parameter (controls edge length distribution)
    • + *
    • addHierarchy - whether to build HNSW-style hierarchical layers
    • + *
    + *
  • + *
+ * + *

Graph Search

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.GraphSearcher} - Performs beam search on graph + * indexes to find approximate nearest neighbors. Supports: + *
      + *
    • Multi-layer hierarchical search
    • + *
    • Result reranking with exact distances
    • + *
    • Filtered search using {@link io.github.jbellis.jvector.util.Bits}
    • + *
    + *
  • + *
  • {@link io.github.jbellis.jvector.graph.SearchResult} - Encapsulates search results with + * node IDs, scores, and search statistics.
  • + *
+ * + *

Vector Access

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.RandomAccessVectorValues} - Interface for random + * access to vectors by ordinal. Supports both shared and unshared implementations.
  • + *
  • {@link io.github.jbellis.jvector.graph.ListRandomAccessVectorValues} - In-memory vector + * storage backed by a List.
  • + *
  • {@link io.github.jbellis.jvector.graph.MapRandomAccessVectorValues} - Vector storage + * backed by a Map, useful for sparse vector sets.
  • + *
+ * + *

Data Structures

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.NodeArray} - Specialized array for storing node + * IDs and scores, supporting efficient sorted insertion.
  • + *
  • {@link io.github.jbellis.jvector.graph.NodeQueue} - Priority queue for graph search.
  • + *
  • {@link io.github.jbellis.jvector.graph.NodesIterator} - Iterator over node ordinals.
  • + *
  • {@link io.github.jbellis.jvector.graph.ConcurrentNeighborMap} - Thread-safe neighbor + * storage for concurrent graph construction.
  • + *
+ * + *

Graph Construction Algorithm

+ *

+ * JVector's graph construction is based on the DiskANN/Vamana algorithm with extensions: + *

    + *
  1. For each new node: + *
      + *
    • Assign a hierarchical level (if hierarchy enabled)
    • + *
    • Search for approximate nearest neighbors using beam search
    • + *
    • Connect to diverse neighbors using robust pruning
    • + *
    • Update existing nodes' neighbor lists (backlinks)
    • + *
    + *
  2. + *
  3. Concurrent insertions track in-progress nodes to maintain consistency
  4. + *
  5. After all insertions, cleanup phase: + *
      + *
    • Remove deleted nodes and update connections
    • + *
    • Optionally refine connections for improved recall
    • + *
    • Enforce degree constraints
    • + *
    + *
  6. + *
+ * + *

Usage Example

+ *
{@code
+ * // Build a graph index
+ * RandomAccessVectorValues vectors = new ListRandomAccessVectorValues(vectorList, dimension);
+ * GraphIndexBuilder builder = new GraphIndexBuilder(
+ *     vectors,
+ *     VectorSimilarityFunction.COSINE,
+ *     16,    // M (max degree)
+ *     100,   // beamWidth
+ *     1.2f,  // neighborOverflow
+ *     1.2f,  // alpha
+ *     true   // addHierarchy
+ * );
+ * ImmutableGraphIndex graph = builder.build(vectors);
+ *
+ * // Search the graph
+ * try (var view = graph.getView();
+ *      var searcher = new GraphSearcher(graph)) {
+ *     VectorFloat query = ...;
+ *     SearchScoreProvider ssp = BuildScoreProvider
+ *         .randomAccessScoreProvider(vectors, VectorSimilarityFunction.COSINE)
+ *         .searchProviderFor(query);
+ *     SearchResult result = searcher.search(ssp, 10, Bits.ALL);
+ *     for (SearchResult.NodeScore ns : result.getNodes()) {
+ *         System.out.printf("Node %d: score %.4f%n", ns.node, ns.score);
+ *     }
+ * }
+ * }
+ * + *

Thread Safety

+ *
    + *
  • {@code ImmutableGraphIndex} and its views are thread-safe for concurrent reads
  • + *
  • {@code GraphIndexBuilder} supports concurrent insertions via {@code addGraphNode}
  • + *
  • {@code GraphSearcher} instances are stateful and not thread-safe; create one per thread
  • + *
  • {@code RandomAccessVectorValues} implementations may be shared or unshared; check + * {@code isValueShared()} and use {@code threadLocalSupplier()} for thread-safe access
  • + *
+ * + *

Related Packages

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.disk} - On-disk graph persistence
  • + *
  • {@link io.github.jbellis.jvector.graph.similarity} - Similarity scoring abstractions
  • + *
  • {@link io.github.jbellis.jvector.graph.diversity} - Diversity providers for neighbor selection
  • + *
+ * + * @see io.github.jbellis.jvector.graph.GraphIndexBuilder + * @see io.github.jbellis.jvector.graph.GraphSearcher + * @see io.github.jbellis.jvector.graph.ImmutableGraphIndex + * @see io.github.jbellis.jvector.graph.MutableGraphIndex + */ +package io.github.jbellis.jvector.graph; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java index b8ec5fa5f..2c2f82d6c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java @@ -29,48 +29,51 @@ * Encapsulates comparing node distances for GraphIndexBuilder. */ public interface BuildScoreProvider { + /** Vector type support for creating and manipulating vectors. */ VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); /** - * @return true if the primary score functions used for construction are exact. This - * is modestly redundant, but it saves having to allocate new Search/Diversity provider + * Returns true if the primary score functions used for construction are exact. + * This is modestly redundant, but it saves having to allocate new Search/Diversity provider * objects in some hot construction loops. + * @return true if the primary score functions are exact, false if they are approximate */ boolean isExact(); /** - * @return the approximate centroid of the known nodes. We use the closest node - * to this centroid as the graph entry point, so this is called when the entry point is deleted + * Returns the approximate centroid of the known nodes. The closest node to this centroid + * is used as the graph entry point, so this is called when the entry point is deleted * or every time the graph size doubles. *

* This is not called on a path that blocks searches or modifications, so it is okay for it to be O(N). + * @return the approximate centroid of the known nodes */ VectorFloat approximateCentroid(); /** - * Create a search score provider to use *internally* during construction. + * Creates a search score provider to use internally during construction. *

* "Internally" means that this may differ from a typical SSP in that it may use - * approximate scores *without* reranking. (In this case, reranking will be done + * approximate scores without reranking. (In this case, reranking will be done * separately by the ConcurrentNeighborSet diversity code.) - *

* @param vector the query vector to provide similarity scores against + * @return a SearchScoreProvider for the given query vector */ SearchScoreProvider searchProviderFor(VectorFloat vector); /** - * Create a search score provider to use *internally* during construction. + * Creates a search score provider to use internally during construction. *

* "Internally" means that this may differ from a typical SSP in that it may use - * approximate scores *without* reranking. (In this case, reranking will be done + * approximate scores without reranking. (In this case, reranking will be done * separately by the ConcurrentNeighborSet diversity code.) - *

* @param node1 the graph node to provide similarity scores against + * @return a SearchScoreProvider for the given node */ SearchScoreProvider searchProviderFor(int node1); /** - * Create a score provider to use internally during construction. + * Creates a score provider to use internally during construction. *

* The difference between the diversity provider and the search provider is * that the diversity provider is only expected to be used a few dozen times per node, @@ -78,11 +81,16 @@ public interface BuildScoreProvider { *

* When scoring is approximate, the scores from the search and diversity provider * must be consistent, i.e. mixing different types of CompressedVectors will cause problems. + * @param node1 the graph node to provide diversity scores against + * @return a SearchScoreProvider for diversity computation */ SearchScoreProvider diversityProviderFor(int node1); /** * Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction. + * @param ravv the RandomAccessVectorValues providing access to vectors + * @param similarityFunction the similarity function to use for scoring + * @return a BuildScoreProvider that performs exact score comparisons */ static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) { // We need two sources of vectors in order to perform diversity check comparisons without @@ -138,7 +146,11 @@ public SearchScoreProvider diversityProviderFor(int node1) { * with reranking performed using RandomAccessVectorValues (which is intended to be * InlineVectorValues for building incrementally, but should technically * work with any RAVV implementation). - * This class is not thread safe, we should never publish its results to another thread. + *

+ * This implementation is not thread safe; results should never be published to another thread. + * @param vsf the vector similarity function to use + * @param pqv the product quantized vectors + * @return a BuildScoreProvider that uses product quantization for approximate scoring */ static BuildScoreProvider pqBuildScoreProvider(VectorSimilarityFunction vsf, PQVectors pqv) { int dimension = pqv.getOriginalSize() / Float.BYTES; @@ -179,6 +191,13 @@ public VectorFloat approximateCentroid() { }; } + /** + * Returns a BSP that performs approximate score comparisons using binary quantization. + * Binary quantization compresses vectors to single bits and uses Hamming distance for + * similarity computation, providing a fast approximation suitable for cosine similarity. + * @param bqv the binary quantized vectors + * @return a BuildScoreProvider that uses binary quantization for approximate scoring + */ static BuildScoreProvider bqBuildScoreProvider(BQVectors bqv) { return new BuildScoreProvider() { @Override diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java index 0754b39d7..9dbe77bfd 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java @@ -26,22 +26,25 @@ public final class DefaultSearchScoreProvider implements SearchScoreProvider { private final ScoreFunction.ExactScoreFunction reranker; /** - * @param scoreFunction the primary, fast scoring function - *

+ * Constructs a DefaultSearchScoreProvider with only a primary scoring function. * No reranking is performed. + * + * @param scoreFunction the primary, fast scoring function */ public DefaultSearchScoreProvider(ScoreFunction scoreFunction) { this(scoreFunction, null); } /** - * @param scoreFunction the primary, fast scoring function - * @param reranker optional reranking function - * Generally, reranker will be null iff scoreFunction is an ExactScoreFunction. However, + * Constructs a DefaultSearchScoreProvider with a primary scoring function and optional reranking. + * Generally, reranker will be null iff scoreFunction is an ExactScoreFunction. However, * it is allowed, and sometimes useful, to only perform approximate scoring without reranking. - *

- * Most often it will be convenient to get the reranker either using `RandomAccessVectorValues.rerankerFor` - * or `ScoringView.rerankerFor`. + * + * Most often it will be convenient to get the reranker either using {@code RandomAccessVectorValues.rerankerFor} + * or {@code ScoringView.rerankerFor}. + * + * @param scoreFunction the primary, fast scoring function + * @param reranker optional reranking function (may be null) */ public DefaultSearchScoreProvider(ScoreFunction scoreFunction, ScoreFunction.ExactScoreFunction reranker) { assert scoreFunction != null; @@ -64,9 +67,14 @@ public ScoreFunction.ExactScoreFunction exactScoreFunction() { } /** - * A SearchScoreProvider for a single-pass search based on exact similarity. + * Creates a SearchScoreProvider for a single-pass search based on exact similarity. * Generally only suitable when your RandomAccessVectorValues is entirely in-memory, * e.g. during construction. + * + * @param v the query vector + * @param vsf the vector similarity function to use + * @param ravv the random access vector values to search + * @return a DefaultSearchScoreProvider configured for exact search */ public static DefaultSearchScoreProvider exact(VectorFloat v, VectorSimilarityFunction vsf, RandomAccessVectorValues ravv) { // don't use ESF.reranker, we need thread safety here diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/ScoreFunction.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/ScoreFunction.java index 0dcb95823..9b8fdf667 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/ScoreFunction.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/ScoreFunction.java @@ -21,7 +21,7 @@ import io.github.jbellis.jvector.vector.types.VectorTypeSupport; /** - * Provides an API for encapsulating similarity to another node or vector. Used both for + * Provides an API for encapsulating similarity to another node or vector. Used both for * building the graph (as part of NodeSimilarity) or for searching it (used standalone, * with a reference to the query vector). *

@@ -29,40 +29,68 @@ * can be defined as a simple lambda. */ public interface ScoreFunction { + /** Vector type support for creating vector instances. */ VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); /** + * Returns true if the ScoreFunction returns exact, full-resolution scores. + * * @return true if the ScoreFunction returns exact, full-resolution scores */ boolean isExact(); /** + * Returns the similarity to one other node. + * + * @param node2 the node to compute similarity to * @return the similarity to one other node */ float similarityTo(int node2); /** - * @return the similarity to all of the nodes that `node2` has an edge towards. - * Used when expanding the neighbors of a search candidate. + * Returns the similarity to all of the nodes that the given node has an edge towards. + * Used when expanding the neighbors of a search candidate for bulk similarity computations. + * + * @param node2 the node whose neighbors should be scored + * @return a vector containing similarity scores to each neighbor + * @throws UnsupportedOperationException if bulk similarity is not supported */ default VectorFloat edgeLoadingSimilarityTo(int node2) { throw new UnsupportedOperationException("bulk similarity not supported"); } /** + * Returns true if edge loading similarity is supported (i.e., if edgeLoadingSimilarityTo can be called). + * * @return true if `edgeLoadingSimilarityTo` is supported */ default boolean supportsEdgeLoadingSimilarity() { return false; } + /** + * A score function that returns exact, full-resolution similarity scores. + */ interface ExactScoreFunction extends ScoreFunction { + /** + * Returns true to indicate this is an exact score function. + * + * @return true + */ default boolean isExact() { return true; } } + /** + * A score function that returns approximate similarity scores, potentially using compressed vectors. + */ interface ApproximateScoreFunction extends ScoreFunction { + /** + * Returns false to indicate this is an approximate score function. + * + * @return false + */ default boolean isExact() { return false; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/SearchScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/SearchScoreProvider.java index 4122f7105..dfea53793 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/SearchScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/SearchScoreProvider.java @@ -19,9 +19,24 @@ /** Encapsulates comparing node distances to a specific vector for GraphSearcher. */ public interface SearchScoreProvider { + /** + * Returns the primary score function for fast approximate scoring. + * + * @return the score function + */ ScoreFunction scoreFunction(); + /** + * Returns the optional reranking function for more accurate scoring. + * + * @return the reranker, or null if no reranking is performed + */ ScoreFunction.ExactScoreFunction reranker(); + /** + * Returns the exact score function, either the primary function if it is exact, or the reranker. + * + * @return the exact score function + */ ScoreFunction.ExactScoreFunction exactScoreFunction(); } \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/package-info.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/package-info.java new file mode 100644 index 000000000..8b6b7bfba --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/package-info.java @@ -0,0 +1,156 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides abstractions for vector similarity scoring during graph construction and search. + *

+ * This package defines a layered abstraction for computing similarity scores between vectors, + * supporting both exact and approximate scoring strategies, optional reranking, and various + * optimization techniques like quantization and caching. + * + *

Scoring Architecture

+ *

+ * The package provides a three-level scoring hierarchy: + *

    + *
  1. BuildScoreProvider - Top-level provider configured during graph construction. + * Creates search-specific score providers for each node or query.
  2. + *
  3. SearchScoreProvider - Per-query or per-node provider that creates actual + * score functions and manages approximate/exact scoring strategies.
  4. + *
  5. ScoreFunction - Performs the actual similarity computations between vectors.
  6. + *
+ * + *

Core Interfaces

+ * + *

Score Providers

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.similarity.BuildScoreProvider} - Top-level + * interface for creating score providers. Maintains shared state like vector data and + * quantization codebooks. Factory methods support various use cases: + *
      + *
    • {@code randomAccessScoreProvider()} - For in-memory vectors
    • + *
    • {@code pqBuildScoreProvider()} - For Product Quantization
    • + *
    • Other variants for NVQ, binary quantization, and fused approaches
    • + *
    + *
  • + *
  • {@link io.github.jbellis.jvector.graph.similarity.SearchScoreProvider} - Per-query + * interface that creates score functions. Supports: + *
      + *
    • Approximate scoring (using quantized vectors)
    • + *
    • Exact scoring (using full-precision vectors)
    • + *
    • Optional reranking to improve precision
    • + *
    + *
  • + *
  • {@link io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider} - Default + * implementation wrapping a single score function.
  • + *
+ * + *

Score Functions

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.similarity.ScoreFunction} - Core interface for + * computing similarity scores. Methods: + *
      + *
    • {@code similarityTo(int)} - Compute similarity to a single node
    • + *
    • {@code edgeLoadingSimilarityTo(int)} - Bulk similarity for all neighbors (optional)
    • + *
    • {@code isExact()} - Indicates if scores are exact or approximate
    • + *
    + *
  • + *
  • {@link io.github.jbellis.jvector.graph.similarity.ScoreFunction.ExactScoreFunction} - + * Marker interface for exact scoring implementations.
  • + *
  • {@link io.github.jbellis.jvector.graph.similarity.ScoreFunction.ApproximateScoreFunction} - + * Marker interface for approximate scoring implementations (typically using quantization).
  • + *
+ * + *

Utility Classes

+ *
    + *
  • {@link io.github.jbellis.jvector.graph.similarity.CachingVectorValues} - Wrapper that + * caches vector access to improve performance when vectors are accessed multiple times.
  • + *
+ * + *

Approximate vs. Exact Scoring

+ *

+ * JVector supports a two-phase scoring strategy: + *

    + *
  1. Approximate scoring: Used during graph traversal to quickly identify candidates. + * Typically uses quantized vectors (PQ, NVQ, or binary quantization) for speed.
  2. + *
  3. Exact scoring: Optional reranking of top candidates using full-precision vectors + * for better accuracy. Controlled by {@code rerankFloor} parameter in search.
  4. + *
+ * + *

Usage Examples

+ * + *

Simple Exact Scoring

+ *
{@code
+ * // Create a score provider for exact scoring
+ * RandomAccessVectorValues vectors = ...;
+ * BuildScoreProvider buildProvider = BuildScoreProvider.randomAccessScoreProvider(
+ *     vectors,
+ *     VectorSimilarityFunction.COSINE
+ * );
+ *
+ * // Create a search provider for a query
+ * VectorFloat query = ...;
+ * SearchScoreProvider searchProvider = buildProvider.searchProviderFor(query);
+ *
+ * // Get a score function and compute similarities
+ * ScoreFunction scoreFunction = searchProvider.scoreFunction();
+ * float score = scoreFunction.similarityTo(nodeId);
+ * }
+ * + *

Approximate Scoring with Reranking

+ *
{@code
+ * // Create a PQ-based score provider
+ * ProductQuantization pq = ProductQuantization.compute(vectors, 16, 256);
+ * BuildScoreProvider buildProvider = BuildScoreProvider.pqBuildScoreProvider(
+ *     vectors,
+ *     VectorSimilarityFunction.COSINE,
+ *     pq
+ * );
+ *
+ * // During search, use approximate scores for traversal
+ * SearchScoreProvider searchProvider = buildProvider.searchProviderFor(query);
+ * ScoreFunction.ApproximateScoreFunction approx = searchProvider.scoreFunction();
+ *
+ * // Rerank top candidates with exact scores
+ * ScoreFunction.ExactScoreFunction exact = searchProvider.exactScoreFunction();
+ * for (int candidate : topCandidates) {
+ *     float exactScore = exact.similarityTo(candidate);
+ * }
+ * }
+ * + *

Performance Considerations

+ *
    + *
  • Approximate scoring: 5-10x faster than exact scoring with quantization, at the + * cost of some precision loss.
  • + *
  • Reranking: Adding exact reranking typically improves recall by 1-5% with minimal + * performance impact when reranking only top-k candidates.
  • + *
  • Edge loading: Bulk similarity computation can be 2-3x faster than individual + * queries when supported by the quantization method.
  • + *
+ * + *

Thread Safety

+ *
    + *
  • {@code BuildScoreProvider} implementations are typically thread-safe and can be shared.
  • + *
  • {@code SearchScoreProvider} instances are lightweight and can be created per query.
  • + *
  • {@code ScoreFunction} instances are typically not thread-safe and should be + * created per thread (or per search operation).
  • + *
+ * + * @see io.github.jbellis.jvector.graph.similarity.BuildScoreProvider + * @see io.github.jbellis.jvector.graph.similarity.SearchScoreProvider + * @see io.github.jbellis.jvector.graph.similarity.ScoreFunction + * @see io.github.jbellis.jvector.quantization + */ +package io.github.jbellis.jvector.graph.similarity; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BQVectors.java index b89bd9c4c..ad19384f1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BQVectors.java @@ -28,10 +28,24 @@ import java.util.Arrays; import java.util.Objects; +/** + * Abstract base class for collections of binary quantized vectors. + *

+ * Binary quantization compresses each float vector into a compact bit representation, + * where each float is represented by a single bit. Similarity is computed using Hamming + * distance, which provides a fast approximation particularly suitable for cosine similarity. + */ public abstract class BQVectors implements CompressedVectors { + /** The binary quantization compressor used by this instance. */ protected final BinaryQuantization bq; + + /** The compressed vector data, stored as arrays of longs. */ protected long[][] compressedVectors; + /** + * Constructs a BQVectors instance with the given binary quantization compressor. + * @param bq the binary quantization compressor + */ protected BQVectors(BinaryQuantization bq) { this.bq = bq; } @@ -55,6 +69,13 @@ public void write(DataOutput out, int version) throws IOException { } } + /** + * Loads binary quantized vectors from the given RandomAccessReader at the specified offset. + * @param in the RandomAccessReader to load from + * @param offset the offset position to start reading from + * @return a BQVectors instance containing the loaded vectors + * @throws IOException if an I/O error occurs or the data format is invalid + */ public static BQVectors load(RandomAccessReader in, long offset) throws IOException { in.seek(offset); @@ -113,10 +134,22 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, }; } + /** + * Computes the similarity between two binary quantized vectors using Hamming distance. + * The similarity is normalized to the range [0, 1], where 1 represents identical vectors. + * @param encoded1 the first encoded vector + * @param encoded2 the second encoded vector + * @return the similarity score between 0 and 1 + */ public float similarityBetween(long[] encoded1, long[] encoded2) { return 1 - (float) VectorUtil.hammingDistance(encoded1, encoded2) / bq.getOriginalDimension(); } + /** + * Returns the compressed vector at the specified index. + * @param i the index of the vector to retrieve + * @return the compressed vector as an array of longs + */ public long[] get(int i) { return compressedVectors[i]; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java index 356353605..d4c7a776c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java @@ -37,12 +37,19 @@ public class BinaryQuantization implements VectorCompressor { private final int dimension; + /** + * Creates a BinaryQuantization instance for vectors of the specified dimension. + * @param dimension the dimension of the vectors to be quantized + */ public BinaryQuantization(int dimension) { this.dimension = dimension; } /** - * Use BQ constructor instead + * Creates a BinaryQuantization instance from the given RandomAccessVectorValues. + * @param ravv the RandomAccessVectorValues to create quantization from + * @return a BinaryQuantization instance + * @deprecated use {@link #BinaryQuantization(int)} constructor instead */ @Deprecated public static BinaryQuantization compute(RandomAccessVectorValues ravv) { @@ -50,7 +57,11 @@ public static BinaryQuantization compute(RandomAccessVectorValues ravv) { } /** - * Use BQ constructor instead + * Creates a BinaryQuantization instance from the given RandomAccessVectorValues. + * @param ravv the RandomAccessVectorValues to create quantization from + * @param parallelExecutor the ForkJoinPool for parallel execution (unused) + * @return a BinaryQuantization instance + * @deprecated use {@link #BinaryQuantization(int)} constructor instead */ @Deprecated public static BinaryQuantization compute(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) { @@ -128,10 +139,20 @@ public void write(DataOutput out, int version) throws IOException { vts.writeFloatVector(out, vts.createFloatVector(dimension)); } + /** + * Returns the original dimension of the vectors being quantized. + * @return the original dimension + */ public int getOriginalDimension() { return dimension; } + /** + * Loads a BinaryQuantization instance from the given RandomAccessReader. + * @param in the RandomAccessReader to load from + * @return a BinaryQuantization instance + * @throws IOException if an I/O error occurs + */ public static BinaryQuantization load(RandomAccessReader in) throws IOException { int dimension = in.readInt(); // We used to record the center of the dataset but this actually degrades performance. diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/CompressedVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/CompressedVectors.java index ee60859b7..86327e605 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/CompressedVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/CompressedVectors.java @@ -25,45 +25,92 @@ import java.io.DataOutput; import java.io.IOException; +/** + * Represents a collection of vectors that have been compressed using a {@link VectorCompressor}. + *

+ * This interface provides methods for serialization, size information, and creating score functions + * for similarity comparisons. Compressed vectors trade precision for reduced memory footprint, + * enabling efficient approximate nearest neighbor search. + */ public interface CompressedVectors extends Accountable { /** - * Write the compressed vectors to the given DataOutput + * Writes the compressed vectors to the given DataOutput using the specified serialization version. * @param out the DataOutput to write to - * @param version the serialization version. versions 2 and 3 are supported + * @param version the serialization version; versions 2 and 3 are supported + * @throws IOException if an I/O error occurs during writing */ void write(DataOutput out, int version) throws IOException; /** - * Write the compressed vectors to the given DataOutput at the current serialization version + * Writes the compressed vectors to the given DataOutput at the current serialization version. + * @param out the DataOutput to write to + * @throws IOException if an I/O error occurs during writing */ default void write(DataOutput out) throws IOException { write(out, OnDiskGraphIndex.CURRENT_VERSION); } - /** @return the original size of each vector, in bytes, before compression */ + /** + * Returns the original size of each vector in bytes, before compression. + * @return the original size of each vector, in bytes + */ int getOriginalSize(); - /** @return the compressed size of each vector, in bytes */ + /** + * Returns the compressed size of each vector in bytes. + * @return the compressed size of each vector, in bytes + */ int getCompressedSize(); - /** @return the compressor used by this instance */ + /** + * Returns the compressor used by this instance. + * @return the compressor used by this instance + */ VectorCompressor getCompressor(); - /** precomputes partial scores for the given query with every centroid; suitable for most searches */ + /** + * Creates an approximate score function with precomputed partial scores for the query vector + * against every centroid. This is suitable for most search operations where precomputation + * cost can be amortized across many score comparisons. + * @param q the query vector + * @param similarityFunction the similarity function to use for scoring + * @return an approximate score function with precomputed scores + */ ScoreFunction.ApproximateScoreFunction precomputedScoreFunctionFor(VectorFloat q, VectorSimilarityFunction similarityFunction); - /** no precomputation; suitable when just a handful of score computations are performed */ + /** + * Creates an approximate score function without precomputation, suitable for diversity checks + * where only a handful of score computations are performed per node. + * @param nodeId the node ID to compute scores against + * @param similarityFunction the similarity function to use for scoring + * @return an approximate score function without precomputation + */ ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int nodeId, VectorSimilarityFunction similarityFunction); - /** no precomputation; suitable when just a handful of score computations are performed */ + /** + * Creates an approximate score function without precomputation, suitable when only a small number + * of score computations are performed. + * @param q the query vector + * @param similarityFunction the similarity function to use for scoring + * @return an approximate score function without precomputation + */ ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, VectorSimilarityFunction similarityFunction); - + /** + * Creates an approximate score function for the given query vector. + * @param q the query vector + * @param similarityFunction the similarity function to use for scoring + * @return an approximate score function with precomputed scores + * @deprecated use {@link #precomputedScoreFunctionFor(VectorFloat, VectorSimilarityFunction)} instead + */ @Deprecated default ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat q, VectorSimilarityFunction similarityFunction) { return precomputedScoreFunctionFor(q, similarityFunction); } - /** the number of vectors */ + /** + * Returns the number of compressed vectors in this collection. + * @return the number of vectors + */ int count(); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java index d55ffbd8c..0878f7927 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java @@ -33,28 +33,52 @@ */ public abstract class FusedADCPQDecoder implements ScoreFunction.ApproximateScoreFunction { private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + /** The product quantization configuration */ protected final ProductQuantization pq; + /** The query vector */ protected final VectorFloat query; + /** The exact score function for fallback computations */ protected final ExactScoreFunction esf; + /** Quantized partial sums for efficient similarity computation */ protected final ByteSequence partialQuantizedSums; - // connected to the Graph View by caller + /** Provides access to packed neighbors for bulk operations */ protected final FusedADC.PackedNeighbors neighbors; - // caller passes this to us for re-use across calls + /** Reusable vector to store bulk similarity results */ protected final VectorFloat results; - // decoder state + /** Partial sums computed from codebooks */ protected final VectorFloat partialSums; + /** Best possible distances for each subspace */ protected final VectorFloat partialBestDistances; + /** Number of invocations before switching to quantized similarity mode */ protected final int invocationThreshold; + /** Current number of invocations */ protected int invocations = 0; + /** Best distance seen so far */ protected float bestDistance; + /** Worst distance seen so far */ protected float worstDistance; + /** Delta value for quantization */ protected float delta; + /** Whether quantized similarity mode is enabled */ protected boolean supportsQuantizedSimilarity = false; + /** The vector similarity function being used */ protected final VectorSimilarityFunction vsf; - // Implements section 3.4 of "Quicker ADC : Unlocking the Hidden Potential of Product Quantization with SIMD" - // The main difference is that since our graph structure rapidly converges towards the best results, - // we don't need to scan K values to have enough confidence that our worstDistance bound is reasonable. + /** + * Creates a new FusedADCPQDecoder for efficient approximate similarity computations. + *

+ * Implements section 3.4 of "Quicker ADC : Unlocking the Hidden Potential of Product Quantization with SIMD". + * The main difference is that since our graph structure rapidly converges towards the best results, + * we don't need to scan K values to have enough confidence that our worstDistance bound is reasonable. + * + * @param pq the product quantization to use for decoding + * @param query the query vector + * @param invocationThreshold the number of invocations before switching to quantized similarity mode + * @param neighbors provides access to packed neighbors for bulk operations + * @param results reusable vector to store bulk similarity results + * @param esf the exact score function for fallback computations + * @param vsf the vector similarity function to use + */ protected FusedADCPQDecoder(ProductQuantization pq, VectorFloat query, int invocationThreshold, FusedADC.PackedNeighbors neighbors, VectorFloat results, ExactScoreFunction esf, VectorSimilarityFunction vsf) { this.pq = pq; this.query = query; @@ -133,11 +157,34 @@ public float similarityTo(int node2) { return esf.similarityTo(node2); } + /** + * Converts a distance value to a similarity score based on the similarity function. + * + * @param distance the distance value to convert + * @return the similarity score + */ protected abstract float distanceToScore(float distance); + /** + * Updates the worst distance observed so far during search. + * + * @param distance the new distance value to consider + */ protected abstract void updateWorstDistance(float distance); + /** + * Decoder specialized for dot product similarity function. + */ static class DotProductDecoder extends FusedADCPQDecoder { + /** + * Creates a new DotProductDecoder. + * + * @param neighbors provides access to packed neighbors for bulk operations + * @param pq the product quantization to use for decoding + * @param query the query vector + * @param results reusable vector to store bulk similarity results + * @param esf the exact score function for fallback computations + */ public DotProductDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat query, VectorFloat results, ExactScoreFunction esf) { super(pq, query, neighbors.maxDegree(), neighbors, results, esf, VectorSimilarityFunction.DOT_PRODUCT); worstDistance = Float.MAX_VALUE; // initialize at best value, update as we search @@ -154,7 +201,19 @@ protected void updateWorstDistance(float distance) { } } + /** + * Decoder specialized for Euclidean distance similarity function. + */ static class EuclideanDecoder extends FusedADCPQDecoder { + /** + * Creates a new EuclideanDecoder. + * + * @param neighbors provides access to packed neighbors for bulk operations + * @param pq the product quantization to use for decoding + * @param query the query vector + * @param results reusable vector to store bulk similarity results + * @param esf the exact score function for fallback computations + */ public EuclideanDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat query, VectorFloat results, ExactScoreFunction esf) { super(pq, query, neighbors.maxDegree(), neighbors, results, esf, VectorSimilarityFunction.EUCLIDEAN); worstDistance = 0; // initialize at best value, update as we search @@ -172,9 +231,13 @@ protected void updateWorstDistance(float distance) { } - // CosineDecoder differs from DotProductDecoder/EuclideanDecoder because there are two different tables of quantized fragments to sum: query to codebook entry dot products, - // and codebook entry to codebook entry dot products. The latter can be calculated once per ProductQuantization, but for lookups to go at the appropriate speed, they must - // also be quantized. We use a similar quantization to partial sums, but we know exactly the worst/best bounds, so overflow does not matter. + /** + * Decoder specialized for cosine similarity function. + *

+ * CosineDecoder differs from DotProductDecoder/EuclideanDecoder because there are two different tables of quantized fragments to sum: query to codebook entry dot products, + * and codebook entry to codebook entry dot products. The latter can be calculated once per ProductQuantization, but for lookups to go at the appropriate speed, they must + * also be quantized. We use a similar quantization to partial sums, but we know exactly the worst/best bounds, so overflow does not matter. + */ static class CosineDecoder extends FusedADCPQDecoder { private final float queryMagnitudeSquared; private final VectorFloat partialSquaredMagnitudes; @@ -186,6 +249,15 @@ static class CosineDecoder extends FusedADCPQDecoder { private float minSquaredMagnitude; private float squaredMagnitudeDelta; + /** + * Creates a new CosineDecoder. + * + * @param neighbors provides access to packed neighbors for bulk operations + * @param pq the product quantization to use for decoding + * @param query the query vector + * @param results reusable vector to store bulk similarity results + * @param esf the exact score function for fallback computations + */ protected CosineDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat query, VectorFloat results, ExactScoreFunction esf) { super(pq, query, neighbors.maxDegree(), neighbors, results, esf, VectorSimilarityFunction.COSINE); worstDistance = Float.MAX_VALUE; // initialize at best value, update as we search @@ -301,6 +373,18 @@ protected void updateWorstDistance(float distance) { }; } + /** + * Factory method that creates the appropriate decoder based on the similarity function. + * + * @param neighbors provides access to packed neighbors for bulk operations + * @param pq the product quantization to use for decoding + * @param query the query vector + * @param results reusable vector to store bulk similarity results + * @param similarityFunction the vector similarity function to use + * @param esf the exact score function for fallback computations + * @return a new decoder instance appropriate for the similarity function + * @throws IllegalArgumentException if the similarity function is not supported + */ public static FusedADCPQDecoder newDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat query, VectorFloat results, VectorSimilarityFunction similarityFunction, ExactScoreFunction esf) { switch (similarityFunction) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutableBQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutableBQVectors.java index 4acc4744d..b98acbe6f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutableBQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutableBQVectors.java @@ -16,12 +16,27 @@ package io.github.jbellis.jvector.quantization; +/** + * An immutable collection of binary quantized vectors. + * This class provides read-only access to a fixed set of compressed vectors. + */ public class ImmutableBQVectors extends BQVectors { + /** + * Creates a new ImmutableBQVectors instance with the given quantization and compressed vectors. + * + * @param bq the binary quantization configuration + * @param compressedVectors the array of compressed vector data + */ public ImmutableBQVectors(BinaryQuantization bq, long[][] compressedVectors) { super(bq); this.compressedVectors = compressedVectors; } + /** + * Returns the number of vectors in this collection. + * + * @return the count of compressed vectors + */ @Override public int count() { return compressedVectors.length; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutablePQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutablePQVectors.java index 8df05b675..d3c016405 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutablePQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutablePQVectors.java @@ -25,6 +25,9 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +/** + * An immutable implementation of PQVectors with precomputed codebook partial sums for efficient diversity scoring. + */ public class ImmutablePQVectors extends PQVectors { private final int vectorCount; private final Map> codebookPartialSumsMap; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java index 13d683327..665e89480 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java @@ -38,6 +38,7 @@ public class KMeansPlusPlusClusterer { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + /** Sentinel value indicating unweighted (isotropic) L2 distance should be used for clustering. */ public static final float UNWEIGHTED = -1.0f; // number of centroids to compute @@ -58,15 +59,23 @@ public class KMeansPlusPlusClusterer { private final VectorFloat[] centroidNums; // the sum of all points assigned to each cluster /** - * Constructs a KMeansPlusPlusFloatClusterer with the specified points and number of clusters. + * Constructs a KMeansPlusPlusClusterer with the specified points and number of clusters, using unweighted L2 distance. * * @param points the points to cluster (points[n][i] is the ith component of the nth point) - * @param k number of clusters. + * @param k the number of clusters */ public KMeansPlusPlusClusterer(VectorFloat[] points, int k) { this(points, chooseInitialCentroids(points, k), UNWEIGHTED); } + /** + * Constructs a KMeansPlusPlusClusterer with the specified points, number of clusters, and anisotropic threshold. + * + * @param points the points to cluster (points[n][i] is the ith component of the nth point) + * @param k the number of clusters + * @param anisotropicThreshold the threshold of relevance for anisotropic angular distance shaping, giving + * higher priority to parallel error. Use {@link #UNWEIGHTED} for normal isotropic L2 distance. + */ public KMeansPlusPlusClusterer(VectorFloat[] points, int k, float anisotropicThreshold) { this(points, chooseInitialCentroids(points, k), anisotropicThreshold); } @@ -108,10 +117,14 @@ public KMeansPlusPlusClusterer(VectorFloat[] points, VectorFloat centroids } /** - * Compute the parallel cost multiplier for a given threshold and squared norm. + * Computes the parallel cost multiplier for a given threshold and squared norm. *

* This uses the approximation derived in Theorem 3.4 of * "Accelerating Large-Scale Inference with Anisotropic Vector Quantization". + * + * @param threshold the dot product threshold + * @param dimensions the number of dimensions in the vectors + * @return the parallel cost multiplier */ static float computeParallelCostMultiplier(double threshold, int dimensions) { assert Double.isFinite(threshold) : "threshold=" + threshold; @@ -126,7 +139,9 @@ static float computeParallelCostMultiplier(double threshold, int dimensions) { /** * Performs clustering on the provided set of points. * - * @return a VectorFloat of cluster centroids. + * @param unweightedIterations the number of unweighted clustering iterations to perform + * @param anisotropicIterations the number of anisotropic clustering iterations to perform + * @return a VectorFloat of cluster centroids */ public VectorFloat cluster(int unweightedIterations, int anisotropicIterations) { // Always cluster unweighted first, it is significantly faster @@ -148,11 +163,22 @@ public VectorFloat cluster(int unweightedIterations, int anisotropicIteration return centroids; } - // This is broken out as a separate public method to allow implementing OPQ efficiently + /** + * Performs one iteration of unweighted clustering. + * This is broken out as a separate public method to allow implementing OPQ efficiently. + * + * @return the number of points that changed clusters + */ public int clusterOnceUnweighted() { updateCentroidsUnweighted(); return updateAssignedPointsUnweighted(); } + + /** + * Performs one iteration of anisotropic clustering. + * + * @return the number of points that changed clusters + */ public int clusterOnceAnisotropic() { updateCentroidsAnisotropic(); return updateAssignedPointsAnisotropic(); @@ -311,7 +337,14 @@ private int updateAssignedPointsAnisotropic() { } /** - * Calculates the weighted distance between two data points. + * Calculates the weighted distance between a data point and a centroid, using anisotropic distance shaping. + * + * @param x the data point + * @param centroid the index of the centroid + * @param parallelCostMultiplier the parallel cost multiplier + * @param cNormSquared the squared norm of the centroid + * @param xNormSquared the squared norm of the data point + * @return the weighted distance */ private float weightedDistance(VectorFloat x, int centroid, float parallelCostMultiplier, float cNormSquared, float xNormSquared) { float cDotX = VectorUtil.dotProduct(centroids, centroid * x.length(), x, 0, x.length()); @@ -324,7 +357,10 @@ private float weightedDistance(VectorFloat x, int centroid, float parallelCos } /** - * Return the index of the closest centroid to the given point + * Returns the index of the closest centroid to the given point. + * + * @param point the point to find the nearest cluster for + * @return the index of the nearest cluster */ private int getNearestCluster(VectorFloat point) { float minDistance = Float.MAX_VALUE; @@ -341,6 +377,12 @@ private int getNearestCluster(VectorFloat point) { return nearestCluster; } + /** + * Asserts that all elements of the vector are finite (not NaN or infinite). + * This assertion is only checked when assertions are enabled. + * + * @param vector the vector to check + */ @SuppressWarnings({"AssertWithSideEffects", "ConstantConditions"}) private static void assertFinite(VectorFloat vector) { boolean assertsEnabled = false; @@ -354,7 +396,7 @@ private static void assertFinite(VectorFloat vector) { } /** - * Calculates centroids from centroidNums/centroidDenoms updated during point assignment + * Calculates centroids from centroidNums/centroidDenoms updated during point assignment. */ private void updateCentroidsUnweighted() { for (int i = 0; i < k; i++) { @@ -370,12 +412,20 @@ private void updateCentroidsUnweighted() { } } + /** + * Initializes a centroid to a random point from the dataset. + * + * @param i the index of the centroid to initialize + */ private void initializeCentroidToRandomPoint(int i) { var random = ThreadLocalRandom.current(); centroids.copyFrom(points[random.nextInt(points.length)], 0, i * points[0].length(), points[0].length()); } - // Uses the algorithm given in appendix 7.5 of "Accelerating Large-Scale Inference with Anisotropic Vector Quantization" + /** + * Updates centroids using anisotropic clustering. + * Uses the algorithm given in appendix 7.5 of "Accelerating Large-Scale Inference with Anisotropic Vector Quantization". + */ private void updateCentroidsAnisotropic() { int dimensions = points[0].length(); float pcm = computeParallelCostMultiplier(anisotropicThreshold, dimensions); @@ -432,6 +482,9 @@ private void updateCentroidsAnisotropic() { /** * Computes the centroid of a list of points. + * + * @param points the list of points + * @return the centroid vector */ public static VectorFloat centroidOf(List> points) { if (points.isEmpty()) { @@ -444,6 +497,11 @@ public static VectorFloat centroidOf(List> points) { return centroid; } + /** + * Returns the centroids computed by the clustering algorithm. + * + * @return the centroids as a single flattened vector + */ public VectorFloat getCentroids() { return centroids; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableBQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableBQVectors.java index 4256e3700..2fbc4a708 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableBQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableBQVectors.java @@ -18,11 +18,15 @@ import io.github.jbellis.jvector.vector.types.VectorFloat; +/** + * A threadsafe mutable BQVectors implementation that grows dynamically as needed. + */ @SuppressWarnings("unused") public class MutableBQVectors extends BQVectors implements MutableCompressedVectors> { private static final int INITIAL_CAPACITY = 1024; private static final float GROWTH_FACTOR = 1.5f; - + + /** The number of vectors currently stored. */ protected int vectorCount; /** @@ -36,6 +40,11 @@ public MutableBQVectors(BinaryQuantization bq) { this.vectorCount = 0; } + /** + * Ensures the internal array has capacity to store a vector at the given ordinal. + * + * @param ordinal the ordinal to ensure capacity for + */ private void ensureCapacity(int ordinal) { if (ordinal >= compressedVectors.length) { int newCapacity = Math.max(ordinal + 1, (int)(compressedVectors.length * GROWTH_FACTOR)); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableCompressedVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableCompressedVectors.java index 49c4b0790..574bbb77f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableCompressedVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableCompressedVectors.java @@ -16,6 +16,11 @@ package io.github.jbellis.jvector.quantization; +/** + * Interface for mutable compressed vector storage that allows adding and modifying vectors. + * + * @param the type of vectors to compress + */ public interface MutableCompressedVectors extends CompressedVectors { /** * Encode the given vector and set it at the given ordinal. Done without unnecessary copying. diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutablePQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutablePQVectors.java index b47d2b9ce..6e61641ca 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutablePQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutablePQVectors.java @@ -35,6 +35,7 @@ public class MutablePQVectors extends PQVectors implements MutableCompressedVect private static final int INITIAL_CHUNKS = 10; private static final float GROWTH_FACTOR = 1.5f; + /** The number of vectors currently stored. */ protected AtomicInteger vectorCount; /** diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQScorer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQScorer.java index a82520871..2e8d421c1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQScorer.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQScorer.java @@ -20,16 +20,30 @@ import io.github.jbellis.jvector.vector.VectorUtil; import io.github.jbellis.jvector.vector.types.VectorFloat; +/** + * Provides scoring functions for comparing NVQ-quantized vectors with query vectors. + * Supports dot product, Euclidean, and cosine similarity functions. + */ public class NVQScorer { final NVQuantization nvq; /** - * Initialize the NVQScorer with an instance of NVQuantization. + * Constructs an NVQScorer with the given NVQuantization instance. + * + * @param nvq the NVQuantization instance to use for scoring */ public NVQScorer(NVQuantization nvq) { this.nvq = nvq; } + /** + * Creates a score function for comparing the query vector against NVQ-quantized vectors. + * + * @param query the query vector + * @param similarityFunction the similarity function to use + * @return a score function for the given query and similarity function + * @throws IllegalArgumentException if the similarity function is not supported + */ public NVQScoreFunction scoreFunctionFor(VectorFloat query, VectorSimilarityFunction similarityFunction) { switch (similarityFunction) { case DOT_PRODUCT: @@ -43,6 +57,13 @@ public NVQScoreFunction scoreFunctionFor(VectorFloat query, VectorSimilarityF } } + /** + * Creates a dot product score function for the given query vector. + * + * @param query the query vector + * @return a dot product score function + * @throws IllegalArgumentException if the bits per dimension is not supported + */ private NVQScoreFunction dotProductScoreFunctionFor(VectorFloat query) { /* Each sub-vector of query vector (full resolution) will be compared to NVQ quantized sub-vectors that were * first de-meaned by subtracting the global mean. @@ -72,6 +93,14 @@ private NVQScoreFunction dotProductScoreFunctionFor(VectorFloat query) { } } + /** + * Creates a Euclidean distance score function for the given query vector. + * The score is converted to a similarity using 1 / (1 + distance). + * + * @param query the query vector + * @return a Euclidean similarity score function + * @throws IllegalArgumentException if the bits per dimension is not supported + */ private NVQScoreFunction euclideanScoreFunctionFor(VectorFloat query) { /* Each sub-vector of query vector (full resolution) will be compared to NVQ quantized sub-vectors that were * first de-meaned by subtracting the global mean. @@ -103,6 +132,13 @@ private NVQScoreFunction euclideanScoreFunctionFor(VectorFloat query) { } } + /** + * Creates a cosine similarity score function for the given query vector. + * + * @param query the query vector + * @return a cosine similarity score function + * @throws IllegalArgumentException if the bits per dimension is not supported + */ private NVQScoreFunction cosineScoreFunctionFor(VectorFloat query) { float queryNorm = (float) Math.sqrt(VectorUtil.dotProduct(query, query)); var querySubVectors = this.nvq.getSubVectors(query); @@ -136,9 +172,15 @@ private NVQScoreFunction cosineScoreFunctionFor(VectorFloat query) { } } + /** + * A functional interface for computing similarity between a query and an NVQ-quantized vector. + */ public interface NVQScoreFunction { /** - * @return the similarity to another vector + * Computes the similarity score to another quantized vector. + * + * @param vector2 the quantized vector to compare against + * @return the similarity score */ float similarityTo(NVQuantization.QuantizedVector vector2); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQVectors.java index bf8019d9d..d6b89a10b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQVectors.java @@ -27,14 +27,21 @@ import java.util.Arrays; import java.util.Objects; +/** + * A collection of vectors compressed using NVQ (Non-uniform Vector Quantization). + * This class implements CompressedVectors and provides scoring and serialization capabilities. + */ public class NVQVectors implements CompressedVectors { final NVQuantization nvq; final NVQScorer scorer; final NVQuantization.QuantizedVector[] compressedVectors; /** - * Initialize the NVQVectors with an initial array of vectors. This array may be + * Initializes the NVQVectors with an initial array of vectors. This array may be * mutated, but caller is responsible for thread safety issues when doing so. + * + * @param nvq the NVQuantization compressor used for these vectors + * @param compressedVectors the array of quantized vectors */ public NVQVectors(NVQuantization nvq, NVQuantization.QuantizedVector[] compressedVectors) { this.nvq = nvq; @@ -42,11 +49,23 @@ public NVQVectors(NVQuantization nvq, NVQuantization.QuantizedVector[] compresse this.compressedVectors = compressedVectors; } + /** + * Returns the number of compressed vectors in this collection. + * + * @return the count of compressed vectors + */ @Override public int count() { return compressedVectors.length; } + /** + * Serializes this NVQVectors instance to a DataOutput. + * + * @param out the DataOutput to write to + * @param version the serialization version to use + * @throws IOException if an I/O error occurs during writing + */ @Override public void write(DataOutput out, int version) throws IOException { @@ -60,6 +79,13 @@ public void write(DataOutput out, int version) throws IOException } } + /** + * Deserializes an NVQVectors instance from a RandomAccessReader. + * + * @param in the RandomAccessReader to read from + * @return the deserialized NVQVectors instance + * @throws IOException if an I/O error occurs during reading or if the vector count is invalid + */ public static NVQVectors load(RandomAccessReader in) throws IOException { var nvq = NVQuantization.load(in); @@ -77,6 +103,14 @@ public static NVQVectors load(RandomAccessReader in) throws IOException { return new NVQVectors(nvq, compressedVectors); } + /** + * Deserializes an NVQVectors instance from a RandomAccessReader starting at a specific offset. + * + * @param in the RandomAccessReader to read from + * @param offset the byte offset to start reading from + * @return the deserialized NVQVectors instance + * @throws IOException if an I/O error occurs during reading + */ public static NVQVectors load(RandomAccessReader in, long offset) throws IOException { in.seek(offset); return load(in); @@ -113,10 +147,21 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve throw new UnsupportedOperationException(); } + /** + * Returns the quantized vector at the specified ordinal. + * + * @param ordinal the index of the vector to retrieve + * @return the quantized vector at the specified index + */ public NVQuantization.QuantizedVector get(int ordinal) { return compressedVectors[ordinal]; } + /** + * Returns the NVQuantization compressor used by this collection. + * + * @return the NVQuantization instance + */ public NVQuantization getNVQuantization() { return nvq; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java index 0354e82be..d76771422 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java @@ -40,7 +40,14 @@ * It divides each vector in subvectors and then quantizes each one individually using a non-uniform quantizer. */ public class NVQuantization implements VectorCompressor, Accountable { + /** + * Enum representing the number of bits used per dimension during NVQ quantization. + * Determines the precision and compression ratio of the quantization. + */ public enum BitsPerDimension { + /** + * Eight bits per dimension (one byte per dimension). + */ EIGHT { @Override public int getInt() { @@ -52,6 +59,9 @@ public ByteSequence createByteSequence(int nDimensions) { return vectorTypeSupport.createByteSequence(nDimensions); } }, + /** + * Four bits per dimension (half a byte per dimension). + */ FOUR { @Override public int getInt() { @@ -74,7 +84,9 @@ public void write(DataOutput out) throws IOException { } /** - * Returns the integer 4 for FOUR and 8 for EIGHT + * Returns the integer value representing the number of bits. + * + * @return 4 for FOUR, 8 for EIGHT */ public abstract int getInt(); @@ -87,8 +99,11 @@ public void write(DataOutput out) throws IOException { /** * Loads the BitsPerDimension from a RandomAccessReader. - * @param in the RandomAccessReader to read from. - * @throws IOException if there is a problem reading from in. + * + * @param in the RandomAccessReader to read from + * @return the loaded BitsPerDimension value + * @throws IOException if there is a problem reading from in + * @throws IllegalArgumentException if an unsupported bits per dimension value is encountered */ public static BitsPerDimension load(RandomAccessReader in) throws IOException { int nBitsPerDimension = in.readInt(); @@ -103,26 +118,39 @@ public static BitsPerDimension load(RandomAccessReader in) throws IOException { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); - // How many bits to use for each dimension when quantizing the vector: + /** + * The number of bits to use for each dimension when quantizing the vector. + */ public final BitsPerDimension bitsPerDimension; - // We subtract the global mean vector to make it robust against center datasets with a large mean: + /** + * The global mean vector, subtracted to make quantization robust against datasets with large mean values. + */ public final VectorFloat globalMean; - // The number of dimensions of the original (uncompressed) vectors: + /** + * The number of dimensions of the original (uncompressed) vectors. + */ public final int originalDimension; - // A matrix that stores the size and starting point of each subvector: + /** + * A matrix that stores the size and starting point of each subvector. + * Each row contains [size, offset] for a subvector. + */ public final int[][] subvectorSizesAndOffsets; - // Whether we want to skip the optimization of the NVQ parameters. Here for debug purposes only. + /** + * Whether to optimize the NVQ parameters during quantization. Set to false to skip optimization (for testing). + */ @VisibleForTesting public boolean learn = true; /** - * Class constructor. - * @param subvectorSizesAndOffsets a matrix that stores the size and starting point of each subvector - * @param globalMean the mean of the database (its average vector) + * Constructs an NVQuantization instance with the specified subvector configuration and global mean. + * + * @param subvectorSizesAndOffsets a matrix where each row contains [size, offset] for a subvector + * @param globalMean the mean vector of the dataset, used to center the data before quantization + * @throws IllegalArgumentException if the global mean length does not match the total vector dimensionality */ private NVQuantization(int[][] subvectorSizesAndOffsets, VectorFloat globalMean) { this.bitsPerDimension = BitsPerDimension.EIGHT; @@ -137,10 +165,12 @@ private NVQuantization(int[][] subvectorSizesAndOffsets, VectorFloat globalMe } /** - * Computes the global mean vector and the data structures used to divide each vector into subvectors. + * Computes an NVQuantization instance by calculating the global mean vector and determining + * the optimal division of vectors into subvectors. * * @param ravv the vectors to quantize - * @param nSubVectors number of subvectors + * @param nSubVectors the number of subvectors to divide each vector into + * @return a new NVQuantization instance configured for the given vectors and subvector count */ public static NVQuantization compute(RandomAccessVectorValues ravv, int nSubVectors) { var subvectorSizesAndOffsets = getSubvectorSizesAndOffsets(ravv.dimension(), nSubVectors); @@ -156,13 +186,23 @@ public static NVQuantization compute(RandomAccessVectorValues ravv, int nSubVect } + /** + * Creates a CompressedVectors instance from an array of compressed vectors. + * + * @param compressedVectors an array of QuantizedVector objects + * @return a new NVQVectors instance wrapping the compressed vectors + */ @Override public CompressedVectors createCompressedVectors(Object[] compressedVectors) { return new NVQVectors(this, (QuantizedVector[]) compressedVectors); } /** - * Encodes the given vectors in parallel using NVQ. + * Encodes all vectors in the given collection in parallel using NVQ. + * + * @param ravv the vectors to encode + * @param parallelExecutor the thread pool to use for parallel encoding + * @return a new NVQVectors instance containing all encoded vectors */ @Override public NVQVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) { @@ -180,8 +220,10 @@ public NVQVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool parallel } /** - * Encodes the input vector using NVQ. - * @return one subvector per subspace + * Encodes a single vector using NVQ by dividing it into subvectors and quantizing each independently. + * + * @param vector the vector to encode + * @return a QuantizedVector containing one quantized subvector per subspace */ @Override public QuantizedVector encode(VectorFloat vector) { @@ -191,7 +233,11 @@ public QuantizedVector encode(VectorFloat vector) { } /** - * Encodes the input vector using NVQ into dest + * Encodes a single vector using NVQ and writes the result into the provided destination. + * This method avoids allocating a new QuantizedVector. + * + * @param v the vector to encode + * @param dest the destination QuantizedVector to write the encoded result into */ @Override public void encodeTo(VectorFloat v, NVQuantization.QuantizedVector dest) { @@ -200,7 +246,11 @@ public void encodeTo(VectorFloat v, NVQuantization.QuantizedVector dest) { } /** - * Creates an array of subvectors from a given vector + * Divides a vector into subvectors according to the configured subvector sizes and offsets. + * Each subvector is a slice of the original vector. + * + * @param vector the vector to divide into subvectors + * @return an array of subvectors, one for each configured subspace */ public VectorFloat[] getSubVectors(VectorFloat vector) { VectorFloat[] subvectors = new VectorFloat[subvectorSizesAndOffsets.length]; @@ -217,7 +267,13 @@ public VectorFloat[] getSubVectors(VectorFloat vector) { } /** - * Splits the vector dimension into M subvectors of roughly equal size. + * Computes the sizes and offsets for dividing a vector into M subvectors of roughly equal size. + * Any remainder dimensions are distributed among the first subvectors. + * + * @param dimensions the total number of dimensions in the vector + * @param M the number of subvectors to create + * @return a matrix where each row contains [size, offset] for a subvector + * @throws IllegalArgumentException if M is greater than dimensions */ static int[][] getSubvectorSizesAndOffsets(int dimensions, int M) { if (M > dimensions) { @@ -237,10 +293,12 @@ static int[][] getSubvectorSizesAndOffsets(int dimensions, int M) { } /** - * Writes the instance to a DataOutput. - * @param out DataOutput to write to - * @param version serialization version. - * @throws IOException fails if we cannot write to the DataOutput + * Serializes this NVQuantization instance to a DataOutput. + * + * @param out the DataOutput to write to + * @param version the serialization version to use + * @throws IOException if an I/O error occurs during writing + * @throws IllegalArgumentException if the version is greater than the current supported version */ public void write(DataOutput out, int version) throws IOException { @@ -263,7 +321,9 @@ public void write(DataOutput out, int version) throws IOException } /** - * Returns the size in bytes of this class when writing it using the write method. + * Returns the size in bytes required to serialize this NVQuantization instance. + * This includes the version, global mean, bits per dimension, and subvector configuration. + * * @return the size in bytes */ @Override @@ -279,10 +339,11 @@ public int compressorSize() { } /** - * Loads an instance from a RandomAccessReader - * @param in the RandomAccessReader - * @return the instance - * @throws IOException fails if we cannot read from the RandomAccessReader + * Deserializes an NVQuantization instance from a RandomAccessReader. + * + * @param in the RandomAccessReader to read from + * @return the deserialized NVQuantization instance + * @throws IOException if an I/O error occurs during reading */ public static NVQuantization load(RandomAccessReader in) throws IOException { int version = in.readInt(); @@ -327,6 +388,11 @@ public int hashCode() { return result; } + /** + * Returns the size in bytes of a single compressed vector. + * + * @return the size in bytes of a compressed vector + */ @Override public int compressedVectorSize() { int size = Integer.BYTES; // number of subvectors @@ -336,6 +402,11 @@ public int compressedVectorSize() { return size; } + /** + * Returns the approximate memory usage in bytes of this NVQuantization instance. + * + * @return the memory usage in bytes + */ @Override public long ramBytesUsed() { return globalMean.ramBytesUsed(); @@ -347,16 +418,20 @@ public String toString() { } /** - * A NuVeQ vector. + * Represents a vector that has been quantized using NVQ (Non-uniform Vector Quantization). + * A quantized vector consists of multiple quantized subvectors, one for each subspace. */ public static class QuantizedVector { + /** The array of quantized subvectors, one per subspace. */ public final QuantizedSubVector[] subVectors; /** - * Class constructor. - * @param subVectors receives the subvectors to quantize - * @param bitsPerDimension the number of bits per dimension - * @param learn whether to use optimization to find the parameters of the nonlinearity + * Quantizes an array of subvectors and writes the result into the destination QuantizedVector. + * + * @param subVectors the array of subvectors to quantize + * @param bitsPerDimension the number of bits to use for each dimension + * @param learn whether to optimize the quantization parameters for each subvector + * @param dest the destination QuantizedVector to write the quantized subvectors into */ public static void quantizeTo(VectorFloat[] subVectors, BitsPerDimension bitsPerDimension, boolean learn, QuantizedVector dest) { for (int i = 0; i < subVectors.length; i++) { @@ -365,17 +440,22 @@ public static void quantizeTo(VectorFloat[] subVectors, BitsPerDimension bits } /** - * Constructs an instance from existing subvectors. Used when loading from a RandomAccessReader. - * @param subVectors the subvectors + * Constructs a QuantizedVector from an array of existing quantized subvectors. + * This constructor is typically used when deserializing from a RandomAccessReader. + * + * @param subVectors the array of quantized subvectors */ private QuantizedVector(QuantizedSubVector[] subVectors) { this.subVectors = subVectors; } /** - * Create an empty instance. Meant to be used as scratch space in conjunction with loadInto - * @param subvectorSizesAndOffsets the array containing the sizes for the subvectors + * Creates an empty QuantizedVector with uninitialized data. + * This is intended to be used as scratch space in conjunction with loadInto. + * + * @param subvectorSizesAndOffsets the matrix defining subvector sizes and offsets * @param bitsPerDimension the number of bits per dimension + * @return a new empty QuantizedVector ready to be populated */ public static QuantizedVector createEmpty(int[][] subvectorSizesAndOffsets, BitsPerDimension bitsPerDimension) { var subVectors = new QuantizedSubVector[subvectorSizesAndOffsets.length]; @@ -387,9 +467,10 @@ public static QuantizedVector createEmpty(int[][] subvectorSizesAndOffsets, Bits /** - * Write the instance to a DataOutput - * @param out the DataOutput - * @throws IOException fails if we cannot write to the DataOutput + * Serializes this QuantizedVector to a DataOutput. + * + * @param out the DataOutput to write to + * @throws IOException if an I/O error occurs during writing */ public void write(DataOutput out) throws IOException { out.writeInt(subVectors.length); @@ -400,9 +481,11 @@ public void write(DataOutput out) throws IOException { } /** - * Read the instance from a RandomAccessReader - * @param in the RandomAccessReader - * @throws IOException fails if we cannot read from the RandomAccessReader + * Deserializes a QuantizedVector from a RandomAccessReader by allocating a new instance. + * + * @param in the RandomAccessReader to read from + * @return the deserialized QuantizedVector + * @throws IOException if an I/O error occurs during reading */ public static QuantizedVector load(RandomAccessReader in) throws IOException { int length = in.readInt(); @@ -415,9 +498,12 @@ public static QuantizedVector load(RandomAccessReader in) throws IOException { } /** - * Read the instance from a RandomAccessReader - * @param in the RandomAccessReader - * @throws IOException fails if we cannot read from the RandomAccessReader + * Deserializes a QuantizedVector from a RandomAccessReader into an existing instance. + * This avoids allocating a new QuantizedVector instance. + * + * @param in the RandomAccessReader to read from + * @param qvector the existing QuantizedVector to populate with deserialized data + * @throws IOException if an I/O error occurs during reading */ public static void loadInto(RandomAccessReader in, QuantizedVector qvector) throws IOException { in.readInt(); @@ -437,29 +523,52 @@ public boolean equals(Object o) { } /** - * A NuVeQ sub-vector. + * Represents a single quantized subvector within an NVQ-encoded vector. + * Each subvector stores the quantized bytes along with parameters needed for dequantization. */ public static class QuantizedSubVector { - // The byte sequence that stores the quantized subvector + /** + * The byte sequence that stores the quantized subvector. + */ public ByteSequence bytes; - // The number of bits for each dimension of the input uncompressed subvector + /** + * The number of bits used for each dimension of the input uncompressed subvector. + */ public BitsPerDimension bitsPerDimension; - // The NVQ parameters + /** + * The growth rate parameter for the non-uniform quantization function. + */ public float growthRate; + + /** + * The midpoint parameter for the non-uniform quantization function. + */ public float midpoint; + + /** + * The maximum value in the original subvector. + */ public float maxValue; + + /** + * The minimum value in the original subvector. + */ public float minValue; - // The number of dimensions of the input uncompressed subvector + /** + * The number of dimensions of the input uncompressed subvector. + */ public int originalDimensions; /** - * Return the number of bytes occupied by the serialization of a QuantizedSubVector - * @param nDims the number fof dimensions of the subvector - * @param bitsPerDimension the number of bits per dimensions - * @return the size in bytes of the quantized subvector + * Computes the size in bytes required to serialize a QuantizedSubVector. + * + * @param nDims the number of dimensions in the subvector + * @param bitsPerDimension the number of bits used per dimension + * @return the total size in bytes of the serialized quantized subvector + * @throws IllegalArgumentException if an unsupported bits per dimension value is provided */ public static int compressedVectorSize(int nDims, BitsPerDimension bitsPerDimension) { // Here we assume that an enum takes 4 bytes @@ -470,11 +579,14 @@ public static int compressedVectorSize(int nDims, BitsPerDimension bitsPerDimens } /** - * Quantize the vector using NVQ into dest + * Quantizes a float subvector using NVQ and writes the result into the destination. + * If learn is true, optimizes the quantization parameters to minimize reconstruction error. + * * @param vector the subvector to quantize - * @param bitsPerDimension the number of bits per dimension - * @param learn whether to use optimization to find the parameters of the nonlinearity - * @param dest the destination subvector + * @param bitsPerDimension the number of bits to use per dimension + * @param learn whether to optimize quantization parameters for this specific subvector + * @param dest the destination QuantizedSubVector to write the result into + * @throws IllegalArgumentException if an unsupported bits per dimension value is provided */ public static void quantizeTo(VectorFloat vector, BitsPerDimension bitsPerDimension, boolean learn, QuantizedSubVector dest) { var minValue = VectorUtil.min(vector); @@ -533,7 +645,16 @@ public static void quantizeTo(VectorFloat vector, BitsPerDimension bitsPerDim } /** - * Constructor used when loading from a RandomAccessReader. It takes its member fields. + * Constructs a QuantizedSubVector from its component fields. + * This constructor is typically used when deserializing from a RandomAccessReader. + * + * @param bytes the quantized byte representation + * @param originalDimensions the number of dimensions in the original unquantized subvector + * @param bitsPerDimension the number of bits used per dimension + * @param minValue the minimum value in the original subvector + * @param maxValue the maximum value in the original subvector + * @param growthRate the growth rate parameter for the non-uniform quantization function + * @param midpoint the midpoint parameter for the non-uniform quantization function */ private QuantizedSubVector(ByteSequence bytes, int originalDimensions, BitsPerDimension bitsPerDimension, float minValue, float maxValue, @@ -548,9 +669,10 @@ private QuantizedSubVector(ByteSequence bytes, int originalDimensions, BitsPe } /** - * Write the instance to a DataOutput - * @param out the DataOutput - * @throws IOException fails if we cannot write to the DataOutput + * Serializes this QuantizedSubVector to a DataOutput. + * + * @param out the DataOutput to write to + * @throws IOException if an I/O error occurs during writing */ public void write(DataOutput out) throws IOException { bitsPerDimension.write(out); @@ -565,9 +687,12 @@ public void write(DataOutput out) throws IOException { } /** - * Create an empty instance. Meant to be used as scratch space in conjunction with loadInto + * Creates an empty QuantizedSubVector with uninitialized data. + * This is intended to be used as scratch space in conjunction with loadInto. + * * @param bitsPerDimension the number of bits per dimension - * @param length the number of dimensions + * @param length the number of dimensions in the subvector + * @return a new empty QuantizedSubVector ready to be populated */ public static QuantizedSubVector createEmpty(BitsPerDimension bitsPerDimension, int length) { ByteSequence bytes = bitsPerDimension.createByteSequence(length); @@ -575,9 +700,11 @@ public static QuantizedSubVector createEmpty(BitsPerDimension bitsPerDimension, } /** - * Read the instance from a RandomAccessReader - * @param in the RandomAccessReader - * @throws IOException fails if we cannot read from the RandomAccessReader + * Deserializes a QuantizedSubVector from a RandomAccessReader by allocating a new instance. + * + * @param in the RandomAccessReader to read from + * @return the deserialized QuantizedSubVector + * @throws IOException if an I/O error occurs during reading */ public static QuantizedSubVector load(RandomAccessReader in) throws IOException { BitsPerDimension bitsPerDimension = BitsPerDimension.load(in); @@ -594,9 +721,12 @@ public static QuantizedSubVector load(RandomAccessReader in) throws IOException } /** - * Read the instance from a RandomAccessReader - * @param in the RandomAccessReader - * @throws IOException fails if we cannot read from the RandomAccessReader + * Deserializes a QuantizedSubVector from a RandomAccessReader into an existing instance. + * This avoids allocating a new QuantizedSubVector instance. + * + * @param in the RandomAccessReader to read from + * @param quantizedSubVector the existing QuantizedSubVector to populate with deserialized data + * @throws IOException if an I/O error occurs during reading */ public static void loadInto(RandomAccessReader in, QuantizedSubVector quantizedSubVector) throws IOException { quantizedSubVector.bitsPerDimension = BitsPerDimension.load(in); @@ -626,8 +756,9 @@ public boolean equals(Object o) { } /** - * The loss used to optimize for the NVQ hyperparameters - * We use the ratio between the loss given by the uniform quantization and the NVQ loss. + * Loss function used to optimize NVQ hyperparameters (growth rate and midpoint). + * The loss is computed as the ratio between uniform quantization loss and NVQ loss, + * where higher values indicate better quantization quality. */ private static class NonuniformQuantizationLossFunction { final private BitsPerDimension bitsPerDimension; @@ -636,10 +767,22 @@ private static class NonuniformQuantizationLossFunction { private float maxValue; private float baseline; + /** + * Constructs a loss function for the given quantization resolution. + * + * @param bitsPerDimension the number of bits per dimension for quantization + */ public NonuniformQuantizationLossFunction(BitsPerDimension bitsPerDimension) { this.bitsPerDimension = bitsPerDimension; } + /** + * Sets the vector to optimize quantization parameters for and computes the baseline loss. + * + * @param vector the vector to quantize + * @param minValue the minimum value in the vector + * @param maxValue the maximum value in the vector + */ public void setVector(VectorFloat vector, float minValue, float maxValue) { this.vector = vector; this.minValue = minValue; @@ -647,10 +790,22 @@ public void setVector(VectorFloat vector, float minValue, float maxValue) { baseline = VectorUtil.nvqUniformLoss(vector, minValue, maxValue, bitsPerDimension.getInt()); } + /** + * Computes the raw NVQ loss for the given parameters. + * + * @param x an array containing [growthRate, midpoint] + * @return the raw loss value (lower is better) + */ public float computeRaw(float[] x) { return VectorUtil.nvqLoss(vector, x[0], x[1], minValue, maxValue, bitsPerDimension.getInt()); } + /** + * Computes the normalized loss as a ratio of baseline to NVQ loss. + * + * @param x an array containing [growthRate, midpoint] + * @return the normalized loss value (higher is better) + */ public float compute(float[] x) { return baseline / computeRaw(x); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index f66a2c6e4..917b8b666 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -34,17 +34,36 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.stream.IntStream; +/** + * Abstract base class for Product Quantization compressed vectors. + * Stores vectors compressed using Product Quantization (PQ) in chunks to avoid exceeding array size limits. + */ public abstract class PQVectors implements CompressedVectors { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + /** The ProductQuantization used for encoding and decoding. */ final ProductQuantization pq; + /** The compressed data chunks storing encoded vectors. */ protected ByteSequence[] compressedDataChunks; + /** The number of vectors stored per chunk. */ protected int vectorsPerChunk; + /** + * Constructs a PQVectors with the given ProductQuantization. + * + * @param pq the ProductQuantization to use for encoding and decoding vectors + */ protected PQVectors(ProductQuantization pq) { this.pq = pq; } + /** + * Loads PQVectors from the given reader. + * + * @param in the reader to load from + * @return the loaded ImmutablePQVectors + * @throws IOException if an I/O error occurs + */ public static ImmutablePQVectors load(RandomAccessReader in) throws IOException { // pq codebooks var pq = ProductQuantization.load(in); @@ -68,6 +87,14 @@ public static ImmutablePQVectors load(RandomAccessReader in) throws IOException return new ImmutablePQVectors(pq, chunks, vectorCount, layout.fullChunkVectors); } + /** + * Loads PQVectors from the given reader at the specified offset. + * + * @param in the reader to load from + * @param offset the offset to seek to before loading + * @return the loaded PQVectors + * @throws IOException if an I/O error occurs + */ public static PQVectors load(RandomAccessReader in, long offset) throws IOException { in.seek(offset); return load(in); @@ -130,7 +157,9 @@ public void write(DataOutput out, int version) throws IOException } /** - * @return the number of chunks that have actually been allocated ({@code <= compressedDataChunks.length}) + * Returns the number of chunks that have actually been allocated. + * + * @return the number of chunks ({@code <= compressedDataChunks.length}) */ protected abstract int validChunkCount(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java index 0e98bbf32..e86260d48 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java @@ -56,11 +56,15 @@ public class ProductQuantization implements VectorCompressor>, Accountable { private static final int MAGIC = 0x75EC4012; // JVECTOR, with some imagination + /** Logger for Product Quantization operations. */ protected static final Logger LOG = Logger.getLogger(ProductQuantization.class.getName()); private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); - static final int DEFAULT_CLUSTERS = 256; // number of clusters per subspace = one byte's worth + /** Default number of clusters per subspace (256 = one byte's worth). */ + static final int DEFAULT_CLUSTERS = 256; + /** Number of K-means iterations to run during codebook creation. */ static final int K_MEANS_ITERATIONS = 6; + /** Maximum number of vectors to use for training PQ codebooks. */ public static final int MAX_PQ_TRAINING_SET_SIZE = 128000; final VectorFloat[] codebooks; // array of codebooks, where each codebook is a VectorFloat consisting of k contiguous subvectors each of length M @@ -76,21 +80,35 @@ public class ProductQuantization implements VectorCompressor>, A private final ThreadLocal> partialQuantizedSums; // for quantized sums during fused ADC private final AtomicReference> partialSquaredMagnitudes; // for cosine partials private final AtomicReference> partialQuantizedSquaredMagnitudes; // for quantized squared magnitude partials during cosine fused ADC - protected volatile float squaredMagnitudeDelta = 0; // for cosine fused ADC squared magnitude quantization delta (since this is invariant for a given PQ) - protected volatile float minSquaredMagnitude = 0; // for cosine fused ADC minimum squared magnitude (invariant for a given PQ) + /** Squared magnitude delta for cosine fused ADC quantization (invariant for a given PQ). */ + protected volatile float squaredMagnitudeDelta = 0; + /** Minimum squared magnitude for cosine fused ADC (invariant for a given PQ). */ + protected volatile float minSquaredMagnitude = 0; /** * Initializes the codebooks by clustering the input data using Product Quantization. * * @param ravv the vectors to quantize * @param M number of subspaces + * @param clusterCount number of clusters per subspace * @param globallyCenter whether to center the vectors globally before quantization * (not recommended when using the quantization for dot product) + * @return a new ProductQuantization instance */ public static ProductQuantization compute(RandomAccessVectorValues ravv, int M, int clusterCount, boolean globallyCenter) { return compute(ravv, M, clusterCount, globallyCenter, UNWEIGHTED, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool()); } + /** + * Initializes the codebooks by clustering the input data using Product Quantization. + * + * @param ravv the vectors to quantize + * @param M number of subspaces + * @param clusterCount number of clusters per subspace + * @param globallyCenter whether to center the vectors globally before quantization + * @param anisotropicThreshold the threshold for anisotropic angular distance shaping + * @return a new ProductQuantization instance + */ public static ProductQuantization compute(RandomAccessVectorValues ravv, int M, int clusterCount, boolean globallyCenter, float anisotropicThreshold) { return compute(ravv, M, clusterCount, globallyCenter, anisotropicThreshold, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool()); } @@ -110,6 +128,7 @@ public static ProductQuantization compute(RandomAccessVectorValues ravv, int M, * @param simdExecutor ForkJoinPool instance for SIMD operations, best is to use a pool with the size of * the number of physical cores. * @param parallelExecutor ForkJoinPool instance for parallel stream operations + * @return a new ProductQuantization instance */ public static ProductQuantization compute(RandomAccessVectorValues ravv, int M, @@ -140,6 +159,13 @@ public static ProductQuantization compute(RandomAccessVectorValues ravv, return new ProductQuantization(codebooks, clusterCount, subvectorSizesAndOffsets, globalCentroid, anisotropicThreshold); } + /** + * Extracts a subset of vectors to use for training the PQ codebooks. + * + * @param ravv the source vectors + * @param parallelExecutor executor for parallel extraction + * @return a list of training vectors + */ static List> extractTrainingVectors(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) { // limit the number of vectors we train on var P = min(1.0f, MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size()); @@ -156,17 +182,25 @@ static List> extractTrainingVectors(RandomAccessVectorValues ravv } /** - * Create a new PQ by fine-tuning this one with the data in `ravv` + * Creates a new PQ by fine-tuning this one with the data in the given vectors. + * Uses default parameters: 1 Lloyd's round, unweighted (isotropic) distance, and default executors. + * + * @param ravv the vectors to use for fine-tuning + * @return a new refined ProductQuantization instance */ public ProductQuantization refine(RandomAccessVectorValues ravv) { return refine(ravv, 1, UNWEIGHTED, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool()); } /** - * Create a new PQ by fine-tuning this one with the data in `ravv` + * Creates a new PQ by fine-tuning this one with the data in the given vectors. * - * @param lloydsRounds number of Lloyd's iterations to run against - * the new data. Suggested values are 1 or 2. + * @param ravv the vectors to use for fine-tuning + * @param lloydsRounds number of Lloyd's iterations to run against the new data. Suggested values are 1 or 2 + * @param anisotropicThreshold the threshold for anisotropic angular distance shaping + * @param simdExecutor ForkJoinPool instance for SIMD operations + * @param parallelExecutor ForkJoinPool instance for parallel stream operations + * @return a new refined ProductQuantization instance */ public ProductQuantization refine(RandomAccessVectorValues ravv, int lloydsRounds, @@ -429,6 +463,9 @@ public void encodeTo(VectorFloat vector, ByteSequence dest) { /** * Decodes the quantized representation (ByteSequence) to its approximate original vector. + * + * @param encoded the quantized vector representation + * @param target the destination vector to write the decoded result into */ public void decode(ByteSequence encoded, VectorFloat target) { decodeCentered(encoded, target); @@ -450,14 +487,18 @@ void decodeCentered(ByteSequence encoded, VectorFloat target) { } /** - * @return how many bytes we are compressing to + * Returns the number of subspaces (subvectors) the vector is divided into. + * + * @return the number of subspaces, equivalent to the compressed vector size in bytes */ public int getSubspaceCount() { return M; } /** - * @return number of clusters per subspace + * Returns the number of clusters per subspace. + * + * @return the number of clusters per subspace */ public int getClusterCount() { return clusterCount; @@ -499,6 +540,11 @@ int closestCentroidIndex(VectorFloat subvector, int m, VectorFloat codeboo /** * Extracts the m-th subvector from a single vector. + * + * @param vector the full vector to extract a subvector from + * @param m the subvector index + * @param subvectorSizeAndOffset the matrix containing size and offset information for each subvector + * @return the extracted subvector */ static VectorFloat getSubVector(VectorFloat vector, int m, int[][] subvectorSizeAndOffset) { VectorFloat subvector = vectorTypeSupport.createFloatVector(subvectorSizeAndOffset[m][0]); @@ -508,6 +554,12 @@ static VectorFloat getSubVector(VectorFloat vector, int m, int[][] subvect /** * Splits the vector dimension into M subvectors of roughly equal size. + * Any remainder dimensions are distributed among the first subvectors. + * + * @param dimensions the total number of dimensions in the vector + * @param M the number of subvectors to create + * @return a matrix where each row contains [size, offset] for a subvector + * @throws IllegalArgumentException if M is greater than dimensions */ @VisibleForTesting static int[][] getSubvectorSizesAndOffsets(int dimensions, int M) { @@ -527,22 +579,47 @@ static int[][] getSubvectorSizesAndOffsets(int dimensions, int M) { return sizes; } + /** + * Returns a thread-local reusable vector for storing partial sums. + * + * @return a reusable vector for partial sums + */ VectorFloat reusablePartialSums() { return partialSums.get(); } + /** + * Returns a thread-local reusable byte sequence for storing quantized partial sums. + * + * @return a reusable byte sequence for quantized partial sums + */ ByteSequence reusablePartialQuantizedSums() { return partialQuantizedSums.get(); } + /** + * Returns a thread-local reusable vector for storing partial best distances. + * + * @return a reusable vector for partial best distances + */ VectorFloat reusablePartialBestDistances() { return partialBestDistances.get(); } + /** + * Returns the atomic reference to partial squared magnitudes for cosine similarity. + * + * @return an atomic reference to the partial squared magnitudes vector + */ AtomicReference> partialSquaredMagnitudes() { return partialSquaredMagnitudes; } + /** + * Returns the atomic reference to quantized partial squared magnitudes for cosine similarity. + * + * @return an atomic reference to the quantized partial squared magnitudes + */ AtomicReference> partialQuantizedSquaredMagnitudes() { return partialQuantizedSquaredMagnitudes; } @@ -594,6 +671,7 @@ public void write(DataOutput out, int version) throws IOException * Since the dot product is commutative, we only need to store the upper triangle of the matrix. * There are M codebooks, and each codebook has k centroids, so the total number of partial sums is M * k * (k+1) / 2. * + * @param vectorSimilarityFunction the similarity function to use for computing partial sums * @return a vector to hold partial sums for a single codebook */ public VectorFloat createCodebookPartialSums(VectorSimilarityFunction vectorSimilarityFunction) { @@ -636,6 +714,13 @@ public int compressorSize() { return size; } + /** + * Deserializes a ProductQuantization instance from a RandomAccessReader. + * + * @param in the RandomAccessReader to read from + * @return the deserialized ProductQuantization instance + * @throws IOException if an I/O error occurs during reading + */ public static ProductQuantization load(RandomAccessReader in) throws IOException { int maybeMagic = in.readInt(); int version; @@ -705,7 +790,10 @@ public int hashCode() { } /** - * @return the centroid of the codebooks + * Returns the global centroid of this PQ, or computes the centroid of the codebooks if no global centroid was set. + * The global centroid is used to center vectors before quantization when {@code globallyCenter} was set to true during PQ creation. + * + * @return the centroid vector */ public VectorFloat getOrComputeCentroid() { if (globalCentroid != null) { @@ -766,6 +854,11 @@ private static void checkClusterCount(int clusterCount) { } } + /** + * Returns the dimensionality of the original uncompressed vectors. + * + * @return the original vector dimension + */ public int getOriginalDimension() { return originalDimension; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java index 09eb1e035..77e83cd16 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java @@ -30,9 +30,16 @@ /** * Interface for vector compression. T is the encoded (compressed) vector type; * it will be an array type. + * + * @param the encoded (compressed) vector type, which will be an array type */ public interface VectorCompressor { + /** + * Encodes all vectors in the RandomAccessVectorValues using the default executor. + * @param ravv the vectors to encode + * @return the compressed vectors + */ default CompressedVectors encodeAll(RandomAccessVectorValues ravv) { return encodeAll(ravv, PhysicalCoreExecutor.pool()); } @@ -46,32 +53,64 @@ default CompressedVectors encodeAll(RandomAccessVectorValues ravv) { */ CompressedVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool simdExecutor); + /** + * Encodes a single vector into the compressed format. + * + * @param v the vector to encode + * @return the encoded (compressed) vector + */ T encode(VectorFloat v); + /** + * Encodes a single vector into the compressed format, storing the result in the provided destination. + * + * @param v the vector to encode + * @param dest the destination array to store the encoded vector + */ void encodeTo(VectorFloat v, T dest); /** + * Writes the compressor configuration to the given output stream. + * * @param out DataOutput to write to * @param version serialization version. Versions 2 and 3 are supported + * @throws IOException if an I/O error occurs during writing */ void write(DataOutput out, int version) throws IOException; - /** Write with the current serialization version */ + /** + * Writes the compressor configuration to the given output stream using the current serialization version. + * + * @param out DataOutput to write to + * @throws IOException if an I/O error occurs during writing + */ default void write(DataOutput out) throws IOException { write(out, OnDiskGraphIndex.CURRENT_VERSION); } /** + * Creates a CompressedVectors instance from an array of compressed vectors. + * * @param compressedVectors must match the type T for this VectorCompressor, but * it is declared as Object because we want callers to be able to use this * without committing to a specific type T. + * @return a CompressedVectors instance wrapping the given array + * @deprecated Use {@link #encodeAll(RandomAccessVectorValues)} instead */ @Deprecated CompressedVectors createCompressedVectors(Object[] compressedVectors); - /** the size of the serialized compressor itself (NOT the size of compressed vectors) */ + /** + * Returns the size of the serialized compressor itself (NOT the size of compressed vectors). + * + * @return the size in bytes of the compressor configuration when serialized + */ int compressorSize(); - /** the size of a compressed vector */ + /** + * Returns the size of a single compressed vector. + * + * @return the size in bytes of a compressed vector + */ int compressedVectorSize(); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/package-info.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/package-info.java new file mode 100644 index 000000000..32f447398 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/package-info.java @@ -0,0 +1,60 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides vector quantization implementations for reducing memory footprint and improving search performance. + *

+ * This package contains multiple quantization techniques: + *

    + *
  • Binary Quantization (BQ) - Compresses vectors to binary representations using + * {@link io.github.jbellis.jvector.quantization.BinaryQuantization}. This provides the highest + * compression ratio at the cost of some accuracy. Binary quantized vectors are stored in + * {@link io.github.jbellis.jvector.quantization.BQVectors}.
  • + *
  • Product Quantization (PQ) - Divides vectors into subvectors and quantizes each independently + * using {@link io.github.jbellis.jvector.quantization.ProductQuantization}. This balances compression + * ratio and accuracy. Product quantized vectors are stored in + * {@link io.github.jbellis.jvector.quantization.PQVectors}.
  • + *
  • Neighborhood Vector Quantization (NVQ) - A variant of PQ that uses neighborhood information + * to improve quantization quality, implemented in {@link io.github.jbellis.jvector.quantization.NVQuantization}. + * NVQ vectors are stored in {@link io.github.jbellis.jvector.quantization.NVQVectors}.
  • + *
+ *

+ * All quantization methods implement the {@link io.github.jbellis.jvector.quantization.VectorCompressor} + * interface, which provides methods for encoding vectors and persisting the compressed representation. + *

+ * The {@link io.github.jbellis.jvector.quantization.CompressedVectors} interface represents the + * compressed form of vectors and provides methods for similarity scoring between compressed vectors + * and both compressed and uncompressed query vectors. + *

+ * Usage Example: + *

{@code
+ * // Create a Product Quantization compressor with 16 subvectors and 256 clusters per subvector
+ * ProductQuantization pq = ProductQuantization.compute(vectors, 16, 256);
+ *
+ * // Encode all vectors
+ * CompressedVectors compressed = pq.encodeAll(vectors);
+ *
+ * // Perform similarity scoring
+ * float score = compressed.score(queryVector, vectorOrdinal);
+ * }
+ * + * @see io.github.jbellis.jvector.quantization.VectorCompressor + * @see io.github.jbellis.jvector.quantization.CompressedVectors + * @see io.github.jbellis.jvector.quantization.BinaryQuantization + * @see io.github.jbellis.jvector.quantization.ProductQuantization + * @see io.github.jbellis.jvector.quantization.NVQuantization + */ +package io.github.jbellis.jvector.quantization; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/AbstractLongHeap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/AbstractLongHeap.java index 0780500ed..0999b49ef 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/AbstractLongHeap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/AbstractLongHeap.java @@ -29,7 +29,7 @@ /** * A min heap that stores longs; a primitive priority queue that like all priority queues maintains - * a partial ordering of its elements such that the leastbo element can always be found in constant + * a partial ordering of its elements such that the least element can always be found in constant * time. Push()'s and pop()'s require log(size). {@link #push(long)} may either grow the heap or * replace the worst element, depending on the subclass implementation. *

@@ -37,7 +37,15 @@ */ public abstract class AbstractLongHeap { + /** + * The array-based heap storage. This array is 1-indexed; element at index 0 is unused. + * Child nodes of node at index i are at indices 2*i and 2*i+1. + */ protected long[] heap; + + /** + * The number of elements currently stored in the heap. + */ protected int size = 0; /** @@ -60,6 +68,7 @@ public AbstractLongHeap(int initialSize) { /** * Adds a value to an LongHeap in log(size) time. * + * @param element the value to add to the heap * @return true if the new value was added. (A fixed-size heap will not add the new value * if it is full, and the new value is worse than the existing ones.) */ @@ -74,6 +83,13 @@ public AbstractLongHeap(int initialSize) { */ public abstract void pushMany(PrimitiveIterator.OfLong elements, int elementsSize); + /** + * Adds an element to the heap, growing the underlying array if necessary. + * After insertion, performs upHeap operation to maintain heap property. + * + * @param element the value to add to the heap + * @return the top (minimum) element of the heap after insertion + */ protected long add(long element) { size++; if (size == heap.length) { @@ -123,6 +139,8 @@ protected void addMany(PrimitiveIterator.OfLong elements, int elementsSize) { * Returns the least element of the LongHeap in constant time. It is up to the caller to verify * that the heap is not empty; no checking is done, and if no elements have been added, 0 is * returned. + * + * @return the minimum element in the heap, or 0 if the heap is empty */ public final long top() { return heap[1]; @@ -131,7 +149,8 @@ public final long top() { /** * Removes and returns the least element of the PriorityQueue in log(size) time. * - * @throws IllegalStateException if the LongHeap is empty. + * @return the minimum element that was removed from the heap + * @throws IllegalStateException if the LongHeap is empty */ public final long pop() { if (size > 0) { @@ -145,7 +164,11 @@ public final long pop() { } } - /** Returns the number of elements currently stored in the PriorityQueue. */ + /** + * Returns the number of elements currently stored in the PriorityQueue. + * + * @return the current number of elements in the heap + */ public final int size() { return size; } @@ -155,6 +178,13 @@ public final void clear() { size = 0; } + /** + * Restores the min-heap property by moving an element up the heap. + * Starting from the given position, compares the element with its parent and swaps if necessary, + * continuing until the heap property is satisfied or the root is reached. + * + * @param origPos the position in the heap array from which to start the upheap operation + */ protected void upHeap(int origPos) { int i = origPos; long value = heap[i]; // save bottom value @@ -167,6 +197,13 @@ protected void upHeap(int origPos) { heap[i] = value; // install saved value } + /** + * Restores the min-heap property by moving an element down the heap. + * Starting from the given position, compares the element with its smaller child and swaps if necessary, + * continuing until the heap property is satisfied or a leaf is reached. + * + * @param i the position in the heap array from which to start the downheap operation + */ protected void downHeap(int i) { long value = heap[i]; // save top value int j = i << 1; // find smaller child @@ -187,20 +224,33 @@ protected void downHeap(int i) { } /** - * Return the element at the ith location in the heap array. Use for iterating over elements when + * Returns the element at the ith location in the heap array. Use for iterating over elements when * the order doesn't matter. Note that the valid arguments range from [1, size]. + * + * @param i the index in the heap array (must be in range [1, size]) + * @return the element at the specified position */ public long get(int i) { return heap[i]; } + /** + * Returns the internal heap array for testing purposes. + * The array is 1-indexed with element at index 0 unused. + * + * @return the internal heap array + */ @VisibleForTesting long[] getHeapArray() { return heap; } /** - * Copies the contents and current size from `other`. Does NOT copy subclass field like BLH's maxSize + * Copies the contents and current size from another heap. + * Ensures this heap has sufficient capacity, then copies the heap array and size. + * Note: Does NOT copy subclass-specific fields such as BoundedLongHeap's maxSize. + * + * @param other the heap to copy from */ public void copyFrom(AbstractLongHeap other) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/Accountable.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/Accountable.java index ca5023d12..e5c52252c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/Accountable.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/Accountable.java @@ -24,6 +24,17 @@ package io.github.jbellis.jvector.util; +/** + * An interface for objects that can report their memory usage. + * This allows tracking of RAM consumption for data structures and cached objects. + */ public interface Accountable { + /** + * Returns an estimate of the memory usage of this object in bytes. + * The estimate should include the object itself and any referenced objects + * that are not shared with other data structures. + * + * @return the estimated memory usage in bytes + */ long ramBytesUsed(); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ArrayUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ArrayUtil.java index 9b949b09d..1b4c8d0b7 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ArrayUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ArrayUtil.java @@ -30,6 +30,10 @@ * Methods for manipulating arrays. */ public final class ArrayUtil { + /** + * The maximum length of an array that can be allocated. + * This accounts for the array header size to prevent integer overflow. + */ public static final int MAX_ARRAY_LENGTH = Integer.MAX_VALUE - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; @@ -48,8 +52,10 @@ private ArrayUtil() {} // no instance * specifies the radix to use when parsing the value. * * @param chars a string representation of an int quantity. + * @param offset the starting position in the char array. + * @param len the number of characters to parse. * @param radix the base to use for conversion. - * @return int the value represented by the argument + * @return the value represented by the argument * @throws NumberFormatException if the argument could not be parsed as an int quantity. */ public static int parseInt(char[] chars, int offset, int len, int radix) @@ -72,6 +78,21 @@ public static int parseInt(char[] chars, int offset, int len, int radix) return parse(chars, offset, len, radix, negative); } + /** + * Internal parsing implementation that converts characters to an integer value. + *

+ * This method performs the actual numeric conversion after initial validation has been done. + * All arithmetic is performed in negative space to handle the full range of int values including + * Integer.MIN_VALUE. + * + * @param chars the character array containing the numeric representation + * @param offset the starting position in the char array + * @param len the number of characters to parse + * @param radix the base to use for conversion (must be between Character.MIN_RADIX and Character.MAX_RADIX) + * @param negative whether the result should be negative + * @return the parsed integer value + * @throws NumberFormatException if the characters could not be parsed as an int quantity + */ private static int parse(char[] chars, int offset, int len, int radix, boolean negative) throws NumberFormatException { int max = Integer.MIN_VALUE / radix; @@ -118,6 +139,7 @@ private static int parse(char[] chars, int offset, int len, int radix, boolean n * @param minTargetSize Minimum required value to be returned. * @param bytesPerElement Bytes used by each element of the array. See constants in {@link * RamUsageEstimator}. + * @return the calculated array size that is at least minTargetSize, optimized for memory alignment. */ public static int oversize(int minTargetSize, int bytesPerElement) { @@ -202,7 +224,12 @@ public static int oversize(int minTargetSize, int bytesPerElement) { } /** - * Returns a new array whose size is exact the specified {@code newLength} without over-allocating + * Returns a new array whose size is exact the specified {@code newLength} without over-allocating. + * + * @param the component type of the array + * @param array the original array to grow + * @param newLength the exact size of the new array + * @return a new array with the specified length containing the original array's elements */ public static T[] growExact(T[] array, int newLength) { Class type = array.getClass(); @@ -217,7 +244,12 @@ public static T[] growExact(T[] array, int newLength) { /** * Returns an array whose size is at least {@code minSize}, generally over-allocating - * exponentially + * exponentially. + * + * @param the component type of the array + * @param array the original array to grow + * @param minSize the minimum required size + * @return the original array if it is already large enough, otherwise a new larger array */ public static T[] grow(T[] array, int minSize) { assert minSize >= 0 : "size must be positive (got " + minSize + "): likely integer overflow?"; @@ -228,7 +260,11 @@ public static T[] grow(T[] array, int minSize) { } /** - * Returns a new array whose size is exact the specified {@code newLength} without over-allocating + * Returns a new array whose size is exact the specified {@code newLength} without over-allocating. + * + * @param array the original array to grow + * @param newLength the exact size of the new array + * @return a new array with the specified length containing the original array's elements */ public static short[] growExact(short[] array, int newLength) { short[] copy = new short[newLength]; @@ -238,7 +274,11 @@ public static short[] growExact(short[] array, int newLength) { /** * Returns an array whose size is at least {@code minSize}, generally over-allocating - * exponentially + * exponentially. + * + * @param array the original array to grow + * @param minSize the minimum required size + * @return the original array if it is already large enough, otherwise a new larger array */ public static short[] grow(short[] array, int minSize) { assert minSize >= 0 : "size must be positive (got " + minSize + "): likely integer overflow?"; @@ -248,7 +288,11 @@ public static short[] grow(short[] array, int minSize) { } /** - * Returns a new array whose size is exact the specified {@code newLength} without over-allocating + * Returns a new array whose size is exact the specified {@code newLength} without over-allocating. + * + * @param array the original array to grow + * @param newLength the exact size of the new array + * @return a new array with the specified length containing the original array's elements */ public static float[] growExact(float[] array, int newLength) { float[] copy = new float[newLength]; @@ -258,7 +302,11 @@ public static float[] growExact(float[] array, int newLength) { /** * Returns an array whose size is at least {@code minSize}, generally over-allocating - * exponentially + * exponentially. + * + * @param array the original array to grow + * @param minSize the minimum required size + * @return the original array if it is already large enough, otherwise a new larger array */ public static float[] grow(float[] array, int minSize) { assert minSize >= 0 : "size must be positive (got " + minSize + "): likely integer overflow?"; @@ -270,7 +318,11 @@ public static float[] grow(float[] array, int minSize) { } /** - * Returns a new array whose size is exact the specified {@code newLength} without over-allocating + * Returns a new array whose size is exact the specified {@code newLength} without over-allocating. + * + * @param array the original array to grow + * @param newLength the exact size of the new array + * @return a new array with the specified length containing the original array's elements */ public static double[] growExact(double[] array, int newLength) { double[] copy = new double[newLength]; @@ -280,7 +332,11 @@ public static double[] growExact(double[] array, int newLength) { /** * Returns an array whose size is at least {@code minSize}, generally over-allocating - * exponentially + * exponentially. + * + * @param array the original array to grow + * @param minSize the minimum required size + * @return the original array if it is already large enough, otherwise a new larger array */ public static double[] grow(double[] array, int minSize) { assert minSize >= 0 : "size must be positive (got " + minSize + "): likely integer overflow?"; @@ -290,7 +346,11 @@ public static double[] grow(double[] array, int minSize) { } /** - * Returns a new array whose size is exact the specified {@code newLength} without over-allocating + * Returns a new array whose size is exact the specified {@code newLength} without over-allocating. + * + * @param array the original array to grow + * @param newLength the exact size of the new array + * @return a new array with the specified length containing the original array's elements */ public static int[] growExact(int[] array, int newLength) { int[] copy = new int[newLength]; @@ -300,7 +360,11 @@ public static int[] growExact(int[] array, int newLength) { /** * Returns an array whose size is at least {@code minSize}, generally over-allocating - * exponentially + * exponentially. + * + * @param array the original array to grow + * @param minSize the minimum required size + * @return the original array if it is already large enough, otherwise a new larger array */ public static int[] grow(int[] array, int minSize) { assert minSize >= 0 : "size must be positive (got " + minSize + "): likely integer overflow?"; @@ -309,13 +373,22 @@ public static int[] grow(int[] array, int minSize) { } else return array; } - /** Returns a larger array, generally over-allocating exponentially */ + /** + * Returns a larger array, generally over-allocating exponentially. + * + * @param array the original array to grow + * @return a new array larger than the original + */ public static int[] grow(int[] array) { return grow(array, 1 + array.length); } /** - * Returns a new array whose size is exact the specified {@code newLength} without over-allocating + * Returns a new array whose size is exact the specified {@code newLength} without over-allocating. + * + * @param array the original array to grow + * @param newLength the exact size of the new array + * @return a new array with the specified length containing the original array's elements */ public static long[] growExact(long[] array, int newLength) { long[] copy = new long[newLength]; @@ -325,7 +398,11 @@ public static long[] growExact(long[] array, int newLength) { /** * Returns an array whose size is at least {@code minSize}, generally over-allocating - * exponentially + * exponentially. + * + * @param array the original array to grow + * @param minSize the minimum required size + * @return the original array if it is already large enough, otherwise a new larger array */ public static long[] grow(long[] array, int minSize) { assert minSize >= 0 : "size must be positive (got " + minSize + "): likely integer overflow?"; @@ -335,7 +412,11 @@ public static long[] grow(long[] array, int minSize) { } /** - * Returns a new array whose size is exact the specified {@code newLength} without over-allocating + * Returns a new array whose size is exact the specified {@code newLength} without over-allocating. + * + * @param array the original array to grow + * @param newLength the exact size of the new array + * @return a new array with the specified length containing the original array's elements */ public static byte[] growExact(byte[] array, int newLength) { byte[] copy = new byte[newLength]; @@ -345,7 +426,11 @@ public static byte[] growExact(byte[] array, int newLength) { /** * Returns an array whose size is at least {@code minSize}, generally over-allocating - * exponentially + * exponentially. + * + * @param array the original array to grow + * @param minSize the minimum required size + * @return the original array if it is already large enough, otherwise a new larger array */ public static byte[] grow(byte[] array, int minSize) { assert minSize >= 0 : "size must be positive (got " + minSize + "): likely integer overflow?"; @@ -355,7 +440,11 @@ public static byte[] grow(byte[] array, int minSize) { } /** - * Returns a new array whose size is exact the specified {@code newLength} without over-allocating + * Returns a new array whose size is exact the specified {@code newLength} without over-allocating. + * + * @param array the original array to grow + * @param newLength the exact size of the new array + * @return a new array with the specified length containing the original array's elements */ public static char[] growExact(char[] array, int newLength) { char[] copy = new char[newLength]; @@ -365,7 +454,11 @@ public static char[] growExact(char[] array, int newLength) { /** * Returns an array whose size is at least {@code minSize}, generally over-allocating - * exponentially + * exponentially. + * + * @param array the original array to grow + * @param minSize the minimum required size + * @return the original array if it is already large enough, otherwise a new larger array */ public static char[] grow(char[] array, int minSize) { assert minSize >= 0 : "size must be positive (got " + minSize + "): likely integer overflow?"; @@ -380,6 +473,7 @@ public static char[] grow(char[] array, int minSize) { * @param array the input array * @param from the initial index of range to be copied (inclusive) * @param to the final index of range to be copied (exclusive) + * @return a new array containing the specified range from the input array */ public static byte[] copyOfSubArray(byte[] array, int from, int to) { final byte[] copy = new byte[to - from]; @@ -393,6 +487,7 @@ public static byte[] copyOfSubArray(byte[] array, int from, int to) { * @param array the input array * @param from the initial index of range to be copied (inclusive) * @param to the final index of range to be copied (exclusive) + * @return a new array containing the specified range from the input array */ public static int[] copyOfSubArray(int[] array, int from, int to) { final int[] copy = new int[to - from]; @@ -406,6 +501,7 @@ public static int[] copyOfSubArray(int[] array, int from, int to) { * @param array the input array * @param from the initial index of range to be copied (inclusive) * @param to the final index of range to be copied (exclusive) + * @return a new array containing the specified range from the input array */ public static float[] copyOfSubArray(float[] array, int from, int to) { final float[] copy = new float[to - from]; @@ -416,9 +512,11 @@ public static float[] copyOfSubArray(float[] array, int from, int to) { /** * Copies the specified range of the given array into a new sub array. * + * @param the component type of the array * @param array the input array * @param from the initial index of range to be copied (inclusive) * @param to the final index of range to be copied (exclusive) + * @return a new array containing the specified range from the input array */ public static T[] copyOfSubArray(T[] array, int from, int to) { final int subLength = to - from; @@ -438,6 +536,7 @@ public static T[] copyOfSubArray(T[] array, int from, int to) { * @param array the input array * @param from the initial index of range to be copied (inclusive) * @param to the final index of range to be copied (exclusive) + * @return a new array containing the specified range from the input array */ public static long[] copyOfSubArray(long[] array, int from, int to) { final long[] copy = new long[to - from]; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/AtomicFixedBitSet.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/AtomicFixedBitSet.java index e17d9d5ce..f63cc0dc7 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/AtomicFixedBitSet.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/AtomicFixedBitSet.java @@ -31,6 +31,11 @@ public class AtomicFixedBitSet extends BitSet { private final AtomicLongArray storage; + /** + * Creates an AtomicFixedBitSet with the specified number of bits. + * All bits are initially unset (false). + * @param numBits the number of bits in the set + */ public AtomicFixedBitSet(int numBits) { int numLongs = (numBits + 63) >>> 6; storage = new AtomicLongArray(numLongs); @@ -186,6 +191,10 @@ public long ramBytesUsed() { return BASE_RAM_BYTES_USED + storageSize; } + /** + * Creates a copy of this AtomicFixedBitSet. + * @return a new AtomicFixedBitSet with the same bit values + */ public AtomicFixedBitSet copy() { AtomicFixedBitSet copy = new AtomicFixedBitSet(length()); for (int i = 0; i < storage.length(); i++) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/BitSet.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/BitSet.java index 524c04153..298e5a894 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/BitSet.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/BitSet.java @@ -28,6 +28,8 @@ * Base implementation for a bit set. */ public abstract class BitSet implements Bits, Accountable { + /** Creates a BitSet instance. */ + public BitSet() {} /** * Clear all the bits of the set. * @@ -38,16 +40,29 @@ public void clear() { clear(0, length()); } - /** The number of bits in the set. */ + /** + * Returns the number of bits in the set. + * @return the number of bits in the set + */ public abstract int length(); - /** Set the bit at i. */ + /** + * Sets the bit at the specified index. + * @param i the index of the bit to set + */ public abstract void set(int i); - /** Set the bit at i, returning true if it was previously set. */ + /** + * Sets the bit at the specified index and returns its previous value. + * @param i the index of the bit to set + * @return {@code true} if the bit was previously set, {@code false} otherwise + */ public abstract boolean getAndSet(int i); - /** Clear the bit at i. */ + /** + * Clears the bit at the specified index. + * @param i the index of the bit to clear + */ public abstract void clear(int i); /** @@ -58,25 +73,33 @@ public void clear() { */ public abstract void clear(int startIndex, int endIndex); - /** Return the number of bits that are set. NOTE: this method is likely to run in linear time */ + /** + * Returns the number of bits that are set. + *

+ * NOTE: this method is likely to run in linear time. + * @return the number of bits that are set + */ public abstract int cardinality(); /** - * Return an approximation of the cardinality of this set. Some implementations may trade accuracy + * Returns an approximation of the cardinality of this set. Some implementations may trade accuracy * for speed if they have the ability to estimate the cardinality of the set without iterating * over all the data. The default implementation returns {@link #cardinality()}. + * @return an approximation of the number of bits that are set */ public abstract int approximateCardinality(); /** - * Returns the index of the last set bit before or on the index specified. -1 is returned if there - * are no more set bits. + * Returns the index of the last set bit before or on the index specified. + * @param index the index to start searching backwards from (inclusive) + * @return the index of the previous set bit, or -1 if there are no more set bits */ public abstract int prevSetBit(int index); /** - * Returns the index of the first set bit starting at the index specified. {@link - * DocIdSetIterator#NO_MORE_DOCS} is returned if there are no more set bits. + * Returns the index of the first set bit starting at the index specified. + * @param index the index to start searching from (inclusive) + * @return the index of the next set bit, or {@link DocIdSetIterator#NO_MORE_DOCS} if there are no more set bits */ public abstract int nextSetBit(int index); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/Bits.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/Bits.java index c48f2de0a..88301be16 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/Bits.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/Bits.java @@ -25,10 +25,17 @@ package io.github.jbellis.jvector.util; /** - * Interface for Bitset-like structures. + * Interface for Bitset-like structures that provide read-only bit access. + *

+ * This interface is used for filtering operations where bits represent the presence + * or absence of elements. It provides constant instances for common cases and utility + * methods for combining Bits instances. */ public interface Bits { + /** A Bits instance where all bits are set. */ Bits ALL = new MatchAllBits(); + + /** A Bits instance where no bits are set. */ Bits NONE = new MatchNoBits(); /** @@ -41,7 +48,10 @@ public interface Bits { boolean get(int index); /** - * Returns a Bits that is true when `bits` is false, and false when `bits` is true + * Returns a Bits instance that is the inverse of the given Bits. + * The result is {@code true} when {@code bits} is {@code false}, and vice versa. + * @param bits the Bits to invert + * @return a Bits instance representing the inverse */ static Bits inverseOf(Bits bits) { return new Bits() { @@ -53,7 +63,11 @@ public boolean get(int index) { } /** - * Return a Bits that is set for a given ordinal iff both it is set in both `a` and `b`. + * Returns a Bits instance representing the intersection of two Bits instances. + * A bit is set in the result if and only if it is set in both {@code a} and {@code b}. + * @param a the first Bits instance + * @param b the second Bits instance + * @return a Bits instance representing the intersection of {@code a} and {@code b} */ static Bits intersectionOf(Bits a, Bits b) { if (a instanceof MatchAllBits) { @@ -78,16 +92,26 @@ public boolean get(int index) { }; } - /** Bits with all bits set. */ + /** + * A Bits implementation where all bits are set. + */ class MatchAllBits implements Bits { + /** Creates a MatchAllBits instance. */ + public MatchAllBits() {} + @Override public boolean get(int index) { return true; } } - /** Bits with no bits set. */ + /** + * A Bits implementation where no bits are set. + */ class MatchNoBits implements Bits { + /** Creates a MatchNoBits instance. */ + public MatchNoBits() {} + @Override public boolean get(int index) { return false; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/BoundedLongHeap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/BoundedLongHeap.java index 62b6d6dd5..d00118b76 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/BoundedLongHeap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/BoundedLongHeap.java @@ -43,11 +43,21 @@ public BoundedLongHeap(int maxSize) { this(maxSize, maxSize); } + /** + * Creates an empty heap with the specified initial and maximum sizes. + * @param initialSize the initial capacity of the heap + * @param maxSize the maximum size the heap can grow to + */ public BoundedLongHeap(int initialSize, int maxSize) { super(initialSize); this.maxSize = maxSize; } + /** + * Sets the maximum size of the heap. + * @param maxSize the new maximum size + * @throws IllegalArgumentException if maxSize is smaller than the current size + */ public void setMaxSize(int maxSize) { if (size > maxSize) { throw new IllegalArgumentException("Cannot set maxSize smaller than current size"); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java index 683cfb5dc..c8c2ae438 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java @@ -30,6 +30,8 @@ *

* "Dense-ish" means that space is allocated for all keys from 0 to the highest key, but * it is valid to have gaps in the keys. The value associated with "gap" keys is null. + * + * @param the type of values stored in this map */ public class DenseIntMap implements IntMap { // locking strategy: @@ -40,6 +42,11 @@ public class DenseIntMap implements IntMap { private volatile AtomicReferenceArray objects; private final AtomicInteger size; + /** + * Constructs a new DenseIntMap with the specified initial capacity. + * + * @param initialCapacity the initial capacity of the map + */ public DenseIntMap(int initialCapacity) { objects = new AtomicReferenceArray<>(initialCapacity); size = new AtomicInteger(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/DocIdSetIterator.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/DocIdSetIterator.java index 1cd9f153f..178d2a557 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/DocIdSetIterator.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/DocIdSetIterator.java @@ -24,6 +24,19 @@ package io.github.jbellis.jvector.util; +/** + * Utility class for document ID iteration. + * Provides constants used during iteration over document sets. + */ public class DocIdSetIterator { + /** + * Sentinel value indicating that there are no more documents to iterate over. + */ public static final int NO_MORE_DOCS = Integer.MAX_VALUE; + + /** + * Constructs a DocIdSetIterator. + */ + public DocIdSetIterator() { + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExceptionUtils.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExceptionUtils.java index d5dc9c350..f15027e93 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExceptionUtils.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExceptionUtils.java @@ -18,7 +18,25 @@ import java.io.IOException; +/** + * Utility methods for exception handling. + */ public class ExceptionUtils { + /** + * Private constructor to prevent instantiation. + */ + private ExceptionUtils() { + } + + /** + * Rethrows the given throwable as an IOException or RuntimeException. + * If the throwable is already an IOException, it is thrown directly. + * If it's a RuntimeException or Error, it is also thrown directly. + * Otherwise, it is wrapped in a RuntimeException. + * + * @param t the throwable to rethrow + * @throws IOException if t is an IOException + */ public static void throwIoException(Throwable t) throws IOException { if (t instanceof RuntimeException) { throw (RuntimeException) t; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExplicitThreadLocal.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExplicitThreadLocal.java index 1833818bd..027f1c8a1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExplicitThreadLocal.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExplicitThreadLocal.java @@ -36,8 +36,16 @@ * ExplicitThreadLocal also implements AutoCloseable to cleanup non-GC'd resources. *

* ExplicitThreadLocal is a drop-in replacement for ThreadLocal, and is used in the same way. + * + * @param the type of thread-local values stored in this instance */ public abstract class ExplicitThreadLocal implements AutoCloseable { + /** + * Constructs an ExplicitThreadLocal. + */ + protected ExplicitThreadLocal() { + } + // thread id -> instance private final ConcurrentHashMap map = new ConcurrentHashMap<>(); @@ -46,10 +54,22 @@ public abstract class ExplicitThreadLocal implements AutoCloseable { // it just once here as a field instead. private final Function initialSupplier = k -> initialValue(); + /** + * Returns the current thread's copy of this thread-local variable. + * If this is the first call by the current thread, initializes the value by calling {@link #initialValue()}. + * + * @return the current thread's value of this thread-local + */ public U get() { return map.computeIfAbsent(Thread.currentThread().getId(), initialSupplier); } + /** + * Returns the initial value for this thread-local variable. + * This method will be invoked the first time a thread accesses the variable with {@link #get()}. + * + * @return the initial value for this thread-local + */ protected abstract U initialValue(); /** @@ -67,6 +87,13 @@ public void close() throws Exception { map.clear(); } + /** + * Creates an explicit thread local variable with the given initial value supplier. + * + * @param the type of the thread local's value + * @param initialValue the supplier to be used to determine the initial value + * @return a new ExplicitThreadLocal instance + */ public static ExplicitThreadLocal withInitial(Supplier initialValue) { return new ExplicitThreadLocal<>() { @Override diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/FixedBitSet.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/FixedBitSet.java index f290f7328..9836a5b33 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/FixedBitSet.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/FixedBitSet.java @@ -45,6 +45,10 @@ public final class FixedBitSet extends BitSet { *

NOTE: the returned bitset reuses the underlying {@code long[]} of the given {@code * bits} if possible. Also, calling {@link #length()} on the returned bits may return a value * greater than {@code numBits}. + * + * @param bits the existing FixedBitSet to check capacity of + * @param numBits the number of bits needed + * @return the given bits if large enough, otherwise a new FixedBitSet with sufficient capacity */ public static FixedBitSet ensureCapacity(FixedBitSet bits, int numBits) { if (numBits < bits.numBits) { @@ -61,7 +65,12 @@ public static FixedBitSet ensureCapacity(FixedBitSet bits, int numBits) { } } - /** returns the number of 64 bit words it would take to hold numBits */ + /** + * Returns the number of 64-bit words it would take to hold numBits. + * + * @param numBits the number of bits + * @return the number of long words needed to hold numBits + */ public static int bits2words(int numBits) { // I.e.: get the word-offset of the last bit and add one (make sure to use >> so 0 // returns 0!) @@ -71,6 +80,10 @@ public static int bits2words(int numBits) { /** * Returns the popcount or cardinality of the intersection of the two sets. Neither set is * modified. + * + * @param a the first bitset + * @param b the second bitset + * @return the number of bits set in both sets */ public static long intersectionCount(FixedBitSet a, FixedBitSet b) { // Depends on the ghost bits being clear! @@ -82,7 +95,13 @@ public static long intersectionCount(FixedBitSet a, FixedBitSet b) { return tot; } - /** Returns the popcount or cardinality of the union of the two sets. Neither set is modified. */ + /** + * Returns the popcount or cardinality of the union of the two sets. Neither set is modified. + * + * @param a the first bitset + * @param b the second bitset + * @return the number of bits set in either or both sets + */ public static long unionCount(FixedBitSet a, FixedBitSet b) { // Depends on the ghost bits being clear! long tot = 0; @@ -102,6 +121,10 @@ public static long unionCount(FixedBitSet a, FixedBitSet b) { /** * Returns the popcount or cardinality of "a and not b" or "intersection(a, not(b))". Neither set * is modified. + * + * @param a the first bitset + * @param b the second bitset + * @return the number of bits set in a but not in b */ public static long andNotCount(FixedBitSet a, FixedBitSet b) { // Depends on the ghost bits being clear! @@ -176,7 +199,11 @@ public int length() { return numBits; } - /** Expert. */ + /** + * Returns the backing long[] array for expert use. + * + * @return the backing long array + */ public long[] getBits() { return bits; } @@ -184,6 +211,8 @@ public long[] getBits() { /** * Returns number of set bits. NOTE: this visits every long in the backing bits array, and the * result is not internally cached! + * + * @return the number of bits set to true in this bitset */ @Override public int cardinality() { @@ -257,6 +286,12 @@ public void clear(int index) { bits[wordNum] &= ~bitmask; } + /** + * Gets the bit at the specified index and clears it atomically. + * + * @param index the bit index + * @return the previous value of the bit + */ public boolean getAndClear(int index) { assert index >= 0 && index < numBits : "index=" + index + ", numBits=" + numBits; int wordNum = index >> 6; // div 64 @@ -312,7 +347,11 @@ public int prevSetBit(int index) { return -1; } - /** this = this OR other */ + /** + * Performs this = this OR other. + * + * @param other the bitset to OR with this one + */ public void or(FixedBitSet other) { or(0, other.bits, other.numWords); } @@ -331,7 +370,11 @@ private void or(final int otherOffsetWords, final long[] otherArr, final int oth } } - /** this = this XOR other */ + /** + * Performs this = this XOR other. + * + * @param other the bitset to XOR with this one + */ public void xor(FixedBitSet other) { xor(other.bits, other.numWords); } @@ -345,7 +388,12 @@ private void xor(long[] otherBits, int otherNumWords) { } } - /** returns true if the sets have any elements in common */ + /** + * Checks if the sets have any elements in common. + * + * @param other the bitset to check for intersection with + * @return true if the sets have any elements in common + */ public boolean intersects(FixedBitSet other) { // Depends on the ghost bits being clear! int pos = Math.min(numWords, other.numWords); @@ -355,7 +403,11 @@ public boolean intersects(FixedBitSet other) { return false; } - /** this = this AND other */ + /** + * Performs this = this AND other. + * + * @param other the bitset to AND with this one + */ public void and(FixedBitSet other) { and(other.bits, other.numWords); } @@ -371,7 +423,11 @@ private void and(final long[] otherArr, final int otherNumWords) { } } - /** this = this AND NOT other */ + /** + * Performs this = this AND NOT other. + * + * @param other the bitset to AND NOT with this one + */ public void andNot(FixedBitSet other) { andNot(0, other.bits, other.numWords); } @@ -448,7 +504,11 @@ public void flip(int startIndex, int endIndex) { bits[endWord] ^= endmask; } - /** Flip the bit at the provided index. */ + /** + * Flips the bit at the provided index. + * + * @param index the bit index to flip + */ public void flip(int index) { assert index >= 0 && index < numBits : "index=" + index + " numBits=" + numBits; int wordNum = index >> 6; // div 64 diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/GrowableBitSet.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/GrowableBitSet.java index 9c42ecb69..0e52a4db4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/GrowableBitSet.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/GrowableBitSet.java @@ -28,10 +28,20 @@ public class GrowableBitSet extends BitSet { private final java.util.BitSet bitSet; + /** + * Creates a GrowableBitSet wrapping the given BitSet. + * + * @param bitSet the BitSet to wrap + */ public GrowableBitSet(java.util.BitSet bitSet) { this.bitSet = bitSet; } + /** + * Creates a GrowableBitSet with the specified initial size. + * + * @param initialBits the initial number of bits to allocate + */ public GrowableBitSet(int initialBits) { this.bitSet = new java.util.BitSet(initialBits); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/IntMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/IntMap.java index 713e9a3ab..0caa476c6 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/IntMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/IntMap.java @@ -20,41 +20,73 @@ import java.util.stream.IntStream; +/** + * A map with integer keys that provides atomic compare-and-put operations. + * + * @param the type of values stored in the map + */ public interface IntMap { /** + * Atomically sets the value for the given key if the current value matches the expected existing value. + * * @param key ordinal - * @return true if successful, false if the current value != `existing` + * @param existing the expected current value (may be null) + * @param value the new value to set + * @return true if successful, false if the current value != {@code existing} */ boolean compareAndPut(int key, T existing, T value); /** + * Returns the number of items that have been added to this map. + * * @return number of items that have been added */ int size(); /** + * Returns the value associated with the given key. + * * @param key ordinal * @return the value of the key, or null if not set */ T get(int key); /** + * Removes the mapping for the given key from this map if present. + * + * @param key the key to remove * @return the former value of the key, or null if it was not set */ T remove(int key); /** + * Checks if this map contains a mapping for the given key. + * + * @param key the key to check * @return true iff the given key is set in the map */ boolean containsKey(int key); /** * Iterates keys in ascending order and calls the consumer for each non-null key-value pair. + * + * @param consumer the consumer to call for each key-value pair */ void forEach(IntBiConsumer consumer); + /** + * A functional interface for consuming key-value pairs where the key is an int. + * + * @param the type of the value + */ @FunctionalInterface interface IntBiConsumer { + /** + * Consumes a key-value pair. + * + * @param key the integer key + * @param value the value associated with the key + */ void consume(int key, T2 value); } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/MathUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/MathUtil.java index 7515d100c..1ff4d0b82 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/MathUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/MathUtil.java @@ -16,8 +16,20 @@ package io.github.jbellis.jvector.util; +/** + * Utility methods for mathematical operations. + */ public class MathUtil { - // looks silly at first but it really does make code more readable + /** Private constructor to prevent instantiation. */ + private MathUtil() { + } + /** + * Squares the given float value. + * While this may look silly at first, it really does make code more readable. + * + * @param a the value to square + * @return the square of a + */ public static float square(float a) { return a * a; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/NumericUtils.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/NumericUtils.java index 43606f5b7..6eb53d7be 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/NumericUtils.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/NumericUtils.java @@ -44,6 +44,8 @@ private NumericUtils() {} // no instance! * reduced, but the value can easily used as an int. The sort order (including {@link Float#NaN}) * is defined by {@link Float#compareTo}; {@code NaN} is greater than positive infinity. * + * @param value the float value to convert + * @return the sortable int representation * @see #sortableIntToFloat */ public static int floatToSortableInt(float value) { @@ -53,13 +55,20 @@ public static int floatToSortableInt(float value) { /** * Converts a sortable int back to a float. * + * @param encoded the sortable int to convert + * @return the original float value * @see #floatToSortableInt */ public static float sortableIntToFloat(int encoded) { return Float.intBitsToFloat(sortableFloatBits(encoded)); } - /** Converts IEEE 754 representation of a float to sortable order (or back to the original) */ + /** + * Converts IEEE 754 representation of a float to sortable order (or back to the original). + * + * @param bits the IEEE 754 float bits to convert + * @return the converted bits in sortable order + */ public static int sortableFloatBits(int bits) { return bits ^ (bits >> 31) & 0x7fffffff; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/PhysicalCoreExecutor.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/PhysicalCoreExecutor.java index 0f757f036..59818c8e3 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/PhysicalCoreExecutor.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/PhysicalCoreExecutor.java @@ -23,39 +23,68 @@ import java.util.function.Supplier; /** - * A fork join pool which is sized to match the number of physical cores on the machine (avoiding hyper-thread count) + * A fork join pool which is sized to match the number of physical cores on the machine (avoiding hyper-thread count). *

* This is important for heavily vectorized sections of the code since it can easily saturate memory bandwidth. + *

+ * Knowing how many physical cores a machine has is left to the operator (however the default of 1/2 cores is today often correct). + * The physical core count can be configured via the {@code jvector.physical_core_count} system property. * * @see ProductQuantization * @see GraphIndexBuilder - * - * Knowing how many physical cores a machine has is left to the operator (however the default of 1/2 cores is today often correct). */ public class PhysicalCoreExecutor implements Closeable { private static final int physicalCoreCount = Integer.getInteger("jvector.physical_core_count", Math.max(1, Runtime.getRuntime().availableProcessors()/2)); + /** The shared PhysicalCoreExecutor instance. */ public static final PhysicalCoreExecutor instance = new PhysicalCoreExecutor(physicalCoreCount); + /** + * Returns the shared ForkJoinPool instance. + * + * @return the ForkJoinPool configured for physical cores + */ public static ForkJoinPool pool() { return instance.pool; } - + private final ForkJoinPool pool; + /** + * Constructs a PhysicalCoreExecutor with the specified number of cores. + * + * @param cores the number of physical cores to use + */ private PhysicalCoreExecutor(int cores) { assert cores > 0 && cores <= Runtime.getRuntime().availableProcessors() : "Invalid core count: " + cores; this.pool = new ForkJoinPool(cores); } + /** + * Executes the given runnable task and waits for completion. + * + * @param run the task to execute + */ public void execute(Runnable run) { pool.submit(run).join(); } + /** + * Submits a task that returns a result and waits for completion. + * + * @param run the task to execute + * @param the result type + * @return the result of the task + */ public T submit(Supplier run) { return pool.submit(run::get).join(); } + /** + * Returns the configured physical core count. + * + * @return the number of physical cores used by this executor + */ public static int getPhysicalCoreCount() { return physicalCoreCount; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/RamUsageEstimator.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/RamUsageEstimator.java index 0bdb1763b..a3b4b4822 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/RamUsageEstimator.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/RamUsageEstimator.java @@ -187,69 +187,125 @@ private RamUsageEstimator() {} // hash tables need to be oversized to avoid collisions, assume 2x capacity (2L * NUM_BYTES_OBJECT_REF) * 2; - /** Aligns an object size to be the next multiple of {@link #NUM_BYTES_OBJECT_ALIGNMENT}. */ + /** + * Aligns an object size to be the next multiple of {@link #NUM_BYTES_OBJECT_ALIGNMENT}. + * + * @param size the size to align + * @return the aligned size + */ public static long alignObjectSize(long size) { size += (long) NUM_BYTES_OBJECT_ALIGNMENT - 1L; return size - (size % NUM_BYTES_OBJECT_ALIGNMENT); } /** - * Return the shallow size of the provided {@link Integer} object. Ignores the possibility that - * this object is part of the VM IntegerCache + * Returns the shallow size of the provided {@link Integer} object. Ignores the possibility that + * this object is part of the VM IntegerCache. + * + * @param ignored the Integer object (parameter value is not used) + * @return the shallow size in bytes of an Integer object */ public static long sizeOf(Integer ignored) { return INTEGER_SIZE; } /** - * Return the shallow size of the provided {@link Long} object. Ignores the possibility that this - * object is part of the VM LongCache + * Returns the shallow size of the provided {@link Long} object. Ignores the possibility that this + * object is part of the VM LongCache. + * + * @param ignored the Long object (parameter value is not used) + * @return the shallow size in bytes of a Long object */ public static long sizeOf(Long ignored) { return LONG_SIZE; } - /** Returns the size in bytes of the byte[] object. */ + /** + * Returns the size in bytes of the byte[] object. + * + * @param arr the byte array + * @return the size in bytes including array header and data + */ public static long sizeOf(byte[] arr) { return alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + arr.length); } - /** Returns the size in bytes of the boolean[] object. */ + /** + * Returns the size in bytes of the boolean[] object. + * + * @param arr the boolean array + * @return the size in bytes including array header and data + */ public static long sizeOf(boolean[] arr) { return alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + arr.length); } - /** Returns the size in bytes of the char[] object. */ + /** + * Returns the size in bytes of the char[] object. + * + * @param arr the char array + * @return the size in bytes including array header and data + */ public static long sizeOf(char[] arr) { return alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) Character.BYTES * arr.length); } - /** Returns the size in bytes of the short[] object. */ + /** + * Returns the size in bytes of the short[] object. + * + * @param arr the short array + * @return the size in bytes including array header and data + */ public static long sizeOf(short[] arr) { return alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) Short.BYTES * arr.length); } - /** Returns the size in bytes of the int[] object. */ + /** + * Returns the size in bytes of the int[] object. + * + * @param arr the int array + * @return the size in bytes including array header and data + */ public static long sizeOf(int[] arr) { return alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) Integer.BYTES * arr.length); } - /** Returns the size in bytes of the float[] object. */ + /** + * Returns the size in bytes of the float[] object. + * + * @param arr the float array + * @return the size in bytes including array header and data + */ public static long sizeOf(float[] arr) { return alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) Float.BYTES * arr.length); } - /** Returns the size in bytes of the long[] object. */ + /** + * Returns the size in bytes of the long[] object. + * + * @param arr the long array + * @return the size in bytes including array header and data + */ public static long sizeOf(long[] arr) { return alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) Long.BYTES * arr.length); } - /** Returns the size in bytes of the double[] object. */ + /** + * Returns the size in bytes of the double[] object. + * + * @param arr the double array + * @return the size in bytes including array header and data + */ public static long sizeOf(double[] arr) { return alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) Double.BYTES * arr.length); } - /** Returns the size in bytes of the String[] object. */ + /** + * Returns the size in bytes of the String[] object including the size of all string elements. + * + * @param arr the String array + * @return the total size in bytes including array header, references, and string contents + */ public static long sizeOf(String[] arr) { long size = shallowSizeOf(arr); for (String s : arr) { @@ -261,6 +317,12 @@ public static long sizeOf(String[] arr) { return size; } + /** + * Returns the RAM bytes used by the given Accountable object. + * + * @param a the Accountable object + * @return the number of bytes used in RAM + */ public static long sizeOf(Accountable a) { return a.ramBytesUsed(); } @@ -349,7 +411,12 @@ private static long sizeOfObject(Object o, int depth, long defSize) { return size; } - /** Returns the size in bytes of the String object. */ + /** + * Returns the size in bytes of the String object including the character array. + * + * @param s the String object + * @return the size in bytes, or 0 if the string is null + */ public static long sizeOf(String s) { if (s == null) { return 0; @@ -361,8 +428,13 @@ public static long sizeOf(String s) { return alignObjectSize(size); } - /** Returns the shallow size in bytes of the Object[] object. */ - // Use this method instead of #shallowSizeOf(Object) to avoid costly reflection + /** + * Returns the shallow size in bytes of the Object[] object. + * Use this method instead of {@link #shallowSizeOf(Object)} to avoid costly reflection. + * + * @param arr the Object array + * @return the shallow size in bytes including array header and object references + */ public static long shallowSizeOf(Object[] arr) { return alignObjectSize( (long) NUM_BYTES_ARRAY_HEADER + (long) NUM_BYTES_OBJECT_REF * arr.length); @@ -374,6 +446,9 @@ public static long shallowSizeOf(Object[] arr) { * memory taken by the fields. * *

JVM object alignments are also applied. + * + * @param obj the object to measure + * @return the shallow size in bytes, or 0 if the object is null */ public static long shallowSizeOf(Object obj) { if (obj == null) return 0; @@ -390,8 +465,10 @@ public static long shallowSizeOf(Object obj) { * works with all conventional classes and primitive types, but not with arrays (the size then * depends on the number of elements and varies from object to object). * - * @see #shallowSizeOf(Object) + * @param clazz the class to measure + * @return the shallow size in bytes of an instance of the given class * @throws IllegalArgumentException if {@code clazz} is an array class. + * @see #shallowSizeOf(Object) */ public static long shallowSizeOfInstance(Class clazz) { if (clazz.isArray()) @@ -442,11 +519,13 @@ private static long shallowSizeOfArray(Object array) { } /** - * This method returns the maximum representation size of an object. sizeSoFar is the - * object's size measured so far. f is the field being probed. + * Returns the maximum representation size of an object accounting for a field being probed. + * The returned offset will be the maximum of whatever was measured so far and the field's + * offset and representation size (unaligned). * - *

The returned offset will be the maximum of whatever was measured so far and f - * field's offset and representation size (unaligned). + * @param sizeSoFar the object's size measured so far + * @param f the field being probed + * @return the updated size including the field's contribution */ public static long adjustForField(long sizeSoFar, final Field f) { final Class type = f.getType(); @@ -454,7 +533,13 @@ public static long adjustForField(long sizeSoFar, final Field f) { return sizeSoFar + fsize; } - /** Returns size in human-readable units (GB, MB, KB or bytes). */ + /** + * Returns the size in human-readable units (GB, MB, KB or bytes). + * + * @param bytes the size in bytes + * @param df the DecimalFormat to use for formatting + * @return a human-readable string representation of the size + */ public static String humanReadableUnits(long bytes, DecimalFormat df) { if (bytes / ONE_GB > 0) { return df.format((float) bytes / ONE_GB) + " GB"; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseBits.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseBits.java index 303a3ac23..a34919c6a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseBits.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseBits.java @@ -20,23 +20,41 @@ /** * Implements the membership parts of an updatable BitSet (but not prev/next bits) + * using a sparse hash set for efficient storage when few bits are set. */ public class SparseBits implements Bits { private final IntHashSet set = new IntHashSet(); + /** + * Creates a new SparseBits instance with an empty set. + */ + public SparseBits() { + } + @Override public boolean get(int index) { return set.contains(index); } + /** + * Sets the bit at the specified index. + * @param index the index of the bit to set + */ public void set(int index) { set.add(index); } + /** + * Clears all bits in this set. + */ public void clear() { set.clear(); } + /** + * Returns the number of bits set to true. + * @return the number of bits set + */ public int cardinality() { return set.size(); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseFixedBitSet.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseFixedBitSet.java index f6ee8b935..13c65801d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseFixedBitSet.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseFixedBitSet.java @@ -68,6 +68,7 @@ private static int blockCount(int length) { /** * Create a {@link SparseFixedBitSet} that can contain bits between 0 included and * length excluded. + * @param length the number of bits this set can hold */ public SparseFixedBitSet(int length) { if (length < 1) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseIntMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseIntMap.java index a8fc555e5..4492a0d0d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseIntMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseIntMap.java @@ -21,9 +21,18 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.IntStream; +/** + * A thread-safe sparse map from integer keys to values, backed by a ConcurrentHashMap. + * This implementation is suitable for sparse key spaces where only a small fraction + * of possible keys are actually used. + * @param the type of values stored in this map + */ public class SparseIntMap implements IntMap { private final ConcurrentHashMap map; + /** + * Creates a new empty SparseIntMap. + */ public SparseIntMap() { this.map = new ConcurrentHashMap<>(); } @@ -62,6 +71,10 @@ public boolean containsKey(int key) { return map.containsKey(key); } + /** + * Returns a stream of all keys in this map. + * @return an IntStream of keys + */ public IntStream keysStream() { return map.keySet().stream().mapToInt(key -> key); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ThreadSafeGrowableBitSet.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ThreadSafeGrowableBitSet.java index d5314a747..f3b1bb82f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ThreadSafeGrowableBitSet.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ThreadSafeGrowableBitSet.java @@ -32,10 +32,18 @@ public class ThreadSafeGrowableBitSet extends BitSet { private final java.util.BitSet bitSet; private final ReadWriteLock lock = new ReentrantReadWriteLock(); + /** + * Creates a new ThreadSafeGrowableBitSet wrapping an existing BitSet. + * @param bitSet the BitSet to wrap + */ public ThreadSafeGrowableBitSet(java.util.BitSet bitSet) { this.bitSet = bitSet; } + /** + * Creates a new ThreadSafeGrowableBitSet with the specified initial capacity. + * @param initialBits the initial number of bits + */ public ThreadSafeGrowableBitSet(int initialBits) { this.bitSet = new java.util.BitSet(initialBits); } @@ -167,6 +175,10 @@ public long ramBytesUsed() { throw new UnsupportedOperationException(); } + /** + * Creates a copy of this ThreadSafeGrowableBitSet. + * @return a new ThreadSafeGrowableBitSet with the same bits set + */ public ThreadSafeGrowableBitSet copy() { lock.readLock().lock(); try { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/package-info.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/package-info.java new file mode 100644 index 000000000..56b165b47 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/package-info.java @@ -0,0 +1,66 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides utility classes for array manipulation, bit operations, memory estimation, and + * concurrent data structures used throughout JVector. + * + *

This package contains low-level utility classes adapted from Apache Lucene and extended + * for JVector's needs. The utilities focus on: + * + *

    + *
  • Array operations: {@link io.github.jbellis.jvector.util.ArrayUtil} provides + * efficient methods for growing and copying arrays of various primitive types and objects, + * with memory-aligned size calculations for optimal performance. + *
  • Bit manipulation: {@link io.github.jbellis.jvector.util.BitUtil}, + * {@link io.github.jbellis.jvector.util.FixedBitSet}, + * {@link io.github.jbellis.jvector.util.SparseFixedBitSet}, and + * {@link io.github.jbellis.jvector.util.GrowableBitSet} offer various bit set implementations + * and bitwise operations optimized for different use cases. + *
  • Memory estimation: {@link io.github.jbellis.jvector.util.RamUsageEstimator} provides + * utilities for estimating object sizes and memory overhead. + *
  • Data structures: Specialized collections including + * {@link io.github.jbellis.jvector.util.BoundedLongHeap}, + * {@link io.github.jbellis.jvector.util.DenseIntMap}, + * {@link io.github.jbellis.jvector.util.SparseIntMap} for efficient storage and retrieval. + *
  • Threading utilities: {@link io.github.jbellis.jvector.util.PhysicalCoreExecutor} + * and {@link io.github.jbellis.jvector.util.ExplicitThreadLocal} for managing concurrent + * operations. + *
+ * + *

Usage Example: + *

{@code
+ * // Growing an array with optimal size calculation
+ * float[] vectors = new float[100];
+ * vectors = ArrayUtil.grow(vectors, 200); // Grows to >= 200, over-allocating for efficiency
+ *
+ * // Using bit sets for neighbor tracking
+ * FixedBitSet visited = new FixedBitSet(graphSize);
+ * visited.set(nodeId);
+ * if (visited.get(neighborId)) {
+ *     // neighbor already visited
+ * }
+ * }
+ * + *

Most classes in this package are final and provide only static methods. The implementations + * prioritize performance and memory efficiency, making them suitable for use in vector search + * operations where arrays and bit sets are manipulated frequently. + * + * @see io.github.jbellis.jvector.util.ArrayUtil + * @see io.github.jbellis.jvector.util.RamUsageEstimator + * @see io.github.jbellis.jvector.util.FixedBitSet + */ +package io.github.jbellis.jvector.util; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArraySliceByteSequence.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArraySliceByteSequence.java index 231325440..27e02ab4d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArraySliceByteSequence.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArraySliceByteSequence.java @@ -24,10 +24,22 @@ * A {@link ByteSequence} implementation that represents a slice of another {@link ByteSequence}. */ public class ArraySliceByteSequence implements ByteSequence { + /** The underlying byte sequence from which this slice is taken */ private final ByteSequence data; + /** The offset within the underlying sequence where this slice begins */ private final int offset; + /** The length of this slice in bytes */ private final int length; + /** + * Creates a new byte sequence slice from an existing byte sequence. + * + * @param data the underlying byte sequence to slice from + * @param offset the starting position within the underlying sequence + * @param length the number of bytes in this slice + * @throws IllegalArgumentException if offset or length are negative, or if the slice + * extends beyond the bounds of the underlying sequence + */ public ArraySliceByteSequence(ByteSequence data, int offset, int length) { if (offset < 0 || length < 0 || offset + length > data.length()) { throw new IllegalArgumentException("Invalid offset or length"); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorizationProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorizationProvider.java index ca540b916..8dc6028e6 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorizationProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorizationProvider.java @@ -32,7 +32,9 @@ final public class DefaultVectorizationProvider extends VectorizationProvider { private final VectorUtilSupport vectorUtilSupport; private final VectorTypeSupport vectorTypes; - + /** + * Constructs a DefaultVectorizationProvider with scalar implementations. + */ public DefaultVectorizationProvider() { vectorUtilSupport = new DefaultVectorUtilSupport(); vectorTypes = new ArrayVectorProvider(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/Matrix.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/Matrix.java index a038ff6f1..f6d66386b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/Matrix.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/Matrix.java @@ -30,10 +30,23 @@ public class Matrix { VectorFloat[] data; + /** + * Constructs a zero-initialized matrix with the specified dimensions. + * + * @param m the number of rows + * @param n the number of columns + */ public Matrix(int m, int n) { this(m, n, true); } + /** + * Constructs a matrix with the specified dimensions. + * + * @param m the number of rows + * @param n the number of columns + * @param allocateZeroed if true, allocate and zero-initialize the matrix; if false, leave rows unallocated + */ public Matrix(int m, int n, boolean allocateZeroed) { data = new VectorFloat[m]; if (allocateZeroed) { @@ -43,14 +56,34 @@ public Matrix(int m, int n, boolean allocateZeroed) { } } + /** + * Returns the value at the specified position in the matrix. + * + * @param i the row index + * @param j the column index + * @return the value at position (i, j) + */ public float get(int i, int j) { return data[i].get(j); } + /** + * Sets the value at the specified position in the matrix. + * + * @param i the row index + * @param j the column index + * @param value the value to set + */ public void set(int i, int j, float value) { data[i].set(j, value); } + /** + * Checks if this matrix has the same dimensions as another matrix. + * + * @param other the matrix to compare with + * @return true if both matrices have the same number of rows and columns + */ public boolean isIsomorphicWith(Matrix other) { return data.length == other.data.length && data[0].length() == other.data[0].length(); } @@ -122,10 +155,23 @@ public Matrix invert() { return inverse; } + /** + * Adds a delta value to the element at the specified position. + * + * @param i the row index + * @param j the column index + * @param delta the value to add + */ public void addTo(int i, int j, float delta) { data[i].set(j, data[i].get(j) + delta); } + /** + * Adds another matrix to this matrix in place. + * + * @param other the matrix to add + * @throws IllegalArgumentException if the matrices have different dimensions + */ public void addInPlace(Matrix other) { if (!this.isIsomorphicWith(other)) { throw new IllegalArgumentException("matrix dimensions differ for " + this + "!=" + other); @@ -136,6 +182,13 @@ public void addInPlace(Matrix other) { } } + /** + * Multiplies this matrix by a vector. + * + * @param v the vector to multiply by + * @return the resulting vector + * @throws IllegalArgumentException if the matrix or vector is empty + */ public VectorFloat multiply(VectorFloat v) { if (data.length == 0) { throw new IllegalArgumentException("Cannot multiply empty matrix"); @@ -151,6 +204,13 @@ public VectorFloat multiply(VectorFloat v) { return result; } + /** + * Computes the outer product of two vectors. + * + * @param a the first vector + * @param b the second vector + * @return a matrix representing the outer product of a and b + */ public static Matrix outerProduct(VectorFloat a, VectorFloat b) { var result = new Matrix(a.length(), b.length(), false); @@ -163,6 +223,11 @@ public static Matrix outerProduct(VectorFloat a, VectorFloat b) { return result; } + /** + * Scales all elements of this matrix by the given multiplier. + * + * @param multiplier the scaling factor + */ public void scale(float multiplier) { for (var row : data) { VectorUtil.scale(row, multiplier); @@ -186,6 +251,12 @@ public boolean equals(Object obj) { return true; } + /** + * Creates a matrix from a 2D float array. + * + * @param values the 2D array to convert to a matrix + * @return a new Matrix containing the given values + */ public static Matrix from(float[][] values) { var result = new Matrix(values.length, values[0].length, false); for (int i = 0; i < values.length; i++) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index e7e8b068f..fb1727b52 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -40,6 +40,9 @@ private VectorUtil() {} /** * Returns the vector dot product of the two vectors. * + * @param a the first vector + * @param b the second vector + * @return the dot product * @throws IllegalArgumentException if the vectors' dimensions differ. */ public static float dotProduct(VectorFloat a, VectorFloat b) { @@ -51,6 +54,16 @@ public static float dotProduct(VectorFloat a, VectorFloat b) { return r; } + /** + * Returns the vector dot product of the two vectors, or subvectors, of the given length. + * + * @param a the first vector + * @param aoffset the starting offset in the first vector + * @param b the second vector + * @param boffset the starting offset in the second vector + * @param length the number of elements to compute the dot product over + * @return the dot product + */ public static float dotProduct(VectorFloat a, int aoffset, VectorFloat b, int boffset, int length) { //This check impacts FLOPS /*if ( length > Math.min(a.length - aoffset, b.length - boffset) ) { @@ -65,6 +78,9 @@ public static float dotProduct(VectorFloat a, int aoffset, VectorFloat b, /** * Returns the cosine similarity between the two vectors. * + * @param a the first vector + * @param b the second vector + * @return the cosine similarity * @throws IllegalArgumentException if the vectors' dimensions differ. */ public static float cosine(VectorFloat a, VectorFloat b) { @@ -79,6 +95,9 @@ public static float cosine(VectorFloat a, VectorFloat b) { /** * Returns the sum of squared differences of the two vectors. * + * @param a the first vector + * @param b the second vector + * @return the sum of squared differences * @throws IllegalArgumentException if the vectors' dimensions differ. */ public static float squareL2Distance(VectorFloat a, VectorFloat b) { @@ -92,6 +111,13 @@ public static float squareL2Distance(VectorFloat a, VectorFloat b) { /** * Returns the sum of squared differences of the two vectors, or subvectors, of the given length. + * + * @param a the first vector + * @param aoffset the starting offset in the first vector + * @param b the second vector + * @param boffset the starting offset in the second vector + * @param length the number of elements to compare + * @return the sum of squared differences */ public static float squareL2Distance(VectorFloat a, int aoffset, VectorFloat b, int boffset, int length) { float r = impl.squareDistance(a, aoffset, b, boffset, length); @@ -114,6 +140,13 @@ public static void l2normalize(VectorFloat v) { scale(v, (float) (1.0 / length)); } + /** + * Returns the sum of the given vectors. + * + * @param vectors the list of vectors to sum + * @return a new vector containing the sum of all input vectors + * @throws IllegalArgumentException if the input list is empty + */ public static VectorFloat sum(List> vectors) { if (vectors.isEmpty()) { throw new IllegalArgumentException("Input list cannot be empty"); @@ -122,62 +155,183 @@ public static VectorFloat sum(List> vectors) { return impl.sum(vectors); } + /** + * Returns the sum of all components in the vector. + * + * @param vector the vector to sum + * @return the sum of all elements in the vector + */ public static float sum(VectorFloat vector) { return impl.sum(vector); } + /** + * Multiplies each element of the vector by the given multiplier, modifying the vector in place. + * + * @param vector the vector to scale (modified in place) + * @param multiplier the scalar value to multiply each element by + */ public static void scale(VectorFloat vector, float multiplier) { impl.scale(vector, multiplier); } + /** + * Adds v2 to v1 element-wise, modifying v1 in place. + * + * @param v1 the vector to add to (modified in place) + * @param v2 the vector to add + */ public static void addInPlace(VectorFloat v1, VectorFloat v2) { impl.addInPlace(v1, v2); } + /** + * Adds a scalar value to each element of v1, modifying v1 in place. + * + * @param v1 the vector to add to (modified in place) + * @param value the scalar value to add to each element + */ public static void addInPlace(VectorFloat v1, float value) { impl.addInPlace(v1, value); } + /** + * Subtracts v2 from v1 element-wise, modifying v1 in place. + * + * @param v1 the vector to subtract from (modified in place) + * @param v2 the vector to subtract + */ public static void subInPlace(VectorFloat v1, VectorFloat v2) { impl.subInPlace(v1, v2); } + /** + * Subtracts a scalar value from each element of the vector, modifying the vector in place. + * + * @param vector the vector to subtract from (modified in place) + * @param value the scalar value to subtract from each element + */ public static void subInPlace(VectorFloat vector, float value) { impl.subInPlace(vector, value); } + /** + * Returns a new vector containing the element-wise difference of lhs and rhs. + * + * @param lhs the left-hand side vector + * @param rhs the right-hand side vector + * @return a new vector containing lhs - rhs + */ public static VectorFloat sub(VectorFloat lhs, VectorFloat rhs) { return impl.sub(lhs, rhs); } + /** + * Returns a new vector containing the result of subtracting a scalar value from each element of lhs. + * + * @param lhs the left-hand side vector + * @param value the scalar value to subtract from each element + * @return a new vector containing lhs - value + */ public static VectorFloat sub(VectorFloat lhs, float value) { return impl.sub(lhs, value); } + /** + * Returns a new vector containing the element-wise difference of two subvectors. + * + * @param a the first vector + * @param aOffset the starting offset in the first vector + * @param b the second vector + * @param bOffset the starting offset in the second vector + * @param length the number of elements to subtract + * @return a new vector containing a[aOffset:aOffset+length] - b[bOffset:bOffset+length] + */ public static VectorFloat sub(VectorFloat a, int aOffset, VectorFloat b, int bOffset, int length) { return impl.sub(a, aOffset, b, bOffset, length); } + /** + * Computes the element-wise minimum of distances1 and distances2, modifying distances1 in place. + * + * @param distances1 the first vector (modified in place to contain the minimum values) + * @param distances2 the second vector + */ public static void minInPlace(VectorFloat distances1, VectorFloat distances2) { impl.minInPlace(distances1, distances2); } + /** + * Assembles values from data using indices in dataOffsets and returns their sum. + * + * @param data the vector containing all data points + * @param dataBase the base index in the data vector + * @param dataOffsets byte sequence containing offsets from the base index + * @return the sum of the assembled values + */ public static float assembleAndSum(VectorFloat data, int dataBase, ByteSequence dataOffsets) { return impl.assembleAndSum(data, dataBase, dataOffsets); } + /** + * Assembles values from data using a subset of indices in dataOffsets and returns their sum. + * + * @param data the vector containing all data points + * @param dataBase the base index in the data vector + * @param dataOffsets byte sequence containing offsets from the base index + * @param dataOffsetsOffset the starting offset in the dataOffsets sequence + * @param dataOffsetsLength the number of offsets to use + * @return the sum of the assembled values + */ public static float assembleAndSum(VectorFloat data, int dataBase, ByteSequence dataOffsets, int dataOffsetsOffset, int dataOffsetsLength) { return impl.assembleAndSum(data, dataBase, dataOffsets, dataOffsetsOffset, dataOffsetsLength); } + /** + * Computes the distance between two product-quantized vectors using precomputed partial results. + * + * @param data the vector of product quantization partial sums + * @param subspaceCount the number of PQ subspaces + * @param dataOffsets1 the ordinals specifying centroids for the first vector + * @param dataOffsetsOffset1 the starting offset in dataOffsets1 + * @param dataOffsets2 the ordinals specifying centroids for the second vector + * @param dataOffsetsOffset2 the starting offset in dataOffsets2 + * @param clusterCount the number of clusters per subspace + * @return the sum of the partial results + */ public static float assembleAndSumPQ(VectorFloat data, int subspaceCount, ByteSequence dataOffsets1, int dataOffsetsOffset1, ByteSequence dataOffsets2, int dataOffsetsOffset2, int clusterCount) { return impl.assembleAndSumPQ(data, subspaceCount, dataOffsets1, dataOffsetsOffset1, dataOffsets2, dataOffsetsOffset2, clusterCount); } + /** + * Computes similarity scores for multiple product-quantized vectors using quantized partial results. + * + * @param shuffles the transposed PQ-encoded vectors + * @param codebookCount the number of codebooks used in PQ encoding + * @param quantizedPartials the quantized precomputed score fragments + * @param delta the quantization delta value + * @param minDistance the minimum distance used in quantization + * @param results the output vector to store similarity scores (modified in place) + * @param vsf the vector similarity function to use + */ public static void bulkShuffleQuantizedSimilarity(ByteSequence shuffles, int codebookCount, ByteSequence quantizedPartials, float delta, float minDistance, VectorFloat results, VectorSimilarityFunction vsf) { impl.bulkShuffleQuantizedSimilarity(shuffles, codebookCount, quantizedPartials, delta, minDistance, vsf, results); } + /** + * Computes cosine similarity scores for multiple product-quantized vectors using quantized partial results. + * + * @param shuffles the transposed PQ-encoded vectors + * @param codebookCount the number of codebooks used in PQ encoding + * @param quantizedPartialSums the quantized precomputed dot product fragments + * @param sumDelta the delta used to quantize the partial sums + * @param minDistance the minimum distance used in quantization + * @param quantizedPartialMagnitudes the quantized precomputed squared magnitudes + * @param magnitudeDelta the delta used to quantize the magnitudes + * @param minMagnitude the minimum magnitude used in quantization + * @param queryMagnitudeSquared the squared magnitude of the query vector + * @param results the output vector to store similarity scores (modified in place) + */ public static void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int codebookCount, ByteSequence quantizedPartialSums, float sumDelta, float minDistance, ByteSequence quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, @@ -185,18 +339,58 @@ public static void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles impl.bulkShuffleQuantizedSimilarityCosine(shuffles, codebookCount, quantizedPartialSums, sumDelta, minDistance, quantizedPartialMagnitudes, magnitudeDelta, minMagnitude, queryMagnitudeSquared, results); } + /** + * Computes the Hamming distance between two bit vectors represented as long arrays. + * + * @param v1 the first bit vector + * @param v2 the second bit vector + * @return the Hamming distance (number of differing bits) + */ public static int hammingDistance(long[] v1, long[] v2) { return impl.hammingDistance(v1, v2); } + /** + * Calculates partial sums for product quantization, storing results in partialSums and partialBestDistances. + * + * @param codebook the PQ codebook vectors + * @param codebookIndex the starting index in the codebook + * @param size the size of each codebook entry + * @param clusterCount the number of clusters per subspace + * @param query the query vector + * @param offset the offset in the query vector + * @param vsf the vector similarity function + * @param partialSums the output vector for partial sums (modified in place) + * @param partialBestDistances the output vector for best distances (modified in place) + */ public static void calculatePartialSums(VectorFloat codebook, int codebookIndex, int size, int clusterCount, VectorFloat query, int offset, VectorSimilarityFunction vsf, VectorFloat partialSums, VectorFloat partialBestDistances) { impl.calculatePartialSums(codebook, codebookIndex, size, clusterCount, query, offset, vsf, partialSums, partialBestDistances); } + /** + * Calculates partial sums for product quantization, storing results in partialSums. + * + * @param codebook the PQ codebook vectors + * @param codebookIndex the starting index in the codebook + * @param size the size of each codebook entry + * @param clusterCount the number of clusters per subspace + * @param query the query vector + * @param offset the offset in the query vector + * @param vsf the vector similarity function + * @param partialSums the output vector for partial sums (modified in place) + */ public static void calculatePartialSums(VectorFloat codebook, int codebookIndex, int size, int clusterCount, VectorFloat query, int offset, VectorSimilarityFunction vsf, VectorFloat partialSums) { impl.calculatePartialSums(codebook, codebookIndex, size, clusterCount, query, offset, vsf, partialSums); } + /** + * Quantizes partial sum values into unsigned 16-bit integers stored as bytes. + * + * @param delta the quantization delta (divisor) + * @param partials the values to quantize + * @param partialBase the base values to subtract before quantization + * @param quantizedPartials the output byte sequence for quantized values (modified in place) + */ public static void quantizePartials(float delta, VectorFloat partials, VectorFloat partialBase, ByteSequence quantizedPartials) { impl.quantizePartials(delta, partials, partialBase, quantizedPartials); } @@ -219,38 +413,129 @@ public static float min(VectorFloat v) { return impl.min(v); } + /** + * Computes the cosine similarity between a query and a product-quantized vector. + * + * @param encoded the PQ-encoded vector + * @param clusterCount the number of clusters per subspace + * @param partialSums the precomputed partial dot products with codebook centroids + * @param aMagnitude the precomputed partial magnitudes of codebook centroids + * @param bMagnitude the magnitude of the query vector + * @return the cosine similarity + */ public static float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { return impl.pqDecodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); } + /** + * Computes the cosine similarity between a query and a subset of a product-quantized vector. + * + * @param encoded the PQ-encoded vector + * @param encodedOffset the starting offset in the encoded vector + * @param encodedLength the number of encoded values to use + * @param clusterCount the number of clusters per subspace + * @param partialSums the precomputed partial dot products with codebook centroids + * @param aMagnitude the precomputed partial magnitudes of codebook centroids + * @param bMagnitude the magnitude of the query vector + * @return the cosine similarity + */ public static float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { return impl.pqDecodedCosineSimilarity(encoded, encodedOffset, encodedLength, clusterCount, partialSums, aMagnitude, bMagnitude); } + /** + * Computes the dot product between a vector and an 8-bit NVQ quantized vector. + * + * @param vector the query vector + * @param bytes the 8-bit quantized vector + * @param growthRate the growth rate parameter of the logistic quantization function + * @param midpoint the midpoint parameter of the logistic quantization function + * @param minValue the minimum value of the quantized subvector + * @param maxValue the maximum value of the quantized subvector + * @return the dot product + */ public static float nvqDotProduct8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue) { return impl.nvqDotProduct8bit(vector, bytes, growthRate, midpoint, minValue, maxValue); } + /** + * Computes the squared Euclidean distance between a vector and an 8-bit NVQ quantized vector. + * + * @param vector the query vector + * @param bytes the 8-bit quantized vector + * @param growthRate the growth rate parameter of the logistic quantization function + * @param midpoint the midpoint parameter of the logistic quantization function + * @param minValue the minimum value of the quantized subvector + * @param maxValue the maximum value of the quantized subvector + * @return the squared Euclidean distance + */ public static float nvqSquareL2Distance8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue) { return impl.nvqSquareL2Distance8bit(vector, bytes, growthRate, midpoint, minValue, maxValue); } + /** + * Computes the cosine similarity between a vector and an 8-bit NVQ quantized vector. + * + * @param vector the query vector + * @param bytes the 8-bit quantized vector + * @param growthRate the growth rate parameter of the logistic quantization function + * @param midpoint the midpoint parameter of the logistic quantization function + * @param minValue the minimum value of the quantized subvector + * @param maxValue the maximum value of the quantized subvector + * @param centroid the global mean vector used to re-center the quantized subvectors + * @return an array containing the cosine similarity components + */ public static float[] nvqCosine8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue, VectorFloat centroid) { return impl.nvqCosine8bit(vector, bytes, growthRate, midpoint, minValue, maxValue, centroid); } + /** + * Shuffles a query vector in place to optimize NVQ quantized vector unpacking performance. + * + * @param vector the vector to shuffle (modified in place) + */ public static void nvqShuffleQueryInPlace8bit(VectorFloat vector) { impl.nvqShuffleQueryInPlace8bit(vector); } + /** + * Quantizes a vector as an 8-bit NVQ quantized vector. + * + * @param vector the vector to quantize + * @param growthRate the growth rate parameter of the logistic quantization function + * @param midpoint the midpoint parameter of the logistic quantization function + * @param minValue the minimum value of the subvector + * @param maxValue the maximum value of the subvector + * @param destination the byte sequence to store the quantized values (modified in place) + */ public static void nvqQuantize8bit(VectorFloat vector, float growthRate, float midpoint, float minValue, float maxValue, ByteSequence destination) { impl.nvqQuantize8bit(vector, growthRate, midpoint, minValue, maxValue, destination); } + /** + * Computes the squared error (loss) of quantizing a vector with NVQ. + * + * @param vector the vector to quantize + * @param growthRate the growth rate parameter of the logistic quantization function + * @param midpoint the midpoint parameter of the logistic quantization function + * @param minValue the minimum value of the subvector + * @param maxValue the maximum value of the subvector + * @param nBits the number of bits per dimension + * @return the squared error + */ public static float nvqLoss(VectorFloat vector, float growthRate, float midpoint, float minValue, float maxValue, int nBits) { return impl.nvqLoss(vector, growthRate, midpoint, minValue, maxValue, nBits); } + /** + * Computes the squared error (loss) of quantizing a vector with a uniform quantizer. + * + * @param vector the vector to quantize + * @param minValue the minimum value of the subvector + * @param maxValue the maximum value of the subvector + * @param nBits the number of bits per dimension + * @return the squared error + */ public static float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits) { return impl.nvqUniformLoss(vector, minValue, maxValue, nBits); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index cc1f74f1b..42e0038d9 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -35,55 +35,144 @@ */ public interface VectorUtilSupport { - /** Calculates the dot product of the given float arrays. */ + /** + * Calculates the dot product of the given float arrays. + * @param a the first vector + * @param b the second vector + * @return the dot product + */ float dotProduct(VectorFloat a, VectorFloat b); - /** Calculates the dot product of float arrays of differing sizes, or a subset of the data */ + /** + * Calculates the dot product of float arrays of differing sizes, or a subset of the data. + * @param a the first vector + * @param aoffset the starting offset in the first vector + * @param b the second vector + * @param boffset the starting offset in the second vector + * @param length the number of elements to compute the dot product over + * @return the dot product + */ float dotProduct(VectorFloat a, int aoffset, VectorFloat b, int boffset, int length); - /** Returns the cosine similarity between the two vectors. */ + /** + * Returns the cosine similarity between the two vectors. + * @param v1 the first vector + * @param v2 the second vector + * @return the cosine similarity + */ float cosine(VectorFloat v1, VectorFloat v2); - /** Calculates the cosine similarity of VectorFloats of differing sizes, or a subset of the data */ + /** + * Calculates the cosine similarity of VectorFloats of differing sizes, or a subset of the data. + * @param a the first vector + * @param aoffset the starting offset in the first vector + * @param b the second vector + * @param boffset the starting offset in the second vector + * @param length the number of elements to compute the cosine similarity over + * @return the cosine similarity + */ float cosine(VectorFloat a, int aoffset, VectorFloat b, int boffset, int length); - /** Returns the sum of squared differences of the two vectors. */ + /** + * Returns the sum of squared differences of the two vectors. + * @param a the first vector + * @param b the second vector + * @return the sum of squared differences + */ float squareDistance(VectorFloat a, VectorFloat b); - /** Calculates the sum of squared differences of float arrays of differing sizes, or a subset of the data */ + /** + * Calculates the sum of squared differences of float arrays of differing sizes, or a subset of the data. + * @param a the first vector + * @param aoffset the starting offset in the first vector + * @param b the second vector + * @param boffset the starting offset in the second vector + * @param length the number of elements to compare + * @return the sum of squared differences + */ float squareDistance(VectorFloat a, int aoffset, VectorFloat b, int boffset, int length); - /** returns the sum of the given vectors. */ + /** + * Returns the sum of the given vectors. + * @param vectors the list of vectors to sum + * @return a new vector containing the sum + */ VectorFloat sum(List> vectors); - /** return the sum of the components of the vector */ + /** + * Returns the sum of the components of the vector. + * @param vector the vector to sum + * @return the sum of all elements + */ float sum(VectorFloat vector); - /** Multiply vector by multiplier, in place (vector will be modified) */ + /** + * Multiplies vector by multiplier, in place (vector will be modified). + * @param vector the vector to scale + * @param multiplier the scalar multiplier + */ void scale(VectorFloat vector, float multiplier); - /** Adds v2 into v1, in place (v1 will be modified) */ + /** + * Adds v2 into v1, in place (v1 will be modified). + * @param v1 the vector to add to + * @param v2 the vector to add + */ void addInPlace(VectorFloat v1, VectorFloat v2); - /** Adds value to each element of v1, in place (v1 will be modified) */ + /** + * Adds value to each element of v1, in place (v1 will be modified). + * @param v1 the vector to add to + * @param value the scalar value to add + */ void addInPlace(VectorFloat v1, float value); - /** Subtracts v2 from v1, in place (v1 will be modified) */ + /** + * Subtracts v2 from v1, in place (v1 will be modified). + * @param v1 the vector to subtract from + * @param v2 the vector to subtract + */ void subInPlace(VectorFloat v1, VectorFloat v2); - /** Subtracts value from each element of v1, in place (v1 will be modified) */ + /** + * Subtracts value from each element of v1, in place (v1 will be modified). + * @param vector the vector to subtract from + * @param value the scalar value to subtract + */ void subInPlace(VectorFloat vector, float value); - /** @return a - b, element-wise */ + /** + * Computes a - b, element-wise. + * @param a the left-hand side vector + * @param b the right-hand side vector + * @return a new vector containing a - b + */ VectorFloat sub(VectorFloat a, VectorFloat b); - /** Subtracts value from each element of a */ + /** + * Subtracts value from each element of a. + * @param a the vector to subtract from + * @param value the scalar value to subtract + * @return a new vector containing a - value + */ VectorFloat sub(VectorFloat a, float value); - /** @return a - b, element-wise, starting at aOffset and bOffset respectively */ + /** + * Computes a - b, element-wise, starting at aOffset and bOffset respectively. + * @param a the first vector + * @param aOffset the starting offset in the first vector + * @param b the second vector + * @param bOffset the starting offset in the second vector + * @param length the number of elements to subtract + * @return a new vector containing a[aOffset:aOffset+length] - b[bOffset:bOffset+length] + */ VectorFloat sub(VectorFloat a, int aOffset, VectorFloat b, int bOffset, int length); - /** Calculates the minimum value for every corresponding lane values in v1 and v2, in place (v1 will be modified) */ + /** + * Calculates the minimum value for every corresponding lane values in v1 and v2, in place (v1 will be modified). + * @param v1 the first vector (modified in place) + * @param v2 the second vector + */ void minInPlace(VectorFloat v1, VectorFloat v2); /** @@ -130,6 +219,12 @@ public interface VectorUtilSupport { */ float assembleAndSumPQ(VectorFloat codebookPartialSums, int subspaceCount, ByteSequence vector1Ordinals, int vector1OrdinalOffset, ByteSequence node2Ordinals, int node2OrdinalOffset, int clusterCount); + /** + * Computes the Hamming distance between two bit vectors. + * @param v1 the first bit vector + * @param v2 the second bit vector + * @return the Hamming distance (number of differing bits) + */ int hammingDistance(long[] v1, long[] v2); // default implementation used here because Panama SIMD can't express necessary SIMD operations and degrades to scalar @@ -143,6 +238,8 @@ public interface VectorUtilSupport { * @param quantizedPartials The quantized precomputed score fragments for each codebook entry. These are stored as a contiguous vector of all * the fragments for one codebook, followed by all the fragments for the next codebook, and so on. These have been * quantized by quantizePartialSums. + * @param delta The quantization delta value used to dequantize the partial results. + * @param minDistance The minimum distance used in quantization. * @param vsf The similarity function to use. * @param results The output vector to store the similarity scores. This should be pre-allocated to the same size as the number of shuffles. */ @@ -219,8 +316,31 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int } } + /** + * Calculates partial sums for product quantization. + * @param codebook the PQ codebook vectors + * @param codebookIndex the starting index in the codebook + * @param size the size of each codebook entry + * @param clusterCount the number of clusters per subspace + * @param query the query vector + * @param offset the offset in the query vector + * @param vsf the vector similarity function + * @param partialSums the output vector for partial sums (modified in place) + */ void calculatePartialSums(VectorFloat codebook, int codebookIndex, int size, int clusterCount, VectorFloat query, int offset, VectorSimilarityFunction vsf, VectorFloat partialSums); + /** + * Calculates partial sums and minimum values for product quantization. + * @param codebook the PQ codebook vectors + * @param codebookIndex the starting index in the codebook + * @param size the size of each codebook entry + * @param clusterCount the number of clusters per subspace + * @param query the query vector + * @param offset the offset in the query vector + * @param vsf the vector similarity function + * @param partialSums the output vector for partial sums (modified in place) + * @param partialMins the output vector for minimum values (modified in place) + */ void calculatePartialSums(VectorFloat codebook, int codebookIndex, int size, int clusterCount, VectorFloat query, int offset, VectorSimilarityFunction vsf, VectorFloat partialSums, VectorFloat partialMins); /** @@ -237,14 +357,45 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int */ void quantizePartials(float delta, VectorFloat partials, VectorFloat partialBases, ByteSequence quantizedPartials); + /** + * Returns the maximum value in the vector. + * @param v the vector + * @return the maximum value + */ float max(VectorFloat v); + + /** + * Returns the minimum value in the vector. + * @param v the vector + * @return the minimum value + */ float min(VectorFloat v); + /** + * Computes the cosine similarity between a query and a product-quantized vector. + * @param encoded the PQ-encoded vector + * @param clusterCount the number of clusters per subspace + * @param partialSums the precomputed partial dot products with codebook centroids + * @param aMagnitude the precomputed partial magnitudes of codebook centroids + * @param bMagnitude the magnitude of the query vector + * @return the cosine similarity + */ default float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude); } + /** + * Computes the cosine similarity between a query and a subset of a product-quantized vector. + * @param encoded the PQ-encoded vector + * @param encodedOffset the starting offset in the encoded vector + * @param encodedLength the number of encoded values to use + * @param clusterCount the number of clusters per subspace + * @param partialSums the precomputed partial dot products with codebook centroids + * @param aMagnitude the precomputed partial magnitudes of codebook centroids + * @param bMagnitude the magnitude of the query vector + * @return the cosine similarity + */ default float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { float sum = 0.0f; @@ -326,6 +477,7 @@ default float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffs * @param minValue The minimum value of the subvector * @param maxValue The maximum value of the subvector * @param nBits the number of bits per dimension + * @return the squared error (loss) */ float nvqLoss(VectorFloat vector, float growthRate, float midpoint, float minValue, float maxValue, int nBits); @@ -335,6 +487,7 @@ default float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffs * @param minValue The minimum value of the subvector * @param maxValue The maximum value of the subvector * @param nBits the number of bits per dimension + * @return the squared error (loss) */ float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorizationProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorizationProvider.java index 1ec46443d..146e257a7 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorizationProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorizationProvider.java @@ -46,12 +46,17 @@ public abstract class VectorizationProvider { /** * Returns the default instance of the provider matching vectorization possibilities of actual * runtime. + * + * @return the VectorizationProvider instance for the current runtime */ public static VectorizationProvider getInstance() { return Objects.requireNonNull( Holder.INSTANCE, "call to getInstance() from subclass of VectorizationProvider"); } + /** + * Protected constructor for subclasses. + */ protected VectorizationProvider() { } @@ -59,17 +64,22 @@ protected VectorizationProvider() { /** * Returns a singleton (stateless) {@link VectorUtilSupport} to support SIMD usage in {@link * VectorUtil}. + * + * @return the VectorUtilSupport implementation */ public abstract VectorUtilSupport getVectorUtilSupport(); /** * Returns a singleton (stateless) {@link VectorTypeSupport} which works with the corresponding {@link VectorUtilSupport} - * implementation + * implementation. + * + * @return the VectorTypeSupport implementation */ public abstract VectorTypeSupport getVectorTypeSupport(); // *** Lookup mechanism: *** + /** Logger for vectorization provider */ protected static final Logger LOG = Logger.getLogger(VectorizationProvider.class.getName()); /** The minimal version of Java that has the bugfix for JDK-8301190. */ diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/package-info.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/package-info.java new file mode 100644 index 000000000..dd7dc57b7 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/package-info.java @@ -0,0 +1,68 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides vector data structures and operations for high-performance vector similarity search. + *

+ * This package contains the core abstractions and implementations for representing and manipulating + * vectors in JVector. The design supports both standard array-based implementations and pluggable + * SIMD-accelerated operations through the {@link io.github.jbellis.jvector.vector.VectorizationProvider} + * interface. + *

+ * Key Components: + *

    + *
  • Vector Representations - The {@link io.github.jbellis.jvector.vector.types.VectorFloat} + * interface (in the types subpackage) defines the contract for floating-point vectors, with + * {@link io.github.jbellis.jvector.vector.ArrayVectorFloat} providing the standard array-based + * implementation.
  • + *
  • Byte Sequences - The {@link io.github.jbellis.jvector.vector.types.ByteSequence} + * interface (in the types subpackage) represents sequences of bytes, used for compressed vectors + * and other byte-level operations. Implementations include + * {@link io.github.jbellis.jvector.vector.ArrayByteSequence} and + * {@link io.github.jbellis.jvector.vector.ArraySliceByteSequence}.
  • + *
  • Vectorization - {@link io.github.jbellis.jvector.vector.VectorizationProvider} defines + * the interface for SIMD-accelerated vector operations. The default implementation is + * {@link io.github.jbellis.jvector.vector.DefaultVectorizationProvider}, which uses standard + * Java array operations. SIMD-accelerated implementations are provided in separate modules + * using the Panama Vector API.
  • + *
  • Similarity Functions - {@link io.github.jbellis.jvector.vector.VectorSimilarityFunction} + * enumerates the supported similarity metrics (DOT_PRODUCT, COSINE, EUCLIDEAN) and provides + * methods for computing vector similarity scores.
  • + *
  • Vector Utilities - {@link io.github.jbellis.jvector.vector.VectorUtil} provides static + * utility methods for common vector operations, delegating to the appropriate + * {@link io.github.jbellis.jvector.vector.VectorUtilSupport} implementation for performance.
  • + *
  • Matrix Operations - {@link io.github.jbellis.jvector.vector.Matrix} provides matrix + * operations for vectors, used in quantization and other linear algebra operations.
  • + *
+ *

+ * Usage Example: + *

{@code
+ * // Create a vector
+ * VectorFloat vector = ArrayVectorFloat.create(new float[]{1.0f, 2.0f, 3.0f});
+ *
+ * // Compute similarity
+ * float similarity = VectorSimilarityFunction.COSINE.compare(vector1, vector2);
+ *
+ * // Use vector utilities
+ * float norm = VectorUtil.norm(vector);
+ * }
+ * + * @see io.github.jbellis.jvector.vector.types + * @see io.github.jbellis.jvector.vector.VectorizationProvider + * @see io.github.jbellis.jvector.vector.VectorSimilarityFunction + * @see io.github.jbellis.jvector.vector.VectorUtil + */ +package io.github.jbellis.jvector.vector; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java index 1ebbe8196..77bfc8809 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java @@ -19,39 +19,179 @@ import io.github.jbellis.jvector.util.Accountable; import java.util.Objects; +/** + * A generic interface for accessing and manipulating byte sequences backed by various storage types. + *

+ * This interface provides a uniform abstraction over different byte storage implementations, + * allowing efficient access to byte data through a common API. The storage type {@code T} + * represents the underlying backing storage (e.g., byte arrays, direct memory buffers, etc.). + *

+ * Implementations support: + *

    + *
  • Random access to individual bytes via {@link #get(int)} and {@link #set(int, byte)}
  • + *
  • Little-endian short operations via {@link #setLittleEndianShort(int, short)}
  • + *
  • Bulk operations including {@link #copyFrom(ByteSequence, int, int, int)} and {@link #zero()}
  • + *
  • Sequence slicing and copying for efficient memory management
  • + *
  • Value-based equality comparison through {@link #equalTo(Object)}
  • + *
+ *

+ * ByteSequence is designed to be used in performance-critical contexts where direct byte + * manipulation is required, such as vector operations and low-level data processing. + * + * @param the type of the backing storage + */ public interface ByteSequence extends Accountable { /** - * @return entire sequence backing storage + * Returns the entire backing storage for this byte sequence. + *

+ * The returned object represents the underlying storage implementation, + * which may be a byte array, ByteBuffer, or other storage mechanism + * depending on the concrete implementation. + * + * @return the backing storage object of type {@code T} */ T get(); + /** + * Returns the offset within the backing storage where this sequence begins. + *

+ * This offset is used in conjunction with {@link #length()} to define the + * valid range of bytes in the backing storage that belong to this sequence. + * For a sequence that starts at the beginning of its backing storage, this + * method returns 0. + * + * @return the starting offset in bytes, zero-based + */ int offset(); + /** + * Returns the number of bytes in this sequence. + *

+ * Valid indices for {@link #get(int)} and {@link #set(int, byte)} operations + * range from 0 (inclusive) to the value returned by this method (exclusive). + * + * @return the length of this sequence in bytes + */ int length(); + /** + * Returns the byte value at the specified index within this sequence. + *

+ * The index is relative to the beginning of this sequence, not the underlying + * backing storage. Valid indices range from 0 to {@link #length()} - 1. + * + * @param i the index of the byte to retrieve, zero-based + * @return the byte value at the specified index + * @throws IndexOutOfBoundsException if the index is negative or greater than or equal to {@link #length()} + */ byte get(int i); + /** + * Sets the byte value at the specified index within this sequence. + *

+ * The index is relative to the beginning of this sequence, not the underlying + * backing storage. Valid indices range from 0 to {@link #length()} - 1. + * + * @param i the index where the byte should be set, zero-based + * @param value the byte value to set + * @throws IndexOutOfBoundsException if the index is negative or greater than or equal to {@link #length()} + */ void set(int i, byte value); /** - * @param shortIndex index (as if this was a short array) inside the sequence to set the short value - * @param value short value to set + * Sets a short value in little-endian byte order at the specified index. + *

+ * This method treats the byte sequence as an array of shorts, where each short + * occupies 2 bytes. The {@code shortIndex} parameter specifies which short position + * to write to (e.g., shortIndex=0 writes to bytes 0-1, shortIndex=1 writes to bytes 2-3). + * The value is stored in little-endian format (least significant byte first). + * + * @param shortIndex the index in short positions (not bytes) where the value should be set + * @param value the short value to set in little-endian byte order + * @throws IndexOutOfBoundsException if the short position would exceed the sequence bounds */ void setLittleEndianShort(int shortIndex, short value); + /** + * Sets all bytes in this sequence to zero. + *

+ * This method efficiently clears the entire byte sequence by writing zero to each + * position from 0 to {@link #length()} - 1. + */ void zero(); + /** + * Copies bytes from another ByteSequence into this sequence. + *

+ * This method performs a bulk copy operation, transferring {@code length} bytes + * from the source sequence starting at {@code srcOffset} to this sequence starting + * at {@code destOffset}. The source and destination may use different backing storage + * types, as indicated by the wildcard parameter type. + *

+ * The source and destination regions must not overlap if both sequences share the + * same backing storage. The behavior of overlapping copies is implementation-dependent. + * + * @param src the source ByteSequence to copy from + * @param srcOffset the starting offset in the source sequence + * @param destOffset the starting offset in this sequence + * @param length the number of bytes to copy + * @throws IndexOutOfBoundsException if the copy operation would read beyond the source + * sequence bounds or write beyond this sequence bounds + * @throws NullPointerException if {@code src} is null + */ void copyFrom(ByteSequence src, int srcOffset, int destOffset, int length); + /** + * Creates an independent copy of this ByteSequence. + *

+ * The returned sequence contains the same byte values as this sequence but uses + * a separate backing storage. Modifications to the copy will not affect this + * sequence, and vice versa. The copy has the same length as the original. + * + * @return a new ByteSequence containing a copy of this sequence's data + */ ByteSequence copy(); + /** + * Creates a new ByteSequence that represents a subsequence of this sequence. + *

+ * The returned slice shares the same backing storage as this sequence but has + * different offset and length values. This allows efficient sub-sequence access + * without copying data. Modifications to the slice will affect the original + * sequence and vice versa. + *

+ * The slice's valid byte range starts at the specified {@code offset} within + * this sequence and extends for {@code length} bytes. + * + * @param offset the starting position within this sequence for the slice + * @param length the number of bytes to include in the slice + * @return a new ByteSequence view representing the specified subsequence + * @throws IndexOutOfBoundsException if {@code offset} is negative, {@code length} + * is negative, or {@code offset + length} exceeds this sequence's length + */ ByteSequence slice(int offset, int length); /** - * Two ByteSequences are equal if they have the same length and the same bytes at each position. - * @param o the other object to compare to - * @return true if the two ByteSequences are equal + * Compares this ByteSequence to another object for byte-wise equality. + *

+ * Two ByteSequences are considered equal if and only if: + *

    + *
  • They have the same {@link #length()}
  • + *
  • They contain the same byte value at each corresponding position
  • + *
+ *

+ * This method performs value-based comparison rather than reference equality. + * It can compare ByteSequences with different backing storage types, as it only + * examines the logical byte content. + *

+ * Note: This is a utility method for value comparison. Implementations should + * not override {@code Object.equals()} with this logic to maintain proper + * collection behavior if needed. + * + * @param o the object to compare to, which may be any type + * @return {@code true} if {@code o} is a ByteSequence with identical length + * and byte content; {@code false} otherwise */ default boolean equalTo(Object o) { if (this == o) return true; @@ -65,7 +205,22 @@ default boolean equalTo(Object o) { } /** - * @return a hash code for this ByteSequence + * Computes a hash code for this ByteSequence based on its byte content. + *

+ * The hash code is calculated by iterating through all bytes in the sequence + * and combining their values using a standard polynomial rolling hash algorithm + * (multiplying by 31 and adding each non-zero byte value). This ensures that: + *

    + *
  • ByteSequences with identical content produce the same hash code
  • + *
  • Zero bytes are optimized out to improve performance for sparse sequences
  • + *
  • The hash code is consistent with {@link #equalTo(Object)}
  • + *
+ *

+ * Note: Like {@link #equalTo(Object)}, this is a utility method. Implementations + * should not override {@code Object.hashCode()} with this logic to maintain proper + * collection behavior if needed. + * + * @return a hash code value for this ByteSequence based on its content */ default int getHashCode() { int result = 1; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorFloat.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorFloat.java index 636fa7d4d..ba56cb537 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorFloat.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorFloat.java @@ -18,29 +18,76 @@ import io.github.jbellis.jvector.util.Accountable; +/** + * Represents a vector of float values with a generic backing storage type. + *

+ * This interface provides abstraction over different vector storage implementations, + * allowing for optimized memory layouts and access patterns. The type parameter {@code T} + * represents the underlying storage mechanism (e.g., float array, ByteBuffer, etc.). + * @param the type of the backing storage + */ public interface VectorFloat extends Accountable { /** - * @return entire vector backing storage + * Returns the entire vector backing storage. + * @return the backing storage */ T get(); + /** + * Returns the length of the vector. + * @return the number of elements in the vector + */ int length(); + /** + * Returns the offset for the element at the specified index in the backing storage. + * The default implementation returns the index itself. + * @param i the logical index + * @return the offset in the backing storage + */ default int offset(int i) { return i; } + /** + * Creates a copy of this vector. + * @return a new VectorFloat instance with the same values + */ VectorFloat copy(); + /** + * Copies elements from another vector into this vector. + * @param src the source vector to copy from + * @param srcOffset the starting offset in the source vector + * @param destOffset the starting offset in this vector + * @param length the number of elements to copy + */ void copyFrom(VectorFloat src, int srcOffset, int destOffset, int length); + /** + * Returns the float value at the specified index. + * @param i the index + * @return the float value at the index + */ float get(int i); + /** + * Sets the float value at the specified index. + * @param i the index + * @param value the value to set + */ void set(int i, float value); + /** + * Sets all elements in the vector to zero. + */ void zero(); + /** + * Computes a hash code for this vector based on its non-zero elements. + * @return the hash code + */ default int getHashCode() { int result = 1; for (int i = 0; i < length(); i++) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorTypeSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorTypeSupport.java index 409389370..511983203 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorTypeSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorTypeSupport.java @@ -21,6 +21,11 @@ import java.io.DataOutput; import java.io.IOException; +/** + * Provides support for creating, reading, and writing vector types. + * Implementations of this interface handle the low-level details of vector storage + * and I/O operations for both float vectors and byte sequences. + */ public interface VectorTypeSupport { /** * Create a vector from the given data. @@ -42,7 +47,7 @@ public interface VectorTypeSupport { * @param r the reader to read the vector from. * @param size the size of the vector to read. * @return the vector. - * @throws IOException + * @throws IOException if an I/O error occurs */ VectorFloat readFloatVector(RandomAccessReader r, int size) throws IOException; @@ -52,7 +57,7 @@ public interface VectorTypeSupport { * @param size the size of the vector to read. * @param vector the vector to store the read data in. * @param offset the offset in the vector to store the read data at. - * @throws IOException + * @throws IOException if an I/O error occurs */ void readFloatVector(RandomAccessReader r, int size, VectorFloat vector, int offset) throws IOException; @@ -60,7 +65,7 @@ public interface VectorTypeSupport { * Write the given vector to the given DataOutput. * @param out the output to write the vector to. * @param vector the vector to write. - * @throws IOException + * @throws IOException if an I/O error occurs */ void writeFloatVector(DataOutput out, VectorFloat vector) throws IOException; @@ -79,9 +84,28 @@ public interface VectorTypeSupport { */ ByteSequence createByteSequence(int length); + /** + * Read a byte sequence from the given RandomAccessReader. + * @param r the reader to read the sequence from + * @param size the size of the sequence to read + * @return the byte sequence + * @throws IOException if an I/O error occurs + */ ByteSequence readByteSequence(RandomAccessReader r, int size) throws IOException; + /** + * Read a byte sequence from the given RandomAccessReader and store it in the given sequence. + * @param r the reader to read the sequence from + * @param sequence the sequence to store the read data in + * @throws IOException if an I/O error occurs + */ void readByteSequence(RandomAccessReader r, ByteSequence sequence) throws IOException; + /** + * Write the given byte sequence to the given DataOutput. + * @param out the output to write the sequence to + * @param sequence the sequence to write + * @throws IOException if an I/O error occurs + */ void writeByteSequence(DataOutput out, ByteSequence sequence) throws IOException; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/package-info.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/package-info.java new file mode 100644 index 000000000..9a7b928ef --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/package-info.java @@ -0,0 +1,78 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides type abstractions and utilities for vector and byte sequence operations. + *

+ * This package defines interfaces and utilities that enable flexible and efficient + * manipulation of vector data and byte sequences in the JVector library. The primary + * goals are to: + *

    + *
  • Abstract over different storage implementations (arrays, buffers, memory-mapped files)
  • + *
  • Provide a uniform API for vector and byte operations
  • + *
  • Enable performance optimizations through pluggable storage backends
  • + *
  • Support both primitive and object-based vector representations
  • + *
+ * + *

Key Types

+ *
+ *
{@link io.github.jbellis.jvector.vector.types.ByteSequence}
+ *
A generic interface for accessing and manipulating byte sequences with various + * backing storage types. Supports random access, bulk operations, slicing, and + * value-based equality comparison. Used extensively for low-level vector data + * storage and manipulation.
+ * + *
{@link io.github.jbellis.jvector.vector.types.VectorFloat}
+ *
Provides abstraction over float vector representations, allowing implementations + * to choose between different storage strategies optimized for their use case.
+ * + *
{@link io.github.jbellis.jvector.vector.types.VectorTypeSupport}
+ *
Factory and utility class for creating and managing vector type implementations. + * Serves as the primary entry point for obtaining vector and byte sequence instances.
+ *
+ * + *

Usage Example

+ *
{@code
+ * // Creating a byte sequence
+ * ByteSequence sequence = VectorTypeSupport.createByteSequence(1024);
+ *
+ * // Setting values
+ * sequence.set(0, (byte) 42);
+ * sequence.setLittleEndianShort(1, (short) 1000);
+ *
+ * // Creating a slice view
+ * ByteSequence slice = sequence.slice(100, 200);
+ *
+ * // Copying data
+ * ByteSequence copy = sequence.copy();
+ * copy.copyFrom(sequence, 0, 512, 512);
+ * }
+ * + *

Design Principles

+ *
    + *
  • Abstraction: Interfaces abstract over concrete storage to allow + * flexibility in implementation choice without affecting client code.
  • + *
  • Zero-copy operations: Methods like {@link io.github.jbellis.jvector.vector.types.ByteSequence#slice(int, int)} + * enable efficient sub-sequence access without data duplication.
  • + *
  • Performance-first: All APIs are designed with performance-critical + * use cases in mind, minimizing overhead and enabling JIT optimizations.
  • + *
  • Type safety: Generic type parameters ensure compile-time type safety + * while maintaining flexibility.
  • + *
+ * + * @since 1.0 + */ +package io.github.jbellis.jvector.vector.types; diff --git a/jvector-examples/pom.xml b/jvector-examples/pom.xml index bace97046..a733eb45a 100644 --- a/jvector-examples/pom.xml +++ b/jvector-examples/pom.xml @@ -24,6 +24,21 @@ ${project.parent.basedir} + + org.apache.maven.plugins + maven-javadoc-plugin + + + --add-modules=jdk.incubator.vector + + 22 + false + true + + io.github.jbellis:* + + + org.apache.maven.plugins maven-assembly-plugin diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/AutoBenchYAML.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/AutoBenchYAML.java index 86dc74659..b77cfd65e 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/AutoBenchYAML.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/AutoBenchYAML.java @@ -52,6 +52,10 @@ * continue from where it left off rather than starting over from the beginning. */ public class AutoBenchYAML { + /** + * Constructs an AutoBenchYAML. + */ + public AutoBenchYAML() {} private static final Logger logger = LoggerFactory.getLogger(AutoBenchYAML.class); /** @@ -70,6 +74,11 @@ private static List getAllDatasetNames() { return allDatasets; } + /** + * Main entry point for the benchmark runner. + * @param args the command line arguments + * @throws IOException if an I/O error occurs + */ public static void main(String[] args) throws IOException { // Check for --output argument (required for this class) String outputPath = null; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java index 4623cbe9d..ec4655aac 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java @@ -39,6 +39,16 @@ * Tests GraphIndexes against vectors from various datasets */ public class Bench { + /** + * Constructs a Bench. + */ + public Bench() {} + + /** + * Main entry point for the benchmark. + * @param args the command line arguments + * @throws IOException if an I/O error occurs + */ public static void main(String[] args) throws IOException { System.out.println("Heap space available is " + Runtime.getRuntime().maxMemory()); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench2D.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench2D.java index dc639f5ea..2d9b4df44 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench2D.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench2D.java @@ -33,6 +33,16 @@ * Tests GraphIndexes against vectors from a 2D dataset */ public class Bench2D { + /** + * Constructs a Bench2D. + */ + public Bench2D() {} + + /** + * Main entry point for the 2D benchmark. + * @param args the command line arguments + * @throws IOException if an I/O error occurs + */ public static void main(String[] args) throws IOException { System.out.println("Heap space available is " + Runtime.getRuntime().maxMemory()); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/BenchResult.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/BenchResult.java index 5eeeff736..65995b524 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/BenchResult.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/BenchResult.java @@ -17,12 +17,28 @@ import java.util.Map; +/** + * Benchmark result data container. + */ public class BenchResult { + /** The dataset name. */ public String dataset; + /** The benchmark parameters. */ public Map parameters; + /** The benchmark metrics. */ public Map metrics; + /** + * Constructs a BenchResult. + */ public BenchResult() {} + + /** + * Constructs a BenchResult with the specified values. + * @param dataset the dataset name + * @param parameters the benchmark parameters + * @param metrics the benchmark metrics + */ public BenchResult(String dataset, Map parameters, Map metrics) { this.dataset = dataset; this.parameters = parameters; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/BenchYAML.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/BenchYAML.java index e81a84863..3f8390e66 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/BenchYAML.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/BenchYAML.java @@ -32,6 +32,16 @@ * Tests GraphIndexes against vectors from various datasets */ public class BenchYAML { + /** + * Constructs a BenchYAML. + */ + public BenchYAML() {} + + /** + * Main entry point for the YAML-based benchmark. + * @param args the command line arguments + * @throws IOException if an I/O error occurs + */ public static void main(String[] args) throws IOException { // args is one of: // - a list of regexes, possibly needing to be split by whitespace. diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/DistancesNVQ.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/DistancesNVQ.java index f3869069c..c934872db 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/DistancesNVQ.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/DistancesNVQ.java @@ -28,8 +28,21 @@ import static java.lang.Math.abs; -// this class uses explicit typing instead of `var` for easier reading when excerpted for instructional use +/** + * Tests NVQ encodings with various datasets. + * This class uses explicit typing instead of var for easier reading when excerpted for instructional use. + */ public class DistancesNVQ { + private DistancesNVQ() { + } + + /** + * Tests NVQ encodings for the given dataset. + * @param filenameBase the base vectors file + * @param filenameQueries the query vectors file + * @param vsf the similarity function + * @throws IOException if an error occurs + */ public static void testNVQEncodings(String filenameBase, String filenameQueries, VectorSimilarityFunction vsf) throws IOException { List> vectors = SiftLoader.readFvecs(filenameBase); List> queries = SiftLoader.readFvecs(filenameQueries); @@ -111,6 +124,10 @@ public static void testNVQEncodings(String filenameBase, String filenameQueries, System.out.println("--"); } + /** + * Runs NVQ test on SIFT dataset. + * @throws IOException if an error occurs + */ public static void runSIFT() throws IOException { System.out.println("Running siftsmall"); @@ -119,6 +136,10 @@ public static void runSIFT() throws IOException { testNVQEncodings(baseVectors, queryVectors, VectorSimilarityFunction.COSINE); } + /** + * Runs NVQ test on ADA dataset. + * @throws IOException if an error occurs + */ public static void runADA() throws IOException { System.out.println("Running ada_002"); @@ -127,6 +148,10 @@ public static void runADA() throws IOException { testNVQEncodings(baseVectors, queryVectors, VectorSimilarityFunction.COSINE); } + /** + * Runs NVQ test on Colbert dataset. + * @throws IOException if an error occurs + */ public static void runColbert() throws IOException { System.out.println("Running colbertv2"); @@ -135,6 +160,10 @@ public static void runColbert() throws IOException { testNVQEncodings(baseVectors, queryVectors, VectorSimilarityFunction.COSINE); } + /** + * Runs NVQ test on OpenAI 3072 dataset. + * @throws IOException if an error occurs + */ public static void runOpenai3072() throws IOException { System.out.println("Running text-embedding-3-large_3072"); @@ -143,6 +172,11 @@ public static void runOpenai3072() throws IOException { testNVQEncodings(baseVectors, queryVectors, VectorSimilarityFunction.COSINE); } + /** + * Main entry point. + * @param args command line arguments + * @throws IOException if an error occurs + */ public static void main(String[] args) throws IOException { runSIFT(); runADA(); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java index a4d62645f..91fdf1638 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java @@ -75,10 +75,20 @@ import java.util.stream.IntStream; /** - * Tests a grid of configurations against a dataset + * Tests a grid of configurations against a dataset. + * This class provides utilities for running comprehensive benchmark sweeps across multiple + * configuration parameters including graph construction settings, compression strategies, + * and search parameters. */ public class Grid { + /** + * Private constructor to prevent instantiation of this utility class. + */ + private Grid() { + throw new AssertionError("Grid is a utility class and should not be instantiated"); + } + private static final String pqCacheDir = "pq_cache"; private static final String dirPrefix = "BenchGraphDir"; @@ -327,6 +337,10 @@ private static BuilderWithSuppliers builderWithSuppliers(Set features return new BuilderWithSuppliers(builder, suppliers); } + /** + * Sets the diagnostic level for benchmarks. + * @param diagLevel the diagnostic level + */ public static void setDiagnosticLevel(int diagLevel) { diagnostic_level = diagLevel; } @@ -346,10 +360,23 @@ private static DiagnosticLevel getDiagnosticLevel() { } } + /** + * Pairs an OnDiskGraphIndexWriter builder with feature state suppliers. + * This class associates graph writer configuration with the functions that provide + * feature-specific state during graph construction. + */ private static class BuilderWithSuppliers { + /** The graph index writer builder. */ public final OnDiskGraphIndexWriter.Builder builder; + /** Map of feature IDs to their state supplier functions. */ public final Map> suppliers; + /** + * Constructs a BuilderWithSuppliers pairing a builder with its state suppliers. + * + * @param builder the graph index writer builder + * @param suppliers map of feature IDs to state supplier functions + */ public BuilderWithSuppliers(OnDiskGraphIndexWriter.Builder builder, Map> suppliers) { this.builder = builder; this.suppliers = suppliers; @@ -533,6 +560,21 @@ private static List setupBenchmarks(Map> be return benchmarks; } + /** + * Runs all configurations and collects benchmark results. + * @param ds the dataset + * @param mGrid the M parameter grid + * @param efConstructionGrid the efConstruction parameter grid + * @param neighborOverflowGrid the neighbor overflow parameter grid + * @param addHierarchyGrid the add hierarchy parameter grid + * @param featureSets the feature sets to test + * @param buildCompressors the build compressor functions + * @param compressionGrid the compression parameter grid + * @param topKGrid the topK parameter grid + * @param usePruningGrid the use pruning parameter grid + * @return the list of benchmark results + * @throws IOException if an error occurs + */ public static List runAllAndCollectResults( DataSet ds, List mGrid, @@ -655,6 +697,9 @@ private static VectorCompressor getCompressor(Function queryVector, ImmutableGraphIndex.View view) { // if we're not compressing then just use the exact score function if (cv == null) { @@ -689,10 +740,18 @@ public SearchScoreProvider scoreProviderFor(VectorFloat queryVector, Immutabl return new DefaultSearchScoreProvider(asf, rr); } + /** + * Gets the graph searcher for this thread. + * @return the graph searcher + */ public GraphSearcher getSearcher() { return searchers.get(); } + /** + * Gets the dataset. + * @return the dataset + */ public DataSet getDataSet() { return ds; } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/HelloVectorWorld.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/HelloVectorWorld.java index a09d1a0e7..70441212a 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/HelloVectorWorld.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/HelloVectorWorld.java @@ -26,6 +26,14 @@ * Tests GraphIndexes against vectors from various datasets */ public class HelloVectorWorld { + private HelloVectorWorld() { + } + + /** + * Main entry point. + * @param args command line arguments + * @throws IOException if an error occurs + */ public static void main(String[] args) throws IOException { System.out.println("Heap space available is " + Runtime.getRuntime().maxMemory()); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java index 3c125fd2b..883a1ea4e 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java @@ -361,6 +361,10 @@ static void help() { System.exit(1); } + /** + * Main entry point. + * @param args command line arguments + */ public static void main(String[] args) { String socketFile = System.getProperty("java.io.tmpdir") + "/jvector.sock"; if (args.length > 1) diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java index e0785e28b..13aa837d5 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java @@ -66,11 +66,28 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -// this class uses explicit typing instead of `var` for easier reading when excerpted for instructional use +/** + * Demonstration examples showing various ways to build and search graph indexes with JVector. + * This class uses explicit typing instead of var for easier reading when excerpted for instructional use. + * Each method demonstrates a different approach from simple in-memory indexes to complex + * on-disk indexes with compression. + */ public class SiftSmall { private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); - // hello world + /** + * Private constructor to prevent instantiation of this example class. + */ + private SiftSmall() { + throw new AssertionError("SiftSmall is an example class and should not be instantiated"); + } + + /** + * Demonstrates the simplest case: building an in-memory graph index and performing a search. + * + * @param baseVectors the vectors to index + * @throws IOException if an I/O error occurs + */ public static void siftInMemory(List> baseVectors) throws IOException { // infer the dimensionality from the first vector int originalDimension = baseVectors.get(0).length(); @@ -105,7 +122,12 @@ public static void siftInMemory(List> baseVectors) throws IOExcep } } - // show how to use explicit GraphSearcher objects + /** + * Shows how to use explicit GraphSearcher objects for more control over search operations. + * + * @param baseVectors the vectors to index + * @throws IOException if an I/O error occurs + */ public static void siftInMemoryWithSearcher(List> baseVectors) throws IOException { int originalDimension = baseVectors.get(0).length(); RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension); @@ -126,7 +148,14 @@ public static void siftInMemoryWithSearcher(List> baseVectors) th } } - // call out to testRecall instead of doing manual searches + /** + * Demonstrates measuring search quality using recall metrics against ground truth results. + * + * @param baseVectors the vectors to index + * @param queryVectors the query vectors to search for + * @param groundTruth the ground truth nearest neighbors for each query + * @throws IOException if an I/O error occurs + */ public static void siftInMemoryWithRecall(List> baseVectors, List> queryVectors, List> groundTruth) throws IOException { int originalDimension = baseVectors.get(0).length(); RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension); @@ -140,7 +169,15 @@ public static void siftInMemoryWithRecall(List> baseVectors, List } } - // write and load index to and from disk + /** + * Demonstrates writing an index to disk and loading it back for searching. + * This shows the basic persistence capabilities of JVector indexes. + * + * @param baseVectors the vectors to index + * @param queryVectors the query vectors to search for + * @param groundTruth the ground truth nearest neighbors for each query + * @throws IOException if an I/O error occurs + */ public static void siftPersisted(List> baseVectors, List> queryVectors, List> groundTruth) throws IOException { int originalDimension = baseVectors.get(0).length(); RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension); @@ -164,7 +201,16 @@ public static void siftPersisted(List> baseVectors, List> baseVectors, List> queryVectors, List> groundTruth) throws IOException { int originalDimension = baseVectors.get(0).length(); RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension); @@ -209,6 +255,16 @@ public static void siftDiskAnn(List> baseVectors, List> baseVectors, List> queryVectors, List> groundTruth) throws IOException { int originalDimension = baseVectors.get(0).length(); RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension); @@ -270,6 +326,15 @@ public static void siftDiskAnnLTM(List> baseVectors, List> baseVectors, List> queryVectors, List> groundTruth) throws IOException { int originalDimension = baseVectors.get(0).length(); RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension); @@ -333,10 +398,13 @@ public static void siftDiskAnnLTMWithNVQ(List> baseVectors, List< } } - // - // Utilities and main() harness - // - + /** + * Generates a random unit vector with the specified dimension. + * The vector is L2-normalized to have unit length. + * + * @param dim the dimension of the vector to generate + * @return a random L2-normalized vector + */ public static VectorFloat randomVector(int dim) { Random R = ThreadLocalRandom.current(); VectorFloat vec = vts.createFloatVector(dim); @@ -378,6 +446,14 @@ private static void testRecall(ImmutableGraphIndex graph, System.out.printf("(%s) Recall: %.4f%n", graphType, recall); } + /** + * Main entry point demonstrating all the example use cases for building and searching + * graph indexes with JVector. Loads the SIFT dataset and runs through various indexing + * strategies from simple in-memory to complex on-disk with compression. + * + * @param args command line arguments (not used) + * @throws IOException if an error occurs reading the dataset files + */ public static void main(String[] args) throws IOException { var siftPath = "siftsmall"; var baseVectors = SiftLoader.readFvecs(String.format("%s/siftsmall_base.fvecs", siftPath)); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/AbstractQueryBenchmark.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/AbstractQueryBenchmark.java index f328aabe7..8fa5acb37 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/AbstractQueryBenchmark.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/AbstractQueryBenchmark.java @@ -16,4 +16,12 @@ package io.github.jbellis.jvector.example.benchmarks; -public abstract class AbstractQueryBenchmark implements QueryBenchmark {} +/** + * Abstract base class for query benchmarks. + */ +public abstract class AbstractQueryBenchmark implements QueryBenchmark { + /** + * Constructs an AbstractQueryBenchmark. + */ + protected AbstractQueryBenchmark() {} +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/AccuracyBenchmark.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/AccuracyBenchmark.java index a99aca6f8..101c3d3ae 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/AccuracyBenchmark.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/AccuracyBenchmark.java @@ -36,10 +36,18 @@ public class AccuracyBenchmark extends AbstractQueryBenchmark { private String formatRecall; private String formatMAP; + /** + * Creates an AccuracyBenchmark with default settings. + * @return the AccuracyBenchmark + */ public static AccuracyBenchmark createDefault() { return new AccuracyBenchmark(true, false, DEFAULT_FORMAT, DEFAULT_FORMAT); } + /** + * Creates an empty AccuracyBenchmark with no metrics enabled. + * @return the AccuracyBenchmark + */ public static AccuracyBenchmark createEmpty() { return new AccuracyBenchmark(false, false, DEFAULT_FORMAT, DEFAULT_FORMAT); } @@ -51,20 +59,38 @@ private AccuracyBenchmark(boolean computeRecall, boolean computeMAP, String form this.formatMAP = formatMAP; } + /** + * Enables recall display with default format. + * @return this AccuracyBenchmark + */ public AccuracyBenchmark displayRecall() { return displayRecall(DEFAULT_FORMAT); } + /** + * Enables recall display with the specified format. + * @param format the format string + * @return this AccuracyBenchmark + */ public AccuracyBenchmark displayRecall(String format) { this.computeRecall = true; this.formatRecall = format; return this; } + /** + * Enables MAP display with default format. + * @return this AccuracyBenchmark + */ public AccuracyBenchmark displayMAP() { return displayMAP(DEFAULT_FORMAT); } + /** + * Enables MAP display with the specified format. + * @param format the format string + * @return this AccuracyBenchmark + */ public AccuracyBenchmark displayMAP(String format) { this.computeMAP = true; this.formatMAP = format; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/BenchmarkTablePrinter.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/BenchmarkTablePrinter.java index cba0f839f..1c65d89da 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/BenchmarkTablePrinter.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/BenchmarkTablePrinter.java @@ -33,6 +33,9 @@ public class BenchmarkTablePrinter { private String headerFmt; private String rowFmt; + /** + * Constructs a BenchmarkTablePrinter. + */ public BenchmarkTablePrinter() { headerFmt = null; rowFmt = null; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/CountBenchmark.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/CountBenchmark.java index d4fe68456..01058cc3f 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/CountBenchmark.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/CountBenchmark.java @@ -37,10 +37,18 @@ public class CountBenchmark extends AbstractQueryBenchmark { private String formatAvgNodesExpanded; private String formatAvgNodesExpandedBaseLayer; + /** + * Creates a CountBenchmark with default settings. + * @return the CountBenchmark + */ public static CountBenchmark createDefault() { return new CountBenchmark(true, false, false, DEFAULT_FORMAT, DEFAULT_FORMAT, DEFAULT_FORMAT); } + /** + * Creates an empty CountBenchmark with no metrics enabled. + * @return the CountBenchmark + */ public static CountBenchmark createEmpty() { return new CountBenchmark(false, false, false, DEFAULT_FORMAT, DEFAULT_FORMAT, DEFAULT_FORMAT); } @@ -55,30 +63,57 @@ private CountBenchmark(boolean computeAvgNodesVisited, boolean computeAvgNodesEx this.formatAvgNodesExpandedBaseLayer = formatAvgNodesExpandedBaseLayer; } + /** + * Enables display of average nodes visited metric. + * @return this CountBenchmark + */ public CountBenchmark displayAvgNodesVisited() { return displayAvgNodesVisited(DEFAULT_FORMAT); } + /** + * Enables display of average nodes visited metric with custom format. + * @param format the format string + * @return this CountBenchmark + */ public CountBenchmark displayAvgNodesVisited(String format) { this.computeAvgNodesVisited = true; this.formatAvgNodesVisited = format; return this; } + /** + * Enables display of average nodes expanded metric. + * @return this CountBenchmark + */ public CountBenchmark displayAvgNodesExpanded() { return displayAvgNodesExpanded(DEFAULT_FORMAT); } + /** + * Enables display of average nodes expanded metric with custom format. + * @param format the format string + * @return this CountBenchmark + */ public CountBenchmark displayAvgNodesExpanded(String format) { this.computeAvgNodesExpanded = true; this.formatAvgNodesExpanded = format; return this; } + /** + * Enables display of average nodes expanded base layer metric. + * @return this CountBenchmark + */ public CountBenchmark displayAvgNodesExpandedBaseLayer() { return displayAvgNodesExpandedBaseLayer(DEFAULT_FORMAT); } + /** + * Enables display of average nodes expanded base layer metric with custom format. + * @param format the format string + * @return this CountBenchmark + */ public CountBenchmark displayAvgNodesExpandedBaseLayer(String format) { this.computeAvgNodesExpandedBaseLayer = true; this.formatAvgNodesExpandedBaseLayer = format; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/ExecutionTimeBenchmark.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/ExecutionTimeBenchmark.java index 449a8409f..560caeb09 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/ExecutionTimeBenchmark.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/ExecutionTimeBenchmark.java @@ -30,6 +30,10 @@ public class ExecutionTimeBenchmark extends AbstractQueryBenchmark { private static volatile long SINK; private String format; + /** + * Creates an ExecutionTimeBenchmark with default format. + * @return the ExecutionTimeBenchmark + */ public static ExecutionTimeBenchmark createDefault() { return new ExecutionTimeBenchmark(DEFAULT_FORMAT); } @@ -38,6 +42,11 @@ private ExecutionTimeBenchmark(String format) { this.format = format; } + /** + * Sets the output format. + * @param format the format string + * @return this ExecutionTimeBenchmark + */ public ExecutionTimeBenchmark setFormat(String format) { this.format = format; return this; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/LatencyBenchmark.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/LatencyBenchmark.java index 861a8d2be..6a051f453 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/LatencyBenchmark.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/LatencyBenchmark.java @@ -39,10 +39,18 @@ public class LatencyBenchmark extends AbstractQueryBenchmark { private static volatile long SINK; + /** + * Creates a LatencyBenchmark with default settings. + * @return the LatencyBenchmark + */ public static LatencyBenchmark createDefault() { return new LatencyBenchmark(true, false, false, DEFAULT_FORMAT, DEFAULT_FORMAT, DEFAULT_FORMAT); } + /** + * Creates an empty LatencyBenchmark with no metrics enabled. + * @return the LatencyBenchmark + */ public static LatencyBenchmark createEmpty() { return new LatencyBenchmark(false, false, false, DEFAULT_FORMAT, DEFAULT_FORMAT, DEFAULT_FORMAT); } @@ -57,30 +65,57 @@ private LatencyBenchmark(boolean computeAvgLatency, boolean computeLatencySTD, b this.formatP999Latency = formatP999Latency; } + /** + * Enables display of average latency metric. + * @return this LatencyBenchmark + */ public LatencyBenchmark displayAvgLatency() { return displayAvgLatency(DEFAULT_FORMAT); } + /** + * Enables display of average latency metric with custom format. + * @param format the format string + * @return this LatencyBenchmark + */ public LatencyBenchmark displayAvgLatency(String format) { this.computeAvgLatency = true; this.formatAvgLatency = format; return this; } + /** + * Enables display of latency standard deviation metric. + * @return this LatencyBenchmark + */ public LatencyBenchmark displayLatencySTD() { return displayLatencySTD(DEFAULT_FORMAT); } + /** + * Enables display of latency standard deviation metric with custom format. + * @param format the format string + * @return this LatencyBenchmark + */ public LatencyBenchmark displayLatencySTD(String format) { this.computeLatencySTD = true; this.formatLatencySTD = format; return this; } + /** + * Enables display of P999 latency metric. + * @return this LatencyBenchmark + */ public LatencyBenchmark displayP999Latency() { return displayP999Latency(DEFAULT_FORMAT); } + /** + * Enables display of P999 latency metric with custom format. + * @param format the format string + * @return this LatencyBenchmark + */ public LatencyBenchmark displayP999Latency(String format) { this.computeP999Latency = true; this.formatP999Latency = format; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/Metric.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/Metric.java index 3cbb62b30..e8fd4b459 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/Metric.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/Metric.java @@ -36,10 +36,31 @@ private Metric(String header, String fmtSpec, double value) { this.value = value; } + /** + * Gets the header. + * @return the header + */ public String getHeader() { return header; } + + /** + * Gets the format specification. + * @return the format specification + */ public String getFmtSpec() { return fmtSpec; } + + /** + * Gets the value. + * @return the value + */ public double getValue() { return value; } + /** + * Creates a new Metric. + * @param header the header + * @param fmtSpec the format specification + * @param value the value + * @return the Metric + */ public static Metric of(String header, String fmtSpec, double value) { return new Metric(header, fmtSpec, value); } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryBenchmark.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryBenchmark.java index c5039569b..b0af787d2 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryBenchmark.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryBenchmark.java @@ -24,8 +24,21 @@ * A common interface for all search benchmarks. */ public interface QueryBenchmark { + /** + * Gets the name of this benchmark. + * @return the benchmark name + */ String getBenchmarkName(); + /** + * Runs the benchmark. + * @param cs the configured system + * @param topK the topK parameter + * @param rerankK the rerankK parameter + * @param usePruning the usePruning parameter + * @param queryRuns the queryRuns parameter + * @return the list of metrics + */ List runBenchmark( ConfiguredSystem cs, int topK, diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryExecutor.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryExecutor.java index 9ec728808..4f57f0412 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryExecutor.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryExecutor.java @@ -21,7 +21,18 @@ import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.vector.types.VectorFloat; +/** + * Utility class for executing queries against a configured system. + * Provides methods to execute single queries with various parameters. + */ public class QueryExecutor { + + /** + * Private constructor to prevent instantiation of this utility class. + */ + private QueryExecutor() { + throw new AssertionError("QueryExecutor is a utility class and should not be instantiated"); + } /** * Executes the query at index i using the given parameters. * @@ -40,7 +51,17 @@ public static SearchResult executeQuery(ConfiguredSystem cs, int topK, int reran return searcher.search(sf, topK, rerankK, 0.0f, 0.0f, Bits.ALL); } - // Overload to allow single query injection (e.g., for warm-up with random vectors) + /** + * Executes a query using a provided query vector instead of retrieving it from the dataset. + * This overload allows single query injection for operations like warm-up with random vectors. + * + * @param cs the configured system + * @param topK the number of top results to return + * @param rerankK the number of candidates for reranking + * @param usePruning whether to use pruning during search + * @param queryVector the query vector to search with + * @return the SearchResult for the given query vector + */ public static SearchResult executeQuery(ConfiguredSystem cs, int topK, int rerankK, boolean usePruning, VectorFloat queryVector ) { var searcher = cs.getSearcher(); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryTester.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryTester.java index 969e25b27..5f3073c9c 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryTester.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryTester.java @@ -24,13 +24,16 @@ import io.github.jbellis.jvector.example.Grid.ConfiguredSystem; /** - * Orchestrates running a set of QueryBenchmark instances - * and collects their summary results. + * Orchestrates running a set of QueryBenchmark instances and collects their summary results. + * Provides a simple interface for executing multiple benchmarks sequentially and gathering + * their metrics for analysis and comparison. */ public class QueryTester { private final List benchmarks; /** + * Constructs a QueryTester with the specified benchmarks to execute. + * * @param benchmarks the benchmarks to run, in the order provided */ public QueryTester(List benchmarks) { @@ -38,14 +41,14 @@ public QueryTester(List benchmarks) { } /** - * Run each benchmark once and return a map from each Summary class - * to its returned summary instance. + * Runs each benchmark once and returns the collected metrics. * * @param cs the configured system under test - * @param topK the top‑K parameter for all benchmarks - * @param rerankK the rerank‑K parameter + * @param topK the top-K parameter for all benchmarks + * @param rerankK the rerank-K parameter * @param usePruning whether to enable pruning * @param queryRuns number of runs for each benchmark + * @return a list of metrics from all benchmarks */ public List run( ConfiguredSystem cs, diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/ThroughputBenchmark.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/ThroughputBenchmark.java index 27b99fa71..f0b28e4ff 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/ThroughputBenchmark.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/ThroughputBenchmark.java @@ -51,8 +51,15 @@ public class ThroughputBenchmark extends AbstractQueryBenchmark { private String formatMaxQps; private BenchmarkDiagnostics diagnostics; + /** Vector type support instance for creating and manipulating vectors during benchmarking. */ VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + /** + * Creates a default throughput benchmark instance with standard settings. + * Configures 3 warmup runs, 3 test runs, and displays average QPS. + * + * @return a new ThroughputBenchmark with default configuration + */ public static ThroughputBenchmark createDefault() { return new ThroughputBenchmark(3, 3, true, false, false, @@ -60,6 +67,14 @@ public static ThroughputBenchmark createDefault() { DiagnosticLevel.NONE); } + /** + * Creates an empty throughput benchmark with no metrics initially enabled. + * Metrics can be configured using the display methods after creation. + * + * @param numWarmupRuns the number of warmup runs to perform before measurement + * @param numTestRuns the number of test runs to perform for measurement + * @return a new ThroughputBenchmark with specified run counts and no metrics enabled + */ public static ThroughputBenchmark createEmpty(int numWarmupRuns, int numTestRuns) { return new ThroughputBenchmark(numWarmupRuns, numTestRuns, false, false, false, @@ -82,30 +97,63 @@ private ThroughputBenchmark(int numWarmupRuns, int numTestRuns, this.diagnostics = new BenchmarkDiagnostics(diagnosticLevel); } + /** + * Enables display of average queries per second (QPS) using the default format. + * + * @return this benchmark instance for method chaining + */ public ThroughputBenchmark displayAvgQps() { return displayAvgQps(DEFAULT_FORMAT); } + /** + * Enables display of average queries per second (QPS) with a custom format string. + * + * @param format the format string for displaying the average QPS value (e.g., ".1f") + * @return this benchmark instance for method chaining + */ public ThroughputBenchmark displayAvgQps(String format) { this.computeAvgQps = true; this.formatAvgQps = format; return this; } + /** + * Enables display of median queries per second (QPS) using the default format. + * + * @return this benchmark instance for method chaining + */ public ThroughputBenchmark displayMedianQps() { return displayMedianQps(DEFAULT_FORMAT); } + /** + * Enables display of median queries per second (QPS) with a custom format string. + * + * @param format the format string for displaying the median QPS value (e.g., ".1f") + * @return this benchmark instance for method chaining + */ public ThroughputBenchmark displayMedianQps(String format) { this.computeMedianQps = true; this.formatMedianQps = format; return this; } + /** + * Enables display of maximum queries per second (QPS) using the default format. + * + * @return this benchmark instance for method chaining + */ public ThroughputBenchmark displayMaxQps() { return displayMaxQps(DEFAULT_FORMAT); } + /** + * Enables display of maximum queries per second (QPS) with a custom format string. + * + * @param format the format string for displaying the maximum QPS value (e.g., ".1f") + * @return this benchmark instance for method chaining + */ public ThroughputBenchmark displayMaxQps(String format) { this.computeMaxQps = true; this.formatMaxQps = format; @@ -113,7 +161,11 @@ public ThroughputBenchmark displayMaxQps(String format) { } /** - * Configure the diagnostic level for this benchmark + * Configures the diagnostic level for this benchmark. + * Higher diagnostic levels provide more detailed performance analysis and recommendations. + * + * @param level the diagnostic level to use during benchmark execution + * @return this benchmark instance for method chaining */ public ThroughputBenchmark withDiagnostics(DiagnosticLevel level) { this.diagnostics = new BenchmarkDiagnostics(level); @@ -125,6 +177,19 @@ public String getBenchmarkName() { return "ThroughputBenchmark"; } + /** + * Executes the throughput benchmark against the configured system. + * Performs warmup runs followed by measured test runs, collecting QPS statistics + * and optional diagnostics. + * + * @param cs the configured system to benchmark + * @param topK the number of top results to return + * @param rerankK the number of candidates to rerank + * @param usePruning whether to use pruning during search + * @param queryRuns the number of query runs (not used in this benchmark) + * @return a list of computed metrics including QPS statistics + * @throws RuntimeException if no metrics are enabled for display + */ @Override public List runBenchmark( ConfiguredSystem cs, diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/diagnostics/BenchmarkDiagnostics.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/diagnostics/BenchmarkDiagnostics.java index b0f71ad06..67bb1bf7b 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/diagnostics/BenchmarkDiagnostics.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/diagnostics/BenchmarkDiagnostics.java @@ -32,6 +32,10 @@ public class BenchmarkDiagnostics { private final List snapshots; private final List timingAnalyses; + /** + * Constructs a BenchmarkDiagnostics with the specified diagnostic level. + * @param level the diagnostic level + */ public BenchmarkDiagnostics(DiagnosticLevel level) { this.level = level; this.systemMonitor = new SystemMonitor(); @@ -42,6 +46,7 @@ public BenchmarkDiagnostics(DiagnosticLevel level) { /** * Creates a BenchmarkDiagnostics instance with BASIC level diagnostics + * @return the BenchmarkDiagnostics instance */ public static BenchmarkDiagnostics createBasic() { return new BenchmarkDiagnostics(DiagnosticLevel.BASIC); @@ -49,6 +54,7 @@ public static BenchmarkDiagnostics createBasic() { /** * Creates a BenchmarkDiagnostics instance with DETAILED level diagnostics + * @return the BenchmarkDiagnostics instance */ public static BenchmarkDiagnostics createDetailed() { return new BenchmarkDiagnostics(DiagnosticLevel.DETAILED); @@ -56,6 +62,7 @@ public static BenchmarkDiagnostics createDetailed() { /** * Creates a BenchmarkDiagnostics instance with VERBOSE level diagnostics + * @return the BenchmarkDiagnostics instance */ public static BenchmarkDiagnostics createVerbose() { return new BenchmarkDiagnostics(DiagnosticLevel.VERBOSE); @@ -63,6 +70,7 @@ public static BenchmarkDiagnostics createVerbose() { /** * Captures system state before starting a benchmark phase + * @param phase the phase name */ public void capturePrePhaseSnapshot(String phase) { if (level == DiagnosticLevel.NONE) return; @@ -78,6 +86,7 @@ public void capturePrePhaseSnapshot(String phase) { /** * Captures system state after completing a benchmark phase and logs changes + * @param phase the phase name */ public void capturePostPhaseSnapshot(String phase) { if (level == DiagnosticLevel.NONE) return; @@ -98,6 +107,7 @@ public void capturePostPhaseSnapshot(String phase) { /** * Records the execution time of a single query (for detailed timing analysis) + * @param nanoTime the query time in nanoseconds */ public void recordQueryTime(long nanoTime) { if (level == DiagnosticLevel.DETAILED || level == DiagnosticLevel.VERBOSE) { @@ -107,6 +117,7 @@ public void recordQueryTime(long nanoTime) { /** * Analyzes and logs timing data for a phase + * @param phase the phase name */ public void analyzePhaseTimings(String phase) { if (level == DiagnosticLevel.DETAILED || level == DiagnosticLevel.VERBOSE) { @@ -119,6 +130,10 @@ public void analyzePhaseTimings(String phase) { /** * Executes a benchmark phase with full diagnostic monitoring + * @param the result type + * @param phase the phase name + * @param benchmarkCode the benchmark code to execute + * @return the result from the benchmark code */ public T monitorPhase(String phase, Supplier benchmarkCode) { capturePrePhaseSnapshot(phase); @@ -138,6 +153,10 @@ public T monitorPhase(String phase, Supplier benchmarkCode) { /** * Executes a benchmark phase with detailed query timing + * @param the result type + * @param phase the phase name + * @param benchmarkCode the benchmark code to execute + * @return the result from the benchmark code */ public T monitorPhaseWithQueryTiming(String phase, QueryTimingBenchmark benchmarkCode) { capturePrePhaseSnapshot(phase); @@ -156,6 +175,10 @@ public T monitorPhaseWithQueryTiming(String phase, QueryTimingBenchmark b return result; } + /** + * Logs a message to the console if diagnostics are enabled. + * @param s the message to log + */ public void console(String s) { if (level != DiagnosticLevel.NONE ) { System.out.println(s); @@ -164,6 +187,8 @@ public void console(String s) { /** * Compares performance between different phases + * @param baselinePhase the baseline phase name + * @param currentPhase the current phase name */ public void comparePhases(String baselinePhase, String currentPhase) { if (timingAnalyses.size() < 2) return; @@ -223,6 +248,7 @@ public void logSummary() { /** * Checks if warmup appears to be effective based on performance stabilization + * @return true if warmup is effective */ public boolean isWarmupEffective() { if (timingAnalyses.size() < 2) return true; @@ -279,6 +305,9 @@ public void provideRecommendations() { /** * Compares performance between runs and identifies significant changes + * @param baseline the baseline timing analysis + * @param current the current timing analysis + * @return the performance comparison */ public static PerformanceAnalyzer.PerformanceComparison compareRuns(PerformanceAnalyzer.TimingAnalysis baseline, PerformanceAnalyzer.TimingAnalysis current) { double p50Change = calculatePercentageChange(baseline.p50, current.p50); @@ -295,6 +324,12 @@ public static PerformanceAnalyzer.PerformanceComparison compareRuns(PerformanceA ); } + /** + * Calculates the percentage change between baseline and current values. + * @param baseline the baseline value + * @param current the current value + * @return the percentage change + */ public static double calculatePercentageChange(long baseline, long current) { if (baseline == 0) return current == 0 ? 0.0 : 100.0; return ((double)(current - baseline) / baseline) * 100.0; @@ -302,6 +337,7 @@ public static double calculatePercentageChange(long baseline, long current) { /** * Logs performance comparison results + * @param comparison the performance comparison to log */ public static void logComparison(PerformanceAnalyzer.PerformanceComparison comparison) { System.out.printf("[%s vs %s] Performance Comparison:%n", @@ -318,9 +354,15 @@ public static void logComparison(PerformanceAnalyzer.PerformanceComparison compa /** * Functional interface for benchmark code that needs query timing + * @param the result type */ @FunctionalInterface public interface QueryTimingBenchmark { + /** + * Executes the benchmark with query timing. + * @param recorder the query time recorder + * @return the result + */ T execute(QueryTimeRecorder recorder); } @@ -329,6 +371,10 @@ public interface QueryTimingBenchmark { */ @FunctionalInterface public interface QueryTimeRecorder { + /** + * Records a query time. + * @param nanoTime the query time in nanoseconds + */ void recordTime(long nanoTime); } } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/diagnostics/PerformanceAnalyzer.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/diagnostics/PerformanceAnalyzer.java index 74104ce02..4e3f942d2 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/diagnostics/PerformanceAnalyzer.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/diagnostics/PerformanceAnalyzer.java @@ -31,7 +31,15 @@ public class PerformanceAnalyzer { private final AtomicLong totalTime = new AtomicLong(0); /** - * Records the execution time of a single query + * Constructs a new PerformanceAnalyzer with empty timing data. + */ + public PerformanceAnalyzer() { + } + + /** + * Records the execution time of a single query. + * + * @param nanoTime the query execution time in nanoseconds */ public void recordQueryTime(long nanoTime) { queryTimes.offer(nanoTime); @@ -40,7 +48,10 @@ public void recordQueryTime(long nanoTime) { } /** - * Analyzes collected timing data and returns performance statistics + * Analyzes collected timing data and returns performance statistics. + * + * @param phase the name of the phase being analyzed + * @return the timing analysis results */ public TimingAnalysis analyzeTimings(String phase) { List times = new ArrayList<>(queryTimes); @@ -77,7 +88,11 @@ public void reset() { } /** - * Compares performance between runs and identifies significant changes + * Compares performance between runs and identifies significant changes. + * + * @param baseline the baseline timing analysis + * @param current the current timing analysis + * @return the performance comparison results */ public static PerformanceComparison compareRuns(TimingAnalysis baseline, TimingAnalysis current) { double p50Change = calculatePercentageChange(baseline.p50, current.p50); @@ -100,7 +115,9 @@ private static double calculatePercentageChange(long baseline, long current) { } /** - * Logs timing analysis results + * Logs timing analysis results to standard output. + * + * @param analysis the timing analysis to log */ public void logAnalysis(TimingAnalysis analysis) { System.out.printf("[%s] Query Timing Analysis:%n", analysis.phase); @@ -124,7 +141,9 @@ public void logAnalysis(TimingAnalysis analysis) { } /** - * Logs performance comparison results + * Logs performance comparison results to standard output. + * + * @param comparison the performance comparison to log */ public static void logComparison(PerformanceComparison comparison) { System.out.printf("[%s vs %s] Performance Comparison:%n", @@ -139,17 +158,39 @@ public static void logComparison(PerformanceComparison comparison) { } } - // Data classes + /** + * Contains timing analysis statistics for a benchmark phase. + */ public static class TimingAnalysis { + /** The name of the phase being analyzed. */ public final String phase; + /** The minimum query time in nanoseconds. */ public final long min; + /** The maximum query time in nanoseconds. */ public final long max; + /** The 50th percentile (median) query time in nanoseconds. */ public final long p50; + /** The 95th percentile query time in nanoseconds. */ public final long p95; + /** The 99th percentile query time in nanoseconds. */ public final long p99; + /** The mean query time in nanoseconds. */ public final long mean; + /** List of outlier query times exceeding 3x the median. */ public final List outliers; + /** + * Constructs a TimingAnalysis with the specified statistics. + * + * @param phase the name of the phase being analyzed + * @param min the minimum query time in nanoseconds + * @param max the maximum query time in nanoseconds + * @param p50 the 50th percentile query time in nanoseconds + * @param p95 the 95th percentile query time in nanoseconds + * @param p99 the 99th percentile query time in nanoseconds + * @param mean the mean query time in nanoseconds + * @param outliers list of outlier query times + */ public TimingAnalysis(String phase, long min, long max, long p50, long p95, long p99, long mean, List outliers) { this.phase = phase; @@ -163,15 +204,36 @@ public TimingAnalysis(String phase, long min, long max, long p50, long p95, long } } + /** + * Contains performance comparison results between two benchmark runs. + */ public static class PerformanceComparison { + /** The name of the baseline phase. */ public final String baselinePhase; + /** The name of the current phase. */ public final String currentPhase; + /** The percentage change in 50th percentile time. */ public final double p50Change; + /** The percentage change in 95th percentile time. */ public final double p95Change; + /** The percentage change in 99th percentile time. */ public final double p99Change; + /** The percentage change in mean time. */ public final double meanChange; + /** Whether a significant performance regression was detected. */ public final boolean significantRegression; + /** + * Constructs a PerformanceComparison with the specified metrics. + * + * @param baselinePhase the name of the baseline phase + * @param currentPhase the name of the current phase + * @param p50Change the percentage change in 50th percentile time + * @param p95Change the percentage change in 95th percentile time + * @param p99Change the percentage change in 99th percentile time + * @param meanChange the percentage change in mean time + * @param significantRegression whether a significant regression was detected + */ public PerformanceComparison(String baselinePhase, String currentPhase, double p50Change, double p95Change, double p99Change, double meanChange, boolean significantRegression) { diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/diagnostics/SystemMonitor.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/diagnostics/SystemMonitor.java index 2dfbfbfac..d2a32923d 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/diagnostics/SystemMonitor.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/diagnostics/SystemMonitor.java @@ -29,8 +29,12 @@ public class SystemMonitor { private final List gcBeans; private final OperatingSystemMXBean osBean; private final ThreadMXBean threadBean; + /** Platform-specific OS bean for extended metrics. */ private final com.sun.management.OperatingSystemMXBean sunOsBean; + /** + * Constructs a SystemMonitor that initializes connections to system management beans. + */ public SystemMonitor() { this.memoryBean = ManagementFactory.getMemoryMXBean(); this.gcBeans = ManagementFactory.getGarbageCollectorMXBeans(); @@ -40,7 +44,9 @@ public SystemMonitor() { } /** - * Captures current system state snapshot + * Captures the current system state snapshot including GC, memory, CPU, and thread statistics. + * + * @return a snapshot of the current system state */ public SystemSnapshot captureSnapshot() { return new SystemSnapshot( @@ -99,7 +105,11 @@ private ThreadStats captureThreadStats() { } /** - * Logs the difference between two snapshots + * Logs the difference between two snapshots to standard output. + * + * @param phase the name of the phase being measured + * @param before the snapshot taken before the phase + * @param after the snapshot taken after the phase */ public void logDifference(String phase, SystemSnapshot before, SystemSnapshot after) { System.out.printf("[%s] System Changes:%n", phase); @@ -133,7 +143,9 @@ public void logDifference(String phase, SystemSnapshot before, SystemSnapshot af } /** - * Logs detailed GC information + * Logs detailed garbage collection information to standard output. + * + * @param phase the name of the phase to include in the output */ public void logDetailedGCStats(String phase) { System.out.printf("[%s] Detailed GC Stats:%n", phase); @@ -143,14 +155,30 @@ public void logDetailedGCStats(String phase) { } } - // Inner classes for data structures + /** + * Contains a complete snapshot of system state at a point in time. + */ public static class SystemSnapshot { + /** The timestamp when this snapshot was captured (milliseconds). */ public final long timestamp; + /** Garbage collection statistics. */ public final GCStats gcStats; + /** Memory usage statistics. */ public final MemoryStats memoryStats; + /** CPU usage statistics. */ public final CPUStats cpuStats; + /** Thread statistics. */ public final ThreadStats threadStats; + /** + * Constructs a SystemSnapshot with the specified metrics. + * + * @param timestamp the timestamp when captured + * @param gcStats garbage collection statistics + * @param memoryStats memory usage statistics + * @param cpuStats CPU usage statistics + * @param threadStats thread statistics + */ public SystemSnapshot(long timestamp, GCStats gcStats, MemoryStats memoryStats, CPUStats cpuStats, ThreadStats threadStats) { this.timestamp = timestamp; @@ -161,17 +189,36 @@ public SystemSnapshot(long timestamp, GCStats gcStats, MemoryStats memoryStats, } } + /** + * Contains garbage collection statistics. + */ public static class GCStats { + /** Total number of garbage collections. */ public final long totalCollections; + /** Total time spent in garbage collection (milliseconds). */ public final long totalCollectionTime; + /** Number of garbage collectors. */ public final int gcCount; + /** + * Constructs GCStats with the specified metrics. + * + * @param totalCollections total number of collections + * @param totalCollectionTime total time spent in collections (ms) + * @param gcCount number of garbage collectors + */ public GCStats(long totalCollections, long totalCollectionTime, int gcCount) { this.totalCollections = totalCollections; this.totalCollectionTime = totalCollectionTime; this.gcCount = gcCount; } + /** + * Computes the difference between this and another GCStats. + * + * @param other the GCStats to subtract from this one + * @return a new GCStats representing the difference + */ public GCStats subtract(GCStats other) { return new GCStats( this.totalCollections - other.totalCollections, @@ -181,15 +228,36 @@ public GCStats subtract(GCStats other) { } } + /** + * Contains memory usage statistics. + */ public static class MemoryStats { + /** Heap memory currently used (bytes). */ public final long heapUsed; + /** Maximum heap memory available (bytes). */ public final long heapMax; + /** Heap memory committed by the JVM (bytes). */ public final long heapCommitted; + /** Non-heap memory used (bytes). */ public final long nonHeapUsed; + /** Free memory in the runtime (bytes). */ public final long freeMemory; + /** Total memory in the runtime (bytes). */ public final long totalMemory; + /** Maximum memory the runtime can use (bytes). */ public final long maxMemory; + /** + * Constructs MemoryStats with the specified metrics. + * + * @param heapUsed heap memory used (bytes) + * @param heapMax maximum heap memory (bytes) + * @param heapCommitted heap memory committed (bytes) + * @param nonHeapUsed non-heap memory used (bytes) + * @param freeMemory free memory (bytes) + * @param totalMemory total memory (bytes) + * @param maxMemory maximum memory (bytes) + */ public MemoryStats(long heapUsed, long heapMax, long heapCommitted, long nonHeapUsed, long freeMemory, long totalMemory, long maxMemory) { this.heapUsed = heapUsed; @@ -202,13 +270,30 @@ public MemoryStats(long heapUsed, long heapMax, long heapCommitted, long nonHeap } } + /** + * Contains CPU usage statistics. + */ public static class CPUStats { + /** System-wide CPU load (0.0 to 1.0). */ public final double systemCpuLoad; + /** Process CPU load (0.0 to 1.0). */ public final double processCpuLoad; + /** System load average. */ public final double systemLoadAverage; + /** Number of available processors. */ public final int availableProcessors; + /** Free physical memory size (bytes). */ public final long freePhysicalMemory; + /** + * Constructs CPUStats with the specified metrics. + * + * @param systemCpuLoad system-wide CPU load (0.0-1.0) + * @param processCpuLoad process CPU load (0.0-1.0) + * @param systemLoadAverage system load average + * @param availableProcessors number of available processors + * @param freePhysicalMemory free physical memory (bytes) + */ public CPUStats(double systemCpuLoad, double processCpuLoad, double systemLoadAverage, int availableProcessors, long freePhysicalMemory) { this.systemCpuLoad = systemCpuLoad; @@ -219,11 +304,24 @@ public CPUStats(double systemCpuLoad, double processCpuLoad, double systemLoadAv } } + /** + * Contains thread statistics. + */ public static class ThreadStats { + /** Number of active threads. */ public final int activeThreads; + /** Peak number of threads. */ public final int peakThreads; + /** Total number of threads started since JVM start. */ public final long totalStartedThreads; + /** + * Constructs ThreadStats with the specified metrics. + * + * @param activeThreads number of active threads + * @param peakThreads peak number of threads + * @param totalStartedThreads total threads started + */ public ThreadStats(int activeThreads, int peakThreads, long totalStartedThreads) { this.activeThreads = activeThreads; this.peakThreads = peakThreads; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java index ba537ed06..a6a2bf19e 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java @@ -27,6 +27,10 @@ * Computes accuracy metrics, such as recall and mean average precision. */ public class AccuracyMetrics { + /** + * Constructs an AccuracyMetrics. + */ + public AccuracyMetrics() {} /** * Compute kGT-recall@kRetrieved, which is the fraction of * the kGT ground-truth nearest neighbors that are in the kRetrieved diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/BenchmarkSummarizer.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/BenchmarkSummarizer.java index dba6064ab..9b84da157 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/BenchmarkSummarizer.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/BenchmarkSummarizer.java @@ -25,7 +25,11 @@ * across all configurations. */ public class BenchmarkSummarizer { - + /** + * Constructs a BenchmarkSummarizer. + */ + public BenchmarkSummarizer() {} + /** * Summary statistics for benchmark results */ @@ -37,10 +41,27 @@ public static class SummaryStats { private final int totalConfigurations; private final double qpsStdDev; + /** + * Constructs SummaryStats with the specified values. + * @param avgRecall the avgRecall parameter + * @param avgQps the avgQps parameter + * @param avgLatency the avgLatency parameter + * @param indexConstruction the indexConstruction parameter + * @param totalConfigurations the totalConfigurations parameter + */ public SummaryStats(double avgRecall, double avgQps, double avgLatency, double indexConstruction, int totalConfigurations) { this(avgRecall, avgQps, avgLatency, indexConstruction, totalConfigurations, 0.0); } + /** + * Constructs SummaryStats with the specified values including QPS standard deviation. + * @param avgRecall the avgRecall parameter + * @param avgQps the avgQps parameter + * @param avgLatency the avgLatency parameter + * @param indexConstruction the indexConstruction parameter + * @param totalConfigurations the totalConfigurations parameter + * @param qpsStdDev the qpsStdDev parameter + */ public SummaryStats(double avgRecall, double avgQps, double avgLatency, double indexConstruction, int totalConfigurations, double qpsStdDev) { this.avgRecall = avgRecall; this.avgQps = avgQps; @@ -50,24 +71,48 @@ public SummaryStats(double avgRecall, double avgQps, double avgLatency, double i this.qpsStdDev = qpsStdDev; } + /** + * Gets the average recall. + * @return the average recall + */ public double getAvgRecall() { return avgRecall; } + /** + * Gets the average QPS. + * @return the average QPS + */ public double getAvgQps() { return avgQps; } + /** + * Gets the average latency. + * @return the average latency + */ public double getAvgLatency() { return avgLatency; } + /** + * Gets the index construction time. + * @return the index construction time + */ public double getIndexConstruction() { return indexConstruction; } + /** + * Gets the total number of configurations. + * @return the total number of configurations + */ public int getTotalConfigurations() { return totalConfigurations; } + /** + * Gets the QPS standard deviation. + * @return the QPS standard deviation + */ public double getQpsStdDev() { return qpsStdDev; } @Override diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/CheckpointManager.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/CheckpointManager.java index 4145100b2..79dbee02c 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/CheckpointManager.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/CheckpointManager.java @@ -122,6 +122,7 @@ public Set getCompletedDatasets() { /** * Returns the list of completed BenchResults. + * @return the list of completed BenchResults */ public List getCompletedResults() { return new ArrayList<>(completedResults); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/CompressorParameters.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/CompressorParameters.java index e1ffebb9b..4fcb921a4 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/CompressorParameters.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/CompressorParameters.java @@ -21,26 +21,59 @@ import io.github.jbellis.jvector.quantization.ProductQuantization; import io.github.jbellis.jvector.quantization.VectorCompressor; +/** + * Base class for compressor parameters. + */ public abstract class CompressorParameters { + /** + * Constructs a CompressorParameters. + */ + public CompressorParameters() {} + + /** No compression constant. */ public static final CompressorParameters NONE = new NoCompressionParameters(); + /** + * Checks if this compressor supports caching. + * @return true if caching is supported + */ public boolean supportsCaching() { return false; } + /** + * Gets the ID string for the specified dataset. + * @param ds the dataset + * @return the ID string + */ public String idStringFor(DataSet ds) { // only required when supportsCaching() is true throw new UnsupportedOperationException(); } + /** + * Computes the compressor for the specified dataset. + * @param ds the dataset + * @return the vector compressor + */ public abstract VectorCompressor computeCompressor(DataSet ds); + /** + * Product quantization parameters. + */ public static class PQParameters extends CompressorParameters { private final int m; private final int k; private final boolean isCentered; private final float anisotropicThreshold; + /** + * Constructs PQParameters. + * @param m the m parameter + * @param k the k parameter + * @param isCentered the isCentered parameter + * @param anisotropicThreshold the anisotropicThreshold parameter + */ public PQParameters(int m, int k, boolean isCentered, float anisotropicThreshold) { this.m = m; this.k = k; @@ -64,16 +97,30 @@ public boolean supportsCaching() { } } + /** + * Binary quantization parameters. + */ public static class BQParameters extends CompressorParameters { + /** + * Constructs BQParameters. + */ + public BQParameters() {} @Override public VectorCompressor computeCompressor(DataSet ds) { return new BinaryQuantization(ds.getDimension()); } } + /** + * NVQ parameters. + */ public static class NVQParameters extends CompressorParameters { private final int nSubVectors; + /** + * Constructs NVQParameters. + * @param nSubVectors the number of sub-vectors + */ public NVQParameters(int nSubVectors) { this.nSubVectors = nSubVectors; } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSet.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSet.java index e193cd6ad..663856abc 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSet.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSet.java @@ -30,14 +30,30 @@ import java.util.Set; import java.util.TreeSet; +/** + * A dataset containing base vectors, query vectors, and ground truth results. + */ public class DataSet { + /** The name of the dataset. */ public final String name; + /** The similarity function for this dataset. */ public final VectorSimilarityFunction similarityFunction; + /** The base vectors in the dataset. */ public final List> baseVectors; + /** The query vectors for searching. */ public final List> queryVectors; + /** The ground truth results for queries. */ public final List> groundTruth; private RandomAccessVectorValues baseRavv; + /** + * Creates a new DataSet. + * @param name the name parameter + * @param similarityFunction the similarityFunction parameter + * @param baseVectors the baseVectors parameter + * @param queryVectors the queryVectors parameter + * @param groundTruth the groundTruth parameter + */ public DataSet(String name, VectorSimilarityFunction similarityFunction, List> baseVectors, @@ -74,6 +90,12 @@ public DataSet(String name, /** * Return a dataset containing the given vectors, scrubbed free from zero vectors and normalized to unit length. * Note: This only scrubs and normalizes for dot product similarity. + * @param pathStr the pathStr parameter + * @param vsf the vsf parameter + * @param baseVectors the baseVectors parameter + * @param queryVectors the queryVectors parameter + * @param groundTruth the groundTruth parameter + * @return the scrubbed dataset */ public static DataSet getScrubbedDataSet(String pathStr, VectorSimilarityFunction vsf, @@ -153,10 +175,18 @@ private static float normOf(VectorFloat baseVector) { return (float) Math.sqrt(norm); } + /** + * Gets the dimension of the vectors. + * @return the dimension + */ public int getDimension() { return baseVectors.get(0).length(); } + /** + * Gets the base vectors as RandomAccessVectorValues. + * @return the base vectors + */ public RandomAccessVectorValues getBaseRavv() { if (baseRavv == null) { baseRavv = new ListRandomAccessVectorValues(baseVectors, getDimension()); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSetCreator.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSetCreator.java index 1cd532160..41b569605 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSetCreator.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSetCreator.java @@ -30,9 +30,15 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +/** + * Utility for creating test datasets. + */ public class DataSetCreator { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + private DataSetCreator() { + } + /** * Creates a 2D grid of vectors, query vectors, and ground truth data for a given grid width. * diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSetLoader.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSetLoader.java index e90a6f275..43ce3af46 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSetLoader.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DataSetLoader.java @@ -18,7 +18,19 @@ import java.io.IOException; +/** + * Utility for loading datasets from various file formats. + */ public class DataSetLoader { + private DataSetLoader() { + } + + /** + * Loads a dataset from the specified file. + * @param fileName the file name + * @return the loaded dataset + * @throws IOException if an error occurs + */ public static DataSet loadDataSet(String fileName) throws IOException { DataSet ds; if (fileName.endsWith(".hdf5")) { diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/Deep1BLoader.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/Deep1BLoader.java index deea90a19..5575d38a3 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/Deep1BLoader.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/Deep1BLoader.java @@ -24,7 +24,20 @@ import java.util.List; import java.util.stream.IntStream; +/** + * Utility for loading Deep1B dataset files. + */ public class Deep1BLoader { + private Deep1BLoader() { + } + + /** + * Reads vectors from a binary file. + * @param filePath the file path + * @param count the number of vectors to read + * @return the list of vectors + * @throws IOException if an error occurs + */ public static List readFBin(String filePath, int count) throws IOException { var vectors = new float[count][]; @@ -66,6 +79,11 @@ public static List readFBin(String filePath, int count) throws IOExcept return List.of(vectors); } + /** + * Reads ground truth data from a file. + * @param filePath the file path + * @return the ground truth data + */ public static ArrayList> readGT(String filePath) { var groundTruthTopK = new ArrayList>(); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DownloadHelper.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DownloadHelper.java index 8725a6f65..0af29c98b 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DownloadHelper.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/DownloadHelper.java @@ -38,6 +38,9 @@ import java.util.List; import java.util.Set; +/** + * Utility for downloading datasets from S3 and HTTP sources. + */ public class DownloadHelper { private static final String bucketName = "astra-vector"; private static final String infraBucketName = "jvector-datasets-infratest"; @@ -46,6 +49,9 @@ public class DownloadHelper { private static final String fvecDir = "fvec"; private final static Set infraDatasets = Set.of("dpr-1M", "dpr-10M", "cap-1M", "cap-6M", "cohere-english-v3-1M", "cohere-english-v3-10M"); + private DownloadHelper() { + } + private static S3AsyncClientBuilder s3AsyncClientBuilder() { return S3AsyncClient.builder() .region(Region.US_EAST_1) @@ -55,6 +61,11 @@ private static S3AsyncClientBuilder s3AsyncClientBuilder() { .credentialsProvider(AnonymousCredentialsProvider.create()); } + /** + * Downloads fvec dataset files if not already present. + * @param name the dataset name + * @return the multi-file datasource + */ public static MultiFileDatasource maybeDownloadFvecs(String name) { String bucket = infraDatasets.contains(name) ? infraBucketName : bucketName; var mfd = MultiFileDatasource.byName.get(name); @@ -111,6 +122,10 @@ public static MultiFileDatasource maybeDownloadFvecs(String name) { return mfd; } + /** + * Downloads HDF5 dataset file if not already present. + * @param datasetName the dataset name + */ public static void maybeDownloadHdf5(String datasetName) { Path path = Path.of(Hdf5Loader.HDF5_DIR); var localPath = path.resolve(datasetName); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/FilteredForkJoinPool.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/FilteredForkJoinPool.java index fb1a550c6..834095f2c 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/FilteredForkJoinPool.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/FilteredForkJoinPool.java @@ -25,6 +25,14 @@ * to make them identifiable for thread leak detection. */ public class FilteredForkJoinPool extends ForkJoinPool { + + /** + * Constructs a FilteredForkJoinPool using the default parallelism level. + * Use the static factory methods instead of this constructor. + */ + public FilteredForkJoinPool() { + super(); + } /** * Creates a ForkJoinPool with the same parallelism as {@link ForkJoinPool#commonPool()} @@ -69,6 +77,11 @@ public ForkJoinWorkerThread newThread(ForkJoinPool pool) { * Custom worker thread class that can be easily identified for thread leak detection. */ private static class JVectorForkJoinWorkerThread extends ForkJoinWorkerThread { + /** + * Creates a new worker thread attached to the given pool. + * + * @param pool the pool this thread will work in + */ protected JVectorForkJoinWorkerThread(ForkJoinPool pool) { super(pool); } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/Hdf5Loader.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/Hdf5Loader.java index 7dfdccc07..11e284bb1 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/Hdf5Loader.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/Hdf5Loader.java @@ -30,10 +30,22 @@ import java.util.List; import java.util.stream.IntStream; +/** + * Utility for loading datasets from HDF5 files. + */ public class Hdf5Loader { + /** The directory containing HDF5 files. */ public static final String HDF5_DIR = "hdf5/"; private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + private Hdf5Loader() { + } + + /** + * Loads a dataset from an HDF5 file. + * @param filename the file name + * @return the loaded dataset + */ public static DataSet load(String filename) { // infer the similarity VectorSimilarityFunction similarityFunction; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MMapRandomAccessVectorValues.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MMapRandomAccessVectorValues.java index 0fda16df1..f39e08b6b 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MMapRandomAccessVectorValues.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MMapRandomAccessVectorValues.java @@ -29,15 +29,31 @@ import java.nio.ByteOrder; import java.nio.channels.FileChannel; +/** + * Memory-mapped implementation of RandomAccessVectorValues that provides efficient + * random access to vectors stored in a file. Uses memory mapping for fast I/O. + */ public class MMapRandomAccessVectorValues implements RandomAccessVectorValues, Closeable { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + /** The dimension of each vector. */ final int dimension; + /** The number of vectors in the file. */ final int rows; + /** The file containing the vectors. */ final File file; + /** Reusable buffer for reading vector values. */ final float[] valueBuffer; + /** The memory-mapped file reader. */ final MMapBuffer fileReader; + /** + * Constructs a MMapRandomAccessVectorValues for the specified file. + * + * @param f the file containing vectors, which must exist and be readable + * @param dimension the dimension of each vector in the file + * @throws IOError if an I/O error occurs during initialization + */ public MMapRandomAccessVectorValues(File f, int dimension) { assert f != null && f.exists() && f.canRead(); assert f.length() % ((long) dimension * Float.BYTES) == 0; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MMapReader.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MMapReader.java index 79a921a46..537feaf4b 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MMapReader.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MMapReader.java @@ -25,12 +25,24 @@ import java.nio.channels.FileChannel; import java.nio.file.Path; +/** + * Memory-mapped implementation of RandomAccessReader that provides efficient + * random access to data stored in a file. Uses memory mapping for fast I/O operations. + */ @SuppressWarnings("unused") public class MMapReader implements RandomAccessReader { + /** The memory-mapped buffer for reading data. */ private final MMapBuffer buffer; + /** The current read position in the buffer. */ private long position; + /** Reusable scratch buffer for bulk read operations. */ private byte[] scratch = new byte[0]; + /** + * Constructs a MMapReader wrapping the specified memory-mapped buffer. + * + * @param buffer the memory-mapped buffer to read from + */ MMapReader(MMapBuffer buffer) { this.buffer = buffer; } @@ -146,9 +158,20 @@ public void close() { // don't close buffer, let the Supplier handle that } + /** + * Supplier that creates MMapReader instances from a memory-mapped file. + * The file is mapped into memory once during construction and shared across all readers. + */ public static class Supplier implements ReaderSupplier { + /** The shared memory-mapped buffer for this file. */ private final MMapBuffer buffer; + /** + * Constructs a Supplier that memory-maps the file at the specified path. + * + * @param path the path to the file to map + * @throws IOException if an I/O error occurs during mapping + */ public Supplier(Path path) throws IOException { buffer = new MMapBuffer(path, FileChannel.MapMode.READ_ONLY, ByteOrder.BIG_ENDIAN); } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MultiFileDatasource.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MultiFileDatasource.java index 6f875e23c..baa11c8ff 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MultiFileDatasource.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MultiFileDatasource.java @@ -25,13 +25,30 @@ import java.util.List; import java.util.Map; +/** + * Represents a dataset composed of multiple files: base vectors, query vectors, and ground truth. + * Provides access to dataset files and loading functionality. + */ public class MultiFileDatasource { + /** The name of the dataset. */ public final String name; + /** Path to the base vectors file. */ public final Path basePath; + /** Path to the query vectors file. */ public final Path queriesPath; + /** Path to the ground truth file. */ public final Path groundTruthPath; + /** Optional hash prefix for dataset paths from environment variable. */ private final static String DATASET_HASH = System.getenv("DATASET_HASH"); + /** + * Constructs a MultiFileDatasource with the specified file paths. + * + * @param name the name of the dataset + * @param basePath the path to the base vectors file + * @param queriesPath the path to the query vectors file + * @param groundTruthPath the path to the ground truth file + */ public MultiFileDatasource(String name, String basePath, String queriesPath, String groundTruthPath) { this.name = name; this.basePath = Paths.get(basePath); @@ -39,14 +56,30 @@ public MultiFileDatasource(String name, String basePath, String queriesPath, Str this.groundTruthPath = Paths.get(groundTruthPath); } + /** + * Returns the parent directory containing the base vectors file. + * + * @return the directory path + */ public Path directory() { return basePath.getParent(); } + /** + * Returns all paths associated with this dataset. + * + * @return an iterable of paths (base, queries, ground truth) + */ public Iterable paths() { return List.of(basePath, queriesPath, groundTruthPath); } + /** + * Loads the dataset from its constituent files. + * + * @return the loaded DataSet + * @throws IOException if an I/O error occurs during loading + */ public DataSet load() throws IOException { var baseVectors = SiftLoader.readFvecs("fvec/" + basePath); var queryVectors = SiftLoader.readFvecs("fvec/" + queriesPath); @@ -54,6 +87,10 @@ public DataSet load() throws IOException { return DataSet.getScrubbedDataSet(name, VectorSimilarityFunction.COSINE, baseVectors, queryVectors, gtVectors); } + /** + * Map of dataset names to their corresponding MultiFileDatasource instances. + * Provides convenient access to predefined benchmark datasets. + */ public static Map byName = new HashMap<>() {{ put("degen-200k", new MultiFileDatasource("degen-200k", "ada-degen/degen_base_vectors.fvec", diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/SiftLoader.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/SiftLoader.java index 46f144418..4655b1123 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/SiftLoader.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/SiftLoader.java @@ -31,9 +31,28 @@ import java.util.HashSet; import java.util.List; +/** + * Utility class for loading SIFT dataset files in fvec and ivec formats. + * These formats are commonly used in vector similarity search benchmarks. + */ public class SiftLoader { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + /** + * Private constructor to prevent instantiation of this utility class. + */ + private SiftLoader() { + throw new AssertionError("SiftLoader is a utility class and should not be instantiated"); + } + + /** + * Reads float vectors from an fvec file. + * The fvec format stores vectors as: [dimension:int][value1:float]...[valueN:float]. + * + * @param filePath the path to the fvec file + * @return a list of float vectors + * @throws IOException if an I/O error occurs + */ public static List> readFvecs(String filePath) throws IOException { var vectors = new ArrayList>(); try (var dis = new DataInputStream(new BufferedInputStream(new FileInputStream(filePath)))) { @@ -53,6 +72,14 @@ public static List> readFvecs(String filePath) throws IOException return vectors; } + /** + * Reads integer vectors from an ivec file. + * The ivec format stores vectors as: [dimension:int][value1:int]...[valueN:int]. + * Typically used for ground truth neighbor lists. + * + * @param filename the path to the ivec file + * @return a list of integer vectors (each vector is a list of integers) + */ public static List> readIvecs(String filename) { var groundTruthTopK = new ArrayList>(); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/UpdatableRandomAccessVectorValues.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/UpdatableRandomAccessVectorValues.java index 0bbb17695..d74862e81 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/UpdatableRandomAccessVectorValues.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/UpdatableRandomAccessVectorValues.java @@ -22,15 +22,37 @@ import java.util.ArrayList; import java.util.List; +/** + * A mutable implementation of {@link RandomAccessVectorValues} that allows vectors to be + * added dynamically. This implementation stores vectors in memory using an {@link ArrayList} + * and is suitable for scenarios where the vector collection needs to grow over time. + * + *

This class is thread-safe for read operations but not for concurrent modifications. + * Multiple threads can safely read vectors using {@link #getVector(int)}, but adding vectors + * via {@link #add(VectorFloat)} should be externally synchronized if concurrent access is needed.

+ */ public class UpdatableRandomAccessVectorValues implements RandomAccessVectorValues { private final List> data; private final int dimensions; + /** + * Creates a new updatable vector collection with the specified dimensionality. + * Initializes the internal storage with a capacity of 1024 vectors. + * + * @param dimensions the dimensionality of vectors that will be stored + */ public UpdatableRandomAccessVectorValues(int dimensions) { this.data = new ArrayList<>(1024); this.dimensions = dimensions; } + /** + * Adds a vector to this collection. The vector must have the same dimensionality + * as specified in the constructor. + * + * @param vector the vector to add to this collection + * @throws IllegalArgumentException if the vector's dimension does not match the expected dimension + */ public void add(VectorFloat vector) { data.add(vector); } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/CommonParameters.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/CommonParameters.java index d89642059..c25c87b97 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/CommonParameters.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/CommonParameters.java @@ -23,9 +23,22 @@ import java.util.function.Function; import java.util.stream.Collectors; +/** + * Common parameters shared across benchmark configurations. + */ public class CommonParameters { + /** List of compression configurations. */ public List compression; + /** + * Constructs a CommonParameters. + */ + public CommonParameters() {} + + /** + * Gets the compressor parameters as functions. + * @return the list of compressor parameter functions + */ public List> getCompressorParameters() { return compression.stream().map(Compression::getCompressorParameters).collect(Collectors.toList()); } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/Compression.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/Compression.java index fe7d4d82d..f39dcf772 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/Compression.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/Compression.java @@ -23,10 +23,24 @@ import java.util.Map; import java.util.function.Function; +/** + * Compression configuration for benchmarks. + */ public class Compression { + /** The compression type. */ public String type; + /** The compression parameters. */ public Map parameters; + /** + * Constructs a Compression. + */ + public Compression() {} + + /** + * Gets the compressor parameters as a function. + * @return the compressor parameters function + */ public Function getCompressorParameters() { switch (type) { case "None": diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/ConstructionParameters.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/ConstructionParameters.java index 23fd75e03..6bceb94e7 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/ConstructionParameters.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/ConstructionParameters.java @@ -22,16 +22,34 @@ import java.util.List; import java.util.stream.Collectors; - +/** + * Construction parameters for graph index building. + */ public class ConstructionParameters extends CommonParameters { + /** List of out-degree values. */ public List outDegree; + /** List of efConstruction values. */ public List efConstruction; + /** List of neighbor overflow values. */ public List neighborOverflow; + /** List of add hierarchy flags. */ public List addHierarchy; + /** List of refine final graph flags. */ public List refineFinalGraph; + /** List of reranking strategies. */ public List reranking; + /** Flag to use saved index if exists. */ public Boolean useSavedIndexIfExists; + /** + * Constructs a ConstructionParameters. + */ + public ConstructionParameters() {} + + /** + * Gets the feature sets based on reranking strategies. + * @return the list of feature sets + */ public List> getFeatureSets() { return reranking.stream().map(item -> { switch (item) { diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/DatasetCollection.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/DatasetCollection.java index 59ac2e4ee..82a4325bf 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/DatasetCollection.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/DatasetCollection.java @@ -25,25 +25,44 @@ import java.util.List; import java.util.Map; +/** + * Collection of dataset names loaded from YAML configuration. + */ public class DatasetCollection { private static final String defaultFile = "./jvector-examples/yaml-configs/datasets.yml"; + /** Map of dataset categories to dataset names. */ public final Map> datasetNames; private DatasetCollection(Map> datasetNames) { this.datasetNames = datasetNames; } + /** + * Loads dataset collection from default file. + * @return the dataset collection + * @throws IOException if an error occurs + */ public static DatasetCollection load() throws IOException { return load(defaultFile); } + /** + * Loads dataset collection from specified file. + * @param file the file path + * @return the dataset collection + * @throws IOException if an error occurs + */ public static DatasetCollection load(String file) throws IOException { InputStream inputStream = new FileInputStream(file); Yaml yaml = new Yaml(); return new DatasetCollection(yaml.load(inputStream)); } + /** + * Gets all dataset names from all categories. + * @return the list of all dataset names + */ public List getAll() { List allDatasetNames = new ArrayList<>(); for (var key : datasetNames.keySet()) { diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/MultiConfig.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/MultiConfig.java index ec4ad7992..b042ddddd 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/MultiConfig.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/MultiConfig.java @@ -24,14 +24,38 @@ import java.io.FileNotFoundException; import java.io.InputStream; +/** + * Configuration container for benchmark runs, loaded from YAML files. + * Includes dataset selection, construction parameters, and search parameters. + */ public class MultiConfig { + /** Default directory for YAML configuration files. */ private static final String defaultDirectory = "./jvector-examples/yaml-configs/"; + /** The version of the configuration format. */ private int version; + /** The name of the dataset to use. */ public String dataset; + /** Parameters for graph construction. */ public ConstructionParameters construction; + /** Parameters for search operations. */ public SearchParameters search; + /** + * Constructs an empty MultiConfig. + */ + public MultiConfig() { + } + + /** + * Loads the default configuration for the specified dataset. + * If a dataset-specific config file exists, it is used; otherwise the default.yml is loaded + * and the dataset name is set. + * + * @param datasetName the name of the dataset + * @return the loaded configuration + * @throws FileNotFoundException if neither dataset-specific nor default config is found + */ public static MultiConfig getDefaultConfig(String datasetName) throws FileNotFoundException { var name = defaultDirectory + datasetName; if (!name.endsWith(".yml")) { @@ -50,11 +74,25 @@ public static MultiConfig getDefaultConfig(String datasetName) throws FileNotFou return config; } + /** + * Loads a configuration from the specified file name. + * + * @param configName the path to the configuration file + * @return the loaded configuration + * @throws FileNotFoundException if the configuration file is not found + */ public static MultiConfig getConfig(String configName) throws FileNotFoundException { File configFile = new File(configName); return getConfig(configFile); } + /** + * Loads a configuration from the specified file. + * + * @param configFile the configuration file to load + * @return the loaded configuration + * @throws FileNotFoundException if the configuration file is not found + */ public static MultiConfig getConfig(File configFile) throws FileNotFoundException { if (!configFile.exists()) { throw new FileNotFoundException(configFile.getAbsolutePath()); @@ -64,10 +102,22 @@ public static MultiConfig getConfig(File configFile) throws FileNotFoundExceptio return yaml.loadAs(inputStream, MultiConfig.class); } + /** + * Returns the configuration format version. + * + * @return the version number + */ public int getVersion() { return version; } + /** + * Sets the configuration format version. + * The version must match the current OnDiskGraphIndex version. + * + * @param version the version number to set + * @throws IllegalArgumentException if the version does not match CURRENT_VERSION + */ public void setVersion(int version) { if (version != OnDiskGraphIndex.CURRENT_VERSION) { throw new IllegalArgumentException("Invalid version: " + version); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/SearchParameters.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/SearchParameters.java index 9a027076e..6d99f1e1f 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/SearchParameters.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/SearchParameters.java @@ -19,8 +19,22 @@ import java.util.List; import java.util.Map; +/** + * Configuration parameters for search operations in benchmarks. + * Extends CommonParameters to include search-specific settings like topK, overquery ratios, + * pruning options, and benchmark configurations. + */ public class SearchParameters extends CommonParameters { + /** Map of topK values to lists of overquery ratios. */ public Map> topKOverquery; + /** List of boolean flags indicating whether to use search pruning. */ public List useSearchPruning; + /** Map of benchmark names to their configuration options. */ public Map> benchmarks; + + /** + * Constructs an empty SearchParameters instance. + */ + public SearchParameters() { + } } \ No newline at end of file diff --git a/jvector-native/pom.xml b/jvector-native/pom.xml index daf84fe6a..24763d4e3 100644 --- a/jvector-native/pom.xml +++ b/jvector-native/pom.xml @@ -49,6 +49,16 @@
+ + org.apache.maven.plugins + maven-javadoc-plugin + + + --add-modules=jdk.incubator.vector + + 22 + + diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReader.java b/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReader.java index c9ac39c21..d13028133 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReader.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReader.java @@ -145,10 +145,66 @@ public void close() { // Individual readers don't close the shared memory } + /** + * Factory for creating multiple {@link MemorySegmentReader} instances that share the same + * memory-mapped file. + *

+ * This supplier implementation allows multiple readers to be created for concurrent access + * to the same underlying memory-mapped file, which is particularly useful for multi-threaded + * vector search operations. All readers share a single memory-mapped region, reducing memory + * overhead and improving cache efficiency. + *

+ * The supplier manages the lifecycle of the shared {@link Arena} and {@link MemorySegment}, + * ensuring that the memory mapping remains valid while any reader might be using it. When + * {@link #close()} is called, the arena is closed and all associated memory mappings are + * released. + *

+ * Performance optimizations: + *

    + *
  • Uses {@code posix_madvise} with {@code MADV_RANDOM} advice to optimize for random + * access patterns typical of vector similarity searches
  • + *
  • Creates a shared arena that allows the memory segment to be accessed from multiple + * threads safely
  • + *
  • Avoids the 2GB file size limitation of {@code ByteBuffer}-based implementations
  • + *
+ * + * @see MemorySegmentReader + * @see Arena#ofShared() + */ public static class Supplier implements ReaderSupplier { private final Arena arena; private final MemorySegment memory; + /** + * Creates a new supplier that memory-maps the specified file for random access. + *

+ * This constructor performs the following operations: + *

    + *
  1. Creates a shared {@link Arena} for managing the memory segment lifecycle
  2. + *
  3. Opens the file as a {@link FileChannel} in read-only mode
  4. + *
  5. Memory-maps the entire file using {@link FileChannel#map(MapMode, long, long, Arena)}
  6. + *
  7. Applies {@code MADV_RANDOM} advice via {@code posix_madvise} to hint that the + * file will be accessed in random order, which is typical for vector similarity + * search workloads. This advice tells the OS not to perform aggressive read-ahead.
  8. + *
+ * The {@code posix_madvise} call is performed using the Foreign Function & Memory API + * to directly invoke the native function. If {@code posix_madvise} is not available + * (e.g., on non-POSIX systems), a warning is logged but construction proceeds normally. + *

+ * Thread safety: The created supplier is thread-safe and can be used + * to create readers from multiple threads. Each call to {@link #get()} returns a new + * reader instance, but all readers share the same underlying memory mapping. + *

+ * Resource management: The caller is responsible for calling + * {@link #close()} when done with the supplier. Failing to close the supplier will + * prevent the memory mapping from being released, potentially causing resource leaks. + * + * @param path the path to the file to be memory-mapped; must be readable and should + * contain vector data in the expected format + * @throws IOException if an I/O error occurs opening the file, mapping the memory, + * or if {@code posix_madvise} returns a non-zero error code + * @throws RuntimeException if an unexpected error occurs during native method invocation + */ public Supplier(Path path) throws IOException { this.arena = Arena.ofShared(); try (var ch = FileChannel.open(path, StandardOpenOption.READ)) { diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java index a098ced7a..5868f1985 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java @@ -26,10 +26,46 @@ import java.nio.Buffer; /** - * VectorTypeSupport using MemorySegments. + * Implementation of {@link VectorTypeSupport} that uses off-heap memory segments for vector storage. + *

+ * This provider leverages Java's Foreign Function & Memory API (introduced in Java 19 and + * finalized in Java 22) to store vector data in native memory segments rather than on-heap arrays. + * This approach provides several advantages for high-performance vector operations: + *

    + *
  • Reduced GC pressure: Vector data is stored off-heap, minimizing + * garbage collection overhead for large vector datasets
  • + *
  • Direct memory access: Enables efficient interoperability with native + * SIMD implementations through {@link java.lang.foreign.MemorySegment} pointers
  • + *
  • Memory-mapped I/O: Supports zero-copy access to memory-mapped vector + * files via {@link io.github.jbellis.jvector.disk.MemorySegmentReader}
  • + *
  • Deterministic cleanup: Memory segments are explicitly managed through + * the arena API, providing predictable memory lifecycle
  • + *
+ * This provider is typically used in conjunction with {@link NativeVectorizationProvider} to + * enable native SIMD acceleration for vector similarity computations. It can handle both + * float vectors ({@link VectorFloat}) and byte sequences ({@link ByteSequence}) backed by + * native memory. + *

+ * Thread safety: This provider is thread-safe for creating new vectors. + * However, individual vector instances are not thread-safe unless they are backed by + * memory segments from a shared arena. + * + * @see MemorySegmentVectorFloat + * @see MemorySegmentByteSequence + * @see VectorTypeSupport */ public class MemorySegmentVectorProvider implements VectorTypeSupport { + /** + * Constructs a new MemorySegmentVectorProvider with default settings. + *

+ * This provider creates vectors backed by off-heap memory segments managed by + * the Foreign Function & Memory API. The provider itself is stateless and + * lightweight - memory management is handled at the individual vector level + * through their associated arenas. + */ + public MemorySegmentVectorProvider() { + } @Override public VectorFloat createFloatVector(Object data) { diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorizationProvider.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorizationProvider.java index 3c876e517..c29591ac8 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorizationProvider.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorizationProvider.java @@ -30,6 +30,9 @@ public class NativeVectorizationProvider extends VectorizationProvider { private final VectorUtilSupport vectorUtilSupport; private final VectorTypeSupport vectorTypeSupport; + /** + * Constructs a NativeVectorizationProvider. + */ public NativeVectorizationProvider() { var libraryLoaded = LibraryLoader.loadJvector(); if (!libraryLoaded) { diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/LibraryLoader.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/LibraryLoader.java index 3eb95f455..38bcb8542 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/LibraryLoader.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/LibraryLoader.java @@ -20,11 +20,59 @@ import java.nio.file.Files; /** - * This class is used to load supporting native libraries. First, it tries to load the library from the system path. - * If that fails, it tries to load the library from the classpath (using the usual copying to a tmp directory route). + * Utility class for loading the JVector native library, which provides SIMD-accelerated + * vector operations through native code implementations. + *

+ * This class implements a fallback loading strategy to maximize compatibility across + * different deployment environments: + *

    + *
  1. First attempts to load the library from the system library path using + * {@link System#loadLibrary(String)}, which checks standard locations like + * {@code java.library.path}
  2. + *
  3. If that fails, attempts to load the library from the classpath by: + *
      + *
    • Extracting the platform-specific library (e.g., {@code libjvector.so}, + * {@code libjvector.dylib}, or {@code jvector.dll}) from JAR resources
    • + *
    • Copying it to a temporary directory
    • + *
    • Loading it using {@link System#load(String)} with the absolute path
    • + *
    + *
  4. + *
+ * This dual-strategy approach allows the native library to be bundled within the JAR + * for ease of distribution while still supporting system-installed libraries for + * production deployments. + *

+ * The class uses a private constructor to prevent instantiation, as it only provides + * static utility methods. + * + * @see System#loadLibrary(String) + * @see System#load(String) */ public class LibraryLoader { private LibraryLoader() {} + + /** + * Attempts to load the JVector native library using a fallback strategy. + *

+ * The method first tries to load {@code libjvector} from the system library path. + * If that fails (typically when the library is not installed system-wide), it attempts + * to extract and load the library from the classpath. + *

+ * The classpath loading process: + *

    + *
  1. Maps the library name to the platform-specific filename using + * {@link System#mapLibraryName(String)}
  2. + *
  3. Creates a temporary file with the appropriate extension
  4. + *
  5. Extracts the library resource from the JAR to the temporary file
  6. + *
  7. Loads the library from the temporary file's absolute path
  8. + *
+ * Any errors during the loading process are silently caught, making this method + * suitable for optional native library loading where fallback implementations exist. + * + * @return {@code true} if the library was successfully loaded from either the system + * path or the classpath; {@code false} if both loading strategies failed or + * the library resource could not be found in the classpath + */ public static boolean loadJvector() { try { System.loadLibrary("jvector"); diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java index 2005d0d5f..99cd147e6 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java @@ -12,6 +12,9 @@ import static java.lang.foreign.ValueLayout.*; import static java.lang.foreign.MemoryLayout.PathElement.*; +/** + * Native SIMD operations for vector similarity computations. + */ public class NativeSimdOps { NativeSimdOps() { @@ -58,21 +61,31 @@ static MemoryLayout align(MemoryLayout layout, long align) { static final SymbolLookup SYMBOL_LOOKUP = SymbolLookup.loaderLookup() .or(Linker.nativeLinker().defaultLookup()); + /** C boolean type layout. */ public static final ValueLayout.OfBoolean C_BOOL = ValueLayout.JAVA_BOOLEAN; + /** C char type layout. */ public static final ValueLayout.OfByte C_CHAR = ValueLayout.JAVA_BYTE; + /** C short type layout. */ public static final ValueLayout.OfShort C_SHORT = ValueLayout.JAVA_SHORT; + /** C int type layout. */ public static final ValueLayout.OfInt C_INT = ValueLayout.JAVA_INT; + /** C long long type layout. */ public static final ValueLayout.OfLong C_LONG_LONG = ValueLayout.JAVA_LONG; + /** C float type layout. */ public static final ValueLayout.OfFloat C_FLOAT = ValueLayout.JAVA_FLOAT; + /** C double type layout. */ public static final ValueLayout.OfDouble C_DOUBLE = ValueLayout.JAVA_DOUBLE; + /** C pointer type layout. */ public static final AddressLayout C_POINTER = ValueLayout.ADDRESS .withTargetLayout(MemoryLayout.sequenceLayout(java.lang.Long.MAX_VALUE, JAVA_BYTE)); + /** C long type layout. */ public static final ValueLayout.OfLong C_LONG = ValueLayout.JAVA_LONG; private static final int true_ = (int)1L; /** * {@snippet lang=c : * #define true 1 * } + * @return the value 1 */ public static int true_() { return true_; @@ -82,6 +95,7 @@ public static int true_() { * {@snippet lang=c : * #define false 0 * } + * @return the value 0 */ public static int false_() { return false_; @@ -91,6 +105,7 @@ public static int false_() { * {@snippet lang=c : * #define __bool_true_false_are_defined 1 * } + * @return the value 1 */ public static int __bool_true_false_are_defined() { return __bool_true_false_are_defined; @@ -110,6 +125,7 @@ private static class check_compatibility { * {@snippet lang=c : * _Bool check_compatibility() * } + * @return the function descriptor */ public static FunctionDescriptor check_compatibility$descriptor() { return check_compatibility.DESC; @@ -120,6 +136,7 @@ private static class check_compatibility { * {@snippet lang=c : * _Bool check_compatibility() * } + * @return the method handle */ public static MethodHandle check_compatibility$handle() { return check_compatibility.HANDLE; @@ -130,15 +147,32 @@ private static class check_compatibility { * {@snippet lang=c : * _Bool check_compatibility() * } + * @return the memory segment address */ public static MemorySegment check_compatibility$address() { return check_compatibility.ADDR; } /** + * Checks whether the native SIMD library is compatible with the current CPU architecture. + *

+ * This method verifies that the underlying native library was compiled with SIMD instructions + * that are supported by the current processor. It should be called before using any other + * native SIMD operations to ensure the library will function correctly. + *

+ * The compatibility check typically verifies CPU features such as: + *

    + *
  • AVX2 support for 256-bit vector operations
  • + *
  • AVX-512 support for 512-bit vector operations (if the library was compiled with AVX-512)
  • + *
  • FMA (Fused Multiply-Add) instruction support
  • + *
* {@snippet lang=c : * _Bool check_compatibility() * } + * + * @return {@code true} if the native library is compatible with the current CPU and can be + * safely used; {@code false} if the CPU lacks required SIMD instruction sets and + * native operations should not be used */ public static boolean check_compatibility() { var mh$ = check_compatibility.HANDLE; @@ -173,6 +207,7 @@ private static class dot_product_f32 { * {@snippet lang=c : * float dot_product_f32(int preferred_size, const float *a, int aoffset, const float *b, int boffset, int length) * } + * @return the function descriptor */ public static FunctionDescriptor dot_product_f32$descriptor() { return dot_product_f32.DESC; @@ -183,6 +218,7 @@ private static class dot_product_f32 { * {@snippet lang=c : * float dot_product_f32(int preferred_size, const float *a, int aoffset, const float *b, int boffset, int length) * } + * @return the method handle */ public static MethodHandle dot_product_f32$handle() { return dot_product_f32.HANDLE; @@ -193,15 +229,36 @@ private static class dot_product_f32 { * {@snippet lang=c : * float dot_product_f32(int preferred_size, const float *a, int aoffset, const float *b, int boffset, int length) * } + * @return the memory segment address */ public static MemorySegment dot_product_f32$address() { return dot_product_f32.ADDR; } /** + * Computes the dot product of two float32 vectors using native SIMD instructions. + *

+ * The dot product is calculated as: sum(a[i] * b[i]) for i in [0, length). + * This native implementation uses hardware SIMD instructions (AVX2/AVX-512) for + * significantly better performance than pure Java implementations, especially for + * large vectors. + *

+ * The vectors are accessed from their respective memory segments starting at the specified + * offsets. Both vectors must have at least {@code length} elements available from their + * starting offsets. * {@snippet lang=c : * float dot_product_f32(int preferred_size, const float *a, int aoffset, const float *b, int boffset, int length) * } + * + * @param preferred_size the preferred SIMD vector width in bits (e.g., 256 for AVX2, + * 512 for AVX-512); this hint allows the native code to select + * the optimal SIMD instruction set for the hardware + * @param a memory segment containing the first vector's float32 data + * @param aoffset starting offset in the first vector (element index, not byte offset) + * @param b memory segment containing the second vector's float32 data + * @param boffset starting offset in the second vector (element index, not byte offset) + * @param length number of elements to include in the dot product computation + * @return the dot product of the two vectors as a float32 value */ public static float dot_product_f32(int preferred_size, MemorySegment a, int aoffset, MemorySegment b, int boffset, int length) { var mh$ = dot_product_f32.HANDLE; @@ -236,6 +293,7 @@ private static class euclidean_f32 { * {@snippet lang=c : * float euclidean_f32(int preferred_size, const float *a, int aoffset, const float *b, int boffset, int length) * } + * @return the function descriptor */ public static FunctionDescriptor euclidean_f32$descriptor() { return euclidean_f32.DESC; @@ -246,6 +304,7 @@ private static class euclidean_f32 { * {@snippet lang=c : * float euclidean_f32(int preferred_size, const float *a, int aoffset, const float *b, int boffset, int length) * } + * @return the method handle */ public static MethodHandle euclidean_f32$handle() { return euclidean_f32.HANDLE; @@ -256,15 +315,36 @@ private static class euclidean_f32 { * {@snippet lang=c : * float euclidean_f32(int preferred_size, const float *a, int aoffset, const float *b, int boffset, int length) * } + * @return the memory segment address */ public static MemorySegment euclidean_f32$address() { return euclidean_f32.ADDR; } /** + * Computes the squared Euclidean distance between two float32 vectors using native SIMD instructions. + *

+ * The squared Euclidean distance is calculated as: sum((a[i] - b[i])^2) for i in [0, length). + * This method returns the squared distance rather than the actual Euclidean distance to avoid + * the computational cost of the square root operation. For distance comparisons and nearest + * neighbor searches, the squared distance provides the same ordering as the actual distance. + *

+ * This native implementation uses hardware SIMD instructions (AVX2/AVX-512) to compute + * multiple differences and squared differences in parallel, providing significant performance + * improvements over pure Java implementations. * {@snippet lang=c : * float euclidean_f32(int preferred_size, const float *a, int aoffset, const float *b, int boffset, int length) * } + * + * @param preferred_size the preferred SIMD vector width in bits (e.g., 256 for AVX2, + * 512 for AVX-512); this hint allows the native code to select + * the optimal SIMD instruction set for the hardware + * @param a memory segment containing the first vector's float32 data + * @param aoffset starting offset in the first vector (element index, not byte offset) + * @param b memory segment containing the second vector's float32 data + * @param boffset starting offset in the second vector (element index, not byte offset) + * @param length number of elements to include in the distance computation + * @return the squared Euclidean distance between the two vectors as a float32 value */ public static float euclidean_f32(int preferred_size, MemorySegment a, int aoffset, MemorySegment b, int boffset, int length) { var mh$ = euclidean_f32.HANDLE; @@ -298,6 +378,7 @@ private static class bulk_quantized_shuffle_dot_f32_512 { * {@snippet lang=c : * void bulk_quantized_shuffle_dot_f32_512(const unsigned char *shuffles, int codebookCount, const char *quantizedPartials, float delta, float minDistance, float *results) * } + * @return the function descriptor */ public static FunctionDescriptor bulk_quantized_shuffle_dot_f32_512$descriptor() { return bulk_quantized_shuffle_dot_f32_512.DESC; @@ -308,6 +389,7 @@ private static class bulk_quantized_shuffle_dot_f32_512 { * {@snippet lang=c : * void bulk_quantized_shuffle_dot_f32_512(const unsigned char *shuffles, int codebookCount, const char *quantizedPartials, float delta, float minDistance, float *results) * } + * @return the method handle */ public static MethodHandle bulk_quantized_shuffle_dot_f32_512$handle() { return bulk_quantized_shuffle_dot_f32_512.HANDLE; @@ -318,15 +400,47 @@ private static class bulk_quantized_shuffle_dot_f32_512 { * {@snippet lang=c : * void bulk_quantized_shuffle_dot_f32_512(const unsigned char *shuffles, int codebookCount, const char *quantizedPartials, float delta, float minDistance, float *results) * } + * @return the memory segment address */ public static MemorySegment bulk_quantized_shuffle_dot_f32_512$address() { return bulk_quantized_shuffle_dot_f32_512.ADDR; } /** + * Performs bulk similarity scoring for Product Quantization (PQ) compressed vectors using + * dot product similarity with native AVX-512 SIMD instructions. + *

+ * This method is optimized for batch scoring of multiple quantized vectors against a query. + * It processes vectors that have been compressed using Product Quantization, where each + * vector is represented as a sequence of codebook indices (shuffles) that reference + * pre-computed partial dot products. + *

+ * The similarity score for each vector is reconstructed by: + *

    + *
  1. Using the shuffle indices to gather the corresponding quantized partial values
  2. + *
  3. Dequantizing these values using: {@code partialValue = quantizedValue * delta + minDistance}
  4. + *
  5. Summing the dequantized partials to produce the final dot product score
  6. + *
+ * The AVX-512 implementation processes multiple vectors in parallel using 512-bit wide + * vector registers for maximum throughput. * {@snippet lang=c : * void bulk_quantized_shuffle_dot_f32_512(const unsigned char *shuffles, int codebookCount, const char *quantizedPartials, float delta, float minDistance, float *results) * } + * + * @param shuffles memory segment containing the codebook indices for each vector; organized + * as a flat array where each group of {@code codebookCount} consecutive bytes + * represents one compressed vector + * @param codebookCount number of codebooks (subquantizers) used in the Product Quantization; + * also the number of bytes per compressed vector + * @param quantizedPartials memory segment containing the quantized partial dot product values; + * these are the pre-computed dot products between the query and each + * codebook centroid, stored in quantized (int8) form + * @param delta the dequantization scale factor; used to convert quantized int8 values back + * to float32: {@code floatValue = int8Value * delta + minDistance} + * @param minDistance the dequantization offset (minimum value); added after scaling during + * dequantization + * @param results output memory segment where the computed similarity scores will be written; + * must have sufficient capacity for all result values */ public static void bulk_quantized_shuffle_dot_f32_512(MemorySegment shuffles, int codebookCount, MemorySegment quantizedPartials, float delta, float minDistance, MemorySegment results) { var mh$ = bulk_quantized_shuffle_dot_f32_512.HANDLE; @@ -360,6 +474,7 @@ private static class bulk_quantized_shuffle_euclidean_f32_512 { * {@snippet lang=c : * void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char *shuffles, int codebookCount, const char *quantizedPartials, float delta, float minDistance, float *results) * } + * @return the function descriptor */ public static FunctionDescriptor bulk_quantized_shuffle_euclidean_f32_512$descriptor() { return bulk_quantized_shuffle_euclidean_f32_512.DESC; @@ -370,6 +485,7 @@ private static class bulk_quantized_shuffle_euclidean_f32_512 { * {@snippet lang=c : * void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char *shuffles, int codebookCount, const char *quantizedPartials, float delta, float minDistance, float *results) * } + * @return the method handle */ public static MethodHandle bulk_quantized_shuffle_euclidean_f32_512$handle() { return bulk_quantized_shuffle_euclidean_f32_512.HANDLE; @@ -380,15 +496,47 @@ private static class bulk_quantized_shuffle_euclidean_f32_512 { * {@snippet lang=c : * void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char *shuffles, int codebookCount, const char *quantizedPartials, float delta, float minDistance, float *results) * } + * @return the memory segment address */ public static MemorySegment bulk_quantized_shuffle_euclidean_f32_512$address() { return bulk_quantized_shuffle_euclidean_f32_512.ADDR; } /** + * Performs bulk distance scoring for Product Quantization (PQ) compressed vectors using + * squared Euclidean distance with native AVX-512 SIMD instructions. + *

+ * This method is optimized for batch distance computation of multiple quantized vectors + * against a query. It processes vectors that have been compressed using Product Quantization, + * where each vector is represented as a sequence of codebook indices (shuffles) that reference + * pre-computed partial squared distances. + *

+ * The squared Euclidean distance for each vector is reconstructed by: + *

    + *
  1. Using the shuffle indices to gather the corresponding quantized partial values
  2. + *
  3. Dequantizing these values using: {@code partialValue = quantizedValue * delta + minDistance}
  4. + *
  5. Summing the dequantized partials to produce the final squared distance
  6. + *
+ * The AVX-512 implementation processes multiple vectors in parallel using 512-bit wide + * vector registers for maximum throughput. * {@snippet lang=c : * void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char *shuffles, int codebookCount, const char *quantizedPartials, float delta, float minDistance, float *results) * } + * + * @param shuffles memory segment containing the codebook indices for each vector; organized + * as a flat array where each group of {@code codebookCount} consecutive bytes + * represents one compressed vector + * @param codebookCount number of codebooks (subquantizers) used in the Product Quantization; + * also the number of bytes per compressed vector + * @param quantizedPartials memory segment containing the quantized partial squared distance values; + * these are the pre-computed squared distances between the query and each + * codebook centroid, stored in quantized (int8) form + * @param delta the dequantization scale factor; used to convert quantized int8 values back + * to float32: {@code floatValue = int8Value * delta + minDistance} + * @param minDistance the dequantization offset (minimum value); added after scaling during + * dequantization + * @param results output memory segment where the computed distance scores will be written; + * must have sufficient capacity for all result values */ public static void bulk_quantized_shuffle_euclidean_f32_512(MemorySegment shuffles, int codebookCount, MemorySegment quantizedPartials, float delta, float minDistance, MemorySegment results) { var mh$ = bulk_quantized_shuffle_euclidean_f32_512.HANDLE; @@ -426,6 +574,7 @@ private static class bulk_quantized_shuffle_cosine_f32_512 { * {@snippet lang=c : * void bulk_quantized_shuffle_cosine_f32_512(const unsigned char *shuffles, int codebookCount, const char *quantizedPartialSums, float sumDelta, float minDistance, const char *quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float *results) * } + * @return the function descriptor */ public static FunctionDescriptor bulk_quantized_shuffle_cosine_f32_512$descriptor() { return bulk_quantized_shuffle_cosine_f32_512.DESC; @@ -436,6 +585,7 @@ private static class bulk_quantized_shuffle_cosine_f32_512 { * {@snippet lang=c : * void bulk_quantized_shuffle_cosine_f32_512(const unsigned char *shuffles, int codebookCount, const char *quantizedPartialSums, float sumDelta, float minDistance, const char *quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float *results) * } + * @return the method handle */ public static MethodHandle bulk_quantized_shuffle_cosine_f32_512$handle() { return bulk_quantized_shuffle_cosine_f32_512.HANDLE; @@ -446,15 +596,56 @@ private static class bulk_quantized_shuffle_cosine_f32_512 { * {@snippet lang=c : * void bulk_quantized_shuffle_cosine_f32_512(const unsigned char *shuffles, int codebookCount, const char *quantizedPartialSums, float sumDelta, float minDistance, const char *quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float *results) * } + * @return the memory segment address */ public static MemorySegment bulk_quantized_shuffle_cosine_f32_512$address() { return bulk_quantized_shuffle_cosine_f32_512.ADDR; } /** + * Performs bulk similarity scoring for Product Quantization (PQ) compressed vectors using + * cosine similarity with native AVX-512 SIMD instructions. + *

+ * This method is optimized for batch cosine similarity computation of multiple quantized + * vectors against a query. Cosine similarity requires both dot products and vector magnitudes, + * so this method processes two separate quantized components for each vector. + *

+ * The cosine similarity for each vector is computed as: + *

+     * cosine = dotProduct / sqrt(queryMagnitude * vectorMagnitude)
+     * 
+ * Since we work with squared magnitudes, the actual computation is: + *
+     * cosine = dotProduct / sqrt(queryMagnitudeSquared * vectorMagnitudeSquared)
+     * 
+ * Both the dot products and magnitudes are reconstructed from quantized partial values + * using shuffle-based lookup and dequantization. The AVX-512 implementation processes + * multiple vectors in parallel using 512-bit wide vector registers. * {@snippet lang=c : * void bulk_quantized_shuffle_cosine_f32_512(const unsigned char *shuffles, int codebookCount, const char *quantizedPartialSums, float sumDelta, float minDistance, const char *quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float *results) * } + * + * @param shuffles memory segment containing the codebook indices for each vector; organized + * as a flat array where each group of {@code codebookCount} consecutive bytes + * represents one compressed vector + * @param codebookCount number of codebooks (subquantizers) used in the Product Quantization; + * also the number of bytes per compressed vector + * @param quantizedPartialSums memory segment containing the quantized partial dot product values; + * these are pre-computed dot products between the query and each + * codebook centroid, stored in quantized (int8) form + * @param sumDelta the dequantization scale factor for dot products; used to convert quantized + * int8 sum values back to float32 + * @param minDistance the dequantization offset for dot products; added after scaling + * @param quantizedPartialMagnitudes memory segment containing the quantized partial magnitude values; + * these are pre-computed squared magnitudes of each codebook + * centroid, stored in quantized (int8) form + * @param magnitudeDelta the dequantization scale factor for magnitudes; used to convert + * quantized int8 magnitude values back to float32 + * @param minMagnitude the dequantization offset for magnitudes; added after scaling + * @param queryMagnitudeSquared the squared magnitude of the query vector; used in the + * denominator of the cosine similarity formula + * @param results output memory segment where the computed cosine similarity scores will be + * written; must have sufficient capacity for all result values */ public static void bulk_quantized_shuffle_cosine_f32_512(MemorySegment shuffles, int codebookCount, MemorySegment quantizedPartialSums, float sumDelta, float minDistance, MemorySegment quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, MemorySegment results) { var mh$ = bulk_quantized_shuffle_cosine_f32_512.HANDLE; @@ -488,6 +679,7 @@ private static class assemble_and_sum_f32_512 { * {@snippet lang=c : * float assemble_and_sum_f32_512(const float *data, int dataBase, const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) * } + * @return the function descriptor */ public static FunctionDescriptor assemble_and_sum_f32_512$descriptor() { return assemble_and_sum_f32_512.DESC; @@ -498,6 +690,7 @@ private static class assemble_and_sum_f32_512 { * {@snippet lang=c : * float assemble_and_sum_f32_512(const float *data, int dataBase, const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) * } + * @return the method handle */ public static MethodHandle assemble_and_sum_f32_512$handle() { return assemble_and_sum_f32_512.HANDLE; @@ -508,15 +701,43 @@ private static class assemble_and_sum_f32_512 { * {@snippet lang=c : * float assemble_and_sum_f32_512(const float *data, int dataBase, const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) * } + * @return the memory segment address */ public static MemorySegment assemble_and_sum_f32_512$address() { return assemble_and_sum_f32_512.ADDR; } /** + * Assembles vector elements from multiple locations using offset-based gathering and computes + * their sum using native AVX-512 SIMD instructions. + *

+ * This method performs a gather-and-sum operation where elements are collected from a data + * array using a sequence of byte offsets, then summed together. This is commonly used in + * Product Quantization to reconstruct and sum partial values from codebook centroids. + *

+ * The operation computes: + *

+     * sum = 0
+     * for i in [0, baseOffsetsLength):
+     *     offset = baseOffsets[baseOffsetsOffset + i]
+     *     sum += data[dataBase + offset]
+     * 
+ * The AVX-512 implementation uses gather instructions to collect multiple elements in parallel + * and SIMD horizontal sum operations for efficient summation. * {@snippet lang=c : * float assemble_and_sum_f32_512(const float *data, int dataBase, const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) * } + * + * @param data memory segment containing the source float32 data array from which elements + * will be gathered + * @param dataBase base index into the data array; offsets are added to this base to compute + * the actual element indices + * @param baseOffsets memory segment containing unsigned byte offsets; each offset specifies + * a relative position from {@code dataBase} + * @param baseOffsetsOffset starting position in the baseOffsets array + * @param baseOffsetsLength number of offsets to process; determines how many elements will + * be gathered and summed + * @return the sum of all gathered float32 values */ public static float assemble_and_sum_f32_512(MemorySegment data, int dataBase, MemorySegment baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) { var mh$ = assemble_and_sum_f32_512.HANDLE; @@ -552,6 +773,7 @@ private static class pq_decoded_cosine_similarity_f32_512 { * {@snippet lang=c : * float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) * } + * @return the function descriptor */ public static FunctionDescriptor pq_decoded_cosine_similarity_f32_512$descriptor() { return pq_decoded_cosine_similarity_f32_512.DESC; @@ -562,6 +784,7 @@ private static class pq_decoded_cosine_similarity_f32_512 { * {@snippet lang=c : * float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) * } + * @return the method handle */ public static MethodHandle pq_decoded_cosine_similarity_f32_512$handle() { return pq_decoded_cosine_similarity_f32_512.HANDLE; @@ -572,15 +795,43 @@ private static class pq_decoded_cosine_similarity_f32_512 { * {@snippet lang=c : * float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) * } + * @return the memory segment address */ public static MemorySegment pq_decoded_cosine_similarity_f32_512$address() { return pq_decoded_cosine_similarity_f32_512.ADDR; } /** + * Computes the cosine similarity between a query and a Product Quantization compressed vector + * using pre-computed partial dot products and magnitudes with native AVX-512 SIMD instructions. + *

+ * This method reconstructs and computes the cosine similarity for a single compressed vector + * by gathering the appropriate partial values using the vector's codebook indices (baseOffsets). + * The cosine similarity formula is: + *

+     * cosine = dotProduct / sqrt(vectorMagnitude * queryMagnitude)
+     * 
+ * Both the dot product and the vector magnitude are reconstructed by gathering and summing + * partial values from pre-computed tables using the codebook indices. * {@snippet lang=c : * float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) * } + * + * @param baseOffsets memory segment containing the codebook indices that represent the + * compressed vector; each byte is an index into a codebook + * @param baseOffsetsOffset starting position in the baseOffsets array for this vector + * @param baseOffsetsLength number of codebook indices for this vector (typically equal to + * the number of subquantizers) + * @param clusterCount number of centroids in each codebook; used to calculate offsets into + * the partialSums and aMagnitude arrays + * @param partialSums memory segment containing pre-computed partial dot products between the + * query and all codebook centroids; organized as a 2D array with dimensions + * [numSubquantizers][clusterCount] + * @param aMagnitude memory segment containing pre-computed partial squared magnitudes for all + * codebook centroids; organized as a 2D array with dimensions + * [numSubquantizers][clusterCount] + * @param bMagnitude the magnitude of the query vector (second operand in the similarity) + * @return the cosine similarity between the query and the compressed vector, as a float32 value */ public static float pq_decoded_cosine_similarity_f32_512(MemorySegment baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, MemorySegment partialSums, MemorySegment aMagnitude, float bMagnitude) { var mh$ = pq_decoded_cosine_similarity_f32_512.HANDLE; @@ -615,6 +866,7 @@ private static class calculate_partial_sums_dot_f32_512 { * {@snippet lang=c : * void calculate_partial_sums_dot_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums) * } + * @return the function descriptor */ public static FunctionDescriptor calculate_partial_sums_dot_f32_512$descriptor() { return calculate_partial_sums_dot_f32_512.DESC; @@ -625,6 +877,7 @@ private static class calculate_partial_sums_dot_f32_512 { * {@snippet lang=c : * void calculate_partial_sums_dot_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums) * } + * @return the method handle */ public static MethodHandle calculate_partial_sums_dot_f32_512$handle() { return calculate_partial_sums_dot_f32_512.HANDLE; @@ -635,15 +888,46 @@ private static class calculate_partial_sums_dot_f32_512 { * {@snippet lang=c : * void calculate_partial_sums_dot_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums) * } + * @return the memory segment address */ public static MemorySegment calculate_partial_sums_dot_f32_512$address() { return calculate_partial_sums_dot_f32_512.ADDR; } /** + * Computes partial dot products between a query vector and all centroids in a Product + * Quantization codebook using native AVX-512 SIMD instructions. + *

+ * This method is a key operation in Product Quantization-based approximate nearest neighbor + * search. It pre-computes the dot products between the query vector and all codebook centroids + * for a single subquantizer. These partial dot products are later used to quickly estimate + * the full dot product between the query and compressed vectors by summing the appropriate + * partial values. + *

+ * For each centroid in the codebook, the method computes: + *

+     * partialSums[i] = dotProduct(query[queryOffset:queryOffset+size], codebook[i*size:(i+1)*size])
+     * 
+ * where i ranges from 0 to clusterCount-1. The AVX-512 implementation processes multiple + * dot products in parallel using 512-bit wide vector registers. * {@snippet lang=c : * void calculate_partial_sums_dot_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums) * } + * + * @param codebook memory segment containing the codebook centroids; organized as a flat + * array where each group of {@code size} consecutive floats represents + * one centroid vector + * @param codebookBase starting index in the codebook array; allows processing a subset of + * the codebook + * @param size dimensionality of each codebook centroid (subspace dimension); also the number + * of elements from the query vector to use + * @param clusterCount number of centroids in the codebook; determines the size of the output + * partialSums array + * @param query memory segment containing the query vector's float32 data + * @param queryOffset starting offset in the query vector; specifies which subspace of the + * query to use for this codebook + * @param partialSums output memory segment where the computed partial dot products will be + * written; must have capacity for at least {@code clusterCount} float values */ public static void calculate_partial_sums_dot_f32_512(MemorySegment codebook, int codebookBase, int size, int clusterCount, MemorySegment query, int queryOffset, MemorySegment partialSums) { var mh$ = calculate_partial_sums_dot_f32_512.HANDLE; @@ -678,6 +962,7 @@ private static class calculate_partial_sums_euclidean_f32_512 { * {@snippet lang=c : * void calculate_partial_sums_euclidean_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums) * } + * @return the function descriptor */ public static FunctionDescriptor calculate_partial_sums_euclidean_f32_512$descriptor() { return calculate_partial_sums_euclidean_f32_512.DESC; @@ -688,6 +973,7 @@ private static class calculate_partial_sums_euclidean_f32_512 { * {@snippet lang=c : * void calculate_partial_sums_euclidean_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums) * } + * @return the method handle */ public static MethodHandle calculate_partial_sums_euclidean_f32_512$handle() { return calculate_partial_sums_euclidean_f32_512.HANDLE; @@ -698,15 +984,46 @@ private static class calculate_partial_sums_euclidean_f32_512 { * {@snippet lang=c : * void calculate_partial_sums_euclidean_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums) * } + * @return the memory segment address */ public static MemorySegment calculate_partial_sums_euclidean_f32_512$address() { return calculate_partial_sums_euclidean_f32_512.ADDR; } /** + * Computes partial squared Euclidean distances between a query vector and all centroids in a + * Product Quantization codebook using native AVX-512 SIMD instructions. + *

+ * This method is a key operation in Product Quantization-based approximate nearest neighbor + * search using Euclidean distance. It pre-computes the squared distances between the query + * vector and all codebook centroids for a single subquantizer. These partial distances are + * later used to quickly estimate the full squared Euclidean distance between the query and + * compressed vectors by summing the appropriate partial values. + *

+ * For each centroid in the codebook, the method computes: + *

+     * partialSums[i] = sum((query[j] - codebook[i][j])^2) for j in [0, size)
+     * 
+ * where i ranges from 0 to clusterCount-1. The AVX-512 implementation processes multiple + * distance calculations in parallel using 512-bit wide vector registers. * {@snippet lang=c : * void calculate_partial_sums_euclidean_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums) * } + * + * @param codebook memory segment containing the codebook centroids; organized as a flat + * array where each group of {@code size} consecutive floats represents + * one centroid vector + * @param codebookBase starting index in the codebook array; allows processing a subset of + * the codebook + * @param size dimensionality of each codebook centroid (subspace dimension); also the number + * of elements from the query vector to use + * @param clusterCount number of centroids in the codebook; determines the size of the output + * partialSums array + * @param query memory segment containing the query vector's float32 data + * @param queryOffset starting offset in the query vector; specifies which subspace of the + * query to use for this codebook + * @param partialSums output memory segment where the computed partial squared distances will be + * written; must have capacity for at least {@code clusterCount} float values */ public static void calculate_partial_sums_euclidean_f32_512(MemorySegment codebook, int codebookBase, int size, int clusterCount, MemorySegment query, int queryOffset, MemorySegment partialSums) { var mh$ = calculate_partial_sums_euclidean_f32_512.HANDLE; @@ -742,6 +1059,7 @@ private static class calculate_partial_sums_best_dot_f32_512 { * {@snippet lang=c : * void calculate_partial_sums_best_dot_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums, float *partialBestDistances) * } + * @return the function descriptor */ public static FunctionDescriptor calculate_partial_sums_best_dot_f32_512$descriptor() { return calculate_partial_sums_best_dot_f32_512.DESC; @@ -752,6 +1070,7 @@ private static class calculate_partial_sums_best_dot_f32_512 { * {@snippet lang=c : * void calculate_partial_sums_best_dot_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums, float *partialBestDistances) * } + * @return the method handle */ public static MethodHandle calculate_partial_sums_best_dot_f32_512$handle() { return calculate_partial_sums_best_dot_f32_512.HANDLE; @@ -762,15 +1081,49 @@ private static class calculate_partial_sums_best_dot_f32_512 { * {@snippet lang=c : * void calculate_partial_sums_best_dot_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums, float *partialBestDistances) * } + * @return the memory segment address */ public static MemorySegment calculate_partial_sums_best_dot_f32_512$address() { return calculate_partial_sums_best_dot_f32_512.ADDR; } /** + * Computes partial dot products and identifies the maximum dot product (best match) between + * a query vector and all centroids in a Product Quantization codebook using native AVX-512 + * SIMD instructions. + *

+ * This method extends {@link #calculate_partial_sums_dot_f32_512} by additionally tracking + * the maximum dot product value among all centroids. This is useful for optimizations in + * Product Quantization where knowing the best possible match helps with early termination + * or pruning strategies. + *

+ * For each centroid in the codebook, the method computes: + *

+     * partialSums[i] = dotProduct(query[queryOffset:queryOffset+size], codebook[i*size:(i+1)*size])
+     * partialBestDistances[0] = max(partialSums[0], partialSums[1], ..., partialSums[clusterCount-1])
+     * 
+ * The AVX-512 implementation processes multiple dot products in parallel and uses SIMD + * horizontal max operations to efficiently find the maximum value. * {@snippet lang=c : * void calculate_partial_sums_best_dot_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums, float *partialBestDistances) * } + * + * @param codebook memory segment containing the codebook centroids; organized as a flat + * array where each group of {@code size} consecutive floats represents + * one centroid vector + * @param codebookBase starting index in the codebook array; allows processing a subset of + * the codebook + * @param size dimensionality of each codebook centroid (subspace dimension); also the number + * of elements from the query vector to use + * @param clusterCount number of centroids in the codebook; determines the size of the output + * partialSums array + * @param query memory segment containing the query vector's float32 data + * @param queryOffset starting offset in the query vector; specifies which subspace of the + * query to use for this codebook + * @param partialSums output memory segment where the computed partial dot products will be + * written; must have capacity for at least {@code clusterCount} float values + * @param partialBestDistances output memory segment where the maximum dot product value will + * be written; only the first element is written */ public static void calculate_partial_sums_best_dot_f32_512(MemorySegment codebook, int codebookBase, int size, int clusterCount, MemorySegment query, int queryOffset, MemorySegment partialSums, MemorySegment partialBestDistances) { var mh$ = calculate_partial_sums_best_dot_f32_512.HANDLE; @@ -806,6 +1159,7 @@ private static class calculate_partial_sums_best_euclidean_f32_512 { * {@snippet lang=c : * void calculate_partial_sums_best_euclidean_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums, float *partialBestDistances) * } + * @return the function descriptor */ public static FunctionDescriptor calculate_partial_sums_best_euclidean_f32_512$descriptor() { return calculate_partial_sums_best_euclidean_f32_512.DESC; @@ -816,6 +1170,7 @@ private static class calculate_partial_sums_best_euclidean_f32_512 { * {@snippet lang=c : * void calculate_partial_sums_best_euclidean_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums, float *partialBestDistances) * } + * @return the method handle */ public static MethodHandle calculate_partial_sums_best_euclidean_f32_512$handle() { return calculate_partial_sums_best_euclidean_f32_512.HANDLE; @@ -826,15 +1181,49 @@ private static class calculate_partial_sums_best_euclidean_f32_512 { * {@snippet lang=c : * void calculate_partial_sums_best_euclidean_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums, float *partialBestDistances) * } + * @return the memory segment address */ public static MemorySegment calculate_partial_sums_best_euclidean_f32_512$address() { return calculate_partial_sums_best_euclidean_f32_512.ADDR; } /** + * Computes partial squared Euclidean distances and identifies the minimum distance (best match) + * between a query vector and all centroids in a Product Quantization codebook using native + * AVX-512 SIMD instructions. + *

+ * This method extends {@link #calculate_partial_sums_euclidean_f32_512} by additionally tracking + * the minimum squared distance value among all centroids. This is useful for optimizations in + * Product Quantization where knowing the best possible match helps with early termination + * or pruning strategies in nearest neighbor search. + *

+ * For each centroid in the codebook, the method computes: + *

+     * partialSums[i] = sum((query[j] - codebook[i][j])^2) for j in [0, size)
+     * partialBestDistances[0] = min(partialSums[0], partialSums[1], ..., partialSums[clusterCount-1])
+     * 
+ * The AVX-512 implementation processes multiple distance calculations in parallel and uses + * SIMD horizontal min operations to efficiently find the minimum value. * {@snippet lang=c : * void calculate_partial_sums_best_euclidean_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums, float *partialBestDistances) * } + * + * @param codebook memory segment containing the codebook centroids; organized as a flat + * array where each group of {@code size} consecutive floats represents + * one centroid vector + * @param codebookBase starting index in the codebook array; allows processing a subset of + * the codebook + * @param size dimensionality of each codebook centroid (subspace dimension); also the number + * of elements from the query vector to use + * @param clusterCount number of centroids in the codebook; determines the size of the output + * partialSums array + * @param query memory segment containing the query vector's float32 data + * @param queryOffset starting offset in the query vector; specifies which subspace of the + * query to use for this codebook + * @param partialSums output memory segment where the computed partial squared distances will be + * written; must have capacity for at least {@code clusterCount} float values + * @param partialBestDistances output memory segment where the minimum squared distance value + * will be written; only the first element is written */ public static void calculate_partial_sums_best_euclidean_f32_512(MemorySegment codebook, int codebookBase, int size, int clusterCount, MemorySegment query, int queryOffset, MemorySegment partialSums, MemorySegment partialBestDistances) { var mh$ = calculate_partial_sums_best_euclidean_f32_512.HANDLE; diff --git a/jvector-twenty/pom.xml b/jvector-twenty/pom.xml index ae6aa659b..afedaf5fc 100644 --- a/jvector-twenty/pom.xml +++ b/jvector-twenty/pom.xml @@ -39,6 +39,16 @@ + + org.apache.maven.plugins + maven-javadoc-plugin + + + --add-modules=jdk.incubator.vector + + 22 + + diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorizationProvider.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorizationProvider.java index ae0952ff7..1234f708c 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorizationProvider.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorizationProvider.java @@ -31,6 +31,23 @@ public class PanamaVectorizationProvider extends VectorizationProvider private final VectorUtilSupport vectorUtilSupport; private final VectorTypeSupport vectorTypeSupport; + /** + * Constructs a new PanamaVectorizationProvider that utilizes the Panama Vector API + * for hardware-accelerated SIMD operations. + *

+ * This constructor initializes the vectorization provider with: + *

    + *
  • {@link PanamaVectorUtilSupport} for SIMD-accelerated vector operations including + * dot product, cosine similarity, and Euclidean distance calculations
  • + *
  • {@link ArrayVectorProvider} for on-heap array-backed vector storage
  • + *
+ * The constructor also disables out-of-bounds checking for vector access operations via + * the system property {@code jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK} to maximize + * performance in production environments. + *

+ * The preferred SIMD vector width is logged at construction time and depends on the + * underlying hardware capabilities (e.g., AVX2, AVX-512). + */ public PanamaVectorizationProvider() { this.vectorUtilSupport = new PanamaVectorUtilSupport(); LOG.info("Preferred f32 species is " + FloatVector.SPECIES_PREFERRED.vectorBitSize());