Skip to content

Commit 9cb93be

Browse files
committed
complete snippetConversational
1 parent dc7eaf4 commit 9cb93be

File tree

2 files changed

+104
-14
lines changed

2 files changed

+104
-14
lines changed

packages/tasks/src/snippets/python.ts

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,81 @@
11
import type { PipelineType } from "../pipelines.js";
2+
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
23
import { getModelInputSnippet } from "./inputs.js";
3-
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
4+
import type {
5+
GenerationConfigFormatter,
6+
GenerationMessagesFormatter,
7+
InferenceSnippet,
8+
ModelDataMinimal,
9+
} from "./types.js";
410

5-
export const snippetConversational = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
6-
content: `from huggingface_hub import InferenceClient
11+
const formatGenerationMessages: GenerationMessagesFormatter = ({ messages, sep, start, end }) =>
12+
start + messages.map(({ role, content }) => `{ "role": "${role}", "content": "${content}" }`).join(sep) + end;
13+
14+
const formatGenerationConfig: GenerationConfigFormatter = ({ config, sep, start, end, connector }) =>
15+
start +
16+
Object.entries(config)
17+
.map(([key, val]) => `${key}${connector}${val}`)
18+
.join(sep) +
19+
end;
20+
21+
export const snippetConversational = (
22+
model: ModelDataMinimal,
23+
accessToken: string,
24+
opts?: {
25+
streaming?: boolean;
26+
messages?: ChatCompletionInputMessage[];
27+
temperature?: GenerationParameters["temperature"];
28+
max_tokens?: GenerationParameters["max_tokens"];
29+
top_p?: GenerationParameters["top_p"];
30+
}
31+
): InferenceSnippet => {
32+
const streaming = opts?.streaming ?? true;
33+
const messages: ChatCompletionInputMessage[] = opts?.messages ?? [
34+
{ role: "user", content: "What is the capital of France?" },
35+
];
36+
37+
const config = {
38+
temperature: opts?.temperature,
39+
max_tokens: opts?.max_tokens ?? 500,
40+
top_p: opts?.top_p,
41+
};
42+
43+
if (streaming) {
44+
return {
45+
content: `from huggingface_hub import InferenceClient
746
847
client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
948
10-
for message in client.chat_completion(
11-
model="${model.id}",
12-
messages=[{"role": "user", "content": "What is the capital of France?"}],
13-
max_tokens=500,
14-
stream=True,
15-
):
16-
print(message.choices[0].delta.content, end="")`,
17-
});
49+
messages = ${formatGenerationMessages({ messages, sep: ",\n\t", start: `[\n\t`, end: `\n]` })}
50+
51+
stream = client.chat.completions.create(
52+
model="${model.id}",
53+
messages=messages,
54+
${formatGenerationConfig({ config, sep: ",\n\t", start: "", end: "", connector: "=" })},
55+
stream=True
56+
)
57+
58+
for chunk in stream:
59+
print(chunk.choices[0].delta.content)`,
60+
};
61+
} else {
62+
return {
63+
content: `from huggingface_hub import InferenceClient
64+
65+
client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
66+
67+
messages = ${formatGenerationMessages({ messages, sep: ",\n\t", start: `[\n\t`, end: `\n]` })}
68+
69+
completion = client.chat.completions.create(
70+
model="${model.id}",
71+
messages=messages,
72+
${formatGenerationConfig({ config, sep: ",\n\t", start: "", end: "", connector: "=" })}
73+
)
74+
75+
print(completion.choices[0].message)`,
76+
};
77+
}
78+
};
1879

1980
export const snippetConversationalWithImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
2081
content: `from huggingface_hub import InferenceClient
@@ -159,7 +220,7 @@ output = query({
159220
export const pythonSnippets: Partial<
160221
Record<
161222
PipelineType,
162-
(model: ModelDataMinimal, accessToken: string, opts?: Record<string, string | boolean | number>) => InferenceSnippet
223+
(model: ModelDataMinimal, accessToken: string, opts?: Record<string, unknown>) => InferenceSnippet
163224
>
164225
> = {
165226
// Same order as in tasks/src/pipelines.ts
@@ -192,10 +253,14 @@ export const pythonSnippets: Partial<
192253
"zero-shot-image-classification": snippetZeroShotImageClassification,
193254
};
194255

195-
export function getPythonInferenceSnippet(model: ModelDataMinimal, accessToken: string): InferenceSnippet {
256+
export function getPythonInferenceSnippet(
257+
model: ModelDataMinimal,
258+
accessToken: string,
259+
opts?: Record<string, unknown>
260+
): InferenceSnippet {
196261
if (model.pipeline_tag === "text-generation" && model.tags.includes("conversational")) {
197262
// Conversational model detected, so we display a code snippet that features the Messages API
198-
return snippetConversational(model, accessToken);
263+
return snippetConversational(model, accessToken, opts);
199264
} else if (model.pipeline_tag === "image-text-to-text" && model.tags.includes("conversational")) {
200265
// Example sending an image to the Message API
201266
return snippetConversationalWithImage(model, accessToken);

packages/tasks/src/snippets/types.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { ModelData } from "../model-data";
2+
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks";
23

34
/**
45
* Minimal model data required for snippets.
@@ -14,3 +15,27 @@ export interface InferenceSnippet {
1415
content: string;
1516
client?: string; // for instance: `client` could be huggingface_hub or openai client for Python snippets
1617
}
18+
19+
interface GenerationSnippetDelimiter {
20+
sep: string;
21+
start: string;
22+
end: string;
23+
connector?: string;
24+
}
25+
26+
type PartialGenerationParameters = Partial<Pick<GenerationParameters, "temperature" | "max_tokens" | "top_p">>;
27+
28+
export type GenerationMessagesFormatter = ({
29+
messages,
30+
sep,
31+
start,
32+
end,
33+
}: GenerationSnippetDelimiter & { messages: ChatCompletionInputMessage[] }) => string;
34+
35+
export type GenerationConfigFormatter = ({
36+
config,
37+
sep,
38+
start,
39+
end,
40+
connector,
41+
}: GenerationSnippetDelimiter & { config: PartialGenerationParameters }) => string;

0 commit comments

Comments
 (0)