-
Notifications
You must be signed in to change notification settings - Fork 434
(WIP) Dynamic inference provider mapping #1173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1230f3d
8537222
add91cf
a4e83d9
8994771
7b00be0
4a7c7ad
10436c7
fdd317f
aa0e049
bdea401
e943a35
72de8ad
bee85ae
08931c0
03a4c4d
2b799be
12e4bed
2737081
660a54c
5cb069f
07069e9
630462b
d646c9d
9d05764
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,4 @@ | ||
export type { ProviderMapping } from "./providers/types"; | ||
export { HfInference, HfInferenceEndpoint } from "./HfInference"; | ||
export { InferenceOutputError } from "./lib/InferenceOutputError"; | ||
export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai"; | ||
export { REPLICATE_SUPPORTED_MODEL_IDS } from "./providers/replicate"; | ||
export { SAMBANOVA_SUPPORTED_MODEL_IDS } from "./providers/sambanova"; | ||
export { TOGETHER_SUPPORTED_MODEL_IDS } from "./providers/together"; | ||
export * from "./types"; | ||
export * from "./tasks"; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import type { WidgetType } from "@huggingface/tasks"; | ||
import type { InferenceProvider, InferenceTask, ModelId, Options, RequestArgs } from "../types"; | ||
import { HF_HUB_URL } from "../config"; | ||
import { HARDCODED_MODEL_ID_MAPPING } from "../providers/consts"; | ||
|
||
type InferenceProviderMapping = Partial< | ||
Record<InferenceProvider, { providerId: string; status: "live" | "staging"; task: WidgetType }> | ||
>; | ||
const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>(); | ||
|
||
export async function getProviderModelId( | ||
params: { | ||
model: string; | ||
provider: InferenceProvider; | ||
}, | ||
args: RequestArgs, | ||
options: { | ||
taskHint?: InferenceTask; | ||
chatCompletion?: boolean; | ||
fetch?: Options["fetch"]; | ||
} = {} | ||
): Promise<string> { | ||
if (params.provider === "hf-inference") { | ||
return params.model; | ||
} | ||
if (!options.taskHint) { | ||
throw new Error("taskHint must be specified when using a third-party provider"); | ||
} | ||
const task: WidgetType = | ||
options.taskHint === "text-generation" && options.chatCompletion ? "conversational" : options.taskHint; | ||
|
||
// A dict called HARDCODED_MODEL_ID_MAPPING takes precedence in all cases (useful for dev purposes) | ||
if (HARDCODED_MODEL_ID_MAPPING[params.model]) { | ||
return HARDCODED_MODEL_ID_MAPPING[params.model]; | ||
} | ||
|
||
let inferenceProviderMapping: InferenceProviderMapping | null; | ||
if (inferenceProviderMappingCache.has(params.model)) { | ||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion | ||
inferenceProviderMapping = inferenceProviderMappingCache.get(params.model)!; | ||
} else { | ||
inferenceProviderMapping = await (options?.fetch ?? fetch)( | ||
`${HF_HUB_URL}/api/models/${params.model}?expand[]=inferenceProviderMapping`, | ||
{ | ||
headers: args.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${args.accessToken}` } : {}, | ||
} | ||
) | ||
.then((resp) => resp.json()) | ||
.then((json) => json.inferenceProviderMapping) | ||
.catch(() => null); | ||
} | ||
|
||
if (!inferenceProviderMapping) { | ||
throw new Error(`We have not been able to find inference provider information for model ${params.model}.`); | ||
} | ||
|
||
const providerMapping = inferenceProviderMapping[params.provider]; | ||
if (providerMapping) { | ||
if (providerMapping.task !== task) { | ||
throw new Error( | ||
`Model ${params.model} is not supported for task ${task} and provider ${params.provider}. Supported task: ${providerMapping.task}.` | ||
); | ||
} | ||
if (providerMapping.status === "staging") { | ||
console.warn( | ||
`Model ${params.model} is in staging mode for provider ${params.provider}. Meant for test purposes only.` | ||
); | ||
} | ||
// TODO: how is it handled server-side if model has multiple tasks (e.g. `text-generation` + `conversational`)? | ||
return providerMapping.providerId; | ||
} | ||
|
||
throw new Error(`Model ${params.model} is not supported provider ${params.provider}.`); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import type { ModelId } from "../types"; | ||
|
||
type ProviderId = string; | ||
|
||
/** | ||
* If you want to try to run inference for a new model locally before it's registered on huggingface.co | ||
* for a given Inference Provider, | ||
* you can add it to the following dictionary, for dev purposes. | ||
*/ | ||
export const HARDCODED_MODEL_ID_MAPPING: Record<ModelId, ProviderId> = { | ||
/** | ||
* "HF model ID" => "Model ID on Inference Provider's side" | ||
*/ | ||
// "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,18 @@ | ||
import type { ProviderMapping } from "./types"; | ||
|
||
export const FAL_AI_API_BASE_URL = "https://fal.run"; | ||
|
||
type FalAiId = string; | ||
|
||
export const FAL_AI_SUPPORTED_MODEL_IDS: ProviderMapping<FalAiId> = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as in https://github.com/huggingface/huggingface_hub/pull/2836/files#r1942755169, i would for now keep the mappings (but make them empty) That way we have slightly less breaking changes too, potentially There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @julien-c I reverted to keep both the hard-coded logic and the new dynamic mapping. Mapping from the Hub takes precedence over the hardcoded one. |
||
"text-to-image": { | ||
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell", | ||
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev", | ||
"playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25", | ||
"ByteDance/SDXL-Lightning": "fal-ai/lightning-models", | ||
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma", | ||
"stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium", | ||
"Warlord-K/Sana-1024": "fal-ai/sana", | ||
"fal/AuraFlow-v0.2": "fal-ai/aura-flow", | ||
"stabilityai/stable-diffusion-xl-base-1.0": "fal-ai/fast-sdxl", | ||
"stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large", | ||
"stabilityai/stable-diffusion-3.5-large-turbo": "fal-ai/stable-diffusion-v35-large/turbo", | ||
"stabilityai/stable-diffusion-3.5-medium": "fal-ai/stable-diffusion-v35-medium", | ||
"Kwai-Kolors/Kolors": "fal-ai/kolors", | ||
}, | ||
"automatic-speech-recognition": { | ||
"openai/whisper-large-v3": "fal-ai/whisper", | ||
}, | ||
"text-to-video": { | ||
"genmo/mochi-1-preview": "fal-ai/mochi-v1", | ||
"tencent/HunyuanVideo": "fal-ai/hunyuan-video", | ||
"THUDM/CogVideoX-5b": "fal-ai/cogvideox-5b", | ||
"Lightricks/LTX-Video": "fal-ai/ltx-video", | ||
}, | ||
}; | ||
/** | ||
* See the registered mapping of HF model ID => Fal model ID here: | ||
* | ||
* https://huggingface.co/api/partners/fal-ai/models | ||
* | ||
* This is a publicly available mapping. | ||
* | ||
* If you want to try to run inference for a new model locally before it's registered on huggingface.co, | ||
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. | ||
* | ||
* - If you work at Fal and want to update this mapping, please use the model mapping API we provide on huggingface.co | ||
* - If you're a community member and want to add a new supported HF model to Fal, please open an issue on the present repo | ||
* and we will tag Fal team members. | ||
* | ||
* Thanks! | ||
*/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,18 @@ | ||
import type { ProviderMapping } from "./types"; | ||
|
||
export const REPLICATE_API_BASE_URL = "https://api.replicate.com"; | ||
|
||
type ReplicateId = string; | ||
|
||
export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping<ReplicateId> = { | ||
"text-to-image": { | ||
"black-forest-labs/FLUX.1-dev": "black-forest-labs/flux-dev", | ||
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell", | ||
"ByteDance/Hyper-SD": | ||
"bytedance/hyper-flux-16step:382cf8959fb0f0d665b26e7e80b8d6dc3faaef1510f14ce017e8c732bb3d1eb7", | ||
"ByteDance/SDXL-Lightning": | ||
"bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637", | ||
"playgroundai/playground-v2.5-1024px-aesthetic": | ||
"playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24", | ||
"stabilityai/stable-diffusion-3.5-large-turbo": "stability-ai/stable-diffusion-3.5-large-turbo", | ||
"stabilityai/stable-diffusion-3.5-large": "stability-ai/stable-diffusion-3.5-large", | ||
"stabilityai/stable-diffusion-3.5-medium": "stability-ai/stable-diffusion-3.5-medium", | ||
"stabilityai/stable-diffusion-xl-base-1.0": | ||
"stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc", | ||
}, | ||
"text-to-speech": { | ||
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:3c645149db020c85d080e2f8cfe482a0e68189a922cde964fa9e80fb179191f3", | ||
"hexgrad/Kokoro-82M": "jaaari/kokoro-82m:dfdf537ba482b029e0a761699e6f55e9162cfd159270bfe0e44857caa5f275a6", | ||
}, | ||
"text-to-video": { | ||
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460", | ||
}, | ||
}; | ||
/** | ||
* See the registered mapping of HF model ID => Replicate model ID here: | ||
* | ||
* https://huggingface.co/api/partners/replicate/models | ||
* | ||
* This is a publicly available mapping. | ||
* | ||
* If you want to try to run inference for a new model locally before it's registered on huggingface.co, | ||
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. | ||
* | ||
* - If you work at Replicate and want to update this mapping, please use the model mapping API we provide on huggingface.co | ||
* - If you're a community member and want to add a new supported HF model to Replicate, please open an issue on the present repo | ||
* and we will tag Replicate team members. | ||
* | ||
* Thanks! | ||
*/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,18 @@ | ||
import type { ProviderMapping } from "./types"; | ||
|
||
export const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai"; | ||
|
||
type SambanovaId = string; | ||
|
||
export const SAMBANOVA_SUPPORTED_MODEL_IDS: ProviderMapping<SambanovaId> = { | ||
/** Chat completion / conversational */ | ||
conversational: { | ||
"allenai/Llama-3.1-Tulu-3-405B":"Llama-3.1-Tulu-3-405B", | ||
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B": "DeepSeek-R1-Distill-Llama-70B", | ||
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", | ||
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct", | ||
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview", | ||
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct", | ||
"meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct", | ||
"meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct", | ||
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct", | ||
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct", | ||
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct", | ||
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct", | ||
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct", | ||
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B", | ||
}, | ||
}; | ||
/** | ||
* See the registered mapping of HF model ID => Sambanova model ID here: | ||
* | ||
* https://huggingface.co/api/partners/sambanova/models | ||
* | ||
* This is a publicly available mapping. | ||
* | ||
* If you want to try to run inference for a new model locally before it's registered on huggingface.co, | ||
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. | ||
* | ||
* - If you work at Sambanova and want to update this mapping, please use the model mapping API we provide on huggingface.co | ||
* - If you're a community member and want to add a new supported HF model to Sambanova, please open an issue on the present repo | ||
* and we will tag Sambanova team members. | ||
* | ||
* Thanks! | ||
*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is ok @Wauplin