-
Notifications
You must be signed in to change notification settings - Fork 413
[inference provider] Add wavespeed.ai as an inference provider #1424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
arabot777
wants to merge
33
commits into
huggingface:main
Choose a base branch
from
arabot777:feat/wavespeedai
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
a4d8504
add wavespeed.ai as an inference provider
arabot777 686931e
delete debug log
arabot777 0e71b88
Merge branch 'main' into feat/wavespeedai
arabot777 4461225
Merge branch 'main' into feat/wavespeedai
arabot777 e0bf580
Merge branch 'main' into feat/wavespeedai
arabot777 07af35f
Merge branch 'main' into feat/wavespeedai
arabot777 fa3afa4
Merge branch 'main' into feat/wavespeedai
arabot777 214ff99
support lora
arabot777 47c64c6
code review
arabot777 7270c5c
Merge branch 'main' into feat/wavespeedai
arabot777 ba35791
code review
arabot777 ca35eab
Merge branch 'main' into feat/wavespeedai
arabot777 80d4640
delete unused import
arabot777 77be0c6
Merge branch 'main' into feat/wavespeedai
arabot777 0c77b3b
Update packages/inference/src/lib/getProviderHelper.ts
arabot777 3ab254e
Update packages/inference/src/lib/getProviderHelper.ts
arabot777 a8fe74c
Merge branch 'main' into feat/wavespeedai
arabot777 f706e02
Merge branch 'main' into feat/wavespeedai
arabot777 47f41f0
code review modification
arabot777 0cfefe8
Merge branch 'main' into feat/wavespeedai
arabot777 f162e89
Merge branch 'main' into feat/wavespeedai
arabot777 b23a000
Merge branch 'main' into feat/wavespeedai
arabot777 6cabc5a
import js file
arabot777 71e4939
Merge branch 'main' into feat/wavespeedai
arabot777 6341233
Merge branch 'main' into feat/wavespeedai
arabot777 bf5ccb4
lora optimize and image-to-image getresponse use header
arabot777 554bd19
Merge branch 'main' into feat/wavespeedai
arabot777 8507385
Merge branch 'main' into feat/wavespeedai
arabot777 054ecb9
Merge branch 'main' into feat/wavespeedai
arabot777 1a1f672
Merge branch 'main' into feat/wavespeedai
arabot777 1b407f3
Merge branch 'main' into feat/wavespeedai
arabot777 4a71a4b
handle inference error; upgrade api v2 -> v3
arabot777 839e940
recode import
arabot777 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import type { TextToImageArgs } from "../tasks/cv/textToImage.js"; | ||
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js"; | ||
import type { TextToVideoArgs } from "../tasks/cv/textToVideo.js"; | ||
import type { BodyParams, RequestArgs, UrlParams } from "../types.js"; | ||
import { delay } from "../utils/delay.js"; | ||
import { omit } from "../utils/omit.js"; | ||
import { base64FromBytes } from "../utils/base64FromBytes.js"; | ||
import type { TextToImageTaskHelper, TextToVideoTaskHelper, ImageToImageTaskHelper } from "./providerHelper.js"; | ||
import { TaskProviderHelper } from "./providerHelper.js"; | ||
import { | ||
InferenceClientInputError, | ||
InferenceClientProviderApiError, | ||
InferenceClientProviderOutputError, | ||
} from "../errors.js"; | ||
|
||
const WAVESPEEDAI_API_BASE_URL = "https://api.wavespeed.ai"; | ||
|
||
/** | ||
* Response structure for task status and results | ||
*/ | ||
interface WaveSpeedAITaskResponse { | ||
id: string; | ||
model: string; | ||
outputs: string[]; | ||
urls: { | ||
get: string; | ||
}; | ||
has_nsfw_contents: boolean[]; | ||
status: "created" | "processing" | "completed" | "failed"; | ||
created_at: string; | ||
error: string; | ||
executionTime: number; | ||
timings: { | ||
inference: number; | ||
}; | ||
} | ||
|
||
/** | ||
* Response structure for initial task submission | ||
*/ | ||
interface WaveSpeedAISubmitResponse { | ||
id: string; | ||
urls: { | ||
get: string; | ||
}; | ||
} | ||
|
||
/** | ||
* Response structure for WaveSpeed AI API | ||
*/ | ||
interface WaveSpeedAIResponse { | ||
code: number; | ||
message: string; | ||
data: WaveSpeedAITaskResponse; | ||
} | ||
|
||
/** | ||
* Response structure for WaveSpeed AI API with submit response data | ||
*/ | ||
interface WaveSpeedAISubmitTaskResponse { | ||
code: number; | ||
message: string; | ||
data: WaveSpeedAISubmitResponse; | ||
} | ||
|
||
abstract class WavespeedAITask extends TaskProviderHelper { | ||
constructor(url?: string) { | ||
super("wavespeed-ai", url || WAVESPEEDAI_API_BASE_URL); | ||
} | ||
|
||
makeRoute(params: UrlParams): string { | ||
return `/api/v3/${params.model}`; | ||
} | ||
|
||
preparePayload(params: BodyParams<ImageToImageArgs | TextToImageArgs | TextToVideoArgs>): Record<string, unknown> { | ||
const payload: Record<string, unknown> = { | ||
...omit(params.args, ["inputs", "parameters"]), | ||
...params.args.parameters, | ||
prompt: params.args.inputs, | ||
}; | ||
// Add LoRA support if adapter is specified in the mapping | ||
if (params.mapping?.adapter === "lora") { | ||
payload.loras = [ | ||
{ | ||
path: params.mapping.hfModelId, | ||
scale: 1, // Default scale value | ||
}, | ||
]; | ||
} | ||
return payload; | ||
} | ||
|
||
override async getResponse( | ||
response: WaveSpeedAISubmitTaskResponse, | ||
url?: string, | ||
headers?: Record<string, string> | ||
): Promise<Blob> { | ||
if (!headers) { | ||
throw new InferenceClientInputError("Headers are required for WaveSpeed AI API calls"); | ||
} | ||
|
||
const resultUrl = response.data.urls.get; | ||
|
||
// Poll for results until completion | ||
while (true) { | ||
const resultResponse = await fetch(resultUrl, { headers }); | ||
|
||
if (!resultResponse.ok) { | ||
throw new InferenceClientProviderApiError( | ||
"Failed to fetch response status from WaveSpeed AI API", | ||
{ url: resultUrl, method: "GET" }, | ||
{ | ||
requestId: resultResponse.headers.get("x-request-id") ?? "", | ||
status: resultResponse.status, | ||
body: await resultResponse.text(), | ||
} | ||
); | ||
} | ||
|
||
const result: WaveSpeedAIResponse = await resultResponse.json(); | ||
if (result.code !== 200) { | ||
throw new InferenceClientProviderOutputError( | ||
`API request to WaveSpeed AI API failed with code ${result.code}: ${result.message}` | ||
); | ||
} | ||
|
||
const taskResult = result.data; | ||
|
||
switch (taskResult.status) { | ||
case "completed": { | ||
// Get the media data from the first output URL | ||
if (!taskResult.outputs?.[0]) { | ||
throw new InferenceClientProviderOutputError( | ||
"Received malformed response from WaveSpeed AI API: No output URL in completed response" | ||
); | ||
} | ||
const mediaResponse = await fetch(taskResult.outputs[0]); | ||
if (!mediaResponse.ok) { | ||
throw new InferenceClientProviderApiError( | ||
"Failed to fetch response status from WaveSpeed AI API", | ||
{ url: taskResult.outputs[0], method: "GET" }, | ||
{ | ||
requestId: mediaResponse.headers.get("x-request-id") ?? "", | ||
status: mediaResponse.status, | ||
body: await mediaResponse.text(), | ||
} | ||
); | ||
} | ||
return await mediaResponse.blob(); | ||
} | ||
case "failed": { | ||
throw new InferenceClientProviderOutputError(taskResult.error || "Task failed"); | ||
} | ||
|
||
default: { | ||
// Wait before polling again | ||
await delay(500); | ||
continue; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
export class WavespeedAITextToImageTask extends WavespeedAITask implements TextToImageTaskHelper { | ||
constructor() { | ||
super(WAVESPEEDAI_API_BASE_URL); | ||
} | ||
} | ||
|
||
export class WavespeedAITextToVideoTask extends WavespeedAITask implements TextToVideoTaskHelper { | ||
constructor() { | ||
super(WAVESPEEDAI_API_BASE_URL); | ||
} | ||
} | ||
|
||
export class WavespeedAIImageToImageTask extends WavespeedAITask implements ImageToImageTaskHelper { | ||
constructor() { | ||
super(WAVESPEEDAI_API_BASE_URL); | ||
} | ||
|
||
hanouticelina marked this conversation as resolved.
Show resolved
Hide resolved
|
||
async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> { | ||
return { | ||
...args, | ||
inputs: args.parameters?.prompt, | ||
hanouticelina marked this conversation as resolved.
Show resolved
Hide resolved
|
||
image: base64FromBytes( | ||
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer()) | ||
), | ||
}; | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.