Skip to content

Commit 6ed9d44

Browse files
nbarr07SBrandeishanouticelina
authored
Integrate Nscale-cloud into HuggingFace inference (#1260)
# Integrate Nscale provider for HuggingFace Inference This PR adds support for Nscale inference to the HuggingFace inference API. Note that our inference service is not publicly live yet but will be soon - this draft PR is for review and preparation. - Implemented standard provider integration for nscale - Added support for text-generation, conversational, and text-to-image tasks - Included tests following the established patterns The tests were all passing when I tried. Any feedback are welcomed! --------- Co-authored-by: SBrandeis <[email protected]> Co-authored-by: célina <[email protected]> Co-authored-by: Simon Brandeis <[email protected]>
1 parent b5bb4f4 commit 6ed9d44

File tree

6 files changed

+149
-0
lines changed

6 files changed

+149
-0
lines changed

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Currently, we support the following providers:
5252
- [Hyperbolic](https://hyperbolic.xyz)
5353
- [Nebius](https://studio.nebius.ai)
5454
- [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link)
55+
- [Nscale](https://nscale.com)
5556
- [Replicate](https://replicate.com)
5657
- [Sambanova](https://sambanova.ai)
5758
- [Together](https://together.xyz)
@@ -79,6 +80,7 @@ Only a subset of models are supported when requesting third-party providers. You
7980
- [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
8081
- [Hyperbolic supported models](https://huggingface.co/api/partners/hyperbolic/models)
8182
- [Nebius supported models](https://huggingface.co/api/partners/nebius/models)
83+
- [Nscale supported models](https://huggingface.co/api/partners/nscale/models)
8284
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
8385
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
8486
- [Together supported models](https://huggingface.co/api/partners/together/models)

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import * as HFInference from "../providers/hf-inference";
88
import * as Hyperbolic from "../providers/hyperbolic";
99
import * as Nebius from "../providers/nebius";
1010
import * as Novita from "../providers/novita";
11+
import * as Nscale from "../providers/nscale";
1112
import * as OpenAI from "../providers/openai";
1213
import type {
1314
AudioClassificationTaskHelper,
@@ -109,6 +110,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
109110
conversational: new Novita.NovitaConversationalTask(),
110111
"text-generation": new Novita.NovitaTextGenerationTask(),
111112
},
113+
nscale: {
114+
"text-to-image": new Nscale.NscaleTextToImageTask(),
115+
conversational: new Nscale.NscaleConversationalTask(),
116+
},
112117
openai: {
113118
conversational: new OpenAI.OpenAIConversationalTask(),
114119
},

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
2828
hyperbolic: {},
2929
nebius: {},
3030
novita: {},
31+
nscale: {},
3132
openai: {},
3233
replicate: {},
3334
sambanova: {},
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/**
2+
* See the registered mapping of HF model ID => Nscale model ID here:
3+
*
4+
* https://huggingface.co/api/partners/nscale-cloud/models
5+
*
6+
* This is a publicly available mapping.
7+
*
8+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10+
*
11+
* - If you work at Nscale and want to update this mapping, please use the model mapping API we provide on huggingface.co
12+
* - If you're a community member and want to add a new supported HF model to Nscale, please open an issue on the present repo
13+
* and we will tag Nscale team members.
14+
*
15+
* Thanks!
16+
*/
17+
import type { TextToImageInput } from "@huggingface/tasks";
18+
import { InferenceOutputError } from "../lib/InferenceOutputError";
19+
import type { BodyParams } from "../types";
20+
import { omit } from "../utils/omit";
21+
import { BaseConversationalTask, TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper";
22+
23+
const NSCALE_API_BASE_URL = "https://inference.api.nscale.com";
24+
25+
interface NscaleCloudBase64ImageGeneration {
26+
data: Array<{
27+
b64_json: string;
28+
}>;
29+
}
30+
31+
export class NscaleConversationalTask extends BaseConversationalTask {
32+
constructor() {
33+
super("nscale", NSCALE_API_BASE_URL);
34+
}
35+
}
36+
37+
export class NscaleTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
38+
constructor() {
39+
super("nscale", NSCALE_API_BASE_URL);
40+
}
41+
42+
preparePayload(params: BodyParams<TextToImageInput>): Record<string, unknown> {
43+
return {
44+
...omit(params.args, ["inputs", "parameters"]),
45+
...params.args.parameters,
46+
response_format: "b64_json",
47+
prompt: params.args.inputs,
48+
model: params.model,
49+
};
50+
}
51+
52+
makeRoute(): string {
53+
return "v1/images/generations";
54+
}
55+
56+
async getResponse(
57+
response: NscaleCloudBase64ImageGeneration,
58+
url?: string,
59+
headers?: HeadersInit,
60+
outputType?: "url" | "blob"
61+
): Promise<string | Blob> {
62+
if (
63+
typeof response === "object" &&
64+
"data" in response &&
65+
Array.isArray(response.data) &&
66+
response.data.length > 0 &&
67+
"b64_json" in response.data[0] &&
68+
typeof response.data[0].b64_json === "string"
69+
) {
70+
const base64Data = response.data[0].b64_json;
71+
if (outputType === "url") {
72+
return `data:image/jpeg;base64,${base64Data}`;
73+
}
74+
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
75+
}
76+
77+
throw new InferenceOutputError("Expected Nscale text-to-image response format");
78+
}
79+
}

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ export const INFERENCE_PROVIDERS = [
4747
"hyperbolic",
4848
"nebius",
4949
"novita",
50+
"nscale",
5051
"openai",
5152
"replicate",
5253
"sambanova",

packages/inference/test/InferenceClient.spec.ts

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,4 +1690,65 @@ describe.skip("InferenceClient", () => {
16901690
},
16911691
TIMEOUT
16921692
);
1693+
describe.concurrent(
1694+
"Nscale",
1695+
() => {
1696+
const client = new InferenceClient(env.HF_NSCALE_KEY ?? "dummy");
1697+
1698+
HARDCODED_MODEL_INFERENCE_MAPPING["nscale"] = {
1699+
"meta-llama/Llama-3.1-8B-Instruct": {
1700+
hfModelId: "meta-llama/Llama-3.1-8B-Instruct",
1701+
providerId: "nscale",
1702+
status: "live",
1703+
task: "conversational",
1704+
},
1705+
"black-forest-labs/FLUX.1-schnell": {
1706+
hfModelId: "black-forest-labs/FLUX.1-schnell",
1707+
providerId: "flux-schnell",
1708+
status: "live",
1709+
task: "text-to-image",
1710+
},
1711+
};
1712+
1713+
it("chatCompletion", async () => {
1714+
const res = await client.chatCompletion({
1715+
model: "meta-llama/Llama-3.1-8B-Instruct",
1716+
provider: "nscale",
1717+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1718+
});
1719+
if (res.choices && res.choices.length > 0) {
1720+
const completion = res.choices[0].message?.content;
1721+
expect(completion).toContain("two");
1722+
}
1723+
});
1724+
it("chatCompletion stream", async () => {
1725+
const stream = client.chatCompletionStream({
1726+
model: "meta-llama/Llama-3.1-8B-Instruct",
1727+
provider: "nscale",
1728+
messages: [{ role: "user", content: "Say 'this is a test'" }],
1729+
stream: true,
1730+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1731+
let fullResponse = "";
1732+
for await (const chunk of stream) {
1733+
if (chunk.choices && chunk.choices.length > 0) {
1734+
const content = chunk.choices[0].delta?.content;
1735+
if (content) {
1736+
fullResponse += content;
1737+
}
1738+
}
1739+
}
1740+
expect(fullResponse).toBeTruthy();
1741+
expect(fullResponse.length).toBeGreaterThan(0);
1742+
});
1743+
it("textToImage", async () => {
1744+
const res = await client.textToImage({
1745+
model: "black-forest-labs/FLUX.1-schnell",
1746+
provider: "nscale",
1747+
inputs: "An astronaut riding a horse",
1748+
});
1749+
expect(res).toBeInstanceOf(Blob);
1750+
});
1751+
},
1752+
TIMEOUT
1753+
);
16931754
});

0 commit comments

Comments
 (0)