Skip to content

[Conversational snippet] Fix, refactor, & add tests #1003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/tasks/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"watch": "npm-run-all --parallel watch:export watch:types",
"prepare": "pnpm run build",
"check": "tsc",
"test": "vitest run",
"inference-codegen": "tsx scripts/inference-codegen.ts && prettier --write src/tasks/*/inference.ts",
"inference-tgi-import": "tsx scripts/inference-tgi-import.ts && prettier --write src/tasks/text-generation/spec/*.json && prettier --write src/tasks/chat-completion/spec/*.json",
"inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write src/tasks/feature-extraction/spec/*.json"
Expand Down
78 changes: 27 additions & 51 deletions packages/tasks/src/snippets/common.ts
Original file line number Diff line number Diff line change
@@ -1,63 +1,39 @@
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks";

export interface StringifyMessagesOptions {
sep: string;
start: string;
end: string;
attributeKeyQuotes?: boolean;
customContentEscaper?: (str: string) => string;
}

export function stringifyMessages(messages: ChatCompletionInputMessage[], opts: StringifyMessagesOptions): string {
const keyRole = opts.attributeKeyQuotes ? `"role"` : "role";
const keyContent = opts.attributeKeyQuotes ? `"content"` : "content";

const messagesStringified = messages.map(({ role, content }) => {
if (typeof content === "string") {
content = JSON.stringify(content).slice(1, -1);
if (opts.customContentEscaper) {
content = opts.customContentEscaper(content);
}
return `{ ${keyRole}: "${role}", ${keyContent}: "${content}" }`;
} else {
2;
content = content.map(({ image_url, text, type }) => ({
type,
image_url,
...(text ? { text: JSON.stringify(text).slice(1, -1) } : undefined),
}));
content = JSON.stringify(content).slice(1, -1);
if (opts.customContentEscaper) {
content = opts.customContentEscaper(content);
}
return `{ ${keyRole}: "${role}", ${keyContent}: [${content}] }`;
}
});

return opts.start + messagesStringified.join(opts.sep) + opts.end;
export function stringifyMessages(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redo stringifyMessages function to be cleaner

messages: ChatCompletionInputMessage[],
opts?: {
indent?: string;
attributeKeyQuotes?: boolean;
customContentEscaper?: (str: string) => string;
}
): string {
let messagesStr = JSON.stringify(messages, null, "\t");
if (opts?.indent) {
messagesStr = messagesStr.replaceAll("\n", `\n${opts.indent}`);
}
if (!opts?.attributeKeyQuotes) {
messagesStr = messagesStr.replace(/"([^"]+)":/g, "$1:");
}
if (opts?.customContentEscaper) {
messagesStr = opts.customContentEscaper(messagesStr);
}
return messagesStr;
}

type PartialGenerationParameters = Partial<Pick<GenerationParameters, "temperature" | "max_tokens" | "top_p">>;

export interface StringifyGenerationConfigOptions {
sep: string;
start: string;
end: string;
attributeValueConnector: string;
attributeKeyQuotes?: boolean;
}

export function stringifyGenerationConfig(
config: PartialGenerationParameters,
opts: StringifyGenerationConfigOptions
opts: {
indent: string;
attributeValueConnector: string;
attributeKeyQuotes?: boolean;
}
): string {
const quote = opts.attributeKeyQuotes ? `"` : "";

return (
opts.start +
Object.entries(config)
.map(([key, val]) => `${quote}${key}${quote}${opts.attributeValueConnector}${val}`)
.join(opts.sep) +
opts.end
);
return Object.entries(config)
.map(([key, val]) => `${quote}${key}${quote}${opts.attributeValueConnector}${val}`)
.join(`,${opts.indent}`);
}
68 changes: 68 additions & 0 deletions packages/tasks/src/snippets/curl.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import type { ModelDataMinimal } from "./types";
import { describe, expect, it } from "vitest";
import { snippetTextGeneration } from "./curl";

