Skip to content

(WIP) Dynamic inference provider mapping #1173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/hub/src/lib/list-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export const MODEL_EXPANDABLE_KEYS = [
"downloadsAllTime",
"gated",
"gitalyUid",
"inferenceProviderMapping",
"lastModified",
"library_name",
"likes",
Expand Down
5 changes: 4 additions & 1 deletion packages/hub/src/types/api/api-model.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { ModelLibraryKey, TransformersInfo } from "@huggingface/tasks";
import type { ModelLibraryKey, TransformersInfo, WidgetType } from "@huggingface/tasks";
import type { License, PipelineType } from "../public";

export interface ApiModelInfo {
Expand All @@ -18,6 +18,9 @@ export interface ApiModelInfo {
downloadsAllTime: number;
files: string[];
gitalyUid: string;
inferenceProviderMapping: Partial<
Record<string, { providerId: string; status: "live" | "staging"; task: WidgetType }>
>;
lastAuthor: { email: string; user?: string };
lastModified: string; // convert to date
library_name?: ModelLibraryKey;
Expand Down
5 changes: 0 additions & 5 deletions packages/inference/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
export type { ProviderMapping } from "./providers/types";
export { HfInference, HfInferenceEndpoint } from "./HfInference";
export { InferenceOutputError } from "./lib/InferenceOutputError";
export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai";
export { REPLICATE_SUPPORTED_MODEL_IDS } from "./providers/replicate";
export { SAMBANOVA_SUPPORTED_MODEL_IDS } from "./providers/sambanova";
export { TOGETHER_SUPPORTED_MODEL_IDS } from "./providers/together";
export * from "./types";
export * from "./tasks";
74 changes: 74 additions & 0 deletions packages/inference/src/lib/getProviderModelId.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import type { WidgetType } from "@huggingface/tasks";
import type { InferenceProvider, InferenceTask, ModelId, Options, RequestArgs } from "../types";
import { HF_HUB_URL } from "../config";
import { HARDCODED_MODEL_ID_MAPPING } from "../providers/consts";

type InferenceProviderMapping = Partial<
Record<InferenceProvider, { providerId: string; status: "live" | "staging"; task: WidgetType }>
>;
const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();

export async function getProviderModelId(
params: {
model: string;
provider: InferenceProvider;
},
args: RequestArgs,
options: {
taskHint?: InferenceTask;
chatCompletion?: boolean;
fetch?: Options["fetch"];
} = {}
): Promise<string> {
if (params.provider === "hf-inference") {
return params.model;
}
if (!options.taskHint) {
throw new Error("taskHint must be specified when using a third-party provider");
}
const task: WidgetType =
options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint;

// A dict called HARDCODED_MODEL_ID_MAPPING takes precedence in all cases (useful for dev purposes)
if (HARDCODED_MODEL_ID_MAPPING[params.model]) {
return HARDCODED_MODEL_ID_MAPPING[params.model];
}

let inferenceProviderMapping: InferenceProviderMapping | null;
if (inferenceProviderMappingCache.has(params.model)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
inferenceProviderMapping = inferenceProviderMappingCache.get(params.model)!;
} else {
inferenceProviderMapping = await (options?.fetch ?? fetch)(
`${HF_HUB_URL}/api/models/${params.model}?expand[]=inferenceProviderMapping`,
{
headers: args.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${args.accessToken}` } : {},
}
)
.then((resp) => resp.json())
.then((json) => json.inferenceProviderMapping)
.catch(() => null);
}

if (!inferenceProviderMapping) {
throw new Error(`We have not been able to find inference provider information for model ${params.model}.`);
}

const providerMapping = inferenceProviderMapping[params.provider];
if (providerMapping) {
if (providerMapping.task !== task) {
throw new Error(
`Model ${params.model} is not supported for task ${task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
);
}
if (providerMapping.status === "staging") {
console.warn(
`Model ${params.model} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
);
}
// TODO: how is it handled server-side if model has multiple tasks (e.g. `text-generation` + `conversational`)?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// TODO: how is it handled server-side if model has multiple tasks (e.g. `text-generation` + `conversational`)?

i think this is ok @Wauplin

return providerMapping.providerId;
}

throw new Error(`Model ${params.model} is not supported provider ${params.provider}.`);
}
63 changes: 14 additions & 49 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import type { WidgetType } from "@huggingface/tasks";
import { HF_HUB_URL } from "../config";
import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_SUPPORTED_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_SUPPORTED_MODEL_IDS } from "../providers/together";
import { FAL_AI_API_BASE_URL } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL } from "../providers/together";
import type { InferenceProvider } from "../types";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
import { version as packageVersion, name as packageName } from "../../package.json";
import { getProviderModelId } from "./getProviderModelId";

const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`;

Expand Down Expand Up @@ -49,18 +49,16 @@ export async function makeRequestOptions(
if (maybeModel && isUrl(maybeModel)) {
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
}

let model: string;
if (!maybeModel) {
if (taskHint) {
model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion });
} else {
throw new Error("No model provided, and no default model found for this task");
/// TODO : change error message ^
}
} else {
model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion });
if (!maybeModel && !taskHint) {
throw new Error("No model provided, and no task has been specified.");
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const hfModel = maybeModel ?? (await loadDefaultModel(taskHint!));
const model = await getProviderModelId({ model: hfModel, provider }, args, {
taskHint,
chatCompletion,
fetch: options?.fetch,
});

/// If accessToken is passed, it should take precedence over includeCredentials
const authMethod = accessToken
Expand Down Expand Up @@ -153,39 +151,6 @@ export async function makeRequestOptions(
return { url, info };
}

function mapModel(params: {
model: string;
provider: InferenceProvider;
taskHint: InferenceTask | undefined;
chatCompletion: boolean | undefined;
}): string {
if (params.provider === "hf-inference") {
return params.model;
}
if (!params.taskHint) {
throw new Error("taskHint must be specified when using a third-party provider");
}
const task: WidgetType =
params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
const model = (() => {
switch (params.provider) {
case "fal-ai":
return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model];
case "replicate":
return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model];
case "sambanova":
return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model];
case "together":
return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model];
}
})();

if (!model) {
throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`);
}
return model;
}

function makeUrl(params: {
authMethod: "none" | "hf-token" | "credentials-include" | "provider-key";
chatCompletion: boolean;
Expand Down
15 changes: 15 additions & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import type { ModelId } from "../types";

type ProviderId = string;

/**
* If you want to try to run inference for a new model locally before it's registered on huggingface.co
* for a given Inference Provider,
* you can add it to the following dictionary, for dev purposes.
*/
export const HARDCODED_MODEL_ID_MAPPING: Record<ModelId, ProviderId> = {
/**
* "HF model ID" => "Model ID on Inference Provider's side"
*/
// "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
};
46 changes: 16 additions & 30 deletions packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,18 @@
import type { ProviderMapping } from "./types";

export const FAL_AI_API_BASE_URL = "https://fal.run";

type FalAiId = string;

export const FAL_AI_SUPPORTED_MODEL_IDS: ProviderMapping<FalAiId> = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as in https://github.com/huggingface/huggingface_hub/pull/2836/files#r1942755169, i would for now keep the mappings (but make them empty)
and still support them in code

That way we have slightly less breaking changes too, potentially

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@julien-c I reverted to keep both the hard-coded logic and the new dynamic mapping. Mapping from the Hub takes precedence over the hardcoded one.

"text-to-image": {
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
"playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25",
"ByteDance/SDXL-Lightning": "fal-ai/lightning-models",
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma",
"stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium",
"Warlord-K/Sana-1024": "fal-ai/sana",
"fal/AuraFlow-v0.2": "fal-ai/aura-flow",
"stabilityai/stable-diffusion-xl-base-1.0": "fal-ai/fast-sdxl",
"stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large",
"stabilityai/stable-diffusion-3.5-large-turbo": "fal-ai/stable-diffusion-v35-large/turbo",
"stabilityai/stable-diffusion-3.5-medium": "fal-ai/stable-diffusion-v35-medium",
"Kwai-Kolors/Kolors": "fal-ai/kolors",
},
"automatic-speech-recognition": {
"openai/whisper-large-v3": "fal-ai/whisper",
},
"text-to-video": {
"genmo/mochi-1-preview": "fal-ai/mochi-v1",
"tencent/HunyuanVideo": "fal-ai/hunyuan-video",
"THUDM/CogVideoX-5b": "fal-ai/cogvideox-5b",
"Lightricks/LTX-Video": "fal-ai/ltx-video",
},
};
/**
* See the registered mapping of HF model ID => Fal model ID here:
*
* https://huggingface.co/api/partners/fal-ai/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Fal and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Fal, please open an issue on the present repo
* and we will tag Fal team members.
*
* Thanks!
*/
44 changes: 16 additions & 28 deletions packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,18 @@
import type { ProviderMapping } from "./types";

export const REPLICATE_API_BASE_URL = "https://api.replicate.com";

type ReplicateId = string;

export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping<ReplicateId> = {
"text-to-image": {
"black-forest-labs/FLUX.1-dev": "black-forest-labs/flux-dev",
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/Hyper-SD":
"bytedance/hyper-flux-16step:382cf8959fb0f0d665b26e7e80b8d6dc3faaef1510f14ce017e8c732bb3d1eb7",
"ByteDance/SDXL-Lightning":
"bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
"playgroundai/playground-v2.5-1024px-aesthetic":
"playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24",
"stabilityai/stable-diffusion-3.5-large-turbo": "stability-ai/stable-diffusion-3.5-large-turbo",
"stabilityai/stable-diffusion-3.5-large": "stability-ai/stable-diffusion-3.5-large",
"stabilityai/stable-diffusion-3.5-medium": "stability-ai/stable-diffusion-3.5-medium",
"stabilityai/stable-diffusion-xl-base-1.0":
"stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
},
"text-to-speech": {
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:3c645149db020c85d080e2f8cfe482a0e68189a922cde964fa9e80fb179191f3",
"hexgrad/Kokoro-82M": "jaaari/kokoro-82m:dfdf537ba482b029e0a761699e6f55e9162cfd159270bfe0e44857caa5f275a6",
},
"text-to-video": {
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460",
},
};
/**
* See the registered mapping of HF model ID => Replicate model ID here:
*
* https://huggingface.co/api/partners/replicate/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Replicate and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Replicate, please open an issue on the present repo
* and we will tag Replicate team members.
*
* Thanks!
*/
39 changes: 16 additions & 23 deletions packages/inference/src/providers/sambanova.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
import type { ProviderMapping } from "./types";

export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";

type SambanovaId = string;

export const SAMBANOVA_SUPPORTED_MODEL_IDS: ProviderMapping<SambanovaId> = {
/** Chat completion / conversational */
conversational: {
"allenai/Llama-3.1-Tulu-3-405B":"Llama-3.1-Tulu-3-405B",
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B": "DeepSeek-R1-Distill-Llama-70B",
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct",
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct",
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct",
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B",
},
};
/**
* See the registered mapping of HF model ID => Sambanova model ID here:
*
* https://huggingface.co/api/partners/sambanova/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Sambanova and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Sambanova, please open an issue on the present repo
* and we will tag Sambanova team members.
*
* Thanks!
*/
Loading