diff --git a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java index 5e5f0e46363..1b6b975c985 100644 --- a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java +++ b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java @@ -409,6 +409,8 @@ public enum CassandraRelevantProperties SAI_VECTOR_FLUSH_THRESHOLD_MAX_ROWS("cassandra.sai.vector_flush_threshold_max_rows", "-1"), // Use non-positive value to disable it. Period in millis to trigger a flush for SAI vector memtable index. SAI_VECTOR_FLUSH_PERIOD_IN_MILLIS("cassandra.sai.vector_flush_period_in_millis", "-1"), + // Use nvq when building graphs in compaction + SAI_VECTOR_USE_NVQ("cassandra.sai.vector_search.use_nvq", "true"), /** * Whether to disable auto-compaction */ diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java index c2339861037..8c872a3c3f8 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java @@ -45,18 +45,21 @@ import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; import io.github.jbellis.jvector.graph.disk.OrdinalMapper; +import io.github.jbellis.jvector.graph.disk.feature.NVQ; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.quantization.BQVectors; import io.github.jbellis.jvector.quantization.BinaryQuantization; import io.github.jbellis.jvector.quantization.MutableBQVectors; import io.github.jbellis.jvector.quantization.MutableCompressedVectors; import io.github.jbellis.jvector.quantization.MutablePQVectors; +import io.github.jbellis.jvector.quantization.NVQuantization; import io.github.jbellis.jvector.quantization.PQVectors; import io.github.jbellis.jvector.quantization.ProductQuantization; import io.github.jbellis.jvector.quantization.VectorCompressor; import io.github.jbellis.jvector.util.Accountable; import io.github.jbellis.jvector.util.RamUsageEstimator; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorUtil; import io.github.jbellis.jvector.vector.VectorizationProvider; import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; @@ -67,6 +70,7 @@ import net.openhft.chronicle.map.ChronicleMapBuilder; import org.agrona.collections.Int2ObjectHashMap; import org.apache.cassandra.concurrent.NamedThreadFactory; +import org.apache.cassandra.config.CassandraRelevantProperties; import org.apache.cassandra.db.Keyspace; import org.apache.cassandra.db.marshal.VectorType; import org.apache.cassandra.exceptions.InvalidRequestException; @@ -87,6 +91,7 @@ import org.apache.cassandra.io.util.File; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.service.StorageService; +import org.apache.lucene.index.FloatVectorValues; import static java.lang.Math.max; import static java.lang.Math.min; @@ -136,6 +141,8 @@ public class CompactionGraph implements Closeable, Accountable private MutableCompressedVectors compressedVectors; private GraphIndexBuilder builder; + private final VectorFloat globalMean; + public CompactionGraph(IndexComponents.ForWrite perIndexComponents, VectorCompressor compressor, boolean unitVectors, long keyCount, boolean allRowsHaveVectors) throws IOException { this.perIndexComponents = perIndexComponents; @@ -203,6 +210,10 @@ else if (compressor instanceof BinaryQuantization) logger.warn("Hierarchical graphs configured but node configured with V3OnDiskFormat.JVECTOR_VERSION {}. " + "Skipping setting for {}", jvectorVersion, indexConfig.getIndexName()); + // TODO only do when NVQ enabled + globalMean = CassandraRelevantProperties.SAI_VECTOR_USE_NVQ.getBoolean() ? vts.createFloatVector(new float[dimension]) + : null; + builder = new GraphIndexBuilder(bsp, dimension, indexConfig.getAnnMaxDegree(), @@ -218,16 +229,16 @@ else if (compressor instanceof BinaryQuantization) termsOffset = (termsFile.exists() ? termsFile.length() : 0) + SAICodecUtils.headerSize(); // placeholder writer, will be replaced at flush time when we finalize the index contents - writer = createTermsWriter(new OrdinalMapper.IdentityMapper(maxRowsInGraph)); + writer = createTermsWriter(new OrdinalMapper.IdentityMapper(maxRowsInGraph), new InlineVectors(dimension)); writer.getOutput().seek(termsFile.length()); // position at the end of the previous segment before writing our own header SAICodecUtils.writeHeader(SAICodecUtils.toLuceneOutput(writer.getOutput()), perIndexComponents.version()); } - private OnDiskGraphIndexWriter createTermsWriter(OrdinalMapper ordinalMapper) throws IOException + private OnDiskGraphIndexWriter createTermsWriter(OrdinalMapper ordinalMapper, Feature feature) throws IOException { return new OnDiskGraphIndexWriter.Builder(builder.getGraph(), termsFile.toPath()) .withStartOffset(termsOffset) - .with(new InlineVectors(dimension)) + .with(feature) .withVersion(context.version().onDiskFormat().jvectorFileFormatVersion()) .withMapper(ordinalMapper) .build(); @@ -335,7 +346,14 @@ public InsertionResult maybeAddVector(ByteBuffer term, int segmentRowId) throws pqFinetuned = true; } - writer.writeInline(ordinal, Feature.singleState(FeatureId.INLINE_VECTORS, new InlineVectors.State(vector))); + // When NVQ is enabled, we write the quantized vectors to disk later. When NVQ is not enabled, we write + // the full precision vectors to disk eagerly. + if (globalMean != null) + // Update the global mean + VectorUtil.addInPlace(globalMean, vector); + else + writer.writeInline(ordinal, Feature.singleState(FeatureId.INLINE_VECTORS, new InlineVectors.State(vector))); + // Fill in any holes in the pqVectors (setZero has the side effect of increasing the count) while (compressedVectors.count() < ordinal) compressedVectors.setZero(compressedVectors.count()); @@ -449,14 +467,40 @@ public SegmentMetadata.ComponentMetadataMap flush() throws IOException es.shutdown(); } - // Recreate the writer with the final ordinalMapper - writer = createTermsWriter(ordinalMapper.get()); + + NVQuantization nvq = null; + // Write the NVQ feature + if (globalMean != null) + { + VectorUtil.scale(globalMean, 1.0f / compressedVectors.count()); + nvq = NVQuantization.compute(globalMean, 2); + writer = createTermsWriter(ordinalMapper.get(), new NVQ(nvq)); + } + else + { + // Recreate the writer with the final ordinalMapper + writer = createTermsWriter(ordinalMapper.get(), new InlineVectors(dimension)); + } // write the graph edge lists and optionally fused adc features var start = System.nanoTime(); - // Required becuase jvector 3 wrote the fused adc map here. We no longer write jvector 3, but we still - // write out the empty map. - writer.write(Map.of()); + if (writer.getFeatureSet().contains(FeatureId.NVQ_VECTORS)) + { + try (var view = builder.getGraph().getView()) + { + var supplier = Feature.singleStateFactory(FeatureId.NVQ_VECTORS, ordinal -> { + postingsMap.get() + nvq.quantize(view.getVector(ordinal)); + new NVQ.State(view, (PQVectors) compressedVectors, ordinal) + }); + writer.write(supplier); + } + } + else + { + writer.write(Map.of()); + } + SAICodecUtils.writeFooter(writer.getOutput(), writer.checksum()); logger.info("Writing graph took {}ms", (System.nanoTime() - start) / 1_000_000); long termsLength = writer.getOutput().position() - termsOffset;