Skip to content

Commit 2e0c2c6

Browse files
axel7083coyotte508
andauthored
feat(hub): adding downloadFileToCacheDir (#1034)
## Description Following #1031 which added a `pathsInfo` method which can return the etag/commitHash for a given file. Allowing to be compliant with the `_hf_hub_download_to_cache_dir`[^1] method from the python library. [^1]: [huggingface_hub/file_download.py#L882](https://github.com/huggingface/huggingface_hub/blob/c547c839dbbe0163e3ca422d017daad7c7f9361f/src/huggingface_hub/file_download.py#L882) ## Potential issue The JS implementation do not handle the .lock files as the python library does.. This could be a problem if using the JS and PY function.. ? The JS could make a basic implementation of the lock file that the PY lib is doing if this is a hard requirement. ## Testing I wrote tests for the existing `downloadFile` function (no change to the implementation) and the new one added `downloadFileToCacheDir`. - [x] unit tests has been added --------- Co-authored-by: Eliott C. <[email protected]> Co-authored-by: Eliott C. <[email protected]>
1 parent f057492 commit 2e0c2c6

File tree

7 files changed

+442
-5
lines changed

7 files changed

+442
-5
lines changed

packages/hub/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"./src/utils/sha256-node.ts": false,
2222
"./src/utils/FileBlob.ts": false,
2323
"./src/lib/cache-management.ts": false,
24+
"./src/lib/download-file-to-cache-dir.ts": false,
2425
"./dist/index.js": "./dist/browser/index.js",
2526
"./dist/index.mjs": "./dist/browser/index.mjs"
2627
},

packages/hub/src/lib/cache-management.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,19 @@ function getHuggingFaceHubCache(): string {
1616
return process.env["HUGGINGFACE_HUB_CACHE"] ?? getDefaultCachePath();
1717
}
1818

