@@ -205,6 +205,60 @@ func TestListModels(t *testing.T) {
205205 assert .Equal (t , "codellama-13b" , modelsPage .Results [1 ].Name )
206206}
207207
208+ func TestSearchModels (t * testing.T ) {
209+ mockServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
210+ assert .Equal (t , "/models" , r .URL .Path )
211+ assert .Equal (t , "QUERY" , r .Method )
212+ assert .Equal (t , "text/plain" , r .Header .Get ("Content-Type" ))
213+
214+ body , err := io .ReadAll (r .Body )
215+ if err != nil {
216+ t .Fatal (err )
217+ }
218+ defer r .Body .Close ()
219+
220+ assert .Equal (t , "stable diffusion" , string (body ))
221+
222+ response := replicate.Page [replicate.Model ]{
223+ Results : []replicate.Model {
224+ {
225+ Owner : "stability-ai" ,
226+ Name : "sdxl" ,
227+ Description : "A text-to-image generative AI model that creates beautiful 1024x1024 images" ,
228+ },
229+ {
230+ Owner : "stability-ai" ,
231+ Name : "stable-diffusion" ,
232+ Description : "Stable Diffusion is a text-to-image diffusion model" ,
233+ },
234+ },
235+ }
236+
237+ w .Header ().Set ("Content-Type" , "application/json" )
238+ w .WriteHeader (http .StatusOK )
239+ json .NewEncoder (w ).Encode (response )
240+ }))
241+ defer mockServer .Close ()
242+
243+ client , err := replicate .NewClient (
244+ replicate .WithToken ("test-token" ),
245+ replicate .WithBaseURL (mockServer .URL ),
246+ )
247+ require .NotNil (t , client )
248+ require .NoError (t , err )
249+
250+ ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
251+ defer cancel ()
252+
253+ modelsPage , err := client .SearchModels (ctx , "stable diffusion" )
254+ assert .NoError (t , err )
255+ assert .Equal (t , 2 , len (modelsPage .Results ))
256+ assert .Equal (t , "stability-ai" , modelsPage .Results [0 ].Owner )
257+ assert .Equal (t , "sdxl" , modelsPage .Results [0 ].Name )
258+ assert .Equal (t , "stability-ai" , modelsPage .Results [1 ].Owner )
259+ assert .Equal (t , "stable-diffusion" , modelsPage .Results [1 ].Name )
260+ }
261+
208262func TestGetModel (t * testing.T ) {
209263 mockServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
210264 assert .Equal (t , "/models/replicate/hello-world" , r .URL .Path )
0 commit comments