Skip to content

[Inference] request() returns a request context to avoid redundant makeRequestOptions calls #1314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 2, 2025
25 changes: 0 additions & 25 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -572,31 +572,6 @@ await hf.tabularClassification({
})
```

## Custom Calls

For models with custom parameters / outputs.

```typescript
await hf.request({
model: 'my-custom-model',
inputs: 'hello world',
parameters: {
custom_param: 'some magic',
}
})

// Custom streaming call, for models with custom parameters / outputs
for await (const output of hf.streamingRequest({
model: 'my-custom-model',
inputs: 'hello world',
parameters: {
custom_param: 'some magic',
}
})) {
...
}
```

You can use any Chat Completion API-compatible provider with the `chatCompletion` method.

```typescript
Expand Down
12 changes: 7 additions & 5 deletions packages/inference/src/tasks/audio/audioClassification.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";
import type { LegacyAudioInput } from "./utils";
import { preparePayload } from "./utils";

Expand All @@ -16,10 +16,12 @@ export async function audioClassification(
options?: Options
): Promise<AudioClassificationOutput> {
const payload = preparePayload(args);
const res = await request<AudioClassificationOutput>(payload, {
...options,
task: "audio-classification",
});
const res = (
await innerRequest<AudioClassificationOutput>(payload, {
...options,
task: "audio-classification",
})
).data;
const isValidOutput =
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
if (!isValidOutput) {
Expand Down
12 changes: 7 additions & 5 deletions packages/inference/src/tasks/audio/audioToAudio.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";
import type { LegacyAudioInput } from "./utils";
import { preparePayload } from "./utils";

Expand Down Expand Up @@ -37,10 +37,12 @@ export interface AudioToAudioOutput {
*/
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
const payload = preparePayload(args);
const res = await request<AudioToAudioOutput>(payload, {
...options,
task: "audio-to-audio",
});
const res = (
await innerRequest<AudioToAudioOutput>(payload, {
...options,
task: "audio-to-audio",
})
).data;

return validateOutput(res);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
import { request } from "../custom/request";
import { omit } from "../../utils/omit";
import { innerRequest } from "../../utils/request";
import type { LegacyAudioInput } from "./utils";
import { preparePayload } from "./utils";
import { omit } from "../../utils/omit";

export type AutomaticSpeechRecognitionArgs = BaseArgs & (AutomaticSpeechRecognitionInput | LegacyAudioInput);
/**
Expand All @@ -17,10 +17,12 @@ export async function automaticSpeechRecognition(
options?: Options
): Promise<AutomaticSpeechRecognitionOutput> {
const payload = await buildPayload(args);
const res = await request<AutomaticSpeechRecognitionOutput>(payload, {
...options,
task: "automatic-speech-recognition",
});
const res = (
await innerRequest<AutomaticSpeechRecognitionOutput>(payload, {
...options,
task: "automatic-speech-recognition",
})
).data;
const isValidOutput = typeof res?.text === "string";
if (!isValidOutput) {
throw new InferenceOutputError("Expected {text: string}");
Expand Down
12 changes: 7 additions & 5 deletions packages/inference/src/tasks/audio/textToSpeech.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { TextToSpeechInput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { omit } from "../../utils/omit";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";
type TextToSpeechArgs = BaseArgs & TextToSpeechInput;

interface OutputUrlTextToSpeechGeneration {
Expand All @@ -22,10 +22,12 @@ export async function textToSpeech(args: TextToSpeechArgs, options?: Options): P
text: args.inputs,
}
: args;
const res = await request<Blob | OutputUrlTextToSpeechGeneration>(payload, {
...options,
task: "text-to-speech",
});
const res = (
await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(payload, {
...options,
task: "text-to-speech",
})
).data;
if (res instanceof Blob) {
return res;
}
Expand Down
36 changes: 4 additions & 32 deletions packages/inference/src/tasks/custom/request.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import type { InferenceTask, Options, RequestArgs } from "../../types";
import { makeRequestOptions } from "../../lib/makeRequestOptions";
import { innerRequest } from "../../utils/request";

/**
* Primitive to make custom calls to the inference provider
* @deprecated Use specific task functions instead. This function will be removed in a future version.
*/
export async function request<T>(
args: RequestArgs,
Expand All @@ -13,35 +14,6 @@ export async function request<T>(
chatCompletion?: boolean;
}
): Promise<T> {
const { url, info } = await makeRequestOptions(args, options);
const response = await (options?.fetch ?? fetch)(url, info);

if (options?.retry_on_error !== false && response.status === 503) {
return request(args, options);
}

if (!response.ok) {
const contentType = response.headers.get("Content-Type");
if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) {
const output = await response.json();
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
throw new Error(
`Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}`
);
}
if (output.error || output.detail) {
throw new Error(JSON.stringify(output.error ?? output.detail));
} else {
throw new Error(output);
}
}
const message = contentType?.startsWith("text/plain;") ? await response.text() : undefined;
throw new Error(message ?? "An error occurred while fetching the blob");
}

if (response.headers.get("Content-Type")?.startsWith("application/json")) {
return await response.json();
}

return (await response.blob()) as T;
const result = await innerRequest<T>(args, options);
return result.data;
}
89 changes: 3 additions & 86 deletions packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import type { InferenceTask, Options, RequestArgs } from "../../types";
import { makeRequestOptions } from "../../lib/makeRequestOptions";
import type { EventSourceMessage } from "../../vendor/fetch-event-source/parse";
import { getLines, getMessages } from "../../vendor/fetch-event-source/parse";

import { innerStreamingRequest } from "../../utils/request";
/**
* Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator
* @deprecated Use specific task functions instead. This function will be removed in a future version.
*/
export async function* streamingRequest<T>(
args: RequestArgs,
Expand All @@ -15,86 +13,5 @@ export async function* streamingRequest<T>(
chatCompletion?: boolean;
}
): AsyncGenerator<T> {
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
const response = await (options?.fetch ?? fetch)(url, info);

if (options?.retry_on_error !== false && response.status === 503) {
return yield* streamingRequest(args, options);
}
if (!response.ok) {
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
const output = await response.json();
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
}
if (typeof output.error === "string") {
throw new Error(output.error);
}
if (output.error && "message" in output.error && typeof output.error.message === "string") {
/// OpenAI errors
throw new Error(output.error.message);
}
}

throw new Error(`Server response contains error: ${response.status}`);
}
if (!response.headers.get("content-type")?.startsWith("text/event-stream")) {
throw new Error(
`Server does not support event stream content type, it returned ` + response.headers.get("content-type")
);
}

