diff --git a/build.xml b/build.xml
index 28152182a84c..e14e35b0e304 100644
--- a/build.xml
+++ b/build.xml
@@ -741,7 +741,7 @@
-
+
diff --git a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java
index 436f8448332c..059ea0728596 100644
--- a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java
+++ b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java
@@ -406,6 +406,8 @@ public enum CassandraRelevantProperties
SAI_VECTOR_FLUSH_THRESHOLD_MAX_ROWS("cassandra.sai.vector_flush_threshold_max_rows", "-1"),
// Use non-positive value to disable it. Period in millis to trigger a flush for SAI vector memtable index.
SAI_VECTOR_FLUSH_PERIOD_IN_MILLIS("cassandra.sai.vector_flush_period_in_millis", "-1"),
+ // Whether compaction should build vector indexes using fused adc
+ SAI_VECTOR_ENABLE_FUSED("cassandra.sai.vector.enable_fused", "true"),
/**
* Whether to disable auto-compaction
*/
diff --git a/src/java/org/apache/cassandra/index/sai/IndexContext.java b/src/java/org/apache/cassandra/index/sai/IndexContext.java
index f6eb4a974f94..020a027b50d2 100644
--- a/src/java/org/apache/cassandra/index/sai/IndexContext.java
+++ b/src/java/org/apache/cassandra/index/sai/IndexContext.java
@@ -38,7 +38,6 @@
import org.slf4j.LoggerFactory;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
-import org.apache.cassandra.config.CassandraRelevantProperties;
import org.apache.cassandra.cql3.Operator;
import org.apache.cassandra.cql3.statements.schema.IndexTarget;
import org.apache.cassandra.db.ClusteringComparator;
diff --git a/src/java/org/apache/cassandra/index/sai/disk/format/Version.java b/src/java/org/apache/cassandra/index/sai/disk/format/Version.java
index 1017cb7a18c7..4cfc978a893c 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/format/Version.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/format/Version.java
@@ -35,6 +35,7 @@
import org.apache.cassandra.index.sai.disk.v5.V5OnDiskFormat;
import org.apache.cassandra.index.sai.disk.v6.V6OnDiskFormat;
import org.apache.cassandra.index.sai.disk.v7.V7OnDiskFormat;
+import org.apache.cassandra.index.sai.disk.v8.V8OnDiskFormat;
import org.apache.cassandra.index.sai.utils.TypeUtil;
import org.apache.cassandra.io.sstable.format.SSTableFormat;
import org.apache.cassandra.schema.SchemaConstants;
@@ -69,10 +70,12 @@ public class Version implements Comparable
public static final Version EC = new Version("ec", V7OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ec"));
// total terms count serialization in index metadata, enables ANN_USE_SYNTHETIC_SCORE by default
public static final Version ED = new Version("ed", V7OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ed"));
+ // jvector file format version 6 (skipped 5)
+ public static final Version FA = new Version("fa", V8OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "fa"));
// These are in reverse-chronological order so that the latest version is first. Version matching tests
// are more likely to match the latest version, so we want to test that one first.
- public static final List ALL = Lists.newArrayList(ED, EC, EB, DC, DB, CA, BA, AA);
+ public static final List ALL = Lists.newArrayList(FA, ED, EC, EB, DC, DB, CA, BA, AA);
public static final Version EARLIEST = AA;
public static final Version VECTOR_EARLIEST = BA;
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v8/V8OnDiskFormat.java b/src/java/org/apache/cassandra/index/sai/disk/v8/V8OnDiskFormat.java
new file mode 100644
index 000000000000..aca3d530ce97
--- /dev/null
+++ b/src/java/org/apache/cassandra/index/sai/disk/v8/V8OnDiskFormat.java
@@ -0,0 +1,32 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.cassandra.index.sai.disk.v8;
+
+import org.apache.cassandra.index.sai.disk.v7.V7OnDiskFormat;
+
+public class V8OnDiskFormat extends V7OnDiskFormat
+{
+ public static final V8OnDiskFormat instance = new V8OnDiskFormat();
+
+ @Override
+ public int jvectorFileFormatVersion()
+ {
+ return 6;
+ }
+}
\ No newline at end of file
diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java
index 92544f07e4bc..3f3a74c27fd9 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java
@@ -27,7 +27,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import io.github.jbellis.jvector.graph.GraphIndex;
+import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
import io.github.jbellis.jvector.graph.GraphSearcher;
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
@@ -72,7 +72,7 @@ public class CassandraDiskAnn
private final FileHandle graphHandle;
private final OnDiskOrdinalsMap ordinalsMap;
private final Set features;
- private final GraphIndex graph;
+ private final ImmutableGraphIndex graph;
private final VectorSimilarityFunction similarityFunction;
@Nullable
private final CompressedVectors compressedVectors;
@@ -94,7 +94,7 @@ public CassandraDiskAnn(SSTableContext sstableContext, SegmentMetadata.Component
SegmentMetadata.ComponentMetadata termsMetadata = this.componentMetadatas.get(IndexComponentType.TERMS_DATA);
graphHandle = indexFiles.termsData();
- var rawGraph = OnDiskGraphIndex.load(graphHandle::createReader, termsMetadata.offset);
+ var rawGraph = OnDiskGraphIndex.load(graphHandle::createReader, termsMetadata.offset, false);
features = rawGraph.getFeatureSet();
graph = rawGraph;
@@ -117,7 +117,7 @@ public CassandraDiskAnn(SSTableContext sstableContext, SegmentMetadata.Component
}
VectorCompression.CompressionType compressionType = VectorCompression.CompressionType.values()[reader.readByte()];
- if (features.contains(FeatureId.FUSED_ADC))
+ if (features.contains(FeatureId.FUSED_PQ))
{
assert compressionType == VectorCompression.CompressionType.PRODUCT_QUANTIZATION;
compressedVectors = null;
@@ -231,11 +231,9 @@ public CloseableIterator search(VectorFloat> queryVector,
searcher.usePruning(usePruning);
try
{
- var view = (GraphIndex.ScoringView) searcher.getView();
+ var view = (ImmutableGraphIndex.ScoringView) searcher.getView();
SearchScoreProvider ssp;
- // FusedADC can no longer be written due to jvector upgrade. However, it's possible these index files
- // still exist, so we have to support them.
- if (features.contains(FeatureId.FUSED_ADC))
+ if (features.contains(FeatureId.FUSED_PQ))
{
var asf = view.approximateScoreFunctionFor(queryVector, similarityFunction);
var rr = isRerankless ? null : view.rerankerFor(queryVector, similarityFunction);
@@ -311,9 +309,9 @@ public OrdinalsView getOrdinalsView()
return ordinalsMap.getOrdinalsView();
}
- public GraphIndex.ScoringView getView()
+ public ImmutableGraphIndex.ScoringView getView()
{
- return (GraphIndex.ScoringView) graph.getView();
+ return (ImmutableGraphIndex.ScoringView) graph.getView();
}
public boolean containsUnitVectors()
diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java
index c42178d8aec1..6406ec11b9c0 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java
@@ -25,15 +25,16 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
+import java.util.EnumMap;
import java.util.Map;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
+import java.util.function.IntFunction;
import java.util.function.IntUnaryOperator;
import java.util.function.ToIntFunction;
-import java.util.stream.IntStream;
import com.google.common.annotations.VisibleForTesting;
import org.cliffc.high_scale_lib.NonBlockingHashMap;
@@ -42,17 +43,20 @@
import io.github.jbellis.jvector.graph.GraphIndexBuilder;
import io.github.jbellis.jvector.graph.GraphSearcher;
+import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter;
import io.github.jbellis.jvector.graph.disk.OrdinalMapper;
import io.github.jbellis.jvector.graph.disk.feature.Feature;
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
+import io.github.jbellis.jvector.graph.disk.feature.FusedPQ;
import io.github.jbellis.jvector.graph.disk.feature.InlineVectors;
import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider;
-import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
import io.github.jbellis.jvector.quantization.BinaryQuantization;
import io.github.jbellis.jvector.quantization.CompressedVectors;
+import io.github.jbellis.jvector.quantization.ImmutablePQVectors;
+import io.github.jbellis.jvector.quantization.PQVectors;
import io.github.jbellis.jvector.quantization.ProductQuantization;
import io.github.jbellis.jvector.quantization.VectorCompressor;
import io.github.jbellis.jvector.util.Accountable;
@@ -63,9 +67,11 @@
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
+import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import org.agrona.collections.IntHashSet;
+import org.apache.cassandra.config.CassandraRelevantProperties;
import org.apache.cassandra.db.compaction.CompactionSSTable;
import org.apache.cassandra.db.marshal.VectorType;
import org.apache.cassandra.db.memtable.Memtable;
@@ -81,7 +87,6 @@
import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata;
import org.apache.cassandra.index.sai.disk.v2.V2VectorIndexSearcher;
import org.apache.cassandra.index.sai.disk.v2.V2VectorPostingsWriter;
-import org.apache.cassandra.index.sai.disk.v3.V3OnDiskFormat;
import org.apache.cassandra.index.sai.disk.v5.V5OnDiskFormat;
import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter;
import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter.Structure;
@@ -89,6 +94,7 @@
import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
import org.apache.cassandra.index.sai.utils.SAICodecUtils;
+import org.apache.cassandra.io.util.File;
import org.apache.cassandra.io.util.SequentialWriter;
import org.apache.cassandra.service.StorageService;
import org.apache.cassandra.tracing.Tracing;
@@ -104,6 +110,9 @@ public enum PQVersion {
V1, // includes unit vector calculation
}
+ /** whether to use fused ADC when writing indexes (assuming all other conditions are met) */
+ private static boolean ENABLE_FUSED = CassandraRelevantProperties.SAI_VECTOR_ENABLE_FUSED.getBoolean();
+
/** minimum number of rows to perform PQ codebook generation */
public static final int MIN_PQ_ROWS = 1024;
@@ -125,6 +134,7 @@ public enum PQVersion {
private final InvalidVectorBehavior invalidVectorBehavior;
private final IntHashSet deletedOrdinals;
private volatile boolean hasDeletions;
+ private volatile boolean unitVectors;
// we don't need to explicitly close these since only on-heap resources are involved
private final ThreadLocal searchers;
@@ -157,6 +167,9 @@ public CassandraOnHeapGraph(IndexContext context, boolean forSearching, Memtable
vectorsByKey = forSearching ? new NonBlockingHashMap<>() : null;
invalidVectorBehavior = forSearching ? InvalidVectorBehavior.FAIL : InvalidVectorBehavior.IGNORE;
+ // We start by assuming the vectors are unit vectors and then if they are not, we will correct it.
+ unitVectors = true;
+
int jvectorVersion = Version.current().onDiskFormat().jvectorFileFormatVersion();
// This is only a warning since it's not a fatal error to write without hierarchy
if (indexConfig.isHierarchyEnabled() && jvectorVersion < 4)
@@ -269,6 +282,12 @@ public long add(ByteBuffer term, T key)
var success = postingsByOrdinal.compareAndPut(ordinal, null, postings);
assert success : "postingsByOrdinal already contains an entry for ordinal " + ordinal;
bytesUsed += builder.addGraphNode(ordinal, vector);
+
+ // We safely added to the graph, check if we need to check for unit length
+ if (sourceModel.hasKnownUnitLengthVectors() || unitVectors)
+ if (!(Math.abs(VectorUtil.dotProduct(vector, vector) - 1.0f) < 0.01))
+ unitVectors = false;
+
return bytesUsed;
}
else
@@ -438,24 +457,18 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
if (indexFile.exists())
termsOffset += indexFile.length();
try (var pqOutput = perIndexComponents.addOrGet(IndexComponentType.PQ).openOutput(true);
- var postingsOutput = perIndexComponents.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true);
- var indexWriter = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), indexFile.toPath())
- .withStartOffset(termsOffset)
- .withVersion(Version.current().onDiskFormat().jvectorFileFormatVersion())
- .withMapper(ordinalMapper)
- .with(new InlineVectors(vectorValues.dimension()))
- .build())
+ var postingsOutput = perIndexComponents.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true))
{
SAICodecUtils.writeHeader(pqOutput);
SAICodecUtils.writeHeader(postingsOutput);
- indexWriter.getOutput().seek(indexFile.length()); // position at the end of the previous segment before writing our own header
- SAICodecUtils.writeHeader(SAICodecUtils.toLuceneOutput(indexWriter.getOutput()));
- assert indexWriter.getOutput().position() == termsOffset : "termsOffset " + termsOffset + " != " + indexWriter.getOutput().position();
+
+ // Write fused unless we don't meet some criteria
+ boolean attemptWritingFused = ENABLE_FUSED && Version.current().onDiskFormat().jvectorFileFormatVersion() >= 6;
// compute and write PQ
long pqOffset = pqOutput.getFilePointer();
- long pqPosition = writePQ(pqOutput.asSequentialWriter(), remappedPostings, perIndexComponents.context());
- long pqLength = pqPosition - pqOffset;
+ var compressor = writePQ(pqOutput.asSequentialWriter(), remappedPostings, perIndexComponents.context(), attemptWritingFused);
+ long pqLength = pqOutput.asSequentialWriter().position() - pqOffset;
// write postings
long postingsOffset = postingsOutput.getFilePointer();
@@ -474,21 +487,42 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
}
long postingsLength = postingsPosition - postingsOffset;
- // write the graph
- var start = System.nanoTime();
- var suppliers = Feature.singleStateFactory(FeatureId.INLINE_VECTORS, nodeId -> new InlineVectors.State(vectorValues.getVector(nodeId)));
- indexWriter.write(suppliers);
- SAICodecUtils.writeFooter(indexWriter.getOutput(), indexWriter.checksum());
- logger.info("Writing graph took {}ms", (System.nanoTime() - start) / 1_000_000);
- long termsLength = indexWriter.getOutput().position() - termsOffset;
+ try (var indexWriter = createIndexWriter(indexFile, termsOffset, perIndexComponents.context(), ordinalMapper, compressor);
+ var view = builder.getGraph().getView())
+ {
+ indexWriter.getOutput().seek(indexFile.length()); // position at the end of the previous segment before writing our own header
+ SAICodecUtils.writeHeader(SAICodecUtils.toLuceneOutput(indexWriter.getOutput()));
+ assert indexWriter.getOutput().position() == termsOffset : "termsOffset " + termsOffset + " != " + indexWriter.getOutput().position();
+
+ // write the graph
+ var start = System.nanoTime();
+ indexWriter.write(suppliers(view, compressor));
+ SAICodecUtils.writeFooter(indexWriter.getOutput(), indexWriter.checksum());
+ logger.info("Writing graph took {}ms", (System.nanoTime() - start) / 1_000_000);
+ long termsLength = indexWriter.getOutput().position() - termsOffset;
+
+ // write remaining footers/checksums
+ SAICodecUtils.writeFooter(pqOutput);
+ SAICodecUtils.writeFooter(postingsOutput);
+
+ // add components to the metadata map
+ return createMetadataMap(termsOffset, termsLength, postingsOffset, postingsLength, pqOffset, pqLength);
+ }
+ }
+ }
+
+ private OnDiskGraphIndexWriter createIndexWriter(File indexFile, long termsOffset, IndexContext context, OrdinalMapper ordinalMapper, VectorCompressor> compressor) throws IOException
+ {
+ var indexWriterBuilder = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), indexFile.toPath())
+ .withStartOffset(termsOffset)
+ .withVersion(Version.current().onDiskFormat().jvectorFileFormatVersion())
+ .withMapper(ordinalMapper)
+ .with(new InlineVectors(vectorValues.dimension()));
- // write remaining footers/checksums
- SAICodecUtils.writeFooter(pqOutput);
- SAICodecUtils.writeFooter(postingsOutput);
+ if (ENABLE_FUSED && compressor instanceof ProductQuantization && Version.current().onDiskFormat().jvectorFileFormatVersion() >= 6)
+ indexWriterBuilder.with(new FusedPQ(context.getIndexWriterConfig().getAnnMaxDegree(), (ProductQuantization) compressor));
- // add components to the metadata map
- return createMetadataMap(termsOffset, termsLength, postingsOffset, postingsLength, pqOffset, pqLength);
- }
+ return indexWriterBuilder.build();
}
static SegmentMetadata.ComponentMetadataMap createMetadataMap(long termsOffset, long termsLength, long postingsOffset, long postingsLength, long pqOffset, long pqLength)
@@ -501,6 +535,22 @@ static SegmentMetadata.ComponentMetadataMap createMetadataMap(long termsOffset,
return metadataMap;
}
+ private EnumMap> suppliers(ImmutableGraphIndex.View view, VectorCompressor> compressor)
+ {
+ var features = new EnumMap>(FeatureId.class);
+ features.put(FeatureId.INLINE_VECTORS, nodeId -> new InlineVectors.State(vectorValues.getVector(nodeId)));
+ if (ENABLE_FUSED && Version.current().onDiskFormat().jvectorFileFormatVersion() >= 6)
+ {
+ if (compressor instanceof ProductQuantization)
+ {
+ ProductQuantization quantization = (ProductQuantization) compressor;
+ IntFunction> func = (oldNodeId) -> quantization.encode(vectorValues.getVector(oldNodeId));
+ features.put(FeatureId.FUSED_PQ, nodeId -> new FusedPQ.State(view, func, nodeId));
+ }
+ }
+ return features;
+ }
+
/**
* Return the best previous CompressedVectors for this column that matches the `matcher` predicate.
* "Best" means the most recent one that hits the row count target of {@link ProductQuantization#MAX_PQ_TRAINING_SET_SIZE},
@@ -553,14 +603,13 @@ public static PqInfo getPqIfPresent(IndexContext indexContext, Function writePQ(SequentialWriter writer, V5VectorPostingsWriter.RemappedPostings remapped, IndexContext indexContext, boolean attemptWritingFused) throws IOException
{
var preferredCompression = sourceModel.compressionProvider.apply(vectorValues.dimension());
// Build encoder and compress vectors
VectorCompressor> compressor; // will be null if we can't compress
CompressedVectors cv = null;
- boolean containsUnitVectors;
// limit the PQ computation and encoding to one index at a time -- goal during flush is to
// evict from memory ASAP so better to do the PQ build (in parallel) one at a time
synchronized (CassandraOnHeapGraph.class)
@@ -578,23 +627,24 @@ private long writePQ(SequentialWriter writer, V5VectorPostingsWriter.RemappedPos
}
assert !vectorValues.isValueShared();
// encode (compress) the vectors to save
- if (compressor != null)
+ if ((compressor instanceof ProductQuantization && !attemptWritingFused) || compressor instanceof BinaryQuantization)
cv = compressor.encodeAll(new RemappedVectorValues(remapped, remapped.maxNewOrdinal, vectorValues));
-
- containsUnitVectors = IntStream.range(0, vectorValues.size())
- .parallel()
- .mapToObj(vectorValues::getVector)
- .allMatch(v -> Math.abs(VectorUtil.dotProduct(v, v) - 1.0f) < 0.01);
}
var actualType = compressor == null ? CompressionType.NONE : preferredCompression.type;
- writePqHeader(writer, containsUnitVectors, actualType);
+ writePqHeader(writer, unitVectors, actualType);
if (actualType == CompressionType.NONE)
- return writer.position();
+ return null;
+
+ if (attemptWritingFused)
+ {
+ compressor.write(writer, Version.current().onDiskFormat().jvectorFileFormatVersion());
+ return compressor;
+ }
// save (outside the synchronized block, this is io-bound not CPU)
cv.write(writer, Version.current().onDiskFormat().jvectorFileFormatVersion());
- return writer.position();
+ return null; // Don't need compressor in this case
}
static void writePqHeader(DataOutput writer, boolean unitVectors, CompressionType type)
diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CloseableReranker.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CloseableReranker.java
index b85b33c81b55..0415b9eb4bcc 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/vector/CloseableReranker.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CloseableReranker.java
@@ -20,21 +20,21 @@
import java.io.Closeable;
-import io.github.jbellis.jvector.graph.GraphIndex;
+import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import org.apache.cassandra.io.util.FileUtils;
/**
- * An ExactScoreFunction that closes the underlying {@link GraphIndex.ScoringView} when closed.
+ * An ExactScoreFunction that closes the underlying {@link ImmutableGraphIndex.ScoringView} when closed.
*/
public class CloseableReranker implements ScoreFunction.ExactScoreFunction, Closeable
{
- private final GraphIndex.ScoringView view;
+ private final ImmutableGraphIndex.ScoringView view;
private final ExactScoreFunction scoreFunction;
- public CloseableReranker(VectorSimilarityFunction similarityFunction, VectorFloat> queryVector, GraphIndex.ScoringView view)
+ public CloseableReranker(VectorSimilarityFunction similarityFunction, VectorFloat> queryVector, ImmutableGraphIndex.ScoringView view)
{
this.view = view;
this.scoreFunction = view.rerankerFor(queryVector, similarityFunction);
diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java
index 2443e272340c..206f90aa7304 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java
@@ -41,6 +41,7 @@
import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
import io.github.jbellis.jvector.graph.disk.feature.Feature;
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
+import io.github.jbellis.jvector.graph.disk.feature.FusedPQ;
import io.github.jbellis.jvector.graph.disk.feature.InlineVectors;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter;
@@ -67,6 +68,7 @@
import net.openhft.chronicle.map.ChronicleMapBuilder;
import org.agrona.collections.Int2ObjectHashMap;
import org.apache.cassandra.concurrent.NamedThreadFactory;
+import org.apache.cassandra.config.CassandraRelevantProperties;
import org.apache.cassandra.db.Keyspace;
import org.apache.cassandra.db.marshal.VectorType;
import org.apache.cassandra.exceptions.InvalidRequestException;
@@ -109,6 +111,8 @@ public class CompactionGraph implements Closeable, Accountable
@VisibleForTesting
public static int PQ_TRAINING_SIZE = ProductQuantization.MAX_PQ_TRAINING_SET_SIZE;
+ private static boolean ENABLE_FUSED = CassandraRelevantProperties.SAI_VECTOR_ENABLE_FUSED.getBoolean();
+
private final VectorType.VectorSerializer serializer;
private final VectorSimilarityFunction similarityFunction;
private final ChronicleMap, CompactionVectorPostings> postingsMap;
@@ -225,12 +229,14 @@ else if (compressor instanceof BinaryQuantization)
private OnDiskGraphIndexWriter createTermsWriter(OrdinalMapper ordinalMapper) throws IOException
{
- return new OnDiskGraphIndexWriter.Builder(builder.getGraph(), termsFile.toPath())
+ var writerBuilder = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), termsFile.toPath())
.withStartOffset(termsOffset)
.with(new InlineVectors(dimension))
.withVersion(Version.current().onDiskFormat().jvectorFileFormatVersion())
- .withMapper(ordinalMapper)
- .build();
+ .withMapper(ordinalMapper);
+ if (ENABLE_FUSED && compressor instanceof ProductQuantization && Version.current().onDiskFormat().jvectorFileFormatVersion() >= 6)
+ writerBuilder.with(new FusedPQ(context.getIndexWriterConfig().getAnnMaxDegree(), (ProductQuantization) compressor));
+ return writerBuilder.build();
}
@Override
@@ -372,7 +378,7 @@ public long addGraphNode(InsertionResult result)
public SegmentMetadata.ComponentMetadataMap flush() throws IOException
{
// header is required to write the postings, but we need to recreate the writer after that with an accurate OrdinalMapper
- writer.writeHeader();
+ writer.writeHeader(builder.getGraph().getView());
writer.close();
int nInProgress = builder.insertsInProgress();
@@ -408,7 +414,7 @@ public SegmentMetadata.ComponentMetadataMap flush() throws IOException
var es = Executors.newSingleThreadExecutor(new NamedThreadFactory("CompactionGraphPostingsWriter"));
long postingsLength;
try (var indexHandle = perIndexComponents.get(IndexComponentType.TERMS_DATA).createIndexBuildTimeFileHandle();
- var index = OnDiskGraphIndex.load(indexHandle::createReader, termsOffset))
+ var index = OnDiskGraphIndex.load(indexHandle::createReader, termsOffset, false))
{
var postingsFuture = es.submit(() -> {
// V2 doesn't support ONE_TO_MANY so force it to ZERO_OR_ONE_TO_MANY if necessary;
@@ -453,9 +459,18 @@ public SegmentMetadata.ComponentMetadataMap flush() throws IOException
// write the graph edge lists and optionally fused adc features
var start = System.nanoTime();
- // Required becuase jvector 3 wrote the fused adc map here. We no longer write jvector 3, but we still
- // write out the empty map.
- writer.write(Map.of());
+ if (writer.getFeatureSet().contains(FeatureId.FUSED_PQ))
+ {
+ try (var view = builder.getGraph().getView())
+ {
+ var supplier = Feature.singleStateFactory(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(view, (PQVectors) compressedVectors, ordinal));
+ writer.write(supplier);
+ }
+ }
+ else
+ {
+ writer.write(Map.of());
+ }
SAICodecUtils.writeFooter(writer.getOutput(), writer.checksum());
logger.info("Writing graph took {}ms", (System.nanoTime() - start) / 1_000_000);
long termsLength = writer.getOutput().position() - termsOffset;
diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java
index bf08896591dd..bccfd8e1352a 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java
@@ -31,18 +31,19 @@
import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.BINARY_QUANTIZATION;
import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.NONE;
import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.PRODUCT_QUANTIZATION;
-
public enum VectorSourceModel
{
- ADA002((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25),
- OPENAI_V3_SMALL((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.5),
- OPENAI_V3_LARGE((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25),
- BERT(COSINE, (dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.25), __ -> 1.0),
- GECKO((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25),
- NV_QA_4((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25),
- COHERE_V3((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25),
-
- OTHER(COSINE, VectorSourceModel::genericCompressionFor, VectorSourceModel::genericOverquery);
+ ADA002((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, true),
+ OPENAI_V3_SMALL((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.5, true),
+ OPENAI_V3_LARGE((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25, true),
+ // BERT is not known to have unit length vectors in all cases
+ BERT(COSINE, (dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.25), __ -> 1.0, false),
+ GECKO((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, true),
+ NV_QA_4((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, false),
+ // Cohere does not officially say they have unit length vectors, but some users report that they do
+ COHERE_V3((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25, false),
+
+ OTHER(COSINE, VectorSourceModel::genericCompressionFor, VectorSourceModel::genericOverquery, false);
/**
* Default similarity function for this model.
@@ -58,18 +59,33 @@ public enum VectorSourceModel
*/
public final Function overqueryProvider;
- VectorSourceModel(Function compressionProvider, double overqueryFactor)
+ /**
+ * Indicates that the model is known to have unit length vectors. When false, the runtime checks per graph
+ * until a non-unit length vector is found.
+ */
+ private final boolean knownUnitLength;
+
+ VectorSourceModel(Function compressionProvider,
+ double overqueryFactor,
+ boolean knownUnitLength)
{
- this(DOT_PRODUCT, compressionProvider, __ -> overqueryFactor);
+ this(DOT_PRODUCT, compressionProvider, __ -> overqueryFactor, knownUnitLength);
}
VectorSourceModel(VectorSimilarityFunction defaultSimilarityFunction,
Function compressionProvider,
- Function overqueryProvider)
+ Function overqueryProvider,
+ boolean knownUnitLength)
{
this.defaultSimilarityFunction = defaultSimilarityFunction;
this.compressionProvider = compressionProvider;
this.overqueryProvider = overqueryProvider;
+ this.knownUnitLength = knownUnitLength;
+ }
+
+ public boolean hasKnownUnitLengthVectors()
+ {
+ return knownUnitLength;
}
public static VectorSourceModel fromString(String value)
diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java
index 8fc9b37c17e8..748a6f29c364 100644
--- a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java
@@ -43,7 +43,7 @@
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
-public class VectorSiftSmallTest extends VectorTester
+public class VectorSiftSmallTest extends VectorTester.Versioned
{
private static final String DATASET = "siftsmall"; // change to "sift" for larger dataset. requires manual download
@@ -156,6 +156,7 @@ public void testCompaction() throws Throwable
assertTrue("Pre-compaction recall is " + recall, recall > 0.975);
}
+ compact();
compact();
for (int topK : List.of(1, 100))
{
@@ -313,7 +314,7 @@ private void createTable()
private void createIndex()
{
// we need a long timeout because we are adding many vectors
- String index = createIndexAsync("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function' : 'euclidean'}");
+ String index = createIndexAsync("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function' : 'euclidean', 'enable_hierarchy': 'true'}");
waitForIndexQueryable(KEYSPACE, index, 5, TimeUnit.MINUTES);
}
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIteratorTest.java b/test/unit/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIteratorTest.java
index 22658da82ac7..ff025379bf52 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIteratorTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIteratorTest.java
@@ -19,10 +19,11 @@
package org.apache.cassandra.index.sai.disk.vector;
import java.util.NoSuchElementException;
+import java.util.function.Function;
import org.junit.Test;
-import io.github.jbellis.jvector.graph.GraphIndex;
+import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
import io.github.jbellis.jvector.graph.NodeQueue;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
@@ -63,7 +64,7 @@ public void testBruteForceRowIdIteratorForEmptyPQAndTopKEqualsLimit()
assertTrue(view.isClosed);
}
- private static class TestView implements GraphIndex.ScoringView
+ private static class TestView implements ImmutableGraphIndex.ScoringView
{
private boolean isClosed = false;
@@ -95,6 +96,12 @@ public NodesIterator getNeighborsIterator(int i, int i1)
throw new UnsupportedOperationException();
}
+ @Override
+ public void processNeighbors(int i, int i1, ScoreFunction scoreFunction, Function function, ImmutableGraphIndex.NeighborProcessor neighborProcessor)
+ {
+
+ }
+
@Override
public int size()
{
@@ -102,7 +109,7 @@ public int size()
}
@Override
- public GraphIndex.NodeAtLevel entryNode()
+ public ImmutableGraphIndex.NodeAtLevel entryNode()
{
throw new UnsupportedOperationException();
}
@@ -112,5 +119,11 @@ public Bits liveNodes()
{
throw new UnsupportedOperationException();
}
+
+ @Override
+ public boolean contains(int i, int i1)
+ {
+ return false;
+ }
}
}