Skip to content

Commit c001002

Browse files
authored
[internal] refacto tests for inference snippets (#1287)
This PR does nothing else than updating `packages/tasks-gen/scripts/generate-snippets-fixtures.ts` which is an internal script used to test the inference snippets. Goal of this PR is to store generated snippet under a new file structure like this: ``` ./snippets-fixtures/automatic-speech-recognition/python/huggingface_hub/1.hf-inference.py ``` instead of ``` ./snippets-fixtures/automatic-speech-recognition/1.huggingface_hub.hf-inference.py ``` In practice the previous file naming was annoying as it meant that adding a new snippet in a client type could lead to renaming another file (due to the `0.`, `1.`, ... prefixes). --- Typically in #1273 it makes the PR much bigger by e.g. deleting [`1.openai.hf-inference.py`](https://github.com/huggingface/huggingface.js/pull/1273/files#diff-4759b74a67cc4caa7b2d273d7c2a8015ba062a19a8fad5cb2e227ca935dcb749) and creating [`2.openai.hf-inference.py`](https://github.com/huggingface/huggingface.js/pull/1273/files#diff-522e7173f8dd851189bb9b7ff311f4ee78ca65a3994caae803ff4fda5fe59733) just because a new [`1.requests.hf-inference.py`](https://github.com/huggingface/huggingface.js/pull/1273/files#diff-c8c5536f5af1631e8f1802155b66b0a23a4316eaaf5fcfce1a036da490acaa22) has been added. Separating files by language + client avoid these unnecessary problems.
1 parent 1f2d0b6 commit c001002

File tree

70 files changed

+39
-28
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+39
-28
lines changed

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ import * as path from "node:path/posix";
2222
import { snippets } from "@huggingface/inference";
2323
import type { SnippetInferenceProvider, InferenceSnippet, ModelDataMinimal } from "@huggingface/tasks";
2424

25-
type LANGUAGE = "sh" | "js" | "py";
25+
const LANGUAGES = ["sh", "js", "python"] as const;
26+
type Language = (typeof LANGUAGES)[number];
27+
const EXTENSIONS: Record<Language, string> = { sh: "sh", js: "js", python: "py" };
2628

2729
const TEST_CASES: {
2830
testName: string;
2931
model: ModelDataMinimal;
30-
languages: LANGUAGE[];
32+
languages: Language[];
3133
providers: SnippetInferenceProvider[];
3234
opts?: Record<string, unknown>;
3335
}[] = [
@@ -39,7 +41,7 @@ const TEST_CASES: {
3941
tags: [],
4042
inference: "",
4143
},
42-
languages: ["py"],
44+
languages: ["python"],
4345
providers: ["hf-inference"],
4446
},
4547
{
@@ -50,7 +52,7 @@ const TEST_CASES: {
5052
tags: ["conversational"],
5153
inference: "",
5254
},
53-
languages: ["sh", "js", "py"],
55+
languages: ["sh", "js", "python"],
5456
providers: ["hf-inference", "together"],
5557
opts: { streaming: false },
5658
},
@@ -62,7 +64,7 @@ const TEST_CASES: {
6264
tags: ["conversational"],
6365
inference: "",
6466
},
65-
languages: ["sh", "js", "py"],
67+
languages: ["sh", "js", "python"],
6668
providers: ["hf-inference", "together"],
6769
opts: { streaming: true },
6870
},
@@ -74,7 +76,7 @@ const TEST_CASES: {
7476
tags: ["conversational"],
7577
inference: "",
7678
},
77-
languages: ["sh", "js", "py"],
79+
languages: ["sh", "js", "python"],
7880
providers: ["hf-inference", "fireworks-ai"],
7981
opts: { streaming: false },
8082
},
@@ -86,7 +88,7 @@ const TEST_CASES: {
8688
tags: ["conversational"],
8789
inference: "",
8890
},
89-
languages: ["sh", "js", "py"],
91+
languages: ["sh", "js", "python"],
9092
providers: ["hf-inference", "fireworks-ai"],
9193
opts: { streaming: true },
9294
},
@@ -98,7 +100,7 @@ const TEST_CASES: {
98100
tags: [],
99101
inference: "",
100102
},
101-
languages: ["py"],
103+
languages: ["python"],
102104
providers: ["hf-inference"],
103105
},
104106
{
@@ -109,7 +111,7 @@ const TEST_CASES: {
109111
tags: [],
110112
inference: "",
111113
},
112-
languages: ["py"],
114+
languages: ["python"],
113115
providers: ["hf-inference"],
114116
},
115117
{
@@ -121,7 +123,7 @@ const TEST_CASES: {
121123
inference: "",
122124
},
123125
providers: ["hf-inference"],
124-
languages: ["py"],
126+
languages: ["python"],
125127
},
126128
{
127129
testName: "text-to-audio-transformers",
@@ -132,7 +134,7 @@ const TEST_CASES: {
132134
inference: "",
133135
},
134136
providers: ["hf-inference"],
135-
languages: ["py"],
137+
languages: ["python"],
136138
},
137139
{
138140
testName: "text-to-image",
@@ -143,7 +145,7 @@ const TEST_CASES: {
143145
inference: "",
144146
},
145147
providers: ["hf-inference", "fal-ai"],
146-
languages: ["sh", "js", "py"],
148+
languages: ["sh", "js", "python"],
147149
},
148150
{
149151
testName: "text-to-video",
@@ -154,7 +156,7 @@ const TEST_CASES: {
154156
inference: "",
155157
},
156158
providers: ["replicate", "fal-ai"],
157-
languages: ["js", "py"],
159+
languages: ["js", "python"],
158160
},
159161
{
160162
testName: "text-classification",
@@ -165,7 +167,7 @@ const TEST_CASES: {
165167
inference: "",
166168
},
167169
providers: ["hf-inference"],
168-
languages: ["sh", "js", "py"],
170+
languages: ["sh", "js", "python"],
169171
},
170172
{
171173
testName: "basic-snippet--token-classification",
@@ -176,7 +178,7 @@ const TEST_CASES: {
176178
inference: "",
177179
},
178180
providers: ["hf-inference"],
179-
languages: ["py"],
181+
languages: ["python"],
180182
},
181183
{
182184
testName: "zero-shot-classification",
@@ -187,7 +189,7 @@ const TEST_CASES: {
187189
inference: "",
188190
},
189191
providers: ["hf-inference"],
190-
languages: ["py"],
192+
languages: ["python"],
191193
},
192194
{
193195
testName: "zero-shot-image-classification",
@@ -198,14 +200,14 @@ const TEST_CASES: {
198200
inference: "",
199201
},
200202
providers: ["hf-inference"],
201-
languages: ["py"],
203+
languages: ["python"],
202204
},
203205
] as const;
204206

