Skip to content

[Inference Providers] Async calls for text-to-video with fal.ai #1292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
19b1de5
add async calls for fal-ai
hanouticelina Mar 17, 2025
246e764
update fal output
hanouticelina Mar 17, 2025
4eca289
fix
hanouticelina Mar 17, 2025
1975dc7
remove comment
hanouticelina Mar 18, 2025
5f77388
fix lint
hanouticelina Mar 18, 2025
6534c9c
Merge branch 'async-calls-falai' of github.com:huggingface/huggingfac…
hanouticelina Mar 18, 2025
0d193dd
Merge branch 'main' into async-calls-falai
hanouticelina Mar 18, 2025
80dc091
Update packages/inference/src/tasks/cv/textToVideo.ts
hanouticelina Mar 18, 2025
e4a7568
Update packages/inference/src/providers/fal-ai.ts
hanouticelina Mar 18, 2025
cf2d1ac
fixes
hanouticelina Mar 18, 2025
77458a4
fix
hanouticelina Mar 18, 2025
8b0f09b
Merge branch 'main' into async-calls-falai
hanouticelina Mar 18, 2025
188175c
use 0.5s for the interval polling
hanouticelina Mar 18, 2025
30ba4cb
Merge branch 'async-calls-falai' of github.com:huggingface/huggingfac…
hanouticelina Mar 18, 2025
1ee3029
Merge branch 'main' into async-calls-falai
hanouticelina Mar 18, 2025
f8a6386
remove text-to-video tests
hanouticelina Mar 18, 2025
4d30eea
Merge branch 'main' of github.com:huggingface/huggingface.js into asy…
hanouticelina Mar 18, 2025
b97f6cf
Merge branch 'async-calls-falai' of github.com:huggingface/huggingfac…
hanouticelina Mar 18, 2025
36c56ed
fix status and result urls construction
hanouticelina Mar 20, 2025
8ec3e55
Merge branch 'main' of github.com:huggingface/huggingface.js into asy…
hanouticelina Mar 20, 2025
3c346af
review suggestions
hanouticelina Mar 24, 2025
1d0f6e2
Merge branch 'main' into async-calls-falai
hanouticelina Mar 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,11 @@ export function makeRequestOptionsFromResolvedModel(
? endpointUrl + `/v1/chat/completions`
: endpointUrl
: providerConfig.makeUrl({
authMethod,
baseUrl:
authMethod !== "provider-key"
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
: providerConfig.baseUrl,
: providerConfig.makeBaseUrl(task),
model: resolvedModel,
chatCompletion,
task,
Expand Down
8 changes: 6 additions & 2 deletions packages/inference/src/providers/black-forest-labs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";

const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";

const makeBaseUrl = (): string => {
return BLACK_FOREST_LABS_AI_API_BASE_URL;
};

const makeBody = (params: BodyParams): Record<string, unknown> => {
return params.args;
};
Expand All @@ -35,7 +39,7 @@ const makeUrl = (params: UrlParams): string => {
};

export const BLACK_FOREST_LABS_CONFIG: ProviderConfig = {
baseUrl: BLACK_FOREST_LABS_AI_API_BASE_URL,
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
Expand Down
8 changes: 6 additions & 2 deletions packages/inference/src/providers/cerebras.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";

const CEREBRAS_API_BASE_URL = "https://api.cerebras.ai";

const makeBaseUrl = (): string => {
return CEREBRAS_API_BASE_URL;
};

const makeBody = (params: BodyParams): Record<string, unknown> => {
return {
...params.args,
Expand All @@ -34,7 +38,7 @@ const makeUrl = (params: UrlParams): string => {
};

export const CEREBRAS_CONFIG: ProviderConfig = {
baseUrl: CEREBRAS_API_BASE_URL,
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
Expand Down
8 changes: 6 additions & 2 deletions packages/inference/src/providers/cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";

const COHERE_API_BASE_URL = "https://api.cohere.com";

const makeBaseUrl = (): string => {
return COHERE_API_BASE_URL;
};

const makeBody = (params: BodyParams): Record<string, unknown> => {
return {
...params.args,
Expand All @@ -34,7 +38,7 @@ const makeUrl = (params: UrlParams): string => {
};

export const COHERE_CONFIG: ProviderConfig = {
baseUrl: COHERE_API_BASE_URL,
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
Expand Down
88 changes: 85 additions & 3 deletions packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
import { InferenceOutputError } from "../lib/InferenceOutputError";
import { isUrl } from "../lib/isUrl";
import type { BodyParams, HeaderParams, InferenceTask, ProviderConfig, UrlParams } from "../types";
import { delay } from "../utils/delay";

const FAL_AI_API_BASE_URL = "https://fal.run";
const FAL_AI_API_BASE_URL_QUEUE = "https://queue.fal.run";

const makeBaseUrl = (task?: InferenceTask): string => {
return task === "text-to-video" ? FAL_AI_API_BASE_URL_QUEUE : FAL_AI_API_BASE_URL;
};

const makeBody = (params: BodyParams): Record<string, unknown> => {
return params.args;
Expand All @@ -29,12 +37,86 @@ const makeHeaders = (params: HeaderParams): Record<string, string> => {
};

const makeUrl = (params: UrlParams): string => {
return `${params.baseUrl}/${params.model}`;
const baseUrl = `${params.baseUrl}/${params.model}`;
if (params.authMethod !== "provider-key" && params.task === "text-to-video") {
return `${baseUrl}?_subdomain=queue`;
}
return baseUrl;
};

export const FAL_AI_CONFIG: ProviderConfig = {
baseUrl: FAL_AI_API_BASE_URL,
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
};

export interface FalAiQueueOutput {
request_id: string;
status: string;
response_url: string;
}

export async function pollFalResponse(
res: FalAiQueueOutput,
url: string,
headers: Record<string, string>
): Promise<Blob> {
const requestId = res.request_id;
if (!requestId) {
throw new InferenceOutputError("No request ID found in the response");
}
let status = res.status;

const parsedUrl = new URL(url);
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
}`;

// extracting the provider model id for status and result urls
// from the response as it might be different from the mapped model in `url`
const modelId = new URL(res.response_url).pathname;
const queryParams = parsedUrl.search;

const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
const resultUrl = `${baseUrl}${modelId}${queryParams}`;

while (status !== "COMPLETED") {
await delay(500);
const statusResponse = await fetch(statusUrl, { headers });

if (!statusResponse.ok) {
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
}
try {
status = (await statusResponse.json()).status;
} catch (error) {
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
}
}

const resultResponse = await fetch(resultUrl, { headers });
let result: unknown;
try {
result = await resultResponse.json();
} catch (error) {
throw new InferenceOutputError("Failed to parse result response from fal-ai API");
}
if (
typeof result === "object" &&
!!result &&
"video" in result &&
typeof result.video === "object" &&
!!result.video &&
"url" in result.video &&
typeof result.video.url === "string" &&
isUrl(result.video.url)
) {
const urlResponse = await fetch(result.video.url);
return await urlResponse.blob();
} else {
throw new InferenceOutputError(
"Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
);
}
}
8 changes: 6 additions & 2 deletions packages/inference/src/providers/fireworks-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";

const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai";

const makeBaseUrl = (): string => {
return FIREWORKS_AI_API_BASE_URL;
};

const makeBody = (params: BodyParams): Record<string, unknown> => {
return {
...params.args,
Expand All @@ -37,7 +41,7 @@ const makeUrl = (params: UrlParams): string => {
};

export const FIREWORKS_AI_CONFIG: ProviderConfig = {
baseUrl: FIREWORKS_AI_API_BASE_URL,
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
Expand Down
8 changes: 6 additions & 2 deletions packages/inference/src/providers/hf-inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
* Thanks!
*/
import { HF_ROUTER_URL } from "../config";
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";

const makeBaseUrl = (): string => {
return `${HF_ROUTER_URL}/hf-inference`;
};

const makeBody = (params: BodyParams): Record<string, unknown> => {
return {
Expand All @@ -36,7 +40,7 @@ const makeUrl = (params: UrlParams): string => {
};

export const HF_INFERENCE_CONFIG: ProviderConfig = {
baseUrl: `${HF_ROUTER_URL}/hf-inference`,
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
Expand Down
8 changes: 6 additions & 2 deletions packages/inference/src/providers/hyperbolic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";

const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";

const makeBaseUrl = (): string => {
return HYPERBOLIC_API_BASE_URL;
};

const makeBody = (params: BodyParams): Record<string, unknown> => {
return {
...params.args,
Expand All @@ -37,7 +41,7 @@ const makeUrl = (params: UrlParams): string => {
};

export const HYPERBOLIC_CONFIG: ProviderConfig = {
baseUrl: HYPERBOLIC_API_BASE_URL,
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
Expand Down
8 changes: 6 additions & 2 deletions packages/inference/src/providers/nebius.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";

const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";

const makeBaseUrl = (): string => {
return NEBIUS_API_BASE_URL;
};

const makeBody = (params: BodyParams): Record<string, unknown> => {
return {
...params.args,
Expand All @@ -43,7 +47,7 @@ const makeUrl = (params: UrlParams): string => {
};

export const NEBIUS_CONFIG: ProviderConfig = {
baseUrl: NEBIUS_API_BASE_URL,
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
Expand Down
7 changes: 5 additions & 2 deletions packages/inference/src/providers/novita.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";

const NOVITA_API_BASE_URL = "https://api.novita.ai";

const makeBaseUrl = (): string => {
return NOVITA_API_BASE_URL;
};
const makeBody = (params: BodyParams): Record<string, unknown> => {
return {
...params.args,
Expand All @@ -41,7 +44,7 @@ const makeUrl = (params: UrlParams): string => {
};

export const NOVITA_CONFIG: ProviderConfig = {
baseUrl: NOVITA_API_BASE_URL,
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
Expand Down
8 changes: 6 additions & 2 deletions packages/inference/src/providers/openai.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
/**
* Special case: provider configuration for a private models provider (OpenAI in this case).
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";

const OPENAI_API_BASE_URL = "https://api.openai.com";

const makeBaseUrl = (): string => {
return OPENAI_API_BASE_URL;
};

const makeBody = (params: BodyParams): Record<string, unknown> => {
if (!params.chatCompletion) {
throw new Error("OpenAI only supports chat completions.");
Expand All @@ -27,7 +31,7 @@ const makeUrl = (params: UrlParams): string => {
};

export const OPENAI_CONFIG: ProviderConfig = {
baseUrl: OPENAI_API_BASE_URL,
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
Expand Down
8 changes: 6 additions & 2 deletions packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";

export const REPLICATE_API_BASE_URL = "https://api.replicate.com";

const makeBaseUrl = (): string => {
return REPLICATE_API_BASE_URL;
};

const makeBody = (params: BodyParams): Record<string, unknown> => {
return {
input: params.args,
Expand All @@ -39,7 +43,7 @@ const makeUrl = (params: UrlParams): string => {
};

export const REPLICATE_CONFIG: ProviderConfig = {
baseUrl: REPLICATE_API_BASE_URL,
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
Expand Down
8 changes: 6 additions & 2 deletions packages/inference/src/providers/sambanova.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";

const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";

const makeBaseUrl = (): string => {
return SAMBANOVA_API_BASE_URL;
};

const makeBody = (params: BodyParams): Record<string, unknown> => {
return {
...params.args,
Expand All @@ -37,7 +41,7 @@ const makeUrl = (params: UrlParams): string => {
};

export const SAMBANOVA_CONFIG: ProviderConfig = {
baseUrl: SAMBANOVA_API_BASE_URL,
makeBaseUrl,
makeBody,
makeHeaders,
makeUrl,
Expand Down
Loading