1
1
import type { PipelineType } from "../pipelines.js" ;
2
+ import type { ChatCompletionInputMessage , GenerationParameters } from "../tasks/index.js" ;
2
3
import { getModelInputSnippet } from "./inputs.js" ;
3
- import type { InferenceSnippet , ModelDataMinimal } from "./types.js" ;
4
+ import type {
5
+ GenerationConfigFormatter ,
6
+ GenerationMessagesFormatter ,
7
+ InferenceSnippet ,
8
+ ModelDataMinimal ,
9
+ } from "./types.js" ;
4
10
5
11
export const snippetBasic = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => ( {
6
12
content : `async function query(data) {
@@ -24,22 +30,128 @@ query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
24
30
});` ,
25
31
} ) ;
26
32
27
- export const snippetTextGeneration = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => {
33
+ const formatGenerationMessages : GenerationMessagesFormatter = ( { messages, sep, start, end } ) =>
34
+ start + messages . map ( ( { role, content } ) => `{ role: "${ role } ", content: "${ content } " }` ) . join ( sep ) + end ;
35
+
36
+ const formatGenerationConfig : GenerationConfigFormatter = ( { config, sep, start, end } ) =>
37
+ start +
38
+ Object . entries ( config )
39
+ . map ( ( [ key , val ] ) => `${ key } : ${ val } ` )
40
+ . join ( sep ) +
41
+ end ;
42
+
43
+ export const snippetTextGeneration = (
44
+ model : ModelDataMinimal ,
45
+ accessToken : string ,
46
+ opts ?: {
47
+ streaming ?: boolean ;
48
+ messages ?: ChatCompletionInputMessage [ ] ;
49
+ temperature ?: GenerationParameters [ "temperature" ] ;
50
+ max_tokens ?: GenerationParameters [ "max_tokens" ] ;
51
+ top_p ?: GenerationParameters [ "top_p" ] ;
52
+ }
53
+ ) : InferenceSnippet | InferenceSnippet [ ] => {
28
54
if ( model . tags . includes ( "conversational" ) ) {
29
55
// Conversational model detected, so we display a code snippet that features the Messages API
30
- return {
31
- content : `import { HfInference } from "@huggingface/inference";
56
+ const streaming = opts ?. streaming ?? true ;
57
+ const messages : ChatCompletionInputMessage [ ] = opts ?. messages ?? [
58
+ { role : "user" , content : "What is the capital of France?" } ,
59
+ ] ;
60
+ const messagesStr = formatGenerationMessages ( { messages, sep : ",\n\t\t" , start : "[\n\t\t" , end : "\n\t]" } ) ;
32
61
33
- const inference = new HfInference("${ accessToken || `{API_TOKEN}` } ");
62
+ const config = {
63
+ temperature : opts ?. temperature ,
64
+ max_tokens : opts ?. max_tokens ?? 500 ,
65
+ top_p : opts ?. top_p ,
66
+ } ;
67
+ const configStr = formatGenerationConfig ( { config, sep : ",\n\t" , start : "" , end : "" } ) ;
34
68
35
- for await (const chunk of inference.chatCompletionStream({
69
+ if ( streaming ) {
70
+ return [
71
+ {
72
+ client : "huggingface_hub" ,
73
+ content : `import { HfInference } from "@huggingface/inference"
74
+
75
+ const client = new HfInference("${ accessToken || `{API_TOKEN}` } ")
76
+
77
+ let out = "";
78
+
79
+ const stream = client.chatCompletionStream({
36
80
model: "${ model . id } ",
37
- messages: [{ role: "user", content: "What is the capital of France?" }],
38
- max_tokens: 500,
39
- })) {
40
- process.stdout.write(chunk.choices[0]?.delta?.content || "");
81
+ messages: ${ messagesStr } ,
82
+ ${ configStr }
83
+ });
84
+
85
+ for await (const chunk of stream) {
86
+ if (chunk.choices && chunk.choices.length > 0) {
87
+ const newContent = chunk.choices[0].delta.content;
88
+ out += newContent;
89
+ console.log(newContent);
90
+ }
41
91
}` ,
42
- } ;
92
+ } ,
93
+ {
94
+ client : "openai" ,
95
+ content : `import { OpenAI } from "openai"
96
+
97
+ const client = new OpenAI({
98
+ baseURL: "https://api-inference.huggingface.co/v1/",
99
+ apiKey: "${ accessToken || `{API_TOKEN}` } "
100
+ })
101
+
102
+ let out = "";
103
+
104
+ const stream = await client.chat.completions.create({
105
+ model: "${ model . id } ",
106
+ messages: ${ messagesStr } ,
107
+ ${ configStr } ,
108
+ stream: true,
109
+ });
110
+
111
+ for await (const chunk of stream) {
112
+ if (chunk.choices && chunk.choices.length > 0) {
113
+ const newContent = chunk.choices[0].delta.content;
114
+ out += newContent;
115
+ console.log(newContent);
116
+ }
117
+ }` ,
118
+ } ,
119
+ ] ;
120
+ } else {
121
+ return [
122
+ {
123
+ client : "huggingface_hub" ,
124
+ content : `import { HfInference } from '@huggingface/inference'
125
+
126
+ const client = new HfInference("${ accessToken || `{API_TOKEN}` } ")
127
+
128
+ const chatCompletion = await client.chatCompletion({
129
+ model: "${ model . id } ",
130
+ messages: ${ messagesStr } ,
131
+ ${ configStr }
132
+ });
133
+
134
+ console.log(chatCompletion.choices[0].message);` ,
135
+ } ,
136
+ {
137
+ client : "openai" ,
138
+ content : `import { OpenAI } from "openai"
139
+
140
+ const client = new OpenAI({
141
+ baseURL: "https://api-inference.huggingface.co/v1/",
142
+ apiKey: "${ accessToken || `{API_TOKEN}` } "
143
+ })
144
+
145
+ const chatCompletion = await client.chat.completions.create({
146
+ model: "${ model . id } ",
147
+ messages: ${ messagesStr } ,
148
+ ${ configStr }
149
+ });
150
+
151
+ console.log(chatCompletion.choices[0].message);` ,
152
+ } ,
153
+ ] ;
154
+ }
43
155
} else {
44
156
return snippetBasic ( model , accessToken ) ;
45
157
}
@@ -187,7 +299,11 @@ query(${getModelInputSnippet(model)}).then((response) => {
187
299
export const jsSnippets : Partial <
188
300
Record <
189
301
PipelineType ,
190
- ( model : ModelDataMinimal , accessToken : string , opts ?: Record < string , string | boolean | number > ) => InferenceSnippet
302
+ (
303
+ model : ModelDataMinimal ,
304
+ accessToken : string ,
305
+ opts ?: Record < string , unknown >
306
+ ) => InferenceSnippet | InferenceSnippet [ ]
191
307
>
192
308
> = {
193
309
// Same order as in js/src/lib/interfaces/Types.ts
0 commit comments