Skip to content

Commit 820e581

Browse files
danmichaeljonesDan Jones
andauthored
Add search-only mode (#29)
* Add new types for search only mode * Add search-only executor class * Add method to agent to build searcher * Lint * Make searcher generic to support generic WeaviateReturn * Update tests * Lint * Tidy up * Rename methods * Update to new search/next design * Lint * Address PR comments * Add mapping of search results * Remove MappedSearchModeResponse * Remove default options * Remove todo comments * Update docs * Lint * Update searcher docs * PR comments --------- Co-authored-by: Dan Jones <[email protected]>
1 parent 3db0d2d commit 820e581

File tree

7 files changed

+540
-9
lines changed

7 files changed

+540
-9
lines changed

src/query/agent.test.ts

Lines changed: 222 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import { WeaviateClient } from "weaviate-client";
22
import { QueryAgent } from "./agent.js";
33
import { ApiQueryAgentResponse } from "./response/api-response.js";
4-
import { QueryAgentResponse } from "./response/response.js";
4+
import { QueryAgentResponse, ComparisonOperator } from "./response/response.js";
5+
import { ApiSearchModeResponse } from "./response/api-response.js";
6+
import { QueryAgentError } from "./response/error.js";
57

68
it("runs the query agent", async () => {
79
const mockClient = {
@@ -93,3 +95,222 @@ it("runs the query agent", async () => {
9395
display: expect.any(Function),
9496
});
9597
});
98+
99+
it("search-only mode success: caches searches and sends on subsequent request", async () => {
100+
const mockClient = {
101+
getConnectionDetails: jest.fn().mockResolvedValue({
102+
host: "test-cluster",
103+
bearerToken: "test-token",
104+
headers: { "X-Provider": "test-key" },
105+
}),
106+
} as unknown as WeaviateClient;
107+
108+
const capturedBodies: ApiSearchModeResponse[] = [];
109+
110+
const apiSuccess: ApiSearchModeResponse = {
111+
original_query: "Test this search only mode!",
112+
searches: [
113+
{
114+
queries: ["search query"],
115+
filters: [
116+
[
117+
{
118+
filter_type: "integer",
119+
property_name: "test_property",
120+
operator: ComparisonOperator.GreaterThan,
121+
value: 0,
122+
},
123+
],
124+
],
125+
filter_operators: "AND",
126+
collection: "test_collection",
127+
},
128+
],
129+
usage: {
130+
requests: 0,
131+
request_tokens: undefined,
132+
response_tokens: undefined,
133+
total_tokens: undefined,
134+
details: undefined,
135+
},
136+
total_time: 1.5,
137+
search_results: {
138+
objects: [
139+
{
140+
uuid: "e6dc0a31-76f8-4bd3-b563-677ced6eb557",
141+
metadata: {
142+
creation_time: null,
143+
update_time: null,
144+
distance: null,
145+
certainty: null,
146+
score: 0.8,
147+
explain_score: null,
148+
rerank_score: null,
149+
is_consistent: null,
150+
},
151+
references: null,
152+
vector: {},
153+
properties: {
154+
test_property: 1.0,
155+
text: "hello",
156+
},
157+
collection: "test_collection",
158+
},
159+
{
160+
uuid: "cf5401cc-f4f1-4eb9-a6a1-173d34f94339",
161+
metadata: {
162+
creation_time: null,
163+
update_time: null,
164+
distance: null,
165+
certainty: null,
166+
score: 0.5,
167+
explain_score: null,
168+
rerank_score: null,
169+
is_consistent: null,
170+
},
171+
references: null,
172+
vector: {},
173+
properties: {
174+
test_property: 2.0,
175+
text: "world!",
176+
},
177+
collection: "test_collection",
178+
},
179+
],
180+
},
181+
};
182+
183+
// Mock the API response, and capture the request body to assert later
184+
global.fetch = jest.fn((url, init?: RequestInit) => {
185+
if (init && init.body) {
186+
capturedBodies.push(
187+
JSON.parse(init.body as string) as ApiSearchModeResponse,
188+
);
189+
}
190+
return Promise.resolve({
191+
ok: true,
192+
json: () => Promise.resolve(apiSuccess),
193+
} as Response);
194+
}) as jest.Mock;
195+
196+
const agent = new QueryAgent(mockClient);
197+
198+
const first = await agent.search("test query", {
199+
limit: 2,
200+
collections: ["test_collection"],
201+
});
202+
expect(first).toMatchObject({
203+
originalQuery: apiSuccess.original_query,
204+
searches: [
205+
{
206+
collection: "test_collection",
207+
queries: ["search query"],
208+
filters: [
209+
[
210+
{
211+
filterType: "integer",
212+
propertyName: "test_property",
213+
operator: ComparisonOperator.GreaterThan,
214+
value: 0,
215+
},
216+
],
217+
],
218+
filterOperators: "AND",
219+
},
220+
],
221+
usage: {
222+
requests: 0,
223+
requestTokens: undefined,
224+
responseTokens: undefined,
225+
totalTokens: undefined,
226+
details: undefined,
227+
},
228+
totalTime: 1.5,
229+
searchResults: {
230+
objects: [
231+
{
232+
uuid: "e6dc0a31-76f8-4bd3-b563-677ced6eb557",
233+
metadata: {
234+
score: 0.8,
235+
},
236+
vectors: {},
237+
properties: {
238+
test_property: 1.0,
239+
text: "hello",
240+
},
241+
collection: "test_collection",
242+
},
243+
{
244+
uuid: "cf5401cc-f4f1-4eb9-a6a1-173d34f94339",
245+
metadata: {
246+
score: 0.5,
247+
},
248+
vectors: {},
249+
properties: {
250+
test_property: 2.0,
251+
text: "world!",
252+
},
253+
collection: "test_collection",
254+
},
255+
],
256+
},
257+
});
258+
expect(typeof first.next).toBe("function");
259+
260+
// First request should have searches: null (generation request)
261+
expect(capturedBodies[0].searches).toBeNull();
262+
263+
// Second request uses the next method on the first response
264+
const second = await first.next({ limit: 2, offset: 1 });
265+
// Second request should include the original searches (execution request)
266+
expect(capturedBodies[1].searches).toEqual(apiSuccess.searches);
267+
// Response mapping should be the same (because response is mocked)
268+
expect(second).toMatchObject({
269+
originalQuery: apiSuccess.original_query,
270+
searches: first.searches,
271+
usage: first.usage,
272+
totalTime: first.totalTime,
273+
searchResults: first.searchResults,
274+
});
275+
expect(typeof second.next).toBe("function");
276+
});
277+
278+
it("search-only mode failure propagates QueryAgentError", async () => {
279+
const mockClient = {
280+
getConnectionDetails: jest.fn().mockResolvedValue({
281+
host: "test-cluster",
282+
bearerToken: "test-token",
283+
headers: { "X-Provider": "test-key" },
284+
}),
285+
} as unknown as WeaviateClient;
286+
287+
const errorJson = {
288+
error: {
289+
message: "Test error message",
290+
code: "test_error_code",
291+
details: { info: "test detail" },
292+
},
293+
};
294+
295+
global.fetch = jest.fn(() =>
296+
Promise.resolve({
297+
ok: false,
298+
text: () => Promise.resolve(JSON.stringify(errorJson)),
299+
} as Response),
300+
) as jest.Mock;
301+
302+
const agent = new QueryAgent(mockClient);
303+
try {
304+
await agent.search("test query", {
305+
limit: 2,
306+
collections: ["test_collection"],
307+
});
308+
} catch (err) {
309+
expect(err).toBeInstanceOf(QueryAgentError);
310+
expect(err).toMatchObject({
311+
message: "Test error message",
312+
code: "test_error_code",
313+
details: { info: "test detail" },
314+
});
315+
}
316+
});

src/query/agent.ts

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import { mapApiResponse } from "./response/api-response-mapping.js";
1414
import { fetchServerSentEvents } from "./response/server-sent-events.js";
1515
import { mapCollections, QueryAgentCollectionConfig } from "./collection.js";
1616
import { handleError } from "./response/error.js";
17+
import { QueryAgentSearcher } from "./search.js";
18+
import { SearchModeResponse } from "./response/response.js";
1719

1820
/**
1921
* An agent for executing agentic queries against Weaviate.
@@ -185,6 +187,25 @@ export class QueryAgent {
185187
yield output;
186188
}
187189
}
190+
191+
/**
192+
* Run the Query Agent search-only mode.
193+
*
194+
* Sends the initial search request and returns the first page of results.
195+
* The returned response includes a `next` method for pagination which
196+
* reuses the same underlying searches to ensure consistency across pages.
197+
*/
198+
async search(
199+
query: string,
200+
{ limit = 20, collections }: QueryAgentSearchOnlyOptions = {},
201+
): Promise<SearchModeResponse> {
202+
const searcher = new QueryAgentSearcher(this.client, query, {
203+
collections: collections ?? this.collections,
204+
systemPrompt: this.systemPrompt,
205+
agentsHost: this.agentsHost,
206+
});
207+
return searcher.run({ limit, offset: 0 });
208+
}
188209
}
189210

190211
/** Options for the QueryAgent. */
@@ -216,3 +237,11 @@ export type QueryAgentStreamOptions = {
216237
/** Include final state in the stream. */
217238
includeFinalState?: boolean;
218239
};
240+
241+
/** Options for the QueryAgent search-only run. */
242+
export type QueryAgentSearchOnlyOptions = {
243+
/** The maximum number of results to return. */
244+
limit?: number;
245+
/** List of collections to query. Will override any collections if passed in the constructor. */
246+
collections?: (string | QueryAgentCollectionConfig)[];
247+
};

src/query/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
export * from "./agent.js";
22
export { QueryAgentCollectionConfig } from "./collection.js";
33
export * from "./response/index.js";
4+
export * from "./search.js";

src/query/response/api-response.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import { Vectors, WeaviateField } from "weaviate-client";
2+
13
import {
24
NumericMetrics,
35
TextMetrics,
@@ -177,3 +179,42 @@ export type ApiSource = {
177179
object_id: string;
178180
collection: string;
179181
};
182+
183+
export type ApiReturnMetadata = {
184+
creation_time: Date | null;
185+
update_time: Date | null;
186+
distance: number | null;
187+
certainty: number | null;
188+
score: number | null;
189+
explain_score: string | null;
190+
rerank_score: number | null;
191+
is_consistent: boolean | null;
192+
};
193+
194+
export type ApiWeaviateObject = {
195+
/** The returned properties of the object as untyped key-value pairs from the API. */
196+
properties: Record<string, WeaviateField>;
197+
/** The returned metadata of the object. */
198+
metadata: ApiReturnMetadata;
199+
/** The returned references of the object. */
200+
references: null;
201+
/** The UUID of the object. */
202+
uuid: string;
203+
/** The returned vectors of the object. */
204+
vector: Vectors;
205+
/** The collection this object belongs to. */
206+
collection: string;
207+
};
208+
209+
export type ApiWeaviateReturn = {
210+
/** The objects that were found by the query. */
211+
objects: ApiWeaviateObject[];
212+
};
213+
214+
export type ApiSearchModeResponse = {
215+
original_query: string;
216+
searches?: ApiSearchResult[];
217+
usage: ApiUsage;
218+
total_time: number;
219+
search_results: ApiWeaviateReturn;
220+
};

0 commit comments

Comments
 (0)