Skip to content

Commit 43b9364

Browse files
WauplinSBrandeis
andauthored
Use makeRequestOptions to generate inference snippets (#1273)
The broader goal of this PR is to use `makeRequestOptions` from JS InferenceClient in order to get all the implementation details (correct URL, correct authorization header, correct payload, etc.). JS InferenceClient is supposed to be the ground truth in this case. **In practice:** - fixed `makeUrl` when chatCompletion + image-text-to-text (review [here](https://github.com/huggingface/huggingface.js/pull/1273/files#diff-a6509c908fd0fb05fdbd3803492d6e9e2570d6dff2a21db575a76b26bff4d565) + other providers) - fixed wrong URL in `openai` python snippet (e.g. [here](https://github.com/huggingface/huggingface.js/pull/1273/files#diff-338b930b960057f85d0d5dd27032b73cd4834a6a7c7ce4db60af19395b8e56f9), [here](https://github.com/huggingface/huggingface.js/pull/1273/files#diff-a253bcdfdf33df1ac53a4051a8ce7bb047a99f32618c048754d617cb55815c14)) - fixed DQA `requests` snippet ([here](https://github.com/huggingface/huggingface.js/pull/1273/files#diff-3a47136351b4572144f2fd42a2518da9be108b66fd5dca392d9d899a125b02d9)) **Technically, this PR:** - splits `makeRequestOptions` in two parts: the async part that does the model ID resolution (depending on task+provider) and the sync part which generates the url, headers, body, etc. For snippets we only need the second part which is a sync call. => new (internal) method `makeRequestOptionsFromResolvedModel` - moves most of the logic inside `snippetGenerator` - logic is: _get inputs_ => _make request options_ => _prepare template data_ => _iterate over clients_ => _generate snippets_ - **Next:** now that the logic is unified, adapting cURL and JS to use the same logic should be fairly easy (e.g. "just" need to create the jinja templates) - => final goal is to handle all languages/clients/providers with the same code and swap the templates - update most providers to use `/chat/completions` endpoint when `chatCompletion` is enabled - Previously we were also checking that task is `text-generation` => now we will also use /chat/completion on "image-text-to-text" models - that was mostly a bug in existing codebase => detected it thanks to the snippets - updated `./packages/inference/package.json` to allow dev mode. Now running `pnpm run dev` in `@inference` makes it much easier to work with `@tasks-gen` (no need to rebuild each time I make a change) --- **EDIT:** ~there is definitely a breaking change in how I handle the `makeRequestOptions` split (hence the broken CI). Will fix this.~ => fixed. --------- Co-authored-by: Simon Brandeis <[email protected]>
1 parent e5e8eb6 commit 43b9364

File tree

141 files changed

+1938
-1725
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

141 files changed

+1938
-1725
lines changed

packages/inference/package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@
5656
"prepublishOnly": "pnpm run build",
5757
"test": "vitest run --config vitest.config.mts",
5858
"test:browser": "vitest run --browser.name=chrome --browser.headless --config vitest.config.mts",
59-
"check": "tsc"
59+
"check": "tsc",
60+
"dev": "tsup src/index.ts --format cjs,esm --watch"
6061
},
6162
"dependencies": {
6263
"@huggingface/tasks": "workspace:^",

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
4545
};
4646

4747
/**
48-
* Helper that prepares request arguments
48+
* Helper that prepares request arguments.
49+
* This async version handle the model ID resolution step.
4950
*/
5051
export async function makeRequestOptions(
5152
args: RequestArgs & {
@@ -56,17 +57,15 @@ export async function makeRequestOptions(
5657
/** In most cases (unless we pass a endpointUrl) we know the task */
5758
task?: InferenceTask;
5859
chatCompletion?: boolean;
59-
/* Used internally to generate inference snippets (in which case model mapping is done separately) */
60-
skipModelIdResolution?: boolean;
6160
}
6261
): Promise<{ url: string; info: RequestInit }> {
63-
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...remainingArgs } = args;
62+
const { provider: maybeProvider, model: maybeModel } = args;
6463
const provider = maybeProvider ?? "hf-inference";
6564
const providerConfig = providerConfigs[provider];
65+
const { task, chatCompletion } = options ?? {};
6666

67-
const { includeCredentials, task, chatCompletion, signal, skipModelIdResolution } = options ?? {};
68-
69-
if (endpointUrl && provider !== "hf-inference") {
67+
// Validate inputs
68+
if (args.endpointUrl && provider !== "hf-inference") {
7069
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
7170
}
7271
if (maybeModel && isUrl(maybeModel)) {
@@ -81,19 +80,43 @@ export async function makeRequestOptions(
8180
if (providerConfig.clientSideRoutingOnly && !maybeModel) {
8281
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
8382
}
83+
8484
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
8585
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
86-
const model = skipModelIdResolution
87-
? hfModel
88-
: providerConfig.clientSideRoutingOnly
89-
? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
90-
removeProviderPrefix(maybeModel!, provider)
91-
: // For closed-models API providers, one needs to pass the model ID directly (e.g. "gpt-3.5-turbo")
92-
await getProviderModelId({ model: hfModel, provider }, args, {
93-
task,
94-
chatCompletion,
95-
fetch: options?.fetch,
96-
});
86+
const resolvedModel = providerConfig.clientSideRoutingOnly
87+
? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
88+
removeProviderPrefix(maybeModel!, provider)
89+
: await getProviderModelId({ model: hfModel, provider }, args, {
90+
task,
91+
chatCompletion,
92+
fetch: options?.fetch,
93+
});
94+
95+
// Use the sync version with the resolved model
96+
return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
97+
}
98+
99+
/**
100+
* Helper that prepares request arguments. - for internal use only
101+
* This sync version skips the model ID resolution step
102+
*/
103+
export function makeRequestOptionsFromResolvedModel(
104+
resolvedModel: string,
105+
args: RequestArgs & {
106+
data?: Blob | ArrayBuffer;
107+
stream?: boolean;
108+
},
109+
options?: Options & {
110+
task?: InferenceTask;
111+
chatCompletion?: boolean;
112+
}
113+
): { url: string; info: RequestInit } {
114+
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
115+
116+
const provider = maybeProvider ?? "hf-inference";
117+
const providerConfig = providerConfigs[provider];
118+
119+
const { includeCredentials, task, chatCompletion, signal } = options ?? {};
97120

98121
const authMethod = (() => {
99122
if (providerConfig.clientSideRoutingOnly) {
@@ -123,7 +146,7 @@ export async function makeRequestOptions(
123146
authMethod !== "provider-key"
124147
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
125148
: providerConfig.baseUrl,
126-
model,
149+
model: resolvedModel,
127150
chatCompletion,
128151
task,
129152
});
@@ -154,7 +177,7 @@ export async function makeRequestOptions(
154177
: JSON.stringify(
155178
providerConfig.makeBody({
156179
args: remainingArgs as Record<string, unknown>,
157-
model,
180+
model: resolvedModel,
158181
task,
159182
chatCompletion,
160183
})

packages/inference/src/providers/fireworks-ai.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ const makeHeaders = (params: HeaderParams): Record<string, string> => {
3030
};
3131

3232
const makeUrl = (params: UrlParams): string => {
33-
if (params.task === "text-generation" && params.chatCompletion) {
33+
if (params.chatCompletion) {
3434
return `${params.baseUrl}/inference/v1/chat/completions`;
3535
}
3636
return `${params.baseUrl}/inference`;

packages/inference/src/providers/hf-inference.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ const makeUrl = (params: UrlParams): string => {
2929
/// when deployed on hf-inference, those two tasks are automatically compatible with one another.
3030
return `${params.baseUrl}/pipeline/${params.task}/${params.model}`;
3131
}
32-
if (params.task === "text-generation" && params.chatCompletion) {
32+
if (params.chatCompletion) {
3333
return `${params.baseUrl}/models/${params.model}/v1/chat/completions`;
3434
}
3535
return `${params.baseUrl}/models/${params.model}`;

packages/inference/src/providers/nebius.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ const makeUrl = (params: UrlParams): string => {
3333
if (params.task === "text-to-image") {
3434
return `${params.baseUrl}/v1/images/generations`;
3535
}
36+
if (params.chatCompletion) {
37+
return `${params.baseUrl}/v1/chat/completions`;
38+
}
3639
if (params.task === "text-generation") {
37-
if (params.chatCompletion) {
38-
return `${params.baseUrl}/v1/chat/completions`;
39-
}
4040
return `${params.baseUrl}/v1/completions`;
4141
}
4242
return params.baseUrl;

packages/inference/src/providers/novita.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,9 @@ const makeHeaders = (params: HeaderParams): Record<string, string> => {
3030
};
3131

3232
const makeUrl = (params: UrlParams): string => {
33-
if (params.task === "text-generation") {
34-
if (params.chatCompletion) {
35-
return `${params.baseUrl}/v3/openai/chat/completions`;
36-
}
33+
if (params.chatCompletion) {
34+
return `${params.baseUrl}/v3/openai/chat/completions`;
35+
} else if (params.task === "text-generation") {
3736
return `${params.baseUrl}/v3/openai/completions`;
3837
} else if (params.task === "text-to-video") {
3938
return `${params.baseUrl}/v3/hf/${params.model}`;

packages/inference/src/providers/sambanova.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ const makeHeaders = (params: HeaderParams): Record<string, string> => {
3030
};
3131

3232
const makeUrl = (params: UrlParams): string => {
33-
if (params.task === "text-generation" && params.chatCompletion) {
33+
if (params.chatCompletion) {
3434
return `${params.baseUrl}/v1/chat/completions`;
3535
}
3636
return params.baseUrl;

packages/inference/src/providers/together.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ const makeUrl = (params: UrlParams): string => {
3333
if (params.task === "text-to-image") {
3434
return `${params.baseUrl}/v1/images/generations`;
3535
}
36+
if (params.chatCompletion) {
37+
return `${params.baseUrl}/v1/chat/completions`;
38+
}
3639
if (params.task === "text-generation") {
37-
if (params.chatCompletion) {
38-
return `${params.baseUrl}/v1/chat/completions`;
39-
}
4040
return `${params.baseUrl}/v1/completions`;
4141
}
4242
return params.baseUrl;

packages/inference/src/snippets/curl.ts

Lines changed: 0 additions & 177 deletions
This file was deleted.

0 commit comments

Comments
 (0)