Skip to content

Commit 0bcfcd7

Browse files
axel7083coyotte508
andauthored
feat(hub): adding snapshot download method (#1038)
## Description We can now create a `snapshotDownload` method similator to the `snapshot_download` of the PY lib[^1], clone to the cache (only cache supported for now) a repository (either model, space or dataset) [^1]: https://huggingface.co/docs/huggingface_hub/en/guides/download#download-an-entire-repository ## Related issues/PR With the amazing help of @coyotte508 we were able to merge the following changes - #1034 - #1031 - #999 Which allow this PR to provide a python compliant clone of a hugging face repository to the cache directory. ## Testing - [x] unit tests are covering the new feature **Manually** ```ts await snapshotDownload({ repo: { name: 'OuteAI/OuteTTS-0.1-350M', type: 'model', }, }); ``` assert using the `huggingface-cli` tool (python) ``` $: huggingface-cli scan-cache REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH ----------------------------------- --------- ------------ -------- ----------------- ----------------- ---- ---------------------------------------------------------------------------------- OuteAI/OuteTTS-0.1-350M model 731.6M 14 5 minutes ago 5 minutes ago main /home/axel7083/.cache/huggingface/hub/models--OuteAI--OuteTTS-0.1-350M ``` --------- Co-authored-by: Eliott C. <[email protected]>
1 parent 352320e commit 0bcfcd7

File tree

6 files changed

+446
-2
lines changed

6 files changed

+446
-2
lines changed

packages/hub/README.md

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,51 @@ Checkout the demo: https://huggingface.co/spaces/huggingfacejs/client-side-oauth
117117

118118
The `@huggingface/hub` package provide basic capabilities to scan the cache directory. Learn more about [Manage huggingface_hub cache-system](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache).
119119

120+
### `scanCacheDir`
121+
122+
You can get the list of cached repositories using the `scanCacheDir` function.
123+
120124
```ts
121125
import { scanCacheDir } from "@huggingface/hub";
122126

123127
const result = await scanCacheDir();
124128

125129
console.log(result);
126130
```
127-
Note that the cache directory is created and used only by the Python and Rust libraries. Downloading files using the `@huggingface/hub` package won't use the cache directory.
131+
Note: this does not work in the browser
132+
133+
### `downloadFileToCacheDir`
134+
135+
You can cache a file of a repository using the `downloadFileToCacheDir` function.
136+
137+
```ts
138+
import { downloadFileToCacheDir } from "@huggingface/hub";
139+
140+
const file = await downloadFileToCacheDir({
141+
repo: 'foo/bar',
142+
path: 'README.md'
143+
});
144+
145+
console.log(file);
146+
```
147+
Note: this does not work in the browser
148+
149+
### `snapshotDownload`
150+
151+
You can download an entire repository at a given revision in the cache directory using the `snapshotDownload` function.
152+
153+
```ts
154+
import { snapshotDownload } from "@huggingface/hub";
155+
156+
const directory = await snapshotDownload({
157+
repo: 'foo/bar',
158+
});
159+
160+
console.log(directory);
161+
```
162+
The code use internally the `downloadFileToCacheDir` function.
163+
164+
Note: this does not work in the browser
128165

129166
## Performance considerations
130167

packages/hub/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"./src/utils/FileBlob.ts": false,
2323
"./src/lib/cache-management.ts": false,
2424
"./src/lib/download-file-to-cache-dir.ts": false,
25+
"./src/lib/snapshot-download.ts": false,
2526
"./dist/index.js": "./dist/browser/index.js",
2627
"./dist/index.mjs": "./dist/browser/index.mjs"
2728
},

