1
1
import type { PipelineType } from "../pipelines.js" ;
2
+ import type { ChatCompletionInputMessage , GenerationParameters } from "../tasks/index.js" ;
2
3
import { getModelInputSnippet } from "./inputs.js" ;
3
- import type { InferenceSnippet , ModelDataMinimal } from "./types.js" ;
4
+ import type {
5
+ GenerationConfigFormatter ,
6
+ GenerationMessagesFormatter ,
7
+ InferenceSnippet ,
8
+ ModelDataMinimal ,
9
+ } from "./types.js" ;
4
10
5
11
export const snippetBasic = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => ( {
6
12
content : `curl https://api-inference.huggingface.co/models/${ model . id } \\
@@ -10,20 +16,58 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): Infe
10
16
-H "Authorization: Bearer ${ accessToken || `{API_TOKEN}` } "` ,
11
17
} ) ;
12
18
13
- export const snippetTextGeneration = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet => {
19
+ const formatGenerationMessages : GenerationMessagesFormatter = ( { messages, sep, start, end } ) =>
20
+ start +
21
+ messages
22
+ . map ( ( { role, content } ) => {
23
+ // escape single quotes since single quotes is used to define http post body inside curl requests
24
+ // TODO: handle the case below
25
+ content = content ?. replace ( / ' / g, "'\\''" ) ;
26
+ return `{ "role": "${ role } ", "content": "${ content } " }` ;
27
+ } )
28
+ . join ( sep ) +
29
+ end ;
30
+
31
+ const formatGenerationConfig : GenerationConfigFormatter = ( { config, sep, start, end } ) =>
32
+ start +
33
+ Object . entries ( config )
34
+ . map ( ( [ key , val ] ) => `"${ key } ": ${ val } ` )
35
+ . join ( sep ) +
36
+ end ;
37
+
38
+ export const snippetTextGeneration = (
39
+ model : ModelDataMinimal ,
40
+ accessToken : string ,
41
+ opts ?: {
42
+ streaming ?: boolean ;
43
+ messages ?: ChatCompletionInputMessage [ ] ;
44
+ temperature ?: GenerationParameters [ "temperature" ] ;
45
+ max_tokens ?: GenerationParameters [ "max_tokens" ] ;
46
+ top_p ?: GenerationParameters [ "top_p" ] ;
47
+ }
48
+ ) : InferenceSnippet => {
14
49
if ( model . tags . includes ( "conversational" ) ) {
15
50
// Conversational model detected, so we display a code snippet that features the Messages API
51
+ const streaming = opts ?. streaming ?? true ;
52
+ const messages : ChatCompletionInputMessage [ ] = opts ?. messages ?? [
53
+ { role : "user" , content : "What is the capital of France?" } ,
54
+ ] ;
55
+
56
+ const config = {
57
+ temperature : opts ?. temperature ,
58
+ max_tokens : opts ?. max_tokens ?? 500 ,
59
+ top_p : opts ?. top_p ,
60
+ } ;
16
61
return {
17
62
content : `curl 'https://api-inference.huggingface.co/models/${ model . id } /v1/chat/completions' \\
18
63
-H "Authorization: Bearer ${ accessToken || `{API_TOKEN}` } " \\
19
64
-H 'Content-Type: application/json' \\
20
- -d '{
21
- "model": "${ model . id } ",
22
- "messages": [{"role": "user", "content": "What is the capital of France?"}],
23
- "max_tokens": 500,
24
- "stream": false
25
- }'
26
- ` ,
65
+ --data '{
66
+ "model": "${ model . id } ",
67
+ "messages": ${ formatGenerationMessages ( { messages, sep : ",\n " , start : `[\n ` , end : `\n]` } ) } ,
68
+ ${ formatGenerationConfig ( { config, sep : ",\n " , start : "" , end : "" } ) } ,
69
+ "stream": ${ ! ! streaming }
70
+ }'` ,
27
71
} ;
28
72
} else {
29
73
return snippetBasic ( model , accessToken ) ;
@@ -76,7 +120,7 @@ export const snippetFile = (model: ModelDataMinimal, accessToken: string): Infer
76
120
export const curlSnippets : Partial <
77
121
Record <
78
122
PipelineType ,
79
- ( model : ModelDataMinimal , accessToken : string , opts ?: Record < string , string | boolean | number > ) => InferenceSnippet
123
+ ( model : ModelDataMinimal , accessToken : string , opts ?: Record < string , unknown > ) => InferenceSnippet
80
124
>
81
125
> = {
82
126
// Same order as in js/src/lib/interfaces/Types.ts
0 commit comments