|
1 | 1 | import type { PipelineType } from "../pipelines.js";
|
| 2 | +import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js"; |
2 | 3 | import { getModelInputSnippet } from "./inputs.js";
|
3 |
| -import type { InferenceSnippet, ModelDataMinimal } from "./types.js"; |
| 4 | +import type { |
| 5 | + GenerationConfigFormatter, |
| 6 | + GenerationMessagesFormatter, |
| 7 | + InferenceSnippet, |
| 8 | + ModelDataMinimal, |
| 9 | +} from "./types.js"; |
4 | 10 |
|
5 |
| -export const snippetConversational = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({ |
6 |
| - content: `from huggingface_hub import InferenceClient |
| 11 | +const formatGenerationMessages: GenerationMessagesFormatter = ({ messages, sep, start, end }) => |
| 12 | + start + messages.map(({ role, content }) => `{ "role": "${role}", "content": "${content}" }`).join(sep) + end; |
| 13 | + |
| 14 | +const formatGenerationConfig: GenerationConfigFormatter = ({ config, sep, start, end, connector }) => |
| 15 | + start + |
| 16 | + Object.entries(config) |
| 17 | + .map(([key, val]) => `${key}${connector}${val}`) |
| 18 | + .join(sep) + |
| 19 | + end; |
| 20 | + |
| 21 | +export const snippetConversational = ( |
| 22 | + model: ModelDataMinimal, |
| 23 | + accessToken: string, |
| 24 | + opts?: { |
| 25 | + streaming?: boolean; |
| 26 | + messages?: ChatCompletionInputMessage[]; |
| 27 | + temperature?: GenerationParameters["temperature"]; |
| 28 | + max_tokens?: GenerationParameters["max_tokens"]; |
| 29 | + top_p?: GenerationParameters["top_p"]; |
| 30 | + } |
| 31 | +): InferenceSnippet => { |
| 32 | + const streaming = opts?.streaming ?? true; |
| 33 | + const messages: ChatCompletionInputMessage[] = opts?.messages ?? [ |
| 34 | + { role: "user", content: "What is the capital of France?" }, |
| 35 | + ]; |
| 36 | + |
| 37 | + const config = { |
| 38 | + temperature: opts?.temperature, |
| 39 | + max_tokens: opts?.max_tokens ?? 500, |
| 40 | + top_p: opts?.top_p, |
| 41 | + }; |
| 42 | + |
| 43 | + if (streaming) { |
| 44 | + return { |
| 45 | + content: `from huggingface_hub import InferenceClient |
7 | 46 |
|
8 | 47 | client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
|
9 | 48 |
|
10 |
| -for message in client.chat_completion( |
11 |
| - model="${model.id}", |
12 |
| - messages=[{"role": "user", "content": "What is the capital of France?"}], |
13 |
| - max_tokens=500, |
14 |
| - stream=True, |
15 |
| -): |
16 |
| - print(message.choices[0].delta.content, end="")`, |
17 |
| -}); |
| 49 | +messages = ${formatGenerationMessages({ messages, sep: ",\n\t", start: `[\n\t`, end: `\n]` })} |
| 50 | +
|
| 51 | +stream = client.chat.completions.create( |
| 52 | + model="${model.id}", |
| 53 | + messages=messages, |
| 54 | + ${formatGenerationConfig({ config, sep: ",\n\t", start: "", end: "", connector: "=" })}, |
| 55 | + stream=True |
| 56 | +) |
| 57 | +
|
| 58 | +for chunk in stream: |
| 59 | + print(chunk.choices[0].delta.content)`, |
| 60 | + }; |
| 61 | + } else { |
| 62 | + return { |
| 63 | + content: `from huggingface_hub import InferenceClient |
| 64 | +
|
| 65 | +client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}") |
| 66 | +
|
| 67 | +messages = ${formatGenerationMessages({ messages, sep: ",\n\t", start: `[\n\t`, end: `\n]` })} |
| 68 | +
|
| 69 | +completion = client.chat.completions.create( |
| 70 | + model="${model.id}", |
| 71 | + messages=messages, |
| 72 | + ${formatGenerationConfig({ config, sep: ",\n\t", start: "", end: "", connector: "=" })} |
| 73 | +) |
| 74 | +
|
| 75 | +print(completion.choices[0].message)`, |
| 76 | + }; |
| 77 | + } |
| 78 | +}; |
18 | 79 |
|
19 | 80 | export const snippetConversationalWithImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
|
20 | 81 | content: `from huggingface_hub import InferenceClient
|
@@ -159,7 +220,7 @@ output = query({
|
159 | 220 | export const pythonSnippets: Partial<
|
160 | 221 | Record<
|
161 | 222 | PipelineType,
|
162 |
| - (model: ModelDataMinimal, accessToken: string, opts?: Record<string, string | boolean | number>) => InferenceSnippet |
| 223 | + (model: ModelDataMinimal, accessToken: string, opts?: Record<string, unknown>) => InferenceSnippet |
163 | 224 | >
|
164 | 225 | > = {
|
165 | 226 | // Same order as in tasks/src/pipelines.ts
|
@@ -192,10 +253,14 @@ export const pythonSnippets: Partial<
|
192 | 253 | "zero-shot-image-classification": snippetZeroShotImageClassification,
|
193 | 254 | };
|
194 | 255 |
|
195 |
| -export function getPythonInferenceSnippet(model: ModelDataMinimal, accessToken: string): InferenceSnippet { |
| 256 | +export function getPythonInferenceSnippet( |
| 257 | + model: ModelDataMinimal, |
| 258 | + accessToken: string, |
| 259 | + opts?: Record<string, unknown> |
| 260 | +): InferenceSnippet { |
196 | 261 | if (model.pipeline_tag === "text-generation" && model.tags.includes("conversational")) {
|
197 | 262 | // Conversational model detected, so we display a code snippet that features the Messages API
|
198 |
| - return snippetConversational(model, accessToken); |
| 263 | + return snippetConversational(model, accessToken, opts); |
199 | 264 | } else if (model.pipeline_tag === "image-text-to-text" && model.tags.includes("conversational")) {
|
200 | 265 | // Example sending an image to the Message API
|
201 | 266 | return snippetConversationalWithImage(model, accessToken);
|
|
0 commit comments