diff --git a/src/query/agent.ts b/src/query/agent.ts index a37677e..e1d70dd 100644 --- a/src/query/agent.ts +++ b/src/query/agent.ts @@ -12,10 +12,15 @@ import { } from "./response/response-mapping.js"; import { mapApiResponse } from "./response/api-response-mapping.js"; import { fetchServerSentEvents } from "./response/server-sent-events.js"; -import { mapCollections, QueryAgentCollectionConfig } from "./collection.js"; +import { + mapCollections, + QueryAgentCollection, + QueryAgentCollectionConfig, +} from "./collection.js"; import { handleError } from "./response/error.js"; import { QueryAgentSearcher } from "./search.js"; import { SearchModeResponse } from "./response/response.js"; +import { getHeaders } from "./connection.js"; /** * An agent for executing agentic queries against Weaviate. @@ -97,33 +102,23 @@ export class QueryAgent { /** * Ask query agent a question. * - * @param query - The natural language query string for the agent. + * @param query - The natural language query string or conversation context for the agent. * @param options - Additional options for the run. * @returns The response from the query agent. */ async ask( - query: string, + query: QueryAgentQuery, { collections }: QueryAgentAskOptions = {}, ): Promise { - const targetCollections = collections ?? this.collections; - if (!targetCollections) { - throw Error("No collections provided to the query agent."); - } - - const { host, bearerToken, headers } = - await this.client.getConnectionDetails(); + const targetCollections = this.validateCollections(collections); + const { requestHeaders, connectionHeaders } = await getHeaders(this.client); const response = await fetch(`${this.agentsHost}/agent/query`, { method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: bearerToken!, - "X-Weaviate-Cluster-Url": host, - "X-Agent-Request-Origin": "typescript-client", - }, + headers: requestHeaders, body: JSON.stringify({ - headers, - query, + headers: connectionHeaders, + query: typeof query === "string" ? query : { messages: query }, collections: mapCollections(targetCollections), system_prompt: this.systemPrompt, }), @@ -238,40 +233,40 @@ export class QueryAgent { /** * Ask query agent a question and stream the response. * - * @param query - The natural language query string for the agent. + * @param query - The natural language query string or conversation context for the agent. * @param options - Additional options for the run. * @returns The response from the query agent. */ askStream( - query: string, + query: QueryAgentQuery, options: QueryAgentAskStreamOptions & { includeProgress: false; includeFinalState: false; }, ): AsyncGenerator; askStream( - query: string, + query: QueryAgentQuery, options: QueryAgentAskStreamOptions & { includeProgress: false; includeFinalState?: true; }, ): AsyncGenerator; askStream( - query: string, + query: QueryAgentQuery, options: QueryAgentAskStreamOptions & { includeProgress?: true; includeFinalState: false; }, ): AsyncGenerator; askStream( - query: string, + query: QueryAgentQuery, options?: QueryAgentAskStreamOptions & { includeProgress?: true; includeFinalState?: true; }, ): AsyncGenerator; async *askStream( - query: string, + query: QueryAgentQuery, { collections, includeProgress, @@ -299,7 +294,7 @@ export class QueryAgent { }, body: JSON.stringify({ headers, - query, + query: typeof query === "string" ? query : { messages: query }, collections: mapCollections(targetCollections), system_prompt: this.systemPrompt, include_progress: includeProgress ?? true, @@ -336,28 +331,50 @@ export class QueryAgent { * reuses the same underlying searches to ensure consistency across pages. */ async search( - query: string, + query: QueryAgentQuery, { limit = 20, collections }: QueryAgentSearchOnlyOptions = {}, ): Promise { - const searcher = new QueryAgentSearcher(this.client, query, { - collections: collections ?? this.collections, - systemPrompt: this.systemPrompt, - agentsHost: this.agentsHost, - }); + const searcher = new QueryAgentSearcher( + this.client, + query, + this.validateCollections(collections), + this.systemPrompt, + this.agentsHost, + ); + return searcher.run({ limit, offset: 0 }); } + + private validateCollections = ( + collections?: QueryAgentCollection[], + ): QueryAgentCollection[] => { + const targetCollections = collections ?? this.collections; + + if (!targetCollections) { + throw Error("No collections provided to the query agent."); + } + + return targetCollections; + }; } /** Options for the QueryAgent. */ export type QueryAgentOptions = { /** List of collections to query. Will be overriden if passed in the `run` method. */ - collections?: (string | QueryAgentCollectionConfig)[]; + collections?: QueryAgentCollection[]; /** System prompt to guide the agent's behavior. */ systemPrompt?: string; /** Host of the agents service. */ agentsHost?: string; }; +export type QueryAgentQuery = string | ChatMessage[]; + +export type ChatMessage = { + role: "user" | "assistant"; + content: string; +}; + /** Options for the QueryAgent run. */ export type QueryAgentRunOptions = { /** List of collections to query. Will override any collections if passed in the constructor. */ diff --git a/src/query/collection.ts b/src/query/collection.ts index f2eacf3..e1c4ffe 100644 --- a/src/query/collection.ts +++ b/src/query/collection.ts @@ -12,6 +12,8 @@ export const mapCollections = ( }, ); +export type QueryAgentCollection = string | QueryAgentCollectionConfig; + /** Configuration for a collection to query. */ export type QueryAgentCollectionConfig = { /** The name of the collection to query. */ diff --git a/src/query/connection.ts b/src/query/connection.ts new file mode 100644 index 0000000..fa716fb --- /dev/null +++ b/src/query/connection.ts @@ -0,0 +1,18 @@ +import { WeaviateClient } from "weaviate-client"; + +export const getHeaders = async (client: WeaviateClient) => { + const { + host, + bearerToken, + headers: connectionHeaders, + } = await client.getConnectionDetails(); + + const requestHeaders = { + "Content-Type": "application/json", + Authorization: bearerToken!, + "X-Weaviate-Cluster-Url": host, + "X-Agent-Request-Origin": "typescript-client", + }; + + return { requestHeaders, connectionHeaders }; +}; diff --git a/src/query/index.ts b/src/query/index.ts index abb4b6c..7533b62 100644 --- a/src/query/index.ts +++ b/src/query/index.ts @@ -1,4 +1,6 @@ export * from "./agent.js"; -export { QueryAgentCollectionConfig } from "./collection.js"; +export { + QueryAgentCollection, + QueryAgentCollectionConfig, +} from "./collection.js"; export * from "./response/index.js"; -export * from "./search.js"; diff --git a/src/query/search.ts b/src/query/search.ts index 164345d..3f75237 100644 --- a/src/query/search.ts +++ b/src/query/search.ts @@ -4,12 +4,15 @@ import { SearchModeResponse, } from "./response/response.js"; import { mapSearchOnlyResponse } from "./response/response-mapping.js"; -import { mapCollections, QueryAgentCollectionConfig } from "./collection.js"; +import { mapCollections } from "./collection.js"; import { handleError } from "./response/error.js"; import { ApiSearchModeResponse, ApiSearchResult, } from "./response/api-response.js"; +import { QueryAgentQuery } from "./agent.js"; +import { QueryAgentCollection } from "./collection.js"; +import { getHeaders } from "./connection.js"; /** * A configured searcher for the QueryAgent. @@ -25,44 +28,15 @@ import { * For more information, see the [Weaviate Query Agent Docs](https://weaviate.io/developers/agents/query) */ export class QueryAgentSearcher { - private agentsHost: string; - private query: string; - private collections: (string | QueryAgentCollectionConfig)[]; - private systemPrompt?: string; private cachedSearches?: ApiSearchResult[]; constructor( private client: WeaviateClient, - query: string, - { - collections = [], - systemPrompt, - agentsHost = "https://api.agents.weaviate.io", - }: { - collections?: (string | QueryAgentCollectionConfig)[]; - systemPrompt?: string; - agentsHost?: string; - } = {}, - ) { - this.query = query; - this.collections = collections; - this.systemPrompt = systemPrompt; - this.agentsHost = agentsHost; - this.cachedSearches = undefined; - } - - private async getHeaders() { - const { host, bearerToken, headers } = - await this.client.getConnectionDetails(); - const requestHeaders = { - "Content-Type": "application/json", - Authorization: bearerToken!, - "X-Weaviate-Cluster-Url": host, - "X-Agent-Request-Origin": "typescript-client", - }; - const connectionHeaders = headers; - return { requestHeaders, connectionHeaders }; - } + private query: QueryAgentQuery, + private collections: QueryAgentCollection[], + private systemPrompt: string | undefined, + private agentsHost: string, + ) {} private buildRequestBody( limit: number, @@ -71,7 +45,8 @@ export class QueryAgentSearcher { ) { const base = { headers: connectionHeaders, - original_query: this.query, + original_query: + typeof this.query === "string" ? this.query : { messages: this.query }, collections: mapCollections(this.collections), limit, offset, @@ -108,7 +83,7 @@ export class QueryAgentSearcher { if (!this.collections || this.collections.length === 0) { throw Error("No collections provided to the query agent."); } - const { requestHeaders, connectionHeaders } = await this.getHeaders(); + const { requestHeaders, connectionHeaders } = await getHeaders(this.client); const response = await fetch(`${this.agentsHost}/agent/search_only`, { method: "POST",