Skip to content

Commit 2ce622d

Browse files
committed
text generation working
1 parent 95d33d8 commit 2ce622d

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ export async function makeRequestOptions(
145145
? args.data
146146
: JSON.stringify({
147147
...otherArgs,
148-
...(chatCompletion || provider === "together" || provider === "nebius" ? { model } : undefined),
148+
...(chatCompletion || provider === "together" || provider === "nebius" || provider === "hyperbolic"
149+
? { model }
150+
: undefined),
149151
}),
150152
...(credentials ? { credentials } : undefined),
151153
signal: options?.signal,

packages/inference/src/tasks/nlp/textGeneration.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ interface TogeteherTextCompletionOutput extends Omit<ChatCompletionOutput, "choi
2121
}>;
2222
}
2323

24+
interface HyperbolicTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
25+
choices: Array<{
26+
message: { content: string };
27+
}>;
28+
}
29+
2430
/**
2531
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
2632
*/
@@ -43,6 +49,21 @@ export async function textGeneration(
4349
return {
4450
generated_text: completion.text,
4551
};
52+
} else if (args.provider === "hyperbolic") {
53+
args.prompt = args.inputs;
54+
const raw = await request<HyperbolicTextCompletionOutput>(args, {
55+
...options,
56+
taskHint: "text-generation",
57+
});
58+
const isValidOutput =
59+
typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
60+
if (!isValidOutput) {
61+
throw new InferenceOutputError("Expected ChatCompletionOutput");
62+
}
63+
const completion = raw.choices[0];
64+
return {
65+
generated_text: completion.message.content,
66+
};
4667
} else {
4768
const res = toArray(
4869
await request<TextGenerationOutput | TextGenerationOutput[]>(args, {

packages/inference/test/HfInference.spec.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,7 @@ describe.concurrent("HfInference", () => {
11861186
"meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
11871187
"meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct",
11881188
"stabilityai/stable-diffusion-2": "stabilityai/stable-diffusion-2",
1189-
"meta-llama/Llama-3.1-405B": "meta-llama/Llama-3.1-405B",
1189+
"meta-llama/Llama-3.1-405B": "meta-llama/Meta-Llama-3.1-405B-Instruct",
11901190
};
11911191

11921192
it("chatCompletion - hyperbolic", async () => {
@@ -1244,9 +1244,10 @@ describe.concurrent("HfInference", () => {
12441244
provider: "hyperbolic",
12451245
messages: [{ role: "user", content: "Paris is" }],
12461246
temperature: 0,
1247+
top_p: 0.01,
12471248
max_tokens: 10,
12481249
});
1249-
expect(res).toMatchObject({ generated_text: " city of love" });
1250+
expect(res).toMatchObject({ generated_text: "...the capital and most populous city of France," });
12501251
});
12511252
},
12521253
TIMEOUT

0 commit comments

Comments
 (0)