diff --git a/packages/inference/README.md b/packages/inference/README.md index 21e46625b..a7cb149ad 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -63,6 +63,7 @@ Currently, we support the following providers: - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) - [Groq](https://groq.com) +- [Dat1](https://dat1.co) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers. @@ -95,6 +96,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Together supported models](https://huggingface.co/api/partners/together/models) - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) +- [Dat1 supported models](https://huggingface.co/api/partners/dat1/models) - [Groq supported models](https://console.groq.com/docs/models) - [Novita AI supported models](https://huggingface.co/api/partners/novita/models) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 9afcb8980..7ca455f86 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -47,8 +47,10 @@ import type { import * as Replicate from "../providers/replicate.js"; import * as Sambanova from "../providers/sambanova.js"; import * as Together from "../providers/together.js"; +import * as Dat1 from "../providers/dat1"; import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types.js"; + export const PROVIDERS: Record>> = { "black-forest-labs": { "text-to-image": new BlackForestLabs.BlackForestLabsTextToImageTask(), @@ -59,6 +61,10 @@ export const PROVIDERS: Record Together model ID here: + * + * https://huggingface.co/api/partners/together/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Together and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Together, please open an issue on the present repo + * and we will tag Together team members. + * + * Thanks! + */ +import { InferenceOutputError } from "../lib/InferenceOutputError"; +import type { BodyParams } from "../types"; +import { omit } from "../utils/omit"; +import { + BaseConversationalTask, + TaskProviderHelper, + type TextToImageTaskHelper, +} from "./providerHelper"; + +const DAT1_API_BASE_URL = "https://api.dat1.co/api/v1/hf"; + +interface Dat1Base64ImageGeneration { + data: Array<{ + b64_json: string; + }>; +} + +export class Dat1ConversationalTask extends BaseConversationalTask { + constructor() { + super("dat1", DAT1_API_BASE_URL); + } + + override makeRoute(): string { + return "/chat/completions"; + } +} + +export class Dat1TextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper { + constructor() { + super("dat1", DAT1_API_BASE_URL); + } + + override makeRoute(): string { + return "/images/generations"; + } + + preparePayload(params: BodyParams): Record { + return { + ...omit(params.args, ["inputs", "parameters"]), + ...(params.args.parameters as Record), + prompt: params.args.inputs, + response_format: "base64", + model: params.model, + }; + } + + async getResponse(response: Dat1Base64ImageGeneration, outputType?: "url" | "blob"): Promise { + if ( + typeof response === "object" && + "data" in response && + Array.isArray(response.data) && + response.data.length > 0 && + "b64_json" in response.data[0] && + typeof response.data[0].b64_json === "string" + ) { + const base64Data = response.data[0].b64_json; + if (outputType === "url") { + return `data:image/jpeg;base64,${base64Data}`; + } + return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob()); + } + + throw new InferenceOutputError("Expected Dat1 text-to-image response format"); + } +} diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index e2f9682c9..4268ceb9b 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -41,6 +41,7 @@ export const INFERENCE_PROVIDERS = [ "black-forest-labs", "cerebras", "cohere", + "dat1", "fal-ai", "featherless-ai", "fireworks-ai", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 602e034cd..7826a24eb 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -989,6 +989,65 @@ describe.skip("InferenceClient", () => { TIMEOUT ); + describe.concurrent( + "dat1", + () => { + const client = new InferenceClient(env.DAT1_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["dat1"] = { + "unsloth/Llama-3.2-3B-Instruct-GGUF": { + hfModelId: "unsloth/Llama-3.2-3B-Instruct-GGUF", + providerId: "unsloth-Llama-32-3B-Instruct-GGUF", + status: "live", + task: "conversational", + }, + "Kwai-Kolors/Kolors": { + hfModelId: "Kwai-Kolors/Kolors", + providerId: "Kwai-Kolors-Kolors", + status: "live", + task: "text-to-image", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "unsloth/Llama-3.2-3B-Instruct-GGUF", + provider: "dat1", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "unsloth/Llama-3.2-3B-Instruct-GGUF", + provider: "dat1", + messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + }) as AsyncGenerator; + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toContain("2"); + }); + + it("textToImage", async () => { + const res = await client.textToImage({ + model: "Kwai-Kolors/Kolors", + provider: "dat1", + inputs: "award winning high resolution photo of a giant tortoise", + }); + expect(res).toBeInstanceOf(Blob); + }); + }, + TIMEOUT + ) + /** * Compatibility with third-party Inference Providers */