Skip to content

Commit 7004980

Browse files
authored
Improve text-generation code snippets (#700)
This PR is the first step towards improving auto-generated code snippets, mainly focusing on improving chat model inputs. Highlights of the PR: - [x] JS snippets were missing the content type header (`"Content-Type": "application/json"`) - [x] Code adapted from https://huggingface.co/blog/tgi-messages-api - [x] Moved snippet generation code to separate folders , but I think we can move this all into `tasks/[task]/snippet.ts`. The reason against keeping it all in a single file (which was `snippets/inputs.ts`) is that this will grow in complexity as we improve code snippets across all other tasks. - [x] Some models don't support system messages, and will throw an error or ignore the system message. How should we handle this? (EDIT: Fixed by only specifying a user message by default, which almost all models should support)
1 parent 3b68418 commit 7004980

File tree

5 files changed

+112
-37
lines changed

5 files changed

+112
-37
lines changed

packages/tasks/src/snippets/curl.ts

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): stri
1010
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
1111
`;
1212

13+
export const snippetTextGeneration = (model: ModelDataMinimal, accessToken: string): string => {
14+
if (model.config?.tokenizer_config?.chat_template) {
15+
// Conversational model detected, so we display a code snippet that features the Messages API
16+
return `curl 'https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions' \\
17+
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" \\
18+
-H 'Content-Type: application/json' \\
19+
-d '{
20+
"model": "${model.id}",
21+
"messages": [{"role": "user", "content": "What is the capital of France?"}],
22+
"max_tokens": 500,
23+
"stream": false
24+
}'
25+
`;
26+
} else {
27+
return snippetBasic(model, accessToken);
28+
}
29+
};
30+
1331
export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): string =>
1432
`curl https://api-inference.huggingface.co/models/${model.id} \\
1533
-X POST \\
@@ -35,7 +53,7 @@ export const curlSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal
3553
translation: snippetBasic,
3654
summarization: snippetBasic,
3755
"feature-extraction": snippetBasic,
38-
"text-generation": snippetBasic,
56+
"text-generation": snippetTextGeneration,
3957
"text2text-generation": snippetBasic,
4058
"fill-mask": snippetBasic,
4159
"sentence-similarity": snippetBasic,

packages/tasks/src/snippets/inputs.ts

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,30 @@ const inputsSummarization = () =>
1111

1212
const inputsTableQuestionAnswering = () =>
1313
`{
14-
"query": "How many stars does the transformers repository have?",
15-
"table": {
16-
"Repository": ["Transformers", "Datasets", "Tokenizers"],
17-
"Stars": ["36542", "4512", "3934"],
18-
"Contributors": ["651", "77", "34"],
19-
"Programming language": [
20-
"Python",
21-
"Python",
22-
"Rust, Python and NodeJS"
23-
]
24-
}
25-
}`;
14+
"query": "How many stars does the transformers repository have?",
15+
"table": {
16+
"Repository": ["Transformers", "Datasets", "Tokenizers"],
17+
"Stars": ["36542", "4512", "3934"],
18+
"Contributors": ["651", "77", "34"],
19+
"Programming language": [
20+
"Python",
21+
"Python",
22+
"Rust, Python and NodeJS"
23+
]
24+
}
25+
}`;
2626

2727
const inputsVisualQuestionAnswering = () =>
2828
`{
29-
"image": "cat.png",
30-
"question": "What is in this image?"
31-
}`;
29+
"image": "cat.png",
30+
"question": "What is in this image?"
31+
}`;
3232

3333
const inputsQuestionAnswering = () =>
3434
`{
35-
"question": "What is my name?",
36-
"context": "My name is Clara and I live in Berkeley."
37-
}`;
35+
"question": "What is my name?",
36+
"context": "My name is Clara and I live in Berkeley."
37+
}`;
3838

3939
const inputsTextClassification = () => `"I like you. I love you"`;
4040

@@ -48,13 +48,13 @@ const inputsFillMask = (model: ModelDataMinimal) => `"The answer to the universe
4848

4949
const inputsSentenceSimilarity = () =>
5050
`{
51-
"source_sentence": "That is a happy person",
52-
"sentences": [
53-
"That is a happy dog",
54-
"That is a very happy person",
55-
"Today is a sunny day"
56-
]
57-
}`;
51+
"source_sentence": "That is a happy person",
52+
"sentences": [
53+
"That is a happy dog",
54+
"That is a very happy person",
55+
"Today is a sunny day"
56+
]
57+
}`;
5858

5959
const inputsFeatureExtraction = () => `"Today is a sunny day and I will get some ice cream."`;
6060

packages/tasks/src/snippets/js.ts

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): stri
77
const response = await fetch(
88
"https://api-inference.huggingface.co/models/${model.id}",
99
{
10-
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
10+
headers: {
11+
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
12+
"Content-Type": "application/json",
13+
},
1114
method: "POST",
1215
body: JSON.stringify(data),
1316
}
@@ -20,12 +23,34 @@ query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
2023
console.log(JSON.stringify(response));
2124
});`;
2225

