Skip to content

Add debug script to check inference snippets + update Python text-to-image snippets #994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
42 changes: 42 additions & 0 deletions .github/workflows/inference-check-snippets.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Inference check snippets
on:
pull_request:
paths:
- "packages/tasks/src/snippets/**"
- ".github/workflows/inference-check-snippets.yml"

jobs:
check-snippets:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3

- run: corepack enable

- uses: actions/setup-node@v3
with:
node-version: "20"
cache: "pnpm"
cache-dependency-path: "**/pnpm-lock.yaml"
- run: |
cd packages/tasks
pnpm install

# TODO: Find a way to run on all pipeline tags
# TODO: print snippet only if it has changed since the last commit on main (?)
# TODO: (even better: automated message on the PR with diff)
- name: Print text-to-image snippets
run: |
cd packages/tasks
pnpm run check-snippets --pipeline-tag="text-to-image"

- name: Print simple text-generation snippets
run: |
cd packages/tasks
pnpm run check-snippets --pipeline-tag="text-generation"

- name: Print conversational text-generation snippets
run: |
cd packages/tasks
pnpm run check-snippets --pipeline-tag="text-generation" --tags="conversational"
3 changes: 2 additions & 1 deletion packages/tasks/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
"check": "tsc",
"inference-codegen": "tsx scripts/inference-codegen.ts && prettier --write src/tasks/*/inference.ts",
"inference-tgi-import": "tsx scripts/inference-tgi-import.ts && prettier --write src/tasks/text-generation/spec/*.json && prettier --write src/tasks/chat-completion/spec/*.json",
"inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write src/tasks/feature-extraction/spec/*.json"
"inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write src/tasks/feature-extraction/spec/*.json",
"check-snippets": "tsx scripts/check-snippets.ts"
},
"type": "module",
"files": [
Expand Down
55 changes: 55 additions & 0 deletions packages/tasks/scripts/check-snippets.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
Copy link
Collaborator

@mishig25 mishig25 Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented tests in https://github.com/huggingface/huggingface.js/pull/1003/files using the vite test convention of xyz.spec.ts files (which run on pnpm test).

I think we should just put more tests into xyz.spec.ts files rather than creating a custom mechanism of check-snippet.ts & inference-check-snippets.yml. Or am I missing some necessary details?

Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much better yes!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with this solution because I was bored about how difficult it was to debug things locally but automated tests feel much more natural now that you mention it ^^

* Generates inference snippets as they would be shown on the Hub for Curl, JS and Python.
* Snippets will only be printed to the terminal to make it easier to debug when making changes to the snippets.
*
* Usage:
* pnpm run check-snippets --pipeline-tag="text-generation" --tags="conversational"
* pnpm run check-snippets --pipeline-tag="image-text-to-text" --tags="conversational"
* pnpm run check-snippets --pipeline-tag="text-to-image"
*
* This script is meant only for debug purposes.
*/
import { python, curl, js } from "../src/snippets/index";
import type { InferenceSnippet, ModelDataMinimal } from "../src/snippets/types";
import type { PipelineType } from "../src/pipelines";

// Parse command-line arguments
const args = process.argv.slice(2).reduce(
(acc, arg) => {
const [key, value] = arg.split("=");
acc[key.replace("--", "")] = value;
return acc;
},
{} as { [key: string]: string }
);

const accessToken = "hf_**********";
const pipelineTag = (args["pipeline-tag"] || "text-generation") as PipelineType;
const tags = (args["tags"] ?? "").split(",");

const modelMinimal: ModelDataMinimal = {
id: "llama-6-1720B-Instruct",
pipeline_tag: pipelineTag,
tags: tags,
inference: "****",
};

const printSnippets = (snippets: InferenceSnippet | InferenceSnippet[], language: string) => {
const snippetArray = Array.isArray(snippets) ? snippets : [snippets];
snippetArray.forEach((snippet) => {
console.log(`\n\x1b[33m${language} ${snippet.client}\x1b[0m`);
console.log(`\n\`\`\`${language}\n${snippet.content}\n\`\`\`\n`);
});
};

const generateAndPrintSnippets = (
generator: (model: ModelDataMinimal, token: string) => InferenceSnippet | InferenceSnippet[],
language: string
) => {
const snippets = generator(modelMinimal, accessToken);
printSnippets(snippets, language);
};

generateAndPrintSnippets(curl.getCurlInferenceSnippet, "curl");
generateAndPrintSnippets(python.getPythonInferenceSnippet, "python");
generateAndPrintSnippets(js.getJsInferenceSnippet, "js");
27 changes: 24 additions & 3 deletions packages/tasks/src/snippets/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
import { getModelInputSnippet } from "./inputs.js";
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";

const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string =>
`from huggingface_hub import InferenceClient

client = InferenceClient(${model.id}, token="${accessToken || "{API_TOKEN}"}")`;

export const snippetConversational = (
model: ModelDataMinimal,
accessToken: string,
Expand Down Expand Up @@ -168,18 +173,31 @@ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({
output = query(${getModelInputSnippet(model)})`,
});

export const snippetTextToImage = (model: ModelDataMinimal): InferenceSnippet => ({
content: `def query(payload):
export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => {
return [
{
client: "requests",
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content

image_bytes = query({
"inputs": ${getModelInputSnippet(model)},
})
# You can access the image with PIL.Image for example
import io
from PIL import Image
image = Image.open(io.BytesIO(image_bytes))`,
});
},
{
client: "huggingface_hub",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's agree on the convention of putting huggingface_hub as the first item in the list

content: `${snippetImportInferenceClient(model, accessToken)}

# output is a PIL.Image object
image = client.text_to_image(${getModelInputSnippet(model)})`,
},
];
};

export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet => ({
content: `def query(payload):
Expand Down Expand Up @@ -284,6 +302,9 @@ export function getPythonInferenceSnippet(
if (model.tags.includes("conversational")) {
// Conversational model detected, so we display a code snippet that features the Messages API
return snippetConversational(model, accessToken, opts);
} else if (model.pipeline_tag == "text-to-image") {
// TODO: factorize this logic
return snippetTextToImage(model, accessToken);
} else {
let snippets =
model.pipeline_tag && model.pipeline_tag in pythonSnippets
Expand Down
Loading