Skip to content

Commit 972c92e

Browse files
authored
Add support for models.search endpoint (#74)
1 parent 8328597 commit 972c92e

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed

client_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,60 @@ func TestListModels(t *testing.T) {
205205
assert.Equal(t, "codellama-13b", modelsPage.Results[1].Name)
206206
}
207207

208+
func TestSearchModels(t *testing.T) {
209+
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
210+
assert.Equal(t, "/models", r.URL.Path)
211+
assert.Equal(t, "QUERY", r.Method)
212+
assert.Equal(t, "text/plain", r.Header.Get("Content-Type"))
213+
214+
body, err := io.ReadAll(r.Body)
215+
if err != nil {
216+
t.Fatal(err)
217+
}
218+
defer r.Body.Close()
219+
220+
assert.Equal(t, "stable diffusion", string(body))
221+
222+
response := replicate.Page[replicate.Model]{
223+
Results: []replicate.Model{
224+
{
225+
Owner: "stability-ai",
226+
Name: "sdxl",
227+
Description: "A text-to-image generative AI model that creates beautiful 1024x1024 images",
228+
},
229+
{
230+
Owner: "stability-ai",
231+
Name: "stable-diffusion",
232+
Description: "Stable Diffusion is a text-to-image diffusion model",
233+
},
234+
},
235+
}
236+
237+
w.Header().Set("Content-Type", "application/json")
238+
w.WriteHeader(http.StatusOK)
239+
json.NewEncoder(w).Encode(response)
240+
}))
241+
defer mockServer.Close()
242+
243+
client, err := replicate.NewClient(
244+
replicate.WithToken("test-token"),
245+
replicate.WithBaseURL(mockServer.URL),
246+
)
247+
require.NotNil(t, client)
248+
require.NoError(t, err)
249+
250+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
251+
defer cancel()
252+
253+
modelsPage, err := client.SearchModels(ctx, "stable diffusion")
254+
assert.NoError(t, err)
255+
assert.Equal(t, 2, len(modelsPage.Results))
256+
assert.Equal(t, "stability-ai", modelsPage.Results[0].Owner)
257+
assert.Equal(t, "sdxl", modelsPage.Results[0].Name)
258+
assert.Equal(t, "stability-ai", modelsPage.Results[1].Owner)
259+
assert.Equal(t, "stable-diffusion", modelsPage.Results[1].Name)
260+
}
261+
208262
func TestGetModel(t *testing.T) {
209263
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
210264
assert.Equal(t, "/models/replicate/hello-world", r.URL.Path)

example_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package replicate_test
33
import (
44
"context"
55
"fmt"
6+
"strings"
67

78
"github.com/replicate/replicate-go"
89
)
@@ -59,3 +60,26 @@ func ExampleClient_CreatePrediction() {
5960
fmt.Println(prediction.Status)
6061
// Output: succeeded
6162
}
63+
64+
func ExampleClient_SearchModels() {
65+
ctx := context.TODO()
66+
67+
r8, err := replicate.NewClient(replicate.WithTokenFromEnv())
68+
if err != nil {
69+
panic(err)
70+
}
71+
72+
query := "llama"
73+
modelsPage, err := r8.SearchModels(ctx, query)
74+
if err != nil {
75+
panic(err)
76+
}
77+
78+
for _, model := range modelsPage.Results {
79+
if model.Owner == "meta" && strings.HasPrefix(model.Name, "meta-llama-3") {
80+
fmt.Printf("Found Meta Llama 3 model")
81+
break
82+
}
83+
}
84+
// Output: Found Meta Llama 3 model
85+
}

model.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"net/http"
8+
"strings"
89
)
910

1011
type Model struct {
@@ -79,6 +80,21 @@ func (r *Client) ListModels(ctx context.Context) (*Page[Model], error) {
7980
return response, nil
8081
}
8182

83+
// SearchModels searches for public models.
84+
func (r *Client) SearchModels(ctx context.Context, query string) (*Page[Model], error) {
85+
response := &Page[Model]{}
86+
request, err := r.newRequest(ctx, "QUERY", "/models", strings.NewReader(query))
87+
if err != nil {
88+
return nil, fmt.Errorf("failed to create request: %w", err)
89+
}
90+
request.Header.Set("Content-Type", "text/plain")
91+
err = r.do(request, response)
92+
if err != nil {
93+
return nil, fmt.Errorf("failed to search models: %w", err)
94+
}
95+
return response, nil
96+
}
97+
8298
// GetModel retrieves information about a model.
8399
func (r *Client) GetModel(ctx context.Context, modelOwner string, modelName string) (*Model, error) {
84100
model := &Model{}

0 commit comments

Comments
 (0)