Skip to content

Commit 696daff

Browse files
authored
Add RunWithOptions method that supports returning file output as io.ReadCloser (#77)
* Add RunWithOptions method that supports returning file output as bytes * Return io.ReadCloser instead of []byte * Add test coverage for RunWithOptions * Return custom FileOutput type that implements io.ReadCloser and provides URL
1 parent 3c5fd6b commit 696daff

File tree

2 files changed

+199
-1
lines changed

2 files changed

+199
-1
lines changed

client_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,79 @@ func TestAutomaticallyRetryPostRequests(t *testing.T) {
14751475
assert.ErrorContains(t, err, http.StatusText(http.StatusInternalServerError))
14761476
}
14771477

1478+
func TestRunWithOptions(t *testing.T) {
1479+
var mockServer *httptest.Server
1480+
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1481+
switch r.URL.Path {
1482+
case "/predictions":
1483+
assert.Equal(t, http.MethodPost, r.Method)
1484+
prediction := replicate.Prediction{
1485+
ID: "gtsllfynndufawqhdngldkdrkq",
1486+
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1487+
Status: replicate.Starting,
1488+
}
1489+
json.NewEncoder(w).Encode(prediction)
1490+
case "/predictions/gtsllfynndufawqhdngldkdrkq":
1491+
assert.Equal(t, http.MethodGet, r.Method)
1492+
prediction := replicate.Prediction{
1493+
ID: "gtsllfynndufawqhdngldkdrkq",
1494+
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1495+
Status: replicate.Succeeded,
1496+
Output: map[string]interface{}{
1497+
"image": mockServer.URL + "/output.png",
1498+
"text": "Hello, world!",
1499+
},
1500+
}
1501+
json.NewEncoder(w).Encode(prediction)
1502+
case "/output.png":
1503+
w.Header().Set("Content-Type", "image/png")
1504+
w.Write([]byte("mock image data"))
1505+
default:
1506+
t.Fatalf("Unexpected request to %s", r.URL.Path)
1507+
}
1508+
}))
1509+
defer mockServer.Close()
1510+
1511+
client, err := replicate.NewClient(
1512+
replicate.WithToken("test-token"),
1513+
replicate.WithBaseURL(mockServer.URL),
1514+
)
1515+
require.NoError(t, err)
1516+
1517+
ctx := context.Background()
1518+
input := replicate.PredictionInput{"prompt": "A test image"}
1519+
1520+
// Test with WithFileOutput option
1521+
output, err := client.RunWithOptions(ctx, "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil, replicate.WithFileOutput())
1522+
1523+
require.NoError(t, err)
1524+
assert.NotNil(t, output)
1525+
1526+
// Check if the image output is transformed to io.ReadCloser
1527+
imageOutput, ok := output.(map[string]interface{})["image"].(io.ReadCloser)
1528+
require.True(t, ok, "Expected image output to be io.ReadCloser")
1529+
1530+
imageData, err := io.ReadAll(imageOutput)
1531+
require.NoError(t, err)
1532+
assert.Equal(t, []byte("mock image data"), imageData)
1533+
1534+
// Check if the text output remains unchanged
1535+
textOutput, ok := output.(map[string]interface{})["text"].(string)
1536+
require.True(t, ok, "Expected text output to be string")
1537+
assert.Equal(t, "Hello, world!", textOutput)
1538+
1539+
// Test without WithFileOutput option
1540+
outputWithoutFileOption, err := client.RunWithOptions(ctx, "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil)
1541+
1542+
require.NoError(t, err)
1543+
assert.NotNil(t, outputWithoutFileOption)
1544+
1545+
// Check if the image output remains a URL string
1546+
imageOutputURL, ok := outputWithoutFileOption.(map[string]interface{})["image"].(string)
1547+
require.True(t, ok, "Expected image output to be string")
1548+
assert.Equal(t, mockServer.URL+"/output.png", imageOutputURL)
1549+
}
1550+
14781551
func TestStream(t *testing.T) {
14791552
tokens := []string{"Alpha", "Bravo", "Charlie", "Delta", "Echo"}
14801553

run.go

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,45 @@
11
package replicate
22

33
import (
4+
"bytes"
45
"context"
6+
"encoding/base64"
57
"errors"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
"net/url"
12+
"strings"
613
)
714

8-
func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (PredictionOutput, error) {
15+
// RunOption is a function that modifies RunOptions
16+
type RunOption func(*runOptions)
17+
18+
// runOptions represents options for running a model
19+
type runOptions struct {
20+
useFileOutput bool
21+
}
22+
23+
// FileOutput is a custom type that implements io.ReadCloser and includes a URL field
24+
type FileOutput struct {
25+
io.ReadCloser
26+
URL string
27+
}
28+
29+
// WithFileOutput sets the UseFileOutput option to true
30+
func WithFileOutput() RunOption {
31+
return func(o *runOptions) {
32+
o.useFileOutput = true
33+
}
34+
}
35+
36+
// RunWithOptions runs a model with specified options
37+
func (r *Client) RunWithOptions(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook, opts ...RunOption) (PredictionOutput, error) {
38+
options := runOptions{}
39+
for _, opt := range opts {
40+
opt(&options)
41+
}
42+
943
id, err := ParseIdentifier(identifier)
1044
if err != nil {
1145
return nil, err
@@ -29,5 +63,96 @@ func (r *Client) Run(ctx context.Context, identifier string, input PredictionInp
2963
return nil, &ModelError{Prediction: prediction}
3064
}
3165

66+
if options.useFileOutput {
67+
return transformOutput(ctx, prediction.Output, r)
68+
}
69+
3270
return prediction.Output, nil
3371
}
72+
73+
// Run runs a model and returns the output
74+
func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (PredictionOutput, error) {
75+
return r.RunWithOptions(ctx, identifier, input, webhook)
76+
}
77+
78+
func transformOutput(ctx context.Context, value interface{}, client *Client) (interface{}, error) {
79+
var err error
80+
switch v := value.(type) {
81+
case map[string]interface{}:
82+
for k, val := range v {
83+
v[k], err = transformOutput(ctx, val, client)
84+
if err != nil {
85+
return nil, err
86+
}
87+
}
88+
return v, nil
89+
case []interface{}:
90+
for i, val := range v {
91+
v[i], err = transformOutput(ctx, val, client)
92+
if err != nil {
93+
return nil, err
94+
}
95+
}
96+
return v, nil
97+
case string:
98+
if strings.HasPrefix(v, "data:") {
99+
return readDataURI(v)
100+
}
101+
if strings.HasPrefix(v, "https:") || strings.HasPrefix(v, "http:") {
102+
return readHTTP(ctx, v, client)
103+
}
104+
return v, nil
105+
}
106+
return value, nil
107+
}
108+
109+
func readDataURI(uri string) (*FileOutput, error) {
110+
u, err := url.Parse(uri)
111+
if err != nil {
112+
return nil, err
113+
}
114+
if u.Scheme != "data" {
115+
return nil, errors.New("not a data URI")
116+
}
117+
mediatype, data, found := strings.Cut(u.Opaque, ",")
118+
if !found {
119+
return nil, errors.New("invalid data URI format")
120+
}
121+
var reader io.Reader
122+
if strings.HasSuffix(mediatype, ";base64") {
123+
decoded, err := base64.StdEncoding.DecodeString(data)
124+
if err != nil {
125+
return nil, err
126+
}
127+
reader = bytes.NewReader(decoded)
128+
} else {
129+
reader = strings.NewReader(data)
130+
}
131+
return &FileOutput{
132+
ReadCloser: io.NopCloser(reader),
133+
URL: uri,
134+
}, nil
135+
}
136+
137+
func readHTTP(ctx context.Context, url string, client *Client) (*FileOutput, error) {
138+
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
139+
if err != nil {
140+
return nil, err
141+
}
142+
resp, err := client.c.Do(req)
143+
if err != nil {
144+
return nil, err
145+
}
146+
if resp == nil || resp.Body == nil {
147+
return nil, errors.New("HTTP request failed to get a response")
148+
}
149+
if resp.StatusCode != http.StatusOK {
150+
resp.Body.Close()
151+
return nil, fmt.Errorf("HTTP request failed with status code %d", resp.StatusCode)
152+
}
153+
154+
return &FileOutput{
155+
ReadCloser: resp.Body,
156+
URL: url,
157+
}, nil
158+
}

0 commit comments

Comments
 (0)