Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
18f7291
vector scorer
mccullocht Sep 4, 2025
149965a
offheap vv
mccullocht Sep 4, 2025
aedce9c
writer
mccullocht Sep 5, 2025
89a4228
reader
mccullocht Sep 8, 2025
a55fee9
wrap up and test flat vector format for sq8
mccullocht Sep 8, 2025
b7abe2d
hnsw codec
mccullocht Sep 8, 2025
7cec9e8
enum for scalar encoding
mccullocht Sep 8, 2025
7b76f3d
fix most of the write path
mccullocht Sep 8, 2025
cf4fdef
packing without testing
mccullocht Sep 8, 2025
132c8ee
flat vectors test
mccullocht Sep 8, 2025
e90b6d1
fix license
mccullocht Sep 9, 2025
f9cc396
CHANGES
mccullocht Sep 9, 2025
e6e6b6a
handle boundary cases with nibble encoding -- unpacked must always ha…
mccullocht Sep 9, 2025
01a9748
resilience to small floating point errors
mccullocht Sep 9, 2025
2e5f89d
tidy--
mccullocht Sep 9, 2025
b30731c
remove unnecessary default
mccullocht Sep 10, 2025
bb89c01
tidy
mccullocht Sep 10, 2025
11a978a
unpack bytes during updateable scoring
mccullocht Sep 10, 2025
d8ad448
Merge remote-tracking branch 'origin/main' into sq-to-osq
mccullocht Sep 10, 2025
a9720ca
Merge remote-tracking branch 'origin/main' into sq-to-osq
mccullocht Sep 12, 2025
bc5d385
Apply suggestion from @benwtrent
mccullocht Sep 15, 2025
30350e4
add 7 bit representation
mccullocht Sep 15, 2025
93209dd
mark existing 99 formats as deprecated
mccullocht Sep 15, 2025
1b73a6f
Merge branch 'sq-to-osq' of github.com:mccullocht/lucene into sq-to-osq
mccullocht Sep 15, 2025
2de7d81
fix some missing 7 bit checks
mccullocht Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ New Features
* GITHUB#15176: Add `[Float|Byte]VectorValues#rescorer(element[])` interface to allow optimized rescoring of vectors.
(Ben Trent)

* GITHUB#15169: Add codecs for 4 and 8 bit Optimized Scalar Quantization vectors (Trevor McCulloch)

Improvements
---------------------
# GITHUB#15148: Add support uint8 distance and allow 8 bit scalar quantization (Trevor McCulloch)
Expand Down
4 changes: 3 additions & 1 deletion lucene/core/src/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene102.Lucene102HnswBinaryQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat;
org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat;
provides org.apache.lucene.codecs.PostingsFormat with
org.apache.lucene.codecs.lucene104.Lucene104PostingsFormat;
provides org.apache.lucene.index.SortFieldProvider with
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene104;

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;

import java.io.IOException;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.hnsw.HnswGraph;

/**
* A vectors format that uses HNSW graph to store and search for vectors. But vectors are binary
* quantized using {@link Lucene104ScalarQuantizedVectorsFormat} before being stored in the graph.
*/
public class Lucene104HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {

public static final String NAME = "Lucene104HnswBinaryQuantizedVectorsFormat";

/**
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
* {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
*/
private final int maxConn;

/**
* The number of candidate neighbors to track while searching the graph for each newly inserted
* node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph}
* for details.
*/
private final int beamWidth;

/** The format for storing, reading, merging vectors on disk */
private final Lucene104ScalarQuantizedVectorsFormat flatVectorsFormat;

private final int numMergeWorkers;
private final TaskExecutor mergeExec;

/** Constructs a format using default graph construction parameters */
public Lucene104HnswScalarQuantizedVectorsFormat() {
this(
ScalarEncoding.UNSIGNED_BYTE,
DEFAULT_MAX_CONN,
DEFAULT_BEAM_WIDTH,
DEFAULT_NUM_MERGE_WORKER,
null);
}

/**
* Constructs a format using the given graph construction parameters.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
*/
public Lucene104HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
this(ScalarEncoding.UNSIGNED_BYTE, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
}

/**
* Constructs a format using the given graph construction parameters and scalar quantization.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
* generated by this format to do the merge
*/
public Lucene104HnswScalarQuantizedVectorsFormat(
ScalarEncoding encoding,
int maxConn,
int beamWidth,
int numMergeWorkers,
ExecutorService mergeExec) {
super(NAME);
flatVectorsFormat = new Lucene104ScalarQuantizedVectorsFormat(encoding);
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to "
+ MAXIMUM_MAX_CONN
+ "; maxConn="
+ maxConn);
}
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
throw new IllegalArgumentException(
"beamWidth must be positive and less than or equal to "
+ MAXIMUM_BEAM_WIDTH
+ "; beamWidth="
+ beamWidth);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException(
"No executor service is needed as we'll use single thread to merge");
}
this.numMergeWorkers = numMergeWorkers;
if (mergeExec != null) {
this.mergeExec = new TaskExecutor(mergeExec);
} else {
this.mergeExec = null;
}
}

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(
state,
maxConn,
beamWidth,
flatVectorsFormat.fieldsWriter(state),
numMergeWorkers,
mergeExec);
}

@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}

@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}

