diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e43b0f9 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.DS_Store diff --git a/cmd/streamer/main.go b/cmd/streamer/main.go new file mode 100644 index 0000000..462effb --- /dev/null +++ b/cmd/streamer/main.go @@ -0,0 +1,99 @@ +package main + +import ( + "math/rand" + "net/http" + "strings" + "time" + + "github.com/ivanvanderbyl/smith/pkg/streaming" + "github.com/rs/cors" +) + +var exampleResponse = `You are a parliamentary advisor in Australia whose role is to help the general public understand the discussions and debates that have taken place in parliament. + Your objective is to provide clear, concise, and impartial summaries of parliamentary proceedings to facilitate increased civic engagement and informed participation in the electoral process. + When responding to inquiries, aim to present the key points, arguments, and decisions made by parliamentarians in an easily understandable manner, while maintaining a neutral stance and avoiding any personal bias or opinion. + Your explanations should be accessible to a broad audience, regardless of their prior knowledge of the political system or the specific issues being addressed. + If no key points are present, do not return any text, and only refer to the original text and not your own interpretation. + Always include the Name of the speaker and their Party in the response if quoted. + + Given the following information, answer the question. + Every response should be in Australian English. Dates and Numeric values are always in metric format. + Include key points and arguments made by parliamentarians. Ensure the response is clear, concise, and impartial. + Always refer to Members of Parliament by their full name and party affiliation in the format: "**First name Last name** _(Party, Electorate)_". + + Your audience has a post graduate level of education and is interested in understanding the key points of the parliamentary proceedings. + + Context information is below.` + +type Annotation struct { + Type string `json:"type"` + Timestamp int64 `json:"timestamp"` + Data Event `json:"data"` +} + +type Event struct { + Type string `json:"type"` +} + +// :[{"type":"agent","timestamp":1719466594099,"data":{"type":"chat-start"}} + +// streamChat is a simple HTTP handler that streams a response to the client. +// It is used to demonstrate how to stream responses to the client in a format that +// is compatible with the Vercel @ai-sdk. +func streamChat(w http.ResponseWriter, req *http.Request) { + // Check if the client supports flushing. + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + // Use the chunked transfer encoding to stream the response to the client. + w.Header().Set("Transfer-Encoding", "chunked") + w.WriteHeader(http.StatusOK) + + // Stream an empty response to the client to signal the start of the chat. + b, err := streaming.Marshal(streaming.TextPart{Value: string("")}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Write(b) + flusher.Flush() + + b, err = streaming.Marshal(streaming.AnnotationPart[Annotation]{ + Value: []Annotation{ + {Type: "agent", Timestamp: time.Now().Unix(), Data: Event{Type: "chat-start"}}, + }, + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Write(b) + flusher.Flush() + + time.Sleep(1 * time.Second) + + for _, str := range strings.Split(exampleResponse, " ") { + str = str + " " + b, err := streaming.Marshal(streaming.TextPart{Value: string(str)}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Write(b) + flusher.Flush() + + randomDuration := rand.Intn(250) // 101 to include 100 + time.Sleep(time.Duration(randomDuration) * time.Millisecond) + } +} + +func main() { + mux := http.NewServeMux() + mux.HandleFunc("/chat", streamChat) + handler := cors.Default().Handler(mux) + http.ListenAndServe(":8090", handler) +} diff --git a/go.mod b/go.mod index f387b50..0b1c3a3 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.20 require ( github.com/invopop/jsonschema v0.12.0 + github.com/rs/cors v1.11.0 github.com/stretchr/testify v1.8.2 ) diff --git a/go.sum b/go.sum index 4e96623..0260c21 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po= +github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/sashabaranov/go-openai v1.26.0 h1:upM565hxdqvCxNzuAcEBZ1XsfGehH0/9kgk9rFVpDxQ= github.com/sashabaranov/go-openai v1.26.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go new file mode 100644 index 0000000..3dfec25 --- /dev/null +++ b/pkg/agent/agent.go @@ -0,0 +1,75 @@ +package agent + +import ( + "github.com/ivanvanderbyl/smith/pkg/function" + "github.com/sashabaranov/go-openai" +) + +const MAX_TOOL_CALLS = 10 + +type ( + Agent struct { + Tasks []Task + SystemPrompt string + llm *openai.Client + Tools function.AnyFunctions + } + + AgentOption func(*Agent) + + Task struct { + History []ChatMessage + ToolCallCount int + } + + Step struct { + ID string + PreviousStep *Step + NextSteps []*Step + } + + ChatMessage struct { + Content string + Role Role + } + + Role string +) + +const ( + RoleUser Role = "user" + RoleAssistant Role = "assistant" + RoleSystem Role = "system" + RoleMemory Role = "memory" +) + +// ShouldContinue returns true if the task should continue to run. +func (t *Task) ShouldContinue() bool { + return t.ToolCallCount < MAX_TOOL_CALLS +} + +func WithSystemPrompt(prompt string) AgentOption { + return func(a *Agent) { + a.SystemPrompt = prompt + } +} + +func (a *Agent) CreateTask() *Task { + initialMessages := []ChatMessage{ + { + Content: a.SystemPrompt, + Role: RoleSystem, + }, + + // Include previous history + } + + return &Task{ + History: initialMessages, + ToolCallCount: 0, + } +} + +func (s *Step) IsLast() bool { + return len(s.NextSteps) == 0 +} diff --git a/pkg/llm/llm.go b/pkg/llm/llm.go new file mode 100644 index 0000000..f0e3930 --- /dev/null +++ b/pkg/llm/llm.go @@ -0,0 +1,13 @@ +package llm + +type ( + LLM interface { + CreateChatCompletion(input ChatCompletionParams) (ChatResponse, error) + // CreateCompletionStream(input CompletionRequest) (CompletionStream, error) + } + + ChatCompletionParams struct { + } + + ChatResponse struct{} +) diff --git a/pkg/streaming/formatter.go b/pkg/streaming/formatter.go new file mode 100644 index 0000000..162b1f3 --- /dev/null +++ b/pkg/streaming/formatter.go @@ -0,0 +1,122 @@ +// streaming implements a stream encoder that is compatible with the Vercel @ai-sdk. +package streaming + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" +) + +type ( + StreamPart interface { + isStreamPart() + Code() Code + fmt.Stringer + } + + Code int8 + + TextPart struct { + Value string + } + + FunctionCallPart[T any] struct { + Name string `json:"name"` + Args T `json:"arguments"` + } + + AnnotationPart[T any] struct { + Value []T + } +) + +func (TextPart) isStreamPart() {} + +func (TextPart) Code() Code { + return TextPartCode +} + +func (t TextPart) String() string { + b, err := json.Marshal(t.Value) + if err != nil { + return "" + } + + return string(b) +} + +func (FunctionCallPart[T]) isStreamPart() {} + +func (FunctionCallPart[T]) Code() Code { + return FunctionCallPartCode +} + +type functionCall struct { + Name string `json:"name"` + Args string `json:"arguments"` +} + +type functionCallWrapper[T any] struct { + Fn functionCall `json:"function_call"` +} + +func (f FunctionCallPart[T]) String() string { + v := functionCallWrapper[T]{ + Fn: functionCall{ + Name: f.Name, + }, + } + + b, err := json.Marshal(f.Args) + if err != nil { + return "" + } + v.Fn.Args = string(b) + + buf := new(bytes.Buffer) + _ = json.NewEncoder(buf).Encode(v) + return buf.String() +} + +func (AnnotationPart[T]) isStreamPart() {} + +func (AnnotationPart[T]) Code() Code { + return MessageAnnotationsPartCode +} + +func (a AnnotationPart[T]) String() string { + buf := new(bytes.Buffer) + _ = json.NewEncoder(buf).Encode(a.Value) + return buf.String() +} + +const ( + TextPartCode Code = iota + FunctionCallPartCode + DataPartCode + ErrorPartCode + AssistantMessagePartCode + AssistantControlDataPartCode + DataMessagePartCode + ToolCallsPartCode + MessageAnnotationsPartCode + ToolCallPartCode + ToolResultPartCode +) + +// Marshal formats a stream part into a string +func Marshal(part StreamPart) ([]byte, error) { + out := bytes.NewBuffer(nil) + + value := part.String() + if strings.HasSuffix(value, "\n") { + value = strings.TrimSuffix(value, "\n") + } + + _, err := fmt.Fprintf(out, "%d:%s\n", part.Code(), value) + if err != nil { + return nil, err + } + return out.Bytes(), nil +} diff --git a/pkg/streaming/formatter_test.go b/pkg/streaming/formatter_test.go new file mode 100644 index 0000000..99633b7 --- /dev/null +++ b/pkg/streaming/formatter_test.go @@ -0,0 +1,65 @@ +package streaming_test + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/ivanvanderbyl/smith/pkg/streaming" + "github.com/stretchr/testify/assert" +) + +type WeatherArgs struct { + Location string `json:"location"` + Units string `json:"units"` +} + +type AgentEvent struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` +} + +type Anot struct { + Test string `json:"test"` +} + +func TestEncoding(t *testing.T) { + a := assert.New(t) + + tests := []struct { + name string + part streaming.StreamPart + expected string + }{ + { + name: "TextPart", + part: streaming.TextPart{Value: "Hello, World!"}, + expected: "0:\"Hello, World!\"\n", + }, + { + name: "FunctionCallPart", + part: streaming.FunctionCallPart[string]{Name: "GET_WEATHER", Args: "Sydney, Australia"}, + expected: `1:"{\"function_call\":{\"name\":\"GET_WEATHER\",\"arguments\":\"{\\\"location\\\":\\\"Sydney, Australia\\\",\\\"units\\\":\\\"celsius\\\"}\"}}`, + }, + { + name: "FunctionCallPart with object", + part: streaming.FunctionCallPart[WeatherArgs]{Name: "GET_WEATHER", Args: WeatherArgs{Location: "Sydney, Australia", Units: "celsius"}}, + expected: `1:{"function_call":{"name":"GET_WEATHER","arguments":"{\"location\":\"Sydney, Australia\",\"units\":\"celsius\"}"}}`, + }, + { + name: "AnnotationPart", + part: streaming.AnnotationPart[Anot]{Value: []Anot{{Test: "value"}}}, + expected: `8:[{"test":"value"}]`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encoded, err := streaming.Marshal(tt.part) + a.NoError(err) + a.Equal(strings.TrimSpace(tt.expected), strings.TrimSpace(string(encoded))) + a.True(bytes.HasSuffix(encoded, []byte("\n"))) + }) + } +}