Skip to content

Commit f78bf7a

Browse files
radamescoyotte508julien-c
authored
Add chat completion method (#645)
Supersede #581. Thanks to @Wauplin, I can import the types from "@huggingface/tasks" I've followed the pattern for `textGeneration` and `textGenerationStream`. --------- Co-authored-by: coyotte508 <[email protected]> Co-authored-by: Julien Chaumond <[email protected]>
1 parent de30544 commit f78bf7a

File tree

14 files changed

+1535
-84
lines changed

14 files changed

+1535
-84
lines changed

README.md

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<p align="center">
22
<br/>
3-
<picture>
3+
<picture>
44
<source media="(prefers-color-scheme: dark)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingfacejs-dark.svg">
55
<source media="(prefers-color-scheme: light)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingfacejs-light.svg">
66
<img alt="huggingface javascript library logo" src="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingfacejs-light.svg" width="376" height="59" style="max-width: 100%;">
@@ -56,8 +56,7 @@ This is a collection of JS libraries to interact with the Hugging Face API, with
5656
- [@huggingface/tasks](packages/tasks/README.md): The definition files and source-of-truth for the Hub's main primitives like pipeline tasks, model libraries, etc.
5757

5858

59-
60-
We use modern features to avoid polyfills and dependencies, so the libraries will only work on modern browsers / Node.js >= 18 / Bun / Deno.
59+
We use modern features to avoid polyfills and dependencies, so the libraries will only work on modern browsers / Node.js >= 18 / Bun / Deno.
6160

6261
The libraries are still very young, please help us by opening issues!
6362

@@ -108,7 +107,6 @@ import { HfAgent } from "npm:@huggingface/agents";
108107
import { createRepo, commit, deleteRepo, listFiles } from "npm:@huggingface/hub"
109108
```
110109

111-
112110
## Usage examples
113111

114112
Get your HF access token in your [account settings](https://huggingface.co/settings/tokens).
@@ -122,6 +120,23 @@ const HF_TOKEN = "hf_...";
122120

123121
const inference = new HfInference(HF_TOKEN);
124122

123+
// Chat completion API
124+
const out = await inference.chatCompletion({
125+
model: "mistralai/Mistral-7B-Instruct-v0.2",
126+
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }],
127+
max_tokens: 100
128+
});
129+
console.log(out.choices[0].message);
130+
131+
// Streaming chat completion API
132+
for await (const chunk of inference.chatCompletionStream({
133+
model: "mistralai/Mistral-7B-Instruct-v0.2",
134+
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }],
135+
max_tokens: 100
136+
})) {
137+
console.log(chunk.choices[0].delta.content);
138+
}
139+
125140
// You can also omit "model" to use the recommended model for the task
126141
await inference.translation({
127142
model: 't5-base',
@@ -144,6 +159,17 @@ await inference.imageToText({
144159
// Using your own dedicated inference endpoint: https://hf.co/docs/inference-endpoints/
145160
const gpt2 = inference.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
146161
const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the universe is'});
162+
163+
//Chat Completion
164+
const mistal = inference.endpoint(
165+
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
166+
);
167+
const out = await mistal.chatCompletion({
168+
model: "mistralai/Mistral-7B-Instruct-v0.2",
169+
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }],
170+
max_tokens: 100,
171+
});
172+
console.log(out.choices[0].message);
147173
```
148174

149175
### @huggingface/hub examples
@@ -200,7 +226,6 @@ const messages = await agent.run("Draw a picture of a cat wearing a top hat. The
200226
console.log(messages);
201227
```
202228

203-
204229
There are more features of course, check each library's README!
205230

206231
## Formatting & testing

packages/inference/README.md

Lines changed: 122 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ It works with both [Inference API (serverless)](https://huggingface.co/docs/api-
55

66
Check out the [full documentation](https://huggingface.co/docs/huggingface.js/inference/README).
77

8-
You can also try out a live [interactive notebook](https://observablehq.com/@huggingface/hello-huggingface-js-inference), see some demos on [hf.co/huggingfacejs](https://huggingface.co/huggingfacejs), or watch a [Scrimba tutorial that explains how Inference Endpoints works](https://scrimba.com/scrim/cod8248f5adfd6e129582c523).
8+
You can also try out a live [interactive notebook](https://observablehq.com/@huggingface/hello-huggingface-js-inference), see some demos on [hf.co/huggingfacejs](https://huggingface.co/huggingfacejs), or watch a [Scrimba tutorial that explains how Inference Endpoints works](https://scrimba.com/scrim/cod8248f5adfd6e129582c523).
99

1010
## Getting Started
1111

@@ -30,7 +30,6 @@ import { HfInference } from "https://esm.sh/@huggingface/inference"
3030
import { HfInference } from "npm:@huggingface/inference"
3131
```
3232

33-
3433
### Initialize
3534

3635
```typescript
@@ -43,7 +42,6 @@ const hf = new HfInference('your access token')
4342

4443
Your access token should be kept private. If you need to protect it in front-end applications, we suggest setting up a proxy server that stores the access token.
4544

46-
4745
#### Tree-shaking
4846

4947
You can import the functions you need directly from the module instead of using the `HfInference` class.
@@ -63,6 +61,85 @@ This will enable tree-shaking by your bundler.
6361

6462
## Natural Language Processing
6563

64+
### Text Generation
65+
66+
Generates text from an input prompt.
67+
68+
[Demo](https://huggingface.co/spaces/huggingfacejs/streaming-text-generation)
69+
70+
```typescript
71+
await hf.textGeneration({
72+
model: 'gpt2',
73+
inputs: 'The answer to the universe is'
74+
})
75+
76+
for await (const output of hf.textGenerationStream({
77+
model: "google/flan-t5-xxl",
78+
inputs: 'repeat "one two three four"',
79+
parameters: { max_new_tokens: 250 }
80+
})) {
81+
console.log(output.token.text, output.generated_text);
82+
}
83+
```
84+
85+
### Text Generation (Chat Completion API Compatible)
86+
87+
Using the `chatCompletion` method, you can generate text with models compatible with the OpenAI Chat Completion API. All models served by [TGI](https://api-inference.huggingface.co/framework/text-generation-inference) on Hugging Face support Messages API.
88+
89+
[Demo](https://huggingface.co/spaces/huggingfacejs/streaming-chat-completion)
90+
91+
```typescript
92+
// Non-streaming API
93+
const out = await hf.chatCompletion({
94+
model: "mistralai/Mistral-7B-Instruct-v0.2",
95+
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }],
96+
max_tokens: 500,
97+
temperature: 0.1,
98+
seed: 0,
99+
});
100+
101+
// Streaming API
102+
let out = "";
103+
for await (const chunk of hf.chatCompletionStream({
104+
model: "mistralai/Mistral-7B-Instruct-v0.2",
105+
messages: [
106+
{ role: "user", content: "Complete the equation 1+1= ,just the answer" },
107+
],
108+
max_tokens: 500,
109+
temperature: 0.1,
110+
seed: 0,
111+
})) {
112+
if (chunk.choices && chunk.choices.length > 0) {
113+
out += chunk.choices[0].delta.content;
114+
}
115+
}
116+
```
117+
118+
It's also possible to call Mistral or OpenAI endpoints directly:
119+
120+
```typescript
121+
const openai = new HfInference(OPENAI_TOKEN).endpoint("https://api.openai.com");
122+
123+
let out = "";
124+
for await (const chunk of openai.chatCompletionStream({
125+
model: "gpt-3.5-turbo",
126+
messages: [
127+
{ role: "user", content: "Complete the equation 1+1= ,just the answer" },
128+
],
129+
max_tokens: 500,
130+
temperature: 0.1,
131+
seed: 0,
132+
})) {
133+
if (chunk.choices && chunk.choices.length > 0) {
134+
out += chunk.choices[0].delta.content;
135+
}
136+
}
137+
138+
// For mistral AI:
139+
// endpointUrl: "https://api.mistral.ai"
140+
// model: "mistral-tiny"
141+
```
142+
66143
### Fill Mask
67144

68145
Tries to fill in a hole with a missing word (token to be precise).
@@ -131,27 +208,6 @@ await hf.textClassification({
131208
})
132209
```
133210

