Skip to content

[Inference] Implement a "1 class = 1 provider<>task pair" logic to isolate provider-specific code #1315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
fc0926c
refactor providers
hanouticelina Mar 26, 2025
4e6cf94
nit
hanouticelina Mar 26, 2025
5514cde
nit
hanouticelina Mar 26, 2025
84745e3
remove unnecessary check
hanouticelina Mar 26, 2025
9466000
fix linting
hanouticelina Mar 27, 2025
fa3f8f0
implement individual classes for sambanova, cohere and cerebras
hanouticelina Mar 27, 2025
a4b4682
add hf-inference helpers
hanouticelina Mar 27, 2025
4dd7e02
fix text-to-image
hanouticelina Mar 27, 2025
e3cb303
Merge branch 'main' of github.com:huggingface/huggingface.js into ref…
hanouticelina Mar 27, 2025
213e658
use conversational task
hanouticelina Mar 27, 2025
6c3823e
backward compatibility hf-inference tasks
hanouticelina Mar 27, 2025
8ccab73
fix tests
hanouticelina Mar 28, 2025
59d1457
Merge branch 'main' of github.com:huggingface/huggingface.js into ref…
hanouticelina Mar 28, 2025
c6252ae
fixes
hanouticelina Mar 28, 2025
e34f2a2
add text-to-audio
hanouticelina Mar 28, 2025
677afc0
improvements and lint
hanouticelina Mar 28, 2025
cc4b255
remove code and add missing tasks for replicate
hanouticelina Mar 28, 2025
b374452
nit
hanouticelina Mar 28, 2025
d0d0f73
Merge branch 'main' of github.com:huggingface/huggingface.js into ref…
hanouticelina Apr 1, 2025
d138032
regenerate fal-ai snippet
hanouticelina Apr 1, 2025
20a8864
nit
hanouticelina Apr 1, 2025
a9ddea1
no need to 'override'
hanouticelina Apr 1, 2025
5a8c576
fix
hanouticelina Apr 1, 2025
6da3a75
apply suggestions
hanouticelina Apr 1, 2025
f87081b
Merge branch 'main' into refactor-providers
hanouticelina Apr 1, 2025
2276e94
some fixes
hanouticelina Apr 2, 2025
6835182
Merge branch 'refactor-providers' of github.com:huggingface/huggingfa…
hanouticelina Apr 2, 2025
67afa27
fix code style
hanouticelina Apr 2, 2025
9d16e7c
group abstract methods
hanouticelina Apr 2, 2025
7782f2b
Better typing + make task functions generic (#1338)
hanouticelina Apr 4, 2025
7753620
Merge branch 'main' of github.com:huggingface/huggingface.js into ref…
hanouticelina Apr 4, 2025
21f9d34
nit
hanouticelina Apr 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 270 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import * as BlackForestLabs from "../providers/black-forest-labs";
import * as Cerebras from "../providers/cerebras";
import * as Cohere from "../providers/cohere";
import * as FalAI from "../providers/fal-ai";
import * as Fireworks from "../providers/fireworks-ai";
import * as HFInference from "../providers/hf-inference";

import * as Hyperbolic from "../providers/hyperbolic";
import * as Nebius from "../providers/nebius";
import * as Novita from "../providers/novita";
import * as OpenAI from "../providers/openai";
import type {
AudioClassificationTaskHelper,
AudioToAudioTaskHelper,
AutomaticSpeechRecognitionTaskHelper,
ConversationalTaskHelper,
DocumentQuestionAnsweringTaskHelper,
FeatureExtractionTaskHelper,
FillMaskTaskHelper,
ImageClassificationTaskHelper,
ImageSegmentationTaskHelper,
ImageToImageTaskHelper,
ImageToTextTaskHelper,
ObjectDetectionTaskHelper,
QuestionAnsweringTaskHelper,
SentenceSimilarityTaskHelper,
SummarizationTaskHelper,
TableQuestionAnsweringTaskHelper,
TabularClassificationTaskHelper,
TabularRegressionTaskHelper,
TaskProviderHelper,
TextClassificationTaskHelper,
TextGenerationTaskHelper,
TextToAudioTaskHelper,
TextToImageTaskHelper,
TextToSpeechTaskHelper,
TextToVideoTaskHelper,
TokenClassificationTaskHelper,
TranslationTaskHelper,
VisualQuestionAnsweringTaskHelper,
ZeroShotClassificationTaskHelper,
ZeroShotImageClassificationTaskHelper,
} from "../providers/providerHelper";
import * as Replicate from "../providers/replicate";
import * as Sambanova from "../providers/sambanova";
import * as Together from "../providers/together";
import type { InferenceProvider, InferenceTask } from "../types";

export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
"black-forest-labs": {
"text-to-image": new BlackForestLabs.BlackForestLabsTextToImageTask(),
},
cerebras: {
conversational: new Cerebras.CerebrasConversationalTask(),
},
cohere: {
conversational: new Cohere.CohereConversationalTask(),
},
"fal-ai": {
"text-to-image": new FalAI.FalAITextToImageTask(),
"text-to-speech": new FalAI.FalAITextToSpeechTask(),
"text-to-video": new FalAI.FalAITextToVideoTask(),
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
},
"hf-inference": {
"text-to-image": new HFInference.HFInferenceTextToImageTask(),
conversational: new HFInference.HFInferenceConversationalTask(),
"text-generation": new HFInference.HFInferenceTextGenerationTask(),
"text-classification": new HFInference.HFInferenceTextClassificationTask(),
"question-answering": new HFInference.HFInferenceQuestionAnsweringTask(),
"audio-classification": new HFInference.HFInferenceAudioClassificationTask(),
"automatic-speech-recognition": new HFInference.HFInferenceAutomaticSpeechRecognitionTask(),
"fill-mask": new HFInference.HFInferenceFillMaskTask(),
"feature-extraction": new HFInference.HFInferenceFeatureExtractionTask(),
"image-classification": new HFInference.HFInferenceImageClassificationTask(),
"image-segmentation": new HFInference.HFInferenceImageSegmentationTask(),
"document-question-answering": new HFInference.HFInferenceDocumentQuestionAnsweringTask(),
"image-to-text": new HFInference.HFInferenceImageToTextTask(),
"object-detection": new HFInference.HFInferenceObjectDetectionTask(),
"audio-to-audio": new HFInference.HFInferenceAudioToAudioTask(),
"zero-shot-image-classification": new HFInference.HFInferenceZeroShotImageClassificationTask(),
"zero-shot-classification": new HFInference.HFInferenceZeroShotClassificationTask(),
"image-to-image": new HFInference.HFInferenceImageToImageTask(),
"sentence-similarity": new HFInference.HFInferenceSentenceSimilarityTask(),
"table-question-answering": new HFInference.HFInferenceTableQuestionAnsweringTask(),
"tabular-classification": new HFInference.HFInferenceTabularClassificationTask(),
"text-to-speech": new HFInference.HFInferenceTextToSpeechTask(),
"token-classification": new HFInference.HFInferenceTokenClassificationTask(),
translation: new HFInference.HFInferenceTranslationTask(),
summarization: new HFInference.HFInferenceSummarizationTask(),
"visual-question-answering": new HFInference.HFInferenceVisualQuestionAnsweringTask(),
"tabular-regression": new HFInference.HFInferenceTabularRegressionTask(),
"text-to-audio": new HFInference.HFInferenceTextToAudioTask(),
},
"fireworks-ai": {
conversational: new Fireworks.FireworksConversationalTask(),
},
hyperbolic: {
"text-to-image": new Hyperbolic.HyperbolicTextToImageTask(),
conversational: new Hyperbolic.HyperbolicConversationalTask(),
"text-generation": new Hyperbolic.HyperbolicTextGenerationTask(),
},
nebius: {
"text-to-image": new Nebius.NebiusTextToImageTask(),
conversational: new Nebius.NebiusConversationalTask(),
"text-generation": new Nebius.NebiusTextGenerationTask(),
},
novita: {
conversational: new Novita.NovitaConversationalTask(),
"text-generation": new Novita.NovitaTextGenerationTask(),
},
openai: {
conversational: new OpenAI.OpenAIConversationalTask(),
},
replicate: {
"text-to-image": new Replicate.ReplicateTextToImageTask(),
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
"text-to-video": new Replicate.ReplicateTextToVideoTask(),
},
sambanova: {
conversational: new Sambanova.SambanovaConversationalTask(),
},
together: {
"text-to-image": new Together.TogetherTextToImageTask(),
conversational: new Together.TogetherConversationalTask(),
"text-generation": new Together.TogetherTextGenerationTask(),
},
};

