Skip to content

Commit c0f38b0

Browse files
hanouticelinajulien-cSBrandeis
authored
[Inference Providers] Async calls for text-to-video with fal.ai (#1292)
## What does this PR do? This PR adds asynchronous polling to the fal.ai text-to-video generation. This allows running inference with models that may take > 2 min to generate results. The other motivation behind this PR is to align the Python and JS clients, the Python client has already been merged into main: huggingface/huggingface_hub#2927 ## Main Changes - Replaced static `baseUrl` property with `makeBaseUrl()` function across all providers. This is needed to be able to customize the base url based on the task. We want to use `FAL_AI_API_BASE_URL_QUEUE` for `text-to-video` only. I'm not convinced if it's the simplest and the best way to do that. - Added a `pollFalResponse()` for `text-to-video`(similarly to what it's done with BFL for `text-to-image`). Any refactoring suggestions are welcome! I'm willing to spend some additional time to make provider-specific updates easier to implement and better align our two clients 🙂 btw, I did not update the VCR tests as we've discussed that it'd be best to remove the VCR for `text-to-video`. Maybe we should remove them here? **EDIT**: removed the text-to-video tests in [f8a6386](f8a6386). I've tested it locally with [tencent/HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo) for which the generation takes more than 2min and it works fine: https://github.com/user-attachments/assets/3cd38900-c4ed-4b28-ae79-8a4e724f58d1 --------- Co-authored-by: Julien Chaumond <[email protected]> Co-authored-by: Simon Brandeis <[email protected]>
1 parent 49f23e1 commit c0f38b0

18 files changed

+168
-501
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,11 @@ export function makeRequestOptionsFromResolvedModel(
143143
? endpointUrl + `/v1/chat/completions`
144144
: endpointUrl
145145
: providerConfig.makeUrl({
146+
authMethod,
146147
baseUrl:
147148
authMethod !== "provider-key"
148149
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider)
149-
: providerConfig.baseUrl,
150+
: providerConfig.makeBaseUrl(task),
150151
model: resolvedModel,
151152
chatCompletion,
152153
task,

packages/inference/src/providers/black-forest-labs.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
*
1515
* Thanks!
1616
*/
17-
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
17+
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
1818

1919
const BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
2020

21+
const makeBaseUrl = (): string => {
22+
return BLACK_FOREST_LABS_AI_API_BASE_URL;
23+
};
24+
2125
const makeBody = (params: BodyParams): Record<string, unknown> => {
2226
return params.args;
2327
};
@@ -35,7 +39,7 @@ const makeUrl = (params: UrlParams): string => {
3539
};
3640

3741
export const BLACK_FOREST_LABS_CONFIG: ProviderConfig = {
38-
baseUrl: BLACK_FOREST_LABS_AI_API_BASE_URL,
42+
makeBaseUrl,
3943
makeBody,
4044
makeHeaders,
4145
makeUrl,

packages/inference/src/providers/cerebras.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
*
1515
* Thanks!
1616
*/
17-
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
17+
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
1818

1919
const CEREBRAS_API_BASE_URL = "https://api.cerebras.ai";
2020

21+
const makeBaseUrl = (): string => {
22+
return CEREBRAS_API_BASE_URL;
23+
};
24+
2125
const makeBody = (params: BodyParams): Record<string, unknown> => {
2226
return {
2327
...params.args,
@@ -34,7 +38,7 @@ const makeUrl = (params: UrlParams): string => {
3438
};
3539

3640
export const CEREBRAS_CONFIG: ProviderConfig = {
37-
baseUrl: CEREBRAS_API_BASE_URL,
41+
makeBaseUrl,
3842
makeBody,
3943
makeHeaders,
4044
makeUrl,

packages/inference/src/providers/cohere.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
*
1515
* Thanks!
1616
*/
17-
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
17+
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
1818

1919
const COHERE_API_BASE_URL = "https://api.cohere.com";
2020

21+
const makeBaseUrl = (): string => {
22+
return COHERE_API_BASE_URL;
23+
};
24+
2125
const makeBody = (params: BodyParams): Record<string, unknown> => {
2226
return {
2327
...params.args,
@@ -34,7 +38,7 @@ const makeUrl = (params: UrlParams): string => {
3438
};
3539

3640
export const COHERE_CONFIG: ProviderConfig = {
37-
baseUrl: COHERE_API_BASE_URL,
41+
makeBaseUrl,
3842
makeBody,
3943
makeHeaders,
4044
makeUrl,

packages/inference/src/providers/fal-ai.ts

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,17 @@
1414
*
1515
* Thanks!
1616
*/
17-
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
17+
import { InferenceOutputError } from "../lib/InferenceOutputError";
18+
import { isUrl } from "../lib/isUrl";
19+
import type { BodyParams, HeaderParams, InferenceTask, ProviderConfig, UrlParams } from "../types";
20+
import { delay } from "../utils/delay";
1821

1922
const FAL_AI_API_BASE_URL = "https://fal.run";
23+
const FAL_AI_API_BASE_URL_QUEUE = "https://queue.fal.run";
24+
25+
const makeBaseUrl = (task?: InferenceTask): string => {
26+
return task === "text-to-video" ? FAL_AI_API_BASE_URL_QUEUE : FAL_AI_API_BASE_URL;
27+
};
2028

2129
const makeBody = (params: BodyParams): Record<string, unknown> => {
2230
return params.args;
@@ -29,12 +37,86 @@ const makeHeaders = (params: HeaderParams): Record<string, string> => {
2937
};
3038

3139
const makeUrl = (params: UrlParams): string => {
32-
return `${params.baseUrl}/${params.model}`;
40+
const baseUrl = `${params.baseUrl}/${params.model}`;
41+
if (params.authMethod !== "provider-key" && params.task === "text-to-video") {
42+
return `${baseUrl}?_subdomain=queue`;
43+
}
44+
return baseUrl;
3345
};
3446

3547
export const FAL_AI_CONFIG: ProviderConfig = {
36-
baseUrl: FAL_AI_API_BASE_URL,
48+
makeBaseUrl,
3749
makeBody,
3850
makeHeaders,
3951
makeUrl,
4052
};
53+
54+
export interface FalAiQueueOutput {
55+
request_id: string;
56+
status: string;
57+
response_url: string;
58+
}
59+
60+
export async function pollFalResponse(
61+
res: FalAiQueueOutput,
62+
url: string,
63+
headers: Record<string, string>
64+
): Promise<Blob> {
65+
const requestId = res.request_id;
66+
if (!requestId) {
67+
throw new InferenceOutputError("No request ID found in the response");
68+
}
69+
let status = res.status;
70+
71+
const parsedUrl = new URL(url);
72+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
73+
parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
74+
}`;
75+
76+
// extracting the provider model id for status and result urls
77+
// from the response as it might be different from the mapped model in `url`
78+
const modelId = new URL(res.response_url).pathname;
79+
const queryParams = parsedUrl.search;
80+
81+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
82+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
83+
84+
while (status !== "COMPLETED") {
85+
await delay(500);
86+
const statusResponse = await fetch(statusUrl, { headers });
87+
88+
if (!statusResponse.ok) {
89+
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
90+
}
91+
try {
92+
status = (await statusResponse.json()).status;
93+
} catch (error) {
94+
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
95+
}
96+
}
97+
98+
const resultResponse = await fetch(resultUrl, { headers });
99+
let result: unknown;
100+
try {
101+
result = await resultResponse.json();
102+
} catch (error) {
103+
throw new InferenceOutputError("Failed to parse result response from fal-ai API");
104+
}
105+
if (
106+
typeof result === "object" &&
107+
!!result &&
108+
"video" in result &&
109+
typeof result.video === "object" &&
110+
!!result.video &&
111+
"url" in result.video &&
112+
typeof result.video.url === "string" &&
113+
isUrl(result.video.url)
114+
) {
115+
const urlResponse = await fetch(result.video.url);
116+
return await urlResponse.blob();
117+
} else {
118+
throw new InferenceOutputError(
119+
"Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
120+
);
121+
}
122+
}

packages/inference/src/providers/fireworks-ai.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
*
1515
* Thanks!
1616
*/
17-
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
17+
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
1818

1919
const FIREWORKS_AI_API_BASE_URL = "https://api.fireworks.ai";
2020

21+
const makeBaseUrl = (): string => {
22+
return FIREWORKS_AI_API_BASE_URL;
23+
};
24+
2125
const makeBody = (params: BodyParams): Record<string, unknown> => {
2226
return {
2327
...params.args,
@@ -37,7 +41,7 @@ const makeUrl = (params: UrlParams): string => {
3741
};
3842

3943
export const FIREWORKS_AI_CONFIG: ProviderConfig = {
40-
baseUrl: FIREWORKS_AI_API_BASE_URL,
44+
makeBaseUrl,
4145
makeBody,
4246
makeHeaders,
4347
makeUrl,

packages/inference/src/providers/hf-inference.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
* Thanks!
1212
*/
1313
import { HF_ROUTER_URL } from "../config";
14-
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
14+
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
15+
16+
const makeBaseUrl = (): string => {
17+
return `${HF_ROUTER_URL}/hf-inference`;
18+
};
1519

1620
const makeBody = (params: BodyParams): Record<string, unknown> => {
1721
return {
@@ -36,7 +40,7 @@ const makeUrl = (params: UrlParams): string => {
3640
};
3741

3842
export const HF_INFERENCE_CONFIG: ProviderConfig = {
39-
baseUrl: `${HF_ROUTER_URL}/hf-inference`,
43+
makeBaseUrl,
4044
makeBody,
4145
makeHeaders,
4246
makeUrl,

packages/inference/src/providers/hyperbolic.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
*
1515
* Thanks!
1616
*/
17-
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
17+
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
1818

1919
const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
2020

21+
const makeBaseUrl = (): string => {
22+
return HYPERBOLIC_API_BASE_URL;
23+
};
24+
2125
const makeBody = (params: BodyParams): Record<string, unknown> => {
2226
return {
2327
...params.args,
@@ -37,7 +41,7 @@ const makeUrl = (params: UrlParams): string => {
3741
};
3842

3943
export const HYPERBOLIC_CONFIG: ProviderConfig = {
40-
baseUrl: HYPERBOLIC_API_BASE_URL,
44+
makeBaseUrl,
4145
makeBody,
4246
makeHeaders,
4347
makeUrl,

packages/inference/src/providers/nebius.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
*
1515
* Thanks!
1616
*/
17-
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
17+
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
1818

1919
const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
2020

21+
const makeBaseUrl = (): string => {
22+
return NEBIUS_API_BASE_URL;
23+
};
24+
2125
const makeBody = (params: BodyParams): Record<string, unknown> => {
2226
return {
2327
...params.args,
@@ -43,7 +47,7 @@ const makeUrl = (params: UrlParams): string => {
4347
};
4448

4549
export const NEBIUS_CONFIG: ProviderConfig = {
46-
baseUrl: NEBIUS_API_BASE_URL,
50+
makeBaseUrl,
4751
makeBody,
4852
makeHeaders,
4953
makeUrl,

packages/inference/src/providers/novita.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
*
1515
* Thanks!
1616
*/
17-
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
17+
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
1818

1919
const NOVITA_API_BASE_URL = "https://api.novita.ai";
2020

21+
const makeBaseUrl = (): string => {
22+
return NOVITA_API_BASE_URL;
23+
};
2124
const makeBody = (params: BodyParams): Record<string, unknown> => {
2225
return {
2326
...params.args,
@@ -41,7 +44,7 @@ const makeUrl = (params: UrlParams): string => {
4144
};
4245

4346
export const NOVITA_CONFIG: ProviderConfig = {
44-
baseUrl: NOVITA_API_BASE_URL,
47+
makeBaseUrl,
4548
makeBody,
4649
makeHeaders,
4750
makeUrl,

packages/inference/src/providers/openai.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
/**
22
* Special case: provider configuration for a private models provider (OpenAI in this case).
33
*/
4-
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
4+
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
55

66
const OPENAI_API_BASE_URL = "https://api.openai.com";
77

8+
const makeBaseUrl = (): string => {
9+
return OPENAI_API_BASE_URL;
10+
};
11+
812
const makeBody = (params: BodyParams): Record<string, unknown> => {
913
if (!params.chatCompletion) {
1014
throw new Error("OpenAI only supports chat completions.");
@@ -27,7 +31,7 @@ const makeUrl = (params: UrlParams): string => {
2731
};
2832

2933
export const OPENAI_CONFIG: ProviderConfig = {
30-
baseUrl: OPENAI_API_BASE_URL,
34+
makeBaseUrl,
3135
makeBody,
3236
makeHeaders,
3337
makeUrl,

packages/inference/src/providers/replicate.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
*
1515
* Thanks!
1616
*/
17-
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
17+
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
1818

1919
export const REPLICATE_API_BASE_URL = "https://api.replicate.com";
2020

21+
const makeBaseUrl = (): string => {
22+
return REPLICATE_API_BASE_URL;
23+
};
24+
2125
const makeBody = (params: BodyParams): Record<string, unknown> => {
2226
return {
2327
input: params.args,
@@ -39,7 +43,7 @@ const makeUrl = (params: UrlParams): string => {
3943
};
4044

4145
export const REPLICATE_CONFIG: ProviderConfig = {
42-
baseUrl: REPLICATE_API_BASE_URL,
46+
makeBaseUrl,
4347
makeBody,
4448
makeHeaders,
4549
makeUrl,

packages/inference/src/providers/sambanova.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
*
1515
* Thanks!
1616
*/
17-
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
17+
import type { BodyParams, HeaderParams, ProviderConfig, UrlParams } from "../types";
1818

1919
const SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai";
2020

21+
const makeBaseUrl = (): string => {
22+
return SAMBANOVA_API_BASE_URL;
23+
};
24+
2125
const makeBody = (params: BodyParams): Record<string, unknown> => {
2226
return {
2327
...params.args,
@@ -37,7 +41,7 @@ const makeUrl = (params: UrlParams): string => {
3741
};
3842

3943
export const SAMBANOVA_CONFIG: ProviderConfig = {
40-
baseUrl: SAMBANOVA_API_BASE_URL,
44+
makeBaseUrl,
4145
makeBody,
4246
makeHeaders,
4347
makeUrl,

0 commit comments

Comments
 (0)