Skip to content

Commit 49042ca

Browse files
committed
js hf & oai clients
1 parent e6d2e4c commit 49042ca

File tree

1 file changed

+128
-12
lines changed
  • packages/tasks/src/snippets

1 file changed

+128
-12
lines changed

packages/tasks/src/snippets/js.ts

Lines changed: 128 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import type { PipelineType } from "../pipelines.js";
2+
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
23
import { getModelInputSnippet } from "./inputs.js";
3-
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
4+
import type {
5+
GenerationConfigFormatter,
6+
GenerationMessagesFormatter,
7+
InferenceSnippet,
8+
ModelDataMinimal,
9+
} from "./types.js";
410

511
export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
612
content: `async function query(data) {
@@ -24,22 +30,128 @@ query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
2430
});`,
2531
});
2632

27-
export const snippetTextGeneration = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => {
33+
const formatGenerationMessages: GenerationMessagesFormatter = ({ messages, sep, start, end }) =>
34+
start + messages.map(({ role, content }) => `{ role: "${role}", content: "${content}" }`).join(sep) + end;
35+
36+
const formatGenerationConfig: GenerationConfigFormatter = ({ config, sep, start, end }) =>
37+
start +
38+
Object.entries(config)
39+
.map(([key, val]) => `${key}: ${val}`)
40+
.join(sep) +
41+
end;
42+
43+
export const snippetTextGeneration = (
44+
model: ModelDataMinimal,
45+
accessToken: string,
46+
opts?: {
47+
streaming?: boolean;
48+
messages?: ChatCompletionInputMessage[];
49+
temperature?: GenerationParameters["temperature"];
50+
max_tokens?: GenerationParameters["max_tokens"];
51+
top_p?: GenerationParameters["top_p"];
52+
}
53+
): InferenceSnippet | InferenceSnippet[] => {
2854
if (model.tags.includes("conversational")) {
2955
// Conversational model detected, so we display a code snippet that features the Messages API
30-
return {
31-
content: `import { HfInference } from "@huggingface/inference";
56+
const streaming = opts?.streaming ?? true;
57+
const messages: ChatCompletionInputMessage[] = opts?.messages ?? [
58+
{ role: "user", content: "What is the capital of France?" },
59+
];
60+
const messagesStr = formatGenerationMessages({ messages, sep: ",\n\t\t", start: "[\n\t\t", end: "\n\t]" });
3261

33-
const inference = new HfInference("${accessToken || `{API_TOKEN}`}");
62+
const config = {
63+
temperature: opts?.temperature,
64+
max_tokens: opts?.max_tokens ?? 500,
65+
top_p: opts?.top_p,
66+
};
67+
const configStr = formatGenerationConfig({ config, sep: ",\n\t", start: "", end: "" });
3468

35-
for await (const chunk of inference.chatCompletionStream({
69+
if (streaming) {
70+
return [
71+
{
72+
client: "huggingface_hub",
73+
content: `import { HfInference } from "@huggingface/inference"
74+
75+
const client = new HfInference("${accessToken || `{API_TOKEN}`}")
76+
77+
let out = "";
78+
79+
const stream = client.chatCompletionStream({
3680
model: "${model.id}",
37-
messages: [{ role: "user", content: "What is the capital of France?" }],
38-
max_tokens: 500,
39-
})) {
40-
process.stdout.write(chunk.choices[0]?.delta?.content || "");
81+
messages: ${messagesStr},
82+
${configStr}
83+
});
84+
85+
for await (const chunk of stream) {
86+
if (chunk.choices && chunk.choices.length > 0) {
87+
const newContent = chunk.choices[0].delta.content;
88+
out += newContent;
89+
console.log(newContent);
90+
}
4191
}`,
42-
};
92+
},
93+
{
94+
client: "openai",
95+
content: `import { OpenAI } from "openai"
96+
97+
const client = new OpenAI({
98+
baseURL: "https://api-inference.huggingface.co/v1/",
99+
apiKey: "${accessToken || `{API_TOKEN}`}"
100+
})
101+
102+
let out = "";
103+
104+
const stream = await client.chat.completions.create({
105+
model: "${model.id}",
106+
messages: ${messagesStr},
107+
${configStr},
108+
stream: true,
109+
});
110+
111+
for await (const chunk of stream) {
112+
if (chunk.choices && chunk.choices.length > 0) {
113+
const newContent = chunk.choices[0].delta.content;
114+
out += newContent;
115+
console.log(newContent);
116+
}
117+
}`,
118+
},
119+
];
120+
} else {
121+
return [
122+
{
123+
client: "huggingface_hub",
124+
content: `import { HfInference } from '@huggingface/inference'
125+
126+
const client = new HfInference("${accessToken || `{API_TOKEN}`}")
127+
128+
const chatCompletion = await client.chatCompletion({
129+
model: "${model.id}",
130+
messages: ${messagesStr},
131+
${configStr}
132+
});
133+
134+
console.log(chatCompletion.choices[0].message);`,
135+
},
136+
{
137+
client: "openai",
138+
content: `import { OpenAI } from "openai"
139+
140+
const client = new OpenAI({
141+
baseURL: "https://api-inference.huggingface.co/v1/",
142+
apiKey: "${accessToken || `{API_TOKEN}`}"
143+
})
144+
145+
const chatCompletion = await client.chat.completions.create({
146+
model: "${model.id}",
147+
messages: ${messagesStr},
148+
${configStr}
149+
});
150+
151+
console.log(chatCompletion.choices[0].message);`,
152+
},
153+
];
154+
}
43155
} else {
44156
return snippetBasic(model, accessToken);
45157
}
@@ -187,7 +299,11 @@ query(${getModelInputSnippet(model)}).then((response) => {
187299
export const jsSnippets: Partial<
188300
Record<
189301
PipelineType,
190-
(model: ModelDataMinimal, accessToken: string, opts?: Record<string, string | boolean | number>) => InferenceSnippet
302+
(
303+
model: ModelDataMinimal,
304+
accessToken: string,
305+
opts?: Record<string, unknown>
306+
) => InferenceSnippet | InferenceSnippet[]
191307
>
192308
> = {
193309
// Same order as in js/src/lib/interfaces/Types.ts

0 commit comments

Comments
 (0)