-
Notifications
You must be signed in to change notification settings - Fork 441
Better typing + make task functions generic #1338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
6d77313
better typing and refactor hf-inference
hanouticelina 50a0313
nit
hanouticelina 5182fe8
reorder
hanouticelina ee8aaa8
remove unused task property
SBrandeis c512b63
no cast when checking type in text generation helper
SBrandeis 83065f7
tweak: code style
SBrandeis d14a0a5
Enforce narrower type for provider arg in constructors
SBrandeis 8981a88
format + lint
SBrandeis de1c13e
fix type predicqte
SBrandeis 57b304c
fix snippet and add text-to-audio
hanouticelina File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,131 +1,270 @@ | ||
import { BlackForestLabsTextToImageTask } from "../providers/black-forest-labs"; | ||
import { CerebrasConversationalTask } from "../providers/cerebras"; | ||
import { CohereConversationalTask } from "../providers/cohere"; | ||
import { | ||
FalAIAutomaticSpeechRecognitionTask, | ||
FalAITextToImageTask, | ||
FalAITextToSpeechTask, | ||
FalAITextToVideoTask, | ||
} from "../providers/fal-ai"; | ||
import { FireworksConversationalTask } from "../providers/fireworks-ai"; | ||
import { | ||
HFInferenceConversationalTask, | ||
HFInferenceTask, | ||
HFInferenceTextGenerationTask, | ||
HFInferenceTextToImageTask, | ||
} from "../providers/hf-inference"; | ||
import { | ||
HyperbolicConversationalTask, | ||
HyperbolicTextGenerationTask, | ||
HyperbolicTextToImageTask, | ||
} from "../providers/hyperbolic"; | ||
import { NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask } from "../providers/nebius"; | ||
import { NovitaConversationalTask, NovitaTextGenerationTask } from "../providers/novita"; | ||
import { OpenAIConversationalTask } from "../providers/openai"; | ||
import type { TaskProviderHelper } from "../providers/providerHelper"; | ||
import { ReplicateTextToImageTask, ReplicateTextToSpeechTask, ReplicateTextToVideoTask } from "../providers/replicate"; | ||
import { SambanovaConversationalTask } from "../providers/sambanova"; | ||
import { TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask } from "../providers/together"; | ||
import * as BlackForestLabs from "../providers/black-forest-labs"; | ||
import * as Cerebras from "../providers/cerebras"; | ||
import * as Cohere from "../providers/cohere"; | ||
import * as FalAI from "../providers/fal-ai"; | ||
import * as Fireworks from "../providers/fireworks-ai"; | ||
import * as HFInference from "../providers/hf-inference"; | ||
|
||
import * as Hyperbolic from "../providers/hyperbolic"; | ||
import * as Nebius from "../providers/nebius"; | ||
import * as Novita from "../providers/novita"; | ||
import * as OpenAI from "../providers/openai"; | ||
import type { | ||
AudioClassificationTaskHelper, | ||
AudioToAudioTaskHelper, | ||
AutomaticSpeechRecognitionTaskHelper, | ||
ConversationalTaskHelper, | ||
DocumentQuestionAnsweringTaskHelper, | ||
FeatureExtractionTaskHelper, | ||
FillMaskTaskHelper, | ||
ImageClassificationTaskHelper, | ||
ImageSegmentationTaskHelper, | ||
ImageToImageTaskHelper, | ||
ImageToTextTaskHelper, | ||
ObjectDetectionTaskHelper, | ||
QuestionAnsweringTaskHelper, | ||
SentenceSimilarityTaskHelper, | ||
SummarizationTaskHelper, | ||
TableQuestionAnsweringTaskHelper, | ||
TabularClassificationTaskHelper, | ||
TabularRegressionTaskHelper, | ||
TaskProviderHelper, | ||
TextClassificationTaskHelper, | ||
TextGenerationTaskHelper, | ||
TextToAudioTaskHelper, | ||
TextToImageTaskHelper, | ||
TextToSpeechTaskHelper, | ||
TextToVideoTaskHelper, | ||
TokenClassificationTaskHelper, | ||
TranslationTaskHelper, | ||
VisualQuestionAnsweringTaskHelper, | ||
ZeroShotClassificationTaskHelper, | ||
ZeroShotImageClassificationTaskHelper, | ||
} from "../providers/providerHelper"; | ||
import * as Replicate from "../providers/replicate"; | ||
import * as Sambanova from "../providers/sambanova"; | ||
import * as Together from "../providers/together"; | ||
import type { InferenceProvider, InferenceTask } from "../types"; | ||
|
||
export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = { | ||
"black-forest-labs": { | ||
"text-to-image": new BlackForestLabsTextToImageTask(), | ||
"text-to-image": new BlackForestLabs.BlackForestLabsTextToImageTask(), | ||
}, | ||
cerebras: { | ||
conversational: new CerebrasConversationalTask(), | ||
conversational: new Cerebras.CerebrasConversationalTask(), | ||
}, | ||
cohere: { | ||
conversational: new CohereConversationalTask(), | ||
conversational: new Cohere.CohereConversationalTask(), | ||
}, | ||
"fal-ai": { | ||
"automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask(), | ||
"text-to-image": new FalAITextToImageTask(), | ||
"text-to-speech": new FalAITextToSpeechTask(), | ||
"text-to-video": new FalAITextToVideoTask(), | ||
}, | ||
"fireworks-ai": { | ||
conversational: new FireworksConversationalTask(), | ||
"text-to-image": new FalAI.FalAITextToImageTask(), | ||
"text-to-speech": new FalAI.FalAITextToSpeechTask(), | ||
"text-to-video": new FalAI.FalAITextToVideoTask(), | ||
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(), | ||
}, | ||
"hf-inference": { | ||
"text-to-image": new HFInferenceTextToImageTask(), | ||
conversational: new HFInferenceConversationalTask(), | ||
"text-generation": new HFInferenceTextGenerationTask(), | ||
"text-classification": new HFInferenceTask("text-classification"), | ||
"text-to-audio": new HFInferenceTask("text-to-audio"), | ||
"question-answering": new HFInferenceTask("question-answering"), | ||
"audio-classification": new HFInferenceTask("audio-classification"), | ||
"automatic-speech-recognition": new HFInferenceTask("automatic-speech-recognition"), | ||
"fill-mask": new HFInferenceTask("fill-mask"), | ||
"feature-extraction": new HFInferenceTask("feature-extraction"), | ||
"image-classification": new HFInferenceTask("image-classification"), | ||
"image-segmentation": new HFInferenceTask("image-segmentation"), | ||
"document-question-answering": new HFInferenceTask("document-question-answering"), | ||
"image-to-text": new HFInferenceTask("image-to-text"), | ||
"object-detection": new HFInferenceTask("object-detection"), | ||
"audio-to-audio": new HFInferenceTask("audio-to-audio"), | ||
"zero-shot-image-classification": new HFInferenceTask("zero-shot-image-classification"), | ||
"zero-shot-classification": new HFInferenceTask("zero-shot-classification"), | ||
"image-to-image": new HFInferenceTask("image-to-image"), | ||
"sentence-similarity": new HFInferenceTask("sentence-similarity"), | ||
"table-question-answering": new HFInferenceTask("table-question-answering"), | ||
"tabular-classification": new HFInferenceTask("tabular-classification"), | ||
"text-to-speech": new HFInferenceTask("text-to-speech"), | ||
"token-classification": new HFInferenceTask("token-classification"), | ||
translation: new HFInferenceTask("translation"), | ||
summarization: new HFInferenceTask("summarization"), | ||
"visual-question-answering": new HFInferenceTask("visual-question-answering"), | ||
"text-to-image": new HFInference.HFInferenceTextToImageTask(), | ||
conversational: new HFInference.HFInferenceConversationalTask(), | ||
"text-generation": new HFInference.HFInferenceTextGenerationTask(), | ||
"text-classification": new HFInference.HFInferenceTextClassificationTask(), | ||
"question-answering": new HFInference.HFInferenceQuestionAnsweringTask(), | ||
"audio-classification": new HFInference.HFInferenceAudioClassificationTask(), | ||
"automatic-speech-recognition": new HFInference.HFInferenceAutomaticSpeechRecognitionTask(), | ||
"fill-mask": new HFInference.HFInferenceFillMaskTask(), | ||
"feature-extraction": new HFInference.HFInferenceFeatureExtractionTask(), | ||
"image-classification": new HFInference.HFInferenceImageClassificationTask(), | ||
"image-segmentation": new HFInference.HFInferenceImageSegmentationTask(), | ||
"document-question-answering": new HFInference.HFInferenceDocumentQuestionAnsweringTask(), | ||
"image-to-text": new HFInference.HFInferenceImageToTextTask(), | ||
"object-detection": new HFInference.HFInferenceObjectDetectionTask(), | ||
"audio-to-audio": new HFInference.HFInferenceAudioToAudioTask(), | ||
"zero-shot-image-classification": new HFInference.HFInferenceZeroShotImageClassificationTask(), | ||
"zero-shot-classification": new HFInference.HFInferenceZeroShotClassificationTask(), | ||
"image-to-image": new HFInference.HFInferenceImageToImageTask(), | ||
"sentence-similarity": new HFInference.HFInferenceSentenceSimilarityTask(), | ||
"table-question-answering": new HFInference.HFInferenceTableQuestionAnsweringTask(), | ||
"tabular-classification": new HFInference.HFInferenceTabularClassificationTask(), | ||
"text-to-speech": new HFInference.HFInferenceTextToSpeechTask(), | ||
"token-classification": new HFInference.HFInferenceTokenClassificationTask(), | ||
translation: new HFInference.HFInferenceTranslationTask(), | ||
summarization: new HFInference.HFInferenceSummarizationTask(), | ||
"visual-question-answering": new HFInference.HFInferenceVisualQuestionAnsweringTask(), | ||
"tabular-regression": new HFInference.HFInferenceTabularRegressionTask(), | ||
"text-to-audio": new HFInference.HFInferenceTextToAudioTask(), | ||
}, | ||
"fireworks-ai": { | ||
conversational: new Fireworks.FireworksConversationalTask(), | ||
}, | ||
hyperbolic: { | ||
"text-to-image": new HyperbolicTextToImageTask(), | ||
conversational: new HyperbolicConversationalTask(), | ||
"text-generation": new HyperbolicTextGenerationTask(), | ||
"text-to-image": new Hyperbolic.HyperbolicTextToImageTask(), | ||
conversational: new Hyperbolic.HyperbolicConversationalTask(), | ||
"text-generation": new Hyperbolic.HyperbolicTextGenerationTask(), | ||
}, | ||
nebius: { | ||
"text-to-image": new NebiusTextToImageTask(), | ||
conversational: new NebiusConversationalTask(), | ||
"text-generation": new NebiusTextGenerationTask(), | ||
"text-to-image": new Nebius.NebiusTextToImageTask(), | ||
conversational: new Nebius.NebiusConversationalTask(), | ||
"text-generation": new Nebius.NebiusTextGenerationTask(), | ||
}, | ||
novita: { | ||
"text-generation": new NovitaTextGenerationTask(), | ||
conversational: new NovitaConversationalTask(), | ||
conversational: new Novita.NovitaConversationalTask(), | ||
"text-generation": new Novita.NovitaTextGenerationTask(), | ||
}, | ||
openai: { | ||
conversational: new OpenAIConversationalTask(), | ||
conversational: new OpenAI.OpenAIConversationalTask(), | ||
}, | ||
replicate: { | ||
"text-to-image": new ReplicateTextToImageTask(), | ||
"text-to-speech": new ReplicateTextToSpeechTask(), | ||
"text-to-video": new ReplicateTextToVideoTask(), | ||
"text-to-image": new Replicate.ReplicateTextToImageTask(), | ||
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(), | ||
"text-to-video": new Replicate.ReplicateTextToVideoTask(), | ||
}, | ||
sambanova: { | ||
conversational: new SambanovaConversationalTask(), | ||
conversational: new Sambanova.SambanovaConversationalTask(), | ||
}, | ||
together: { | ||
"text-to-image": new TogetherTextToImageTask(), | ||
"text-generation": new TogetherTextGenerationTask(), | ||
conversational: new TogetherConversationalTask(), | ||
"text-to-image": new Together.TogetherTextToImageTask(), | ||
conversational: new Together.TogetherConversationalTask(), | ||
"text-generation": new Together.TogetherTextGenerationTask(), | ||
}, | ||
}; | ||
|
||
/** | ||
* Get provider helper instance by name and task | ||
*/ | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "text-to-image" | ||
): TextToImageTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "conversational" | ||
): ConversationalTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "text-generation" | ||
): TextGenerationTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "text-to-speech" | ||
): TextToSpeechTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "text-to-audio" | ||
): TextToAudioTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "automatic-speech-recognition" | ||
): AutomaticSpeechRecognitionTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "text-to-video" | ||
): TextToVideoTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "text-classification" | ||
): TextClassificationTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "question-answering" | ||
): QuestionAnsweringTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "audio-classification" | ||
): AudioClassificationTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "audio-to-audio" | ||
): AudioToAudioTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "fill-mask" | ||
): FillMaskTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "feature-extraction" | ||
): FeatureExtractionTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "image-classification" | ||
): ImageClassificationTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "image-segmentation" | ||
): ImageSegmentationTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "document-question-answering" | ||
): DocumentQuestionAnsweringTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "image-to-text" | ||
): ImageToTextTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "object-detection" | ||
): ObjectDetectionTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "zero-shot-image-classification" | ||
): ZeroShotImageClassificationTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "zero-shot-classification" | ||
): ZeroShotClassificationTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "image-to-image" | ||
): ImageToImageTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "sentence-similarity" | ||
): SentenceSimilarityTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "table-question-answering" | ||
): TableQuestionAnsweringTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "tabular-classification" | ||
): TabularClassificationTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "tabular-regression" | ||
): TabularRegressionTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "token-classification" | ||
): TokenClassificationTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "translation" | ||
): TranslationTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "summarization" | ||
): SummarizationTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper( | ||
provider: InferenceProvider, | ||
task: "visual-question-answering" | ||
): VisualQuestionAnsweringTaskHelper & TaskProviderHelper; | ||
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper; | ||
|
||
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper { | ||
// special case for hf-inference, where the task is optional | ||
if (provider === "hf-inference") { | ||
if (!task) { | ||
return new HFInferenceTask(); | ||
return new HFInference.HFInferenceTask(); | ||
} | ||
} | ||
if (!task) { | ||
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'"); | ||
} | ||
const helper = PROVIDERS[provider][task]; | ||
if (!helper) { | ||
if (!(provider in PROVIDERS)) { | ||
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`); | ||
} | ||
const providerTasks = PROVIDERS[provider]; | ||
if (!providerTasks || !(task in providerTasks)) { | ||
throw new Error( | ||
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(PROVIDERS[provider])}` | ||
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}` | ||
); | ||
} | ||
return helper; | ||
return providerTasks[task] as TaskProviderHelper; | ||
} | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.