Skip to content

Commit a2057d3

Browse files
committed
curl
1 parent 49042ca commit a2057d3

File tree

1 file changed

+54
-10
lines changed

1 file changed

+54
-10
lines changed

packages/tasks/src/snippets/curl.ts

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import type { PipelineType } from "../pipelines.js";
2+
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
23
import { getModelInputSnippet } from "./inputs.js";
3-
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
4+
import type {
5+
GenerationConfigFormatter,
6+
GenerationMessagesFormatter,
7+
InferenceSnippet,
8+
ModelDataMinimal,
9+
} from "./types.js";
410

511
export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
612
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
@@ -10,20 +16,58 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): Infe
1016
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"`,
1117
});
1218

13-
export const snippetTextGeneration = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => {
19+
const formatGenerationMessages: GenerationMessagesFormatter = ({ messages, sep, start, end }) =>
20+
start +
21+
messages
22+
.map(({ role, content }) => {
23+
// escape single quotes since single quotes is used to define http post body inside curl requests
24+
// TODO: handle the case below
25+
content = content?.replace(/'/g, "'\\''");
26+
return `{ "role": "${role}", "content": "${content}" }`;
27+
})
28+
.join(sep) +
29+
end;
30+
31+
const formatGenerationConfig: GenerationConfigFormatter = ({ config, sep, start, end }) =>
32+
start +
33+
Object.entries(config)
34+
.map(([key, val]) => `"${key}": ${val}`)
35+
.join(sep) +
36+
end;
37+
38+
export const snippetTextGeneration = (
39+
model: ModelDataMinimal,
40+
accessToken: string,
41+
opts?: {
42+
streaming?: boolean;
43+
messages?: ChatCompletionInputMessage[];
44+
temperature?: GenerationParameters["temperature"];
45+
max_tokens?: GenerationParameters["max_tokens"];
46+
top_p?: GenerationParameters["top_p"];
47+
}
48+
): InferenceSnippet => {
1449
if (model.tags.includes("conversational")) {
1550
// Conversational model detected, so we display a code snippet that features the Messages API
51+
const streaming = opts?.streaming ?? true;
52+
const messages: ChatCompletionInputMessage[] = opts?.messages ?? [
53+
{ role: "user", content: "What is the capital of France?" },
54+
];
55+
56+
const config = {
57+
temperature: opts?.temperature,
58+
max_tokens: opts?.max_tokens ?? 500,
59+
top_p: opts?.top_p,
60+
};
1661
return {
1762
content: `curl 'https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions' \\
1863
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" \\
1964
-H 'Content-Type: application/json' \\
20-
-d '{
21-
"model": "${model.id}",
22-
"messages": [{"role": "user", "content": "What is the capital of France?"}],
23-
"max_tokens": 500,
24-
"stream": false
25-
}'
26-
`,
65+
--data '{
66+
"model": "${model.id}",
67+
"messages": ${formatGenerationMessages({ messages, sep: ",\n ", start: `[\n `, end: `\n]` })},
68+
${formatGenerationConfig({ config, sep: ",\n ", start: "", end: "" })},
69+
"stream": ${!!streaming}
70+
}'`,
2771
};
2872
} else {
2973
return snippetBasic(model, accessToken);
@@ -76,7 +120,7 @@ export const snippetFile = (model: ModelDataMinimal, accessToken: string): Infer
76120
export const curlSnippets: Partial<
77121
Record<
78122
PipelineType,
79-
(model: ModelDataMinimal, accessToken: string, opts?: Record<string, string | boolean | number>) => InferenceSnippet
123+
(model: ModelDataMinimal, accessToken: string, opts?: Record<string, unknown>) => InferenceSnippet
80124
>
81125
> = {
82126
// Same order as in js/src/lib/interfaces/Types.ts

0 commit comments

Comments
 (0)