/**
* Get provider helper instance by name and task
*/
export function getProviderHelper(
provider: InferenceProvider,
task: "text-to-image"
): TextToImageTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "conversational"
): ConversationalTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "text-generation"
): TextGenerationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "text-to-speech"
): TextToSpeechTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "text-to-audio"
): TextToAudioTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "automatic-speech-recognition"
): AutomaticSpeechRecognitionTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "text-to-video"
): TextToVideoTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "text-classification"
): TextClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "question-answering"
): QuestionAnsweringTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "audio-classification"
): AudioClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "audio-to-audio"
): AudioToAudioTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "fill-mask"
): FillMaskTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "feature-extraction"
): FeatureExtractionTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "image-classification"
): ImageClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "image-segmentation"
): ImageSegmentationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "document-question-answering"
): DocumentQuestionAnsweringTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "image-to-text"
): ImageToTextTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "object-detection"
): ObjectDetectionTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "zero-shot-image-classification"
): ZeroShotImageClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "zero-shot-classification"
): ZeroShotClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "image-to-image"
): ImageToImageTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "sentence-similarity"
): SentenceSimilarityTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "table-question-answering"
): TableQuestionAnsweringTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "tabular-classification"
): TabularClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "tabular-regression"
): TabularRegressionTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "token-classification"
): TokenClassificationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "translation"
): TranslationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "summarization"
): SummarizationTaskHelper & TaskProviderHelper;
export function getProviderHelper(
provider: InferenceProvider,
task: "visual-question-answering"
): VisualQuestionAnsweringTaskHelper & TaskProviderHelper;
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper;

export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper {
if (provider === "hf-inference") {
if (!task) {
return new HFInference.HFInferenceTask();
}
}
if (!task) {
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
}
if (!(provider in PROVIDERS)) {
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
}
const providerTasks = PROVIDERS[provider];
if (!providerTasks || !(task in providerTasks)) {
throw new Error(
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
);
}
return providerTasks[task] as TaskProviderHelper;
}
Loading