Skip to content

Commit d776a4b

Browse files
committed
zero shot class
1 parent f1c3367 commit d776a4b

File tree

4 files changed

+26
-13
lines changed

4 files changed

+26
-13
lines changed

packages/inference/src/snippets/templates/python/requests/zeroShotClassification.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ def query(payload):
33
return response.json()
44

55
output = query({
6-
"inputs": {{ inputs }},
6+
"inputs": {{ inputs.asObj.inputs }},
77
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
88
})

packages/inference/src/snippets/templates/python/requests/zeroShotImageClassification.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ def query(data):
99
return response.json()
1010

1111
output = query({
12-
"image_path": {{ inputs }},
12+
"image_path": {{ inputs.asObj.inputs }},
1313
"parameters": {"candidate_labels": ["cat", "dog", "llama"]},
1414
})

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -178,17 +178,17 @@ const TEST_CASES: {
178178
providers: ["hf-inference"],
179179
languages: ["py"],
180180
},
181-
// {
182-
// testName: "zero-shot-classification",
183-
// model: {
184-
// id: "facebook/bart-large-mnli",
185-
// pipeline_tag: "zero-shot-classification",
186-
// tags: [],
187-
// inference: "",
188-
// },
189-
// providers: ["hf-inference"],
190-
// languages: ["py"],
191-
// },
181+
{
182+
testName: "zero-shot-classification",
183+
model: {
184+
id: "facebook/bart-large-mnli",
185+
pipeline_tag: "zero-shot-classification",
186+
tags: [],
187+
inference: "",
188+
},
189+
providers: ["hf-inference"],
190+
languages: ["py"],
191+
},
192192
// {
193193
// testName: "zero-shot-image-classification",
194194
// model: {
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import requests
2+
3+
API_URL = "https://router.huggingface.co/hf-inference/models/facebook/bart-large-mnli"
4+
headers = {"Authorization": "Bearer api_token"}
5+
6+
def query(payload):
7+
response = requests.post(API_URL, headers=headers, json=payload)
8+
return response.json()
9+
10+
output = query({
11+
"inputs": "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!",
12+
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
13+
})

0 commit comments

Comments
 (0)