if (!response.body) {
return;
}

const reader = response.body.getReader();
let events: EventSourceMessage[] = [];

const onEvent = (event: EventSourceMessage) => {
// accumulate events in array
events.push(event);
};

const onChunk = getLines(
getMessages(
() => {},
() => {},
onEvent
)
);

try {
while (true) {
const { done, value } = await reader.read();
if (done) {
return;
}
onChunk(value);
for (const event of events) {
if (event.data.length > 0) {
if (event.data === "[DONE]") {
return;
}
const data = JSON.parse(event.data);
if (typeof data === "object" && data !== null && "error" in data) {
const errorStr =
typeof data.error === "string"
? data.error
: typeof data.error === "object" &&
data.error &&
"message" in data.error &&
typeof data.error.message === "string"
? data.error.message
: JSON.stringify(data.error);
throw new Error(`Error forwarded from backend: ` + errorStr);
}
yield data as T;
}
}
events = [];
}
} finally {
reader.releaseLock();
}
yield* innerStreamingRequest(args, options);
}
12 changes: 7 additions & 5 deletions packages/inference/src/tasks/cv/imageClassification.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";
import { preparePayload, type LegacyImageInput } from "./utils";

export type ImageClassificationArgs = BaseArgs & (ImageClassificationInput | LegacyImageInput);
Expand All @@ -15,10 +15,12 @@ export async function imageClassification(
options?: Options
): Promise<ImageClassificationOutput> {
const payload = preparePayload(args);
const res = await request<ImageClassificationOutput>(payload, {
...options,
task: "image-classification",
});
const res = (
await innerRequest<ImageClassificationOutput>(payload, {
...options,
task: "image-classification",
})
).data;
const isValidOutput =
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
if (!isValidOutput) {
Expand Down
12 changes: 7 additions & 5 deletions packages/inference/src/tasks/cv/imageSegmentation.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";
import { preparePayload, type LegacyImageInput } from "./utils";

export type ImageSegmentationArgs = BaseArgs & (ImageSegmentationInput | LegacyImageInput);
Expand All @@ -15,10 +15,12 @@ export async function imageSegmentation(
options?: Options
): Promise<ImageSegmentationOutput> {
const payload = preparePayload(args);
const res = await request<ImageSegmentationOutput>(payload, {
...options,
task: "image-segmentation",
});
const res = (
await innerRequest<ImageSegmentationOutput>(payload, {
...options,
task: "image-segmentation",
})
).data;
const isValidOutput =
Array.isArray(res) &&
res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
Expand Down
12 changes: 7 additions & 5 deletions packages/inference/src/tasks/cv/imageToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { ImageToImageInput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";

export type ImageToImageArgs = BaseArgs & ImageToImageInput;

Expand All @@ -26,10 +26,12 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
),
};
}
const res = await request<Blob>(reqArgs, {
...options,
task: "image-to-image",
});
const res = (
await innerRequest<Blob>(reqArgs, {
...options,
task: "image-to-image",
})
).data;
const isValidOutput = res && res instanceof Blob;
if (!isValidOutput) {
throw new InferenceOutputError("Expected Blob");
Expand Down
6 changes: 3 additions & 3 deletions packages/inference/src/tasks/cv/imageToText.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { innerRequest } from "../../utils/request";
import type { LegacyImageInput } from "./utils";
import { preparePayload } from "./utils";

Expand All @@ -12,11 +12,11 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
const payload = preparePayload(args);
const res = (
await request<[ImageToTextOutput]>(payload, {
await innerRequest<[ImageToTextOutput]>(payload, {
...options,
task: "image-to-text",
})
)?.[0];
).data?.[0];

if (typeof res?.generated_text !== "string") {
throw new InferenceOutputError("Expected {generated_text: string}");
Expand Down
Loading