Skip to content

Commit fd62db4

Browse files
authored
Update conversational inference API snippets (#976)
### Description This PR updates inference snippet generating functions signatures. Before, the functions were generating `string`. Now, the function will generate `InferenceSnippet | InferenceSnippet []`. https://github.com/huggingface/huggingface.js/blob/5bc694b9e7b845d25743e6db30daef60f85883a7/packages/tasks/src/snippets/types.ts#L14-L17 Why do we need to generate `InferenceSnippet []`, not just `InferenceSnippet`? Because for a given langauge (let's say), we wanna show multiple clients options (`huggingface_hub`, `openai`). (see the attached video below). Also, this PR improves the conversational snippet greatly. ### Screen recording https://github.com/user-attachments/assets/6bc982c5-a855-4879-b9be-ca023d2b59fa
1 parent da39402 commit fd62db4

File tree

5 files changed

+490
-114
lines changed

5 files changed

+490
-114
lines changed

packages/tasks/src/snippets/common.ts

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks";
2+
3+
export interface StringifyMessagesOptions {
4+
sep: string;
5+
start: string;
6+
end: string;
7+
attributeKeyQuotes?: boolean;
8+
customContentEscaper?: (str: string) => string;
9+
}
10+
11+
export function stringifyMessages(messages: ChatCompletionInputMessage[], opts: StringifyMessagesOptions): string {
12+
const keyRole = opts.attributeKeyQuotes ? `"role"` : "role";
13+
const keyContent = opts.attributeKeyQuotes ? `"role"` : "role";
14+
15+
const messagesStringified = messages.map(({ role, content }) => {
16+
if (typeof content === "string") {
17+
content = JSON.stringify(content).slice(1, -1);
18+
if (opts.customContentEscaper) {
19+
content = opts.customContentEscaper(content);
20+
}
21+
return `{ ${keyRole}: "${role}", ${keyContent}: "${content}" }`;
22+
} else {
23+
2;
24+
content = content.map(({ image_url, text, type }) => ({
25+
type,
26+
image_url,
27+
...(text ? { text: JSON.stringify(text).slice(1, -1) } : undefined),
28+
}));
29+
content = JSON.stringify(content).slice(1, -1);
30+
if (opts.customContentEscaper) {
31+
content = opts.customContentEscaper(content);
32+
}
33+
return `{ ${keyRole}: "${role}", ${keyContent}: ${content} }`;
34+
}
35+
});
36+
37+
return opts.start + messagesStringified.join(opts.sep) + opts.end;
38+
}
39+
40+
type PartialGenerationParameters = Partial<Pick<GenerationParameters, "temperature" | "max_tokens" | "top_p">>;
41+
42+
export interface StringifyGenerationConfigOptions {
43+
sep: string;
44+
start: string;
45+
end: string;
46+
attributeValueConnector: string;
47+
attributeKeyQuotes?: boolean;
48+
}
49+
50+
export function stringifyGenerationConfig(
51+
config: PartialGenerationParameters,
52+
opts: StringifyGenerationConfigOptions
53+
): string {
54+
const quote = opts.attributeKeyQuotes ? `"` : "";
55+
56+
return (
57+
opts.start +
58+
Object.entries(config)
59+
.map(([key, val]) => `${quote}${key}${quote}${opts.attributeValueConnector}${val}`)
60+
.join(opts.sep) +
61+
opts.end
62+
);
63+
}

packages/tasks/src/snippets/curl.ts

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,73 @@
11
import type { PipelineType } from "../pipelines.js";
2+
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
3+
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
24
import { getModelInputSnippet } from "./inputs.js";
3-
import type { ModelDataMinimal } from "./types.js";
5+
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
46

5-
export const snippetBasic = (model: ModelDataMinimal, accessToken: string): string =>
6-
`curl https://api-inference.huggingface.co/models/${model.id} \\
7+
export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
8+
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
79
-X POST \\
810
-d '{"inputs": ${getModelInputSnippet(model, true)}}' \\
911
-H 'Content-Type: application/json' \\
10-
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"`;
12+
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"`,
13+
});
1114

