From 133861e7885989ffbc509bd164d4de0570273f47 Mon Sep 17 00:00:00 2001 From: Umesh Madan Date: Thu, 6 Mar 2025 17:09:11 -0800 Subject: [PATCH] knowpro: Message Indexing (#796) * Message Indexing * Refactored and re-arranged fuzzy/embedding indexing for better robustness. * Return partial results to allow restarts * Cleaner batch behavior * Classic RAG command * Setting up for persisting message index serialization. --- ts/examples/chat/src/memory/knowproMemory.ts | 199 ++++++---- ts/examples/chat/src/memory/knowproPrinter.ts | 7 + ts/packages/knowPro/src/conversationThread.ts | 5 +- ts/packages/knowPro/src/fuzzyIndex.ts | 360 ++++++++++++++---- ts/packages/knowPro/src/import.ts | 10 +- ts/packages/knowPro/src/index.ts | 1 + ts/packages/knowPro/src/interfaces.ts | 24 ++ ts/packages/knowPro/src/messageIndex.ts | 132 +++++++ ts/packages/knowPro/src/secondaryIndexes.ts | 4 +- ts/packages/knowPro/src/textLocationIndex.ts | 148 +++---- .../memory/conversation/src/importPodcast.ts | 25 +- ts/packages/typeagent/src/lib/array.ts | 56 +++ 12 files changed, 736 insertions(+), 235 deletions(-) create mode 100644 ts/packages/knowPro/src/messageIndex.ts diff --git a/ts/examples/chat/src/memory/knowproMemory.ts b/ts/examples/chat/src/memory/knowproMemory.ts index d19b845fa..4fd137a1d 100644 --- a/ts/examples/chat/src/memory/knowproMemory.ts +++ b/ts/examples/chat/src/memory/knowproMemory.ts @@ -69,6 +69,7 @@ export async function createKnowproCommands( commands.kpSearchTerms = searchTerms; commands.kpSearchV1 = searchV1; commands.kpSearch = search; + commands.kpPodcastRag = podcastRag; commands.kpEntities = entities; commands.kpPodcastBuildIndex = podcastBuildIndex; commands.kpPodcastBuildMessageIndex = podcastBuildMessageIndex; @@ -454,7 +455,7 @@ export async function createKnowproCommands( function searchDef(): CommandMetadata { return { description: - "Search using natural language and knowlege-processor search filters", + "Search using natural language and old knowlege-processor search filters", args: { query: arg("Search query"), }, @@ -560,88 +561,37 @@ export async function createKnowproCommands( } } - function createSearchGroup( - termArgs: string[], - namedArgs: NamedArgs, - commandDef: CommandMetadata, - andTerms: boolean = false, - ): kp.SearchTermGroup { - const searchTerms = parseQueryTerms(termArgs); - const propertyTerms = propertyTermsFromNamedArgs(namedArgs, commandDef); + function ragDef(): CommandMetadata { return { - booleanOp: andTerms ? "and" : "or", - terms: [...searchTerms, ...propertyTerms], + description: "Classic rag", + args: { + query: arg("Search query"), + }, + options: { + maxToDisplay: argNum("Maximum matches to display", 25), + }, }; } - - function propertyTermsFromNamedArgs( - namedArgs: NamedArgs, - commandDef: CommandMetadata, - ): kp.PropertySearchTerm[] { - return createPropertyTerms(namedArgs, commandDef); - } - - function createPropertyTerms( - namedArgs: NamedArgs, - commandDef: CommandMetadata, - nameFilter?: (name: string) => boolean, - ): kp.PropertySearchTerm[] { - const keyValues = keyValuesFromNamedArgs(namedArgs, commandDef); - const propertyNames = nameFilter - ? Object.keys(keyValues).filter(nameFilter) - : Object.keys(keyValues); - const propertySearchTerms: kp.PropertySearchTerm[] = []; - for (const propertyName of propertyNames) { - const allValues = splitTermValues(keyValues[propertyName]); - for (const value of allValues) { - propertySearchTerms.push( - kp.createPropertySearchTerm(propertyName, value), - ); - } + commands.kpPodcastRag.metadata = ragDef(); + async function podcastRag(args: string[]): Promise { + if (!ensureConversationLoaded()) { + return; } - return propertySearchTerms; - } - - function whenFilterFromNamedArgs( - namedArgs: NamedArgs, - commandDef: CommandMetadata, - ): kp.WhenFilter { - let filter: kp.WhenFilter = { - knowledgeType: namedArgs.ktype, - }; - const conv: kp.IConversation | undefined = - context.podcast ?? context.images; - const dateRange = kp.getTimeRangeForConversation(conv!); - if (dateRange) { - let startDate: Date | undefined; - let endDate: Date | undefined; - // Did they provide an explicit date range? - if (namedArgs.startDate || namedArgs.endDate) { - startDate = argToDate(namedArgs.startDate) ?? dateRange.start; - endDate = argToDate(namedArgs.endDate) ?? dateRange.end; - } else { - // They may have provided a relative date range - if (namedArgs.startMinute >= 0) { - startDate = dateTime.addMinutesToDate( - dateRange.start, - namedArgs.startMinute, - ); - } - if (namedArgs.endMinute > 0) { - endDate = dateTime.addMinutesToDate( - dateRange.start, - namedArgs.endMinute, - ); - } - } - if (startDate) { - filter.dateRange = { - start: startDate, - end: endDate, - }; - } + const messageIndex = + context.conversation?.secondaryIndexes?.messageIndex; + if (!messageIndex) { + context.printer.writeError( + "No message text index. Run kpPodcastBuildMessageIndex", + ); + return; } - return filter; + const namedArgs = parseNamedArguments(args, ragDef()); + const matches = await messageIndex.lookupMessages(namedArgs.query); + context.printer.writeScoredMessages( + matches, + context.conversation?.messages!, + namedArgs.maxToDisplay, + ); } function entitiesDef(): CommandMetadata { @@ -728,6 +678,9 @@ export async function createKnowproCommands( commands.kpPodcastBuildMessageIndex.metadata = podcastBuildMessageIndexDef(); async function podcastBuildMessageIndex(args: string[]): Promise { + if (!ensureConversationLoaded()) { + return; + } const namedArgs = parseNamedArguments( args, podcastBuildMessageIndexDef(), @@ -736,7 +689,7 @@ export async function createKnowproCommands( `Indexing ${context.conversation?.messages.length} messages`, ); let progress = new ProgressBar(context.printer, namedArgs.maxMessages); - await context.podcast?.buildMessageIndex( + const result = await context.podcast!.buildMessageIndex( createIndexingEventHandler( context, progress, @@ -745,8 +698,12 @@ export async function createKnowproCommands( namedArgs.batchSize, ); progress.complete(); + context.printer.writeListIndexingResult(result); } + //------------------------- + // Index Image Building + //-------------------------- function imageCollectionBuildIndexDef(): CommandMetadata { return { description: "Build image collection index", @@ -789,6 +746,90 @@ export async function createKnowproCommands( End COMMANDS ------------*/ + function createSearchGroup( + termArgs: string[], + namedArgs: NamedArgs, + commandDef: CommandMetadata, + andTerms: boolean = false, + ): kp.SearchTermGroup { + const searchTerms = parseQueryTerms(termArgs); + const propertyTerms = propertyTermsFromNamedArgs(namedArgs, commandDef); + return { + booleanOp: andTerms ? "and" : "or", + terms: [...searchTerms, ...propertyTerms], + }; + } + + function propertyTermsFromNamedArgs( + namedArgs: NamedArgs, + commandDef: CommandMetadata, + ): kp.PropertySearchTerm[] { + return createPropertyTerms(namedArgs, commandDef); + } + + function createPropertyTerms( + namedArgs: NamedArgs, + commandDef: CommandMetadata, + nameFilter?: (name: string) => boolean, + ): kp.PropertySearchTerm[] { + const keyValues = keyValuesFromNamedArgs(namedArgs, commandDef); + const propertyNames = nameFilter + ? Object.keys(keyValues).filter(nameFilter) + : Object.keys(keyValues); + const propertySearchTerms: kp.PropertySearchTerm[] = []; + for (const propertyName of propertyNames) { + const allValues = splitTermValues(keyValues[propertyName]); + for (const value of allValues) { + propertySearchTerms.push( + kp.createPropertySearchTerm(propertyName, value), + ); + } + } + return propertySearchTerms; + } + + function whenFilterFromNamedArgs( + namedArgs: NamedArgs, + commandDef: CommandMetadata, + ): kp.WhenFilter { + let filter: kp.WhenFilter = { + knowledgeType: namedArgs.ktype, + }; + const conv: kp.IConversation | undefined = + context.podcast ?? context.images; + const dateRange = kp.getTimeRangeForConversation(conv!); + if (dateRange) { + let startDate: Date | undefined; + let endDate: Date | undefined; + // Did they provide an explicit date range? + if (namedArgs.startDate || namedArgs.endDate) { + startDate = argToDate(namedArgs.startDate) ?? dateRange.start; + endDate = argToDate(namedArgs.endDate) ?? dateRange.end; + } else { + // They may have provided a relative date range + if (namedArgs.startMinute >= 0) { + startDate = dateTime.addMinutesToDate( + dateRange.start, + namedArgs.startMinute, + ); + } + if (namedArgs.endMinute > 0) { + endDate = dateTime.addMinutesToDate( + dateRange.start, + namedArgs.endMinute, + ); + } + } + if (startDate) { + filter.dateRange = { + start: startDate, + end: endDate, + }; + } + } + return filter; + } + function ensureConversationLoaded(): kp.IConversation | undefined { if (context.conversation) { return context.conversation; diff --git a/ts/examples/chat/src/memory/knowproPrinter.ts b/ts/examples/chat/src/memory/knowproPrinter.ts index 26d79e1f9..c9f1e93a2 100644 --- a/ts/examples/chat/src/memory/knowproPrinter.ts +++ b/ts/examples/chat/src/memory/knowproPrinter.ts @@ -382,6 +382,13 @@ export class KnowProPrinter extends ChatPrinter { return this; } + public writeListIndexingResult(result: kp.ListIndexingResult) { + this.writeLine(`Indexed ${result.numberCompleted} items`); + if (result.error) { + this.writeError(result.error); + } + } + public writeSearchFilter( action: knowLib.conversation.GetAnswerWithTermsActionV2, ) { diff --git a/ts/packages/knowPro/src/conversationThread.ts b/ts/packages/knowPro/src/conversationThread.ts index 96cb5575b..b3ce8aa4b 100644 --- a/ts/packages/knowPro/src/conversationThread.ts +++ b/ts/packages/knowPro/src/conversationThread.ts @@ -13,6 +13,7 @@ import { TextEmbeddingIndex, TextEmbeddingIndexSettings, } from "./fuzzyIndex.js"; +import { NormalizedEmbedding } from "typeagent"; export interface IConversationThreadData { threads?: IThreadDataItem[] | undefined; @@ -91,13 +92,15 @@ export class ConversationThreads implements IConversationThreads { if (data.threads) { this.threads = []; this.embeddingIndex.clear(); + const embeddings: NormalizedEmbedding[] = []; for (let i = 0; i < data.threads.length; ++i) { this.threads.push(data.threads[i].thread); const embedding = deserializeEmbedding( data.threads[i].embedding, ); - this.embeddingIndex.add(embedding); + embeddings.push(embedding); } + this.embeddingIndex.deserialize(embeddings); } } } diff --git a/ts/packages/knowPro/src/fuzzyIndex.ts b/ts/packages/knowPro/src/fuzzyIndex.ts index e1c670027..a9492eeed 100644 --- a/ts/packages/knowPro/src/fuzzyIndex.ts +++ b/ts/packages/knowPro/src/fuzzyIndex.ts @@ -15,95 +15,91 @@ import { openai, TextEmbeddingModel } from "aiclient"; import * as levenshtein from "fast-levenshtein"; import { createEmbeddingCache } from "knowledge-processor"; import { Scored } from "./common.js"; -import { IndexingEventHandlers } from "./interfaces.js"; +import { ListIndexingResult, IndexingEventHandlers } from "./interfaces.js"; +import { error, Result, success } from "typechat"; -export class TextEmbeddingIndex { +export class EmbeddingIndex { private embeddings: NormalizedEmbedding[]; - constructor(public settings: TextEmbeddingIndexSettings) { - this.embeddings = []; + constructor(embeddings?: NormalizedEmbedding[]) { + this.embeddings = embeddings ?? []; } public get size(): number { return this.embeddings.length; } - public async addText(texts: string | string[]): Promise { - if (Array.isArray(texts)) { - const embeddings = await generateTextEmbeddingsWithRetry( - this.settings.embeddingModel, - texts, - ); + public get(pos: number): NormalizedEmbedding { + return this.embeddings[pos]; + } + + public push(embeddings: NormalizedEmbedding | NormalizedEmbedding[]): void { + if (Array.isArray(embeddings)) { this.embeddings.push(...embeddings); } else { - const embedding = await generateEmbedding( - this.settings.embeddingModel, - texts, - ); - this.embeddings.push(embedding); + this.embeddings.push(embeddings); } } - public async addTextBatch( - textToIndex: string[], - eventHandler?: IndexingEventHandlers, - batchSize?: number, - ): Promise { - for (const batch of getIndexingBatches( - textToIndex, - batchSize ?? this.settings.batchSize, - )) { - if ( - eventHandler?.onEmbeddingsCreated && - !eventHandler.onEmbeddingsCreated( - textToIndex, - batch.values, - batch.startAt, - ) - ) { - break; - } - // TODO: return IndexingResult to track how far we got before a non-recoverable failure - await this.addText(batch.values); + public insertAt( + index: number, + embeddings: NormalizedEmbedding | NormalizedEmbedding[], + ): void { + if (Array.isArray(embeddings)) { + this.embeddings.splice(index, 0, ...embeddings); + } else { + this.embeddings.splice(index, 0, embeddings); } } - public get(pos: number): NormalizedEmbedding { - return this.embeddings[pos]; - } - - public add(embedding: NormalizedEmbedding): void { - this.embeddings.push(embedding); - } - - public async getIndexesOfNearest( - text: string | NormalizedEmbedding, + public getIndexesOfNearest( + embedding: NormalizedEmbedding, maxMatches?: number, minScore?: number, - ): Promise { - const textEmbedding = await generateEmbedding( - this.settings.embeddingModel, - text, + ): Scored[] { + return this.indexesOfNearest( + this.embeddings, + embedding, + maxMatches, + minScore, ); - return this.indexesOfNearestText(textEmbedding, maxMatches, minScore); } - public async getIndexesOfNearestMultiple( - textArray: string[], + /** + * Finds the indexes of the nearest embeddings within a specified subset. + * + * This function searches for the nearest embeddings to a given embedding + * within a subset of the embeddings array, defined by the provided indices. + * + * @param {NormalizedEmbedding} embedding - The embedding to compare against. + * @param {number[]} indicesToSearch - An array of indices specifying the subset of embeddings to search. + * @param {number} [maxMatches] - Optional. The maximum number of matches to return. If not specified, all matches are returned. + * @param {number} [minScore] - Optional. The minimum similarity score required for a match to be considered valid. + * @returns {Scored[]} An array of objects, each containing the index of a matching embedding and its similarity score. + */ + public getIndexesOfNearestInSubset( + embedding: NormalizedEmbedding, + indicesToSearch: number[], maxMatches?: number, minScore?: number, - ): Promise { - const textEmbeddings = await generateTextEmbeddings( - this.settings.embeddingModel, - textArray, + ): Scored[] { + const embeddingsToSearch = indicesToSearch.map( + (i) => this.embeddings[i], ); - const results = []; - for (const embedding of textEmbeddings) { - results.push( - await this.getIndexesOfNearest(embedding, maxMatches, minScore), - ); - } - return results; + // This gives us the offsets within in the embeddingsToSearch array + const nearestInSubset = this.indexesOfNearest( + embeddingsToSearch, + embedding, + maxMatches, + minScore, + ); + // We need to map back to actual positions + return nearestInSubset.map((match) => { + return { + item: indicesToSearch[match.item], + score: match.score, + }; + }); } public removeAt(pos: number): void { @@ -122,26 +118,25 @@ export class TextEmbeddingIndex { this.embeddings = embeddings; } - private indexesOfNearestText( - textEmbedding: NormalizedEmbedding, + private indexesOfNearest( + embeddingsToSearch: NormalizedEmbedding[], + embedding: NormalizedEmbedding, maxMatches?: number, minScore?: number, ): Scored[] { - maxMatches ??= this.settings.maxMatches; - minScore ??= this.settings.minScore; let matches: Scored[]; if (maxMatches && maxMatches > 0) { matches = indexesOfNearest( - this.embeddings, - textEmbedding, + embeddingsToSearch, + embedding, maxMatches, SimilarityType.Dot, minScore, ); } else { matches = indexesOfAllNearest( - this.embeddings, - textEmbedding, + embeddingsToSearch, + embedding, SimilarityType.Dot, minScore, ); @@ -150,6 +145,227 @@ export class TextEmbeddingIndex { } } +export async function generateTextEmbeddingsForIndex( + embeddingModel: TextEmbeddingModel, + texts: string | string[], +): Promise> { + try { + let embeddings: NormalizedEmbedding[]; + const textsToEmbed = Array.isArray(texts) ? texts : [texts]; + embeddings = await generateTextEmbeddingsWithRetry( + embeddingModel, + textsToEmbed, + ); + return success(embeddings); + } catch (ex) { + return error(`generateTExtEmbeddingsForIndex failed: ${ex}`); + } +} + +export async function addTextToEmbeddingIndex( + embeddingIndex: EmbeddingIndex, + embeddingModel: TextEmbeddingModel, + textToIndex: string[], +): Promise { + let result: ListIndexingResult = { numberCompleted: 0 }; + const embeddingResult = await generateTextEmbeddingsForIndex( + embeddingModel, + textToIndex, + ); + if (embeddingResult.success) { + embeddingIndex.push(embeddingResult.data); + result.numberCompleted = textToIndex.length; + } else { + result.error = embeddingResult.message; + } + return result; +} + +export async function addTextBatchToEmbeddingIndex( + embeddingIndex: EmbeddingIndex, + embeddingModel: TextEmbeddingModel, + textToIndex: string[], + batchSize: number, + eventHandler?: IndexingEventHandlers, +): Promise { + let result: ListIndexingResult = { numberCompleted: 0 }; + for (const batch of getIndexingBatches(textToIndex, batchSize)) { + if ( + eventHandler?.onEmbeddingsCreated && + !eventHandler.onEmbeddingsCreated( + textToIndex, + batch.values, + batch.startAt, + ) + ) { + break; + } + const batchResult = await generateTextEmbeddingsForIndex( + embeddingModel, + batch.values, + ); + if (!batchResult.success) { + result.error = batchResult.message; + break; + } + embeddingIndex.push(batchResult.data); + result.numberCompleted = batch.startAt + batch.values.length; + } + return result; +} + +export async function indexOfNearestTextInIndex( + embeddingIndex: EmbeddingIndex, + embeddingModel: TextEmbeddingModel, + text: string, + maxMatches?: number, + minScore?: number, +): Promise { + const textEmbedding = await generateEmbedding(embeddingModel, text); + return embeddingIndex.getIndexesOfNearest( + textEmbedding, + maxMatches, + minScore, + ); +} + +export async function indexOfNearestTextInIndexSubset( + embeddingIndex: EmbeddingIndex, + embeddingModel: TextEmbeddingModel, + text: string, + indicesToSearch: number[], + maxMatches?: number, + minScore?: number, +): Promise { + const textEmbedding = await generateEmbedding(embeddingModel, text); + return embeddingIndex.getIndexesOfNearestInSubset( + textEmbedding, + indicesToSearch, + maxMatches, + minScore, + ); +} + +export async function indexesOfNearestTextBatchInIndex( + embeddingIndex: EmbeddingIndex, + embeddingModel: TextEmbeddingModel, + textArray: string[], + maxMatches?: number, + minScore?: number, +): Promise { + const textEmbeddings = await generateTextEmbeddings( + embeddingModel, + textArray, + ); + const results = []; + for (const embedding of textEmbeddings) { + results.push( + embeddingIndex.getIndexesOfNearest(embedding, maxMatches, minScore), + ); + } + return results; +} + +export class TextEmbeddingIndex { + private embeddingIndex: EmbeddingIndex; + + constructor(public settings: TextEmbeddingIndexSettings) { + this.embeddingIndex = new EmbeddingIndex(); + } + + public get size(): number { + return this.embeddingIndex.size; + } + + /** + * Convert text into embeddings and add them to the internal index. + * This can throw + * @param textToIndex + */ + public async addText( + textToIndex: string | string[], + ): Promise { + return addTextToEmbeddingIndex( + this.embeddingIndex, + this.settings.embeddingModel, + Array.isArray(textToIndex) ? textToIndex : [textToIndex], + ); + } + + /** + * Add text to the index in batches + * @param textToIndex + * @param eventHandler + * @param batchSize + * @returns Returns the index of the last item in textToIndex which was successfully completed + */ + public async addTextBatch( + textToIndex: string[], + eventHandler?: IndexingEventHandlers, + batchSize?: number, + ): Promise { + return addTextBatchToEmbeddingIndex( + this.embeddingIndex, + this.settings.embeddingModel, + textToIndex, + batchSize ?? this.settings.batchSize, + eventHandler, + ); + } + + public get(pos: number): NormalizedEmbedding { + return this.embeddingIndex.get(pos); + } + + public async getIndexesOfNearest( + text: string, + maxMatches?: number, + minScore?: number, + ): Promise { + maxMatches ??= this.settings.maxMatches; + minScore ??= this.settings.minScore; + return indexOfNearestTextInIndex( + this.embeddingIndex, + this.settings.embeddingModel, + text, + maxMatches, + minScore, + ); + } + + public async getIndexesOfNearestMultiple( + textBatch: string[], + maxMatches?: number, + minScore?: number, + ): Promise { + maxMatches ??= this.settings.maxMatches; + minScore ??= this.settings.minScore; + return indexesOfNearestTextBatchInIndex( + this.embeddingIndex, + this.settings.embeddingModel, + textBatch, + maxMatches, + minScore, + ); + } + + public removeAt(pos: number): void { + this.embeddingIndex.removeAt(pos); + } + + public clear(): void { + this.embeddingIndex.clear(); + } + + public serialize(): Float32Array[] { + return this.embeddingIndex.serialize(); + } + + public deserialize(embeddings: Float32Array[]): void { + this.embeddingIndex.deserialize(embeddings); + } +} + export function serializeEmbedding(embedding: NormalizedEmbedding): number[] { return Array.from(embedding); } diff --git a/ts/packages/knowPro/src/import.ts b/ts/packages/knowPro/src/import.ts index 8874556d7..92ba6b4d1 100644 --- a/ts/packages/knowPro/src/import.ts +++ b/ts/packages/knowPro/src/import.ts @@ -7,19 +7,23 @@ import { TextEmbeddingIndexSettings, } from "./fuzzyIndex.js"; import { RelatedTermIndexSettings } from "./relatedTermsIndex.js"; +import { MessageTextIndexSettings } from "./messageIndex.js"; export type ConversationSettings = { relatedTermIndexSettings: RelatedTermIndexSettings; threadSettings: TextEmbeddingIndexSettings; + messageTextIndexSettings: MessageTextIndexSettings; }; export function createConversationSettings(): ConversationSettings { - const embeddingIndexSettings = createTextEmbeddingIndexSettings(); return { relatedTermIndexSettings: { - embeddingIndexSettings, + embeddingIndexSettings: createTextEmbeddingIndexSettings(), + }, + threadSettings: createTextEmbeddingIndexSettings(), + messageTextIndexSettings: { + embeddingIndexSettings: createTextEmbeddingIndexSettings(), }, - threadSettings: embeddingIndexSettings, }; } diff --git a/ts/packages/knowPro/src/index.ts b/ts/packages/knowPro/src/index.ts index 6d1aa7f4d..2d3302394 100644 --- a/ts/packages/knowPro/src/index.ts +++ b/ts/packages/knowPro/src/index.ts @@ -17,3 +17,4 @@ export * from "./dateTimeSchema.js"; export * from "./searchSchema.js"; export * from "./searchTranslator.js"; export * from "./textLocationIndex.js"; +export * from "./messageIndex.js"; diff --git a/ts/packages/knowPro/src/interfaces.ts b/ts/packages/knowPro/src/interfaces.ts index 81b9168d2..2b6bf88e4 100644 --- a/ts/packages/knowPro/src/interfaces.ts +++ b/ts/packages/knowPro/src/interfaces.ts @@ -116,6 +116,7 @@ export interface IConversationSecondaryIndexes { timestampIndex?: ITimestampToTextRangeIndex | undefined; termToRelatedTermsIndex?: ITermToRelatedTermsIndex | undefined; threads?: IConversationThreads | undefined; + messageIndex?: IMessageTextIndex | undefined; } /** @@ -201,6 +202,24 @@ export interface IConversationThreads { removeThread(threadIndex: ThreadIndex): void; } +export interface IMessageTextIndex { + addMessages( + messages: IMessage[], + eventHandler?: IndexingEventHandlers, + ): Promise; + lookupMessages( + messageText: string, + maxMatches?: number, + thresholdScore?: number, + ): Promise; + lookupMessagesInSubset( + messageText: string, + indicesToSearch: MessageIndex[], + maxMatches?: number, + thresholdScore?: number, + ): Promise; +} + //------------------------ // Serialization formats //------------------------ @@ -243,3 +262,8 @@ export type IndexingResults = { chunksIndexedUpto?: TextLocation | undefined; error?: string | undefined; }; + +export type ListIndexingResult = { + numberCompleted: number; + error?: string | undefined; +}; diff --git a/ts/packages/knowPro/src/messageIndex.ts b/ts/packages/knowPro/src/messageIndex.ts new file mode 100644 index 000000000..e017cae56 --- /dev/null +++ b/ts/packages/knowPro/src/messageIndex.ts @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +import { TextEmbeddingIndexSettings } from "./fuzzyIndex.js"; +import { + IMessage, + MessageIndex, + IndexingEventHandlers, + TextLocation, + ListIndexingResult, + ScoredMessageIndex, + IConversation, + IMessageTextIndex, +} from "./interfaces.js"; +import { + ITextToTextLocationIndexData, + TextToTextLocationIndex, +} from "./textLocationIndex.js"; + +export type MessageTextIndexSettings = { + embeddingIndexSettings: TextEmbeddingIndexSettings; +}; + +export interface IMessageTextIndexData { + indexData?: ITextToTextLocationIndexData | undefined; +} + +export class MessageTextIndex implements IMessageTextIndex { + private textLocationIndex: TextToTextLocationIndex; + + constructor(public settings: MessageTextIndexSettings) { + this.textLocationIndex = new TextToTextLocationIndex( + settings.embeddingIndexSettings, + ); + } + + public get size(): number { + return this.textLocationIndex.size; + } + + public addMessages( + messages: IMessage[], + eventHandler?: IndexingEventHandlers, + ): Promise { + const baseMessageIndex: MessageIndex = this.size; + const allChunks: [string, TextLocation][] = []; + // Collect everything so we can batch efficiently + for (let i = 0; i < messages.length; ++i) { + const message = messages[i]; + let messageIndex = baseMessageIndex + i; + for ( + let chunkIndex = 0; + chunkIndex < message.textChunks.length; + ++chunkIndex + ) { + allChunks.push([ + message.textChunks[chunkIndex], + { messageIndex, chunkIndex }, + ]); + } + } + return this.textLocationIndex.addTextLocations(allChunks, eventHandler); + } + + public async lookupMessages( + messageText: string, + maxMatches?: number, + thresholdScore?: number, + ): Promise { + const scoredLocations = await this.textLocationIndex.lookupText( + messageText, + maxMatches, + thresholdScore, + ); + return scoredLocations.map((sl) => { + return { + messageIndex: sl.textLocation.messageIndex, + score: sl.score, + }; + }); + } + + public async lookupMessagesInSubset( + messageText: string, + indicesToSearch: MessageIndex[], + maxMatches?: number, + thresholdScore?: number, + ): Promise { + const scoredLocations = await this.textLocationIndex.lookupTextInSubset( + messageText, + indicesToSearch, + maxMatches, + thresholdScore, + ); + return scoredLocations.map((sl) => { + return { + messageIndex: sl.textLocation.messageIndex, + score: sl.score, + }; + }); + } + + public serialize(): IMessageTextIndexData { + return { + indexData: this.textLocationIndex.serialize(), + }; + } + + public deserialize(data: IMessageTextIndexData): void { + if (data.indexData) { + this.textLocationIndex.clear(); + this.textLocationIndex.deserialize(data.indexData); + } + } +} + +export async function buildMessageIndex( + conversation: IConversation, + settings: MessageTextIndexSettings, + eventHandler?: IndexingEventHandlers, +): Promise { + if (conversation.secondaryIndexes) { + conversation.secondaryIndexes.messageIndex ??= new MessageTextIndex( + settings, + ); + const messageIndex = conversation.secondaryIndexes.messageIndex; + const messages = conversation.messages; + return messageIndex.addMessages(messages, eventHandler); + } + return { + numberCompleted: 0, + }; +} diff --git a/ts/packages/knowPro/src/secondaryIndexes.ts b/ts/packages/knowPro/src/secondaryIndexes.ts index c57ee96b9..48699ccfd 100644 --- a/ts/packages/knowPro/src/secondaryIndexes.ts +++ b/ts/packages/knowPro/src/secondaryIndexes.ts @@ -9,6 +9,7 @@ import { IndexingEventHandlers, Term, } from "./interfaces.js"; +import { IMessageTextIndexData } from "./messageIndex.js"; import { PropertyIndex, buildPropertyIndex } from "./propertyIndex.js"; import { buildRelatedTermsIndex, @@ -69,5 +70,6 @@ export interface ITextEmbeddingIndexData { export interface IConversationDataWithIndexes extends IConversationData { relatedTermsIndexData?: ITermsToRelatedTermsIndexData | undefined; - threadData?: IConversationThreadData; + threadData?: IConversationThreadData | undefined; + messageIndexData?: IMessageTextIndexData | undefined; } diff --git a/ts/packages/knowPro/src/textLocationIndex.ts b/ts/packages/knowPro/src/textLocationIndex.ts index a0719226b..6dfc64aa9 100644 --- a/ts/packages/knowPro/src/textLocationIndex.ts +++ b/ts/packages/knowPro/src/textLocationIndex.ts @@ -1,10 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -import { IMessage, MessageIndex, TextLocation } from "./interfaces.js"; +import { ListIndexingResult, TextLocation } from "./interfaces.js"; import { IndexingEventHandlers } from "./interfaces.js"; import { - TextEmbeddingIndex, + addTextBatchToEmbeddingIndex, + addTextToEmbeddingIndex, + EmbeddingIndex, + indexOfNearestTextInIndex, + indexOfNearestTextInIndexSubset, TextEmbeddingIndexSettings, } from "./fuzzyIndex.js"; @@ -13,12 +17,15 @@ export type ScoredTextLocation = { textLocation: TextLocation; }; -export interface ITextToTextLocationIndexFuzzy { - addTextLocation(text: string, textLocation: TextLocation): Promise; - addTextLocationsBatched( +export interface ITextToTextLocationIndex { + addTextLocation( + text: string, + textLocation: TextLocation, + ): Promise; + addTextLocations( textAndLocations: [string, TextLocation][], eventHandler?: IndexingEventHandlers, - ): Promise; + ): Promise; lookupText( text: string, maxMatches?: number, @@ -34,36 +41,58 @@ export interface ITextToTextLocationIndexData { embeddings: Float32Array[]; } -export class TextToTextLocationIndexFuzzy - implements ITextToTextLocationIndexFuzzy -{ +export class TextToTextLocationIndex implements ITextToTextLocationIndex { private textLocations: TextLocation[]; - private embeddingIndex: TextEmbeddingIndex; + private embeddingIndex: EmbeddingIndex; - constructor(settings: TextEmbeddingIndexSettings) { + constructor(public settings: TextEmbeddingIndexSettings) { this.textLocations = []; - this.embeddingIndex = new TextEmbeddingIndex(settings); + this.embeddingIndex = new EmbeddingIndex(); + } + + public get size(): number { + return this.embeddingIndex.size; + } + + public get(pos: number): TextLocation { + return this.textLocations[pos]; } public async addTextLocation( text: string, textLocation: TextLocation, - ): Promise { - await this.embeddingIndex.addText(text); - this.textLocations.push(textLocation); + ): Promise { + const result = await addTextToEmbeddingIndex( + this.embeddingIndex, + this.settings.embeddingModel, + [text], + ); + if (result.numberCompleted > 0) { + this.textLocations.push(textLocation); + } + return result; } - public async addTextLocationsBatched( + public async addTextLocations( textAndLocations: [string, TextLocation][], eventHandler?: IndexingEventHandlers, batchSize?: number, - ): Promise { - await this.embeddingIndex.addTextBatch( + ): Promise { + const result = await addTextBatchToEmbeddingIndex( + this.embeddingIndex, + this.settings.embeddingModel, textAndLocations.map((tl) => tl[0]), + batchSize ?? this.settings.batchSize, eventHandler, - batchSize, ); - this.textLocations.push(...textAndLocations.map((tl) => tl[1])); + if (result.numberCompleted > 0) { + textAndLocations = + result.numberCompleted === textAndLocations.length + ? textAndLocations + : textAndLocations.slice(0, result.numberCompleted); + this.textLocations.push(...textAndLocations.map((tl) => tl[1])); + } + return result; } public async lookupText( @@ -71,8 +100,32 @@ export class TextToTextLocationIndexFuzzy maxMatches?: number, thresholdScore?: number, ): Promise { - const matches = await this.embeddingIndex.getIndexesOfNearest( + const matches = await indexOfNearestTextInIndex( + this.embeddingIndex, + this.settings.embeddingModel, + text, + maxMatches, + thresholdScore, + ); + return matches.map((m) => { + return { + textLocation: this.textLocations[m.item], + score: m.score, + }; + }); + } + + public async lookupTextInSubset( + text: string, + indicesToSearch: number[], + maxMatches?: number, + thresholdScore?: number, + ): Promise { + const matches = await indexOfNearestTextInIndexSubset( + this.embeddingIndex, + this.settings.embeddingModel, text, + indicesToSearch, maxMatches, thresholdScore, ); @@ -84,6 +137,11 @@ export class TextToTextLocationIndexFuzzy }); } + public clear(): void { + this.textLocations = []; + this.embeddingIndex.clear(); + } + public serialize(): ITextToTextLocationIndexData { return { textLocations: this.textLocations, @@ -101,51 +159,3 @@ export class TextToTextLocationIndexFuzzy this.embeddingIndex.deserialize(data.embeddings); } } - -export async function addMessagesToIndex( - textLocationIndex: TextToTextLocationIndexFuzzy, - messages: IMessage[], - baseMessageIndex: MessageIndex, - eventHandler?: IndexingEventHandlers, - batchSize?: number, -): Promise { - const allChunks: [string, TextLocation][] = []; - // Collect everything so we can batch efficiently - for (let i = 0; i < messages.length; ++i) { - const message = messages[i]; - let messageIndex = baseMessageIndex + i; - for ( - let chunkIndex = 0; - chunkIndex < message.textChunks.length; - ++chunkIndex - ) { - allChunks.push([ - message.textChunks[chunkIndex], - { messageIndex, chunkIndex }, - ]); - } - } - // Todo: return an IndexingResult - await textLocationIndex.addTextLocationsBatched( - allChunks, - eventHandler, - batchSize, - ); -} - -export async function buildMessageIndex( - messages: IMessage[], - settings: TextEmbeddingIndexSettings, - eventHandler?: IndexingEventHandlers, - batchSize?: number, -) { - const textLocationIndex = new TextToTextLocationIndexFuzzy(settings); - await addMessagesToIndex( - textLocationIndex, - messages, - 0, - eventHandler, - batchSize, - ); - return textLocationIndex; -} diff --git a/ts/packages/memory/conversation/src/importPodcast.ts b/ts/packages/memory/conversation/src/importPodcast.ts index d8035c9b3..6eac03835 100644 --- a/ts/packages/memory/conversation/src/importPodcast.ts +++ b/ts/packages/memory/conversation/src/importPodcast.ts @@ -20,8 +20,10 @@ import { IConversationDataWithIndexes, writeConversationDataToFile, readConversationDataFromFile, - TextToTextLocationIndexFuzzy, buildMessageIndex, + MessageTextIndexSettings, + MessageTextIndex, + ListIndexingResult, } from "knowpro"; import { conversation as kpLib, split } from "knowledge-processor"; import { collections, dateTime, getFileName, readAllText } from "typeagent"; @@ -109,10 +111,7 @@ export class Podcast implements IConversation { public settings: ConversationSettings; public semanticRefIndex: ConversationIndex; public secondaryIndexes: PodcastSecondaryIndexes; - /** - * Work in progress - */ - public messageIndex?: TextToTextLocationIndexFuzzy | undefined; + public messageIndex?: MessageTextIndex | undefined; constructor( public nameTag: string = "", @@ -162,13 +161,19 @@ export class Podcast implements IConversation { public async buildMessageIndex( eventHandler?: IndexingEventHandlers, batchSize?: number, - ): Promise { - this.messageIndex = await buildMessageIndex( - this.messages, - this.settings.relatedTermIndexSettings.embeddingIndexSettings!, + ): Promise { + const settings: MessageTextIndexSettings = { + ...this.settings.messageTextIndexSettings, + }; + if (batchSize !== undefined && batchSize > 0) { + settings.embeddingIndexSettings.batchSize = batchSize; + } + const indexingResult = await buildMessageIndex( + this, + settings, eventHandler, - batchSize, ); + return indexingResult; } public async serialize(): Promise { diff --git a/ts/packages/typeagent/src/lib/array.ts b/ts/packages/typeagent/src/lib/array.ts index 7e3f24664..f1f24a745 100644 --- a/ts/packages/typeagent/src/lib/array.ts +++ b/ts/packages/typeagent/src/lib/array.ts @@ -355,3 +355,59 @@ export class CircularArray implements Iterable { --this.count; } } + +export class SortedArray implements Iterable { + private array: T[]; + private sorted: boolean; + + constructor(public itemComparer: (x: T, y: T) => number, array?: T[] | undefined, isSorted: boolean = false) { + this.array = array ?? []; + this.sorted = isSorted; + } + public get length(): number { + return this.array.length; + } + + public get(index: number): T { + this.ensureSorted(); + return this.array[index]; + } + + public set(index: number, value: T): void { + this.array[index] = value; + } + + public push(value: T | T[]): void { + if (Array.isArray(value)) { + this.array.push(...value); + } + else { + this.array.push(value); + } + this.sorted = false; + } + + public indexOf(value: T): number { + this.ensureSorted(); + return binarySearch(this.array, value, this.itemComparer); + } + + public findIndex(value: T, compareFn: (x: T, other: V) => number): number { + this.ensureSorted(); + return binarySearch(this.array, value, compareFn); + } + + public *[Symbol.iterator](): Iterator { + this.ensureSorted(); + for (let i = 0; i < this.array.length; ++i) { + yield this.array[i]; + } + } + + private ensureSorted(): void { + if (!this.sorted) { + this.array.sort(this.itemComparer); + this.sorted = true; + } + } +} \ No newline at end of file