@@ -22,12 +22,14 @@ import * as path from "node:path/posix";
22
22
import { snippets } from "@huggingface/inference" ;
23
23
import type { SnippetInferenceProvider , InferenceSnippet , ModelDataMinimal } from "@huggingface/tasks" ;
24
24
25
- type LANGUAGE = "sh" | "js" | "py" ;
25
+ const LANGUAGES = [ "sh" , "js" , "python" ] as const ;
26
+ type Language = ( typeof LANGUAGES ) [ number ] ;
27
+ const EXTENSIONS : Record < Language , string > = { sh : "sh" , js : "js" , python : "py" } ;
26
28
27
29
const TEST_CASES : {
28
30
testName : string ;
29
31
model : ModelDataMinimal ;
30
- languages : LANGUAGE [ ] ;
32
+ languages : Language [ ] ;
31
33
providers : SnippetInferenceProvider [ ] ;
32
34
opts ?: Record < string , unknown > ;
33
35
} [ ] = [
@@ -39,7 +41,7 @@ const TEST_CASES: {
39
41
tags : [ ] ,
40
42
inference : "" ,
41
43
} ,
42
- languages : [ "py " ] ,
44
+ languages : [ "python " ] ,
43
45
providers : [ "hf-inference" ] ,
44
46
} ,
45
47
{
@@ -50,7 +52,7 @@ const TEST_CASES: {
50
52
tags : [ "conversational" ] ,
51
53
inference : "" ,
52
54
} ,
53
- languages : [ "sh" , "js" , "py " ] ,
55
+ languages : [ "sh" , "js" , "python " ] ,
54
56
providers : [ "hf-inference" , "together" ] ,
55
57
opts : { streaming : false } ,
56
58
} ,
@@ -62,7 +64,7 @@ const TEST_CASES: {
62
64
tags : [ "conversational" ] ,
63
65
inference : "" ,
64
66
} ,
65
- languages : [ "sh" , "js" , "py " ] ,
67
+ languages : [ "sh" , "js" , "python " ] ,
66
68
providers : [ "hf-inference" , "together" ] ,
67
69
opts : { streaming : true } ,
68
70
} ,
@@ -74,7 +76,7 @@ const TEST_CASES: {
74
76
tags : [ "conversational" ] ,
75
77
inference : "" ,
76
78
} ,
77
- languages : [ "sh" , "js" , "py " ] ,
79
+ languages : [ "sh" , "js" , "python " ] ,
78
80
providers : [ "hf-inference" , "fireworks-ai" ] ,
79
81
opts : { streaming : false } ,
80
82
} ,
@@ -86,7 +88,7 @@ const TEST_CASES: {
86
88
tags : [ "conversational" ] ,
87
89
inference : "" ,
88
90
} ,
89
- languages : [ "sh" , "js" , "py " ] ,
91
+ languages : [ "sh" , "js" , "python " ] ,
90
92
providers : [ "hf-inference" , "fireworks-ai" ] ,
91
93
opts : { streaming : true } ,
92
94
} ,
@@ -98,7 +100,7 @@ const TEST_CASES: {
98
100
tags : [ ] ,
99
101
inference : "" ,
100
102
} ,
101
- languages : [ "py " ] ,
103
+ languages : [ "python " ] ,
102
104
providers : [ "hf-inference" ] ,
103
105
} ,
104
106
{
@@ -109,7 +111,7 @@ const TEST_CASES: {
109
111
tags : [ ] ,
110
112
inference : "" ,
111
113
} ,
112
- languages : [ "py " ] ,
114
+ languages : [ "python " ] ,
113
115
providers : [ "hf-inference" ] ,
114
116
} ,
115
117
{
@@ -121,7 +123,7 @@ const TEST_CASES: {
121
123
inference : "" ,
122
124
} ,
123
125
providers : [ "hf-inference" ] ,
124
- languages : [ "py " ] ,
126
+ languages : [ "python " ] ,
125
127
} ,
126
128
{
127
129
testName : "text-to-audio-transformers" ,
@@ -132,7 +134,7 @@ const TEST_CASES: {
132
134
inference : "" ,
133
135
} ,
134
136
providers : [ "hf-inference" ] ,
135
- languages : [ "py " ] ,
137
+ languages : [ "python " ] ,
136
138
} ,
137
139
{
138
140
testName : "text-to-image" ,
@@ -143,7 +145,7 @@ const TEST_CASES: {
143
145
inference : "" ,
144
146
} ,
145
147
providers : [ "hf-inference" , "fal-ai" ] ,
146
- languages : [ "sh" , "js" , "py " ] ,
148
+ languages : [ "sh" , "js" , "python " ] ,
147
149
} ,
148
150
{
149
151
testName : "text-to-video" ,
@@ -154,7 +156,7 @@ const TEST_CASES: {
154
156
inference : "" ,
155
157
} ,
156
158
providers : [ "replicate" , "fal-ai" ] ,
157
- languages : [ "js" , "py " ] ,
159
+ languages : [ "js" , "python " ] ,
158
160
} ,
159
161
{
160
162
testName : "text-classification" ,
@@ -165,7 +167,7 @@ const TEST_CASES: {
165
167
inference : "" ,
166
168
} ,
167
169
providers : [ "hf-inference" ] ,
168
- languages : [ "sh" , "js" , "py " ] ,
170
+ languages : [ "sh" , "js" , "python " ] ,
169
171
} ,
170
172
{
171
173
testName : "basic-snippet--token-classification" ,
@@ -176,7 +178,7 @@ const TEST_CASES: {
176
178
inference : "" ,
177
179
} ,
178
180
providers : [ "hf-inference" ] ,
179
- languages : [ "py " ] ,
181
+ languages : [ "python " ] ,
180
182
} ,
181
183
{
182
184
testName : "zero-shot-classification" ,
@@ -187,7 +189,7 @@ const TEST_CASES: {
187
189
inference : "" ,
188
190
} ,
189
191
providers : [ "hf-inference" ] ,
190
- languages : [ "py " ] ,
192
+ languages : [ "python " ] ,
191
193
} ,
192
194
{
193
195
testName : "zero-shot-image-classification" ,
@@ -198,14 +200,14 @@ const TEST_CASES: {
198
200
inference : "" ,
199
201
} ,
200
202
providers : [ "hf-inference" ] ,
201
- languages : [ "py " ] ,
203
+ languages : [ "python " ] ,
202
204
} ,
203
205
] as const ;
204
206
205
207
const GET_SNIPPET_FN = {
206
208
sh : snippets . curl . getCurlInferenceSnippet ,
207
209
js : snippets . js . getJsInferenceSnippet ,
208
- py : snippets . python . getPythonInferenceSnippet ,
210
+ python : snippets . python . getPythonInferenceSnippet ,
209
211
} as const ;
210
212
211
213
const rootDirFinder = ( ) : string => {
@@ -228,42 +230,51 @@ function getFixtureFolder(testName: string): string {
228
230
229
231
function generateInferenceSnippet (
230
232
model : ModelDataMinimal ,
231
- language : LANGUAGE ,
233
+ language : Language ,
232
234
provider : SnippetInferenceProvider ,
233
235
opts ?: Record < string , unknown >
234
236
) : InferenceSnippet [ ] {
235
237
const providerModelId = provider === "hf-inference" ? model . id : `<${ provider } alias for ${ model . id } >` ;
236
- return GET_SNIPPET_FN [ language ] ( model , "api_token" , provider , providerModelId , opts ) ;
238
+ const snippets = GET_SNIPPET_FN [ language ] ( model , "api_token" , provider , providerModelId , opts ) as InferenceSnippet [ ] ;
239
+ return snippets . sort ( ( snippetA , snippetB ) => snippetA . client . localeCompare ( snippetB . client ) ) ;
237
240
}
238
241
239
242
async function getExpectedInferenceSnippet (
240
243
testName : string ,
241
- language : LANGUAGE ,
244
+ language : Language ,
242
245
provider : SnippetInferenceProvider
243
246
) : Promise < InferenceSnippet [ ] > {
244
247
const fixtureFolder = getFixtureFolder ( testName ) ;
245
- const files = await fs . readdir ( fixtureFolder ) ;
248
+ const languageFolder = path . join ( fixtureFolder , language ) ;
249
+ const files = await fs . readdir ( languageFolder , { recursive : true } ) ;
246
250
247
251
const expectedSnippets : InferenceSnippet [ ] = [ ] ;
248
- for ( const file of files . filter ( ( file ) => file . endsWith ( "." + language ) && file . includes ( `.${ provider } .` ) ) . sort ( ) ) {
249
- const client = path . basename ( file ) . split ( "." ) . slice ( 1 , - 2 ) . join ( "." ) ; // e.g. '0.huggingface.js.replicate.js' => "huggingface.js"
250
- const content = await fs . readFile ( path . join ( fixtureFolder , file ) , { encoding : "utf-8" } ) ;
252
+ for ( const file of files . filter ( ( file ) => file . includes ( `.${ provider } .` ) ) . sort ( ) ) {
253
+ const client = file . split ( "/" ) [ 0 ] ; // e.g. fal_client/1.fal-ai.python => fal_client
254
+ const content = await fs . readFile ( path . join ( languageFolder , file ) , { encoding : "utf-8" } ) ;
251
255
expectedSnippets . push ( { client, content } ) ;
252
256
}
253
257
return expectedSnippets ;
254
258
}
255
259
256
260
async function saveExpectedInferenceSnippet (
257
261
testName : string ,
258
- language : LANGUAGE ,
262
+ language : Language ,
259
263
provider : SnippetInferenceProvider ,
260
264
snippets : InferenceSnippet [ ]
261
265
) {
262
266
const fixtureFolder = getFixtureFolder ( testName ) ;
263
267
await fs . mkdir ( fixtureFolder , { recursive : true } ) ;
264
268
265
- for ( const [ index , snippet ] of snippets . entries ( ) ) {
266
- const file = path . join ( fixtureFolder , `${ index } .${ snippet . client ?? "default" } .${ provider } .${ language } ` ) ;
269
+ const indexPerClient = new Map < string , number > ( ) ;
270
+ for ( const snippet of snippets ) {
271
+ const extension = EXTENSIONS [ language ] ;
272
+ const client = snippet . client ;
273
+ const index = indexPerClient . get ( client ) ?? 0 ;
274
+ indexPerClient . set ( client , index + 1 ) ;
275
+
276
+ const file = path . join ( fixtureFolder , language , snippet . client , `${ index } .${ provider } .${ extension } ` ) ;
277
+ await fs . mkdir ( path . dirname ( file ) , { recursive : true } ) ;
267
278
await fs . writeFile ( file , snippet . content ) ;
268
279
}
269
280
}
0 commit comments