diff --git a/build.xml b/build.xml index 28152182a84c..e14e35b0e304 100644 --- a/build.xml +++ b/build.xml @@ -741,7 +741,7 @@ - + diff --git a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java index 436f8448332c..059ea0728596 100644 --- a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java +++ b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java @@ -406,6 +406,8 @@ public enum CassandraRelevantProperties SAI_VECTOR_FLUSH_THRESHOLD_MAX_ROWS("cassandra.sai.vector_flush_threshold_max_rows", "-1"), // Use non-positive value to disable it. Period in millis to trigger a flush for SAI vector memtable index. SAI_VECTOR_FLUSH_PERIOD_IN_MILLIS("cassandra.sai.vector_flush_period_in_millis", "-1"), + // Whether compaction should build vector indexes using fused adc + SAI_VECTOR_ENABLE_FUSED("cassandra.sai.vector.enable_fused", "true"), /** * Whether to disable auto-compaction */ diff --git a/src/java/org/apache/cassandra/index/sai/IndexContext.java b/src/java/org/apache/cassandra/index/sai/IndexContext.java index f6eb4a974f94..020a027b50d2 100644 --- a/src/java/org/apache/cassandra/index/sai/IndexContext.java +++ b/src/java/org/apache/cassandra/index/sai/IndexContext.java @@ -38,7 +38,6 @@ import org.slf4j.LoggerFactory; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; -import org.apache.cassandra.config.CassandraRelevantProperties; import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.cql3.statements.schema.IndexTarget; import org.apache.cassandra.db.ClusteringComparator; diff --git a/src/java/org/apache/cassandra/index/sai/disk/format/Version.java b/src/java/org/apache/cassandra/index/sai/disk/format/Version.java index 1017cb7a18c7..4cfc978a893c 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/format/Version.java +++ b/src/java/org/apache/cassandra/index/sai/disk/format/Version.java @@ -35,6 +35,7 @@ import org.apache.cassandra.index.sai.disk.v5.V5OnDiskFormat; import org.apache.cassandra.index.sai.disk.v6.V6OnDiskFormat; import org.apache.cassandra.index.sai.disk.v7.V7OnDiskFormat; +import org.apache.cassandra.index.sai.disk.v8.V8OnDiskFormat; import org.apache.cassandra.index.sai.utils.TypeUtil; import org.apache.cassandra.io.sstable.format.SSTableFormat; import org.apache.cassandra.schema.SchemaConstants; @@ -69,10 +70,12 @@ public class Version implements Comparable public static final Version EC = new Version("ec", V7OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ec")); // total terms count serialization in index metadata, enables ANN_USE_SYNTHETIC_SCORE by default public static final Version ED = new Version("ed", V7OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ed")); + // jvector file format version 6 (skipped 5) + public static final Version FA = new Version("fa", V8OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "fa")); // These are in reverse-chronological order so that the latest version is first. Version matching tests // are more likely to match the latest version, so we want to test that one first. - public static final List ALL = Lists.newArrayList(ED, EC, EB, DC, DB, CA, BA, AA); + public static final List ALL = Lists.newArrayList(FA, ED, EC, EB, DC, DB, CA, BA, AA); public static final Version EARLIEST = AA; public static final Version VECTOR_EARLIEST = BA; diff --git a/src/java/org/apache/cassandra/index/sai/disk/v8/V8OnDiskFormat.java b/src/java/org/apache/cassandra/index/sai/disk/v8/V8OnDiskFormat.java new file mode 100644 index 000000000000..aca3d530ce97 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v8/V8OnDiskFormat.java @@ -0,0 +1,32 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.cassandra.index.sai.disk.v8; + +import org.apache.cassandra.index.sai.disk.v7.V7OnDiskFormat; + +public class V8OnDiskFormat extends V7OnDiskFormat +{ + public static final V8OnDiskFormat instance = new V8OnDiskFormat(); + + @Override + public int jvectorFileFormatVersion() + { + return 6; + } +} \ No newline at end of file diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java index 92544f07e4bc..3f3a74c27fd9 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java @@ -27,7 +27,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.GraphSearcher; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; @@ -72,7 +72,7 @@ public class CassandraDiskAnn private final FileHandle graphHandle; private final OnDiskOrdinalsMap ordinalsMap; private final Set features; - private final GraphIndex graph; + private final ImmutableGraphIndex graph; private final VectorSimilarityFunction similarityFunction; @Nullable private final CompressedVectors compressedVectors; @@ -94,7 +94,7 @@ public CassandraDiskAnn(SSTableContext sstableContext, SegmentMetadata.Component SegmentMetadata.ComponentMetadata termsMetadata = this.componentMetadatas.get(IndexComponentType.TERMS_DATA); graphHandle = indexFiles.termsData(); - var rawGraph = OnDiskGraphIndex.load(graphHandle::createReader, termsMetadata.offset); + var rawGraph = OnDiskGraphIndex.load(graphHandle::createReader, termsMetadata.offset, false); features = rawGraph.getFeatureSet(); graph = rawGraph; @@ -117,7 +117,7 @@ public CassandraDiskAnn(SSTableContext sstableContext, SegmentMetadata.Component } VectorCompression.CompressionType compressionType = VectorCompression.CompressionType.values()[reader.readByte()]; - if (features.contains(FeatureId.FUSED_ADC)) + if (features.contains(FeatureId.FUSED_PQ)) { assert compressionType == VectorCompression.CompressionType.PRODUCT_QUANTIZATION; compressedVectors = null; @@ -231,11 +231,9 @@ public CloseableIterator search(VectorFloat queryVector, searcher.usePruning(usePruning); try { - var view = (GraphIndex.ScoringView) searcher.getView(); + var view = (ImmutableGraphIndex.ScoringView) searcher.getView(); SearchScoreProvider ssp; - // FusedADC can no longer be written due to jvector upgrade. However, it's possible these index files - // still exist, so we have to support them. - if (features.contains(FeatureId.FUSED_ADC)) + if (features.contains(FeatureId.FUSED_PQ)) { var asf = view.approximateScoreFunctionFor(queryVector, similarityFunction); var rr = isRerankless ? null : view.rerankerFor(queryVector, similarityFunction); @@ -311,9 +309,9 @@ public OrdinalsView getOrdinalsView() return ordinalsMap.getOrdinalsView(); } - public GraphIndex.ScoringView getView() + public ImmutableGraphIndex.ScoringView getView() { - return (GraphIndex.ScoringView) graph.getView(); + return (ImmutableGraphIndex.ScoringView) graph.getView(); } public boolean containsUnitVectors() diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java index c42178d8aec1..6406ec11b9c0 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java @@ -25,15 +25,16 @@ import java.util.Arrays; import java.util.Collection; import java.util.Comparator; +import java.util.EnumMap; import java.util.Map; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentSkipListMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; +import java.util.function.IntFunction; import java.util.function.IntUnaryOperator; import java.util.function.ToIntFunction; -import java.util.stream.IntStream; import com.google.common.annotations.VisibleForTesting; import org.cliffc.high_scale_lib.NonBlockingHashMap; @@ -42,17 +43,20 @@ import io.github.jbellis.jvector.graph.GraphIndexBuilder; import io.github.jbellis.jvector.graph.GraphSearcher; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.SearchResult; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; import io.github.jbellis.jvector.graph.disk.OrdinalMapper; import io.github.jbellis.jvector.graph.disk.feature.Feature; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; -import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; import io.github.jbellis.jvector.quantization.BinaryQuantization; import io.github.jbellis.jvector.quantization.CompressedVectors; +import io.github.jbellis.jvector.quantization.ImmutablePQVectors; +import io.github.jbellis.jvector.quantization.PQVectors; import io.github.jbellis.jvector.quantization.ProductQuantization; import io.github.jbellis.jvector.quantization.VectorCompressor; import io.github.jbellis.jvector.util.Accountable; @@ -63,9 +67,11 @@ import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.VectorUtil; import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.ByteSequence; import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; import org.agrona.collections.IntHashSet; +import org.apache.cassandra.config.CassandraRelevantProperties; import org.apache.cassandra.db.compaction.CompactionSSTable; import org.apache.cassandra.db.marshal.VectorType; import org.apache.cassandra.db.memtable.Memtable; @@ -81,7 +87,6 @@ import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; import org.apache.cassandra.index.sai.disk.v2.V2VectorIndexSearcher; import org.apache.cassandra.index.sai.disk.v2.V2VectorPostingsWriter; -import org.apache.cassandra.index.sai.disk.v3.V3OnDiskFormat; import org.apache.cassandra.index.sai.disk.v5.V5OnDiskFormat; import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter; import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter.Structure; @@ -89,6 +94,7 @@ import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.SAICodecUtils; +import org.apache.cassandra.io.util.File; import org.apache.cassandra.io.util.SequentialWriter; import org.apache.cassandra.service.StorageService; import org.apache.cassandra.tracing.Tracing; @@ -104,6 +110,9 @@ public enum PQVersion { V1, // includes unit vector calculation } + /** whether to use fused ADC when writing indexes (assuming all other conditions are met) */ + private static boolean ENABLE_FUSED = CassandraRelevantProperties.SAI_VECTOR_ENABLE_FUSED.getBoolean(); + /** minimum number of rows to perform PQ codebook generation */ public static final int MIN_PQ_ROWS = 1024; @@ -125,6 +134,7 @@ public enum PQVersion { private final InvalidVectorBehavior invalidVectorBehavior; private final IntHashSet deletedOrdinals; private volatile boolean hasDeletions; + private volatile boolean unitVectors; // we don't need to explicitly close these since only on-heap resources are involved private final ThreadLocal searchers; @@ -157,6 +167,9 @@ public CassandraOnHeapGraph(IndexContext context, boolean forSearching, Memtable vectorsByKey = forSearching ? new NonBlockingHashMap<>() : null; invalidVectorBehavior = forSearching ? InvalidVectorBehavior.FAIL : InvalidVectorBehavior.IGNORE; + // We start by assuming the vectors are unit vectors and then if they are not, we will correct it. + unitVectors = true; + int jvectorVersion = Version.current().onDiskFormat().jvectorFileFormatVersion(); // This is only a warning since it's not a fatal error to write without hierarchy if (indexConfig.isHierarchyEnabled() && jvectorVersion < 4) @@ -269,6 +282,12 @@ public long add(ByteBuffer term, T key) var success = postingsByOrdinal.compareAndPut(ordinal, null, postings); assert success : "postingsByOrdinal already contains an entry for ordinal " + ordinal; bytesUsed += builder.addGraphNode(ordinal, vector); + + // We safely added to the graph, check if we need to check for unit length + if (sourceModel.hasKnownUnitLengthVectors() || unitVectors) + if (!(Math.abs(VectorUtil.dotProduct(vector, vector) - 1.0f) < 0.01)) + unitVectors = false; + return bytesUsed; } else @@ -438,24 +457,18 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn if (indexFile.exists()) termsOffset += indexFile.length(); try (var pqOutput = perIndexComponents.addOrGet(IndexComponentType.PQ).openOutput(true); - var postingsOutput = perIndexComponents.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true); - var indexWriter = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), indexFile.toPath()) - .withStartOffset(termsOffset) - .withVersion(Version.current().onDiskFormat().jvectorFileFormatVersion()) - .withMapper(ordinalMapper) - .with(new InlineVectors(vectorValues.dimension())) - .build()) + var postingsOutput = perIndexComponents.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true)) { SAICodecUtils.writeHeader(pqOutput); SAICodecUtils.writeHeader(postingsOutput); - indexWriter.getOutput().seek(indexFile.length()); // position at the end of the previous segment before writing our own header - SAICodecUtils.writeHeader(SAICodecUtils.toLuceneOutput(indexWriter.getOutput())); - assert indexWriter.getOutput().position() == termsOffset : "termsOffset " + termsOffset + " != " + indexWriter.getOutput().position(); + + // Write fused unless we don't meet some criteria + boolean attemptWritingFused = ENABLE_FUSED && Version.current().onDiskFormat().jvectorFileFormatVersion() >= 6; // compute and write PQ long pqOffset = pqOutput.getFilePointer(); - long pqPosition = writePQ(pqOutput.asSequentialWriter(), remappedPostings, perIndexComponents.context()); - long pqLength = pqPosition - pqOffset; + var compressor = writePQ(pqOutput.asSequentialWriter(), remappedPostings, perIndexComponents.context(), attemptWritingFused); + long pqLength = pqOutput.asSequentialWriter().position() - pqOffset; // write postings long postingsOffset = postingsOutput.getFilePointer(); @@ -474,21 +487,42 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn } long postingsLength = postingsPosition - postingsOffset; - // write the graph - var start = System.nanoTime(); - var suppliers = Feature.singleStateFactory(FeatureId.INLINE_VECTORS, nodeId -> new InlineVectors.State(vectorValues.getVector(nodeId))); - indexWriter.write(suppliers); - SAICodecUtils.writeFooter(indexWriter.getOutput(), indexWriter.checksum()); - logger.info("Writing graph took {}ms", (System.nanoTime() - start) / 1_000_000); - long termsLength = indexWriter.getOutput().position() - termsOffset; + try (var indexWriter = createIndexWriter(indexFile, termsOffset, perIndexComponents.context(), ordinalMapper, compressor); + var view = builder.getGraph().getView()) + { + indexWriter.getOutput().seek(indexFile.length()); // position at the end of the previous segment before writing our own header + SAICodecUtils.writeHeader(SAICodecUtils.toLuceneOutput(indexWriter.getOutput())); + assert indexWriter.getOutput().position() == termsOffset : "termsOffset " + termsOffset + " != " + indexWriter.getOutput().position(); + + // write the graph + var start = System.nanoTime(); + indexWriter.write(suppliers(view, compressor)); + SAICodecUtils.writeFooter(indexWriter.getOutput(), indexWriter.checksum()); + logger.info("Writing graph took {}ms", (System.nanoTime() - start) / 1_000_000); + long termsLength = indexWriter.getOutput().position() - termsOffset; + + // write remaining footers/checksums + SAICodecUtils.writeFooter(pqOutput); + SAICodecUtils.writeFooter(postingsOutput); + + // add components to the metadata map + return createMetadataMap(termsOffset, termsLength, postingsOffset, postingsLength, pqOffset, pqLength); + } + } + } + + private OnDiskGraphIndexWriter createIndexWriter(File indexFile, long termsOffset, IndexContext context, OrdinalMapper ordinalMapper, VectorCompressor compressor) throws IOException + { + var indexWriterBuilder = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), indexFile.toPath()) + .withStartOffset(termsOffset) + .withVersion(Version.current().onDiskFormat().jvectorFileFormatVersion()) + .withMapper(ordinalMapper) + .with(new InlineVectors(vectorValues.dimension())); - // write remaining footers/checksums - SAICodecUtils.writeFooter(pqOutput); - SAICodecUtils.writeFooter(postingsOutput); + if (ENABLE_FUSED && compressor instanceof ProductQuantization && Version.current().onDiskFormat().jvectorFileFormatVersion() >= 6) + indexWriterBuilder.with(new FusedPQ(context.getIndexWriterConfig().getAnnMaxDegree(), (ProductQuantization) compressor)); - // add components to the metadata map - return createMetadataMap(termsOffset, termsLength, postingsOffset, postingsLength, pqOffset, pqLength); - } + return indexWriterBuilder.build(); } static SegmentMetadata.ComponentMetadataMap createMetadataMap(long termsOffset, long termsLength, long postingsOffset, long postingsLength, long pqOffset, long pqLength) @@ -501,6 +535,22 @@ static SegmentMetadata.ComponentMetadataMap createMetadataMap(long termsOffset, return metadataMap; } + private EnumMap> suppliers(ImmutableGraphIndex.View view, VectorCompressor compressor) + { + var features = new EnumMap>(FeatureId.class); + features.put(FeatureId.INLINE_VECTORS, nodeId -> new InlineVectors.State(vectorValues.getVector(nodeId))); + if (ENABLE_FUSED && Version.current().onDiskFormat().jvectorFileFormatVersion() >= 6) + { + if (compressor instanceof ProductQuantization) + { + ProductQuantization quantization = (ProductQuantization) compressor; + IntFunction> func = (oldNodeId) -> quantization.encode(vectorValues.getVector(oldNodeId)); + features.put(FeatureId.FUSED_PQ, nodeId -> new FusedPQ.State(view, func, nodeId)); + } + } + return features; + } + /** * Return the best previous CompressedVectors for this column that matches the `matcher` predicate. * "Best" means the most recent one that hits the row count target of {@link ProductQuantization#MAX_PQ_TRAINING_SET_SIZE}, @@ -553,14 +603,13 @@ public static PqInfo getPqIfPresent(IndexContext indexContext, Function writePQ(SequentialWriter writer, V5VectorPostingsWriter.RemappedPostings remapped, IndexContext indexContext, boolean attemptWritingFused) throws IOException { var preferredCompression = sourceModel.compressionProvider.apply(vectorValues.dimension()); // Build encoder and compress vectors VectorCompressor compressor; // will be null if we can't compress CompressedVectors cv = null; - boolean containsUnitVectors; // limit the PQ computation and encoding to one index at a time -- goal during flush is to // evict from memory ASAP so better to do the PQ build (in parallel) one at a time synchronized (CassandraOnHeapGraph.class) @@ -578,23 +627,24 @@ private long writePQ(SequentialWriter writer, V5VectorPostingsWriter.RemappedPos } assert !vectorValues.isValueShared(); // encode (compress) the vectors to save - if (compressor != null) + if ((compressor instanceof ProductQuantization && !attemptWritingFused) || compressor instanceof BinaryQuantization) cv = compressor.encodeAll(new RemappedVectorValues(remapped, remapped.maxNewOrdinal, vectorValues)); - - containsUnitVectors = IntStream.range(0, vectorValues.size()) - .parallel() - .mapToObj(vectorValues::getVector) - .allMatch(v -> Math.abs(VectorUtil.dotProduct(v, v) - 1.0f) < 0.01); } var actualType = compressor == null ? CompressionType.NONE : preferredCompression.type; - writePqHeader(writer, containsUnitVectors, actualType); + writePqHeader(writer, unitVectors, actualType); if (actualType == CompressionType.NONE) - return writer.position(); + return null; + + if (attemptWritingFused) + { + compressor.write(writer, Version.current().onDiskFormat().jvectorFileFormatVersion()); + return compressor; + } // save (outside the synchronized block, this is io-bound not CPU) cv.write(writer, Version.current().onDiskFormat().jvectorFileFormatVersion()); - return writer.position(); + return null; // Don't need compressor in this case } static void writePqHeader(DataOutput writer, boolean unitVectors, CompressionType type) diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CloseableReranker.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CloseableReranker.java index b85b33c81b55..0415b9eb4bcc 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/CloseableReranker.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CloseableReranker.java @@ -20,21 +20,21 @@ import java.io.Closeable; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; import org.apache.cassandra.io.util.FileUtils; /** - * An ExactScoreFunction that closes the underlying {@link GraphIndex.ScoringView} when closed. + * An ExactScoreFunction that closes the underlying {@link ImmutableGraphIndex.ScoringView} when closed. */ public class CloseableReranker implements ScoreFunction.ExactScoreFunction, Closeable { - private final GraphIndex.ScoringView view; + private final ImmutableGraphIndex.ScoringView view; private final ExactScoreFunction scoreFunction; - public CloseableReranker(VectorSimilarityFunction similarityFunction, VectorFloat queryVector, GraphIndex.ScoringView view) + public CloseableReranker(VectorSimilarityFunction similarityFunction, VectorFloat queryVector, ImmutableGraphIndex.ScoringView view) { this.view = view; this.scoreFunction = view.rerankerFor(queryVector, similarityFunction); diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java index 2443e272340c..206f90aa7304 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java @@ -41,6 +41,7 @@ import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; import io.github.jbellis.jvector.graph.disk.feature.Feature; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; @@ -67,6 +68,7 @@ import net.openhft.chronicle.map.ChronicleMapBuilder; import org.agrona.collections.Int2ObjectHashMap; import org.apache.cassandra.concurrent.NamedThreadFactory; +import org.apache.cassandra.config.CassandraRelevantProperties; import org.apache.cassandra.db.Keyspace; import org.apache.cassandra.db.marshal.VectorType; import org.apache.cassandra.exceptions.InvalidRequestException; @@ -109,6 +111,8 @@ public class CompactionGraph implements Closeable, Accountable @VisibleForTesting public static int PQ_TRAINING_SIZE = ProductQuantization.MAX_PQ_TRAINING_SET_SIZE; + private static boolean ENABLE_FUSED = CassandraRelevantProperties.SAI_VECTOR_ENABLE_FUSED.getBoolean(); + private final VectorType.VectorSerializer serializer; private final VectorSimilarityFunction similarityFunction; private final ChronicleMap, CompactionVectorPostings> postingsMap; @@ -225,12 +229,14 @@ else if (compressor instanceof BinaryQuantization) private OnDiskGraphIndexWriter createTermsWriter(OrdinalMapper ordinalMapper) throws IOException { - return new OnDiskGraphIndexWriter.Builder(builder.getGraph(), termsFile.toPath()) + var writerBuilder = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), termsFile.toPath()) .withStartOffset(termsOffset) .with(new InlineVectors(dimension)) .withVersion(Version.current().onDiskFormat().jvectorFileFormatVersion()) - .withMapper(ordinalMapper) - .build(); + .withMapper(ordinalMapper); + if (ENABLE_FUSED && compressor instanceof ProductQuantization && Version.current().onDiskFormat().jvectorFileFormatVersion() >= 6) + writerBuilder.with(new FusedPQ(context.getIndexWriterConfig().getAnnMaxDegree(), (ProductQuantization) compressor)); + return writerBuilder.build(); } @Override @@ -372,7 +378,7 @@ public long addGraphNode(InsertionResult result) public SegmentMetadata.ComponentMetadataMap flush() throws IOException { // header is required to write the postings, but we need to recreate the writer after that with an accurate OrdinalMapper - writer.writeHeader(); + writer.writeHeader(builder.getGraph().getView()); writer.close(); int nInProgress = builder.insertsInProgress(); @@ -408,7 +414,7 @@ public SegmentMetadata.ComponentMetadataMap flush() throws IOException var es = Executors.newSingleThreadExecutor(new NamedThreadFactory("CompactionGraphPostingsWriter")); long postingsLength; try (var indexHandle = perIndexComponents.get(IndexComponentType.TERMS_DATA).createIndexBuildTimeFileHandle(); - var index = OnDiskGraphIndex.load(indexHandle::createReader, termsOffset)) + var index = OnDiskGraphIndex.load(indexHandle::createReader, termsOffset, false)) { var postingsFuture = es.submit(() -> { // V2 doesn't support ONE_TO_MANY so force it to ZERO_OR_ONE_TO_MANY if necessary; @@ -453,9 +459,18 @@ public SegmentMetadata.ComponentMetadataMap flush() throws IOException // write the graph edge lists and optionally fused adc features var start = System.nanoTime(); - // Required becuase jvector 3 wrote the fused adc map here. We no longer write jvector 3, but we still - // write out the empty map. - writer.write(Map.of()); + if (writer.getFeatureSet().contains(FeatureId.FUSED_PQ)) + { + try (var view = builder.getGraph().getView()) + { + var supplier = Feature.singleStateFactory(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(view, (PQVectors) compressedVectors, ordinal)); + writer.write(supplier); + } + } + else + { + writer.write(Map.of()); + } SAICodecUtils.writeFooter(writer.getOutput(), writer.checksum()); logger.info("Writing graph took {}ms", (System.nanoTime() - start) / 1_000_000); long termsLength = writer.getOutput().position() - termsOffset; diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java index bf08896591dd..bccfd8e1352a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java @@ -31,18 +31,19 @@ import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.BINARY_QUANTIZATION; import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.NONE; import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.PRODUCT_QUANTIZATION; - public enum VectorSourceModel { - ADA002((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25), - OPENAI_V3_SMALL((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.5), - OPENAI_V3_LARGE((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25), - BERT(COSINE, (dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.25), __ -> 1.0), - GECKO((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25), - NV_QA_4((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25), - COHERE_V3((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25), - - OTHER(COSINE, VectorSourceModel::genericCompressionFor, VectorSourceModel::genericOverquery); + ADA002((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, true), + OPENAI_V3_SMALL((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.5, true), + OPENAI_V3_LARGE((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25, true), + // BERT is not known to have unit length vectors in all cases + BERT(COSINE, (dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.25), __ -> 1.0, false), + GECKO((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, true), + NV_QA_4((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, false), + // Cohere does not officially say they have unit length vectors, but some users report that they do + COHERE_V3((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25, false), + + OTHER(COSINE, VectorSourceModel::genericCompressionFor, VectorSourceModel::genericOverquery, false); /** * Default similarity function for this model. @@ -58,18 +59,33 @@ public enum VectorSourceModel */ public final Function overqueryProvider; - VectorSourceModel(Function compressionProvider, double overqueryFactor) + /** + * Indicates that the model is known to have unit length vectors. When false, the runtime checks per graph + * until a non-unit length vector is found. + */ + private final boolean knownUnitLength; + + VectorSourceModel(Function compressionProvider, + double overqueryFactor, + boolean knownUnitLength) { - this(DOT_PRODUCT, compressionProvider, __ -> overqueryFactor); + this(DOT_PRODUCT, compressionProvider, __ -> overqueryFactor, knownUnitLength); } VectorSourceModel(VectorSimilarityFunction defaultSimilarityFunction, Function compressionProvider, - Function overqueryProvider) + Function overqueryProvider, + boolean knownUnitLength) { this.defaultSimilarityFunction = defaultSimilarityFunction; this.compressionProvider = compressionProvider; this.overqueryProvider = overqueryProvider; + this.knownUnitLength = knownUnitLength; + } + + public boolean hasKnownUnitLengthVectors() + { + return knownUnitLength; } public static VectorSourceModel fromString(String value) diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java index 8fc9b37c17e8..748a6f29c364 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java @@ -43,7 +43,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; -public class VectorSiftSmallTest extends VectorTester +public class VectorSiftSmallTest extends VectorTester.Versioned { private static final String DATASET = "siftsmall"; // change to "sift" for larger dataset. requires manual download @@ -156,6 +156,7 @@ public void testCompaction() throws Throwable assertTrue("Pre-compaction recall is " + recall, recall > 0.975); } + compact(); compact(); for (int topK : List.of(1, 100)) { @@ -313,7 +314,7 @@ private void createTable() private void createIndex() { // we need a long timeout because we are adding many vectors - String index = createIndexAsync("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function' : 'euclidean'}"); + String index = createIndexAsync("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function' : 'euclidean', 'enable_hierarchy': 'true'}"); waitForIndexQueryable(KEYSPACE, index, 5, TimeUnit.MINUTES); } diff --git a/test/unit/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIteratorTest.java b/test/unit/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIteratorTest.java index 22658da82ac7..ff025379bf52 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIteratorTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIteratorTest.java @@ -19,10 +19,11 @@ package org.apache.cassandra.index.sai.disk.vector; import java.util.NoSuchElementException; +import java.util.function.Function; import org.junit.Test; -import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.NodeQueue; import io.github.jbellis.jvector.graph.NodesIterator; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; @@ -63,7 +64,7 @@ public void testBruteForceRowIdIteratorForEmptyPQAndTopKEqualsLimit() assertTrue(view.isClosed); } - private static class TestView implements GraphIndex.ScoringView + private static class TestView implements ImmutableGraphIndex.ScoringView { private boolean isClosed = false; @@ -95,6 +96,12 @@ public NodesIterator getNeighborsIterator(int i, int i1) throw new UnsupportedOperationException(); } + @Override + public void processNeighbors(int i, int i1, ScoreFunction scoreFunction, Function function, ImmutableGraphIndex.NeighborProcessor neighborProcessor) + { + + } + @Override public int size() { @@ -102,7 +109,7 @@ public int size() } @Override - public GraphIndex.NodeAtLevel entryNode() + public ImmutableGraphIndex.NodeAtLevel entryNode() { throw new UnsupportedOperationException(); } @@ -112,5 +119,11 @@ public Bits liveNodes() { throw new UnsupportedOperationException(); } + + @Override + public boolean contains(int i, int i1) + { + return false; + } } }