Skip to content

Commit 4f04904

Browse files
authored
✨ Add inferenceProvider filter when listing models (#1198)
This PR adds the missing query param `inference_provider` when listing models (`listModels`).
1 parent 3e78986 commit 4f04904

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

packages/hub/src/lib/list-models.spec.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,39 @@ describe("listModels", () => {
8080

8181
expect(count).to.equal(10);
8282
});
83+
84+
it("should search model by inference provider", async () => {
85+
let count = 0;
86+
for await (const entry of listModels({
87+
search: { inferenceProviders: ["together"] },
88+
additionalFields: ["inferenceProviderMapping"],
89+
limit: 10,
90+
})) {
91+
count++;
92+
if (Array.isArray(entry.inferenceProviderMapping)) {
93+
expect(entry.inferenceProviderMapping.map(({ provider }) => provider)).to.include("together");
94+
}
95+
}
96+
97+
expect(count).to.equal(10);
98+
});
99+
100+
it("should search model by several inference providers", async () => {
101+
let count = 0;
102+
const inferenceProviders = ["together", "replicate"];
103+
for await (const entry of listModels({
104+
search: { inferenceProviders },
105+
additionalFields: ["inferenceProviderMapping"],
106+
limit: 10,
107+
})) {
108+
count++;
109+
if (Array.isArray(entry.inferenceProviderMapping)) {
110+
expect(
111+
entry.inferenceProviderMapping.filter(({ provider }) => inferenceProviders.includes(provider)).length
112+
).toBeGreaterThan(0);
113+
}
114+
}
115+
116+
expect(count).to.equal(10);
117+
});
83118
});

packages/hub/src/lib/list-models.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ export async function* listModels<
6363
owner?: string;
6464
task?: PipelineType;
6565
tags?: string[];
66+
/**
67+
* Will search for models that have one of the inference providers in the list.
68+
*/
69+
inferenceProviders?: string[];
6670
};
6771
hubUrl?: string;
6872
additionalFields?: T[];
@@ -84,6 +88,9 @@ export async function* listModels<
8488
...(params?.search?.owner ? { author: params.search.owner } : undefined),
8589
...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined),
8690
...(params?.search?.query ? { search: params.search.query } : undefined),
91+
...(params?.search?.inferenceProviders
92+
? { inference_provider: params.search.inferenceProviders.join(",") }
93+
: undefined),
8794
}),
8895
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
8996
...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),

0 commit comments

Comments
 (0)