Skip to content
223 changes: 222 additions & 1 deletion src/query/agent.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { WeaviateClient } from "weaviate-client";
import { QueryAgent } from "./agent.js";
import { ApiQueryAgentResponse } from "./response/api-response.js";
import { QueryAgentResponse } from "./response/response.js";
import { QueryAgentResponse, ComparisonOperator } from "./response/response.js";
import { ApiSearchModeResponse } from "./response/api-response.js";
import { QueryAgentError } from "./response/error.js";

it("runs the query agent", async () => {
const mockClient = {
Expand Down Expand Up @@ -93,3 +95,222 @@ it("runs the query agent", async () => {
display: expect.any(Function),
});
});

it("search-only mode success: caches searches and sends on subsequent request", async () => {
const mockClient = {
getConnectionDetails: jest.fn().mockResolvedValue({
host: "test-cluster",
bearerToken: "test-token",
headers: { "X-Provider": "test-key" },
}),
} as unknown as WeaviateClient;

const capturedBodies: ApiSearchModeResponse[] = [];

const apiSuccess: ApiSearchModeResponse = {
original_query: "Test this search only mode!",
searches: [
{
queries: ["search query"],
filters: [
[
{
filter_type: "integer",
property_name: "test_property",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
],
],
filter_operators: "AND",
collection: "test_collection",
},
],
usage: {
requests: 0,
request_tokens: undefined,
response_tokens: undefined,
total_tokens: undefined,
details: undefined,
},
total_time: 1.5,
search_results: {
objects: [
{
uuid: "e6dc0a31-76f8-4bd3-b563-677ced6eb557",
metadata: {
creation_time: null,
update_time: null,
distance: null,
certainty: null,
score: 0.8,
explain_score: null,
rerank_score: null,
is_consistent: null,
},
references: null,
vector: {},
properties: {
test_property: 1.0,
text: "hello",
},
collection: "test_collection",
},
{
uuid: "cf5401cc-f4f1-4eb9-a6a1-173d34f94339",
metadata: {
creation_time: null,
update_time: null,
distance: null,
certainty: null,
score: 0.5,
explain_score: null,
rerank_score: null,
is_consistent: null,
},
references: null,
vector: {},
properties: {
test_property: 2.0,
text: "world!",
},
collection: "test_collection",
},
],
},
};

// Mock the API response, and capture the request body to assert later
global.fetch = jest.fn((url, init?: RequestInit) => {
if (init && init.body) {
capturedBodies.push(
JSON.parse(init.body as string) as ApiSearchModeResponse,
);
}
return Promise.resolve({
ok: true,
json: () => Promise.resolve(apiSuccess),
} as Response);
}) as jest.Mock;

const agent = new QueryAgent(mockClient);

const first = await agent.search("test query", {
limit: 2,
collections: ["test_collection"],
});
expect(first).toMatchObject({
originalQuery: apiSuccess.original_query,
searches: [
{
collection: "test_collection",
queries: ["search query"],
filters: [
[
{
filterType: "integer",
propertyName: "test_property",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
],
],
filterOperators: "AND",
},
],
usage: {
requests: 0,
requestTokens: undefined,
responseTokens: undefined,
totalTokens: undefined,
details: undefined,
},
totalTime: 1.5,
searchResults: {
objects: [
{
uuid: "e6dc0a31-76f8-4bd3-b563-677ced6eb557",
metadata: {
score: 0.8,
},
vectors: {},
properties: {
test_property: 1.0,
text: "hello",
},
collection: "test_collection",
},
{
uuid: "cf5401cc-f4f1-4eb9-a6a1-173d34f94339",
metadata: {
score: 0.5,
},
vectors: {},
properties: {
test_property: 2.0,
text: "world!",
},
collection: "test_collection",
},
],
},
});
expect(typeof first.next).toBe("function");

// First request should have searches: null (generation request)
expect(capturedBodies[0].searches).toBeNull();

// Second request uses the next method on the first response
const second = await first.next({ limit: 2, offset: 1 });
// Second request should include the original searches (execution request)
expect(capturedBodies[1].searches).toEqual(apiSuccess.searches);
// Response mapping should be the same (because response is mocked)
expect(second).toMatchObject({
originalQuery: apiSuccess.original_query,
searches: first.searches,
usage: first.usage,
totalTime: first.totalTime,
searchResults: first.searchResults,
});
expect(typeof second.next).toBe("function");
});

it("search-only mode failure propagates QueryAgentError", async () => {
const mockClient = {
getConnectionDetails: jest.fn().mockResolvedValue({
host: "test-cluster",
bearerToken: "test-token",
headers: { "X-Provider": "test-key" },
}),
} as unknown as WeaviateClient;

const errorJson = {
error: {
message: "Test error message",
code: "test_error_code",
details: { info: "test detail" },
},
};

global.fetch = jest.fn(() =>
Promise.resolve({
ok: false,
text: () => Promise.resolve(JSON.stringify(errorJson)),
} as Response),
) as jest.Mock;

const agent = new QueryAgent(mockClient);
try {
await agent.search("test query", {
limit: 2,
collections: ["test_collection"],
});
} catch (err) {
expect(err).toBeInstanceOf(QueryAgentError);
expect(err).toMatchObject({
message: "Test error message",
code: "test_error_code",
details: { info: "test detail" },
});
}
});
29 changes: 29 additions & 0 deletions src/query/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import { mapApiResponse } from "./response/api-response-mapping.js";
import { fetchServerSentEvents } from "./response/server-sent-events.js";
import { mapCollections, QueryAgentCollectionConfig } from "./collection.js";
import { handleError } from "./response/error.js";
import { QueryAgentSearcher } from "./search.js";
import { SearchModeResponse } from "./response/response.js";

/**
* An agent for executing agentic queries against Weaviate.
Expand Down Expand Up @@ -185,6 +187,25 @@ export class QueryAgent {
yield output;
}
}

/**
* Run the Query Agent search-only mode.
*
* Sends the initial search request and returns the first page of results.
* The returned response includes a `next` method for pagination which
* reuses the same underlying searches to ensure consistency across pages.
*/
async search(
query: string,
{ limit = 20, collections }: QueryAgentSearchOnlyOptions = {},
): Promise<SearchModeResponse> {
const searcher = new QueryAgentSearcher(this.client, query, {
collections: collections ?? this.collections,
systemPrompt: this.systemPrompt,
agentsHost: this.agentsHost,
});
return searcher.run({ limit, offset: 0 });
}
}

/** Options for the QueryAgent. */
Expand Down Expand Up @@ -216,3 +237,11 @@ export type QueryAgentStreamOptions = {
/** Include final state in the stream. */
includeFinalState?: boolean;
};

/** Options for the QueryAgent search-only run. */
export type QueryAgentSearchOnlyOptions = {
/** The maximum number of results to return. */
limit?: number;
/** List of collections to query. Will override any collections if passed in the constructor. */
collections?: (string | QueryAgentCollectionConfig)[];
};
1 change: 1 addition & 0 deletions src/query/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from "./agent.js";
export { QueryAgentCollectionConfig } from "./collection.js";
export * from "./response/index.js";
export * from "./search.js";
41 changes: 41 additions & 0 deletions src/query/response/api-response.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { Vectors, WeaviateField } from "weaviate-client";

import {
NumericMetrics,
TextMetrics,
Expand Down Expand Up @@ -177,3 +179,42 @@ export type ApiSource = {
object_id: string;
collection: string;
};

export type ApiReturnMetadata = {
creation_time: Date | null;
update_time: Date | null;
distance: number | null;
certainty: number | null;
score: number | null;
explain_score: string | null;
rerank_score: number | null;
is_consistent: boolean | null;
};

export type ApiWeaviateObject = {
/** The returned properties of the object as untyped key-value pairs from the API. */
properties: Record<string, WeaviateField>;
/** The returned metadata of the object. */
metadata: ApiReturnMetadata;
/** The returned references of the object. */
references: null;
/** The UUID of the object. */
uuid: string;
/** The returned vectors of the object. */
vector: Vectors;
/** The collection this object belongs to. */
collection: string;
};

export type ApiWeaviateReturn = {
/** The objects that were found by the query. */
objects: ApiWeaviateObject[];
};

export type ApiSearchModeResponse = {
original_query: string;
searches?: ApiSearchResult[];
usage: ApiUsage;
total_time: number;
search_results: ApiWeaviateReturn;
};
Loading