Skip to content

Use InferenceClient in (some) python inference snippets #971

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 83 additions & 80 deletions packages/tasks/src/snippets/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,40 @@ import type { PipelineType } from "../pipelines.js";
import { getModelInputSnippet } from "./inputs.js";
import type { ModelDataMinimal } from "./types.js";

export const snippetConversational = (model: ModelDataMinimal, accessToken: string): string =>
// Import snippets

const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string =>
`from huggingface_hub import InferenceClient

client = InferenceClient(${model.id}, token="${accessToken || "{API_TOKEN}"}")
`;

const snippetImportConversationalInferenceClient = (model: ModelDataMinimal, accessToken: string): string =>
// Same but uses OpenAI convention
`from huggingface_hub import InferenceClient

client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
`;

for message in client.chat_completion(
const snippetImportRequests = (model: ModelDataMinimal, accessToken: string): string =>
`import requests

API_URL = "https://api-inference.huggingface.co/models/${model.id}"
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}`;

export const snippetConversational = (model: ModelDataMinimal): string =>
`for message in client.chat_completion(
model="${model.id}",
messages=[{"role": "user", "content": "What is the capital of France?"}],
max_tokens=500,
stream=True,
):
print(message.choices[0].delta.content, end="")`;

export const snippetConversationalWithImage = (model: ModelDataMinimal, accessToken: string): string =>
`from huggingface_hub import InferenceClient

client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
// InferenceClient-based snippets

image_url = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
export const snippetConversationalWithImage = (model: ModelDataMinimal): string =>
`image_url = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"

for message in client.chat_completion(
model="${model.id}",
Expand All @@ -38,31 +53,29 @@ for message in client.chat_completion(
):
print(message.choices[0].delta.content, end="")`;

export const snippetZeroShotClassification = (model: ModelDataMinimal): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
export const snippetDocumentQuestionAnswering = (): string =>
`output = client.document_question_answering("cat.png", "What is in this image?")`;

output = query({
"inputs": ${getModelInputSnippet(model)},
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
})`;
export const snippetTabularClassification = (model: ModelDataMinimal): string =>
`output = client.tabular_classification(${getModelInputSnippet(model)})`;

export const snippetTabularRegression = (model: ModelDataMinimal): string =>
`output = client.tabular_regression(${getModelInputSnippet(model)})`;
export const snippetTextToImage = (model: ModelDataMinimal): string =>
`# output is a PIL.Image object
image = client.text_to_image(${getModelInputSnippet(model)})`;

export const snippetZeroShotClassification = (model: ModelDataMinimal): string =>
`text = ${getModelInputSnippet(model)}
labels = ["refund", "legal", "faq"]
output = client.zero_shot_classification(text, labels)`;

export const snippetZeroShotImageClassification = (model: ModelDataMinimal): string =>
`def query(data):
with open(data["image_path"], "rb") as f:
img = f.read()
payload={
"parameters": data["parameters"],
"inputs": base64.b64encode(img).decode("utf-8")
}
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
`image = ${getModelInputSnippet(model)}
labels = ["cat", "dog", "llama"]
output = client.zero_shot_image_classification(image, labels)`;

output = query({
"image_path": ${getModelInputSnippet(model)},
"parameters": {"candidate_labels": ["cat", "dog", "llama"]},
})`;
// requests-based snippets

export const snippetBasic = (model: ModelDataMinimal): string =>
`def query(payload):
Expand All @@ -82,26 +95,6 @@ export const snippetFile = (model: ModelDataMinimal): string =>

output = query(${getModelInputSnippet(model)})`;

export const snippetTextToImage = (model: ModelDataMinimal): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
image_bytes = query({
"inputs": ${getModelInputSnippet(model)},
})
# You can access the image with PIL.Image for example
import io
from PIL import Image
image = Image.open(io.BytesIO(image_bytes))`;

export const snippetTabular = (model: ModelDataMinimal): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
response = query({
"inputs": {"data": ${getModelInputSnippet(model)}},
})`;

export const snippetTextToAudio = (model: ModelDataMinimal): string => {
// Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
// with the latest update to inference-api (IA).
Expand Down Expand Up @@ -131,19 +124,16 @@ Audio(audio, rate=sampling_rate)`;
}
};

export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): string =>
`def query(payload):
with open(payload["image"], "rb") as f:
img = f.read()
payload["image"] = base64.b64encode(img).decode("utf-8")
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()

output = query({
"inputs": ${getModelInputSnippet(model)},
})`;
const PIPELINES_USING_INFERENCE_CLIENT: PipelineType[] = [
"document-question-answering",
"tabular-classification",
"tabular-regression",
"text-to-image",
"zero-shot-classification",
"zero-shot-image-classification",
];

export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal, accessToken: string) => string>> = {
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal) => string>> = {
// Same order as in tasks/src/pipelines.ts
"text-classification": snippetBasic,
"token-classification": snippetBasic,
Expand All @@ -165,8 +155,8 @@ export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinim
"audio-to-audio": snippetFile,
"audio-classification": snippetFile,
"image-classification": snippetFile,
"tabular-regression": snippetTabular,
"tabular-classification": snippetTabular,
"tabular-regression": snippetTabularRegression,
"tabular-classification": snippetTabularClassification,
"object-detection": snippetFile,
"image-segmentation": snippetFile,
"document-question-answering": snippetDocumentQuestionAnswering,
Expand All @@ -175,25 +165,38 @@ export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinim
};

export function getPythonInferenceSnippet(model: ModelDataMinimal, accessToken: string): string {
if (model.pipeline_tag === "text-generation" && model.tags.includes("conversational")) {
// Conversational model detected, so we display a code snippet that features the Messages API
return snippetConversational(model, accessToken);
} else if (model.pipeline_tag === "image-text-to-text" && model.tags.includes("conversational")) {
// Example sending an image to the Message API
return snippetConversationalWithImage(model, accessToken);
} else {
const body =
model.pipeline_tag && model.pipeline_tag in pythonSnippets
? pythonSnippets[model.pipeline_tag]?.(model, accessToken) ?? ""
: "";

return `import requests

API_URL = "https://api-inference.huggingface.co/models/${model.id}"
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
// Specific case for chat completion snippets
const isConversational =
"conversational" in model.tags &&
model.pipeline_tag &&
model.pipeline_tag in ["text-generation", "image-text-to-text"];

// Determine the import snippet based on model tags and pipeline tag
const getImportSnippet = () => {
if (isConversational) {
return snippetImportConversationalInferenceClient(model, accessToken);
} else if (model.pipeline_tag && model.pipeline_tag in PIPELINES_USING_INFERENCE_CLIENT) {
return snippetImportInferenceClient(model, accessToken);
} else {
return snippetImportRequests(model, accessToken);
}
};

// Determine the body snippet based on model tags and pipeline tag
const getBodySnippet = () => {
if (isConversational) {
return model.pipeline_tag === "text-generation"
? snippetConversational(model)
: snippetConversationalWithImage(model);
} else if (model.pipeline_tag && model.pipeline_tag in pythonSnippets) {
return pythonSnippets[model.pipeline_tag]?.(model) ?? "";
} else {
return "";
}
};

${body}`;
}
// Combine import and body snippets with newline separation
return `${getImportSnippet()}\n\n${getBodySnippet()}`;
}

export function hasPythonInferenceSnippet(model: ModelDataMinimal): boolean {
Expand Down
Loading