134-
### Text Generation
135-
136-
Generates text from an input prompt.
137-
138-
[Demo](https://huggingface.co/spaces/huggingfacejs/streaming-text-generation)
139-
140-
```typescript
141-
await hf.textGeneration({
142-
model: 'gpt2',
143-
inputs: 'The answer to the universe is'
144-
})
145-
146-
for await (const output of hf.textGenerationStream({
147-
model: "google/flan-t5-xxl",
148-
inputs: 'repeat "one two three four"',
149-
parameters: { max_new_tokens: 250 }
150-
})) {
151-
console.log(output.token.text, output.generated_text);
152-
}
153-
```
154-
155211
### Token Classification
156212

157213
Used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
@@ -177,9 +233,9 @@ await hf.translation({
177233
model: 'facebook/mbart-large-50-many-to-many-mmt',
178234
inputs: textToTranslate,
179235
parameters: {
180-
"src_lang": "en_XX",
181-
"tgt_lang": "fr_XX"
182-
}
236+
"src_lang": "en_XX",
237+
"tgt_lang": "fr_XX"
238+
}
183239
})
184240
```
185241

@@ -497,13 +553,52 @@ for await (const output of hf.streamingRequest({
497553
}
498554
```
499555

556+
You can use any Chat Completion API-compatible provider with the `chatCompletion` method.
557+
558+
```typescript
559+
// Chat Completion Example
560+
const MISTRAL_KEY = process.env.MISTRAL_KEY;
561+
const hf = new HfInference(MISTRAL_KEY);
562+
const ep = hf.endpoint("https://api.mistral.ai");
563+
const stream = ep.chatCompletionStream({
564+
model: "mistral-tiny",
565+
messages: [{ role: "user", content: "Complete the equation one + one = , just the answer" }],
566+
});
567+
let out = "";
568+
for await (const chunk of stream) {
569+
if (chunk.choices && chunk.choices.length > 0) {
570+
out += chunk.choices[0].delta.content;
571+
console.log(out);
572+
}
573+
}
574+
```
575+
500576
## Custom Inference Endpoints
501577

