Skip to content

Commit 86b1f2e

Browse files
julien-cSBrandeis
andauthored
[Inference] compatibility with third-party Inference providers (#1077)
# TL;DR Allow users to request 3rd party inference providers (Sambanova, Replicate, Together, Fal) with `@huggingface/inference` for a curated set of models on the HF Hub For now, Requesting a 3rd party inference provider requires users to pass an api key from this provider as a parameter to the inference function. --------- Co-authored-by: SBrandeis <[email protected]>
1 parent 29715d4 commit 86b1f2e

20 files changed

+1644
-765
lines changed

.github/workflows/test.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ jobs:
4141
run: VCR_MODE=playback pnpm --filter ...[${{ steps.since.outputs.SINCE }}] test
4242
env:
4343
HF_TOKEN: ${{ secrets.HF_TOKEN }}
44+
HF_FAL_KEY: dummy
45+
HF_REPLICATE_KEY: dummy
46+
HF_SAMBANOVA_KEY: dummy
47+
HF_TOGETHER_KEY: dummy
4448

4549
browser:
4650
runs-on: ubuntu-latest
@@ -77,6 +81,10 @@ jobs:
7781
run: VCR_MODE=playback pnpm --filter ...[${{ steps.since.outputs.SINCE }}] test:browser
7882
env:
7983
HF_TOKEN: ${{ secrets.HF_TOKEN }}
84+
HF_FAL_KEY: dummy
85+
HF_REPLICATE_KEY: dummy
86+
HF_SAMBANOVA_KEY: dummy
87+
HF_TOGETHER_KEY: dummy
8088

8189
e2e:
8290
runs-on: ubuntu-latest
@@ -140,3 +148,7 @@ jobs:
140148
env:
141149
NPM_CONFIG_REGISTRY: http://localhost:4874/
142150
HF_TOKEN: ${{ secrets.HF_TOKEN }}
151+
HF_FAL_KEY: dummy
152+
HF_REPLICATE_KEY: dummy
153+
HF_SAMBANOVA_KEY: dummy
154+
HF_TOGETHER_KEY: dummy

README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// Programatically interact with the Hub
1414

1515
await createRepo({
16-
repo: {type: "model", name: "my-user/nlp-model"},
16+
repo: { type: "model", name: "my-user/nlp-model" },
1717
accessToken: HF_TOKEN
1818
});
1919

@@ -53,11 +53,13 @@ await inference.textToImage({
5353

5454
This is a collection of JS libraries to interact with the Hugging Face API, with TS types included.
5555

56-
- [@huggingface/inference](packages/inference/README.md): Use Inference Endpoints (dedicated) and Inference API (serverless) to make calls to 100,000+ Machine Learning models
56+
- [@huggingface/inference](packages/inference/README.md): Use Inference API (serverless) and Inference Endpoints (dedicated) to make calls to 100,000+ Machine Learning models
5757
- [@huggingface/hub](packages/hub/README.md): Interact with huggingface.co to create or delete repos and commit / download files
5858
- [@huggingface/agents](packages/agents/README.md): Interact with HF models through a natural language interface
5959
- [@huggingface/gguf](packages/gguf/README.md): A GGUF parser that works on remotely hosted files.
60+
- [@huggingface/dduf](packages/dduf/README.md): Similar package for DDUF (DDUF Diffusers Unified Format)
6061
- [@huggingface/tasks](packages/tasks/README.md): The definition files and source-of-truth for the Hub's main primitives like pipeline tasks, model libraries, etc.
62+
- [@huggingface/jinja](packages/jinja/README.md): A minimalistic JS implementation of the Jinja templating engine, to be used for ML chat templates.
6163
- [@huggingface/space-header](packages/space-header/README.md): Use the Space `mini_header` outside Hugging Face
6264

6365

@@ -165,7 +167,7 @@ await inference.imageToText({
165167
const gpt2 = inference.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
166168
const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the universe is'});
167169

168-
//Chat Completion
170+
// Chat Completion
169171
const llamaEndpoint = inference.endpoint(
170172
"https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct"
171173
);
@@ -185,7 +187,7 @@ import { createRepo, uploadFile, deleteFiles } from "@huggingface/hub";
185187
const HF_TOKEN = "hf_...";
186188

187189
await createRepo({
188-
repo: "my-user/nlp-model", // or {type: "model", name: "my-user/nlp-test"},
190+
repo: "my-user/nlp-model", // or { type: "model", name: "my-user/nlp-test" },
189191
accessToken: HF_TOKEN
190192
});
191193

@@ -200,7 +202,7 @@ await uploadFile({
200202
});
201203

202204
await deleteFiles({
203-
repo: {type: "space", name: "my-user/my-space"}, // or "spaces/my-user/my-space"
205+
repo: { type: "space", name: "my-user/my-space" }, // or "spaces/my-user/my-space"
204206
accessToken: HF_TOKEN,
205207
paths: ["README.md", ".gitattributes"]
206208
});
@@ -209,7 +211,7 @@ await deleteFiles({
209211
### @huggingface/agents example
210212

211213
```ts
212-
import {HfAgent, LLMFromHub, defaultTools} from '@huggingface/agents';
214+
import { HfAgent, LLMFromHub, defaultTools } from '@huggingface/agents';
213215

214216
const HF_TOKEN = "hf_...";
215217

packages/inference/LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2022 Tim Mikeladze
3+
Copyright (c) 2022 Tim Mikeladze and the Hugging Face team
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

packages/inference/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# 🤗 Hugging Face Inference Endpoints
22

3-
A Typescript powered wrapper for the Hugging Face Inference Endpoints API. Learn more about Inference Endpoints at [Hugging Face](https://huggingface.co/inference-endpoints).
4-
It works with both [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index).
3+
A Typescript powered wrapper for the Hugging Face Inference API (serverless), Inference Endpoints (dedicated), and third-party Inference Providers.
4+
It works with [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index), and even with supported third-party Inference Providers.
55

66
Check out the [full documentation](https://huggingface.co/docs/huggingface.js/inference/README).
77

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
import type { InferenceTask, Options, RequestArgs } from "../types";
1+
import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
2+
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
3+
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
4+
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
5+
import { INFERENCE_PROVIDERS, type InferenceTask, type Options, type RequestArgs } from "../types";
26
import { omit } from "../utils/omit";
37
import { HF_HUB_URL } from "./getDefaultTask";
48
import { isUrl } from "./isUrl";
59

610
const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
711

812
/**
9-
* Loaded from huggingface.co/api/tasks if needed
13+
* Lazy-loaded from huggingface.co/api/tasks when needed
14+
* Used to determine the default model to use when it's not user defined
1015
*/
1116
let tasks: Record<string, { models: { id: string }[] }> | null = null;
1217

@@ -26,21 +31,14 @@ export async function makeRequestOptions(
2631
chatCompletion?: boolean;
2732
}
2833
): Promise<{ url: string; info: RequestInit }> {
29-
const { accessToken, endpointUrl, ...otherArgs } = args;
34+
const { accessToken, endpointUrl, provider, ...otherArgs } = args;
3035
let { model } = args;
31-
const {
32-
forceTask: task,
33-
includeCredentials,
34-
taskHint,
35-
wait_for_model,
36-
use_cache,
37-
dont_load_model,
38-
chatCompletion,
39-
} = options ?? {};
36+
const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } =
37+
options ?? {};
4038

4139
const headers: Record<string, string> = {};
4240
if (accessToken) {
43-
headers["Authorization"] = `Bearer ${accessToken}`;
41+
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
4442
}
4543

4644
if (!model && !tasks && taskHint) {
@@ -61,6 +59,35 @@ export async function makeRequestOptions(
6159
if (!model) {
6260
throw new Error("No model provided, and no default model found for this task");
6361
}
62+
if (provider) {
63+
if (!INFERENCE_PROVIDERS.includes(provider)) {
64+
throw new Error("Unknown Inference provider");
65+
}
66+
if (!accessToken) {
67+
throw new Error("Specifying an Inference provider requires an accessToken");
68+
}
69+
70+
const modelId = (() => {
71+
switch (provider) {
72+
case "replicate":
73+
return REPLICATE_MODEL_IDS[model];
74+
case "sambanova":
75+
return SAMBANOVA_MODEL_IDS[model];
76+
case "together":
77+
return TOGETHER_MODEL_IDS[model]?.id;
78+
case "fal-ai":
79+
return FAL_AI_MODEL_IDS[model];
80+
default:
81+
return model;
82+
}
83+
})();
84+
85+
if (!modelId) {
86+
throw new Error(`Model ${model} is not supported for provider ${provider}`);
87+
}
88+
89+
model = modelId;
90+
}
6491

6592
const binary = "data" in args && !!args.data;
6693

@@ -77,6 +104,9 @@ export async function makeRequestOptions(
77104
if (dont_load_model) {
78105
headers["X-Load-Model"] = "0";
79106
}
107+
if (provider === "replicate") {
108+
headers["Prefer"] = "wait";
109+
}
80110

81111
let url = (() => {
82112
if (endpointUrl && isUrl(model)) {
@@ -89,8 +119,33 @@ export async function makeRequestOptions(
89119
if (endpointUrl) {
90120
return endpointUrl;
91121
}
92-
if (task) {
93-
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
122+
if (forceTask) {
123+
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${forceTask}/${model}`;
124+
}
125+
if (provider) {
126+
if (!accessToken) {
127+
throw new Error("Specifying an Inference provider requires an accessToken");
128+
}
129+
if (accessToken.startsWith("hf_")) {
130+
/// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account.
131+
throw new Error("Inference proxying is not implemented yet");
132+
} else {
133+
switch (provider) {
134+
case "fal-ai":
135+
return `${FAL_AI_API_BASE_URL}/${model}`;
136+
case "replicate":
137+
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`;
138+
case "sambanova":
139+
return SAMBANOVA_API_BASE_URL;
140+
case "together":
141+
if (taskHint === "text-to-image") {
142+
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
143+
}
144+
return TOGETHER_API_BASE_URL;
145+
default:
146+
break;
147+
}
148+
}
94149
}
95150

96151
return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
@@ -99,6 +154,9 @@ export async function makeRequestOptions(
99154
if (chatCompletion && !url.endsWith("/chat/completions")) {
100155
url += "/v1/chat/completions";
101156
}
157+
if (provider === "together" && taskHint === "text-generation" && !chatCompletion) {
158+
url += "/v1/completions";
159+
}
102160

103161
/**
104162
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
@@ -116,9 +174,11 @@ export async function makeRequestOptions(
116174
body: binary
117175
? args.data
118176
: JSON.stringify({
119-
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs),
177+
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate" || provider === "fal-ai"
178+
? omit(otherArgs, "model")
179+
: { ...otherArgs, model }),
120180
}),
121-
...(credentials && { credentials }),
181+
...(credentials ? { credentials } : undefined),
122182
signal: options?.signal,
123183
};
124184

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import type { ModelId } from "../types";
2+
3+
export const FAL_AI_API_BASE_URL = "https://fal.run";
4+
5+
type FalAiId = string;
6+
7+
/**
8+
* Mapping from HF model ID -> fal.ai app id
9+
*/
10+
export const FAL_AI_MODEL_IDS: Record<ModelId, FalAiId> = {
11+
/** text-to-image */
12+
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
13+
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
14+
15+
/** automatic-speech-recognition */
16+
"openai/whisper-large-v3": "fal-ai/whisper",
17+
};
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import type { ModelId } from "../types";
2+
3+
export const REPLICATE_API_BASE_URL = "https://api.replicate.com";
4+
5+
type ReplicateId = string;
6+
7+
/**
8+
* Mapping from HF model ID -> Replicate model ID
9+
*
10+
* Available models can be fetched with:
11+
* ```
12+
* curl -s \
13+
* -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
14+
* 'https://api.replicate.com/v1/models'
15+
* ```
16+
*/
17+
export const REPLICATE_MODEL_IDS: Record<ModelId, ReplicateId> = {
18+
/** text-to-image */
19+
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
20+
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step",
21+
};
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import type { ModelId } from "../types";
2+
3+
export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
4+
5+
/**
6+
* Note for reviewers: our goal would be to ask Sambanova to support
7+
* our model ids too, so we don't have to define a mapping
8+
* or keep it up-to-date.
9+
*
10+
* As a fallback, if the above is not possible, ask Sambanova to
11+
* provide the mapping as an fetchable API.
12+
*/
13+
type SambanovaId = string;
14+
15+
/**
16+
* https://community.sambanova.ai/t/supported-models/193
17+
*/
18+
export const SAMBANOVA_MODEL_IDS: Record<ModelId, SambanovaId> = {
19+
/** Chat completion / conversational */
20+
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
21+
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
22+
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
23+
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
24+
"meta-llama/Llama-3.2-1B": "Meta-Llama-3.2-1B-Instruct",
25+
"meta-llama/Llama-3.2-3B": "Meta-Llama-3.2-3B-Instruct",
26+
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
27+
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
28+
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
29+
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
30+
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
31+
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B",
32+
};
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import type { ModelId } from "../types";
2+
3+
export const TOGETHER_API_BASE_URL = "https://api.together.xyz";
4+
5+
/**
6+
* Same comment as in sambanova.ts
7+
*/
8+
type TogetherId = string;
9+
10+
/**
11+
* https://docs.together.ai/reference/models-1
12+
*/
13+
export const TOGETHER_MODEL_IDS: Record<
14+
ModelId,
15+
{ id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }
16+
> = {
17+
/** text-to-image */
18+
"black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" },
19+
"black-forest-labs/FLUX.1-Depth-dev": { id: "black-forest-labs/FLUX.1-depth", type: "image" },
20+
"black-forest-labs/FLUX.1-dev": { id: "black-forest-labs/FLUX.1-dev", type: "image" },
21+
"black-forest-labs/FLUX.1-Redux-dev": { id: "black-forest-labs/FLUX.1-redux", type: "image" },
22+
"black-forest-labs/FLUX.1-schnell": { id: "black-forest-labs/FLUX.1-pro", type: "image" },
23+
"stabilityai/stable-diffusion-xl-base-1.0": { id: "stabilityai/stable-diffusion-xl-base-1.0", type: "image" },
24+
25+
/** chat completion */
26+
"databricks/dbrx-instruct": { id: "databricks/dbrx-instruct", type: "chat" },
27+
"deepseek-ai/deepseek-llm-67b-chat": { id: "deepseek-ai/deepseek-llm-67b-chat", type: "chat" },
28+
"google/gemma-2-9b-it": { id: "google/gemma-2-9b-it", type: "chat" },
29+
"google/gemma-2b-it": { id: "google/gemma-2-27b-it", type: "chat" },
30+
"llava-hf/llava-v1.6-mistral-7b-hf": { id: "llava-hf/llava-v1.6-mistral-7b-hf", type: "chat" },
31+
"meta-llama/Llama-2-13b-chat-hf": { id: "meta-llama/Llama-2-13b-chat-hf", type: "chat" },
32+
"meta-llama/Llama-2-70b-hf": { id: "meta-llama/Llama-2-70b-hf", type: "language" },
33+
"meta-llama/Llama-2-7b-chat-hf": { id: "meta-llama/Llama-2-7b-chat-hf", type: "chat" },
34+
"meta-llama/Llama-3.2-11B-Vision-Instruct": { id: "meta-llama/Llama-Vision-Free", type: "chat" },
35+
"meta-llama/Llama-3.2-3B-Instruct": { id: "meta-llama/Llama-3.2-3B-Instruct-Turbo", type: "chat" },
36+
"meta-llama/Llama-3.2-90B-Vision-Instruct": { id: "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", type: "chat" },
37+
"meta-llama/Llama-3.3-70B-Instruct": { id: "meta-llama/Llama-3.3-70B-Instruct-Turbo", type: "chat" },
38+
"meta-llama/Meta-Llama-3-70B-Instruct": { id: "meta-llama/Llama-3-70b-chat-hf", type: "chat" },
39+
"meta-llama/Meta-Llama-3-8B-Instruct": { id: "togethercomputer/Llama-3-8b-chat-hf-int4", type: "chat" },
40+
"meta-llama/Meta-Llama-3.1-405B-Instruct": { id: "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", type: "chat" },
41+
"meta-llama/Meta-Llama-3.1-70B-Instruct": { id: "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type: "chat" },
42+
"meta-llama/Meta-Llama-3.1-8B-Instruct": { id: "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K", type: "chat" },
43+
"microsoft/WizardLM-2-8x22B": { id: "microsoft/WizardLM-2-8x22B", type: "chat" },
44+
"mistralai/Mistral-7B-Instruct-v0.3": { id: "mistralai/Mistral-7B-Instruct-v0.3", type: "chat" },
45+
"mistralai/Mixtral-8x22B-Instruct-v0.1": { id: "mistralai/Mixtral-8x22B-Instruct-v0.1", type: "chat" },
46+
"mistralai/Mixtral-8x7B-Instruct-v0.1": { id: "mistralai/Mixtral-8x7B-Instruct-v0.1", type: "chat" },
47+
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": { id: "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", type: "chat" },
48+
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": { id: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", type: "chat" },
49+
"Qwen/Qwen2-72B-Instruct": { id: "Qwen/Qwen2-72B-Instruct", type: "chat" },
50+
"Qwen/Qwen2.5-72B-Instruct": { id: "Qwen/Qwen2.5-72B-Instruct-Turbo", type: "chat" },
51+
"Qwen/Qwen2.5-7B-Instruct": { id: "Qwen/Qwen2.5-7B-Instruct-Turbo", type: "chat" },
52+
"Qwen/Qwen2.5-Coder-32B-Instruct": { id: "Qwen/Qwen2.5-Coder-32B-Instruct", type: "chat" },
53+
"Qwen/QwQ-32B-Preview": { id: "Qwen/QwQ-32B-Preview", type: "chat" },
54+
"scb10x/llama-3-typhoon-v1.5-8b-instruct": { id: "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", type: "chat" },
55+
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": { id: "scb10x/scb10x-llama3-typhoon-v1-5x-4f316", type: "chat" },
56+
57+
/** text-generation */
58+
"meta-llama/Meta-Llama-3-8B": { id: "meta-llama/Meta-Llama-3-8B", type: "language" },
59+
"mistralai/Mixtral-8x7B-v0.1": { id: "mistralai/Mixtral-8x7B-v0.1", type: "language" },
60+
};

0 commit comments

Comments
 (0)