packages/hub/src/lib/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ export * from "./oauth-handle-redirect";
2121
export * from "./oauth-login-url";
2222
export * from "./parse-safetensors-metadata";
2323
export * from "./paths-info";
24+
export * from "./snapshot-download";
2425
export * from "./space-info";
2526
export * from "./upload-file";
2627
export * from "./upload-files";
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
import { expect, test, describe, vi, beforeEach } from "vitest";
2+
import { dirname, join } from "node:path";
3+
import { mkdir, writeFile } from "node:fs/promises";
4+
import { getHFHubCachePath } from "./cache-management";
5+
import { downloadFileToCacheDir } from "./download-file-to-cache-dir";
6+
import { snapshotDownload } from "./snapshot-download";
7+
import type { ListFileEntry } from "./list-files";
8+
import { listFiles } from "./list-files";
9+
import { modelInfo } from "./model-info";
10+
import type { ModelEntry } from "./list-models";
11+
import type { ApiModelInfo } from "../types/api/api-model";
12+
import { datasetInfo } from "./dataset-info";
13+
import type { DatasetEntry } from "./list-datasets";
14+
import type { ApiDatasetInfo } from "../types/api/api-dataset";
15+
import { spaceInfo } from "./space-info";
16+
import type { SpaceEntry } from "./list-spaces";
17+
import type { ApiSpaceInfo } from "../types/api/api-space";
18+
19+
vi.mock("node:fs/promises", () => ({
20+
writeFile: vi.fn(),
21+
mkdir: vi.fn(),
22+
}));
23+
24+
vi.mock("./space-info", () => ({
25+
spaceInfo: vi.fn(),
26+
}));
27+
28+
vi.mock("./dataset-info", () => ({
29+
datasetInfo: vi.fn(),
30+
}));
31+
32+
vi.mock("./model-info", () => ({
33+
modelInfo: vi.fn(),
34+
}));
35+
36+
vi.mock("./list-files", () => ({
37+
listFiles: vi.fn(),
38+
}));
39+
40+
vi.mock("./download-file-to-cache-dir", () => ({
41+
downloadFileToCacheDir: vi.fn(),
42+
}));
43+
44+
const DUMMY_SHA = "dummy-sha";
45+
46+
// utility method to transform an array of ListFileEntry to an AsyncGenerator<ListFileEntry>
47+
async function* toAsyncGenerator(content: ListFileEntry[]): AsyncGenerator<ListFileEntry> {
48+
for (const entry of content) {
49+
yield Promise.resolve(entry);
50+
}
51+
}
52+
53+
beforeEach(() => {
54+
vi.resetAllMocks();
55+
vi.mocked(listFiles).mockReturnValue(toAsyncGenerator([]));
56+
57+
// mock repo info
58+
vi.mocked(modelInfo).mockResolvedValue({
59+
sha: DUMMY_SHA,
60+
} as ModelEntry & ApiModelInfo);
61+
vi.mocked(datasetInfo).mockResolvedValue({
62+
sha: DUMMY_SHA,
63+
} as DatasetEntry & ApiDatasetInfo);
64+
vi.mocked(spaceInfo).mockResolvedValue({
65+
sha: DUMMY_SHA,
66+
} as SpaceEntry & ApiSpaceInfo);
67+
});
68+
69+
describe("snapshotDownload", () => {
70+
test("empty AsyncGenerator should not call downloadFileToCacheDir", async () => {
71+
await snapshotDownload({
72+
repo: {
73+
name: "foo/bar",
74+
type: "space",
75+
},
76+
});
77+
78+
expect(downloadFileToCacheDir).not.toHaveBeenCalled();
79+
});
80+
81+
test("repo type model should use modelInfo", async () => {
82+
await snapshotDownload({
83+
repo: {
84+
name: "foo/bar",
85+
type: "model",
86+
},
87+
});
88+
expect(modelInfo).toHaveBeenCalledOnce();
89+
expect(modelInfo).toHaveBeenCalledWith({
90+
name: "foo/bar",
91+
additionalFields: ["sha"],
92+
revision: "main",
93+
repo: {
94+
name: "foo/bar",
95+
type: "model",
96+
},
97+
});
98+
});
99+
100+
test("repo type dataset should use datasetInfo", async () => {
101+
await snapshotDownload({
102+
repo: {
103+
name: "foo/bar",
104+
type: "dataset",
105+
},
106+
});
107+
expect(datasetInfo).toHaveBeenCalledOnce();
108+
expect(datasetInfo).toHaveBeenCalledWith({
109+
name: "foo/bar",
110+
additionalFields: ["sha"],
111+
revision: "main",
112+
repo: {
113+
name: "foo/bar",
114+
type: "dataset",
115+
},
116+
});
117+
});
118+
119+
test("repo type space should use spaceInfo", async () => {
120+
await snapshotDownload({
121+
repo: {
122+
name: "foo/bar",
123+
type: "space",
124+
},
125+
});
126+
expect(spaceInfo).toHaveBeenCalledOnce();
127+
expect(spaceInfo).toHaveBeenCalledWith({
128+
name: "foo/bar",
129+
additionalFields: ["sha"],
130+
revision: "main",
131+
repo: {
132+
name: "foo/bar",
133+
type: "space",
134+
},
135+
});
136+
});
137+
138+
test("commitHash should be saved to ref folder", async () => {
139+
await snapshotDownload({
140+
repo: {
141+
name: "foo/bar",
142+
type: "space",
143+
},
144+
revision: "dummy-revision",
145+
});
146+
147+
// cross-platform testing
148+
const expectedPath = join(getHFHubCachePath(), "spaces--foo--bar", "refs", "dummy-revision");
149+
expect(mkdir).toHaveBeenCalledWith(dirname(expectedPath), { recursive: true });
150+
expect(writeFile).toHaveBeenCalledWith(expectedPath, DUMMY_SHA);
151+
});
152+
153+
test("directory ListFileEntry should mkdir it", async () => {
154+
vi.mocked(listFiles).mockReturnValue(
155+
toAsyncGenerator([
156+
{
157+
oid: "dummy-etag",
158+
type: "directory",
159+
path: "potatoes",
160+
size: 0,
161+
lastCommit: {
162+
date: new Date().toISOString(),
163+
id: DUMMY_SHA,
164+
title: "feat: best commit",
165+
},
166+
},
167+
])
168+
);
169+
170+
await snapshotDownload({
171+
repo: {
172+
name: "foo/bar",
173+
type: "space",
174+
},
175+
});
176+
177+
// cross-platform testing
178+
const expectedPath = join(getHFHubCachePath(), "spaces--foo--bar", "snapshots", DUMMY_SHA, "potatoes");
179+
expect(mkdir).toHaveBeenCalledWith(expectedPath, { recursive: true });
180+
});
181+
182+
test("files in ListFileEntry should download them", async () => {
183+
const entries: ListFileEntry[] = Array.from({ length: 10 }, (_, i) => ({
184+
oid: `dummy-etag-${i}`,
185+
type: "file",
186+
path: `file-${i}.txt`,
187+
size: i,
188+
lastCommit: {
189+
date: new Date().toISOString(),
190+
id: DUMMY_SHA,
191+
title: "feat: best commit",
192+
},
193+
}));
194+
vi.mocked(listFiles).mockReturnValue(toAsyncGenerator(entries));
195+
196+
await snapshotDownload({
197+
repo: {
198+
name: "foo/bar",
199+
type: "space",
200+
},
201+
});
202+
203+
for (const entry of entries) {
204+
expect(downloadFileToCacheDir).toHaveBeenCalledWith(
205+
expect.objectContaining({
206+
repo: {
207+
name: "foo/bar",
208+
type: "space",
209+
},
210+
path: entry.path,
211+
revision: DUMMY_SHA,
212+
})
213+
);
214+
}
215+
});
216+
217+
test("custom params should be propagated", async () => {
218+
// fetch mock
219+
const fetchMock: typeof fetch = vi.fn();
220+
const hubMock = "https://foor.bar";
221+
const accessTokenMock = "dummy-access-token";
222+
223+
vi.mocked(listFiles).mockReturnValue(
224+
toAsyncGenerator([
225+
{
226+
oid: `dummy-etag`,
227+
type: "file",
228+
path: `file.txt`,
229+
size: 10,
230+
lastCommit: {
231+
date: new Date().toISOString(),
232+
id: DUMMY_SHA,
233+
title: "feat: best commit",
234+
},
235+
},
236+
])
237+
);
238+
239+
await snapshotDownload({
240+
repo: {
241+
name: "foo/bar",
242+
type: "space",
243+
},
244+
hubUrl: hubMock,
245+
fetch: fetchMock,
246+
accessToken: accessTokenMock,
247+
});
248+
249+
expect(spaceInfo).toHaveBeenCalledWith(
250+
expect.objectContaining({
251+
fetch: fetchMock,
252+
hubUrl: hubMock,
253+
accessToken: accessTokenMock,
254+
})
255+
);
256+
257+
// list files should receive custom fetch
258+
expect(listFiles).toHaveBeenCalledWith(
259+
expect.objectContaining({
260+
fetch: fetchMock,
261+
hubUrl: hubMock,
262+
accessToken: accessTokenMock,
263+
})
264+
);
265+
266+
// download file to cache should receive custom fetch
267+
expect(downloadFileToCacheDir).toHaveBeenCalledWith(
268+
expect.objectContaining({
269+
fetch: fetchMock,
270+
hubUrl: hubMock,
271+
accessToken: accessTokenMock,
272+
})
273+
);
274+
});
275+
});

0 commit comments

Comments
 (0)