502578
Learn more about using your own inference endpoints [here](https://hf.co/docs/inference-endpoints/)
503579

504580
```typescript
505581
const gpt2 = hf.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
506582
const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the universe is'});
583+
584+
// Chat Completion Example
585+
const ep = hf.endpoint(
586+
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
587+
);
588+
const stream = ep.chatCompletionStream({
589+
model: "tgi",
590+
messages: [{ role: "user", content: "Complete the equation 1+1= ,just the answer" }],
591+
max_tokens: 500,
592+
temperature: 0.1,
593+
seed: 0,
594+
});
595+
let out = "";
596+
for await (const chunk of stream) {
597+
if (chunk.choices && chunk.choices.length > 0) {
598+
out += chunk.choices[0].delta.content;
599+
console.log(out);
600+
}
601+
}
507602
```
508603

509604
By default, all calls to the inference endpoint will wait until the model is

packages/inference/src/HfInference.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ type TaskWithNoAccessToken = {
1414
) => ReturnType<Task[key]>;
1515
};
1616

17-
type TaskWithNoAccessTokenNoModel = {
17+
type TaskWithNoAccessTokenNoEndpointUrl = {
1818
[key in keyof Task]: (
19-
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "model">,
19+
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "endpointUrl">,
2020
options?: Parameters<Task[key]>[1]
2121
) => ReturnType<Task[key]>;
2222
};
@@ -57,12 +57,12 @@ export class HfInferenceEndpoint {
5757
enumerable: false,
5858
value: (params: RequestArgs, options: Options) =>
5959
// eslint-disable-next-line @typescript-eslint/no-explicit-any
60-
fn({ ...params, accessToken, model: endpointUrl } as any, { ...defaultOptions, ...options }),
60+
fn({ ...params, accessToken, endpointUrl } as any, { ...defaultOptions, ...options }),
6161
});
6262
}
6363
}
6464
}
6565

6666
export interface HfInference extends TaskWithNoAccessToken {}
6767

68-
export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoModel {}
68+
export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoEndpointUrl {}

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { InferenceTask, Options, RequestArgs } from "../types";
2+
import { omit } from "../utils/omit";
23
import { HF_HUB_URL } from "./getDefaultTask";
34
import { isUrl } from "./isUrl";
45

@@ -22,10 +23,10 @@ export async function makeRequestOptions(
2223
forceTask?: string | InferenceTask;
2324
/** To load default model if needed */
2425
taskHint?: InferenceTask;
26+
chatCompletion?: boolean;
2527
}
2628
): Promise<{ url: string; info: RequestInit }> {
27-
// eslint-disable-next-line @typescript-eslint/no-unused-vars
28-
const { accessToken, model: _model, ...otherArgs } = args;
29+
const { accessToken, endpointUrl, ...otherArgs } = args;
2930
let { model } = args;
3031
const {
3132
forceTask: task,
@@ -34,7 +35,7 @@ export async function makeRequestOptions(
3435
wait_for_model,
3536
use_cache,
3637
dont_load_model,
37-
...otherOptions
38+
chatCompletion,
3839
} = options ?? {};
3940

4041
const headers: Record<string, string> = {};
@@ -77,18 +78,28 @@ export async function makeRequestOptions(
7778
headers["X-Load-Model"] = "0";
7879
}
7980

80-
const url = (() => {
81+
let url = (() => {
82+
if (endpointUrl && isUrl(model)) {
83+
throw new TypeError("Both model and endpointUrl cannot be URLs");
84+
}
8185
if (isUrl(model)) {
86+
console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
8287
return model;
8388
}
84-
89+
if (endpointUrl) {
90+
return endpointUrl;
91+
}
8592
if (task) {
8693
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
8794
}
8895

8996
return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
9097
})();
9198

99+
if (chatCompletion && !url.endsWith("/chat/completions")) {
100+
url += "/v1/chat/completions";
101+
}
102+
92103
/**
93104
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
94105
*/
@@ -105,8 +116,7 @@ export async function makeRequestOptions(
105116
body: binary
106117
? args.data
107118
: JSON.stringify({
108-
...otherArgs,
109-
options: options && otherOptions,
119+
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs),
110120
}),
111121
...(credentials && { credentials }),
112122
signal: options?.signal,

packages/inference/src/tasks/custom/request.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ export async function request<T>(
1111
task?: string | InferenceTask;
1212
/** To load default model if needed */
1313
taskHint?: InferenceTask;
14+
/** Is chat completion compatible */
15+
chatCompletion?: boolean;
1416
}
1517
): Promise<T> {
1618
const { url, info } = await makeRequestOptions(args, options);
@@ -26,6 +28,9 @@ export async function request<T>(
2628
if (!response.ok) {
2729
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
2830
const output = await response.json();
31+
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
32+
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
33+
}
2934
if (output.error) {
3035
throw new Error(output.error);
3136
}

0 commit comments

Comments
 (0)