Skip to content

✨ Integrate Nebius AI Studio into HuggingFace inference #1190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_FAL_KEY: dummy
HF_NEBIUS_KEY: dummy
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
Expand Down Expand Up @@ -83,6 +84,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_FAL_KEY: dummy
HF_NEBIUS_KEY: dummy
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
Expand Down Expand Up @@ -151,6 +153,7 @@ jobs:
NPM_CONFIG_REGISTRY: http://localhost:4874/
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_FAL_KEY: dummy
HF_NEBIUS_KEY: dummy
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
Expand Down
6 changes: 4 additions & 2 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ You can send inference requests to third-party providers with the inference clie
Currently, we support the following providers:
- [Fal.ai](https://fal.ai)
- [Fireworks AI](https://fireworks.ai)
- [Nebius](https://studio.nebius.ai)
- [Replicate](https://replicate.com)
- [Sambanova](https://sambanova.ai)
- [Together](https://together.xyz)
Expand All @@ -71,12 +72,13 @@ When authenticated with a third-party provider key, the request is made directly
Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
- [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models)
- [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
- [Nebius supported models](https://huggingface.co/api/partners/nebius/models)
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
- [Together supported models](https://huggingface.co/api/partners/together/models)
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)

❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you!

👋**Want to add another provider?** Get in touch if you'd like to add support for another Inference provider, and/or request it on https://huggingface.co/spaces/huggingface/HuggingDiscussions/discussions/49
Expand Down Expand Up @@ -463,7 +465,7 @@ await hf.zeroShotImageClassification({
model: 'openai/clip-vit-large-patch14-336',
inputs: {
image: await (await fetch('https://placekitten.com/300/300')).blob()
},
},
parameters: {
candidate_labels: ['cat', 'dog']
}
Expand Down
19 changes: 18 additions & 1 deletion packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
import { FAL_AI_API_BASE_URL } from "../providers/fal-ai";
import { NEBIUS_API_BASE_URL } from "../providers/nebius";
import { REPLICATE_API_BASE_URL } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL } from "../providers/together";
Expand Down Expand Up @@ -143,7 +144,7 @@ export async function makeRequestOptions(
? args.data
: JSON.stringify({
...otherArgs,
...(chatCompletion || provider === "together" ? { model } : undefined),
...(chatCompletion || provider === "together" || provider === "nebius" ? { model } : undefined),
}),
...(credentials ? { credentials } : undefined),
signal: options?.signal,
Expand Down Expand Up @@ -172,6 +173,22 @@ function makeUrl(params: {
: FAL_AI_API_BASE_URL;
return `${baseUrl}/${params.model}`;
}
case "nebius": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: NEBIUS_API_BASE_URL;

if (params.taskHint === "text-to-image") {
return `${baseUrl}/v1/images/generations`;
}
if (params.taskHint === "text-generation") {
if (params.chatCompletion) {
return `${baseUrl}/v1/chat/completions`;
}
return `${baseUrl}/v1/completions`;
}
return baseUrl;
}
case "replicate": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
"fal-ai": {},
"fireworks-ai": {},
"hf-inference": {},
nebius: {},
replicate: {},
sambanova: {},
together: {},
Expand Down
18 changes: 18 additions & 0 deletions packages/inference/src/providers/nebius.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";

/**
* See the registered mapping of HF model ID => Nebius model ID here:
*
* https://huggingface.co/api/partners/nebius/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Nebius and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Nebius, please open an issue on the present repo
* and we will tag Nebius team members.
*
* Thanks!
*/
6 changes: 5 additions & 1 deletion packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ interface OutputUrlImageGeneration {
*/
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<Blob> {
const payload =
args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate"
args.provider === "together" ||
args.provider === "fal-ai" ||
args.provider === "replicate" ||
args.provider === "nebius"
? {
...omit(args, ["inputs", "parameters"]),
...args.parameters,
...(args.provider !== "replicate" ? { response_format: "base64" } : undefined),
...(args.provider === "nebius" ? { response_format: "b64_json" } : undefined),
prompt: args.inputs,
}
: args;
Expand Down
7 changes: 5 additions & 2 deletions packages/inference/src/tasks/nlp/chatCompletion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ export async function chatCompletion(
taskHint: "text-generation",
chatCompletion: true,
});

const isValidOutput =
typeof res === "object" &&
Array.isArray(res?.choices) &&
typeof res?.created === "number" &&
typeof res?.id === "string" &&
typeof res?.model === "string" &&
/// Together.ai does not output a system_fingerprint
(res.system_fingerprint === undefined || typeof res.system_fingerprint === "string") &&
/// Together.ai and Nebius do not output a system_fingerprint
(res.system_fingerprint === undefined ||
res.system_fingerprint === null ||
typeof res.system_fingerprint === "string") &&
typeof res?.usage === "object";

if (!isValidOutput) {
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export type InferenceTask = Exclude<PipelineType, "other">;
export const INFERENCE_PROVIDERS = [
"fal-ai",
"fireworks-ai",
"nebius",
"hf-inference",
"replicate",
"sambanova",
Expand Down
50 changes: 50 additions & 0 deletions packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,56 @@ describe.concurrent("HfInference", () => {
TIMEOUT
);

describe.concurrent(
"Nebius",
() => {
const client = new HfInference(env.HF_NEBIUS_KEY);

HARDCODED_MODEL_ID_MAPPING.nebius = {
"meta-llama/Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
};

it("chatCompletion", async () => {
const res = await client.chatCompletion({
model: "meta-llama/Llama-3.1-8B-Instruct",
provider: "nebius",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toMatch(/(two|2)/i);
}
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
model: "meta-llama/Llama-3.1-70B-Instruct",
provider: "nebius",
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
}) as AsyncGenerator<ChatCompletionStreamOutput>;
let out = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
}
}
expect(out).toMatch(/(two|2)/i);
});

it("textToImage", async () => {
const res = await client.textToImage({
model: "black-forest-labs/FLUX.1-schnell",
provider: "nebius",
inputs: "award winning high resolution photo of a giant tortoise",
});
expect(res).toBeInstanceOf(Blob);
});
},
TIMEOUT
);

describe.concurrent("3rd party providers", () => {
it("chatCompletion - fails with unsupported model", async () => {
expect(
Expand Down
73 changes: 73 additions & 0 deletions packages/inference/test/tapes.json
Original file line number Diff line number Diff line change
Expand Up @@ -6920,5 +6920,78 @@
"vary": "Accept-Encoding"
}
}
},
"90dc791157e9ec8ed109eaf07946d878e9208ed6eee79af8dd52a56ef7d40371": {
"url": "https://api.studio.nebius.ai/v1/chat/completions",
"init": {
"headers": {
"Content-Type": "application/json"
},
"method": "POST",
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete this sentence with words, one plus one is equal \"}],\"model\":\"meta-llama/Meta-Llama-3.1-8B-Instruct\"}"
},
"response": {
"body": "{\"id\":\"chatcmpl-89392f51529b4d1c82c3d58b210735c5\",\"choices\":[{\"finish_reason\":\"stop\",\"index\":0,\"logprobs\":null,\"message\":{\"content\":\"Two\",\"refusal\":null,\"role\":\"assistant\",\"audio\":null,\"function_call\":null,\"tool_calls\":[]},\"stop_reason\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\"object\":\"chat.completion\",\"service_tier\":null,\"system_fingerprint\":null,\"usage\":{\"completion_tokens\":2,\"prompt_tokens\":21,\"total_tokens\":23,\"completion_tokens_details\":null,\"prompt_tokens_details\":null},\"prompt_logprobs\":null}",
"status": 200,
"statusText": "OK",
"headers": {
"connection": "keep-alive",
"content-type": "application/json",
"strict-transport-security": "max-age=15724800; includeSubDomains"
}
}
},
"2b75bf387ea5775a8172608df8a1bf7d652b1c5e10f0263e39456ec56e20eedf": {
"url": "https://api.studio.nebius.ai/v1/chat/completions",
"init": {
"headers": {
"Content-Type": "application/json"
},
"method": "POST",
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete the equation 1 + 1 = , just the answer\"}],\"stream\":true,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\"}"
},
"response": {
"body": "data: {\"id\":\"chatcmpl-2ef2941399e74965a2705410b42b4a3f\",\"choices\":[{\"delta\":{\"content\":\"\",\"role\":\"assistant\"},\"finish_reason\":null,\"index\":0,\"logprobs\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-2ef2941399e74965a2705410b42b4a3f\",\"choices\":[{\"delta\":{\"content\":\"2\"},\"finish_reason\":null,\"index\":0,\"logprobs\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-2ef2941399e74965a2705410b42b4a3f\",\"choices\":[{\"delta\":{\"content\":\".\"},\"finish_reason\":null,\"index\":0,\"logprobs\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-2ef2941399e74965a2705410b42b4a3f\",\"choices\":[{\"delta\":{\"content\":\"\"},\"finish_reason\":\"stop\",\"index\":0,\"logprobs\":null,\"stop_reason\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\"object\":\"chat.completion.chunk\"}\n\ndata: [DONE]\n\n",
"status": 200,
"statusText": "OK",
"headers": {
"connection": "keep-alive",
"content-type": "text/event-stream; charset=utf-8",
"strict-transport-security": "max-age=15724800; includeSubDomains",
"transfer-encoding": "chunked"
}
}
},
"b4345ef6e7eb30328b2bf84508cdfc172ecb697d8506c51cc7e447adc7323658": {
"url": "https://api.studio.nebius.ai/v1/images/generations",
"init": {
"headers": {
"Content-Type": "application/json"
},
"method": "POST",
"body": "{\"response_format\":\"b64_json\",\"prompt\":\"award winning high resolution photo of a giant tortoise\",\"model\":\"black-forest-labs/flux-schnell\"}"
},
"response": {
"body": "{\"data\":[{\"b64_json\":\"UklGRujgAQBXRUJQVlA4INzgAQAQegOdASoAAgACPhkIg0EhBgtzgAQAYSli6wABWWhZqn6p9Jvqf8h+z35S/JFxn1Heo/uH+X/4X+F/cz5b/+7/Ofkd2G9bf+n/PejR0Z/4v8l/r/26+YH+c/9H+m/2nwh/RH/d/y/79/QP+qv/G/un+g/b36Df8n9z/eR/dv+P+VPwW/rX+t/+X+y/4P//+Yj/tfub7uv8F/0/3K+An+1f77//f8TtPf309hD97/V6/8/7wf9P5YP7R/0P3K/7fyNf1L/R//b/Xf8H4AP//7fPMjwf+OH5r8lPM38Z+Yft39u/yf+W/t//2/2f2FfcX+P/ivDf6j/N/7n/J/5/2D/jv3C/Ff3b/Lf8X+9fu98l/6j/FfuT/jfUX8t/af8//i/8//x/8Z+7n2F/i38o/t39q/yv+j/uP7f/RB8j/tf8//qP/d/rfQt1r/D/6T/If4//j/v/9CPrj88/vv9w/yH+2/u/7n+3v/if5b/Qf+X3s/Uv8V/uf81+Rn2Bfyv+gf6D+7/5H/r/4n///+D7F/1H/U/1P+o/7nqWfav8z/1v87/q/2k+wH+Yf1L/Uf3z/L/+n/J////+/it/Mf8b/I/6H/xf5j///+j4s/nv9z/3H+K/zv/i/yv///9f6BfyD+d/4/+4/5H/n/4H///+P7pv+l+fPzu/Zf/ofnr9F/6r/7P86f3//9ifCC7lwK/mrcqzl7KswpK6fOxHeXWMQ66NvfdMdbfPjryRxBoK7Cl7y7JAFFu+4OPy3T0UhWdetGanczWNDZ6hP3af/YD7nH0cZyq5Rr2pdXTLcRWW0tONLzyK/1DKAvE1aKmARq2bjActrCWJBr1gAZAvPvwGN4dsjAz3pJhhhc2fNHYEWcUa+pJs9szFbKPCJXckqc+4Kf8bf84MsgozZo6FC7tto7W2DY7Nk6+pZzpRA+1qY81hRqJMTXVSduE+HlovvQL0CDxW2x2qSkGNulpY\"}],\"id\":\"text2img-b743e941-3756-4fb8-aeca-2883ab029516\"}",
"status": 200,
"statusText": "OK",
"headers": {
"connection": "keep-alive",
"content-type": "application/json",
"strict-transport-security": "max-age=15724800; includeSubDomains"
}
}
},
"76e27a0a58b167b19f3a059ab499955365e64ca8b816440ec321764b0f14fd98": {
"url": "",
"init": {},
"response": {
"body": "",
"status": 200,
"statusText": "OK",
"headers": {
"content-type": "image/jpeg"
}
}
}
}