Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 51 additions & 30 deletions src/query/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ import {
} from "./response/response-mapping.js";
import { mapApiResponse } from "./response/api-response-mapping.js";
import { fetchServerSentEvents } from "./response/server-sent-events.js";
import { mapCollections, QueryAgentCollectionConfig } from "./collection.js";
import {
mapCollections,
QueryAgentCollection,
QueryAgentCollectionConfig,
} from "./collection.js";
import { handleError } from "./response/error.js";
import { QueryAgentSearcher } from "./search.js";
import { SearchModeResponse } from "./response/response.js";
import { getHeaders } from "./connection.js";

/**
* An agent for executing agentic queries against Weaviate.
Expand Down Expand Up @@ -97,32 +102,22 @@ export class QueryAgent {
/**
* Ask query agent a question.
*
* @param query - The natural language query string for the agent.
* @param query - The natural language query string or conversation context for the agent.
* @param options - Additional options for the run.
* @returns The response from the query agent.
*/
async ask(
query: string,
query: QueryAgentQuery,
{ collections }: QueryAgentAskOptions = {},
): Promise<QueryAgentResponse> {
const targetCollections = collections ?? this.collections;
if (!targetCollections) {
throw Error("No collections provided to the query agent.");
}

const { host, bearerToken, headers } =
await this.client.getConnectionDetails();
const targetCollections = this.validateCollections(collections);
const { requestHeaders, connectionHeaders } = await getHeaders(this.client);

const response = await fetch(`${this.agentsHost}/agent/query`, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: bearerToken!,
"X-Weaviate-Cluster-Url": host,
"X-Agent-Request-Origin": "typescript-client",
},
headers: requestHeaders,
body: JSON.stringify({
headers,
headers: connectionHeaders,
query,
collections: mapCollections(targetCollections),
system_prompt: this.systemPrompt,
Expand Down Expand Up @@ -238,40 +233,40 @@ export class QueryAgent {
/**
* Ask query agent a question and stream the response.
*
* @param query - The natural language query string for the agent.
* @param query - The natural language query string or conversation context for the agent.
* @param options - Additional options for the run.
* @returns The response from the query agent.
*/
askStream(
query: string,
query: QueryAgentQuery,
options: QueryAgentAskStreamOptions & {
includeProgress: false;
includeFinalState: false;
},
): AsyncGenerator<StreamedTokens>;
askStream(
query: string,
query: QueryAgentQuery,
options: QueryAgentAskStreamOptions & {
includeProgress: false;
includeFinalState?: true;
},
): AsyncGenerator<StreamedTokens | QueryAgentResponse>;
askStream(
query: string,
query: QueryAgentQuery,
options: QueryAgentAskStreamOptions & {
includeProgress?: true;
includeFinalState: false;
},
): AsyncGenerator<ProgressMessage | StreamedTokens>;
askStream(
query: string,
query: QueryAgentQuery,
options?: QueryAgentAskStreamOptions & {
includeProgress?: true;
includeFinalState?: true;
},
): AsyncGenerator<ProgressMessage | StreamedTokens | QueryAgentResponse>;
async *askStream(
query: string,
query: QueryAgentQuery,
{
collections,
includeProgress,
Expand Down Expand Up @@ -336,28 +331,54 @@ export class QueryAgent {
* reuses the same underlying searches to ensure consistency across pages.
*/
async search(
query: string,
query: QueryAgentQuery,
{ limit = 20, collections }: QueryAgentSearchOnlyOptions = {},
): Promise<SearchModeResponse> {
const searcher = new QueryAgentSearcher(this.client, query, {
collections: collections ?? this.collections,
systemPrompt: this.systemPrompt,
agentsHost: this.agentsHost,
});
const searcher = new QueryAgentSearcher(
this.client,
query,
this.validateCollections(collections),
this.systemPrompt,
this.agentsHost,
);

return searcher.run({ limit, offset: 0 });
}

private validateCollections = (
collections?: QueryAgentCollection[],
): QueryAgentCollection[] => {
const targetCollections = collections ?? this.collections;

if (!targetCollections) {
throw Error("No collections provided to the query agent.");
}

return targetCollections;
};
}

/** Options for the QueryAgent. */
export type QueryAgentOptions = {
/** List of collections to query. Will be overriden if passed in the `run` method. */
collections?: (string | QueryAgentCollectionConfig)[];
collections?: QueryAgentCollection[];
/** System prompt to guide the agent's behavior. */
systemPrompt?: string;
/** Host of the agents service. */
agentsHost?: string;
};

export type QueryAgentQuery = string | ConversationContext;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to make this string | ChatMessage[] and then wrap in ConversationContext just for the request, so the user doesn't need to include the {messages: ...} (to match the python client)?


export type ConversationContext = {
messages: ChatMessage[];
};

export type ChatMessage = {
role: "user" | "assistant";
content: string;
};

/** Options for the QueryAgent run. */
export type QueryAgentRunOptions = {
/** List of collections to query. Will override any collections if passed in the constructor. */
Expand Down
2 changes: 2 additions & 0 deletions src/query/collection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ export const mapCollections = (
},
);

export type QueryAgentCollection = string | QueryAgentCollectionConfig;

/** Configuration for a collection to query. */
export type QueryAgentCollectionConfig = {
/** The name of the collection to query. */
Expand Down
18 changes: 18 additions & 0 deletions src/query/connection.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { WeaviateClient } from "weaviate-client";

export const getHeaders = async (client: WeaviateClient) => {
const {
host,
bearerToken,
headers: connectionHeaders,
} = await client.getConnectionDetails();

const requestHeaders = {
"Content-Type": "application/json",
Authorization: bearerToken!,
"X-Weaviate-Cluster-Url": host,
"X-Agent-Request-Origin": "typescript-client",
};

return { requestHeaders, connectionHeaders };
};
6 changes: 4 additions & 2 deletions src/query/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
export * from "./agent.js";
export { QueryAgentCollectionConfig } from "./collection.js";
export {
QueryAgentCollection,
QueryAgentCollectionConfig,
} from "./collection.js";
export * from "./response/index.js";
export * from "./search.js";
46 changes: 10 additions & 36 deletions src/query/search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ import {
SearchModeResponse,
} from "./response/response.js";
import { mapSearchOnlyResponse } from "./response/response-mapping.js";
import { mapCollections, QueryAgentCollectionConfig } from "./collection.js";
import { mapCollections } from "./collection.js";
import { handleError } from "./response/error.js";
import {
ApiSearchModeResponse,
ApiSearchResult,
} from "./response/api-response.js";
import { QueryAgentQuery } from "./agent.js";
import { QueryAgentCollection } from "./collection.js";
import { getHeaders } from "./connection.js";

/**
* A configured searcher for the QueryAgent.
Expand All @@ -25,44 +28,15 @@ import {
* For more information, see the [Weaviate Query Agent Docs](https://weaviate.io/developers/agents/query)
*/
export class QueryAgentSearcher {
private agentsHost: string;
private query: string;
private collections: (string | QueryAgentCollectionConfig)[];
private systemPrompt?: string;
private cachedSearches?: ApiSearchResult[];

constructor(
private client: WeaviateClient,
query: string,
{
collections = [],
systemPrompt,
agentsHost = "https://api.agents.weaviate.io",
}: {
collections?: (string | QueryAgentCollectionConfig)[];
systemPrompt?: string;
agentsHost?: string;
} = {},
) {
this.query = query;
this.collections = collections;
this.systemPrompt = systemPrompt;
this.agentsHost = agentsHost;
this.cachedSearches = undefined;
}

private async getHeaders() {
const { host, bearerToken, headers } =
await this.client.getConnectionDetails();
const requestHeaders = {
"Content-Type": "application/json",
Authorization: bearerToken!,
"X-Weaviate-Cluster-Url": host,
"X-Agent-Request-Origin": "typescript-client",
};
const connectionHeaders = headers;
return { requestHeaders, connectionHeaders };
}
private query: QueryAgentQuery,
private collections: QueryAgentCollection[],
private systemPrompt: string | undefined,
private agentsHost: string,
) {}

private buildRequestBody(
limit: number,
Expand Down Expand Up @@ -108,7 +82,7 @@ export class QueryAgentSearcher {
if (!this.collections || this.collections.length === 0) {
throw Error("No collections provided to the query agent.");
}
const { requestHeaders, connectionHeaders } = await this.getHeaders();
const { requestHeaders, connectionHeaders } = await getHeaders(this.client);

const response = await fetch(`${this.agentsHost}/agent/search_only`, {
method: "POST",
Expand Down