Skip to content

Commit 62e314a

Browse files
saksham36julien-cSBrandeis
authored
Black Forest Labs Image Models (#1193)
Co-authored-by: Julien Chaumond <[email protected]> Co-authored-by: SBrandeis <[email protected]>
1 parent 57154a5 commit 62e314a

File tree

12 files changed

+190
-8
lines changed

12 files changed

+190
-8
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ jobs:
4848
HF_TOGETHER_KEY: dummy
4949
HF_NOVITA_KEY: dummy
5050
HF_FIREWORKS_KEY: dummy
51+
HF_BLACK_FOREST_LABS_KEY: dummy
5152

5253
browser:
5354
runs-on: ubuntu-latest
@@ -91,6 +92,7 @@ jobs:
9192
HF_TOGETHER_KEY: dummy
9293
HF_NOVITA_KEY: dummy
9394
HF_FIREWORKS_KEY: dummy
95+
HF_BLACK_FOREST_LABS_KEY: dummy
9496

9597
e2e:
9698
runs-on: ubuntu-latest
@@ -161,3 +163,4 @@ jobs:
161163
HF_TOGETHER_KEY: dummy
162164
HF_NOVITA_KEY: dummy
163165
HF_FIREWORKS_KEY: dummy
166+
HF_BLACK_FOREST_LABS_KEY: dummy

packages/inference/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Currently, we support the following providers:
5454
- [Replicate](https://replicate.com)
5555
- [Sambanova](https://sambanova.ai)
5656
- [Together](https://together.xyz)
57+
- [Blackforestlabs](https://blackforestlabs.ai)
5758

5859
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
5960
```ts

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
66
import { TOGETHER_API_BASE_URL } from "../providers/together";
77
import { NOVITA_API_BASE_URL } from "../providers/novita";
88
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
9+
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
910
import type { InferenceProvider } from "../types";
1011
import type { InferenceTask, Options, RequestArgs } from "../types";
1112
import { isUrl } from "./isUrl";
@@ -80,8 +81,13 @@ export async function makeRequestOptions(
8081

8182
const headers: Record<string, string> = {};
8283
if (accessToken) {
83-
headers["Authorization"] =
84-
provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
84+
if (provider === "fal-ai" && authMethod === "provider-key") {
85+
headers["Authorization"] = `Key ${accessToken}`;
86+
} else if (provider === "black-forest-labs" && authMethod === "provider-key") {
87+
headers["X-Key"] = accessToken;
88+
} else {
89+
headers["Authorization"] = `Bearer ${accessToken}`;
90+
}
8591
}
8692

8793
// e.g. @huggingface/inference/3.1.3
@@ -148,6 +154,12 @@ function makeUrl(params: {
148154

149155
const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
150156
switch (params.provider) {
157+
case "black-forest-labs": {
158+
const baseUrl = shouldProxy
159+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
160+
: BLACKFORESTLABS_AI_API_BASE_URL;
161+
return `${baseUrl}/${params.model}`;
162+
}
151163
case "fal-ai": {
152164
const baseUrl = shouldProxy
153165
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
2+
3+
/**
4+
* See the registered mapping of HF model ID => Black Forest Labs model ID here:
5+
*
6+
* https://huggingface.co/api/partners/blackforestlabs/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Black Forest Labs and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Black Forest Labs, please open an issue on the present repo
15+
* and we will tag Black Forest Labs team members.
16+
*
17+
* Thanks!
18+
*/

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
1616
* Example:
1717
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
1818
*/
19+
"black-forest-labs": {},
1920
"fal-ai": {},
2021
"fireworks-ai": {},
2122
"hf-inference": {},

packages/inference/src/providers/novita.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ export const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
1515
* and we will tag Novita team members.
1616
*
1717
* Thanks!
18-
*/
18+
*/

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
33
import type { BaseArgs, InferenceProvider, Options } from "../../types";
44
import { omit } from "../../utils/omit";
55
import { request } from "../custom/request";
6+
import { delay } from "../../utils/delay";
67

78
export type TextToImageArgs = BaseArgs & TextToImageInput;
89

@@ -14,6 +15,10 @@ interface Base64ImageGeneration {
1415
interface OutputUrlImageGeneration {
1516
output: string[];
1617
}
18+
interface BlackForestLabsResponse {
19+
id: string;
20+
polling_url: string;
21+
}
1722

1823
function getResponseFormatArg(provider: InferenceProvider) {
1924
switch (provider) {
@@ -44,12 +49,17 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
4449
...getResponseFormatArg(args.provider),
4550
prompt: args.inputs,
4651
};
47-
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(payload, {
52+
const res = await request<
53+
TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration | BlackForestLabsResponse
54+
>(payload, {
4855
...options,
4956
taskHint: "text-to-image",
5057
});
5158

5259
if (res && typeof res === "object") {
60+
if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
61+
return await pollBflResponse(res.polling_url);
62+
}
5363
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
5464
const image = await fetch(res.images[0].url);
5565
return await image.blob();
@@ -72,3 +82,33 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
7282
}
7383
return res;
7484
}
85+
86+
async function pollBflResponse(url: string): Promise<Blob> {
87+
const urlObj = new URL(url);
88+
for (let step = 0; step < 5; step++) {
89+
await delay(1000);
90+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
91+
urlObj.searchParams.set("attempt", step.toString(10));
92+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
93+
if (!resp.ok) {
94+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
95+
}
96+
const payload = await resp.json();
97+
if (
98+
typeof payload === "object" &&
99+
payload &&
100+
"status" in payload &&
101+
typeof payload.status === "string" &&
102+
payload.status === "Ready" &&
103+
"result" in payload &&
104+
typeof payload.result === "object" &&
105+
payload.result &&
106+
"sample" in payload.result &&
107+
typeof payload.result.sample === "string"
108+
) {
109+
const image = await fetch(payload.result.sample);
110+
return await image.blob();
111+
}
112+
}
113+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
114+
}

packages/inference/src/types.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@ export interface Options {
2929
export type InferenceTask = Exclude<PipelineType, "other">;
3030

3131
export const INFERENCE_PROVIDERS = [
32+
"black-forest-labs",
3233
"fal-ai",
3334
"fireworks-ai",
34-
"nebius",
3535
"hf-inference",
36+
"nebius",
37+
"novita",
3638
"replicate",
3739
"sambanova",
3840
"together",
39-
"novita",
4041
] as const;
4142

4243
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];

packages/inference/src/utils/delay.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
export function delay(ms: number): Promise<void> {
2+
return new Promise((resolve) => {
3+
setTimeout(() => resolve(), ms);
4+
});
5+
}

packages/inference/test/HfInference.spec.ts

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { assert, describe, expect, it } from "vitest";
22

33
import type { ChatCompletionStreamOutput } from "@huggingface/tasks";
44

5-
import { chatCompletion, HfInference } from "../src";
5+
import { chatCompletion, HfInference, textToImage } from "../src";
66
import { textToVideo } from "../src/tasks/cv/textToVideo";
77
import { readTestFile } from "./test-files";
88
import "./vcr";
@@ -1214,4 +1214,30 @@ describe.concurrent("HfInference", () => {
12141214
},
12151215
TIMEOUT
12161216
);
1217+
describe.concurrent(
1218+
"Black Forest Labs",
1219+
() => {
1220+
HARDCODED_MODEL_ID_MAPPING["black-forest-labs"] = {
1221+
"black-forest-labs/FLUX.1-dev": "flux-dev",
1222+
// "black-forest-labs/FLUX.1-schnell": "flux-pro",
1223+
};
1224+
1225+
it("textToImage", async () => {
1226+
const res = await textToImage({
1227+
model: "black-forest-labs/FLUX.1-dev",
1228+
provider: "black-forest-labs",
1229+
accessToken: env.HF_BLACK_FOREST_LABS_KEY,
1230+
inputs: "A raccoon driving a truck",
1231+
parameters: {
1232+
height: 256,
1233+
width: 256,
1234+
num_inference_steps: 4,
1235+
seed: 8817,
1236+
},
1237+
});
1238+
expect(res).toBeInstanceOf(Blob);
1239+
});
1240+
},
1241+
TIMEOUT
1242+
);
12171243
});

packages/inference/test/tapes.json

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6994,5 +6994,80 @@
69946994
"transfer-encoding": "chunked"
69956995
}
69966996
}
6997+
},
6998+
"b320223c78e20541a47c961d89d24f507b0b0257224d91cd05744c93f2d67d2c": {
6999+
"url": "https://api.us1.bfl.ai/v1/flux-dev",
7000+
"init": {
7001+
"headers": {
7002+
"Content-Type": "application/json"
7003+
},
7004+
"method": "POST",
7005+
"body": "{\"height\":256,\"width\":256,\"num_inference_steps\":4,\"seed\":8817,\"prompt\":\"A raccoon driving a truck\"}"
7006+
},
7007+
"response": {
7008+
"body": "{\"id\":\"dd8b1c92-587b-428c-9095-d8c12a641160\",\"polling_url\":\"https://api.us1.bfl.ai/v1/get_result?id=dd8b1c92-587b-428c-9095-d8c12a641160\"}",
7009+
"status": 200,
7010+
"statusText": "OK",
7011+
"headers": {
7012+
"connection": "keep-alive",
7013+
"content-type": "application/json",
7014+
"strict-transport-security": "max-age=31536000; includeSubDomains"
7015+
}
7016+
}
7017+
},
7018+
"23eefbade142f7a1e33d50dd6bfaf56e7b959689f7990025db1b353214890a03": {
7019+
"url": "https://api.us1.bfl.ai/v1/get_result?id=dd8b1c92-587b-428c-9095-d8c12a641160&attempt=0",
7020+
"init": {
7021+
"headers": {
7022+
"Content-Type": "application/json"
7023+
}
7024+
},
7025+
"response": {
7026+
"body": "{\"id\":\"dd8b1c92-587b-428c-9095-d8c12a641160\",\"status\":\"Pending\",\"result\":null,\"progress\":0.6}",
7027+
"status": 200,
7028+
"statusText": "OK",
7029+
"headers": {
7030+
"connection": "keep-alive",
7031+
"content-type": "application/json",
7032+
"retry-after": "1",
7033+
"strict-transport-security": "max-age=31536000; includeSubDomains"
7034+
}
7035+
}
7036+
},
7037+
"5803254b4092ae6ac445292c617480002607bb30cce9ba8dc37ce9bb2754f94b": {
7038+
"url": "https://api.us1.bfl.ai/v1/get_result?id=dd8b1c92-587b-428c-9095-d8c12a641160&attempt=1",
7039+
"init": {
7040+
"headers": {
7041+
"Content-Type": "application/json"
7042+
}
7043+
},
7044+
"response": {
7045+
"body": "{\"id\":\"dd8b1c92-587b-428c-9095-d8c12a641160\",\"status\":\"Ready\",\"result\":{\"sample\":\"https://delivery-us1.bfl.ai/results/aa7ab8da64b946dca070d455854a0c3e/sample.jpeg?se=2025-02-13T16%3A12%3A37Z&sp=r&sv=2024-11-04&sr=b&rsct=image/jpeg&sig=u0wzLXKBr8dCMnk9US51zQs7Ma/x/l0lEJvEM3pMUrA%3D\",\"prompt\":\"A raccoon driving a truck\",\"seed\":8817,\"start_time\":1739462555.336884,\"end_time\":1739462557.9051642,\"duration\":2.5682802200317383},\"progress\":null}",
7046+
"status": 200,
7047+
"statusText": "OK",
7048+
"headers": {
7049+
"connection": "keep-alive",
7050+
"content-type": "application/json",
7051+
"retry-after": "1",
7052+
"strict-transport-security": "max-age=31536000; includeSubDomains"
7053+
}
7054+
}
7055+
},
7056+
"548ead8522302cb1123833c27e33d193a8fd619633271414bd2d84e1a71469f0": {
7057+
"url": "https://delivery-us1.bfl.ai/results/aa7ab8da64b946dca070d455854a0c3e/sample.jpeg?se=2025-02-13T16%3A12%3A37Z&sp=r&sv=2024-11-04&sr=b&rsct=image/jpeg&sig=u0wzLXKBr8dCMnk9US51zQs7Ma/x/l0lEJvEM3pMUrA%3D",
7058+
"init": {},
7059+
"response": {
7060+
"body": "",
7061+
"status": 200,
7062+
"statusText": "OK",
7063+
"headers": {
7064+
"accept-ranges": "bytes",
7065+
"connection": "keep-alive",
7066+
"content-md5": "RRVPlYWCAb46mkS5lSs7RQ==",
7067+
"content-type": "image/jpeg",
7068+
"etag": "\"0x8DD4C47D65702E7\"",
7069+
"last-modified": "Thu, 13 Feb 2025 16:02:37 GMT"
7070+
}
7071+
}
69977072
}
69987073
}

packages/inference/test/vcr.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ async function vcr(
181181
const tape: Tape = {
182182
url,
183183
init: {
184-
headers: init.headers && omit(init.headers as Record<string, string>, ["Authorization", "User-Agent"]),
184+
headers: init.headers && omit(init.headers as Record<string, string>, ["Authorization", "User-Agent", "X-Key"]),
185185
method: init.method,
186186
body: typeof init.body === "string" && init.body.length < 1_000 ? init.body : undefined,
187187
},

0 commit comments

Comments
 (0)