-
Notifications
You must be signed in to change notification settings - Fork 411
[inference provider] Add wavespeed.ai as an inference provider #1424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
a4d8504
686931e
0e71b88
4461225
e0bf580
07af35f
fa3afa4
214ff99
47c64c6
7270c5c
ba35791
ca35eab
80d4640
77be0c6
0c77b3b
3ab254e
a8fe74c
f706e02
47f41f0
0cfefe8
f162e89
b23a000
6cabc5a
71e4939
6341233
bf5ccb4
554bd19
8507385
054ecb9
1a1f672
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,193 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import { InferenceOutputError } from "../lib/InferenceOutputError"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import { ImageToImageArgs } from "../tasks"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import type { BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import { delay } from "../utils/delay"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import { omit } from "../utils/omit"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import { base64FromBytes } from "../utils/base64FromBytes"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
TaskProviderHelper, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
TextToImageTaskHelper, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
TextToVideoTaskHelper, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
ImageToImageTaskHelper, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} from "./providerHelper"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modify as suggested |
||||||||||||||||||||||||||||||||||||||||||||||||||||
const WAVESPEEDAI_API_BASE_URL = "https://api.wavespeed.ai"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||||||||||||||||||||
* Common response structure for all WaveSpeed AI API responses | ||||||||||||||||||||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
interface WaveSpeedAICommonResponse<T> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
code: number; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
message: string; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
data: T; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This abstraction is not necessary IMO, let's remove it (see my other comment)
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has been modified as suggested |
||||||||||||||||||||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||||||||||||||||||||
* Response structure for task status and results | ||||||||||||||||||||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
interface WaveSpeedAITaskResponse { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
id: string; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
model: string; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
outputs: string[]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
urls: { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
get: string; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
has_nsfw_contents: boolean[]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
status: "created" | "processing" | "completed" | "failed"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
created_at: string; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
error: string; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
executionTime: number; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
timings: { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
inference: number; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||||||||||||||||||||
* Response structure for initial task submission | ||||||||||||||||||||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
interface WaveSpeedAISubmitResponse { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
id: string; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
urls: { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
get: string; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
type WaveSpeedAIResponse<T = WaveSpeedAITaskResponse> = WaveSpeedAICommonResponse<T>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this type alias is needed, can we remove it?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This type is needed and will be used in two places. It's uncertain whether it will be used again in the future. |
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following the previous comment - let's remove one level of abstraction
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has been modified as suggested |
||||||||||||||||||||||||||||||||||||||||||||||||||||
abstract class WavespeedAITask extends TaskProviderHelper { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
private accessToken: string | undefined; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
constructor(url?: string) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
super("wavespeed-ai", url || WAVESPEEDAI_API_BASE_URL); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
makeRoute(params: UrlParams): string { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return `/api/v2/${params.model}`; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
preparePayload(params: BodyParams): Record<string, unknown> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const payload: Record<string, unknown> = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
...omit(params.args, ["inputs", "parameters"]), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
...(params.args.parameters as Record<string, unknown>), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
prompt: params.args.inputs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// Add LoRA support if adapter is specified in the mapping | ||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to cast into
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has been modified as suggested |
||||||||||||||||||||||||||||||||||||||||||||||||||||
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
payload.loras = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
path: params.mapping.adapterWeightsPath, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For reference,
Let's make sure that is indeed what your API is expecting when running LoRAs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I see that fal is the endpoint that has been concatenated with hf. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the test cases, I conducted the test in this way. The
However, I'm not sure whether the input parameters submitted by hf to lora must be the abbreviation of the file path of the hf model and then concatenated with the hf address in the code. If it is this kind of specification, I can complete it in the format of fal There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think your API can just take the hf model id as the loras path, right?
Suggested change
As mentioned by @SBrandeis, this part depends on what your API is expecting as inputs when using LoRAs weights. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you're correct. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I completed the modification and ran the use case successfully |
||||||||||||||||||||||||||||||||||||||||||||||||||||
scale: 1, // Default scale value | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return payload; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
override prepareHeaders(params: HeaderParams, isBinary: boolean): Record<string, string> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
this.accessToken = params.accessToken; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const headers: Record<string, string> = { Authorization: `Bearer ${params.accessToken}` }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!isBinary) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
headers["Content-Type"] = "application/json"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return headers; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the same behavior as the blanket implementation here: No need for an override IMO
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed this part of the logic at the beginning. However, the I have to rewrite prepareHeaders here and by assignment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather update export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "image-to-image");
const payload = await providerHelper.preparePayloadAsync(args);
const { data: res } = await innerRequest<Blob>(payload, providerHelper, {
...options,
task: "image-to-image",
});
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "image-to-image" });
return providerHelper.getResponse(res, url, info.headers as Record<string, string>);
} rather than overriding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Your suggestion makes sense. Initially, this was a common/public function, so I took a minimalistic approach and didn't modify it. Now, let me try making some changes here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I completed the modification and ran the use case successfully |
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
override async getResponse( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
response: WaveSpeedAIResponse<WaveSpeedAISubmitResponse>, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
url?: string, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
headers?: Record<string, string> | ||||||||||||||||||||||||||||||||||||||||||||||||||||
): Promise<Blob> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!headers && this.accessToken) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
headers = { Authorization: `Bearer ${this.accessToken}` }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!headers) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
throw new InferenceOutputError("Headers are required for WaveSpeed AI API calls"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
const resultUrl = response.data.urls.get; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
// Poll for results until completion | ||||||||||||||||||||||||||||||||||||||||||||||||||||
while (true) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const resultResponse = await fetch(resultUrl, { headers }); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!resultResponse.ok) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
throw new InferenceOutputError(`Failed to get result: ${resultResponse.statusText}`); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
const result: WaveSpeedAIResponse = await resultResponse.json(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if (result.code !== 200) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
throw new InferenceOutputError(`API request failed with code ${result.code}: ${result.message}`); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
const taskResult = result.data; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
switch (taskResult.status) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
case "completed": { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// Get the video data from the first output URL | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!taskResult.outputs?.[0]) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
throw new InferenceOutputError("No video URL in completed response"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
const videoResponse = await fetch(taskResult.outputs[0]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!videoResponse.ok) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
throw new InferenceOutputError("Failed to fetch video data"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return await videoResponse.blob(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From what I understand, the payload can be something else than a video (eg an image) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, |
||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
case "failed": { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
throw new InferenceOutputError(taskResult.error || "Task failed"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
case "processing": | ||||||||||||||||||||||||||||||||||||||||||||||||||||
case "created": | ||||||||||||||||||||||||||||||||||||||||||||||||||||
// Wait before polling again | ||||||||||||||||||||||||||||||||||||||||||||||||||||
await delay(100); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
default: { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
throw new InferenceOutputError(`Unknown status: ${taskResult.status}`); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
export class WavespeedAITextToImageTask extends WavespeedAITask implements TextToImageTaskHelper { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
constructor() { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
super(WAVESPEEDAI_API_BASE_URL); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
export class WavespeedAITextToVideoTask extends WavespeedAITask implements TextToVideoTaskHelper { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
constructor() { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
super(WAVESPEEDAI_API_BASE_URL); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
export class WavespeedAIImageToImageTask extends WavespeedAITask implements ImageToImageTaskHelper { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
constructor() { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
super(WAVESPEEDAI_API_BASE_URL); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
hanouticelina marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!args.parameters) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
...args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
model: args.model, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
data: args.inputs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
...args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
inputs: base64FromBytes( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the wavespeed API support base64-encoded images as inputs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
override preparePayload(params: BodyParams): Record<string, unknown> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
...omit(params.args, ["inputs", "parameters"]), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
...(params.args.parameters as Record<string, unknown>), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
image: params.args.inputs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think only one of the two ( cc @hanouticelina - would love your opinion on that specific point There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I only kept preparePayloadAsync func There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes agree! |
||||||||||||||||||||||||||||||||||||||||||||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2023,4 +2023,119 @@ describe.skip("InferenceClient", () => { | |
}, | ||
TIMEOUT | ||
); | ||
describe.concurrent( | ||
"Wavespeed AI", | ||
() => { | ||
const client = new InferenceClient(env.HF_WAVESPEED_KEY ?? "dummy"); | ||
|
||
HARDCODED_MODEL_INFERENCE_MAPPING["wavespeed-ai"] = { | ||
"wavespeed-ai/flux-schnell": { | ||
hfModelId: "wavespeed-ai/flux-schnell", | ||
providerId: "wavespeed-ai/flux-schnell", | ||
status: "live", | ||
task: "text-to-image", | ||
}, | ||
"wavespeed-ai/wan-2.1/t2v-480p": { | ||
hfModelId: "wavespeed-ai/wan-2.1/t2v-480p", | ||
providerId: "wavespeed-ai/wan-2.1/t2v-480p", | ||
status: "live", | ||
task: "text-to-video", | ||
}, | ||
"wavespeed-ai/hidream-e1-full": { | ||
hfModelId: "wavespeed-ai/hidream-e1-full", | ||
providerId: "wavespeed-ai/hidream-e1-full", | ||
status: "live", | ||
task: "image-to-image", | ||
}, | ||
"wavespeed-ai/wan-2.1/i2v-480p": { | ||
hfModelId: "wavespeed-ai/wan-2.1/i2v-480p", | ||
providerId: "wavespeed-ai/wan-2.1/i2v-480p", | ||
status: "live", | ||
task: "image-to-video", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this task is not supported in the client code - let's remove it for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. has deleted |
||
}, | ||
"wavespeed-ai/flux-dev-lora": { | ||
hfModelId: "wavespeed-ai/flux-dev-lora", | ||
providerId: "wavespeed-ai/flux-dev-lora", | ||
status: "live", | ||
task: "text-to-image", | ||
adapter: "lora", | ||
adapterWeightsPath: | ||
"https://d32s1zkpjdc4b1.cloudfront.net/predictions/599f3739f5354afc8a76a12042736bfd/1.safetensors", | ||
}, | ||
"wavespeed-ai/flux-dev-lora-ultra-fast": { | ||
hfModelId: "wavespeed-ai/flux-dev-lora-ultra-fast", | ||
providerId: "wavespeed-ai/flux-dev-lora-ultra-fast", | ||
status: "live", | ||
task: "text-to-image", | ||
adapter: "lora", | ||
adapterWeightsPath: "linoyts/yarn_art_Flux_LoRA", | ||
}, | ||
}; | ||
|
||
it(`textToImage - wavespeed-ai/flux-schnell`, async () => { | ||
const res = await client.textToImage({ | ||
model: "wavespeed-ai/flux-schnell", | ||
provider: "wavespeed-ai", | ||
inputs: | ||
"Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.", | ||
}); | ||
expect(res).toBeInstanceOf(Blob); | ||
}); | ||
|
||
it(`textToImage - wavespeed-ai/flux-dev-lora`, async () => { | ||
const res = await client.textToImage({ | ||
model: "wavespeed-ai/flux-dev-lora", | ||
provider: "wavespeed-ai", | ||
inputs: | ||
"Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.", | ||
}); | ||
expect(res).toBeInstanceOf(Blob); | ||
}); | ||
|
||
it(`textToImage - wavespeed-ai/flux-dev-lora-ultra-fast`, async () => { | ||
const res = await client.textToImage({ | ||
model: "wavespeed-ai/flux-dev-lora-ultra-fast", | ||
provider: "wavespeed-ai", | ||
inputs: | ||
"Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.", | ||
}); | ||
expect(res).toBeInstanceOf(Blob); | ||
}); | ||
|
||
it(`textToVideo - wavespeed-ai/wan-2.1/t2v-480p`, async () => { | ||
const res = await client.textToVideo({ | ||
model: "wavespeed-ai/wan-2.1/t2v-480p", | ||
provider: "wavespeed-ai", | ||
inputs: | ||
"A cool street dancer, wearing a baggy hoodie and hip-hop pants, dancing in front of a graffiti wall, night neon background, quick camera cuts, urban trends.", | ||
parameters: { | ||
guidance_scale: 5, | ||
num_inference_steps: 30, | ||
seed: -1, | ||
}, | ||
duration: 5, | ||
enable_safety_checker: true, | ||
flow_shift: 2.9, | ||
size: "480*832", | ||
}); | ||
expect(res).toBeInstanceOf(Blob); | ||
}); | ||
|
||
it(`imageToImage - wavespeed-ai/hidream-e1-full`, async () => { | ||
const res = await client.imageToImage({ | ||
model: "wavespeed-ai/hidream-e1-full", | ||
provider: "wavespeed-ai", | ||
inputs: new Blob([readTestFile("cheetah.png")], { type: "image / png" }), | ||
parameters: { | ||
prompt: "The leopard chases its prey", | ||
guidance_scale: 5, | ||
num_inference_steps: 30, | ||
seed: -1, | ||
}, | ||
}); | ||
expect(res).toBeInstanceOf(Blob); | ||
}); | ||
}, | ||
60000 * 5 | ||
); | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has deleted