19-
function getHFHubCache(): string {
19+
export function getHFHubCachePath(): string {
2020
return process.env["HF_HUB_CACHE"] ?? getHuggingFaceHubCache();
2121
}
2222

2323
const FILES_TO_IGNORE: string[] = [".DS_Store"];
2424

25+
export const REPO_ID_SEPARATOR: string = "--";
26+
27+
export function getRepoFolderName({ name, type }: RepoId): string {
28+
const parts = [`${type}s`, ...name.split("/")]
29+
return parts.join(REPO_ID_SEPARATOR);
30+
}
31+
2532
export interface CachedFileInfo {
2633
path: string;
2734
/**
@@ -63,7 +70,7 @@ export interface HFCacheInfo {
6370
}
6471

6572
export async function scanCacheDir(cacheDir: string | undefined = undefined): Promise<HFCacheInfo> {
66-
if (!cacheDir) cacheDir = getHFHubCache();
73+
if (!cacheDir) cacheDir = getHFHubCachePath();
6774

6875
const s = await stat(cacheDir);
6976
if (!s.isDirectory()) {
@@ -107,12 +114,12 @@ export async function scanCacheDir(cacheDir: string | undefined = undefined): Pr
107114
export async function scanCachedRepo(repoPath: string): Promise<CachedRepoInfo> {
108115
// get the directory name
109116
const name = basename(repoPath);
110-
if (!name.includes("--")) {
117+
if (!name.includes(REPO_ID_SEPARATOR)) {
111118
throw new Error(`Repo path is not a valid HuggingFace cache directory: ${name}`);
112119
}
113120

114121
// parse the repoId from directory name
115-
const [type, ...remaining] = name.split("--");
122+
const [type, ...remaining] = name.split(REPO_ID_SEPARATOR);
116123
const repoType = parseRepoType(type);
117124
const repoId = remaining.join("/");
118125

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
import { expect, test, describe, vi, beforeEach } from "vitest";
2+
import type { RepoDesignation, RepoId } from "../types/public";
3+
import { dirname, join } from "node:path";
4+
import { lstat, mkdir, stat, symlink, writeFile, rename } from "node:fs/promises";
5+
import { pathsInfo } from "./paths-info";
6+
import type { Stats } from "node:fs";
7+
import { getHFHubCachePath, getRepoFolderName } from "./cache-management";
8+
import { toRepoId } from "../utils/toRepoId";
9+
import { downloadFileToCacheDir } from "./download-file-to-cache-dir";
10+
11+
vi.mock('node:fs/promises', () => ({
12+
writeFile: vi.fn(),
13+
rename: vi.fn(),
14+
symlink: vi.fn(),
15+
lstat: vi.fn(),
16+
mkdir: vi.fn(),
17+
stat: vi.fn()
18+
}));
19+
20+
vi.mock('./paths-info', () => ({
21+
pathsInfo: vi.fn(),
22+
}));
23+
24+
const DUMMY_REPO: RepoId = {
25+
name: 'hello-world',
26+
type: 'model',
27+
};
28+
29+
const DUMMY_ETAG = "dummy-etag";
30+
31+
// utility test method to get blob file path
32+
function _getBlobFile(params: {
33+
repo: RepoDesignation;
34+
etag: string;
35+
cacheDir?: string, // default to {@link getHFHubCache}
36+
}) {
37+
return join(params.cacheDir ?? getHFHubCachePath(), getRepoFolderName(toRepoId(params.repo)), "blobs", params.etag);
38+
}
39+
40+
// utility test method to get snapshot file path
41+
function _getSnapshotFile(params: {
42+
repo: RepoDesignation;
43+
path: string;
44+
revision : string;
45+
cacheDir?: string, // default to {@link getHFHubCache}
46+
}) {
47+
return join(params.cacheDir ?? getHFHubCachePath(), getRepoFolderName(toRepoId(params.repo)), "snapshots", params.revision, params.path);
48+
}
49+
50+
describe('downloadFileToCacheDir', () => {
51+
const fetchMock: typeof fetch = vi.fn();
52+
beforeEach(() => {
53+
vi.resetAllMocks();
54+
// mock 200 request
55+
vi.mocked(fetchMock).mockResolvedValue({
56+
status: 200,
57+
ok: true,
58+
body: 'dummy-body'
59+
} as unknown as Response);
60+
61+
// prevent to use caching
62+
vi.mocked(stat).mockRejectedValue(new Error('Do not exists'));
63+
vi.mocked(lstat).mockRejectedValue(new Error('Do not exists'));
64+
});
65+
66+
test('should throw an error if fileDownloadInfo return nothing', async () => {
67+
await expect(async () => {
68+
await downloadFileToCacheDir({
69+
repo: DUMMY_REPO,
70+
path: '/README.md',
71+
fetch: fetchMock,
72+
});
73+
}).rejects.toThrowError('cannot get path info for /README.md');
74+
75+
expect(pathsInfo).toHaveBeenCalledWith(expect.objectContaining({
76+
repo: DUMMY_REPO,
77+
paths: ['/README.md'],
78+
fetch: fetchMock,
79+
}));
80+
});
81+
82+
test('existing symlinked and blob should not re-download it', async () => {
83+
// <cache>/<repo>/<revision>/snapshots/README.md
84+
const expectPointer = _getSnapshotFile({
85+
repo: DUMMY_REPO,
86+
path: '/README.md',
87+
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
88+
});
89+
// stat ensure a symlink and the pointed file exists
90+
vi.mocked(stat).mockResolvedValue({} as Stats) // prevent default mocked reject
91+
92+
const output = await downloadFileToCacheDir({
93+
repo: DUMMY_REPO,
94+
path: '/README.md',
95+
fetch: fetchMock,
96+
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
97+
});
98+
99+
expect(stat).toHaveBeenCalledOnce();
100+
// Get call argument for stat
101+
const starArg = vi.mocked(stat).mock.calls[0][0];
102+
103+
expect(starArg).toBe(expectPointer)
104+
expect(fetchMock).not.toHaveBeenCalledWith();
105+
106+
expect(output).toBe(expectPointer);
107+
});
108+
109+
test('existing blob should only create the symlink', async () => {
110+
// <cache>/<repo>/<revision>/snapshots/README.md
111+
const expectPointer = _getSnapshotFile({
112+
repo: DUMMY_REPO,
113+
path: '/README.md',
114+
revision: "dummy-commit-hash",
115+
});
116+
// <cache>/<repo>/blobs/<etag>
117+
const expectedBlob = _getBlobFile({
118+
repo: DUMMY_REPO,
119+
etag: DUMMY_ETAG,
120+
});
121+
122+
// mock existing blob only no symlink
123+
vi.mocked(lstat).mockResolvedValue({} as Stats);
124+
// mock pathsInfo resolve content
125+
vi.mocked(pathsInfo).mockResolvedValue([{
126+
oid: DUMMY_ETAG,
127+
size: 55,
128+
path: 'README.md',
129+
type: 'file',
130+
lastCommit: {
131+
date: new Date(),
132+
id: 'dummy-commit-hash',
133+
title: 'Commit msg',
134+
},
135+
}]);
136+
137+
const output = await downloadFileToCacheDir({
138+
repo: DUMMY_REPO,
139+
path: '/README.md',
140+
fetch: fetchMock,
141+
});
142+
143+
expect(stat).not.toHaveBeenCalled();
144+
// should have check for the blob
145+
expect(lstat).toHaveBeenCalled();
146+
expect(vi.mocked(lstat).mock.calls[0][0]).toBe(expectedBlob);
147+
148+
// symlink should have been created
149+
expect(symlink).toHaveBeenCalledOnce();
150+
// no download done
151+
expect(fetchMock).not.toHaveBeenCalled();
152+
153+
expect(output).toBe(expectPointer);
154+
});
155+
156+
test('expect resolve value to be the pointer path of downloaded file', async () => {
157+
// <cache>/<repo>/<revision>/snapshots/README.md
158+
const expectPointer = _getSnapshotFile({
159+
repo: DUMMY_REPO,
160+
path: '/README.md',
161+
revision: "dummy-commit-hash",
162+
});
163+
// <cache>/<repo>/blobs/<etag>
164+
const expectedBlob = _getBlobFile({
165+
repo: DUMMY_REPO,
166+
etag: DUMMY_ETAG,
167+
});
168+
169+
vi.mocked(pathsInfo).mockResolvedValue([{
170+
oid: DUMMY_ETAG,
171+
size: 55,
172+
path: 'README.md',
173+
type: 'file',
174+
lastCommit: {
175+
date: new Date(),
176+
id: 'dummy-commit-hash',
177+
title: 'Commit msg',
178+
},
179+
}]);
180+
181+
const output = await downloadFileToCacheDir({
182+
repo: DUMMY_REPO,
183+
path: '/README.md',
184+
fetch: fetchMock,
185+
});
186+
187+
// expect blobs and snapshots folder to have been mkdir
188+
expect(vi.mocked(mkdir).mock.calls[0][0]).toBe(dirname(expectedBlob));
189+
expect(vi.mocked(mkdir).mock.calls[1][0]).toBe(dirname(expectPointer));
190+
191+
expect(output).toBe(expectPointer);
192+
});
193+
194+
test('should write fetch response to blob', async () => {
195+
// <cache>/<repo>/<revision>/snapshots/README.md
196+
const expectPointer = _getSnapshotFile({
197+
repo: DUMMY_REPO,
198+
path: '/README.md',
199+
revision: "dummy-commit-hash",
200+
});
201+
// <cache>/<repo>/blobs/<etag>
202+
const expectedBlob = _getBlobFile({
203+
repo: DUMMY_REPO,
204+
etag: DUMMY_ETAG,
205+
});
206+
207+
// mock pathsInfo resolve content
208+
vi.mocked(pathsInfo).mockResolvedValue([{
209+
oid: DUMMY_ETAG,
210+
size: 55,
211+
path: 'README.md',
212+
type: 'file',
213+
lastCommit: {
214+
date: new Date(),
215+
id: 'dummy-commit-hash',
216+
title: 'Commit msg',
217+
},
218+
}]);
219+
220+
await downloadFileToCacheDir({
221+
repo: DUMMY_REPO,
222+
path: '/README.md',
223+
fetch: fetchMock,
224+
});
225+
226+
const incomplete = `${expectedBlob}.incomplete`;
227+
// 1. should write fetch#response#body to incomplete file
228+
expect(writeFile).toHaveBeenCalledWith(incomplete, 'dummy-body');
229+
// 2. should rename the incomplete to the blob expected name
230+
expect(rename).toHaveBeenCalledWith(incomplete, expectedBlob);
231+
// 3. should create symlink pointing to blob
232+
expect(symlink).toHaveBeenCalledWith(expectedBlob, expectPointer);
233+
});
234+
});

0 commit comments

Comments
 (0)