diff --git a/package.json b/package.json index b5566a7..2409e54 100644 --- a/package.json +++ b/package.json @@ -16,12 +16,13 @@ "@types/jest": "^29.5.1", "@types/node": "^20.2.4", "jest": "^29.5.0", - "openai": "^3.2.1", + "openai": "^3.3.0", "ts-jest": "^29.1.0", "ts-toolbelt": "^9.6.0", "typescript": "^5.0.4" }, "dependencies": { + "arktype": "1.0.14-alpha", "zod": "^3.21.4" } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 13a3f5c..9869661 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -3,21 +3,23 @@ lockfileVersion: 5.4 specifiers: '@types/jest': ^29.5.1 '@types/node': ^20.2.4 + arktype: 1.0.14-alpha jest: ^29.5.0 - openai: ^3.2.1 + openai: ^3.3.0 ts-jest: ^29.1.0 ts-toolbelt: ^9.6.0 typescript: ^5.0.4 zod: ^3.21.4 dependencies: + arktype: 1.0.14-alpha zod: 3.21.4 devDependencies: '@types/jest': 29.5.1 '@types/node': 20.2.4 jest: 29.5.0_@types+node@20.2.4 - openai: 3.2.1 + openai: 3.3.0 ts-jest: 29.1.0_tobmchb5uviuq5lwsinkw5fvje ts-toolbelt: 9.6.0 typescript: 5.0.4 @@ -769,6 +771,11 @@ packages: sprintf-js: 1.0.3 dev: true + /arktype/1.0.14-alpha: + resolution: {integrity: sha512-theD5K4QrYCWMtQ52Masj169IgtMJ8Argld/MBS4lotEwR3b+GzfBvsqVJ1OIKhJDTdES02FLQTjcfRe00//mA==} + requiresBuild: true + dev: false + /asynckit/0.4.0: resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==} dev: true @@ -1922,8 +1929,8 @@ packages: mimic-fn: 2.1.0 dev: true - /openai/3.2.1: - resolution: {integrity: sha512-762C9BNlJPbjjlWZi4WYK9iM2tAVAv0uUp1UmI34vb0CN5T2mjB/qM6RYBmNKMh/dN9fC+bxqPwWJZUTWW052A==} + /openai/3.3.0: + resolution: {integrity: sha512-uqxI/Au+aPRnsaQRe8CojU0eCR7I0mBiKjD3sNMzY6DaC1ZVrc85u98mtJW6voDug8fgGN+DIZmTDxTthxb7dQ==} dependencies: axios: 0.26.1 form-data: 4.0.0 diff --git a/src/Chat.ts b/src/Chat.ts index 8c8104a..2ebf19c 100644 --- a/src/Chat.ts +++ b/src/Chat.ts @@ -15,7 +15,9 @@ export class Chat< ) {} toArray() { - return (this.messages as any[]).map((m: ChatCompletionRequestMessage) => ({ + return (this.messages as any[]) + .filter((m) => m.content !== undefined) + .map((m) => ({ role: m.role, content: new PromptBuilder(m.content) .addInputValidation>() diff --git a/src/ChatBuilder.ts b/src/ChatBuilder.ts index bd1e2d5..ad5a12f 100644 --- a/src/ChatBuilder.ts +++ b/src/ChatBuilder.ts @@ -2,7 +2,7 @@ import { z } from "zod"; import { F } from "ts-toolbelt"; import { ChatCompletionRequestMessage } from "openai"; import { Chat } from "./Chat"; -import { user, assistant, system } from "./ChatHelpers"; +import { User, Assistant, System, Function } from "./ChatHelpers"; import { ExtractArgs, ExtractChatArgs, TypeToZodShape, ReplaceChatArgs } from "./types"; export class ChatBuilder< @@ -19,31 +19,45 @@ export class ChatBuilder< return new ChatBuilder(this.messages) as any; } - user( + User( str: TUserText ): ChatBuilder< [...TMessages, { role: "user"; content: TUserText }], F.Narrow & ExtractArgs > { - return new ChatBuilder([...this.messages, user(str)]) as any; + return new ChatBuilder([...this.messages, User(str)]) as any; } - system( + System( str: TSystemText ): ChatBuilder< [...TMessages, { role: "system"; content: TSystemText }], F.Narrow & ExtractArgs > { - return new ChatBuilder([...this.messages, system(str)]) as any; + return new ChatBuilder([...this.messages, System(str)]) as any; } - assistant( + Assistant( str: TAssistantText ): ChatBuilder< [...TMessages, { role: "assistant"; content: TAssistantText }], F.Narrow & ExtractArgs > { - return new ChatBuilder([...this.messages, assistant(str)]) as any; + return new ChatBuilder([...this.messages, Assistant(str)]) as any; + } + + // Backwards compadibility + user = this.User; + system = this.System; + assistant = this.Assistant; + + Function( + str: TAssistantText + ): ChatBuilder< + [...TMessages, { role: "function"; content: TAssistantText }], + F.Narrow & ExtractArgs + > { + return new ChatBuilder([...this.messages, Function(str)]) as any; } addZodInputValidation( diff --git a/src/ChatHelpers.ts b/src/ChatHelpers.ts index fb72c0f..1085f2b 100644 --- a/src/ChatHelpers.ts +++ b/src/ChatHelpers.ts @@ -1,7 +1,6 @@ // ChatMessage creation helpers -// Ideally these would Dedent their content, but ts is checker is way too slow // https://tinyurl.com/message-creators-literal-types -export function system( +export function System( literals: TemplateStringsArray | T, ...placeholders: unknown[] ) { @@ -10,7 +9,7 @@ export function system( content: dedent(literals, ...placeholders), }; } -export function user( +export function User( literals: TemplateStringsArray | T, ...placeholders: unknown[] ) { @@ -19,7 +18,7 @@ export function user( content: dedent(literals, ...placeholders), }; } -export function assistant( +export function Assistant( literals: TemplateStringsArray | T, ...placeholders: unknown[] ) { @@ -28,6 +27,20 @@ export function assistant( content: dedent(literals, ...placeholders), }; } +export function Function( + literals: TemplateStringsArray | T, + ...placeholders: unknown[] +) { + return { + role: "function" as const, + content: dedent(literals, ...placeholders), + }; +} + +// backwards compadibility +export const system = System; +export const user = User; +export const assistant = Assistant; export function dedent( templ: TemplateStringsArray | T, diff --git a/src/PromptBuilder.ts b/src/PromptBuilder.ts index a6c83c5..0f93e62 100644 --- a/src/PromptBuilder.ts +++ b/src/PromptBuilder.ts @@ -1,4 +1,5 @@ import { z } from "zod"; +import { Type } from "arktype"; import { F } from "ts-toolbelt"; import { Prompt } from "./Prompt"; import { ExtractArgs, ReplaceArgs, TypeToZodShape } from "./types"; @@ -18,23 +19,13 @@ export class PromptBuilder< addZodInputValidation( shape: TypeToZodShape ) { - const zodValidator = z.object(shape as any); - return new (class extends PromptBuilder { - validate(args: Record): args is TShape { - return zodValidator.safeParse(args).success; - } - - get type() { - return this.template as ReplaceArgs; - } + return new ZodPromptBuilder(this.template, shape); + } - build( - args: F.Narrow - ) { - zodValidator.parse(args); - return new Prompt(this.template, args).toString(); - } - })(this.template); + addArkTypeInputValidation( + shape: Type + ) { + return new ArkTypePromptBuilder(this.template, shape); } validate(args: Record): args is TExpectedInput { @@ -52,3 +43,65 @@ export class PromptBuilder< return new Prompt(this.template, args).toString(); } } + +class ZodPromptBuilder< + TPromptTemplate extends string, + TExpectedInput extends ExtractArgs +> extends PromptBuilder { + constructor( + public template: TPromptTemplate, + public shape: TypeToZodShape + ) { + super(template); + } + validate(args: Record): args is TExpectedInput { + const zodValidator = z.object(this.shape as any); + return zodValidator.safeParse(args).success; + } + + get type() { + return this.template as ReplaceArgs; + } + + build( + args: F.Narrow + ) { + const zodValidator = z.object(this.shape as any); + zodValidator.parse(args); + return new Prompt(this.template, args).toString(); + } +} + +class ArkTypePromptBuilder< + TPromptTemplate extends string, + TExpectedInput extends ExtractArgs +> extends PromptBuilder { + constructor( + public template: TPromptTemplate, + public shape: Type + ) { + super(template); + } + validate(args: Record): args is TExpectedInput { + try { + this.shape(args); + return true; + } catch (e) { + return false; + } + } + + get type() { + return this.template as ReplaceArgs; + } + + build( + args: F.Narrow + ) { + const { problems } = this.shape(args); + if (problems?.summary) { + throw new Error(problems.summary); + } + return new Prompt(this.template, args).toString(); + } +} diff --git a/src/__tests__/Chat.test.ts b/src/__tests__/Chat.test.ts index dfb6afc..59f70a6 100644 --- a/src/__tests__/Chat.test.ts +++ b/src/__tests__/Chat.test.ts @@ -1,6 +1,6 @@ import { strict as assert } from "node:assert"; import { Chat } from "../Chat"; -import { system, user, assistant } from "../ChatHelpers"; +import { System, User, Assistant } from "../ChatHelpers"; import { Equal, Expect } from "./types.test"; describe("Chat", () => { @@ -14,7 +14,7 @@ describe("Chat", () => { const chat = new Chat( [ // ^? - user("Tell me a {{jokeType}} joke"), + User("Tell me a {{jokeType}} joke"), ], // @ts-expect-error {} @@ -28,13 +28,13 @@ describe("Chat", () => { const chat = new Chat( [ // ^? - user(`Tell me a {{jokeType}} joke`), + User(`Tell me a {{jokeType}} joke`), ], { jokeType: "funny" as const, } ).toArray(); - const usrMsg = user("Tell me a funny joke"); + const usrMsg = User("Tell me a funny joke"); // ^? type test = Expect>; assert.deepEqual(chat, [usrMsg]); @@ -44,9 +44,9 @@ describe("Chat", () => { const chat = new Chat( [ // ^? - user(`Tell me a {{jokeType1}} joke`), - assistant(`{{var2}} joke?`), - system(`joke? {{var3}}`), + User(`Tell me a {{jokeType1}} joke`), + Assistant(`{{var2}} joke?`), + System(`joke? {{var3}}`), ], { jokeType1: "funny", @@ -54,9 +54,9 @@ describe("Chat", () => { var3: "bar", } as const ).toArray(); - const usrMsg = user("Tell me a funny joke"); - const astMsg = assistant("foo joke?"); - const sysMsg = system("joke? bar"); + const usrMsg = User("Tell me a funny joke"); + const astMsg = Assistant("foo joke?"); + const sysMsg = System("joke? bar"); type test = Expect< Equal >; @@ -67,15 +67,15 @@ describe("Chat", () => { const chat = new Chat( [ // ^? - user(`Tell me a joke`), - assistant(`joke?`), - system(`joke?`), + User(`Tell me a joke`), + Assistant(`joke?`), + System(`joke?`), ], {} ).toArray(); - const usrMsg = user("Tell me a joke"); - const astMsg = assistant("joke?"); - const sysMsg = system("joke?"); + const usrMsg = User("Tell me a joke"); + const astMsg = Assistant("joke?"); + const sysMsg = System("joke?"); type test = Expect< Equal >; diff --git a/src/__tests__/PromptBuilder.test.ts b/src/__tests__/PromptBuilder.test.ts index d837194..bdcf474 100644 --- a/src/__tests__/PromptBuilder.test.ts +++ b/src/__tests__/PromptBuilder.test.ts @@ -1,5 +1,6 @@ import { strict as assert } from "node:assert"; import { z, ZodError } from "zod"; +import { type } from 'arktype'; import { PromptBuilder } from "../PromptBuilder"; import { Equal, Expect } from "./types.test"; @@ -282,7 +283,9 @@ describe("PromptBuilder with input validation using Zod", () => { ); type BasicType = typeof promptBuilder.type; // ^? - type BasicTest = Expect>; + type BasicTest = Expect< + Equal + >; const tsValidatedPromptBuilder = promptBuilder.addInputValidation<{ jokeType: "funny" | "silly"; @@ -305,34 +308,37 @@ describe("PromptBuilder with input validation using Zod", () => { }); test("Can write a function that accepts the type of a PromptBuilder then accepts any output from that builder", () => { - const promptBuilder = new PromptBuilder("Tell me a {{jokeType}} joke.").addInputValidation<{ - jokeType: "funny" | "silly" + const promptBuilder = new PromptBuilder( + "Tell me a {{jokeType}} joke." + ).addInputValidation<{ + jokeType: "funny" | "silly"; }>(); function exampleFunction(input: typeof promptBuilder.type) {} exampleFunction(promptBuilder.build({ jokeType: "funny" })); - exampleFunction("Tell me a funny joke.") + exampleFunction("Tell me a funny joke."); exampleFunction(promptBuilder.build({ jokeType: "silly" })); - exampleFunction("Tell me a silly joke.") + exampleFunction("Tell me a silly joke."); // @ts-expect-error exampleFunction(promptBuilder.build({ jokeType: "bad" })); // @ts-expect-error - exampleFunction("Tell me a bad joke.") - }) - + exampleFunction("Tell me a bad joke."); + }); test("Can write a function that accepts the type of a PromptBuilder then accepts any output from that builder", () => { - const promptBuilder = new PromptBuilder("Tell me a {{jokeType}} joke.").addZodInputValidation({ + const promptBuilder = new PromptBuilder( + "Tell me a {{jokeType}} joke." + ).addZodInputValidation({ jokeType: z.union([z.literal("funny"), z.literal("silly")]), }); function exampleFunction(input: typeof promptBuilder.type) {} exampleFunction(promptBuilder.build({ jokeType: "funny" })); - exampleFunction("Tell me a funny joke.") + exampleFunction("Tell me a funny joke."); exampleFunction(promptBuilder.build({ jokeType: "silly" })); - exampleFunction("Tell me a silly joke.") + exampleFunction("Tell me a silly joke."); // @ts-expect-error - exampleFunction("Tell me a bad joke.") + exampleFunction("Tell me a bad joke."); assert.throws( () => { // @ts-expect-error @@ -350,5 +356,32 @@ describe("PromptBuilder with input validation using Zod", () => { return true; } ); - }) + }); + + test("Can add arktypes validation", () => { + const promptBuilder = new PromptBuilder( + "Tell me a {{jokeType}} joke." + ).addArkTypeInputValidation( + type({ + jokeType: "'funny' | 'silly'", + }) + ); + function exampleFunction(input: typeof promptBuilder.type) {} + + exampleFunction(promptBuilder.build({ jokeType: "funny" })); + exampleFunction("Tell me a funny joke."); + exampleFunction(promptBuilder.build({ jokeType: "silly" })); + exampleFunction("Tell me a silly joke."); + // @ts-expect-error + exampleFunction("Tell me a bad joke."); + + assert.throws( + () => { + + // @ts-expect-error + exampleFunction(promptBuilder.build({ jokeType: "bad" })); + } + ); + + }); }); diff --git a/src/index.ts b/src/index.ts index 907f022..33bc1a3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,5 +3,6 @@ import { PromptBuilder } from "./PromptBuilder"; import { Chat } from "./Chat"; import { system, user, assistant } from "./ChatHelpers"; import { ChatBuilder } from "./ChatBuilder"; +import * as unstable from './unstable' -export { Prompt, PromptBuilder, Chat, ChatBuilder, system, user, assistant }; +export { Prompt, PromptBuilder, Chat, ChatBuilder, system, user, assistant, unstable }; diff --git a/src/unstable/AgentBuilder.ts b/src/unstable/AgentBuilder.ts new file mode 100644 index 0000000..3c43947 --- /dev/null +++ b/src/unstable/AgentBuilder.ts @@ -0,0 +1,92 @@ +import { z } from "zod"; +import { F } from "ts-toolbelt"; +import { ChatCompletionRequestMessage } from "openai"; +import { Chat } from "../Chat"; +import { user, assistant, system } from "../ChatHelpers"; +import { ExtractArgs, ExtractChatArgs, TypeToZodShape, ReplaceChatArgs } from "../types"; + +export class AgentBuilder< + TMessages extends + | [] + | [...ChatCompletionRequestMessage[], ChatCompletionRequestMessage], + TExpectedInput extends ExtractChatArgs +> { + constructor(public messages: TMessages) {} + + addInputValidation< + TSTypeValidator extends ExtractChatArgs + >(): AgentBuilder { + return new AgentBuilder(this.messages) as any; + } + + addOutputValidation< + TSTypeValidator extends ExtractChatArgs + >(): AgentBuilder { + + user( + str: TUserText + ): AgentBuilder< + [...TMessages, { role: "user"; content: TUserText }], + F.Narrow & ExtractArgs + > { + return new AgentBuilder([...this.messages, user(str)]) as any; + } + + system( + str: TSystemText + ): AgentBuilder< + [...TMessages, { role: "system"; content: TSystemText }], + F.Narrow & ExtractArgs + > { + return new AgentBuilder([...this.messages, system(str)]) as any; + } + + assistant( + str: TAssistantText + ): AgentBuilder< + [...TMessages, { role: "assistant"; content: TAssistantText }], + F.Narrow & ExtractArgs + > { + return new AgentBuilder([...this.messages, assistant(str)]) as any; + } + + addZodInputValidation( + shape: TypeToZodShape + ) { + const zodValidator = z.object(shape as any); + return new (class extends AgentBuilder { + validate(args: Record): args is TShape { + return zodValidator.safeParse(args).success; + } + + get type() { + return this.messages as ReplaceChatArgs; + } + + build( + args: F.Narrow + ) { + zodValidator.parse(args); + return super.build(args); + } + })(this.messages); + } + + validate(args: Record): args is TExpectedInput { + // Validate can only be called on a PromptBuilder with zod input validation + return false; + } + + get type() { + return this.messages as ReplaceChatArgs; + } + + build( + args: F.Narrow + ) { + return new Chat( + this.messages as any, + args + ).toArray(); + } +} diff --git a/src/unstable/ChatModelBuilder.ts b/src/unstable/ChatModelBuilder.ts new file mode 100644 index 0000000..e69de29 diff --git a/src/unstable/PromptModelBuilder.ts b/src/unstable/PromptModelBuilder.ts new file mode 100644 index 0000000..e69de29 diff --git a/src/unstable/index.ts b/src/unstable/index.ts new file mode 100644 index 0000000..4a1ec83 --- /dev/null +++ b/src/unstable/index.ts @@ -0,0 +1 @@ +export { AgentBuilder } from './AgentBuilder'; \ No newline at end of file