Skip to content

Commit de26ffa

Browse files
committed
tests: use task methods for better typing + match hf.js API
1 parent 6bdd200 commit de26ffa

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

packages/inference/test/HfInference.spec.ts

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { assert, describe, expect, it } from "vitest";
33
import type { ChatCompletionStreamOutput } from "@huggingface/tasks";
44

55
import type { TextToImageArgs } from "../src";
6-
import { chatCompletion, HfInference } from "../src";
6+
import { chatCompletion, chatCompletionStream, HfInference, textGeneration, textToImage } from "../src";
77
import { textToVideo } from "../src/tasks/cv/textToVideo";
88
import { readTestFile } from "./test-files";
99
import "./vcr";
@@ -1180,8 +1180,6 @@ describe.concurrent("HfInference", () => {
11801180
describe.concurrent(
11811181
"Hyperbolic",
11821182
() => {
1183-
const client = new HfInference(env.HF_HYPERBOLIC_KEY);
1184-
11851183
HARDCODED_MODEL_ID_MAPPING.hyperbolic = {
11861184
"meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
11871185
"meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct",
@@ -1190,7 +1188,8 @@ describe.concurrent("HfInference", () => {
11901188
};
11911189

11921190
it("chatCompletion - hyperbolic", async () => {
1193-
const res = await client.chatCompletion({
1191+
const res = await chatCompletion({
1192+
accessToken: env.HF_HYPERBOLIC_KEY,
11941193
model: "meta-llama/Llama-3.2-3B-Instruct",
11951194
provider: "hyperbolic",
11961195
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
@@ -1210,7 +1209,8 @@ describe.concurrent("HfInference", () => {
12101209
});
12111210

12121211
it("chatCompletion stream", async () => {
1213-
const stream = client.chatCompletionStream({
1212+
const stream = chatCompletionStream({
1213+
accessToken: env.HF_HYPERBOLIC_KEY,
12141214
model: "meta-llama/Llama-3.3-70B-Instruct",
12151215
provider: "hyperbolic",
12161216
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
@@ -1225,12 +1225,12 @@ describe.concurrent("HfInference", () => {
12251225
});
12261226

12271227
it("textToImage", async () => {
1228-
const res = await client.textToImage({
1228+
const res = await textToImage({
1229+
accessToken: env.HF_HYPERBOLIC_KEY,
12291230
model: "stabilityai/stable-diffusion-2",
12301231
provider: "hyperbolic",
12311232
inputs: "award winning high resolution photo of a giant tortoise",
12321233
parameters: {
1233-
model_name: "SD2",
12341234
height: 128,
12351235
width: 128,
12361236
},
@@ -1239,13 +1239,16 @@ describe.concurrent("HfInference", () => {
12391239
});
12401240

12411241
it("textGeneration", async () => {
1242-
const res = await client.textGeneration({
1242+
const res = await textGeneration({
1243+
accessToken: env.HF_HYPERBOLIC_KEY,
12431244
model: "meta-llama/Llama-3.1-405B",
12441245
provider: "hyperbolic",
1245-
messages: [{ role: "user", content: "Paris is" }],
1246-
temperature: 0,
1247-
top_p: 0.01,
1248-
max_tokens: 10,
1246+
inputs: "Paris is",
1247+
parameters: {
1248+
temperature: 0,
1249+
top_p: 0.01,
1250+
max_new_tokens: 10,
1251+
}
12491252
});
12501253
expect(res).toMatchObject({ generated_text: "...the capital and most populous city of France," });
12511254
});

0 commit comments

Comments
 (0)