205207
const GET_SNIPPET_FN = {
206208
sh: snippets.curl.getCurlInferenceSnippet,
207209
js: snippets.js.getJsInferenceSnippet,
208-
py: snippets.python.getPythonInferenceSnippet,
210+
python: snippets.python.getPythonInferenceSnippet,
209211
} as const;
210212

211213
const rootDirFinder = (): string => {
@@ -228,42 +230,51 @@ function getFixtureFolder(testName: string): string {
228230

229231
function generateInferenceSnippet(
230232
model: ModelDataMinimal,
231-
language: LANGUAGE,
233+
language: Language,
232234
provider: SnippetInferenceProvider,
233235
opts?: Record<string, unknown>
234236
): InferenceSnippet[] {
235237
const providerModelId = provider === "hf-inference" ? model.id : `<${provider} alias for ${model.id}>`;
236-
return GET_SNIPPET_FN[language](model, "api_token", provider, providerModelId, opts);
238+
const snippets = GET_SNIPPET_FN[language](model, "api_token", provider, providerModelId, opts) as InferenceSnippet[];
239+
return snippets.sort((snippetA, snippetB) => snippetA.client.localeCompare(snippetB.client));
237240
}
238241

239242
async function getExpectedInferenceSnippet(
240243
testName: string,
241-
language: LANGUAGE,
244+
language: Language,
242245
provider: SnippetInferenceProvider
243246
): Promise<InferenceSnippet[]> {
244247
const fixtureFolder = getFixtureFolder(testName);
245-
const files = await fs.readdir(fixtureFolder);
248+
const languageFolder = path.join(fixtureFolder, language);
249+
const files = await fs.readdir(languageFolder, { recursive: true });
246250

247251
const expectedSnippets: InferenceSnippet[] = [];
248-
for (const file of files.filter((file) => file.endsWith("." + language) && file.includes(`.${provider}.`)).sort()) {
249-
const client = path.basename(file).split(".").slice(1, -2).join("."); // e.g. '0.huggingface.js.replicate.js' => "huggingface.js"
250-
const content = await fs.readFile(path.join(fixtureFolder, file), { encoding: "utf-8" });
252+
for (const file of files.filter((file) => file.includes(`.${provider}.`)).sort()) {
253+
const client = file.split("/")[0]; // e.g. fal_client/1.fal-ai.python => fal_client
254+
const content = await fs.readFile(path.join(languageFolder, file), { encoding: "utf-8" });
251255
expectedSnippets.push({ client, content });
252256
}
253257
return expectedSnippets;
254258
}
255259

256260
async function saveExpectedInferenceSnippet(
257261
testName: string,
258-
language: LANGUAGE,
262+
language: Language,
259263
provider: SnippetInferenceProvider,
260264
snippets: InferenceSnippet[]
261265
) {
262266
const fixtureFolder = getFixtureFolder(testName);
263267
await fs.mkdir(fixtureFolder, { recursive: true });
264268

265-
for (const [index, snippet] of snippets.entries()) {
266-
const file = path.join(fixtureFolder, `${index}.${snippet.client ?? "default"}.${provider}.${language}`);
269+
const indexPerClient = new Map<string, number>();
270+
for (const snippet of snippets) {
271+
const extension = EXTENSIONS[language];
272+
const client = snippet.client;
273+
const index = indexPerClient.get(client) ?? 0;
274+
indexPerClient.set(client, index + 1);
275+
276+
const file = path.join(fixtureFolder, language, snippet.client, `${index}.${provider}.${extension}`);
277+
await fs.mkdir(path.dirname(file), { recursive: true });
267278
await fs.writeFile(file, snippet.content);
268279
}
269280
}

0 commit comments

Comments
 (0)