describe("inference API snippets", () => {
it("conversational llm", async () => {
const model: ModelDataMinimal = {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
};
const snippet = snippetTextGeneration(model, "api_token");

expect(snippet.content)
.toEqual(`curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions' \\
-H "Authorization: Bearer api_token" \\
-H 'Content-Type: application/json' \\
--data '{
"model": "meta-llama/Llama-3.1-8B-Instruct",
"messages": [
{
"role": "user",
"content": "What is the capital of France?"
}
],
"max_tokens": 500,
"stream": true
}'`);
});

it("conversational vlm", async () => {
const model: ModelDataMinimal = {
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
pipeline_tag: "image-text-to-text",
tags: ["conversational"],
inference: "",
};
const snippet = snippetTextGeneration(model, "api_token");

expect(snippet.content)
.toEqual(`curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.2-11B-Vision-Instruct/v1/chat/completions' \\
-H "Authorization: Bearer api_token" \\
-H 'Content-Type: application/json' \\
--data '{
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe this image in one sentence."
},
{
"type": "image_url",
"image_url": {
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
}
}
]
}
],
"max_tokens": 500,
"stream": true
}'`);
});
});
8 changes: 2 additions & 6 deletions packages/tasks/src/snippets/curl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,12 @@ export const snippetTextGeneration = (
--data '{
"model": "${model.id}",
"messages": ${stringifyMessages(messages, {
sep: ",\n\t\t",
start: `[\n\t\t`,
end: `\n\t]`,
indent: "\t",
attributeKeyQuotes: true,
customContentEscaper: (str) => str.replace(/'/g, "'\\''"),
})},
${stringifyGenerationConfig(config, {
sep: ",\n ",
start: "",
end: "",
indent: "\n ",
attributeKeyQuotes: true,
attributeValueConnector: ": ",
})},
Expand Down
1 change: 1 addition & 0 deletions packages/tasks/src/snippets/inputs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ const modelInputSnippets: {
"tabular-classification": inputsTabularPrediction,
"text-classification": inputsTextClassification,
"text-generation": inputsTextGeneration,
"image-text-to-text": inputsTextGeneration,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fix was needed, which caused an error on the hub. slack context.

I've created tests files for snippets that test the snippets for both LLMs & VLMs

"text-to-image": inputsTextToImage,
"text-to-speech": inputsTextToSpeech,
"text-to-audio": inputsTextToAudio,
Expand Down
86 changes: 86 additions & 0 deletions packages/tasks/src/snippets/js.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import type { InferenceSnippet, ModelDataMinimal } from "./types";
import { describe, expect, it } from "vitest";
import { snippetTextGeneration } from "./js";

describe("inference API snippets", () => {
it("conversational llm", async () => {
const model: ModelDataMinimal = {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
};
const snippet = snippetTextGeneration(model, "api_token") as InferenceSnippet[];

expect(snippet[0].content).toEqual(`import { HfInference } from "@huggingface/inference"

const client = new HfInference("api_token")

let out = "";

const stream = client.chatCompletionStream({
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [
{
role: "user",
content: "What is the capital of France?"
}
],
max_tokens: 500
});

for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const newContent = chunk.choices[0].delta.content;
out += newContent;
console.log(newContent);
}
}`);
});

it("conversational vlm", async () => {
const model: ModelDataMinimal = {
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
pipeline_tag: "image-text-to-text",
tags: ["conversational"],
inference: "",
};
const snippet = snippetTextGeneration(model, "api_token") as InferenceSnippet[];

expect(snippet[0].content).toEqual(`import { HfInference } from "@huggingface/inference"

const client = new HfInference("api_token")

let out = "";

const stream = client.chatCompletionStream({
model: "meta-llama/Llama-3.2-11B-Vision-Instruct",
messages: [
{
role: "user",
content: [
{
type: "text",
text: "Describe this image in one sentence."
},
{
type: "image_url",
image_url: {
url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
}
}
]
}
],
max_tokens: 500
});

for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const newContent = chunk.choices[0].delta.content;
out += newContent;
console.log(newContent);
}
}`);
});
});
6 changes: 2 additions & 4 deletions packages/tasks/src/snippets/js.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,15 @@ export const snippetTextGeneration = (
const streaming = opts?.streaming ?? true;
const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[];
const messages = opts?.messages ?? exampleMessages;
const messagesStr = stringifyMessages(messages, { sep: ",\n\t\t", start: "[\n\t\t", end: "\n\t]" });
const messagesStr = stringifyMessages(messages, { indent: "\t" });

const config = {
...(opts?.temperature ? { temperature: opts.temperature } : undefined),
max_tokens: opts?.max_tokens ?? 500,
...(opts?.top_p ? { top_p: opts.top_p } : undefined),
};
const configStr = stringifyGenerationConfig(config, {
sep: ",\n\t",
start: "",
end: "",
indent: "\n\t",
attributeValueConnector: ": ",
});

Expand Down
78 changes: 78 additions & 0 deletions packages/tasks/src/snippets/python.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import type { ModelDataMinimal } from "./types";
import { describe, expect, it } from "vitest";
import { snippetConversational } from "./python";

describe("inference API snippets", () => {
it("conversational llm", async () => {
const model: ModelDataMinimal = {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
};
const snippet = snippetConversational(model, "api_token");

expect(snippet[0].content).toEqual(`from huggingface_hub import InferenceClient

client = InferenceClient(api_key="api_token")

messages = [
{
"role": "user",
"content": "What is the capital of France?"
}
]

stream = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=messages,
max_tokens=500,
stream=True
)

for chunk in stream:
print(chunk.choices[0].delta.content, end="")`);
});

it("conversational vlm", async () => {
const model: ModelDataMinimal = {
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
pipeline_tag: "image-text-to-text",
tags: ["conversational"],
inference: "",
};
const snippet = snippetConversational(model, "api_token");

expect(snippet[0].content).toEqual(`from huggingface_hub import InferenceClient

client = InferenceClient(api_key="api_token")

messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe this image in one sentence."
},
{
"type": "image_url",
"image_url": {
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
}
}
]
}
]

stream = client.chat.completions.create(
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
messages=messages,
max_tokens=500,
stream=True
)

for chunk in stream:
print(chunk.choices[0].delta.content, end="")`);
});
});
Loading
Loading