Skip to content

Commit

Permalink
HNSW BP reordering (apache#14097)
Browse files Browse the repository at this point in the history
  • Loading branch information
msokolov authored Jan 7, 2025
1 parent 72d93de commit 02a09f6
Show file tree
Hide file tree
Showing 34 changed files with 1,558 additions and 147 deletions.
2 changes: 1 addition & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ API Changes

New Features
---------------------
(No changes)
* GITHUB#14097: Binary partitioning merge policy over float-valued vector field. (Mike Sokolov)

Improvements
---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,11 @@ public int entryNode() {
throw new UnsupportedOperationException();
}

@Override
public int maxConn() {
throw new UnsupportedOperationException();
}

@Override
public NodesIterator getNodesOnLevel(int level) {
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ public int entryNode() {
throw new UnsupportedOperationException();
}

@Override
public int maxConn() {
return maxConn;
}

@Override
public NodesIterator getNodesOnLevel(int level) {
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,11 @@ public int entryNode() {
return entryNode;
}

@Override
public int maxConn() {
return (int) bytesForConns / Integer.BYTES - 1;
}

@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ public int entryNode() {
return entryNode;
}

@Override
public int maxConn() {
return maxConn;
}

@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,11 @@ public int entryNode() {
return entryNode;
}

@Override
public int maxConn() {
return (int) bytesForConns / Integer.BYTES - 1;
}

@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ private static final class OffHeapHnswGraph extends HnswGraph {
final int size;
final long bytesForConns;
final long bytesForConns0;
final int maxConn;

int arcCount;
int arcUpTo;
Expand All @@ -463,6 +464,7 @@ private static final class OffHeapHnswGraph extends HnswGraph {
this.bytesForConns = Math.multiplyExact(Math.addExact(entry.M, 1L), Integer.BYTES);
this.bytesForConns0 =
Math.multiplyExact(Math.addExact(Math.multiplyExact(entry.M, 2L), 1), Integer.BYTES);
maxConn = entry.M;
}

@Override
Expand Down Expand Up @@ -501,6 +503,11 @@ public int numLevels() {
return numLevels;
}

@Override
public int maxConn() {
return maxConn;
}

@Override
public int entryNode() {
return entryNode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,11 @@ public int entryNode() throws IOException {
return entryNode;
}

@Override
public int maxConn() {
return currentNeighborsBuffer.length / 2;
}

@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,11 @@ public int numLevels() {
return graph.numLevels();
}

@Override
public int maxConn() {
return graph.maxConn();
}

@Override
public int entryNode() {
throw new UnsupportedOperationException("Not supported on a mock graph");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,11 @@ public int entryNode() {
throw new UnsupportedOperationException("Not supported on a mock graph");
}

@Override
public int maxConn() {
throw new UnsupportedOperationException("Not supported on a mock graph");
}

@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ protected KnnVectorsReader() {}
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
*
* <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
* FieldInfo}. The return value is never {@code null}.
* FieldInfo}.
*
* @param field the vector field to search
* @param target the vector-valued query
Expand Down Expand Up @@ -103,7 +103,7 @@ public abstract void search(
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
*
* <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
* FieldInfo}. The return value is never {@code null}.
* FieldInfo}.
*
* @param field the vector field to search
* @param target the vector-valued query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,8 @@ public void seek(int level, int targetOrd) throws IOException {
level == 0
? targetOrd
: Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd);
assert targetIndex >= 0;
assert targetIndex >= 0
: "seek level=" + level + " target=" + targetOrd + " not found: " + targetIndex;
// unsafe; no bounds checking
dataIn.seek(graphLevelNodeOffsets.get(targetIndex + graphLevelNodeIndexOffsets[level]));
arcCount = dataIn.readVInt();
Expand Down Expand Up @@ -526,6 +527,11 @@ public int numLevels() throws IOException {
return numLevels;
}

@Override
public int maxConn() {
return currentNeighborsBuffer.length >> 1;
}

@Override
public int entryNode() throws IOException {
return entryNode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ public Lucene99HnswVectorsWriter(
this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec;
segmentWriteState = state;

String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
Expand Down Expand Up @@ -293,6 +292,11 @@ public int numLevels() {
return graph.numLevels();
}

@Override
public int maxConn() {
return graph.maxConn();
}

@Override
public int entryNode() {
throw new UnsupportedOperationException("Not supported on a mock graph");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
Expand All @@ -41,6 +42,7 @@
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;

/**
* Enables per field numeric vector support.
Expand Down Expand Up @@ -189,7 +191,7 @@ public long ramBytesUsed() {
}

/** VectorReader that can wrap multiple delegate readers, selected by field. */
public static class FieldsReader extends KnnVectorsReader {
public static class FieldsReader extends KnnVectorsReader implements HnswGraphProvider {

private final IntObjectHashMap<KnnVectorsReader> fields = new IntObjectHashMap<>();
private final FieldInfos fieldInfos;
Expand Down Expand Up @@ -322,6 +324,17 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
reader.search(field, target, knnCollector, acceptDocs);
}

@Override
public HnswGraph getGraph(String field) throws IOException {
final FieldInfo info = fieldInfos.fieldInfo(field);
KnnVectorsReader knnVectorsReader = fields.get(info.number);
if (knnVectorsReader instanceof HnswGraphProvider) {
return ((HnswGraphProvider) knnVectorsReader).getGraph(field);
} else {
return null;
}
}

@Override
public void close() throws IOException {
List<KnnVectorsReader> readers = new ArrayList<>(fields.size());
Expand Down
15 changes: 10 additions & 5 deletions lucene/core/src/java/org/apache/lucene/index/IndexSorter.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,16 @@
/**
* Handles how documents should be sorted in an index, both within a segment and between segments.
*
* <p>Implementers must provide the following methods: {@link #getDocComparator(LeafReader,int)} -
* an object that determines how documents within a segment are to be sorted {@link
* #getComparableProviders(List)} - an array of objects that return a sortable long value per
* document and segment {@link #getProviderName()} - the SPI-registered name of a {@link
* SortFieldProvider} to serialize the sort
* <p>Implementers must provide the following methods:
*
* <ul>
* <li>{@link #getDocComparator(LeafReader,int)} - an object that determines how documents within
* a segment are to be sorted
* <li>{@link #getComparableProviders(List)} - an array of objects that return a sortable long
* value per document and segment
* <li>{@link #getProviderName()} - the SPI-registered name of a {@link SortFieldProvider} to
* serialize the sort
* </ul>
*
* <p>The companion {@link SortFieldProvider} should be registered with SPI via {@code
* META-INF/services}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,10 @@ protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxO
} else {
initializedNodes = new FixedBitSet(maxOrd);
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorValues, initializedNodes);
graph =
InitializedHnswGraphBuilder.initGraph(M, initializerGraph, oldToNewOrdinalMap, maxOrd);
graph = InitializedHnswGraphBuilder.initGraph(initializerGraph, oldToNewOrdinalMap, maxOrd);
}
}
return new HnswConcurrentMergeBuilder(
taskExecutor, numWorker, scorerSupplier, M, beamWidth, graph, initializedNodes);
taskExecutor, numWorker, scorerSupplier, beamWidth, graph, initializedNodes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ public HnswConcurrentMergeBuilder(
TaskExecutor taskExecutor,
int numWorker,
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
OnHeapHnswGraph hnsw,
BitSet initializedNodes)
Expand All @@ -62,7 +61,6 @@ public HnswConcurrentMergeBuilder(
workers[i] =
new ConcurrentMergeWorker(
scorerSupplier.copy(),
M,
beamWidth,
HnswGraphBuilder.randSeed,
hnsw,
Expand Down Expand Up @@ -149,7 +147,6 @@ private static final class ConcurrentMergeWorker extends HnswGraphBuilder {

private ConcurrentMergeWorker(
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
long seed,
OnHeapHnswGraph hnsw,
Expand All @@ -159,7 +156,6 @@ private ConcurrentMergeWorker(
throws IOException {
super(
scorerSupplier,
M,
beamWidth,
seed,
hnsw,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ public int maxNodeId() {
/** Returns the number of levels of the graph */
public abstract int numLevels() throws IOException;

/** returns M, the maximum number of connections for a node. */
public abstract int maxConn() throws IOException;

/** Returns graph's entry point on the top level * */
public abstract int entryNode() throws IOException;

Expand Down Expand Up @@ -118,6 +121,11 @@ public int numLevels() {
return 0;
}

@Override
public int maxConn() {
return 0;
}

@Override
public int entryNode() {
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,14 @@ public static HnswGraphBuilder create(
protected HnswGraphBuilder(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize)
throws IOException {
this(scorerSupplier, M, beamWidth, seed, new OnHeapHnswGraph(M, graphSize));
this(scorerSupplier, beamWidth, seed, new OnHeapHnswGraph(M, graphSize));
}

protected HnswGraphBuilder(
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
long seed,
OnHeapHnswGraph hnsw)
RandomVectorScorerSupplier scorerSupplier, int beamWidth, long seed, OnHeapHnswGraph hnsw)
throws IOException {
this(
scorerSupplier,
M,
beamWidth,
seed,
hnsw,
Expand All @@ -125,29 +120,26 @@ protected HnswGraphBuilder(
* ordinals, using the given hyperparameter settings, and returns the resulting graph.
*
* @param scorerSupplier a supplier to create vector scorer from ordinals.
* @param M – graph fanout parameter used to calculate the maximum number of connections a node
* can have – M on upper layers, and M * 2 on the lowest level.
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
* @param seed the seed for a random number generator used during graph construction. Provide this
* to ensure repeatable construction.
* @param hnsw the graph to build, can be previously initialized
*/
protected HnswGraphBuilder(
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
long seed,
OnHeapHnswGraph hnsw,
HnswLock hnswLock,
HnswGraphSearcher graphSearcher)
throws IOException {
if (M <= 0) {
if (hnsw.maxConn() <= 0) {
throw new IllegalArgumentException("M (max connections) must be positive");
}
if (beamWidth <= 0) {
throw new IllegalArgumentException("beamWidth must be positive");
}
this.M = M;
this.M = hnsw.maxConn();
this.scorerSupplier =
Objects.requireNonNull(scorerSupplier, "scorer supplier must not be null");
// normalization factor for level generation; currently not configurable
Expand Down
Loading

0 comments on commit 02a09f6

Please sign in to comment.