-
Notifications
You must be signed in to change notification settings - Fork 514
[inference provider] Add wavespeed.ai as an inference provider #1424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 14 commits
a4d8504
686931e
0e71b88
4461225
e0bf580
07af35f
fa3afa4
214ff99
47c64c6
7270c5c
ba35791
ca35eab
80d4640
77be0c6
0c77b3b
3ab254e
a8fe74c
f706e02
47f41f0
0cfefe8
f162e89
b23a000
6cabc5a
71e4939
6341233
bf5ccb4
554bd19
8507385
054ecb9
1a1f672
1b407f3
4a71a4b
839e940
64a991d
fd20f75
98465e2
4e4ca9c
76344db
f145274
dc66fd4
da95bc3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,178 @@ | ||||||||||||||||||||||||||||||
import { InferenceOutputError } from "../lib/InferenceOutputError"; | ||||||||||||||||||||||||||||||
import type { ImageToImageArgs } from "../tasks"; | ||||||||||||||||||||||||||||||
import type { BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types"; | ||||||||||||||||||||||||||||||
import { delay } from "../utils/delay"; | ||||||||||||||||||||||||||||||
import { omit } from "../utils/omit"; | ||||||||||||||||||||||||||||||
import { base64FromBytes } from "../utils/base64FromBytes"; | ||||||||||||||||||||||||||||||
import { | ||||||||||||||||||||||||||||||
TaskProviderHelper, | ||||||||||||||||||||||||||||||
TextToImageTaskHelper, | ||||||||||||||||||||||||||||||
TextToVideoTaskHelper, | ||||||||||||||||||||||||||||||
ImageToImageTaskHelper, | ||||||||||||||||||||||||||||||
} from "./providerHelper"; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
const WAVESPEEDAI_API_BASE_URL = "https://api.wavespeed.ai"; | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||
* Common response structure for all WaveSpeed AI API responses | ||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||
interface WaveSpeedAICommonResponse<T> { | ||||||||||||||||||||||||||||||
code: number; | ||||||||||||||||||||||||||||||
message: string; | ||||||||||||||||||||||||||||||
data: T; | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
/** | |
* Common response structure for all WaveSpeed AI API responses | |
*/ | |
interface WaveSpeedAICommonResponse<T> { | |
code: number; | |
message: string; | |
data: T; | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been modified as suggested
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this type alias is needed, can we remove it?
type WaveSpeedAIResponse<T = WaveSpeedAITaskResponse> = WaveSpeedAICommonResponse<T>; |
WaveSpeedAICommonResponse
can be renamed to WaveSpeedAIResponse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This type is needed and will be used in two places. It's uncertain whether it will be used again in the future.
It follows the DRY (Don't Repeat Yourself) principle
It provides better type safety (through default generic parameters)
It makes the code more readable and maintainable
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the previous comment - let's remove one level of abstraction
type WaveSpeedAIResponse<T = WaveSpeedAITaskResponse> = WaveSpeedAICommonResponse<T>; | |
interface WaveSpeedAIResponse { | |
code: number; | |
message: string; | |
data: WaveSpeedAITaskResponse | |
} | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been modified as suggested
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to cast into Result<string, unknown>
if the params
have the proper type
ImageToImageArgs
, TextToImageArgs
, and TextToVideoArgs
need to be improrted from "../tasks"
preparePayload(params: BodyParams): Record<string, unknown> { | |
const payload: Record<string, unknown> = { | |
...omit(params.args, ["inputs", "parameters"]), | |
...(params.args.parameters as Record<string, unknown>), | |
prompt: params.args.inputs, | |
}; | |
// Add LoRA support if adapter is specified in the mapping | |
preparePayload(params: BodyParams<ImageToImageArgs | TextToImageArgs | TextToVideoArgs>): Record<string, unknown> { | |
const payload: Record<string, unknown> = { | |
...omit(params.args, ["inputs", "parameters"]), | |
...params.args.parameters, | |
prompt: params.args.inputs, | |
}; | |
// Add LoRA support if adapter is specified in the mapping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been modified as suggested
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reference, adapterWeightsPath
is the path to the LoRA weights inside the associated HF repo
eg, for nerijs/pixel-art-xl, it will be
"pixel-art-xl.safetensors"
Let's make sure that is indeed what your API is expecting when running LoRAs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I see that fal is the endpoint that has been concatenated with hf.
Can I directly set the adapterWeightsPath to a lora http address? Or any other address.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the test cases, I conducted the test in this way. The adapterWeightsPath
was directly passed over as an input parameter of lora.
"wavespeed-ai/flux-dev-lora": {
hfModelId: "wavespeed-ai/flux-dev-lora",
providerId: "wavespeed-ai/flux-dev-lora",
status: "live",
task: "text-to-image",
adapter: "lora",
adapterWeightsPath:
"https://d32s1zkpjdc4b1.cloudfront.net/predictions/599f3739f5354afc8a76a12042736bfd/1.safetensors",
},
"wavespeed-ai/flux-dev-lora-ultra-fast": {
hfModelId: "wavespeed-ai/flux-dev-lora-ultra-fast",
providerId: "wavespeed-ai/flux-dev-lora-ultra-fast",
status: "live",
task: "text-to-image",
adapter: "lora",
adapterWeightsPath: "linoyts/yarn_art_Flux_LoRA",
},
However, I'm not sure whether the input parameters submitted by hf to lora must be the abbreviation of the file path of the hf model and then concatenated with the hf address in the code. If it is this kind of specification, I can complete it in the format of fal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think your API can just take the hf model id as the loras path, right?
path: params.mapping.adapterWeightsPath, | |
path: params.mapping.hfModelId,, |
As mentioned by @SBrandeis, this part depends on what your API is expecting as inputs when using LoRAs weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you're correct.
In the example, linoyts/yarn_art_Flux_LoRA
is the lora model address of hf. We will automatically match and download the hf model。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I completed the modification and ran the use case successfully
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the same behavior as the blanket implementation here:
https://github.com/arabot777/huggingface.js/blob/f706e02d6128f559bd5551072344ff6e31b9c4be/packages/inference/src/providers/providerHelper.ts#L114-L124
No need for an override IMO
override prepareHeaders(params: HeaderParams, isBinary: boolean): Record<string, string> { | |
this.accessToken = params.accessToken; | |
const headers: Record<string, string> = { Authorization: `Bearer ${params.accessToken}` }; | |
if (!isBinary) { | |
headers["Content-Type"] = "application/json"; | |
} | |
return headers; | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this part of the logic at the beginning. However, the getresponse
method of imageToimage.ts
did not pass in header information.
I have to rewrite prepareHeaders here and by assignment
this.accessToken = params.accessToken;
To ensure that the complete ak information of the header can be passed on when calling getresponse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather update ImageToImage
to be able to pass headers to getResponse
:
export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "image-to-image");
const payload = await providerHelper.preparePayloadAsync(args);
const { data: res } = await innerRequest<Blob>(payload, providerHelper, {
...options,
task: "image-to-image",
});
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "image-to-image" });
return providerHelper.getResponse(res, url, info.headers as Record<string, string>);
}
rather than overriding prepareHeaders
and doing this.accessToken = params.accessToken
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your suggestion makes sense. Initially, this was a common/public function, so I took a minimalistic approach and didn't modify it. Now, let me try making some changes here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I completed the modification and ran the use case successfully
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw new InferenceOutputError("Headers are required for WaveSpeed AI API calls"); | |
throw new InferenceClientInputError("Headers are required for WaveSpeed AI API calls"); |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw new InferenceOutputError(`Failed to get result: ${resultResponse.statusText}`); | |
throw new InferenceClientProviderApiError( | |
"Failed to fetch response status from WaveSpeed AI API", | |
{ url: resultUrl, method: "GET" }, | |
{ | |
requestId: resultResponse.headers.get("x-request-id") ?? "", | |
body: await resultResponse.text(), | |
} | |
); |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw new InferenceOutputError(`API request failed with code ${result.code}: ${result.message}`); | |
throw new InferenceClientProviderOutputError(`API request to WaveSpeed AI API failed with code ${result.code}: ${result.message}`); |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw new InferenceOutputError("No output URL in completed response"); | |
throw new InferenceClientProviderOutputError("Received malformed response from WaveSpeed AI API: No output URL in completed response"); |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw new InferenceOutputError("Failed to fetch output data"); | |
throw new InferenceClientProviderApiError( | |
"Failed to fetch response status from WaveSpeed AI API", | |
{ url: taskResult.outputs[0], method: "GET" }, | |
{ | |
requestId: mediaResponse.headers.get("x-request-id") ?? "", | |
body: await mediaResponse.text(), | |
} | |
); |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw new InferenceOutputError(taskResult.error || "Task failed"); | |
throw new InferenceClientProviderOutputError(taskResult.error || "Task failed"); |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw new InferenceOutputError(`Unknown status: ${taskResult.status}`); | |
throw new InferenceClientProviderOutputError(`Unknown status: ${taskResult.status}`); |
hanouticelina marked this conversation as resolved.
Show resolved
Hide resolved
hanouticelina marked this conversation as resolved.
Show resolved
Hide resolved
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2023,4 +2023,113 @@ describe.skip("InferenceClient", () => { | |||||
}, | ||||||
TIMEOUT | ||||||
); | ||||||
describe.concurrent( | ||||||
"Wavespeed AI", | ||||||
() => { | ||||||
const client = new InferenceClient(env.HF_WAVESPEED_KEY ?? "dummy"); | ||||||
|
||||||
HARDCODED_MODEL_INFERENCE_MAPPING["wavespeed-ai"] = { | ||||||
"wavespeed-ai/flux-schnell": { | ||||||
hfModelId: "wavespeed-ai/flux-schnell", | ||||||
providerId: "wavespeed-ai/flux-schnell", | ||||||
status: "live", | ||||||
task: "text-to-image", | ||||||
}, | ||||||
"wavespeed-ai/wan-2.1/t2v-480p": { | ||||||
hfModelId: "wavespeed-ai/wan-2.1/t2v-480p", | ||||||
providerId: "wavespeed-ai/wan-2.1/t2v-480p", | ||||||
status: "live", | ||||||
task: "text-to-video", | ||||||
}, | ||||||
"wavespeed-ai/hidream-e1-full": { | ||||||
hfModelId: "wavespeed-ai/hidream-e1-full", | ||||||
providerId: "wavespeed-ai/hidream-e1-full", | ||||||
status: "live", | ||||||
task: "image-to-image", | ||||||
}, | ||||||
"wavespeed-ai/flux-dev-lora": { | ||||||
hfModelId: "wavespeed-ai/flux-dev-lora", | ||||||
providerId: "wavespeed-ai/flux-dev-lora", | ||||||
status: "live", | ||||||
task: "text-to-image", | ||||||
adapter: "lora", | ||||||
adapterWeightsPath: | ||||||
"https://d32s1zkpjdc4b1.cloudfront.net/predictions/599f3739f5354afc8a76a12042736bfd/1.safetensors", | ||||||
}, | ||||||
"wavespeed-ai/flux-dev-lora-ultra-fast": { | ||||||
hfModelId: "wavespeed-ai/flux-dev-lora-ultra-fast", | ||||||
providerId: "wavespeed-ai/flux-dev-lora-ultra-fast", | ||||||
status: "live", | ||||||
task: "text-to-image", | ||||||
adapter: "lora", | ||||||
adapterWeightsPath: "linoyts/yarn_art_Flux_LoRA", | ||||||
}, | ||||||
}; | ||||||
|
||||||
it(`textToImage - wavespeed-ai/flux-schnell`, async () => { | ||||||
const res = await client.textToImage({ | ||||||
model: "wavespeed-ai/flux-schnell", | ||||||
|
model: "wavespeed-ai/flux-schnell", | |
model: "black-forest-labs/FLUX.1-schnell", |
Uh oh!
There was an error while loading. Please reload this page.