@@ -28,7 +28,7 @@ export const snippetConversational = (
28
28
max_tokens ?: GenerationParameters [ "max_tokens" ] ;
29
29
top_p ?: GenerationParameters [ "top_p" ] ;
30
30
}
31
- ) : InferenceSnippet => {
31
+ ) : InferenceSnippet [ ] => {
32
32
const streaming = opts ?. streaming ?? true ;
33
33
const messages : ChatCompletionInputMessage [ ] = opts ?. messages ?? [
34
34
{ role : "user" , content : "What is the capital of France?" } ,
@@ -41,8 +41,10 @@ export const snippetConversational = (
41
41
} ;
42
42
43
43
if ( streaming ) {
44
- return {
45
- content : `from huggingface_hub import InferenceClient
44
+ return [
45
+ {
46
+ client : "huggingface_hub" ,
47
+ content : `from huggingface_hub import InferenceClient
46
48
47
49
client = InferenceClient(api_key="${ accessToken || "{API_TOKEN}" } ")
48
50
@@ -57,10 +59,34 @@ stream = client.chat.completions.create(
57
59
58
60
for chunk in stream:
59
61
print(chunk.choices[0].delta.content)` ,
60
- } ;
62
+ } ,
63
+ {
64
+ client : "openai" ,
65
+ content : `from openai import OpenAI
66
+
67
+ client = OpenAI(
68
+ base_url="https://api-inference.huggingface.co/v1/",
69
+ api_key="${ accessToken || "{API_TOKEN}" } "
70
+ )
71
+
72
+ messages = ${ formatGenerationMessages ( { messages, sep : ",\n\t" , start : `[\n\t` , end : `\n]` } ) }
73
+
74
+ stream = client.chat.completions.create(
75
+ model="${ model . id } ",
76
+ messages=messages,
77
+ ${ formatGenerationConfig ( { config, sep : ",\n\t" , start : "" , end : "" , connector : "=" } ) } ,
78
+ stream=True
79
+ )
80
+
81
+ for chunk in stream:
82
+ print(chunk.choices[0].delta.content)` ,
83
+ } ,
84
+ ] ;
61
85
} else {
62
- return {
63
- content : `from huggingface_hub import InferenceClient
86
+ return [
87
+ {
88
+ client : "huggingface_hub" ,
89
+ content : `from huggingface_hub import InferenceClient
64
90
65
91
client = InferenceClient(api_key="${ accessToken || "{API_TOKEN}" } ")
66
92
@@ -73,7 +99,27 @@ completion = client.chat.completions.create(
73
99
)
74
100
75
101
print(completion.choices[0].message)` ,
76
- } ;
102
+ } ,
103
+ {
104
+ client : "openai" ,
105
+ content : `from openai import OpenAI
106
+
107
+ client = OpenAI(
108
+ base_url="https://api-inference.huggingface.co/v1/",
109
+ api_key="${ accessToken || "{API_TOKEN}" } "
110
+ )
111
+
112
+ messages = ${ formatGenerationMessages ( { messages, sep : ",\n\t" , start : `[\n\t` , end : `\n]` } ) }
113
+
114
+ completion = client.chat.completions.create(
115
+ model="${ model . id } ",
116
+ messages=messages,
117
+ ${ formatGenerationConfig ( { config, sep : ",\n\t" , start : "" , end : "" , connector : "=" } ) }
118
+ )
119
+
120
+ print(completion.choices[0].message)` ,
121
+ } ,
122
+ ] ;
77
123
}
78
124
} ;
79
125
@@ -220,7 +266,11 @@ output = query({
220
266
export const pythonSnippets : Partial <
221
267
Record <
222
268
PipelineType ,
223
- ( model : ModelDataMinimal , accessToken : string , opts ?: Record < string , unknown > ) => InferenceSnippet
269
+ (
270
+ model : ModelDataMinimal ,
271
+ accessToken : string ,
272
+ opts ?: Record < string , unknown >
273
+ ) => InferenceSnippet | InferenceSnippet [ ]
224
274
>
225
275
> = {
226
276
// Same order as in tasks/src/pipelines.ts
0 commit comments