Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"typescript": "^5.2.2"
},
"dependencies": {
"zod": "^3.22.2"
"zod": "^3.22.2",
"zod-to-json-schema": "^3.21.4"
}
}
11 changes: 11 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions src/Chat.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import OpenAI from "openai";
import { PromptBuilder } from "./PromptBuilder";
import { ExtractArgs, ExtractChatArgs, ReplaceChatArgs } from "./types";
import { ToolBuilder } from "./ToolBuilder";
import { Tool, ToolType } from './Tool'

export class Chat<
const ToolNames extends string,
TMessages extends
| []
| [
Expand All @@ -14,8 +17,15 @@ export class Chat<
constructor(
public messages: TMessages,
public args: TSuppliedInputArgs,
public tools = {} as Record<ToolNames, Tool<ToolNames, ToolType, any, any>>,
public mustUseTool: boolean = false
) {}

toJSONSchema() {
const tools = Object.values(this.tools) as Tool<ToolNames, ToolType, any, any>[];
return tools.reduce((acc, t) => ({ ...acc, ...t.toJSONSchema()}), {})
}

toArray() {
return this.messages.map((m) => ({
role: m.role,
Expand Down
3 changes: 2 additions & 1 deletion src/ChatBuilder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
TypeToZodShape,
ReplaceChatArgs,
} from "./types";
import { ToolBuilder } from "./ToolBuilder";

export class ChatBuilder<
TMessages extends
Expand Down Expand Up @@ -90,7 +91,7 @@ export class ChatBuilder<
build<const TSuppliedInputArgs extends TExpectedInput>(
args: TSuppliedInputArgs,
) {
return new Chat<TMessages, TSuppliedInputArgs>(
return new Chat<"", TMessages, TSuppliedInputArgs>(
this.messages as any,
args,
).toArray();
Expand Down
50 changes: 50 additions & 0 deletions src/Tool.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { z, ZodType } from 'zod'
import { zodToJsonSchema } from 'zod-to-json-schema'

export const ToolType = z.enum(["query", "mutation"])
export type ToolType = z.infer<typeof ToolType>


export class Tool<
TName extends string,
TType extends "query" | "mutation",
const TExpectedInput extends { [key: string]: string },
TExpectedOutput
> {
constructor(
public name: TName,
public description: string,
public type: TType,
public use: (input: TExpectedInput) => TExpectedOutput,
public input?: ZodType<TExpectedInput>,
public output?: ZodType<TExpectedOutput>,
) {}

toJSONSchema() {
if (!this.input) {
throw new Error('Tool has no input schema. Please use ToolBuilder.addZodInputValidation to set.')
}
const schema = zodToJsonSchema(this.input) as any;
delete schema.$schema;
if (!schema.additionalProperties) delete schema.additionalProperties;
return {
name: this.name,
description: this.description,
parameters: schema,
};
}

validateInput(args: unknown): args is TExpectedInput {
if (!this.input) {
throw new Error('Tool has no input schema. Please use ToolBuilder.addZodInputValidation to set.')
}
return this.input.safeParse(args).success
}

validateOutput(args: unknown): args is TExpectedOutput {
if (!this.output) {
throw new Error('Tool has no output schema. Please use ToolBuilder.addZodInputValidation to set.')
}
return this.output.safeParse(args).success
}
}
94 changes: 94 additions & 0 deletions src/ToolBuilder.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import { z, AnyZodObject, infer as _infer, ZodType } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { TypeToZodShape } from "./types";
import { Tool } from "./Tool";

export class ToolBuilder<
TName extends string,
TType extends "query" | "mutation",
const TExpectedInput extends Record<string, any>,
TExpectedOutput
> {
constructor(
public name: TName,
public description: string = "",
public type: TType = "query" as TType,
public implementation?: (input: TExpectedInput) => TExpectedOutput
) {}

addZodInputValidation<TShape extends TExpectedInput>(
shape: TypeToZodShape<TShape>
): ToolBuilder<TName, TType, TShape, TExpectedOutput> {
const zodValidator = z.object(shape as any);
return new (class extends ToolBuilder<
TName,
TType,
TShape,
TExpectedOutput
> {
validate(args: unknown): args is TShape {
return zodValidator.safeParse(args).success;
}

query(queryFunction: (input: TExpectedInput) => TShape) {
// zodValidator.parse(args);
return new Tool(this.name, "query", queryFunction);
}

mutation(mutationFunction: (input: TExpectedInput) => TShape) {
return new Tool(this.name, "mutation", mutationFunction);
}
})(this.name, this.description, this.type, this.implementation);
}

addZodOutputValidation<TShape extends TExpectedOutput>(shape: ZodType<TShape>) {
const zodValidator = z.object(shape as any);
return new (class extends ToolBuilder<
TName,
TType,
TExpectedInput,
TShape
> {
validateOutput(output: unknown): output is TShape {
return zodValidator.safeParse(output).success;
}

query(queryFunction: (input: TExpectedInput) => TShape) {
return new Tool(this.name, this.description, "query", queryFunction);
}

mutation(mutationFunction: (input: TExpectedInput) => TShape) {
return new Tool(this.name, this.description, "mutation", mutationFunction);
}
})(this.name, this.description, this.type, this.implementation as any);
}

query(queryFunction: (input: any) => any) {
return new Tool(this.name, this.description, "query", queryFunction);
}

mutation(mutationFunction: (input: any) => any) {
return new Tool(this.name, this.description, "mutation", mutationFunction);
}

toJSONSchema() {
// // const fns: any[] = [];
// // const { params, ...rest } = this.implementation[key];
const schema = zodToJsonSchema(
z.object({
name: z.string(),
})
);
delete schema.$schema;
// if (!schema.additionalProperties) delete schema.additionalProperties;
// // fns.push();
return {
name: this.name,
parameters: schema,
};
}

build<TShape extends TExpectedInput>(input: TShape) {
return new Tool(this.name, this.description, this.type, this.implementation!);
}
}
113 changes: 100 additions & 13 deletions src/__tests__/Chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ import { strict as assert } from "node:assert";
import { Chat } from "../Chat";
import { system, user, assistant } from "../ChatHelpers";
import { Equal, Expect } from "./types.test";
import { ToolBuilder } from "../ToolBuilder";
import { Tool } from '../Tool'
import { z } from "zod";

describe("Chat", () => {
it("should allow empty array", () => {
const chat = new Chat([], {}).toArray();
const chat = new Chat([], {}, {}).toArray();
type test = Expect<Equal<typeof chat, []>>;
assert.deepEqual(chat, []);
});
Expand All @@ -14,7 +17,7 @@ describe("Chat", () => {
const chat = new Chat(
[user("Tell me a {{jokeType}} joke")],
// @ts-expect-error
{},
{}
).toArray();
type test = Expect<
Equal<typeof chat, [{ role: "user"; content: `Tell me a ${any} joke` }]>
Expand All @@ -31,6 +34,9 @@ describe("Chat", () => {
assert.deepEqual(chat, [usrMsg]);
});

const usrMsg = user("Tell me a funny joke");
const astMsg = assistant("foo joke?");
const sysMsg = system("joke? bar");
it("should allow chat of all diffent types", () => {
const chat = new Chat(
[
Expand All @@ -42,28 +48,109 @@ describe("Chat", () => {
jokeType1: "funny",
var2: "foo",
var3: "bar",
},
}
).toArray();
const usrMsg = user("Tell me a funny joke");
const astMsg = assistant("foo joke?");
const sysMsg = system("joke? bar");
type test = Expect<
Equal<typeof chat, [typeof usrMsg, typeof astMsg, typeof sysMsg]>
>;
assert.deepEqual(chat, [usrMsg, astMsg, sysMsg]);
});

it("should allow chat of all diffent types with no args", () => {
const chat = new Chat(
[user(`Tell me a joke`), assistant(`joke?`), system(`joke?`)],
{},
).toArray();
const usrMsg = user("Tell me a joke");
const astMsg = assistant("joke?");
const sysMsg = system("joke?");
const chat = new Chat([usrMsg, astMsg, sysMsg], {}).toArray();
type test = Expect<
Equal<typeof chat, [typeof usrMsg, typeof astMsg, typeof sysMsg]>
>;
assert.deepEqual(chat, [usrMsg, astMsg, sysMsg]);
});

it("should allow me to pass in tools", () => {
const google = new ToolBuilder("google")
.addZodInputValidation({ query: z.string() })
.addZodOutputValidation(z.object({ results: z.array(z.string()) }))
.query(({ query }) => {
return {
results: ["foo", "bar"],
};
});
const wikipedia = new ToolBuilder("wikipedia")
.addZodInputValidation({ page: z.string() })
.addZodOutputValidation(z.object({ results: z.array(z.string()) }))
.query(({ page }) => {
return {
results: ["foo", "bar"],
};
});

const sendEmail = new ToolBuilder("sendEmail")
.addZodInputValidation({
to: z.string(),
subject: z.string(),
body: z.string(),
})
.addZodOutputValidation(z.object({ success: z.boolean() }))
.mutation(({ to, subject, body }) => {
return {
success: true,
};
});
const tools = {
google,
wikipedia,
sendEmail,
};
const chat = new Chat([usrMsg, astMsg, sysMsg], {}, tools);

type tests = [
Expect<
Equal<
typeof chat,
Chat<
keyof typeof tools,
[typeof usrMsg, typeof astMsg, typeof sysMsg],
{}
>
>
>,
Expect<
Equal<
typeof tools,
{
google: Tool<
"google",
"query",
{
query: string;
},
{
results: string[];
}
>;
wikipedia: Tool<
"wikipedia",
"query",
{
page: string;
},
{
results: string[];
}
>;
sendEmail: Tool<
"sendEmail",
"mutation",
{
to: string;
subject: string;
body: string;
},
{
success: boolean;
}
>;
}
>
>
];
});
});
Loading