@Override
public String toString() {
return "Lucene104HnswScalarQuantizedVectorsFormat(name=Lucene104HnswScalarQuantizedVectorsFormat, maxConn="
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene104;

import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;

/** Vector scorer over OptimizedScalarQuantized vectors */
public class Lucene104ScalarQuantizedVectorScorer implements FlatVectorsScorer {
private final FlatVectorsScorer nonQuantizedDelegate;

public Lucene104ScalarQuantizedVectorScorer(FlatVectorsScorer nonQuantizedDelegate) {
this.nonQuantizedDelegate = nonQuantizedDelegate;
}

@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
throws IOException {
if (vectorValues instanceof QuantizedByteVectorValues qv) {
return new ScalarQuantizedVectorScorerSupplier(qv, similarityFunction);
}
// It is possible to get to this branch during initial indexing and flush
return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
throws IOException {
if (vectorValues instanceof QuantizedByteVectorValues qv) {
OptimizedScalarQuantizer quantizer = qv.getQuantizer();
byte[] targetQuantized =
new byte
[OptimizedScalarQuantizer.discretize(
target.length, qv.getScalarEncoding().getDimensionsPerByte())];
// We make a copy as the quantization process mutates the input
float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length);
if (similarityFunction == COSINE) {
VectorUtil.l2normalize(copy);
}
target = copy;
var targetCorrectiveTerms =
quantizer.scalarQuantize(
target, targetQuantized, qv.getScalarEncoding().getBits(), qv.getCentroid());
return new RandomVectorScorer.AbstractRandomVectorScorer(qv) {
@Override
public float score(int node) throws IOException {
return quantizedScore(
targetQuantized, targetCorrectiveTerms, qv, node, similarityFunction);
}
};
}
// It is possible to get to this branch during initial indexing and flush
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
throws IOException {
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
}

@Override
public String toString() {
return "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate="
+ nonQuantizedDelegate
+ ")";
}

private static final class ScalarQuantizedVectorScorerSupplier
implements RandomVectorScorerSupplier {
private final QuantizedByteVectorValues targetValues;
private final QuantizedByteVectorValues values;
private final VectorSimilarityFunction similarity;

public ScalarQuantizedVectorScorerSupplier(
QuantizedByteVectorValues values, VectorSimilarityFunction similarity) throws IOException {
this.targetValues = values.copy();
this.values = values;
this.similarity = similarity;
}

@Override
public UpdateableRandomVectorScorer scorer() throws IOException {
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) {
private byte[] targetVector;
private OptimizedScalarQuantizer.QuantizationResult targetCorrectiveTerms;

@Override
public float score(int node) throws IOException {
return quantizedScore(targetVector, targetCorrectiveTerms, values, node, similarity);
}

@Override
public void setScoringOrdinal(int node) throws IOException {
var rawTargetVector = targetValues.vectorValue(node);
switch (values.getScalarEncoding()) {
case UNSIGNED_BYTE -> targetVector = rawTargetVector;
case SEVEN_BIT -> targetVector = rawTargetVector;
case PACKED_NIBBLE -> {
if (targetVector == null) {
targetVector = new byte[OptimizedScalarQuantizer.discretize(values.dimension(), 2)];
}
OffHeapScalarQuantizedVectorValues.unpackNibbles(rawTargetVector, targetVector);
}
}
targetCorrectiveTerms = targetValues.getCorrectiveTerms(node);
}
};
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new ScalarQuantizedVectorScorerSupplier(values.copy(), similarity);
}
}

private static final float[] SCALE_LUT =
new float[] {
1f,
1f / ((1 << 2) - 1),
1f / ((1 << 3) - 1),
1f / ((1 << 4) - 1),
1f / ((1 << 5) - 1),
1f / ((1 << 6) - 1),
1f / ((1 << 7) - 1),
1f / ((1 << 8) - 1),
};

private static float quantizedScore(
byte[] quantizedQuery,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
QuantizedByteVectorValues targetVectors,
int targetOrd,
VectorSimilarityFunction similarityFunction)
throws IOException {
var scalarEncoding = targetVectors.getScalarEncoding();
byte[] quantizedDoc = targetVectors.vectorValue(targetOrd);
float qcDist =
switch (scalarEncoding) {
case UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc);
case SEVEN_BIT -> VectorUtil.dotProduct(quantizedQuery, quantizedDoc);
case PACKED_NIBBLE -> VectorUtil.int4DotProductPacked(quantizedQuery, quantizedDoc);
};
OptimizedScalarQuantizer.QuantizationResult indexCorrections =
targetVectors.getCorrectiveTerms(targetOrd);
float scale = SCALE_LUT[scalarEncoding.getBits() - 1];
float x1 = indexCorrections.quantizedComponentSum();
float ax = indexCorrections.lowerInterval();
// Here we must scale according to the bits
float lx = (indexCorrections.upperInterval() - ax) * scale;
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * scale;
float y1 = queryCorrections.quantizedComponentSum();
float score =
ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist;
// For euclidean, we need to invert the score and apply the additional correction, which is
// assumed to be the squared l2norm of the centroid centered vectors.
if (similarityFunction == EUCLIDEAN) {
score =
queryCorrections.additionalCorrection()
+ indexCorrections.additionalCorrection()
- 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score +=
queryCorrections.additionalCorrection()
+ indexCorrections.additionalCorrection()
- targetVectors.getCentroidDP();
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
}
return Math.max((1f + score) / 2f, 0);
}
}
}
Loading
Loading