Skip to content

Commit e172a27

Browse files
committed
Add script to check inference snippet changes + update Python textToImage snippet
1 parent 461d989 commit e172a27

File tree

3 files changed

+74
-4
lines changed

3 files changed

+74
-4
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
name: Inference check snippets
2+
on:
3+
pull_request:
4+
paths:
5+
- "packages/tasks/src/snippets/**"
6+
- ".github/workflows/inference-check-snippets.yml"
7+
8+
jobs:
9+
check-snippets:
10+
runs-on: ubuntu-latest
11+
12+
steps:
13+
- uses: actions/checkout@v3
14+
15+
- run: corepack enable
16+
17+
- uses: actions/setup-node@v3
18+
with:
19+
node-version: "20"
20+
cache: "pnpm"
21+
cache-dependency-path: "**/pnpm-lock.yaml"
22+
- run: |
23+
cd packages/tasks
24+
pnpm install --frozen-lockfile --filter .
25+
pnpm install --frozen-lockfile --filter ...[${{ steps.since.outputs.SINCE }}]...
26+
pnpm --filter ...[${{ steps.since.outputs.SINCE }}]... build
27+
28+
# TODO: Find a way to run on all pipeline tags
29+
# TODO: print snippet only if it has changed since the last commit on main (?)
30+
# TODO: (even better: automated message on the PR with diff)
31+
- name: Print text-to-image snippets
32+
run: |
33+
cd packages/tasks
34+
pnpm run check-snippets --pipeline-tag="text-to-image"
35+
36+
- name: Print simple text-generation snippets
37+
run: |
38+
cd packages/tasks
39+
pnpm run check-snippets --pipeline-tag="text-generation"
40+
41+
- name: Print conversational text-generation snippets
42+
run: |
43+
cd packages/tasks
44+
pnpm run check-snippets --pipeline-tag="text-generation" --tags="conversational"
45+
46+
- name: Print conversational image-text-to-text snippets
47+
run: |
48+
cd packages/tasks
49+
pnpm run check-snippets --pipeline-tag="image-text-to-text" --tags="conversational"

packages/tasks/scripts/check-snippets.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import minimist from "minimist";
1717
const args = minimist(process.argv.slice(2));
1818

1919
const accessToken = "hf_**********";
20-
const pipelineTag = args["pipeline-type"] || "text-generation";
20+
const pipelineTag = args["pipeline-tag"] || "text-generation";
2121
const tags = (args["tags"] || "").split(",");
2222

2323
const modelMinimal: ModelDataMinimal = {

packages/tasks/src/snippets/python.ts

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
44
import { getModelInputSnippet } from "./inputs.js";
55
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
66

7+
const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string =>
8+
`from huggingface_hub import InferenceClient
9+
10+
client = InferenceClient(${model.id}, token="${accessToken || "{API_TOKEN}"}")`;
11+
712
export const snippetConversational = (
813
model: ModelDataMinimal,
914
accessToken: string,
@@ -184,18 +189,31 @@ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({
184189
output = query(${getModelInputSnippet(model)})`,
185190
});
186191

187-
export const snippetTextToImage = (model: ModelDataMinimal): InferenceSnippet => ({
188-
content: `def query(payload):
192+
export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => {
193+
return [
194+
{
195+
client: "requests",
196+
content: `def query(payload):
189197
response = requests.post(API_URL, headers=headers, json=payload)
190198
return response.content
199+
191200
image_bytes = query({
192201
"inputs": ${getModelInputSnippet(model)},
193202
})
194203
# You can access the image with PIL.Image for example
195204
import io
196205
from PIL import Image
197206
image = Image.open(io.BytesIO(image_bytes))`,
198-
});
207+
},
208+
{
209+
client: "huggingface_hub",
210+
content: `${snippetImportInferenceClient(model, accessToken)}
211+
212+
# output is a PIL.Image object
213+
image = client.text_to_image(${getModelInputSnippet(model)})`,
214+
},
215+
];
216+
};
199217

200218
export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet => ({
201219
content: `def query(payload):
@@ -300,6 +318,9 @@ export function getPythonInferenceSnippet(
300318
if (model.tags.includes("conversational")) {
301319
// Conversational model detected, so we display a code snippet that features the Messages API
302320
return snippetConversational(model, accessToken, opts);
321+
} else if (model.pipeline_tag == "text-to-image") {
322+
// TODO: factorize this logic
323+
return snippetTextToImage(model, accessToken);
303324
} else {
304325
let snippets =
305326
model.pipeline_tag && model.pipeline_tag in pythonSnippets

0 commit comments

Comments
 (0)