Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,8 @@ public enum CassandraRelevantProperties
SAI_VECTOR_FLUSH_THRESHOLD_MAX_ROWS("cassandra.sai.vector_flush_threshold_max_rows", "-1"),
// Use non-positive value to disable it. Period in millis to trigger a flush for SAI vector memtable index.
SAI_VECTOR_FLUSH_PERIOD_IN_MILLIS("cassandra.sai.vector_flush_period_in_millis", "-1"),
// Use nvq when building graphs in compaction
SAI_VECTOR_USE_NVQ("cassandra.sai.vector_search.use_nvq", "true"),
/**
* Whether to disable auto-compaction
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,21 @@
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter;
import io.github.jbellis.jvector.graph.disk.OrdinalMapper;
import io.github.jbellis.jvector.graph.disk.feature.NVQ;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.quantization.BQVectors;
import io.github.jbellis.jvector.quantization.BinaryQuantization;
import io.github.jbellis.jvector.quantization.MutableBQVectors;
import io.github.jbellis.jvector.quantization.MutableCompressedVectors;
import io.github.jbellis.jvector.quantization.MutablePQVectors;
import io.github.jbellis.jvector.quantization.NVQuantization;
import io.github.jbellis.jvector.quantization.PQVectors;
import io.github.jbellis.jvector.quantization.ProductQuantization;
import io.github.jbellis.jvector.quantization.VectorCompressor;
import io.github.jbellis.jvector.util.Accountable;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
Expand All @@ -67,6 +70,7 @@
import net.openhft.chronicle.map.ChronicleMapBuilder;
import org.agrona.collections.Int2ObjectHashMap;
import org.apache.cassandra.concurrent.NamedThreadFactory;
import org.apache.cassandra.config.CassandraRelevantProperties;
import org.apache.cassandra.db.Keyspace;
import org.apache.cassandra.db.marshal.VectorType;
import org.apache.cassandra.exceptions.InvalidRequestException;
Expand All @@ -87,6 +91,7 @@
import org.apache.cassandra.io.util.File;
import org.apache.cassandra.io.util.FileUtils;
import org.apache.cassandra.service.StorageService;
import org.apache.lucene.index.FloatVectorValues;

import static java.lang.Math.max;
import static java.lang.Math.min;
Expand Down Expand Up @@ -136,6 +141,8 @@ public class CompactionGraph implements Closeable, Accountable
private MutableCompressedVectors compressedVectors;
private GraphIndexBuilder builder;

private final VectorFloat<?> globalMean;

public CompactionGraph(IndexComponents.ForWrite perIndexComponents, VectorCompressor<?> compressor, boolean unitVectors, long keyCount, boolean allRowsHaveVectors) throws IOException
{
this.perIndexComponents = perIndexComponents;
Expand Down Expand Up @@ -203,6 +210,10 @@ else if (compressor instanceof BinaryQuantization)
logger.warn("Hierarchical graphs configured but node configured with V3OnDiskFormat.JVECTOR_VERSION {}. " +
"Skipping setting for {}", jvectorVersion, indexConfig.getIndexName());

// TODO only do when NVQ enabled
globalMean = CassandraRelevantProperties.SAI_VECTOR_USE_NVQ.getBoolean() ? vts.createFloatVector(new float[dimension])
: null;

builder = new GraphIndexBuilder(bsp,
dimension,
indexConfig.getAnnMaxDegree(),
Expand All @@ -218,16 +229,16 @@ else if (compressor instanceof BinaryQuantization)
termsOffset = (termsFile.exists() ? termsFile.length() : 0)
+ SAICodecUtils.headerSize();
// placeholder writer, will be replaced at flush time when we finalize the index contents
writer = createTermsWriter(new OrdinalMapper.IdentityMapper(maxRowsInGraph));
writer = createTermsWriter(new OrdinalMapper.IdentityMapper(maxRowsInGraph), new InlineVectors(dimension));
writer.getOutput().seek(termsFile.length()); // position at the end of the previous segment before writing our own header
SAICodecUtils.writeHeader(SAICodecUtils.toLuceneOutput(writer.getOutput()));
}

private OnDiskGraphIndexWriter createTermsWriter(OrdinalMapper ordinalMapper) throws IOException
private OnDiskGraphIndexWriter createTermsWriter(OrdinalMapper ordinalMapper, Feature feature) throws IOException
{
return new OnDiskGraphIndexWriter.Builder(builder.getGraph(), termsFile.toPath())
.withStartOffset(termsOffset)
.with(new InlineVectors(dimension))
.with(feature)
.withVersion(Version.current().onDiskFormat().jvectorFileFormatVersion())
.withMapper(ordinalMapper)
.build();
Expand Down Expand Up @@ -335,7 +346,14 @@ public InsertionResult maybeAddVector(ByteBuffer term, int segmentRowId) throws
pqFinetuned = true;
}

writer.writeInline(ordinal, Feature.singleState(FeatureId.INLINE_VECTORS, new InlineVectors.State(vector)));
// When NVQ is enabled, we write the quantized vectors to disk later. When NVQ is not enabled, we write
// the full precision vectors to disk eagerly.
if (globalMean != null)
// Update the global mean
VectorUtil.addInPlace(globalMean, vector);
else
writer.writeInline(ordinal, Feature.singleState(FeatureId.INLINE_VECTORS, new InlineVectors.State(vector)));

// Fill in any holes in the pqVectors (setZero has the side effect of increasing the count)
while (compressedVectors.count() < ordinal)
compressedVectors.setZero(compressedVectors.count());
Expand Down Expand Up @@ -448,14 +466,40 @@ public SegmentMetadata.ComponentMetadataMap flush() throws IOException
es.shutdown();
}

// Recreate the writer with the final ordinalMapper
writer = createTermsWriter(ordinalMapper.get());

NVQuantization nvq = null;
// Write the NVQ feature
if (globalMean != null)
{
VectorUtil.scale(globalMean, 1.0f / compressedVectors.count());
nvq = NVQuantization.compute(globalMean, 2);
writer = createTermsWriter(ordinalMapper.get(), new NVQ(nvq));
}
else
{
// Recreate the writer with the final ordinalMapper
writer = createTermsWriter(ordinalMapper.get(), new InlineVectors(dimension));
}

// write the graph edge lists and optionally fused adc features
var start = System.nanoTime();
// Required becuase jvector 3 wrote the fused adc map here. We no longer write jvector 3, but we still
// write out the empty map.
writer.write(Map.of());
if (writer.getFeatureSet().contains(FeatureId.NVQ_VECTORS))
{
try (var view = builder.getGraph().getView())
{
var supplier = Feature.singleStateFactory(FeatureId.NVQ_VECTORS, ordinal -> {
postingsMap.get()
nvq.quantize(view.getVector(ordinal));
new NVQ.State(view, (PQVectors) compressedVectors, ordinal)
});
writer.write(supplier);
}
}
else
{
writer.write(Map.of());
}

SAICodecUtils.writeFooter(writer.getOutput(), writer.checksum());
logger.info("Writing graph took {}ms", (System.nanoTime() - start) / 1_000_000);
long termsLength = writer.getOutput().position() - termsOffset;
Expand Down
Loading