Skip to content

Commit 3e78986

Browse files
Kaihuang724julien-cconnorchSBrandeis
authored
draft: add hyperbolic support (#1191)
Added Hyperbolic as an inference provider --------- Co-authored-by: Julien Chaumond <[email protected]> Co-authored-by: Connor Chevli <[email protected]> Co-authored-by: SBrandeis <[email protected]>
1 parent 5a394d2 commit 3e78986

File tree

10 files changed

+379
-12
lines changed

10 files changed

+379
-12
lines changed

.github/workflows/test.yml

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ jobs:
4141
run: VCR_MODE=playback pnpm --filter ...[${{ steps.since.outputs.SINCE }}] test
4242
env:
4343
HF_TOKEN: ${{ secrets.HF_TOKEN }}
44+
HF_BLACK_FOREST_LABS_KEY: dummy
4445
HF_FAL_KEY: dummy
46+
HF_FIREWORKS_KEY: dummy
47+
HF_HYPERBOLIC_KEY: dummy
4548
HF_NEBIUS_KEY: dummy
49+
HF_NOVITA_KEY: dummy
4650
HF_REPLICATE_KEY: dummy
4751
HF_SAMBANOVA_KEY: dummy
4852
HF_TOGETHER_KEY: dummy
49-
HF_NOVITA_KEY: dummy
50-
HF_FIREWORKS_KEY: dummy
51-
HF_BLACK_FOREST_LABS_KEY: dummy
5253

5354
browser:
5455
runs-on: ubuntu-latest
@@ -85,14 +86,15 @@ jobs:
8586
run: VCR_MODE=playback pnpm --filter ...[${{ steps.since.outputs.SINCE }}] test:browser
8687
env:
8788
HF_TOKEN: ${{ secrets.HF_TOKEN }}
89+
HF_BLACK_FOREST_LABS_KEY: dummy
8890
HF_FAL_KEY: dummy
91+
HF_FIREWORKS_KEY: dummy
92+
HF_HYPERBOLIC_KEY: dummy
8993
HF_NEBIUS_KEY: dummy
94+
HF_NOVITA_KEY: dummy
9095
HF_REPLICATE_KEY: dummy
9196
HF_SAMBANOVA_KEY: dummy
9297
HF_TOGETHER_KEY: dummy
93-
HF_NOVITA_KEY: dummy
94-
HF_FIREWORKS_KEY: dummy
95-
HF_BLACK_FOREST_LABS_KEY: dummy
9698

9799
e2e:
98100
runs-on: ubuntu-latest
@@ -156,11 +158,12 @@ jobs:
156158
env:
157159
NPM_CONFIG_REGISTRY: http://localhost:4874/
158160
HF_TOKEN: ${{ secrets.HF_TOKEN }}
161+
HF_BLACK_FOREST_LABS_KEY: dummy
159162
HF_FAL_KEY: dummy
163+
HF_FIREWORKS_KEY: dummy
164+
HF_HYPERBOLIC_KEY: dummy
160165
HF_NEBIUS_KEY: dummy
166+
HF_NOVITA_KEY: dummy
161167
HF_REPLICATE_KEY: dummy
162168
HF_SAMBANOVA_KEY: dummy
163169
HF_TOGETHER_KEY: dummy
164-
HF_NOVITA_KEY: dummy
165-
HF_FIREWORKS_KEY: dummy
166-
HF_BLACK_FOREST_LABS_KEY: dummy

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ You can send inference requests to third-party providers with the inference clie
4949
Currently, we support the following providers:
5050
- [Fal.ai](https://fal.ai)
5151
- [Fireworks AI](https://fireworks.ai)
52+
- [Hyperbolic](https://hyperbolic.xyz)
5253
- [Nebius](https://studio.nebius.ai)
5354
- [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link)
5455
- [Replicate](https://replicate.com)
@@ -74,6 +75,7 @@ When authenticated with a third-party provider key, the request is made directly
7475
Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
7576
- [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models)
7677
- [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
78+
- [Hyperbolic supported models](https://huggingface.co/api/partners/hyperbolic/models)
7779
- [Nebius supported models](https://huggingface.co/api/partners/nebius/models)
7880
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
7981
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
66
import { TOGETHER_API_BASE_URL } from "../providers/together";
77
import { NOVITA_API_BASE_URL } from "../providers/novita";
88
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
9+
import { HYPERBOLIC_API_BASE_URL } from "../providers/hyperbolic";
910
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
1011
import type { InferenceProvider } from "../types";
1112
import type { InferenceTask, Options, RequestArgs } from "../types";
@@ -132,7 +133,11 @@ export async function makeRequestOptions(
132133
? args.data
133134
: JSON.stringify({
134135
...otherArgs,
135-
...(chatCompletion || provider === "together" || provider === "nebius" ? { model } : undefined),
136+
...(taskHint === "text-to-image" && provider === "hyperbolic"
137+
? { model_name: model }
138+
: chatCompletion || provider === "together" || provider === "nebius" || provider === "hyperbolic"
139+
? { model }
140+
: undefined),
136141
}),
137142
...(credentials ? { credentials } : undefined),
138143
signal: options?.signal,
@@ -229,6 +234,16 @@ function makeUrl(params: {
229234
}
230235
return baseUrl;
231236
}
237+
case "hyperbolic": {
238+
const baseUrl = shouldProxy
239+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
240+
: HYPERBOLIC_API_BASE_URL;
241+
242+
if (params.taskHint === "text-to-image") {
243+
return `${baseUrl}/v1/images/generations`;
244+
}
245+
return `${baseUrl}/v1/chat/completions`;
246+
}
232247
case "novita": {
233248
const baseUrl = shouldProxy
234249
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
2020
"fal-ai": {},
2121
"fireworks-ai": {},
2222
"hf-inference": {},
23+
hyperbolic: {},
2324
nebius: {},
2425
replicate: {},
2526
sambanova: {},
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
2+
3+
/**
4+
* See the registered mapping of HF model ID => Hyperbolic model ID here:
5+
*
6+
* https://huggingface.co/api/partners/hyperbolic/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Hyperbolic and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Hyperbolic, please open an issue on the present repo
15+
* and we will tag Hyperbolic team members.
16+
*
17+
* Thanks!
18+
*/

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ interface Base64ImageGeneration {
1515
interface OutputUrlImageGeneration {
1616
output: string[];
1717
}
18+
interface HyperbolicTextToImageOutput {
19+
images: Array<{ image: string }>;
20+
}
21+
1822
interface BlackForestLabsResponse {
1923
id: string;
2024
polling_url: string;
@@ -50,7 +54,11 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
5054
prompt: args.inputs,
5155
};
5256
const res = await request<
53-
TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration | BlackForestLabsResponse
57+
| TextToImageOutput
58+
| Base64ImageGeneration
59+
| OutputUrlImageGeneration
60+
| BlackForestLabsResponse
61+
| HyperbolicTextToImageOutput
5462
>(payload, {
5563
...options,
5664
taskHint: "text-to-image",
@@ -64,6 +72,17 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
6472
const image = await fetch(res.images[0].url);
6573
return await image.blob();
6674
}
75+
if (
76+
args.provider === "hyperbolic" &&
77+
"images" in res &&
78+
Array.isArray(res.images) &&
79+
res.images[0] &&
80+
typeof res.images[0].image === "string"
81+
) {
82+
const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
83+
const blob = await base64Response.blob();
84+
return blob;
85+
}
6786
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
6887
const base64Data = res.data[0].b64_json;
6988
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);

packages/inference/src/tasks/nlp/textGeneration.ts

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
88
import type { BaseArgs, Options } from "../../types";
99
import { toArray } from "../../utils/toArray";
1010
import { request } from "../custom/request";
11+
import { omit } from "../../utils/omit";
1112

1213
export type { TextGenerationInput, TextGenerationOutput };
1314

@@ -21,6 +22,12 @@ interface TogeteherTextCompletionOutput extends Omit<ChatCompletionOutput, "choi
2122
}>;
2223
}
2324

25+
interface HyperbolicTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
26+
choices: Array<{
27+
message: { content: string };
28+
}>;
29+
}
30+
2431
/**
2532
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
2633
*/
@@ -43,6 +50,30 @@ export async function textGeneration(
4350
return {
4451
generated_text: completion.text,
4552
};
53+
} else if (args.provider === "hyperbolic") {
54+
const payload = {
55+
messages: [{ content: args.inputs, role: "user" }],
56+
...(args.parameters
57+
? {
58+
max_tokens: args.parameters.max_new_tokens,
59+
...omit(args.parameters, "max_new_tokens"),
60+
}
61+
: undefined),
62+
...omit(args, ["inputs", "parameters"]),
63+
};
64+
const raw = await request<HyperbolicTextCompletionOutput>(payload, {
65+
...options,
66+
taskHint: "text-generation",
67+
});
68+
const isValidOutput =
69+
typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
70+
if (!isValidOutput) {
71+
throw new InferenceOutputError("Expected ChatCompletionOutput");
72+
}
73+
const completion = raw.choices[0];
74+
return {
75+
generated_text: completion.message.content,
76+
};
4677
} else {
4778
const res = toArray(
4879
await request<TextGenerationOutput | TextGenerationOutput[]>(args, {

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ export const INFERENCE_PROVIDERS = [
3333
"fal-ai",
3434
"fireworks-ai",
3535
"hf-inference",
36+
"hyperbolic",
3637
"nebius",
3738
"novita",
3839
"replicate",

packages/inference/test/HfInference.spec.ts

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ import { assert, describe, expect, it } from "vitest";
22

33
import type { ChatCompletionStreamOutput } from "@huggingface/tasks";
44

5-
import { chatCompletion, HfInference, textToImage } from "../src";
5+
import type { TextToImageArgs } from "../src";
6+
import { chatCompletion, chatCompletionStream, HfInference, textGeneration, textToImage } from "../src";
67
import { textToVideo } from "../src/tasks/cv/textToVideo";
78
import { readTestFile } from "./test-files";
89
import "./vcr";
@@ -1176,6 +1177,85 @@ describe.concurrent("HfInference", () => {
11761177
TIMEOUT
11771178
);
11781179

1180+
describe.concurrent(
1181+
"Hyperbolic",
1182+
() => {
1183+
HARDCODED_MODEL_ID_MAPPING.hyperbolic = {
1184+
"meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
1185+
"meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct",
1186+
"stabilityai/stable-diffusion-2": "SD2",
1187+
"meta-llama/Llama-3.1-405B": "meta-llama/Meta-Llama-3.1-405B-Instruct",
1188+
};
1189+
1190+
it("chatCompletion - hyperbolic", async () => {
1191+
const res = await chatCompletion({
1192+
accessToken: env.HF_HYPERBOLIC_KEY,
1193+
model: "meta-llama/Llama-3.2-3B-Instruct",
1194+
provider: "hyperbolic",
1195+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1196+
temperature: 0.1,
1197+
});
1198+
1199+
expect(res).toBeDefined();
1200+
expect(res.choices).toBeDefined();
1201+
expect(res.choices?.length).toBeGreaterThan(0);
1202+
1203+
if (res.choices && res.choices.length > 0) {
1204+
const completion = res.choices[0].message?.content;
1205+
expect(completion).toBeDefined();
1206+
expect(typeof completion).toBe("string");
1207+
expect(completion).toContain("two");
1208+
}
1209+
});
1210+
1211+
it("chatCompletion stream", async () => {
1212+
const stream = chatCompletionStream({
1213+
accessToken: env.HF_HYPERBOLIC_KEY,
1214+
model: "meta-llama/Llama-3.3-70B-Instruct",
1215+
provider: "hyperbolic",
1216+
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
1217+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1218+
let out = "";
1219+
for await (const chunk of stream) {
1220+
if (chunk.choices && chunk.choices.length > 0) {
1221+
out += chunk.choices[0].delta.content;
1222+
}
1223+
}
1224+
expect(out).toContain("2");
1225+
});
1226+
1227+
it("textToImage", async () => {
1228+
const res = await textToImage({
1229+
accessToken: env.HF_HYPERBOLIC_KEY,
1230+
model: "stabilityai/stable-diffusion-2",
1231+
provider: "hyperbolic",
1232+
inputs: "award winning high resolution photo of a giant tortoise",
1233+
parameters: {
1234+
height: 128,
1235+
width: 128,
1236+
},
1237+
} satisfies TextToImageArgs);
1238+
expect(res).toBeInstanceOf(Blob);
1239+
});
1240+
1241+
it("textGeneration", async () => {
1242+
const res = await textGeneration({
1243+
accessToken: env.HF_HYPERBOLIC_KEY,
1244+
model: "meta-llama/Llama-3.1-405B",
1245+
provider: "hyperbolic",
1246+
inputs: "Paris is",
1247+
parameters: {
1248+
temperature: 0,
1249+
top_p: 0.01,
1250+
max_new_tokens: 10,
1251+
},
1252+
});
1253+
expect(res).toMatchObject({ generated_text: "...the capital and most populous city of France," });
1254+
});
1255+
},
1256+
TIMEOUT
1257+
);
1258+
11791259
describe.concurrent(
11801260
"Novita",
11811261
() => {

0 commit comments

Comments
 (0)