Skip to content

Commit a04bd6a

Browse files
authored
provide conversation context to the query agent (#31)
* provide conversation context to the query agent * simplify conversation context interface
1 parent 529666a commit a04bd6a

File tree

5 files changed

+85
-71
lines changed

5 files changed

+85
-71
lines changed

src/query/agent.ts

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,15 @@ import {
1212
} from "./response/response-mapping.js";
1313
import { mapApiResponse } from "./response/api-response-mapping.js";
1414
import { fetchServerSentEvents } from "./response/server-sent-events.js";
15-
import { mapCollections, QueryAgentCollectionConfig } from "./collection.js";
15+
import {
16+
mapCollections,
17+
QueryAgentCollection,
18+
QueryAgentCollectionConfig,
19+
} from "./collection.js";
1620
import { handleError } from "./response/error.js";
1721
import { QueryAgentSearcher } from "./search.js";
1822
import { SearchModeResponse } from "./response/response.js";
23+
import { getHeaders } from "./connection.js";
1924

2025
/**
2126
* An agent for executing agentic queries against Weaviate.
@@ -97,33 +102,23 @@ export class QueryAgent {
97102
/**
98103
* Ask query agent a question.
99104
*
100-
* @param query - The natural language query string for the agent.
105+
* @param query - The natural language query string or conversation context for the agent.
101106
* @param options - Additional options for the run.
102107
* @returns The response from the query agent.
103108
*/
104109
async ask(
105-
query: string,
110+
query: QueryAgentQuery,
106111
{ collections }: QueryAgentAskOptions = {},
107112
): Promise<QueryAgentResponse> {
108-
const targetCollections = collections ?? this.collections;
109-
if (!targetCollections) {
110-
throw Error("No collections provided to the query agent.");
111-
}
112-
113-
const { host, bearerToken, headers } =
114-
await this.client.getConnectionDetails();
113+
const targetCollections = this.validateCollections(collections);
114+
const { requestHeaders, connectionHeaders } = await getHeaders(this.client);
115115

116116
const response = await fetch(`${this.agentsHost}/agent/query`, {
117117
method: "POST",
118-
headers: {
119-
"Content-Type": "application/json",
120-
Authorization: bearerToken!,
121-
"X-Weaviate-Cluster-Url": host,
122-
"X-Agent-Request-Origin": "typescript-client",
123-
},
118+
headers: requestHeaders,
124119
body: JSON.stringify({
125-
headers,
126-
query,
120+
headers: connectionHeaders,
121+
query: typeof query === "string" ? query : { messages: query },
127122
collections: mapCollections(targetCollections),
128123
system_prompt: this.systemPrompt,
129124
}),
@@ -238,40 +233,40 @@ export class QueryAgent {
238233
/**
239234
* Ask query agent a question and stream the response.
240235
*
241-
* @param query - The natural language query string for the agent.
236+
* @param query - The natural language query string or conversation context for the agent.
242237
* @param options - Additional options for the run.
243238
* @returns The response from the query agent.
244239
*/
245240
askStream(
246-
query: string,
241+
query: QueryAgentQuery,
247242
options: QueryAgentAskStreamOptions & {
248243
includeProgress: false;
249244
includeFinalState: false;
250245
},
251246
): AsyncGenerator<StreamedTokens>;
252247
askStream(
253-
query: string,
248+
query: QueryAgentQuery,
254249
options: QueryAgentAskStreamOptions & {
255250
includeProgress: false;
256251
includeFinalState?: true;
257252
},
258253
): AsyncGenerator<StreamedTokens | QueryAgentResponse>;
259254
askStream(
260-
query: string,
255+
query: QueryAgentQuery,
261256
options: QueryAgentAskStreamOptions & {
262257
includeProgress?: true;
263258
includeFinalState: false;
264259
},
265260
): AsyncGenerator<ProgressMessage | StreamedTokens>;
266261
askStream(
267-
query: string,
262+
query: QueryAgentQuery,
268263
options?: QueryAgentAskStreamOptions & {
269264
includeProgress?: true;
270265
includeFinalState?: true;
271266
},
272267
): AsyncGenerator<ProgressMessage | StreamedTokens | QueryAgentResponse>;
273268
async *askStream(
274-
query: string,
269+
query: QueryAgentQuery,
275270
{
276271
collections,
277272
includeProgress,
@@ -299,7 +294,7 @@ export class QueryAgent {
299294
},
300295
body: JSON.stringify({
301296
headers,
302-
query,
297+
query: typeof query === "string" ? query : { messages: query },
303298
collections: mapCollections(targetCollections),
304299
system_prompt: this.systemPrompt,
305300
include_progress: includeProgress ?? true,
@@ -336,28 +331,50 @@ export class QueryAgent {
336331
* reuses the same underlying searches to ensure consistency across pages.
337332
*/
338333
async search(
339-
query: string,
334+
query: QueryAgentQuery,
340335
{ limit = 20, collections }: QueryAgentSearchOnlyOptions = {},
341336
): Promise<SearchModeResponse> {
342-
const searcher = new QueryAgentSearcher(this.client, query, {
343-
collections: collections ?? this.collections,
344-
systemPrompt: this.systemPrompt,
345-
agentsHost: this.agentsHost,
346-
});
337+
const searcher = new QueryAgentSearcher(
338+
this.client,
339+
query,
340+
this.validateCollections(collections),
341+
this.systemPrompt,
342+
this.agentsHost,
343+
);
344+
347345
return searcher.run({ limit, offset: 0 });
348346
}
347+
348+
private validateCollections = (
349+
collections?: QueryAgentCollection[],
350+
): QueryAgentCollection[] => {
351+
const targetCollections = collections ?? this.collections;
352+
353+
if (!targetCollections) {
354+
throw Error("No collections provided to the query agent.");
355+
}
356+
357+
return targetCollections;
358+
};
349359
}
350360

351361
/** Options for the QueryAgent. */
352362
export type QueryAgentOptions = {
353363
/** List of collections to query. Will be overriden if passed in the `run` method. */
354-
collections?: (string | QueryAgentCollectionConfig)[];
364+
collections?: QueryAgentCollection[];
355365
/** System prompt to guide the agent's behavior. */
356366
systemPrompt?: string;
357367
/** Host of the agents service. */
358368
agentsHost?: string;
359369
};
360370

371+
export type QueryAgentQuery = string | ChatMessage[];
372+
373+
export type ChatMessage = {
374+
role: "user" | "assistant";
375+
content: string;
376+
};
377+
361378
/** Options for the QueryAgent run. */
362379
export type QueryAgentRunOptions = {
363380
/** List of collections to query. Will override any collections if passed in the constructor. */

src/query/collection.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ export const mapCollections = (
1212
},
1313
);
1414

15+
export type QueryAgentCollection = string | QueryAgentCollectionConfig;
16+
1517
/** Configuration for a collection to query. */
1618
export type QueryAgentCollectionConfig = {
1719
/** The name of the collection to query. */

src/query/connection.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import { WeaviateClient } from "weaviate-client";
2+
3+
export const getHeaders = async (client: WeaviateClient) => {
4+
const {
5+
host,
6+
bearerToken,
7+
headers: connectionHeaders,
8+
} = await client.getConnectionDetails();
9+
10+
const requestHeaders = {
11+
"Content-Type": "application/json",
12+
Authorization: bearerToken!,
13+
"X-Weaviate-Cluster-Url": host,
14+
"X-Agent-Request-Origin": "typescript-client",
15+
};
16+
17+
return { requestHeaders, connectionHeaders };
18+
};

src/query/index.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
export * from "./agent.js";
2-
export { QueryAgentCollectionConfig } from "./collection.js";
2+
export {
3+
QueryAgentCollection,
4+
QueryAgentCollectionConfig,
5+
} from "./collection.js";
36
export * from "./response/index.js";
4-
export * from "./search.js";

src/query/search.ts

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@ import {
44
SearchModeResponse,
55
} from "./response/response.js";
66
import { mapSearchOnlyResponse } from "./response/response-mapping.js";
7-
import { mapCollections, QueryAgentCollectionConfig } from "./collection.js";
7+
import { mapCollections } from "./collection.js";
88
import { handleError } from "./response/error.js";
99
import {
1010
ApiSearchModeResponse,
1111
ApiSearchResult,
1212
} from "./response/api-response.js";
13+
import { QueryAgentQuery } from "./agent.js";
14+
import { QueryAgentCollection } from "./collection.js";
15+
import { getHeaders } from "./connection.js";
1316

1417
/**
1518
* A configured searcher for the QueryAgent.
@@ -25,44 +28,15 @@ import {
2528
* For more information, see the [Weaviate Query Agent Docs](https://weaviate.io/developers/agents/query)
2629
*/
2730
export class QueryAgentSearcher {
28-
private agentsHost: string;
29-
private query: string;
30-
private collections: (string | QueryAgentCollectionConfig)[];
31-
private systemPrompt?: string;
3231
private cachedSearches?: ApiSearchResult[];
3332

3433
constructor(
3534
private client: WeaviateClient,
36-
query: string,
37-
{
38-
collections = [],
39-
systemPrompt,
40-
agentsHost = "https://api.agents.weaviate.io",
41-
}: {
42-
collections?: (string | QueryAgentCollectionConfig)[];
43-
systemPrompt?: string;
44-
agentsHost?: string;
45-
} = {},
46-
) {
47-
this.query = query;
48-
this.collections = collections;
49-
this.systemPrompt = systemPrompt;
50-
this.agentsHost = agentsHost;
51-
this.cachedSearches = undefined;
52-
}
53-
54-
private async getHeaders() {
55-
const { host, bearerToken, headers } =
56-
await this.client.getConnectionDetails();
57-
const requestHeaders = {
58-
"Content-Type": "application/json",
59-
Authorization: bearerToken!,
60-
"X-Weaviate-Cluster-Url": host,
61-
"X-Agent-Request-Origin": "typescript-client",
62-
};
63-
const connectionHeaders = headers;
64-
return { requestHeaders, connectionHeaders };
65-
}
35+
private query: QueryAgentQuery,
36+
private collections: QueryAgentCollection[],
37+
private systemPrompt: string | undefined,
38+
private agentsHost: string,
39+
) {}
6640

6741
private buildRequestBody(
6842
limit: number,
@@ -71,7 +45,8 @@ export class QueryAgentSearcher {
7145
) {
7246
const base = {
7347
headers: connectionHeaders,
74-
original_query: this.query,
48+
original_query:
49+
typeof this.query === "string" ? this.query : { messages: this.query },
7550
collections: mapCollections(this.collections),
7651
limit,
7752
offset,
@@ -108,7 +83,7 @@ export class QueryAgentSearcher {
10883
if (!this.collections || this.collections.length === 0) {
10984
throw Error("No collections provided to the query agent.");
11085
}
111-
const { requestHeaders, connectionHeaders } = await this.getHeaders();
86+
const { requestHeaders, connectionHeaders } = await getHeaders(this.client);
11287

11388
const response = await fetch(`${this.agentsHost}/agent/search_only`, {
11489
method: "POST",

0 commit comments

Comments
 (0)