Skip to content

Commit 560585e

Browse files
authored
[Inference] Prepare release (#1112)
- Expose `ProviderMapping` type - Update READMEs
1 parent 959e58b commit 560585e

File tree

5 files changed

+64
-17
lines changed

5 files changed

+64
-17
lines changed

README.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ await uploadFile({
2727
}
2828
});
2929

30-
// Use Inference API
30+
// Use HF Inference API
3131

3232
await inference.chatCompletion({
3333
model: "meta-llama/Llama-3.1-8B-Instruct",
@@ -53,7 +53,7 @@ await inference.textToImage({
5353

5454
This is a collection of JS libraries to interact with the Hugging Face API, with TS types included.
5555

56-
- [@huggingface/inference](packages/inference/README.md): Use Inference API (serverless) and Inference Endpoints (dedicated) to make calls to 100,000+ Machine Learning models
56+
- [@huggingface/inference](packages/inference/README.md): Use Inference API (serverless), Inference Endpoints (dedicated) and third-party Inference providers to make calls to 100,000+ Machine Learning models
5757
- [@huggingface/hub](packages/hub/README.md): Interact with huggingface.co to create or delete repos and commit / download files
5858
- [@huggingface/agents](packages/agents/README.md): Interact with HF models through a natural language interface
5959
- [@huggingface/gguf](packages/gguf/README.md): A GGUF parser that works on remotely hosted files.
@@ -144,6 +144,22 @@ for await (const chunk of inference.chatCompletionStream({
144144
console.log(chunk.choices[0].delta.content);
145145
}
146146

147+
/// Using a third-party provider:
148+
await inference.chatCompletion({
149+
model: "meta-llama/Llama-3.1-8B-Instruct",
150+
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
151+
max_tokens: 512,
152+
provider: "sambanova"
153+
})
154+
155+
await inference.textToImage({
156+
model: "black-forest-labs/FLUX.1-dev",
157+
inputs: "a picture of a green bird",
158+
provider: "together"
159+
})
160+
161+
162+
147163
// You can also omit "model" to use the recommended model for the task
148164
await inference.translation({
149165
inputs: "My name is Wolfgang and I live in Amsterdam",

packages/inference/README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,34 @@ const hf = new HfInference('your access token')
4242

4343
Your access token should be kept private. If you need to protect it in front-end applications, we suggest setting up a proxy server that stores the access token.
4444

45+
### Requesting third-party inference providers
46+
47+
You can request inference from third-party providers with the inference client.
48+
49+
Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz) and [Sambanova](https://sambanova.ai).
50+
51+
To make request to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
52+
```ts
53+
const accessToken = "hf_..."; // Either a HF access token, or an API key from the 3rd party provider (Replicate in this example)
54+
55+
const client = new HfInference(accessToken);
56+
await client.textToImage({
57+
provider: "replicate",
58+
model:"black-forest-labs/Flux.1-dev",
59+
inputs: "A black forest cake"
60+
})
61+
```
62+
63+
When authenticated with a Hugging Face access token, the request is routed through https://huggingface.co.
64+
When authenticated with a third-party provider key, the request is made directly against that provider's inference API.
65+
66+
Only a subset of models are supported when requesting 3rd party providers. You can check the list of supported models per pipeline tasks here:
67+
- [Fal.ai supported models](./src/providers/fal-ai.ts)
68+
- [Replicate supported models](./src/providers/replicate.ts)
69+
- [Sambanova supported models](./src/providers/sambanova.ts)
70+
- [Together supported models](./src/providers/together.ts)
71+
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
72+
4573
#### Tree-shaking
4674

4775
You can import the functions you need directly from the module instead of using the `HfInference` class.

packages/inference/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
export type { ProviderMapping } from "./providers/types"
12
export { HfInference, HfInferenceEndpoint } from "./HfInference";
23
export { InferenceOutputError } from "./lib/InferenceOutputError";
34
export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai";

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { WidgetType } from "@huggingface/tasks";
12
import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config";
23
import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai";
34
import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate";
@@ -65,21 +66,21 @@ export async function makeRequestOptions(
6566
? "hf-token"
6667
: "provider-key"
6768
: includeCredentials === "include"
68-
? "credentials-include"
69-
: "none";
69+
? "credentials-include"
70+
: "none";
7071

7172
const url = endpointUrl
7273
? chatCompletion
7374
? endpointUrl + `/v1/chat/completions`
7475
: endpointUrl
7576
: makeUrl({
76-
authMethod,
77-
chatCompletion: chatCompletion ?? false,
78-
forceTask,
79-
model,
80-
provider: provider ?? "hf-inference",
81-
taskHint,
82-
});
77+
authMethod,
78+
chatCompletion: chatCompletion ?? false,
79+
forceTask,
80+
model,
81+
provider: provider ?? "hf-inference",
82+
taskHint,
83+
});
8384

8485
const headers: Record<string, string> = {};
8586
if (accessToken) {
@@ -133,9 +134,9 @@ export async function makeRequestOptions(
133134
body: binary
134135
? args.data
135136
: JSON.stringify({
136-
...otherArgs,
137-
...(chatCompletion || provider === "together" ? { model } : undefined),
138-
}),
137+
...otherArgs,
138+
...(chatCompletion || provider === "together" ? { model } : undefined),
139+
}),
139140
...(credentials ? { credentials } : undefined),
140141
signal: options?.signal,
141142
};
@@ -155,7 +156,7 @@ function mapModel(params: {
155156
if (!params.taskHint) {
156157
throw new Error("taskHint must be specified when using a third-party provider");
157158
}
158-
const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
159+
const task: WidgetType = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
159160
const model = (() => {
160161
switch (params.provider) {
161162
case "fal-ai":
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import type { InferenceTask, ModelId } from "../types";
1+
import type { WidgetType } from "@huggingface/tasks";
2+
import type { ModelId } from "../types";
23

34
export type ProviderMapping<ProviderId extends string> = Partial<
4-
Record<InferenceTask | "conversational", Partial<Record<ModelId, ProviderId>>>
5+
Record<WidgetType, Partial<Record<ModelId, ProviderId>>>
56
>;

0 commit comments

Comments
 (0)