Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/nice-hairs-fetch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"agents": minor
---

Type-safe react stub streaming calls
9 changes: 6 additions & 3 deletions packages/agents/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ export type AgentClientOptions<State = unknown> = Omit<
/**
* Options for streaming RPC calls
*/
export type StreamOptions = {
export type StreamOptions<
OnChunkT extends unknown | SerializableValue = unknown,
OnDoneT extends unknown | SerializableValue = unknown
> = {
/** Called when a chunk of data is received */
onChunk?: (chunk: unknown) => void;
onChunk?: (chunk: OnChunkT) => void;
/** Called when the stream ends */
onDone?: (finalChunk: unknown) => void;
onDone?: (finalChunk: OnDoneT) => void;
/** Called when an error occurs */
onError?: (error: string) => void;
};
Expand Down
10 changes: 7 additions & 3 deletions packages/agents/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import type { TransportType } from "./mcp/types";
import { genericObservability, type Observability } from "./observability";
import { DisposableStore } from "./core/events";
import { MessageType } from "./ai-types";
import type { SerializableValue } from "./serializable";

export type { Connection, ConnectionContext, WSMessage } from "partyserver";

Expand Down Expand Up @@ -1982,7 +1983,10 @@ export async function getAgentByName<
/**
* A wrapper for streaming responses in callable methods
*/
export class StreamingResponse {
export class StreamingResponse<
OnChunkT extends SerializableValue | unknown = unknown,
OnDoneT extends SerializableValue | unknown = unknown
> {
private _connection: Connection;
private _id: string;
private _closed = false;
Expand All @@ -1996,7 +2000,7 @@ export class StreamingResponse {
* Send a chunk of data to the client
* @param chunk The data to send
*/
send(chunk: unknown) {
send(chunk: OnChunkT) {
if (this._closed) {
throw new Error("StreamingResponse is already closed");
}
Expand All @@ -2014,7 +2018,7 @@ export class StreamingResponse {
* End the stream and send the final chunk (if any)
* @param finalChunk Optional final chunk of data to send
*/
end(finalChunk?: unknown) {
end(finalChunk?: OnDoneT) {
if (this._closed) {
throw new Error("StreamingResponse is already closed");
}
Expand Down
176 changes: 166 additions & 10 deletions packages/agents/src/react.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ import { usePartySocket } from "partysocket/react";
import { useCallback, useRef, use, useMemo, useEffect } from "react";
import type { Agent, MCPServersState, RPCRequest, RPCResponse } from "./";
import type { StreamOptions } from "./client";
import type { Method, RPCMethod } from "./serializable";
import { MessageType } from "./ai-types";
import type {
AllSerializableValues,
SerializableReturnValue,
SerializableValue
} from "./serializable";

/**
* Convert a camelCase string to a kebab-case string
Expand Down Expand Up @@ -130,17 +134,76 @@ export type UseAgentOptions<State = unknown> = Omit<
onMcpUpdate?: (mcpServers: MCPServersState) => void;
};

// biome-ignore lint: suppressions/parse
type Method = (...args: any[]) => any;

type NonStreamingRPCMethod<T extends Method> =
AllSerializableValues<Parameters<T>> extends true
? ReturnType<T> extends SerializableReturnValue
? T
: never
: never;

interface StreamingResponse<
Chunk extends SerializableValue | unknown = unknown,
Done extends SerializableValue | unknown = unknown
> {
send(chunk: Chunk): void;
end(finalChunk?: Done): void;
}

type StreamingRPCMethod<T extends Method> = T extends (
arg: infer A,
...rest: infer R
) => void | Promise<void>
? A extends StreamingResponse<SerializableValue, SerializableValue>
? AllSerializableValues<R> extends true
? T
: never
: never
: never;

type RPCMethod<T extends Method> =
T extends NonStreamingRPCMethod<T>
? NonStreamingRPCMethod<T>
: T extends StreamingRPCMethod<T>
? StreamingRPCMethod<T>
: never;

type RPCMethods<T> = {
[K in keyof T as T[K] extends Method ? K : never]: T[K] extends Method
? RPCMethod<T[K]>
: never;
};

type AllOptional<T> = T extends [infer A, ...infer R]
? undefined extends A
? AllOptional<R>
: false
: true; // no params means optional by default

type RPCMethods<T> = {
[K in keyof T as T[K] extends RPCMethod<T[K]> ? K : never]: RPCMethod<T[K]>;
};
type StreamOptionsFrom<StreamingResponseT> =
StreamingResponseT extends StreamingResponse<
infer T extends SerializableValue,
infer U extends SerializableValue
>
? StreamOptions<T, U>
: never;

type ReturnAndChunkTypesFrom<StreamingResponseT extends StreamingResponse> =
StreamingResponseT extends StreamingResponse<
infer Chunk extends SerializableValue,
infer Done extends SerializableValue
>
? [Chunk, Done]
: never;

type RestParameters<T extends Method> =
Parameters<StreamingRPCMethod<T>> extends [unknown, ...infer Rest]
? Rest
: never;

type OptionalParametersMethod<T extends RPCMethod> =
type OptionalParametersMethod<T extends Method> =
AllOptional<Parameters<T>> extends true ? T : never;

// all methods of the Agent, excluding the ones that are declared in the base Agent class
Expand All @@ -160,6 +223,14 @@ type RequiredAgentMethods<T> = Omit<
keyof OptionalAgentMethods<T>
>;

type StreamingAgentMethods<T> = {
[K in keyof AgentMethods<T> as AgentMethods<T>[K] extends StreamingRPCMethod<
AgentMethods<T>[K]
>
? K
: never]: StreamingRPCMethod<AgentMethods<T>[K]>;
};

type AgentPromiseReturnType<T, K extends keyof AgentMethods<T>> =
// biome-ignore lint: suppressions/parse
ReturnType<AgentMethods<T>[K]> extends Promise<any>
Expand All @@ -182,7 +253,18 @@ type RequiredArgsAgentMethodCall<AgentT> = <
streamOptions?: StreamOptions
) => AgentPromiseReturnType<AgentT, K>;

type AgentMethodCall<AgentT> = OptionalArgsAgentMethodCall<AgentT> &
type StreamingAgentMethodCall<AgentT> = <
K extends keyof StreamingAgentMethods<AgentT>
>(
method: K,
args: RestParameters<StreamingAgentMethods<AgentT>[K]>,
streamOptions: StreamOptionsFrom<
Parameters<StreamingAgentMethods<AgentT>[K]>[0]
>
) => void;

type AgentMethodCall<AgentT> = StreamingAgentMethodCall<AgentT> &
OptionalArgsAgentMethodCall<AgentT> &
RequiredArgsAgentMethodCall<AgentT>;

type UntypedAgentMethodCall = <T = unknown>(
Expand All @@ -192,13 +274,35 @@ type UntypedAgentMethodCall = <T = unknown>(
) => Promise<T>;

type AgentStub<T> = {
[K in keyof AgentMethods<T>]: (
...args: Parameters<AgentMethods<T>[K]>
) => AgentPromiseReturnType<AgentMethods<T>, K>;
[K in keyof AgentMethods<T>]: AgentMethods<T>[K] extends NonStreamingRPCMethod<
AgentMethods<T>[K]
>
? (
...args: Parameters<AgentMethods<T>[K]>
) => AgentPromiseReturnType<AgentMethods<T>, K>
: never;
};

type AgentStreamingStub<T> = {
[K in keyof AgentMethods<T>]: AgentMethods<T>[K] extends StreamingRPCMethod<
AgentMethods<T>[K]
>
? (
...args: RestParameters<AgentMethods<T>[K]>
) => AsyncGenerator<
ReturnAndChunkTypesFrom<
Parameters<StreamingRPCMethod<AgentMethods<T>[K]>>[0]
>[0],
ReturnAndChunkTypesFrom<
Parameters<StreamingRPCMethod<AgentMethods<T>[K]>>[0]
>[1]
>
: never;
};

// we neet to use Method instead of RPCMethod here for retro-compatibility
type UntypedAgentStub = Record<string, Method>;
type UntypedAgentStreamingStub = StreamingAgentMethods<unknown>;

/**
* React hook for connecting to an Agent
Expand All @@ -211,6 +315,7 @@ export function useAgent<State = unknown>(
setState: (state: State) => void;
call: UntypedAgentMethodCall;
stub: UntypedAgentStub;
streamingStub: UntypedAgentStreamingStub;
};
export function useAgent<
AgentT extends {
Expand All @@ -225,14 +330,15 @@ export function useAgent<
setState: (state: State) => void;
call: AgentMethodCall<AgentT>;
stub: AgentStub<AgentT>;
streamingStub: AgentStreamingStub<AgentT>;
};
export function useAgent<State>(
options: UseAgentOptions<unknown>
): PartySocket & {
agent: string;
name: string;
setState: (state: State) => void;
call: UntypedAgentMethodCall | AgentMethodCall<unknown>;
call: UntypedAgentMethodCall;
stub: UntypedAgentStub;
} {
const agentNamespace = camelCaseToKebabCase(options.agent);
Expand Down Expand Up @@ -381,6 +487,7 @@ export function useAgent<State>(
setState: (state: State) => void;
call: UntypedAgentMethodCall;
stub: UntypedAgentStub;
streamingStub: UntypedAgentStreamingStub;
};
// Create the call method
const call = useCallback(
Expand Down Expand Up @@ -429,6 +536,55 @@ export function useAgent<State>(
}
}
);
// biome-ignore lint: suppressions/parse
agent.streamingStub = new Proxy<any>(
{},
{
get: (_target, method) => {
return async function* (...args: unknown[]) {
let resolve: (value: unknown) => void;
let reject: (reason: unknown) => void;
let promise = new Promise((res, rej) => {
resolve = res;
reject = rej;
});

// 4. State flags
let isDone = false;

// 5. Callback implementation
const streamOptions: StreamOptions = {
onChunk: (chunk: unknown) => {
resolve(chunk);
promise = new Promise((res, rej) => {
resolve = res;
reject = rej;
});
},
onError: (error: unknown) => {
isDone = true;
reject(error);
},
onDone: (done: unknown) => {
isDone = true;
resolve(done);
}
};

call(method as string, args, streamOptions);

while (!isDone) {
const result = await promise;
if (isDone) {
return result;
} else {
yield result;
}
}
};
}
}
);

// warn if agent isn't in lowercase
if (agent.agent !== agent.agent.toLowerCase()) {
Expand Down
15 changes: 1 addition & 14 deletions packages/agents/src/serializable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,8 @@ export type SerializableReturnValue =
| Promise<SerializableValue>
| Promise<void>;

type AllSerializableValues<A> = A extends [infer First, ...infer Rest]
export type AllSerializableValues<A> = A extends [infer First, ...infer Rest]
? First extends SerializableValue
? AllSerializableValues<Rest>
: false
: true; // no params means serializable by default

// biome-ignore lint: suspicious/noExplicitAny
export type Method = (...args: any[]) => any;

export type RPCMethod<T = Method> = T extends Method
? T extends (...arg: infer A) => infer R
? AllSerializableValues<A> extends true
? R extends SerializableReturnValue
? T
: never
: never
: never
: never;
Loading
Loading