diff --git a/packages/inference/README.md b/packages/inference/README.md index 5425204d05..7a714fa425 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -63,6 +63,7 @@ Currently, we support the following providers: - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) - [Groq](https://groq.com) +- [Wavespeed.ai](https://wavespeed.ai/) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers. @@ -97,6 +98,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) - [Groq supported models](https://console.groq.com/docs/models) - [Novita AI supported models](https://huggingface.co/api/partners/novita/models) +- [Wavespeed.ai supported models](https://huggingface.co/api/partners/wavespeed-ai/models) ❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type. This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you! diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index ad0eb89a6d..34d6a718cc 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -47,6 +47,7 @@ import type { import * as Replicate from "../providers/replicate.js"; import * as Sambanova from "../providers/sambanova.js"; import * as Together from "../providers/together.js"; +import * as WavespeedAI from "../providers/wavespeed-ai.js"; import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types.js"; import { InferenceClientInputError } from "../errors.js"; @@ -148,6 +149,11 @@ export const PROVIDERS: Record): Record { + const payload: Record = { + ...omit(params.args, ["inputs", "parameters"]), + ...params.args.parameters, + prompt: params.args.inputs, + }; + // Add LoRA support if adapter is specified in the mapping + if (params.mapping?.adapter === "lora") { + payload.loras = [ + { + path: params.mapping.hfModelId, + scale: 1, // Default scale value + }, + ]; + } + return payload; + } + + override async getResponse( + response: WaveSpeedAISubmitTaskResponse, + url?: string, + headers?: Record + ): Promise { + if (!headers) { + throw new InferenceClientInputError("Headers are required for WaveSpeed AI API calls"); + } + + const resultUrl = response.data.urls.get; + + // Poll for results until completion + while (true) { + const resultResponse = await fetch(resultUrl, { headers }); + + if (!resultResponse.ok) { + throw new InferenceClientProviderApiError( + "Failed to fetch response status from WaveSpeed AI API", + { url: resultUrl, method: "GET" }, + { + requestId: resultResponse.headers.get("x-request-id") ?? "", + status: resultResponse.status, + body: await resultResponse.text(), + } + ); + } + + const result: WaveSpeedAIResponse = await resultResponse.json(); + if (result.code !== 200) { + throw new InferenceClientProviderOutputError( + `API request to WaveSpeed AI API failed with code ${result.code}: ${result.message}` + ); + } + + const taskResult = result.data; + + switch (taskResult.status) { + case "completed": { + // Get the media data from the first output URL + if (!taskResult.outputs?.[0]) { + throw new InferenceClientProviderOutputError( + "Received malformed response from WaveSpeed AI API: No output URL in completed response" + ); + } + const mediaResponse = await fetch(taskResult.outputs[0]); + if (!mediaResponse.ok) { + throw new InferenceClientProviderApiError( + "Failed to fetch response status from WaveSpeed AI API", + { url: taskResult.outputs[0], method: "GET" }, + { + requestId: mediaResponse.headers.get("x-request-id") ?? "", + status: mediaResponse.status, + body: await mediaResponse.text(), + } + ); + } + return await mediaResponse.blob(); + } + case "failed": { + throw new InferenceClientProviderOutputError(taskResult.error || "Task failed"); + } + + default: { + // Wait before polling again + await delay(500); + continue; + } + } + } + } +} + +export class WavespeedAITextToImageTask extends WavespeedAITask implements TextToImageTaskHelper { + constructor() { + super(WAVESPEEDAI_API_BASE_URL); + } +} + +export class WavespeedAITextToVideoTask extends WavespeedAITask implements TextToVideoTaskHelper { + constructor() { + super(WAVESPEEDAI_API_BASE_URL); + } +} + +export class WavespeedAIImageToImageTask extends WavespeedAITask implements ImageToImageTaskHelper { + constructor() { + super(WAVESPEEDAI_API_BASE_URL); + } + + async preparePayloadAsync(args: ImageToImageArgs): Promise { + return { + ...args, + inputs: args.parameters?.prompt, + image: base64FromBytes( + new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer()) + ), + }; + } +} diff --git a/packages/inference/src/tasks/cv/imageToImage.ts b/packages/inference/src/tasks/cv/imageToImage.ts index 4405dd2cb2..1266cc5452 100644 --- a/packages/inference/src/tasks/cv/imageToImage.ts +++ b/packages/inference/src/tasks/cv/imageToImage.ts @@ -3,6 +3,7 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js"; import { getProviderHelper } from "../../lib/getProviderHelper.js"; import type { BaseArgs, Options } from "../../types.js"; import { innerRequest } from "../../utils/request.js"; +import { makeRequestOptions } from "../../lib/makeRequestOptions.js"; export type ImageToImageArgs = BaseArgs & ImageToImageInput; @@ -18,5 +19,6 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P ...options, task: "image-to-image", }); - return providerHelper.getResponse(res); + const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "image-to-image" }); + return providerHelper.getResponse(res, url, info.headers as Record); } diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index e2f9682c97..69e94c5a26 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -55,6 +55,7 @@ export const INFERENCE_PROVIDERS = [ "replicate", "sambanova", "together", + "wavespeed-ai", ] as const; export const PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"] as const; diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 602e034cd4..4389123790 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -2023,4 +2023,112 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + describe.concurrent( + "Wavespeed AI", + () => { + const client = new InferenceClient(env.HF_WAVESPEED_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["wavespeed-ai"] = { + "wavespeed-ai/flux-schnell": { + hfModelId: "wavespeed-ai/flux-schnell", + providerId: "wavespeed-ai/flux-schnell", + status: "live", + task: "text-to-image", + }, + "wavespeed-ai/wan-2.1/t2v-480p": { + hfModelId: "wavespeed-ai/wan-2.1/t2v-480p", + providerId: "wavespeed-ai/wan-2.1/t2v-480p", + status: "live", + task: "text-to-video", + }, + "wavespeed-ai/hidream-e1-full": { + hfModelId: "wavespeed-ai/hidream-e1-full", + providerId: "wavespeed-ai/hidream-e1-full", + status: "live", + task: "image-to-image", + }, + "openfree/flux-chatgpt-ghibli-lora": { + hfModelId: "openfree/flux-chatgpt-ghibli-lora", + providerId: "wavespeed-ai/flux-dev-lora", + status: "live", + task: "text-to-image", + adapter: "lora", + adapterWeightsPath: "openfree/flux-chatgpt-ghibli-lora", + }, + "linoyts/yarn_art_Flux_LoRA": { + hfModelId: "linoyts/yarn_art_Flux_LoRA", + providerId: "wavespeed-ai/flux-dev-lora-ultra-fast", + status: "live", + task: "text-to-image", + adapter: "lora", + adapterWeightsPath: "linoyts/yarn_art_Flux_LoRA", + }, + }; + + it(`textToImage - wavespeed-ai/flux-schnell`, async () => { + const res = await client.textToImage({ + model: "wavespeed-ai/flux-schnell", + provider: "wavespeed-ai", + inputs: + "Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it(`textToImage - openfree/flux-chatgpt-ghibli-lora`, async () => { + const res = await client.textToImage({ + model: "openfree/flux-chatgpt-ghibli-lora", + provider: "wavespeed-ai", + inputs: + "Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it(`textToImage - linoyts/yarn_art_Flux_LoRA`, async () => { + const res = await client.textToImage({ + model: "linoyts/yarn_art_Flux_LoRA", + provider: "wavespeed-ai", + inputs: + "Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it(`textToVideo - wavespeed-ai/wan-2.1/t2v-480p`, async () => { + const res = await client.textToVideo({ + model: "wavespeed-ai/wan-2.1/t2v-480p", + provider: "wavespeed-ai", + inputs: + "A cool street dancer, wearing a baggy hoodie and hip-hop pants, dancing in front of a graffiti wall, night neon background, quick camera cuts, urban trends.", + parameters: { + guidance_scale: 5, + num_inference_steps: 30, + seed: -1, + }, + duration: 5, + enable_safety_checker: true, + flow_shift: 2.9, + size: "480*832", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it(`imageToImage - wavespeed-ai/hidream-e1-full`, async () => { + const res = await client.imageToImage({ + model: "wavespeed-ai/hidream-e1-full", + provider: "wavespeed-ai", + inputs: new Blob([readTestFile("cheetah.png")], { type: "image / png" }), + parameters: { + prompt: "The leopard chases its prey", + guidance_scale: 5, + num_inference_steps: 30, + seed: -1, + }, + }); + expect(res).toBeInstanceOf(Blob); + }); + }, + 60000 * 5 + ); });