Skip to content

Commit ea047d1

Browse files
mishig25nbroad1881
andauthored
[Conversational snippet] Fix, refactor, & add tests (#1003)
### Description [Conversational snippet] Fix, refactor, & add tests closes #1010 1. Fixes an error https://github.com/huggingface/huggingface.js/pull/1003/files#r1824269481 2. Redo https://github.com/huggingface/huggingface.js/pull/1003/files#r1824270922 3. Add tests --------- Co-authored-by: Nicholas Broad <[email protected]>
1 parent 138366b commit ea047d1

File tree

9 files changed

+269
-72
lines changed

9 files changed

+269
-72
lines changed

packages/tasks/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"watch": "npm-run-all --parallel watch:export watch:types",
3131
"prepare": "pnpm run build",
3232
"check": "tsc",
33+
"test": "vitest run",
3334
"inference-codegen": "tsx scripts/inference-codegen.ts && prettier --write src/tasks/*/inference.ts",
3435
"inference-tgi-import": "tsx scripts/inference-tgi-import.ts && prettier --write src/tasks/text-generation/spec/*.json && prettier --write src/tasks/chat-completion/spec/*.json",
3536
"inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write src/tasks/feature-extraction/spec/*.json"

packages/tasks/src/snippets/common.ts

Lines changed: 27 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,39 @@
11
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks";
22

3-
export interface StringifyMessagesOptions {
4-
sep: string;
5-
start: string;
6-
end: string;
7-
attributeKeyQuotes?: boolean;
8-
customContentEscaper?: (str: string) => string;
9-
}
10-
11-
export function stringifyMessages(messages: ChatCompletionInputMessage[], opts: StringifyMessagesOptions): string {
12-
const keyRole = opts.attributeKeyQuotes ? `"role"` : "role";
13-
const keyContent = opts.attributeKeyQuotes ? `"content"` : "content";
14-
15-
const messagesStringified = messages.map(({ role, content }) => {
16-
if (typeof content === "string") {
17-
content = JSON.stringify(content).slice(1, -1);
18-
if (opts.customContentEscaper) {
19-
content = opts.customContentEscaper(content);
20-
}
21-
return `{ ${keyRole}: "${role}", ${keyContent}: "${content}" }`;
22-
} else {
23-
2;
24-
content = content.map(({ image_url, text, type }) => ({
25-
type,
26-
image_url,
27-
...(text ? { text: JSON.stringify(text).slice(1, -1) } : undefined),
28-
}));
29-
content = JSON.stringify(content).slice(1, -1);
30-
if (opts.customContentEscaper) {
31-
content = opts.customContentEscaper(content);
32-
}
33-
return `{ ${keyRole}: "${role}", ${keyContent}: [${content}] }`;
34-
}
35-
});
36-
37-
return opts.start + messagesStringified.join(opts.sep) + opts.end;
3+
export function stringifyMessages(
4+
messages: ChatCompletionInputMessage[],
5+
opts?: {
6+
indent?: string;
7+
attributeKeyQuotes?: boolean;
8+
customContentEscaper?: (str: string) => string;
9+
}
10+
): string {
11+
let messagesStr = JSON.stringify(messages, null, "\t");
12+
if (opts?.indent) {
13+
messagesStr = messagesStr.replaceAll("\n", `\n${opts.indent}`);
14+
}
15+
if (!opts?.attributeKeyQuotes) {
16+
messagesStr = messagesStr.replace(/"([^"]+)":/g, "$1:");
17+
}
18+
if (opts?.customContentEscaper) {
19+
messagesStr = opts.customContentEscaper(messagesStr);
20+
}
21+
return messagesStr;
3822
}
3923

4024
type PartialGenerationParameters = Partial<Pick<GenerationParameters, "temperature" | "max_tokens" | "top_p">>;
4125

42-
export interface StringifyGenerationConfigOptions {
43-
sep: string;
44-
start: string;
45-
end: string;
46-
attributeValueConnector: string;
47-
attributeKeyQuotes?: boolean;
48-
}
49-
5026
export function stringifyGenerationConfig(
5127
config: PartialGenerationParameters,
52-
opts: StringifyGenerationConfigOptions
28+
opts: {
29+
indent: string;
30+
attributeValueConnector: string;
31+
attributeKeyQuotes?: boolean;
32+
}
5333
): string {
5434
const quote = opts.attributeKeyQuotes ? `"` : "";
5535

56-
return (
57-
opts.start +
58-
Object.entries(config)
59-
.map(([key, val]) => `${quote}${key}${quote}${opts.attributeValueConnector}${val}`)
60-
.join(opts.sep) +
61-
opts.end
62-
);
36+
return Object.entries(config)
37+
.map(([key, val]) => `${quote}${key}${quote}${opts.attributeValueConnector}${val}`)
38+
.join(`,${opts.indent}`);
6339
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import type { ModelDataMinimal } from "./types";
2+
import { describe, expect, it } from "vitest";
3+
import { snippetTextGeneration } from "./curl";
4+
5+
describe("inference API snippets", () => {
6+
it("conversational llm", async () => {
7+
const model: ModelDataMinimal = {
8+
id: "meta-llama/Llama-3.1-8B-Instruct",
9+
pipeline_tag: "text-generation",
10+
tags: ["conversational"],
11+
inference: "",
12+
};
13+
const snippet = snippetTextGeneration(model, "api_token");
14+
15+
expect(snippet.content)
16+
.toEqual(`curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions' \\
17+
-H "Authorization: Bearer api_token" \\
18+
-H 'Content-Type: application/json' \\
19+
--data '{
20+
"model": "meta-llama/Llama-3.1-8B-Instruct",
21+
"messages": [
22+
{
23+
"role": "user",
24+
"content": "What is the capital of France?"
25+
}
26+
],
27+
"max_tokens": 500,
28+
"stream": true
29+
}'`);
30+
});
31+
32+
it("conversational vlm", async () => {
33+
const model: ModelDataMinimal = {
34+
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
35+
pipeline_tag: "image-text-to-text",
36+
tags: ["conversational"],
37+
inference: "",
38+
};
39+
const snippet = snippetTextGeneration(model, "api_token");
40+
41+
expect(snippet.content)
42+
.toEqual(`curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.2-11B-Vision-Instruct/v1/chat/completions' \\
43+
-H "Authorization: Bearer api_token" \\
44+
-H 'Content-Type: application/json' \\
45+
--data '{
46+
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
47+
"messages": [
48+
{
49+
"role": "user",
50+
"content": [
51+
{
52+
"type": "text",
53+
"text": "Describe this image in one sentence."
54+
},
55+
{
56+
"type": "image_url",
57+
"image_url": {
58+
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
59+
}
60+
}
61+
]
62+
}
63+
],
64+
"max_tokens": 500,
65+
"stream": true
66+
}'`);
67+
});
68+
});

packages/tasks/src/snippets/curl.ts

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,12 @@ export const snippetTextGeneration = (
4141
--data '{
4242
"model": "${model.id}",
4343
"messages": ${stringifyMessages(messages, {
44-
sep: ",\n\t\t",
45-
start: `[\n\t\t`,
46-
end: `\n\t]`,
44+
indent: "\t",
4745
attributeKeyQuotes: true,
4846
customContentEscaper: (str) => str.replace(/'/g, "'\\''"),
4947
})},
5048
${stringifyGenerationConfig(config, {
51-
sep: ",\n ",
52-
start: "",
53-
end: "",
49+
indent: "\n ",
5450
attributeKeyQuotes: true,
5551
attributeValueConnector: ": ",
5652
})},

packages/tasks/src/snippets/inputs.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ const modelInputSnippets: {
128128
"tabular-classification": inputsTabularPrediction,
129129
"text-classification": inputsTextClassification,
130130
"text-generation": inputsTextGeneration,
131+
"image-text-to-text": inputsTextGeneration,
131132
"text-to-image": inputsTextToImage,
132133
"text-to-speech": inputsTextToSpeech,
133134
"text-to-audio": inputsTextToAudio,
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import type { InferenceSnippet, ModelDataMinimal } from "./types";
2+
import { describe, expect, it } from "vitest";
3+
import { snippetTextGeneration } from "./js";
4+
5+
describe("inference API snippets", () => {
6+
it("conversational llm", async () => {
7+
const model: ModelDataMinimal = {
8+
id: "meta-llama/Llama-3.1-8B-Instruct",
9+
pipeline_tag: "text-generation",
10+
tags: ["conversational"],
11+
inference: "",
12+
};
13+
const snippet = snippetTextGeneration(model, "api_token") as InferenceSnippet[];
14+
15+
expect(snippet[0].content).toEqual(`import { HfInference } from "@huggingface/inference"
16+
17+
const client = new HfInference("api_token")
18+
19+
let out = "";
20+
21+
const stream = client.chatCompletionStream({
22+
model: "meta-llama/Llama-3.1-8B-Instruct",
23+
messages: [
24+
{
25+
role: "user",
26+
content: "What is the capital of France?"
27+
}
28+
],
29+
max_tokens: 500
30+
});
31+
32+
for await (const chunk of stream) {
33+
if (chunk.choices && chunk.choices.length > 0) {
34+
const newContent = chunk.choices[0].delta.content;
35+
out += newContent;
36+
console.log(newContent);
37+
}
38+
}`);
39+
});
40+
41+
it("conversational vlm", async () => {
42+
const model: ModelDataMinimal = {
43+
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
44+
pipeline_tag: "image-text-to-text",
45+
tags: ["conversational"],
46+
inference: "",
47+
};
48+
const snippet = snippetTextGeneration(model, "api_token") as InferenceSnippet[];
49+
50+
expect(snippet[0].content).toEqual(`import { HfInference } from "@huggingface/inference"
51+
52+
const client = new HfInference("api_token")
53+
54+
let out = "";
55+
56+
const stream = client.chatCompletionStream({
57+
model: "meta-llama/Llama-3.2-11B-Vision-Instruct",
58+
messages: [
59+
{
60+
role: "user",
61+
content: [
62+
{
63+
type: "text",
64+
text: "Describe this image in one sentence."
65+
},
66+
{
67+
type: "image_url",
68+
image_url: {
69+
url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
70+
}
71+
}
72+
]
73+
}
74+
],
75+
max_tokens: 500
76+
});
77+
78+
for await (const chunk of stream) {
79+
if (chunk.choices && chunk.choices.length > 0) {
80+
const newContent = chunk.choices[0].delta.content;
81+
out += newContent;
82+
console.log(newContent);
83+
}
84+
}`);
85+
});
86+
});

packages/tasks/src/snippets/js.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,15 @@ export const snippetTextGeneration = (
4242
const streaming = opts?.streaming ?? true;
4343
const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
4444
const messages = opts?.messages ?? exampleMessages;
45-
const messagesStr = stringifyMessages(messages, { sep: ",\n\t\t", start: "[\n\t\t", end: "\n\t]" });
45+
const messagesStr = stringifyMessages(messages, { indent: "\t" });
4646

4747
const config = {
4848
...(opts?.temperature ? { temperature: opts.temperature } : undefined),
4949
max_tokens: opts?.max_tokens ?? 500,
5050
...(opts?.top_p ? { top_p: opts.top_p } : undefined),
5151
};
5252
const configStr = stringifyGenerationConfig(config, {
53-
sep: ",\n\t",
54-
start: "",
55-
end: "",
53+
indent: "\n\t",
5654
attributeValueConnector: ": ",
5755
});
5856

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import type { ModelDataMinimal } from "./types";
2+
import { describe, expect, it } from "vitest";
3+
import { snippetConversational } from "./python";
4+
5+
describe("inference API snippets", () => {
6+
it("conversational llm", async () => {
7+
const model: ModelDataMinimal = {
8+
id: "meta-llama/Llama-3.1-8B-Instruct",
9+
pipeline_tag: "text-generation",
10+
tags: ["conversational"],
11+
inference: "",
12+
};
13+
const snippet = snippetConversational(model, "api_token");
14+
15+
expect(snippet[0].content).toEqual(`from huggingface_hub import InferenceClient
16+
17+
client = InferenceClient(api_key="api_token")
18+
19+
messages = [
20+
{
21+
"role": "user",
22+
"content": "What is the capital of France?"
23+
}
24+
]
25+
26+
stream = client.chat.completions.create(
27+
model="meta-llama/Llama-3.1-8B-Instruct",
28+
messages=messages,
29+
max_tokens=500,
30+
stream=True
31+
)
32+
33+
for chunk in stream:
34+
print(chunk.choices[0].delta.content, end="")`);
35+
});
36+
37+
it("conversational vlm", async () => {
38+
const model: ModelDataMinimal = {
39+
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
40+
pipeline_tag: "image-text-to-text",
41+
tags: ["conversational"],
42+
inference: "",
43+
};
44+
const snippet = snippetConversational(model, "api_token");
45+
46+
expect(snippet[0].content).toEqual(`from huggingface_hub import InferenceClient
47+
48+
client = InferenceClient(api_key="api_token")
49+
50+
messages = [
51+
{
52+
"role": "user",
53+
"content": [
54+
{
55+
"type": "text",
56+
"text": "Describe this image in one sentence."
57+
},
58+
{
59+
"type": "image_url",
60+
"image_url": {
61+
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
62+
}
63+
}
64+
]
65+
}
66+
]
67+
68+
stream = client.chat.completions.create(
69+
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
70+
messages=messages,
71+
max_tokens=500,
72+
stream=True
73+
)
74+
75+
for chunk in stream:
76+
print(chunk.choices[0].delta.content, end="")`);
77+
});
78+
});

0 commit comments

Comments
 (0)