Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions src/query/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export class QueryAgent {
/**
* Run the query agent.
*
* @deprecated Use {@link ask} instead.
* @param query - The natural language query string for the agent.
* @param options - Additional options for the run.
* @returns The response from the query agent.
Expand Down Expand Up @@ -93,9 +94,52 @@ export class QueryAgent {
return mapResponse(await response.json());
}

/**
* Ask query agent a question.
*
* @param query - The natural language query string for the agent.
* @param options - Additional options for the run.
* @returns The response from the query agent.
*/
async ask(
query: string,
{ collections }: QueryAgentAskOptions = {},
): Promise<QueryAgentResponse> {
const targetCollections = collections ?? this.collections;
if (!targetCollections) {
throw Error("No collections provided to the query agent.");
}

const { host, bearerToken, headers } =
await this.client.getConnectionDetails();

const response = await fetch(`${this.agentsHost}/agent/query`, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: bearerToken!,
"X-Weaviate-Cluster-Url": host,
"X-Agent-Request-Origin": "typescript-client",
},
body: JSON.stringify({
headers,
query,
collections: mapCollections(targetCollections),
system_prompt: this.systemPrompt,
}),
});

if (!response.ok) {
await handleError(await response.text());
}

return mapResponse(await response.json());
}

/**
* Stream responses from the query agent.
*
* @deprecated Use {@link askStream} instead.
* @param query - The natural language query string for the agent.
* @param options - Additional options for the run.
* @returns The response from the query agent.
Expand All @@ -107,20 +151,23 @@ export class QueryAgent {
includeFinalState: false;
},
): AsyncGenerator<StreamedTokens>;
/** @deprecated Use {@link askStream} instead. */
stream(
query: string,
options: QueryAgentStreamOptions & {
includeProgress: false;
includeFinalState?: true;
},
): AsyncGenerator<StreamedTokens | QueryAgentResponse>;
/** @deprecated Use {@link askStream} instead. */
stream(
query: string,
options: QueryAgentStreamOptions & {
includeProgress?: true;
includeFinalState: false;
},
): AsyncGenerator<ProgressMessage | StreamedTokens>;
/** @deprecated Use {@link askStream} instead. */
stream(
query: string,
options?: QueryAgentStreamOptions & {
Expand Down Expand Up @@ -188,6 +235,99 @@ export class QueryAgent {
}
}

/**
* Ask query agent a question and stream the response.
*
* @param query - The natural language query string for the agent.
* @param options - Additional options for the run.
* @returns The response from the query agent.
*/
askStream(
query: string,
options: QueryAgentAskStreamOptions & {
includeProgress: false;
includeFinalState: false;
},
): AsyncGenerator<StreamedTokens>;
askStream(
query: string,
options: QueryAgentAskStreamOptions & {
includeProgress: false;
includeFinalState?: true;
},
): AsyncGenerator<StreamedTokens | QueryAgentResponse>;
askStream(
query: string,
options: QueryAgentAskStreamOptions & {
includeProgress?: true;
includeFinalState: false;
},
): AsyncGenerator<ProgressMessage | StreamedTokens>;
askStream(
query: string,
options?: QueryAgentAskStreamOptions & {
includeProgress?: true;
includeFinalState?: true;
},
): AsyncGenerator<ProgressMessage | StreamedTokens | QueryAgentResponse>;
async *askStream(
query: string,
{
collections,
includeProgress,
includeFinalState,
}: QueryAgentAskStreamOptions = {},
): AsyncGenerator<ProgressMessage | StreamedTokens | QueryAgentResponse> {
const targetCollections = collections ?? this.collections;

if (!targetCollections) {
throw Error("No collections provided to the query agent.");
}

const { host, bearerToken, headers } =
await this.client.getConnectionDetails();

const sseStream = fetchServerSentEvents(
`${this.agentsHost}/agent/stream_query`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: bearerToken!,
"X-Weaviate-Cluster-Url": host,
"X-Agent-Request-Origin": "typescript-client",
},
body: JSON.stringify({
headers,
query,
collections: mapCollections(targetCollections),
system_prompt: this.systemPrompt,
include_progress: includeProgress ?? true,
include_final_state: includeFinalState ?? true,
}),
},
);

for await (const event of sseStream) {
if (event.event === "error") {
await handleError(event.data);
}

let output: ProgressMessage | StreamedTokens | QueryAgentResponse;
if (event.event === "progress_message") {
output = mapProgressMessageFromSSE(event);
} else if (event.event === "streamed_tokens") {
output = mapStreamedTokensFromSSE(event);
} else if (event.event === "final_state") {
output = mapResponseFromSSE(event);
} else {
throw new Error(`Unexpected event type: ${event.event}: ${event.data}`);
}

yield output;
}
}

/**
* Run the Query Agent search-only mode.
*
Expand Down Expand Up @@ -226,6 +366,12 @@ export type QueryAgentRunOptions = {
context?: QueryAgentResponse;
};

/** Options for the QueryAgent ask. */
export type QueryAgentAskOptions = {
/** List of collections to query. Will override any collections if passed in the constructor. */
collections?: (string | QueryAgentCollectionConfig)[];
};

/** Options for the QueryAgent stream. */
export type QueryAgentStreamOptions = {
/** List of collections to query. Will override any collections if passed in the constructor. */
Expand All @@ -238,6 +384,16 @@ export type QueryAgentStreamOptions = {
includeFinalState?: boolean;
};

/** Options for the QueryAgent askStream. */
export type QueryAgentAskStreamOptions = {
/** List of collections to query. Will override any collections if passed in the constructor. */
collections?: (string | QueryAgentCollectionConfig)[];
/** Include progress messages in the stream. */
includeProgress?: boolean;
/** Include final state in the stream. */
includeFinalState?: boolean;
};

/** Options for the QueryAgent search-only run. */
export type QueryAgentSearchOnlyOptions = {
/** The maximum number of results to return. */
Expand Down