Skip to content

[Inference Providers] Refactor: better typing? #1332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 65 additions & 12 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,62 @@ import {
import { NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask } from "../providers/nebius";
import { NovitaConversationalTask, NovitaTextGenerationTask } from "../providers/novita";
import { OpenAIConversationalTask } from "../providers/openai";
import type { TaskProviderHelper } from "../providers/providerHelper";
import type { TaskProviderHelper, TextGenerationTaskHelper, TextToImageTaskHelper } from "../providers/providerHelper";
import { ReplicateTextToImageTask, ReplicateTextToSpeechTask, ReplicateTextToVideoTask } from "../providers/replicate";
import { SambanovaConversationalTask } from "../providers/sambanova";
import { TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask } from "../providers/together";
import type { InferenceProvider, InferenceTask } from "../types";
import { typedIn } from "../utils/typedIn";
import { typedInclude } from "../utils/typedInclude";
import { typedKeys } from "../utils/typedKeys";

export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
"black-forest-labs": {
export const HELPERS = {
"text-to-image": {
"black-forest-labs": new BlackForestLabsTextToImageTask(),
"fal-ai": new FalAITextToImageTask(),
"hf-inference": new HFInferenceTextToImageTask(),
hyperbolic: new HyperbolicTextToImageTask(),
nebius: new NebiusTextToImageTask(),
together: new TogetherTextToImageTask(),
} satisfies Partial<Record<InferenceProvider, TextToImageTaskHelper>>,
// "text-to-speech": {
// "fal-ai": new FalAITextToSpeechTask(),
// "hf-inference": new HFInferenceTask("text-to-speech"),
// replicate: new ReplicateTextToSpeechTask(),
// } satisfies Partial<Record<InferenceProvider, TextToSpeechTaskHelper>>,
// "text-to-video": {
// "fal-ai": new FalAITextToVideoTask(),
// replicate: new ReplicateTextToVideoTask(),
// } satisfies Partial<Record<InferenceProvider, TextToVideoTaskHelper>>,
// "automatic-speech-recognition": {
// "fal-ai": new FalAIAutomaticSpeechRecognitionTask(),
// "hf-inference": new HFInferenceTask("automatic-speech-recognition"),
// } satisfies Partial<Record<InferenceProvider, AutomaticSpeechRecognitionTaskHelper>>,
// "text-generation": {
// "hf-inference": new HFInferenceTextGenerationTask(),
// "hyperbolic": new HyperbolicTextGenerationTask(),
// nebius: new NebiusTextGenerationTask(),
// "novita": new NovitaTextGenerationTask(),
// "together": new TogetherTextGenerationTask(),
// } satisfies Partial<Record<InferenceProvider, TextGenerationTaskHelper>>,
// "conversational": {
// cerebras: new CerebrasConversationalTask(),
// cohere: new CohereConversationalTask(),
// "fireworks-ai": new FireworksConversationalTask(),
// "hf-inference": new HFInferenceConversationalTask(),
// hyperbolic: new HyperbolicConversationalTask(),
// nebius: new NebiusConversationalTask(),
// novita: new NovitaConversationalTask(),
// openai: new OpenAIConversationalTask(),
// replicate: new ReplicateTextToImageTask(),
// sambanova: new SambanovaConversationalTask(),
// together: new TogetherConversationalTask(),
// } satisfies Partial<Record<InferenceProvider, ConversationalTaskHelper>>,
} satisfies Partial<Record<InferenceTask, Partial<Record<InferenceProvider, TaskProviderHelper>>>>;

export const SUPPORTED_TASKS = typedKeys(HELPERS);
/**
"black-forest-labs": {
"text-to-image": new BlackForestLabsTextToImageTask(),
},
cerebras: {
Expand Down Expand Up @@ -106,11 +154,13 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
"text-generation": new TogetherTextGenerationTask(),
conversational: new TogetherConversationalTask(),
},
};
}
*/

/**
* Get provider helper instance by name and task
*/
export function getProviderHelper(provider: InferenceProvider, task: "text-to-image"): TextToImageTaskHelper & TaskProviderHelper;
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper {
// special case for hf-inference, where the task is optional
if (provider === "hf-inference") {
Expand All @@ -121,14 +171,17 @@ export function getProviderHelper(provider: InferenceProvider, task: InferenceTa
if (!task) {
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
}
if (!(provider in PROVIDERS)) {
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
if (!typedInclude(SUPPORTED_TASKS, task)) {
throw new Error(`Task '${task}' not supported. Available tasks: ${Object.keys(HELPERS)}`);
}
const providerTasks = PROVIDERS[provider];
if (!providerTasks || !(task in providerTasks)) {
throw new Error(
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
);
switch (task) {
case "text-to-image": {
if (!typedIn(HELPERS["text-to-image"], provider)) {
throw new Error(
`Provider '${provider}' not supported for task '${task}'. Available providers: ${Object.keys(HELPERS["text-to-image"] ?? {})}`
);
}
return HELPERS["text-to-image"][provider];
}
}
return providerTasks[task] as TaskProviderHelper;
}
4 changes: 2 additions & 2 deletions packages/inference/src/providers/black-forest-labs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { BodyParams, HeaderParams, UrlParams } from "../types";
import { delay } from "../utils/delay";
import { omit } from "../utils/omit";
import { TaskProviderHelper } from "./providerHelper";
import { TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper";

const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
interface BlackForestLabsResponse {
id: string;
polling_url: string;
}

export class BlackForestLabsTextToImageTask extends TaskProviderHelper {
export class BlackForestLabsTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
constructor() {
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL, "text-to-image");
}
Expand Down
5 changes: 3 additions & 2 deletions packages/inference/src/providers/hf-inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { HF_ROUTER_URL } from "../config";
import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { BodyParams, InferenceTask, UrlParams } from "../types";
import { toArray } from "../utils/toArray";
import type { TextToImageTaskHelper } from "./providerHelper";
import { TaskProviderHelper } from "./providerHelper";

interface Base64ImageGeneration {
Expand Down Expand Up @@ -53,7 +54,7 @@ export class HFInferenceTask extends TaskProviderHelper {
}
}

export class HFInferenceTextToImageTask extends HFInferenceTask {
export class HFInferenceTextToImageTask extends HFInferenceTask implements TextToImageTaskHelper {
constructor() {
super("text-to-image");
}
Expand All @@ -63,7 +64,7 @@ export class HFInferenceTextToImageTask extends HFInferenceTask {
url?: string,
headers?: Record<string, string>,
outputType?: "url" | "blob"
): Promise<unknown> {
): Promise<string | Blob> {
if (!response) {
throw new InferenceOutputError("response is undefined");
}
Expand Down
15 changes: 8 additions & 7 deletions packages/inference/src/providers/hyperbolic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import type { ChatCompletionOutput, TextGenerationOutput } from "@huggingface/ta
import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { BodyParams, UrlParams } from "../types";
import { omit } from "../utils/omit";
import type { TextGenerationTaskHelper, TextToImageTaskHelper } from "./providerHelper";
import { BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper } from "./providerHelper";

const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
Expand All @@ -38,7 +39,7 @@ export class HyperbolicConversationalTask extends BaseConversationalTask {
}
}

export class HyperbolicTextGenerationTask extends BaseTextGenerationTask {
export class HyperbolicTextGenerationTask extends BaseTextGenerationTask implements TextGenerationTaskHelper {
constructor() {
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
}
Expand All @@ -53,16 +54,16 @@ export class HyperbolicTextGenerationTask extends BaseTextGenerationTask {
messages: [{ content: params.args.inputs, role: "user" }],
...(params.args.parameters
? {
max_tokens: (params.args.parameters as Record<string, unknown>).max_new_tokens,
...omit(params.args.parameters as Record<string, unknown>, "max_new_tokens"),
}
max_tokens: (params.args.parameters as Record<string, unknown>).max_new_tokens,
...omit(params.args.parameters as Record<string, unknown>, "max_new_tokens"),
}
: undefined),
...omit(params.args, ["inputs", "parameters"]),
model: params.model,
};
}

override getResponse(response: HyperbolicTextCompletionOutput): TextGenerationOutput {
override async getResponse(response: HyperbolicTextCompletionOutput): Promise<TextGenerationOutput> {
if (
typeof response === "object" &&
"choices" in response &&
Expand All @@ -79,7 +80,7 @@ export class HyperbolicTextGenerationTask extends BaseTextGenerationTask {
}
}

export class HyperbolicTextToImageTask extends TaskProviderHelper {
export class HyperbolicTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
constructor() {
super("hyperbolic", HYPERBOLIC_API_BASE_URL, "text-to-image");
}
Expand All @@ -98,7 +99,7 @@ export class HyperbolicTextToImageTask extends TaskProviderHelper {
};
}

getResponse(response: HyperbolicTextToImageOutput, outputType?: "url" | "blob"): Promise<Blob> | string {
async getResponse(response: HyperbolicTextToImageOutput, outputType?: "url" | "blob"): Promise<Blob | string> {
if (
typeof response === "object" &&
"images" in response &&
Expand Down
5 changes: 3 additions & 2 deletions packages/inference/src/providers/nebius.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { BodyParams, UrlParams } from "../types";
import { omit } from "../utils/omit";
import type { TextToImageTaskHelper } from "./providerHelper";
import { BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper } from "./providerHelper";

const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
Expand All @@ -39,7 +40,7 @@ export class NebiusTextGenerationTask extends BaseTextGenerationTask {
}
}

export class NebiusTextToImageTask extends TaskProviderHelper {
export class NebiusTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
constructor() {
super("nebius", NEBIUS_API_BASE_URL, "text-to-image");
}
Expand All @@ -59,7 +60,7 @@ export class NebiusTextToImageTask extends TaskProviderHelper {
return "v1/images/generations";
}

getResponse(response: NebiusBase64ImageGeneration, outputType?: "url" | "blob"): string | Promise<Blob> {
async getResponse(response: NebiusBase64ImageGeneration, outputType?: "url" | "blob"): Promise<string | Blob> {
if (
typeof response === "object" &&
"data" in response &&
Expand Down
50 changes: 40 additions & 10 deletions packages/inference/src/providers/providerHelper.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import type { ChatCompletionOutput, TextGenerationOutput } from "@huggingface/tasks";
import type { ChatCompletionOutput, TextGenerationInput, TextGenerationOutput, TextToImageInput } from "@huggingface/tasks";
import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { BodyParams, HeaderParams, UrlParams } from "../types";
import type { BaseArgs, BodyParams, HeaderParams, UrlParams } from "../types";
import { toArray } from "../utils/toArray";
import type { ChatCompletionInput } from "@huggingface/tasks/src/tasks";
/**
* Base class for task-specific provider helpers
*/
Expand All @@ -11,7 +12,7 @@ export abstract class TaskProviderHelper {
private baseUrl: string,
private task?: string,
readonly clientSideRoutingOnly: boolean = false
) {}
) { }

/**
* Return the response in the expected format.
Expand All @@ -20,9 +21,9 @@ export abstract class TaskProviderHelper {
abstract getResponse(
response: unknown,
url?: string,
headers?: Record<string, string>,
headers?: HeadersInit,
outputType?: "url" | "blob"
): unknown;
): Promise<unknown>;

/**
* Prepare the base URL for the request
Expand Down Expand Up @@ -73,7 +74,36 @@ export abstract class TaskProviderHelper {
abstract preparePayload(params: BodyParams): unknown;
}

export class BaseConversationalTask extends TaskProviderHelper {
export interface TextToImageTaskHelper {
getResponse(response: unknown,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when making response as unknown here, is it still possible to declare the getResponse method in BlackForestLabsTextToImageTask with response: BlackForestLabsResponse ?
same for other downstream classes that will implement TextToImageTaskHelper.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got my answer: it should be good 😄 i was just overthinking it

url?: string,
headers?: HeadersInit,
outputType?: "url" | "blob"
): Promise<string | Blob>;

preparePayload(params: BodyParams<TextToImageInput & BaseArgs>): Record<string, unknown>;
}

export interface TextGenerationTaskHelper {
getResponse(response: unknown,
url?: string,
headers?: HeadersInit,
): Promise<TextGenerationOutput>;

preparePayload(params: BodyParams<TextGenerationInput & BaseArgs>): Record<string, unknown>;

}

export interface ConversationalTaskHelper {
getResponse(response: unknown,
url?: string,
headers?: HeadersInit,
): Promise<ChatCompletionOutput>;

preparePayload(params: BodyParams<ChatCompletionInput & BaseArgs>): Record<string, unknown>;
}

export class BaseConversationalTask extends TaskProviderHelper implements ConversationalTaskHelper {
constructor(provider: string, baseUrl: string, clientSideRoutingOnly: boolean = false) {
super(provider, baseUrl, "conversational", clientSideRoutingOnly);
}
Expand All @@ -83,14 +113,14 @@ export class BaseConversationalTask extends TaskProviderHelper {
return "v1/chat/completions";
}

preparePayload(params: BodyParams): Record<string, unknown> {
preparePayload(params: BodyParams<ChatCompletionInput & BaseArgs>): Record<string, unknown> {
return {
...params.args,
model: params.model,
};
}

getResponse(response: ChatCompletionOutput): ChatCompletionOutput {
async getResponse(response: ChatCompletionOutput): Promise<ChatCompletionOutput> {
if (
typeof response === "object" &&
Array.isArray(response?.choices) &&
Expand All @@ -110,7 +140,7 @@ export class BaseConversationalTask extends TaskProviderHelper {
}
}

export class BaseTextGenerationTask extends TaskProviderHelper {
export class BaseTextGenerationTask extends TaskProviderHelper implements TextGenerationTaskHelper {
constructor(provider: string, baseUrl: string, clientSideRoutingOnly: boolean = false) {
super(provider, baseUrl, "text-generation", clientSideRoutingOnly);
}
Expand All @@ -125,7 +155,7 @@ export class BaseTextGenerationTask extends TaskProviderHelper {
return "v1/completions";
}

getResponse(response: unknown): TextGenerationOutput {
async getResponse(response: unknown): Promise<TextGenerationOutput> {
const res = toArray(response);
// @ts-expect-error - We need to check properties on unknown type
if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
Expand Down
7 changes: 4 additions & 3 deletions packages/inference/src/providers/together.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import type { ChatCompletionOutput, TextGenerationOutput, TextGenerationOutputFi
import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { BodyParams, UrlParams } from "../types";
import { omit } from "../utils/omit";
import type { TextToImageTaskHelper } from "./providerHelper";
import { BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper } from "./providerHelper";

const TOGETHER_API_BASE_URL = "https://api.together.xyz";
Expand Down Expand Up @@ -57,7 +58,7 @@ export class TogetherTextGenerationTask extends BaseTextGenerationTask {
};
}

override getResponse(response: TogeteherTextCompletionOutput): TextGenerationOutput {
override async getResponse(response: TogeteherTextCompletionOutput): Promise<TextGenerationOutput> {
if (
typeof response === "object" &&
"choices" in response &&
Expand All @@ -73,7 +74,7 @@ export class TogetherTextGenerationTask extends BaseTextGenerationTask {
}
}

export class TogetherTextToImageTask extends TaskProviderHelper {
export class TogetherTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
constructor() {
super("together", TOGETHER_API_BASE_URL, "text-to-image");
}
Expand All @@ -93,7 +94,7 @@ export class TogetherTextToImageTask extends TaskProviderHelper {
};
}

getResponse(response: TogetherBase64ImageGeneration, outputType?: "url" | "blob"): string | Promise<Blob> {
async getResponse(response: TogetherBase64ImageGeneration, outputType?: "url" | "blob"): Promise<string | Blob> {
if (
typeof response === "object" &&
"data" in response &&
Expand Down
5 changes: 2 additions & 3 deletions packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ export async function textToImage(
export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
const provider = args.provider ?? "hf-inference";
const providerHelper = getProviderHelper(provider, "text-to-image");
const res = await request<Record<string, unknown>>(args, {
const res = await request(args, {
...options,
task: "text-to-image",
});
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-image" });
// @ts-expect-error - Provider-specific implementations accept the outputType parameter
return providerHelper.getResponse(res, url, info.headers as Record<string, string>, options?.outputType);
return await providerHelper.getResponse(res, url, info.headers, options?.outputType);
}
4 changes: 2 additions & 2 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ export interface UrlParams {
task?: InferenceTask;
}

export interface BodyParams {
args: Record<string, unknown>;
export interface BodyParams<T extends Record<string, unknown> = Record<string, unknown>> {
args: T;
model: string;
task?: InferenceTask;
}
Loading