Skip to content

Commit 529666a

Browse files
authored
introduce ask/askStream and deprecate run/stream (#30)
* introduce ask/askStream and deprecate run/stream * remove context from ask functions
1 parent 820e581 commit 529666a

File tree

1 file changed

+156
-0
lines changed

1 file changed

+156
-0
lines changed

src/query/agent.ts

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ export class QueryAgent {
5353
/**
5454
* Run the query agent.
5555
*
56+
* @deprecated Use {@link ask} instead.
5657
* @param query - The natural language query string for the agent.
5758
* @param options - Additional options for the run.
5859
* @returns The response from the query agent.
@@ -93,9 +94,52 @@ export class QueryAgent {
9394
return mapResponse(await response.json());
9495
}
9596

97+
/**
98+
* Ask query agent a question.
99+
*
100+
* @param query - The natural language query string for the agent.
101+
* @param options - Additional options for the run.
102+
* @returns The response from the query agent.
103+
*/
104+
async ask(
105+
query: string,
106+
{ collections }: QueryAgentAskOptions = {},
107+
): Promise<QueryAgentResponse> {
108+
const targetCollections = collections ?? this.collections;
109+
if (!targetCollections) {
110+
throw Error("No collections provided to the query agent.");
111+
}
112+
113+
const { host, bearerToken, headers } =
114+
await this.client.getConnectionDetails();
115+
116+
const response = await fetch(`${this.agentsHost}/agent/query`, {
117+
method: "POST",
118+
headers: {
119+
"Content-Type": "application/json",
120+
Authorization: bearerToken!,
121+
"X-Weaviate-Cluster-Url": host,
122+
"X-Agent-Request-Origin": "typescript-client",
123+
},
124+
body: JSON.stringify({
125+
headers,
126+
query,
127+
collections: mapCollections(targetCollections),
128+
system_prompt: this.systemPrompt,
129+
}),
130+
});
131+
132+
if (!response.ok) {
133+
await handleError(await response.text());
134+
}
135+
136+
return mapResponse(await response.json());
137+
}
138+
96139
/**
97140
* Stream responses from the query agent.
98141
*
142+
* @deprecated Use {@link askStream} instead.
99143
* @param query - The natural language query string for the agent.
100144
* @param options - Additional options for the run.
101145
* @returns The response from the query agent.
@@ -107,20 +151,23 @@ export class QueryAgent {
107151
includeFinalState: false;
108152
},
109153
): AsyncGenerator<StreamedTokens>;
154+
/** @deprecated Use {@link askStream} instead. */
110155
stream(
111156
query: string,
112157
options: QueryAgentStreamOptions & {
113158
includeProgress: false;
114159
includeFinalState?: true;
115160
},
116161
): AsyncGenerator<StreamedTokens | QueryAgentResponse>;
162+
/** @deprecated Use {@link askStream} instead. */
117163
stream(
118164
query: string,
119165
options: QueryAgentStreamOptions & {
120166
includeProgress?: true;
121167
includeFinalState: false;
122168
},
123169
): AsyncGenerator<ProgressMessage | StreamedTokens>;
170+
/** @deprecated Use {@link askStream} instead. */
124171
stream(
125172
query: string,
126173
options?: QueryAgentStreamOptions & {
@@ -188,6 +235,99 @@ export class QueryAgent {
188235
}
189236
}
190237

238+
/**
239+
* Ask query agent a question and stream the response.
240+
*
241+
* @param query - The natural language query string for the agent.
242+
* @param options - Additional options for the run.
243+
* @returns The response from the query agent.
244+
*/
245+
askStream(
246+
query: string,
247+
options: QueryAgentAskStreamOptions & {
248+
includeProgress: false;
249+
includeFinalState: false;
250+
},
251+
): AsyncGenerator<StreamedTokens>;
252+
askStream(
253+
query: string,
254+
options: QueryAgentAskStreamOptions & {
255+
includeProgress: false;
256+
includeFinalState?: true;
257+
},
258+
): AsyncGenerator<StreamedTokens | QueryAgentResponse>;
259+
askStream(
260+
query: string,
261+
options: QueryAgentAskStreamOptions & {
262+
includeProgress?: true;
263+
includeFinalState: false;
264+
},
265+
): AsyncGenerator<ProgressMessage | StreamedTokens>;
266+
askStream(
267+
query: string,
268+
options?: QueryAgentAskStreamOptions & {
269+
includeProgress?: true;
270+
includeFinalState?: true;
271+
},
272+
): AsyncGenerator<ProgressMessage | StreamedTokens | QueryAgentResponse>;
273+
async *askStream(
274+
query: string,
275+
{
276+
collections,
277+
includeProgress,
278+
includeFinalState,
279+
}: QueryAgentAskStreamOptions = {},
280+
): AsyncGenerator<ProgressMessage | StreamedTokens | QueryAgentResponse> {
281+
const targetCollections = collections ?? this.collections;
282+
283+
if (!targetCollections) {
284+
throw Error("No collections provided to the query agent.");
285+
}
286+
287+
const { host, bearerToken, headers } =
288+
await this.client.getConnectionDetails();
289+
290+
const sseStream = fetchServerSentEvents(
291+
`${this.agentsHost}/agent/stream_query`,
292+
{
293+
method: "POST",
294+
headers: {
295+
"Content-Type": "application/json",
296+
Authorization: bearerToken!,
297+
"X-Weaviate-Cluster-Url": host,
298+
"X-Agent-Request-Origin": "typescript-client",
299+
},
300+
body: JSON.stringify({
301+
headers,
302+
query,
303+
collections: mapCollections(targetCollections),
304+
system_prompt: this.systemPrompt,
305+
include_progress: includeProgress ?? true,
306+
include_final_state: includeFinalState ?? true,
307+
}),
308+
},
309+
);
310+
311+
for await (const event of sseStream) {
312+
if (event.event === "error") {
313+
await handleError(event.data);
314+
}
315+
316+
let output: ProgressMessage | StreamedTokens | QueryAgentResponse;
317+
if (event.event === "progress_message") {
318+
output = mapProgressMessageFromSSE(event);
319+
} else if (event.event === "streamed_tokens") {
320+
output = mapStreamedTokensFromSSE(event);
321+
} else if (event.event === "final_state") {
322+
output = mapResponseFromSSE(event);
323+
} else {
324+
throw new Error(`Unexpected event type: ${event.event}: ${event.data}`);
325+
}
326+
327+
yield output;
328+
}
329+
}
330+
191331
/**
192332
* Run the Query Agent search-only mode.
193333
*
@@ -226,6 +366,12 @@ export type QueryAgentRunOptions = {
226366
context?: QueryAgentResponse;
227367
};
228368

369+
/** Options for the QueryAgent ask. */
370+
export type QueryAgentAskOptions = {
371+
/** List of collections to query. Will override any collections if passed in the constructor. */
372+
collections?: (string | QueryAgentCollectionConfig)[];
373+
};
374+
229375
/** Options for the QueryAgent stream. */
230376
export type QueryAgentStreamOptions = {
231377
/** List of collections to query. Will override any collections if passed in the constructor. */
@@ -238,6 +384,16 @@ export type QueryAgentStreamOptions = {
238384
includeFinalState?: boolean;
239385
};
240386

387+
/** Options for the QueryAgent askStream. */
388+
export type QueryAgentAskStreamOptions = {
389+
/** List of collections to query. Will override any collections if passed in the constructor. */
390+
collections?: (string | QueryAgentCollectionConfig)[];
391+
/** Include progress messages in the stream. */
392+
includeProgress?: boolean;
393+
/** Include final state in the stream. */
394+
includeFinalState?: boolean;
395+
};
396+
241397
/** Options for the QueryAgent search-only run. */
242398
export type QueryAgentSearchOnlyOptions = {
243399
/** The maximum number of results to return. */

0 commit comments

Comments
 (0)