12-
export const snippetTextGeneration = (model: ModelDataMinimal, accessToken: string): string => {
15+
export const snippetTextGeneration = (
16+
model: ModelDataMinimal,
17+
accessToken: string,
18+
opts?: {
19+
streaming?: boolean;
20+
messages?: ChatCompletionInputMessage[];
21+
temperature?: GenerationParameters["temperature"];
22+
max_tokens?: GenerationParameters["max_tokens"];
23+
top_p?: GenerationParameters["top_p"];
24+
}
25+
): InferenceSnippet => {
1326
if (model.tags.includes("conversational")) {
1427
// Conversational model detected, so we display a code snippet that features the Messages API
15-
return `curl 'https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions' \\
28+
const streaming = opts?.streaming ?? true;
29+
const messages: ChatCompletionInputMessage[] = opts?.messages ?? [
30+
{ role: "user", content: "What is the capital of France?" },
31+
];
32+
33+
const config = {
34+
...(opts?.temperature ? { temperature: opts.temperature } : undefined),
35+
max_tokens: opts?.max_tokens ?? 500,
36+
...(opts?.top_p ? { top_p: opts.top_p } : undefined),
37+
};
38+
return {
39+
content: `curl 'https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions' \\
1640
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" \\
1741
-H 'Content-Type: application/json' \\
18-
-d '{
19-
"model": "${model.id}",
20-
"messages": [{"role": "user", "content": "What is the capital of France?"}],
21-
"max_tokens": 500,
22-
"stream": false
23-
}'
24-
`;
42+
--data '{
43+
"model": "${model.id}",
44+
"messages": ${stringifyMessages(messages, {
45+
sep: ",\n\t\t",
46+
start: `[\n\t\t`,
47+
end: `\n\t]`,
48+
attributeKeyQuotes: true,
49+
customContentEscaper: (str) => str.replace(/'/g, "'\\''"),
50+
})},
51+
${stringifyGenerationConfig(config, {
52+
sep: ",\n ",
53+
start: "",
54+
end: "",
55+
attributeKeyQuotes: true,
56+
attributeValueConnector: ": ",
57+
})},
58+
"stream": ${!!streaming}
59+
}'`,
60+
};
2561
} else {
2662
return snippetBasic(model, accessToken);
2763
}
2864
};
2965

30-
export const snippetImageTextToTextGeneration = (model: ModelDataMinimal, accessToken: string): string => {
66+
export const snippetImageTextToTextGeneration = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => {
3167
if (model.tags.includes("conversational")) {
3268
// Conversational model detected, so we display a code snippet that features the Messages API
33-
return `curl 'https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions' \\
69+
return {
70+
content: `curl 'https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions' \\
3471
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" \\
3572
-H 'Content-Type: application/json' \\
3673
-d '{
@@ -47,26 +84,34 @@ export const snippetImageTextToTextGeneration = (model: ModelDataMinimal, access
4784
"max_tokens": 500,
4885
"stream": false
4986
}'
50-
`;
87+
`,
88+
};
5189
} else {
5290
return snippetBasic(model, accessToken);
5391
}
5492
};
5593

56-
export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): string =>
57-
`curl https://api-inference.huggingface.co/models/${model.id} \\
94+
export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
95+
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
5896
-X POST \\
5997
-d '{"inputs": ${getModelInputSnippet(model, true)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}}' \\
6098
-H 'Content-Type: application/json' \\
61-
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"`;
99+
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"`,
100+
});
62101

63-
export const snippetFile = (model: ModelDataMinimal, accessToken: string): string =>
64-
`curl https://api-inference.huggingface.co/models/${model.id} \\
102+
export const snippetFile = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
103+
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
65104
-X POST \\
66105
--data-binary '@${getModelInputSnippet(model, true, true)}' \\
67-
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"`;
106+
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"`,
107+
});
68108

69-
export const curlSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal, accessToken: string) => string>> = {
109+
export const curlSnippets: Partial<
110+
Record<
111+
PipelineType,
112+
(model: ModelDataMinimal, accessToken: string, opts?: Record<string, unknown>) => InferenceSnippet
113+
>
114+
> = {
70115
// Same order as in js/src/lib/interfaces/Types.ts
71116
"text-classification": snippetBasic,
72117
"token-classification": snippetBasic,
@@ -93,10 +138,10 @@ export const curlSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal
93138
"image-segmentation": snippetFile,
94139
};
95140

96-
export function getCurlInferenceSnippet(model: ModelDataMinimal, accessToken: string): string {
141+
export function getCurlInferenceSnippet(model: ModelDataMinimal, accessToken: string): InferenceSnippet {
97142
return model.pipeline_tag && model.pipeline_tag in curlSnippets
98-
? curlSnippets[model.pipeline_tag]?.(model, accessToken) ?? ""
99-
: "";
143+
? curlSnippets[model.pipeline_tag]?.(model, accessToken) ?? { content: "" }
144+
: { content: "" };
100145
}
101146

102147
export function hasCurlInferenceSnippet(model: Pick<ModelDataMinimal, "pipeline_tag">): boolean {

0 commit comments

Comments
 (0)