diff --git a/src/query/agent.test.ts b/src/query/agent.test.ts index bcafdc4..8445c8d 100644 --- a/src/query/agent.test.ts +++ b/src/query/agent.test.ts @@ -1,7 +1,9 @@ import { WeaviateClient } from "weaviate-client"; import { QueryAgent } from "./agent.js"; import { ApiQueryAgentResponse } from "./response/api-response.js"; -import { QueryAgentResponse } from "./response/response.js"; +import { QueryAgentResponse, ComparisonOperator } from "./response/response.js"; +import { ApiSearchModeResponse } from "./response/api-response.js"; +import { QueryAgentError } from "./response/error.js"; it("runs the query agent", async () => { const mockClient = { @@ -93,3 +95,222 @@ it("runs the query agent", async () => { display: expect.any(Function), }); }); + +it("search-only mode success: caches searches and sends on subsequent request", async () => { + const mockClient = { + getConnectionDetails: jest.fn().mockResolvedValue({ + host: "test-cluster", + bearerToken: "test-token", + headers: { "X-Provider": "test-key" }, + }), + } as unknown as WeaviateClient; + + const capturedBodies: ApiSearchModeResponse[] = []; + + const apiSuccess: ApiSearchModeResponse = { + original_query: "Test this search only mode!", + searches: [ + { + queries: ["search query"], + filters: [ + [ + { + filter_type: "integer", + property_name: "test_property", + operator: ComparisonOperator.GreaterThan, + value: 0, + }, + ], + ], + filter_operators: "AND", + collection: "test_collection", + }, + ], + usage: { + requests: 0, + request_tokens: undefined, + response_tokens: undefined, + total_tokens: undefined, + details: undefined, + }, + total_time: 1.5, + search_results: { + objects: [ + { + uuid: "e6dc0a31-76f8-4bd3-b563-677ced6eb557", + metadata: { + creation_time: null, + update_time: null, + distance: null, + certainty: null, + score: 0.8, + explain_score: null, + rerank_score: null, + is_consistent: null, + }, + references: null, + vector: {}, + properties: { + test_property: 1.0, + text: "hello", + }, + collection: "test_collection", + }, + { + uuid: "cf5401cc-f4f1-4eb9-a6a1-173d34f94339", + metadata: { + creation_time: null, + update_time: null, + distance: null, + certainty: null, + score: 0.5, + explain_score: null, + rerank_score: null, + is_consistent: null, + }, + references: null, + vector: {}, + properties: { + test_property: 2.0, + text: "world!", + }, + collection: "test_collection", + }, + ], + }, + }; + + // Mock the API response, and capture the request body to assert later + global.fetch = jest.fn((url, init?: RequestInit) => { + if (init && init.body) { + capturedBodies.push( + JSON.parse(init.body as string) as ApiSearchModeResponse, + ); + } + return Promise.resolve({ + ok: true, + json: () => Promise.resolve(apiSuccess), + } as Response); + }) as jest.Mock; + + const agent = new QueryAgent(mockClient); + + const first = await agent.search("test query", { + limit: 2, + collections: ["test_collection"], + }); + expect(first).toMatchObject({ + originalQuery: apiSuccess.original_query, + searches: [ + { + collection: "test_collection", + queries: ["search query"], + filters: [ + [ + { + filterType: "integer", + propertyName: "test_property", + operator: ComparisonOperator.GreaterThan, + value: 0, + }, + ], + ], + filterOperators: "AND", + }, + ], + usage: { + requests: 0, + requestTokens: undefined, + responseTokens: undefined, + totalTokens: undefined, + details: undefined, + }, + totalTime: 1.5, + searchResults: { + objects: [ + { + uuid: "e6dc0a31-76f8-4bd3-b563-677ced6eb557", + metadata: { + score: 0.8, + }, + vectors: {}, + properties: { + test_property: 1.0, + text: "hello", + }, + collection: "test_collection", + }, + { + uuid: "cf5401cc-f4f1-4eb9-a6a1-173d34f94339", + metadata: { + score: 0.5, + }, + vectors: {}, + properties: { + test_property: 2.0, + text: "world!", + }, + collection: "test_collection", + }, + ], + }, + }); + expect(typeof first.next).toBe("function"); + + // First request should have searches: null (generation request) + expect(capturedBodies[0].searches).toBeNull(); + + // Second request uses the next method on the first response + const second = await first.next({ limit: 2, offset: 1 }); + // Second request should include the original searches (execution request) + expect(capturedBodies[1].searches).toEqual(apiSuccess.searches); + // Response mapping should be the same (because response is mocked) + expect(second).toMatchObject({ + originalQuery: apiSuccess.original_query, + searches: first.searches, + usage: first.usage, + totalTime: first.totalTime, + searchResults: first.searchResults, + }); + expect(typeof second.next).toBe("function"); +}); + +it("search-only mode failure propagates QueryAgentError", async () => { + const mockClient = { + getConnectionDetails: jest.fn().mockResolvedValue({ + host: "test-cluster", + bearerToken: "test-token", + headers: { "X-Provider": "test-key" }, + }), + } as unknown as WeaviateClient; + + const errorJson = { + error: { + message: "Test error message", + code: "test_error_code", + details: { info: "test detail" }, + }, + }; + + global.fetch = jest.fn(() => + Promise.resolve({ + ok: false, + text: () => Promise.resolve(JSON.stringify(errorJson)), + } as Response), + ) as jest.Mock; + + const agent = new QueryAgent(mockClient); + try { + await agent.search("test query", { + limit: 2, + collections: ["test_collection"], + }); + } catch (err) { + expect(err).toBeInstanceOf(QueryAgentError); + expect(err).toMatchObject({ + message: "Test error message", + code: "test_error_code", + details: { info: "test detail" }, + }); + } +}); diff --git a/src/query/agent.ts b/src/query/agent.ts index 3bdcb83..1583aaa 100644 --- a/src/query/agent.ts +++ b/src/query/agent.ts @@ -14,6 +14,8 @@ import { mapApiResponse } from "./response/api-response-mapping.js"; import { fetchServerSentEvents } from "./response/server-sent-events.js"; import { mapCollections, QueryAgentCollectionConfig } from "./collection.js"; import { handleError } from "./response/error.js"; +import { QueryAgentSearcher } from "./search.js"; +import { SearchModeResponse } from "./response/response.js"; /** * An agent for executing agentic queries against Weaviate. @@ -185,6 +187,25 @@ export class QueryAgent { yield output; } } + + /** + * Run the Query Agent search-only mode. + * + * Sends the initial search request and returns the first page of results. + * The returned response includes a `next` method for pagination which + * reuses the same underlying searches to ensure consistency across pages. + */ + async search( + query: string, + { limit = 20, collections }: QueryAgentSearchOnlyOptions = {}, + ): Promise { + const searcher = new QueryAgentSearcher(this.client, query, { + collections: collections ?? this.collections, + systemPrompt: this.systemPrompt, + agentsHost: this.agentsHost, + }); + return searcher.run({ limit, offset: 0 }); + } } /** Options for the QueryAgent. */ @@ -216,3 +237,11 @@ export type QueryAgentStreamOptions = { /** Include final state in the stream. */ includeFinalState?: boolean; }; + +/** Options for the QueryAgent search-only run. */ +export type QueryAgentSearchOnlyOptions = { + /** The maximum number of results to return. */ + limit?: number; + /** List of collections to query. Will override any collections if passed in the constructor. */ + collections?: (string | QueryAgentCollectionConfig)[]; +}; diff --git a/src/query/index.ts b/src/query/index.ts index 3ae6fd3..abb4b6c 100644 --- a/src/query/index.ts +++ b/src/query/index.ts @@ -1,3 +1,4 @@ export * from "./agent.js"; export { QueryAgentCollectionConfig } from "./collection.js"; export * from "./response/index.js"; +export * from "./search.js"; diff --git a/src/query/response/api-response.ts b/src/query/response/api-response.ts index bf420e7..2b36ea6 100644 --- a/src/query/response/api-response.ts +++ b/src/query/response/api-response.ts @@ -1,3 +1,5 @@ +import { Vectors, WeaviateField } from "weaviate-client"; + import { NumericMetrics, TextMetrics, @@ -177,3 +179,42 @@ export type ApiSource = { object_id: string; collection: string; }; + +export type ApiReturnMetadata = { + creation_time: Date | null; + update_time: Date | null; + distance: number | null; + certainty: number | null; + score: number | null; + explain_score: string | null; + rerank_score: number | null; + is_consistent: boolean | null; +}; + +export type ApiWeaviateObject = { + /** The returned properties of the object as untyped key-value pairs from the API. */ + properties: Record; + /** The returned metadata of the object. */ + metadata: ApiReturnMetadata; + /** The returned references of the object. */ + references: null; + /** The UUID of the object. */ + uuid: string; + /** The returned vectors of the object. */ + vector: Vectors; + /** The collection this object belongs to. */ + collection: string; +}; + +export type ApiWeaviateReturn = { + /** The objects that were found by the query. */ + objects: ApiWeaviateObject[]; +}; + +export type ApiSearchModeResponse = { + original_query: string; + searches?: ApiSearchResult[]; + usage: ApiUsage; + total_time: number; + search_results: ApiWeaviateReturn; +}; diff --git a/src/query/response/response-mapping.ts b/src/query/response/response-mapping.ts index 6cb4e35..07dd111 100644 --- a/src/query/response/response-mapping.ts +++ b/src/query/response/response-mapping.ts @@ -1,3 +1,5 @@ +import { ReturnMetadata } from "weaviate-client"; + import { QueryAgentResponse, SearchResult, @@ -9,6 +11,9 @@ import { StreamedTokens, ProgressMessage, DateFilterValue, + WeaviateObjectWithCollection, + WeaviateReturnWithCollection, + SearchModeResponse, } from "./response.js"; import { @@ -20,6 +25,9 @@ import { ApiUsage, ApiSource, ApiDateFilterValue, + ApiSearchModeResponse, + ApiWeaviateObject, + ApiWeaviateReturn, } from "./api-response.js"; import { ServerSentEvent } from "./server-sent-events.js"; @@ -47,15 +55,16 @@ export const mapResponse = ( }; }; +const mapInnerSearches = (searches: ApiSearchResult[]): SearchResult[] => + searches.map((result) => ({ + collection: result.collection, + queries: result.queries, + filters: result.filters.map(mapPropertyFilters), + filterOperators: result.filter_operators, + })); + const mapSearches = (searches: ApiSearchResult[][]): SearchResult[][] => - searches.map((searchGroup) => - searchGroup.map((result) => ({ - collection: result.collection, - queries: result.queries, - filters: result.filters.map(mapPropertyFilters), - filterOperators: result.filter_operators, - })), - ); + searches.map((searchGroup) => mapInnerSearches(searchGroup)); const mapDatePropertyFilter = ( filterValue: ApiDateFilterValue, @@ -298,3 +307,69 @@ export const mapResponseFromSSE = ( display: () => display(properties), }; }; + +const mapWeaviateObject = ( + object: ApiWeaviateObject, +): WeaviateObjectWithCollection => { + const metadata: ReturnMetadata = { + creationTime: + object.metadata.creation_time !== null + ? object.metadata.creation_time + : undefined, + updateTime: + object.metadata.update_time !== null + ? object.metadata.update_time + : undefined, + distance: + object.metadata.distance !== null ? object.metadata.distance : undefined, + certainty: + object.metadata.certainty !== null + ? object.metadata.certainty + : undefined, + score: object.metadata.score !== null ? object.metadata.score : undefined, + explainScore: + object.metadata.explain_score !== null + ? object.metadata.explain_score + : undefined, + rerankScore: + object.metadata.rerank_score !== null + ? object.metadata.rerank_score + : undefined, + isConsistent: + object.metadata.is_consistent !== null + ? object.metadata.is_consistent + : undefined, + }; + + return { + properties: object.properties, + metadata, + references: undefined, + uuid: object.uuid, + vectors: object.vector, + collection: object.collection, + }; +}; + +export const mapWeviateSearchResults = ( + response: ApiWeaviateReturn, +): WeaviateReturnWithCollection => ({ + objects: response.objects.map(mapWeaviateObject), +}); + +export const mapSearchOnlyResponse = ( + response: ApiSearchModeResponse, +): { + mappedResponse: Omit; + apiSearches: ApiSearchResult[] | undefined; +} => { + const apiSearches = response.searches; + const mappedResponse: Omit = { + originalQuery: response.original_query, + searches: apiSearches ? mapInnerSearches(apiSearches) : undefined, + usage: mapUsage(response.usage), + totalTime: response.total_time, + searchResults: mapWeviateSearchResults(response.search_results), + }; + return { mappedResponse, apiSearches }; +}; diff --git a/src/query/response/response.ts b/src/query/response/response.ts index e0400f7..6c64d66 100644 --- a/src/query/response/response.ts +++ b/src/query/response/response.ts @@ -1,3 +1,5 @@ +import { WeaviateReturn, WeaviateObject } from "weaviate-client"; + export type QueryAgentResponse = { outputType: "finalState"; originalQuery: string; @@ -260,3 +262,28 @@ export type StreamedTokens = { outputType: "streamedTokens"; delta: string; }; + +export type WeaviateObjectWithCollection = WeaviateObject & { + collection: string; +}; + +export type WeaviateReturnWithCollection = WeaviateReturn & { + objects: WeaviateObjectWithCollection[]; +}; + +/** Options for the executing a prepared QueryAgent search. */ +export type SearchExecutionOptions = { + /** The maximum number of results to return. */ + limit?: number; + /** The offset of the results to return, for paginating through query result sets. */ + offset: number; +}; + +export type SearchModeResponse = { + originalQuery: string; + searches?: SearchResult[]; + usage: Usage; + totalTime: number; + searchResults: WeaviateReturnWithCollection; + next: (options: SearchExecutionOptions) => Promise; +}; diff --git a/src/query/search.ts b/src/query/search.ts new file mode 100644 index 0000000..164345d --- /dev/null +++ b/src/query/search.ts @@ -0,0 +1,137 @@ +import { WeaviateClient } from "weaviate-client"; +import { + SearchExecutionOptions, + SearchModeResponse, +} from "./response/response.js"; +import { mapSearchOnlyResponse } from "./response/response-mapping.js"; +import { mapCollections, QueryAgentCollectionConfig } from "./collection.js"; +import { handleError } from "./response/error.js"; +import { + ApiSearchModeResponse, + ApiSearchResult, +} from "./response/api-response.js"; + +/** + * A configured searcher for the QueryAgent. + * + * This is used internally by the QueryAgent class to run search-mode queries. + * After the first request is made, the underlying searches are cached and can + * be reused for paginating through the a consistent set of results. + * + * Warning: + * Weaviate Agents - Query Agent is an early stage alpha product. + * The API is subject to breaking changes. Please ensure you are using the latest version of the client. + * + * For more information, see the [Weaviate Query Agent Docs](https://weaviate.io/developers/agents/query) + */ +export class QueryAgentSearcher { + private agentsHost: string; + private query: string; + private collections: (string | QueryAgentCollectionConfig)[]; + private systemPrompt?: string; + private cachedSearches?: ApiSearchResult[]; + + constructor( + private client: WeaviateClient, + query: string, + { + collections = [], + systemPrompt, + agentsHost = "https://api.agents.weaviate.io", + }: { + collections?: (string | QueryAgentCollectionConfig)[]; + systemPrompt?: string; + agentsHost?: string; + } = {}, + ) { + this.query = query; + this.collections = collections; + this.systemPrompt = systemPrompt; + this.agentsHost = agentsHost; + this.cachedSearches = undefined; + } + + private async getHeaders() { + const { host, bearerToken, headers } = + await this.client.getConnectionDetails(); + const requestHeaders = { + "Content-Type": "application/json", + Authorization: bearerToken!, + "X-Weaviate-Cluster-Url": host, + "X-Agent-Request-Origin": "typescript-client", + }; + const connectionHeaders = headers; + return { requestHeaders, connectionHeaders }; + } + + private buildRequestBody( + limit: number, + offset: number, + connectionHeaders: HeadersInit | undefined, + ) { + const base = { + headers: connectionHeaders, + original_query: this.query, + collections: mapCollections(this.collections), + limit, + offset, + } as const; + if (this.cachedSearches === undefined) { + return { + ...base, + searches: null, + system_prompt: this.systemPrompt || null, + }; + } + return { + ...base, + searches: this.cachedSearches, + }; + } + + /** + * Run the search-only agent with the given limit and offset values. + * + * Calling this method multiple times on the same QueryAgentSearcher instance will result + * in the same underlying searches being performed each time, allowing you to paginate + * over a consistent results set. + * + * @param [options] - Options for executing the search + * @param [options.limit] - The maximum number of results to return. Defaults to 20 if not specified. + * @param [options.offset] - The offset to start from. + * @returns A SearchModeResponse object containing the results, usage, and underlying searches performed. + */ + async run({ + limit = 20, + offset, + }: SearchExecutionOptions): Promise { + if (!this.collections || this.collections.length === 0) { + throw Error("No collections provided to the query agent."); + } + const { requestHeaders, connectionHeaders } = await this.getHeaders(); + + const response = await fetch(`${this.agentsHost}/agent/search_only`, { + method: "POST", + headers: requestHeaders, + body: JSON.stringify( + this.buildRequestBody(limit, offset, connectionHeaders), + ), + }); + if (!response.ok) { + await handleError(await response.text()); + } + const parsedResponse = (await response.json()) as ApiSearchModeResponse; + const { mappedResponse, apiSearches } = + mapSearchOnlyResponse(parsedResponse); + // If we successfully mapped the searches, cache them for the next request. + // Since this cache is a private internal value, there's not point in mapping + // back and forth between the exported and API types, so we cache apiSearches + if (mappedResponse.searches) { + this.cachedSearches = apiSearches; + } + return { + ...mappedResponse, + next: async (options: SearchExecutionOptions) => this.run(options), + }; + } +}