Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.DS_Store
99 changes: 99 additions & 0 deletions cmd/streamer/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package main

import (
"math/rand"
"net/http"
"strings"
"time"

"github.com/ivanvanderbyl/smith/pkg/streaming"
"github.com/rs/cors"
)

var exampleResponse = `You are a parliamentary advisor in Australia whose role is to help the general public understand the discussions and debates that have taken place in parliament.
Your objective is to provide clear, concise, and impartial summaries of parliamentary proceedings to facilitate increased civic engagement and informed participation in the electoral process.
When responding to inquiries, aim to present the key points, arguments, and decisions made by parliamentarians in an easily understandable manner, while maintaining a neutral stance and avoiding any personal bias or opinion.
Your explanations should be accessible to a broad audience, regardless of their prior knowledge of the political system or the specific issues being addressed.
If no key points are present, do not return any text, and only refer to the original text and not your own interpretation.
Always include the Name of the speaker and their Party in the response if quoted.

Given the following information, answer the question.
Every response should be in Australian English. Dates and Numeric values are always in metric format.
Include key points and arguments made by parliamentarians. Ensure the response is clear, concise, and impartial.
Always refer to Members of Parliament by their full name and party affiliation in the format: "**First name Last name** _(Party, Electorate)_".

Your audience has a post graduate level of education and is interested in understanding the key points of the parliamentary proceedings.

Context information is below.`

type Annotation struct {
Type string `json:"type"`
Timestamp int64 `json:"timestamp"`
Data Event `json:"data"`
}

type Event struct {
Type string `json:"type"`
}

// :[{"type":"agent","timestamp":1719466594099,"data":{"type":"chat-start"}}

// streamChat is a simple HTTP handler that streams a response to the client.
// It is used to demonstrate how to stream responses to the client in a format that
// is compatible with the Vercel @ai-sdk.
func streamChat(w http.ResponseWriter, req *http.Request) {
// Check if the client supports flushing.
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return
}

// Use the chunked transfer encoding to stream the response to the client.
w.Header().Set("Transfer-Encoding", "chunked")
w.WriteHeader(http.StatusOK)

// Stream an empty response to the client to signal the start of the chat.
b, err := streaming.Marshal(streaming.TextPart{Value: string("")})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write(b)
flusher.Flush()

