Skip to content

Commit 18bd1f5

Browse files
kitcat-devjulien-c
andauthored
✨ Integrate Nebius AI Studio into HuggingFace inference (#1190)
Added support for [the inference service from Nebius](https://nebius.com/services/studio-inference-service). I have added tests similar to others in `HfInference.spec.ts` file. They were passing just yesterday – before hardcode-models were removed. I'll have to double-check when our models will be dynamically mapped with your partners API. --------- Co-authored-by: Julien Chaumond <[email protected]>
1 parent c1a8dfc commit 18bd1f5

File tree

10 files changed

+178
-6
lines changed

10 files changed

+178
-6
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ jobs:
4242
env:
4343
HF_TOKEN: ${{ secrets.HF_TOKEN }}
4444
HF_FAL_KEY: dummy
45+
HF_NEBIUS_KEY: dummy
4546
HF_REPLICATE_KEY: dummy
4647
HF_SAMBANOVA_KEY: dummy
4748
HF_TOGETHER_KEY: dummy
@@ -83,6 +84,7 @@ jobs:
8384
env:
8485
HF_TOKEN: ${{ secrets.HF_TOKEN }}
8586
HF_FAL_KEY: dummy
87+
HF_NEBIUS_KEY: dummy
8688
HF_REPLICATE_KEY: dummy
8789
HF_SAMBANOVA_KEY: dummy
8890
HF_TOGETHER_KEY: dummy
@@ -151,6 +153,7 @@ jobs:
151153
NPM_CONFIG_REGISTRY: http://localhost:4874/
152154
HF_TOKEN: ${{ secrets.HF_TOKEN }}
153155
HF_FAL_KEY: dummy
156+
HF_NEBIUS_KEY: dummy
154157
HF_REPLICATE_KEY: dummy
155158
HF_SAMBANOVA_KEY: dummy
156159
HF_TOGETHER_KEY: dummy

packages/inference/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ You can send inference requests to third-party providers with the inference clie
4949
Currently, we support the following providers:
5050
- [Fal.ai](https://fal.ai)
5151
- [Fireworks AI](https://fireworks.ai)
52+
- [Nebius](https://studio.nebius.ai)
5253
- [Replicate](https://replicate.com)
5354
- [Sambanova](https://sambanova.ai)
5455
- [Together](https://together.xyz)
@@ -71,12 +72,13 @@ When authenticated with a third-party provider key, the request is made directly
7172
Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
7273
- [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models)
7374
- [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
75+
- [Nebius supported models](https://huggingface.co/api/partners/nebius/models)
7476
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
7577
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
7678
- [Together supported models](https://huggingface.co/api/partners/together/models)
7779
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
7880

79-
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
81+
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
8082
This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you!
8183

8284
👋**Want to add another provider?** Get in touch if you'd like to add support for another Inference provider, and/or request it on https://huggingface.co/spaces/huggingface/HuggingDiscussions/discussions/49
@@ -463,7 +465,7 @@ await hf.zeroShotImageClassification({
463465
model: 'openai/clip-vit-large-patch14-336',
464466
inputs: {
465467
image: await (await fetch('https://placekitten.com/300/300')).blob()
466-
},
468+
},
467469
parameters: {
468470
candidate_labels: ['cat', 'dog']
469471
}

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
22
import { FAL_AI_API_BASE_URL } from "../providers/fal-ai";
3+
import { NEBIUS_API_BASE_URL } from "../providers/nebius";
34
import { REPLICATE_API_BASE_URL } from "../providers/replicate";
45
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
56
import { TOGETHER_API_BASE_URL } from "../providers/together";
@@ -143,7 +144,7 @@ export async function makeRequestOptions(
143144
? args.data
144145
: JSON.stringify({
145146
...otherArgs,
146-
...(chatCompletion || provider === "together" ? { model } : undefined),
147+
...(chatCompletion || provider === "together" || provider === "nebius" ? { model } : undefined),
147148
}),
148149
...(credentials ? { credentials } : undefined),
149150
signal: options?.signal,
@@ -172,6 +173,22 @@ function makeUrl(params: {
172173
: FAL_AI_API_BASE_URL;
173174
return `${baseUrl}/${params.model}`;
174175
}
176+
case "nebius": {
177+
const baseUrl = shouldProxy
178+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
179+
: NEBIUS_API_BASE_URL;
180+
181+
if (params.taskHint === "text-to-image") {
182+
return `${baseUrl}/v1/images/generations`;
183+
}
184+
if (params.taskHint === "text-generation") {
185+
if (params.chatCompletion) {
186+
return `${baseUrl}/v1/chat/completions`;
187+
}
188+
return `${baseUrl}/v1/completions`;
189+
}
190+
return baseUrl;
191+
}
175192
case "replicate": {
176193
const baseUrl = shouldProxy
177194
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
1919
"fal-ai": {},
2020
"fireworks-ai": {},
2121
"hf-inference": {},
22+
nebius: {},
2223
replicate: {},
2324
sambanova: {},
2425
together: {},
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
2+
3+
/**
4+
* See the registered mapping of HF model ID => Nebius model ID here:
5+
*
6+
* https://huggingface.co/api/partners/nebius/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Nebius and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Nebius, please open an issue on the present repo
15+
* and we will tag Nebius team members.
16+
*
17+
* Thanks!
18+
*/

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@ interface OutputUrlImageGeneration {
2121
*/
2222
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<Blob> {
2323
const payload =
24-
args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate"
24+
args.provider === "together" ||
25+
args.provider === "fal-ai" ||
26+
args.provider === "replicate" ||
27+
args.provider === "nebius"
2528
? {
2629
...omit(args, ["inputs", "parameters"]),
2730
...args.parameters,
2831
...(args.provider !== "replicate" ? { response_format: "base64" } : undefined),
32+
...(args.provider === "nebius" ? { response_format: "b64_json" } : undefined),
2933
prompt: args.inputs,
3034
}
3135
: args;

packages/inference/src/tasks/nlp/chatCompletion.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@ export async function chatCompletion(
1515
taskHint: "text-generation",
1616
chatCompletion: true,
1717
});
18+
1819
const isValidOutput =
1920
typeof res === "object" &&
2021
Array.isArray(res?.choices) &&
2122
typeof res?.created === "number" &&
2223
typeof res?.id === "string" &&
2324
typeof res?.model === "string" &&
24-
/// Together.ai does not output a system_fingerprint
25-
(res.system_fingerprint === undefined || typeof res.system_fingerprint === "string") &&
25+
/// Together.ai and Nebius do not output a system_fingerprint
26+
(res.system_fingerprint === undefined ||
27+
res.system_fingerprint === null ||
28+
typeof res.system_fingerprint === "string") &&
2629
typeof res?.usage === "object";
2730

2831
if (!isValidOutput) {

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ export type InferenceTask = Exclude<PipelineType, "other">;
4747
export const INFERENCE_PROVIDERS = [
4848
"fal-ai",
4949
"fireworks-ai",
50+
"nebius",
5051
"hf-inference",
5152
"replicate",
5253
"sambanova",

packages/inference/test/HfInference.spec.ts

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,56 @@ describe.concurrent("HfInference", () => {
10641064
TIMEOUT
10651065
);
10661066

1067+
describe.concurrent(
1068+
"Nebius",
1069+
() => {
1070+
const client = new HfInference(env.HF_NEBIUS_KEY);
1071+
1072+
HARDCODED_MODEL_ID_MAPPING.nebius = {
1073+
"meta-llama/Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
1074+
"meta-llama/Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct",
1075+
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
1076+
};
1077+
1078+
it("chatCompletion", async () => {
1079+
const res = await client.chatCompletion({
1080+
model: "meta-llama/Llama-3.1-8B-Instruct",
1081+
provider: "nebius",
1082+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1083+
});
1084+
if (res.choices && res.choices.length > 0) {
1085+
const completion = res.choices[0].message?.content;
1086+
expect(completion).toMatch(/(two|2)/i);
1087+
}
1088+
});
1089+
1090+
it("chatCompletion stream", async () => {
1091+
const stream = client.chatCompletionStream({
1092+
model: "meta-llama/Llama-3.1-70B-Instruct",
1093+
provider: "nebius",
1094+
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
1095+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1096+
let out = "";
1097+
for await (const chunk of stream) {
1098+
if (chunk.choices && chunk.choices.length > 0) {
1099+
out += chunk.choices[0].delta.content;
1100+
}
1101+
}
1102+
expect(out).toMatch(/(two|2)/i);
1103+
});
1104+
1105+
it("textToImage", async () => {
1106+
const res = await client.textToImage({
1107+
model: "black-forest-labs/FLUX.1-schnell",
1108+
provider: "nebius",
1109+
inputs: "award winning high resolution photo of a giant tortoise",
1110+
});
1111+
expect(res).toBeInstanceOf(Blob);
1112+
});
1113+
},
1114+
TIMEOUT
1115+
);
1116+
10671117
describe.concurrent("3rd party providers", () => {
10681118
it("chatCompletion - fails with unsupported model", async () => {
10691119
expect(

packages/inference/test/tapes.json

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6920,5 +6920,78 @@
69206920
"vary": "Accept-Encoding"
69216921
}
69226922
}
6923+
},
6924+
"90dc791157e9ec8ed109eaf07946d878e9208ed6eee79af8dd52a56ef7d40371": {
6925+
"url": "https://api.studio.nebius.ai/v1/chat/completions",
6926+
"init": {
6927+
"headers": {
6928+
"Content-Type": "application/json"
6929+
},
6930+
"method": "POST",
6931+
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete this sentence with words, one plus one is equal \"}],\"model\":\"meta-llama/Meta-Llama-3.1-8B-Instruct\"}"
6932+
},
6933+
"response": {
6934+
"body": "{\"id\":\"chatcmpl-89392f51529b4d1c82c3d58b210735c5\",\"choices\":[{\"finish_reason\":\"stop\",\"index\":0,\"logprobs\":null,\"message\":{\"content\":\"Two\",\"refusal\":null,\"role\":\"assistant\",\"audio\":null,\"function_call\":null,\"tool_calls\":[]},\"stop_reason\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\"object\":\"chat.completion\",\"service_tier\":null,\"system_fingerprint\":null,\"usage\":{\"completion_tokens\":2,\"prompt_tokens\":21,\"total_tokens\":23,\"completion_tokens_details\":null,\"prompt_tokens_details\":null},\"prompt_logprobs\":null}",
6935+
"status": 200,
6936+
"statusText": "OK",
6937+
"headers": {
6938+
"connection": "keep-alive",
6939+
"content-type": "application/json",
6940+
"strict-transport-security": "max-age=15724800; includeSubDomains"
6941+
}
6942+
}
6943+
},
6944+
"2b75bf387ea5775a8172608df8a1bf7d652b1c5e10f0263e39456ec56e20eedf": {
6945+
"url": "https://api.studio.nebius.ai/v1/chat/completions",
6946+
"init": {
6947+
"headers": {
6948+
"Content-Type": "application/json"
6949+
},
6950+
"method": "POST",
6951+
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete the equation 1 + 1 = , just the answer\"}],\"stream\":true,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\"}"
6952+
},
6953+
"response": {
6954+
"body": "data: {\"id\":\"chatcmpl-2ef2941399e74965a2705410b42b4a3f\",\"choices\":[{\"delta\":{\"content\":\"\",\"role\":\"assistant\"},\"finish_reason\":null,\"index\":0,\"logprobs\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-2ef2941399e74965a2705410b42b4a3f\",\"choices\":[{\"delta\":{\"content\":\"2\"},\"finish_reason\":null,\"index\":0,\"logprobs\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-2ef2941399e74965a2705410b42b4a3f\",\"choices\":[{\"delta\":{\"content\":\".\"},\"finish_reason\":null,\"index\":0,\"logprobs\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-2ef2941399e74965a2705410b42b4a3f\",\"choices\":[{\"delta\":{\"content\":\"\"},\"finish_reason\":\"stop\",\"index\":0,\"logprobs\":null,\"stop_reason\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\"object\":\"chat.completion.chunk\"}\n\ndata: [DONE]\n\n",
6955+
"status": 200,
6956+
"statusText": "OK",
6957+
"headers": {
6958+
"connection": "keep-alive",
6959+
"content-type": "text/event-stream; charset=utf-8",
6960+
"strict-transport-security": "max-age=15724800; includeSubDomains",
6961+
"transfer-encoding": "chunked"
6962+
}
6963+
}
6964+
},
6965+
"b4345ef6e7eb30328b2bf84508cdfc172ecb697d8506c51cc7e447adc7323658": {
6966+
"url": "https://api.studio.nebius.ai/v1/images/generations",
6967+
"init": {
6968+
"headers": {
6969+
"Content-Type": "application/json"
6970+
},
6971+
"method": "POST",
6972+
"body": "{\"response_format\":\"b64_json\",\"prompt\":\"award winning high resolution photo of a giant tortoise\",\"model\":\"black-forest-labs/flux-schnell\"}"
6973+
},
6974+
"response": {
6975+
"body": "{\"data\":[{\"b64_json\":\"UklGRujgAQBXRUJQVlA4INzgAQAQegOdASoAAgACPhkIg0EhBgtzgAQAYSli6wABWWhZqn6p9Jvqf8h+z35S/JFxn1Heo/uH+X/4X+F/cz5b/+7/Ofkd2G9bf+n/PejR0Z/4v8l/r/26+YH+c/9H+m/2nwh/RH/d/y/79/QP+qv/G/un+g/b36Df8n9z/eR/dv+P+VPwW/rX+t/+X+y/4P//+Yj/tfub7uv8F/0/3K+An+1f77//f8TtPf309hD97/V6/8/7wf9P5YP7R/0P3K/7fyNf1L/R//b/Xf8H4AP//7fPMjwf+OH5r8lPM38Z+Yft39u/yf+W/t//2/2f2FfcX+P/ivDf6j/N/7n/J/5/2D/jv3C/Ff3b/Lf8X+9fu98l/6j/FfuT/jfUX8t/af8//i/8//x/8Z+7n2F/i38o/t39q/yv+j/uP7f/RB8j/tf8//qP/d/rfQt1r/D/6T/If4//j/v/9CPrj88/vv9w/yH+2/u/7n+3v/if5b/Qf+X3s/Uv8V/uf81+Rn2Bfyv+gf6D+7/5H/r/4n///+D7F/1H/U/1P+o/7nqWfav8z/1v87/q/2k+wH+Yf1L/Uf3z/L/+n/J////+/it/Mf8b/I/6H/xf5j///+j4s/nv9z/3H+K/zv/i/yv///9f6BfyD+d/4/+4/5H/n/4H///+P7pv+l+fPzu/Zf/ofnr9F/6r/7P86f3//9ifCC7lwK/mrcqzl7KswpK6fOxHeXWMQ66NvfdMdbfPjryRxBoK7Cl7y7JAFFu+4OPy3T0UhWdetGanczWNDZ6hP3af/YD7nH0cZyq5Rr2pdXTLcRWW0tONLzyK/1DKAvE1aKmARq2bjActrCWJBr1gAZAvPvwGN4dsjAz3pJhhhc2fNHYEWcUa+pJs9szFbKPCJXckqc+4Kf8bf84MsgozZo6FC7tto7W2DY7Nk6+pZzpRA+1qY81hRqJMTXVSduE+HlovvQL0CDxW2x2qSkGNulpY\"}],\"id\":\"text2img-b743e941-3756-4fb8-aeca-2883ab029516\"}",
6976+
"status": 200,
6977+
"statusText": "OK",
6978+
"headers": {
6979+
"connection": "keep-alive",
6980+
"content-type": "application/json",
6981+
"strict-transport-security": "max-age=15724800; includeSubDomains"
6982+
}
6983+
}
6984+
},
6985+
"76e27a0a58b167b19f3a059ab499955365e64ca8b816440ec321764b0f14fd98": {
6986+
"url": "",
6987+
"init": {},
6988+
"response": {
6989+
"body": "",
6990+
"status": 200,
6991+
"statusText": "OK",
6992+
"headers": {
6993+
"content-type": "image/jpeg"
6994+
}
6995+
}
69236996
}
69246997
}

0 commit comments

Comments
 (0)