Skip to content

Commit bd36835

Browse files
committed
zero shot image class
1 parent d776a4b commit bd36835

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -189,17 +189,17 @@ const TEST_CASES: {
189189
providers: ["hf-inference"],
190190
languages: ["py"],
191191
},
192-
// {
193-
// testName: "zero-shot-image-classification",
194-
// model: {
195-
// id: "openai/clip-vit-large-patch14",
196-
// pipeline_tag: "zero-shot-image-classification",
197-
// tags: [],
198-
// inference: "",
199-
// },
200-
// providers: ["hf-inference"],
201-
// languages: ["py"],
202-
// },
192+
{
193+
testName: "zero-shot-image-classification",
194+
model: {
195+
id: "openai/clip-vit-large-patch14",
196+
pipeline_tag: "zero-shot-image-classification",
197+
tags: [],
198+
inference: "",
199+
},
200+
providers: ["hf-inference"],
201+
languages: ["py"],
202+
},
203203
] as const;
204204

205205
const GET_SNIPPET_FN = {
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import base64
2+
import requests
3+
4+
API_URL = "https://router.huggingface.co/hf-inference/models/openai/clip-vit-large-patch14"
5+
headers = {"Authorization": "Bearer api_token"}
6+
7+
def query(data):
8+
with open(data["image_path"], "rb") as f:
9+
img = f.read()
10+
payload={
11+
"parameters": data["parameters"],
12+
"inputs": base64.b64encode(img).decode("utf-8")
13+
}
14+
response = requests.post(API_URL, headers=headers, json=payload)
15+
return response.json()
16+
17+
output = query({
18+
"image_path": "cats.jpg",
19+
"parameters": {"candidate_labels": ["cat", "dog", "llama"]},
20+
})

0 commit comments

Comments
 (0)