26+
export const snippetTextGeneration = (model: ModelDataMinimal, accessToken: string): string => {
27+
if (model.config?.tokenizer_config?.chat_template) {
28+
// Conversational model detected, so we display a code snippet that features the Messages API
29+
return `import { HfInference } from "@huggingface/inference";
30+
31+
const inference = new HfInference("${accessToken || `{API_TOKEN}`}");
32+
33+
for await (const chunk of inference.chatCompletionStream({
34+
model: "${model.id}",
35+
messages: [{ role: "user", content: "What is the capital of France?" }],
36+
max_tokens: 500,
37+
})) {
38+
process.stdout.write(chunk.choices[0]?.delta?.content || "");
39+
}
40+
`;
41+
} else {
42+
return snippetBasic(model, accessToken);
43+
}
44+
};
2345
export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): string =>
2446
`async function query(data) {
2547
const response = await fetch(
2648
"https://api-inference.huggingface.co/models/${model.id}",
2749
{
28-
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
50+
headers: {
51+
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
52+
"Content-Type": "application/json",
53+
},
2954
method: "POST",
3055
body: JSON.stringify(data),
3156
}
@@ -45,7 +70,10 @@ export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string)
4570
const response = await fetch(
4671
"https://api-inference.huggingface.co/models/${model.id}",
4772
{
48-
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
73+
headers: {
74+
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
75+
"Content-Type": "application/json",
76+
},
4977
method: "POST",
5078
body: JSON.stringify(data),
5179
}
@@ -62,7 +90,10 @@ export const snippetTextToAudio = (model: ModelDataMinimal, accessToken: string)
6290
const response = await fetch(
6391
"https://api-inference.huggingface.co/models/${model.id}",
6492
{
65-
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
93+
headers: {
94+
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
95+
"Content-Type": "application/json",
96+
},
6697
method: "POST",
6798
body: JSON.stringify(data),
6899
}
@@ -99,7 +130,10 @@ export const snippetFile = (model: ModelDataMinimal, accessToken: string): strin
99130
const response = await fetch(
100131
"https://api-inference.huggingface.co/models/${model.id}",
101132
{
102-
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
133+
headers: {
134+
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
135+
"Content-Type": "application/json",
136+
},
103137
method: "POST",
104138
body: data,
105139
}
@@ -122,7 +156,7 @@ export const jsSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal,
122156
translation: snippetBasic,
123157
summarization: snippetBasic,
124158
"feature-extraction": snippetBasic,
125-
"text-generation": snippetBasic,
159+
"text-generation": snippetTextGeneration,
126160
"text2text-generation": snippetBasic,
127161
"fill-mask": snippetBasic,
128162
"sentence-similarity": snippetBasic,

packages/tasks/src/snippets/python.ts

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@ import type { PipelineType } from "../pipelines.js";
22
import { getModelInputSnippet } from "./inputs.js";
33
import type { ModelDataMinimal } from "./types.js";
44

5+
export const snippetConversational = (model: ModelDataMinimal, accessToken: string): string =>
6+
`from huggingface_hub import InferenceClient
7+
8+
client = InferenceClient(
9+
"${model.id}",
10+
token="${accessToken || "{API_TOKEN}"}",
11+
)
12+
13+
for message in client.chat_completion(
14+
messages=[{"role": "user", "content": "What is the capital of France?"}],
15+
max_tokens=500,
16+
stream=True,
17+
):
18+
print(message.choices[0].delta.content, end="")
19+
`;
20+
521
export const snippetZeroShotClassification = (model: ModelDataMinimal): string =>
622
`def query(payload):
723
response = requests.post(API_URL, headers=headers, json=payload)
@@ -107,7 +123,7 @@ output = query({
107123
"inputs": ${getModelInputSnippet(model)},
108124
})`;
109125

110-
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal) => string>> = {
126+
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal, accessToken: string) => string>> = {
111127
// Same order as in tasks/src/pipelines.ts
112128
"text-classification": snippetBasic,
113129
"token-classification": snippetBasic,
@@ -138,15 +154,22 @@ export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinim
138154
};
139155

140156
export function getPythonInferenceSnippet(model: ModelDataMinimal, accessToken: string): string {
141-
const body =
142-
model.pipeline_tag && model.pipeline_tag in pythonSnippets ? pythonSnippets[model.pipeline_tag]?.(model) ?? "" : "";
157+
if (model.pipeline_tag === "text-generation" && model.config?.tokenizer_config?.chat_template) {
158+
// Conversational model detected, so we display a code snippet that features the Messages API
159+
return snippetConversational(model, accessToken);
160+
} else {
161+
const body =
162+
model.pipeline_tag && model.pipeline_tag in pythonSnippets
163+
? pythonSnippets[model.pipeline_tag]?.(model, accessToken) ?? ""
164+
: "";
143165

144-
return `import requests
166+
return `import requests
145167
146168
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
147169
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
148170
149171
${body}`;
172+
}
150173
}
151174

152175
export function hasPythonInferenceSnippet(model: ModelDataMinimal): boolean {

packages/tasks/src/snippets/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ import type { ModelData } from "../model-data";
55
*
66
* Add more fields as needed.
77
*/
8-
export type ModelDataMinimal = Pick<ModelData, "id" | "pipeline_tag" | "mask_token" | "library_name">;
8+
export type ModelDataMinimal = Pick<ModelData, "id" | "pipeline_tag" | "mask_token" | "library_name" | "config">;

0 commit comments

Comments
 (0)