b, err = streaming.Marshal(streaming.AnnotationPart[Annotation]{
Value: []Annotation{
{Type: "agent", Timestamp: time.Now().Unix(), Data: Event{Type: "chat-start"}},
},
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write(b)
flusher.Flush()

time.Sleep(1 * time.Second)

for _, str := range strings.Split(exampleResponse, " ") {
str = str + " "
b, err := streaming.Marshal(streaming.TextPart{Value: string(str)})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write(b)
flusher.Flush()

randomDuration := rand.Intn(250) // 101 to include 100
time.Sleep(time.Duration(randomDuration) * time.Millisecond)
}
}

func main() {
mux := http.NewServeMux()
mux.HandleFunc("/chat", streamChat)
handler := cors.Default().Handler(mux)
http.ListenAndServe(":8090", handler)
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.20

require (
github.com/invopop/jsonschema v0.12.0
github.com/rs/cors v1.11.0
github.com/stretchr/testify v1.8.2
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po=
github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/sashabaranov/go-openai v1.26.0 h1:upM565hxdqvCxNzuAcEBZ1XsfGehH0/9kgk9rFVpDxQ=
github.com/sashabaranov/go-openai v1.26.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
75 changes: 75 additions & 0 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package agent

import (
"github.com/ivanvanderbyl/smith/pkg/function"
"github.com/sashabaranov/go-openai"
)

const MAX_TOOL_CALLS = 10

type (
Agent struct {
Tasks []Task
SystemPrompt string
llm *openai.Client
Tools function.AnyFunctions
}

AgentOption func(*Agent)

Task struct {
History []ChatMessage
ToolCallCount int
}

Step struct {
ID string
PreviousStep *Step
NextSteps []*Step
}

ChatMessage struct {
Content string
Role Role
}

Role string
)

const (
RoleUser Role = "user"
RoleAssistant Role = "assistant"
RoleSystem Role = "system"
RoleMemory Role = "memory"
)

// ShouldContinue returns true if the task should continue to run.
func (t *Task) ShouldContinue() bool {
return t.ToolCallCount < MAX_TOOL_CALLS
}

func WithSystemPrompt(prompt string) AgentOption {
return func(a *Agent) {
a.SystemPrompt = prompt
}
}

func (a *Agent) CreateTask() *Task {
initialMessages := []ChatMessage{
{
Content: a.SystemPrompt,
Role: RoleSystem,
},

// Include previous history
}

return &Task{
History: initialMessages,
ToolCallCount: 0,
}
}

func (s *Step) IsLast() bool {
return len(s.NextSteps) == 0
}
13 changes: 13 additions & 0 deletions pkg/llm/llm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package llm

type (
LLM interface {
CreateChatCompletion(input ChatCompletionParams) (ChatResponse, error)
// CreateCompletionStream(input CompletionRequest) (CompletionStream, error)
}

ChatCompletionParams struct {
}

ChatResponse struct{}
)
122 changes: 122 additions & 0 deletions pkg/streaming/formatter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// streaming implements a stream encoder that is compatible with the Vercel @ai-sdk.
package streaming

import (
"bytes"
"encoding/json"
"fmt"
"strings"
)

type (
StreamPart interface {
isStreamPart()
Code() Code
fmt.Stringer
}

Code int8

TextPart struct {
Value string
}

FunctionCallPart[T any] struct {
Name string `json:"name"`
Args T `json:"arguments"`
}

AnnotationPart[T any] struct {
Value []T
}
)

func (TextPart) isStreamPart() {}

func (TextPart) Code() Code {
return TextPartCode
}

func (t TextPart) String() string {
b, err := json.Marshal(t.Value)
if err != nil {
return ""
}

return string(b)
}

func (FunctionCallPart[T]) isStreamPart() {}

func (FunctionCallPart[T]) Code() Code {
return FunctionCallPartCode
}

type functionCall struct {
Name string `json:"name"`
Args string `json:"arguments"`
}

type functionCallWrapper[T any] struct {
Fn functionCall `json:"function_call"`
}

func (f FunctionCallPart[T]) String() string {
v := functionCallWrapper[T]{
Fn: functionCall{
Name: f.Name,
},
}

b, err := json.Marshal(f.Args)
if err != nil {
return ""
}
v.Fn.Args = string(b)

buf := new(bytes.Buffer)
_ = json.NewEncoder(buf).Encode(v)
return buf.String()
}

func (AnnotationPart[T]) isStreamPart() {}

func (AnnotationPart[T]) Code() Code {
return MessageAnnotationsPartCode
}

func (a AnnotationPart[T]) String() string {
buf := new(bytes.Buffer)
_ = json.NewEncoder(buf).Encode(a.Value)
return buf.String()
}

const (
TextPartCode Code = iota
FunctionCallPartCode
DataPartCode
ErrorPartCode
AssistantMessagePartCode
AssistantControlDataPartCode
DataMessagePartCode
ToolCallsPartCode
MessageAnnotationsPartCode
ToolCallPartCode
ToolResultPartCode
)

// Marshal formats a stream part into a string
func Marshal(part StreamPart) ([]byte, error) {
out := bytes.NewBuffer(nil)

value := part.String()
if strings.HasSuffix(value, "\n") {
value = strings.TrimSuffix(value, "\n")
}

_, err := fmt.Fprintf(out, "%d:%s\n", part.Code(), value)
if err != nil {
return nil, err
}
return out.Bytes(), nil
}
65 changes: 65 additions & 0 deletions pkg/streaming/formatter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package streaming_test

import (
"bytes"
"encoding/json"
"strings"
"testing"

"github.com/ivanvanderbyl/smith/pkg/streaming"
"github.com/stretchr/testify/assert"
)

type WeatherArgs struct {
Location string `json:"location"`
Units string `json:"units"`
}

type AgentEvent struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}

type Anot struct {
Test string `json:"test"`
}

func TestEncoding(t *testing.T) {
a := assert.New(t)

tests := []struct {
name string
part streaming.StreamPart
expected string
}{
{
name: "TextPart",
part: streaming.TextPart{Value: "Hello, World!"},
expected: "0:\"Hello, World!\"\n",
},
{
name: "FunctionCallPart",
part: streaming.FunctionCallPart[string]{Name: "GET_WEATHER", Args: "Sydney, Australia"},
expected: `1:"{\"function_call\":{\"name\":\"GET_WEATHER\",\"arguments\":\"{\\\"location\\\":\\\"Sydney, Australia\\\",\\\"units\\\":\\\"celsius\\\"}\"}}`,
},
{
name: "FunctionCallPart with object",
part: streaming.FunctionCallPart[WeatherArgs]{Name: "GET_WEATHER", Args: WeatherArgs{Location: "Sydney, Australia", Units: "celsius"}},
expected: `1:{"function_call":{"name":"GET_WEATHER","arguments":"{\"location\":\"Sydney, Australia\",\"units\":\"celsius\"}"}}`,
},
{
name: "AnnotationPart",
part: streaming.AnnotationPart[Anot]{Value: []Anot{{Test: "value"}}},
expected: `8:[{"test":"value"}]`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
encoded, err := streaming.Marshal(tt.part)
a.NoError(err)
a.Equal(strings.TrimSpace(tt.expected), strings.TrimSpace(string(encoded)))
a.True(bytes.HasSuffix(encoded, []byte("\n")))
})
}
}