Skip to content

feat(hub): adding downloadFileToCacheDir #1034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/hub/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"./src/utils/sha256-node.ts": false,
"./src/utils/FileBlob.ts": false,
"./src/lib/cache-management.ts": false,
"./src/lib/download-file-to-cache-dir.ts": false,
"./dist/index.js": "./dist/browser/index.js",
"./dist/index.mjs": "./dist/browser/index.mjs"
},
Expand Down
15 changes: 11 additions & 4 deletions packages/hub/src/lib/cache-management.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@ function getHuggingFaceHubCache(): string {
return process.env["HUGGINGFACE_HUB_CACHE"] ?? getDefaultCachePath();
}

function getHFHubCache(): string {
export function getHFHubCachePath(): string {
return process.env["HF_HUB_CACHE"] ?? getHuggingFaceHubCache();
}

const FILES_TO_IGNORE: string[] = [".DS_Store"];

export const REPO_ID_SEPARATOR: string = "--";

export function getRepoFolderName({ name, type }: RepoId): string {
const parts = [`${type}s`, ...name.split("/")]
return parts.join(REPO_ID_SEPARATOR);
}

export interface CachedFileInfo {
path: string;
/**
Expand Down Expand Up @@ -63,7 +70,7 @@ export interface HFCacheInfo {
}

export async function scanCacheDir(cacheDir: string | undefined = undefined): Promise<HFCacheInfo> {
if (!cacheDir) cacheDir = getHFHubCache();
if (!cacheDir) cacheDir = getHFHubCachePath();

const s = await stat(cacheDir);
if (!s.isDirectory()) {
Expand Down Expand Up @@ -107,12 +114,12 @@ export async function scanCacheDir(cacheDir: string | undefined = undefined): Pr
export async function scanCachedRepo(repoPath: string): Promise<CachedRepoInfo> {
// get the directory name
const name = basename(repoPath);
if (!name.includes("--")) {
if (!name.includes(REPO_ID_SEPARATOR)) {
throw new Error(`Repo path is not a valid HuggingFace cache directory: ${name}`);
}

// parse the repoId from directory name
const [type, ...remaining] = name.split("--");
const [type, ...remaining] = name.split(REPO_ID_SEPARATOR);
const repoType = parseRepoType(type);
const repoId = remaining.join("/");

Expand Down
234 changes: 234 additions & 0 deletions packages/hub/src/lib/download-file-to-cache-dir.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import { expect, test, describe, vi, beforeEach } from "vitest";
import type { RepoDesignation, RepoId } from "../types/public";
import { dirname, join } from "node:path";
import { lstat, mkdir, stat, symlink, writeFile, rename } from "node:fs/promises";
import { pathsInfo } from "./paths-info";
import type { Stats } from "node:fs";
import { getHFHubCachePath, getRepoFolderName } from "./cache-management";
import { toRepoId } from "../utils/toRepoId";
import { downloadFileToCacheDir } from "./download-file-to-cache-dir";

vi.mock('node:fs/promises', () => ({
writeFile: vi.fn(),
rename: vi.fn(),
symlink: vi.fn(),
lstat: vi.fn(),
mkdir: vi.fn(),
stat: vi.fn()
}));

vi.mock('./paths-info', () => ({
pathsInfo: vi.fn(),
}));

const DUMMY_REPO: RepoId = {
name: 'hello-world',
type: 'model',
};

const DUMMY_ETAG = "dummy-etag";

// utility test method to get blob file path
function _getBlobFile(params: {
repo: RepoDesignation;
etag: string;
cacheDir?: string, // default to {@link getHFHubCache}
}) {
return join(params.cacheDir ?? getHFHubCachePath(), getRepoFolderName(toRepoId(params.repo)), "blobs", params.etag);
}

// utility test method to get snapshot file path
function _getSnapshotFile(params: {
repo: RepoDesignation;
path: string;
revision : string;
cacheDir?: string, // default to {@link getHFHubCache}
}) {
return join(params.cacheDir ?? getHFHubCachePath(), getRepoFolderName(toRepoId(params.repo)), "snapshots", params.revision, params.path);
}

describe('downloadFileToCacheDir', () => {
const fetchMock: typeof fetch = vi.fn();
beforeEach(() => {
vi.resetAllMocks();
// mock 200 request
vi.mocked(fetchMock).mockResolvedValue({
status: 200,
ok: true,
body: 'dummy-body'
} as unknown as Response);

// prevent to use caching
vi.mocked(stat).mockRejectedValue(new Error('Do not exists'));
vi.mocked(lstat).mockRejectedValue(new Error('Do not exists'));
});

test('should throw an error if fileDownloadInfo return nothing', async () => {
await expect(async () => {
await downloadFileToCacheDir({
repo: DUMMY_REPO,
path: '/README.md',
fetch: fetchMock,
});
}).rejects.toThrowError('cannot get path info for /README.md');

expect(pathsInfo).toHaveBeenCalledWith(expect.objectContaining({
repo: DUMMY_REPO,
paths: ['/README.md'],
fetch: fetchMock,
}));
});

test('existing symlinked and blob should not re-download it', async () => {
// <cache>/<repo>/<revision>/snapshots/README.md
const expectPointer = _getSnapshotFile({
repo: DUMMY_REPO,
path: '/README.md',
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
});
// stat ensure a symlink and the pointed file exists
vi.mocked(stat).mockResolvedValue({} as Stats) // prevent default mocked reject

const output = await downloadFileToCacheDir({
repo: DUMMY_REPO,
path: '/README.md',
fetch: fetchMock,
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
});

expect(stat).toHaveBeenCalledOnce();
// Get call argument for stat
const starArg = vi.mocked(stat).mock.calls[0][0];

expect(starArg).toBe(expectPointer)
expect(fetchMock).not.toHaveBeenCalledWith();

expect(output).toBe(expectPointer);
});

test('existing blob should only create the symlink', async () => {
// <cache>/<repo>/<revision>/snapshots/README.md
const expectPointer = _getSnapshotFile({
repo: DUMMY_REPO,
path: '/README.md',
revision: "dummy-commit-hash",
});
// <cache>/<repo>/blobs/<etag>
const expectedBlob = _getBlobFile({
repo: DUMMY_REPO,
etag: DUMMY_ETAG,
});

// mock existing blob only no symlink
vi.mocked(lstat).mockResolvedValue({} as Stats);
// mock pathsInfo resolve content
vi.mocked(pathsInfo).mockResolvedValue([{
oid: DUMMY_ETAG,
size: 55,
path: 'README.md',
type: 'file',
lastCommit: {
date: new Date(),
id: 'dummy-commit-hash',
title: 'Commit msg',
},
}]);

const output = await downloadFileToCacheDir({
repo: DUMMY_REPO,
path: '/README.md',
fetch: fetchMock,
});

expect(stat).not.toHaveBeenCalled();
// should have check for the blob
expect(lstat).toHaveBeenCalled();
expect(vi.mocked(lstat).mock.calls[0][0]).toBe(expectedBlob);

// symlink should have been created
expect(symlink).toHaveBeenCalledOnce();
// no download done
expect(fetchMock).not.toHaveBeenCalled();

expect(output).toBe(expectPointer);
});

test('expect resolve value to be the pointer path of downloaded file', async () => {
// <cache>/<repo>/<revision>/snapshots/README.md
const expectPointer = _getSnapshotFile({
repo: DUMMY_REPO,
path: '/README.md',
revision: "dummy-commit-hash",
});
// <cache>/<repo>/blobs/<etag>
const expectedBlob = _getBlobFile({
repo: DUMMY_REPO,
etag: DUMMY_ETAG,
});

vi.mocked(pathsInfo).mockResolvedValue([{
oid: DUMMY_ETAG,
size: 55,
path: 'README.md',
type: 'file',
lastCommit: {
date: new Date(),
id: 'dummy-commit-hash',
title: 'Commit msg',
},
}]);

const output = await downloadFileToCacheDir({
repo: DUMMY_REPO,
path: '/README.md',
fetch: fetchMock,
});

// expect blobs and snapshots folder to have been mkdir
expect(vi.mocked(mkdir).mock.calls[0][0]).toBe(dirname(expectedBlob));
expect(vi.mocked(mkdir).mock.calls[1][0]).toBe(dirname(expectPointer));

expect(output).toBe(expectPointer);
});

test('should write fetch response to blob', async () => {
// <cache>/<repo>/<revision>/snapshots/README.md
const expectPointer = _getSnapshotFile({
repo: DUMMY_REPO,
path: '/README.md',
revision: "dummy-commit-hash",
});
// <cache>/<repo>/blobs/<etag>
const expectedBlob = _getBlobFile({
repo: DUMMY_REPO,
etag: DUMMY_ETAG,
});

// mock pathsInfo resolve content
vi.mocked(pathsInfo).mockResolvedValue([{
oid: DUMMY_ETAG,
size: 55,
path: 'README.md',
type: 'file',
lastCommit: {
date: new Date(),
id: 'dummy-commit-hash',
title: 'Commit msg',
},
}]);

await downloadFileToCacheDir({
repo: DUMMY_REPO,
path: '/README.md',
fetch: fetchMock,
});

const incomplete = `${expectedBlob}.incomplete`;
// 1. should write fetch#response#body to incomplete file
expect(writeFile).toHaveBeenCalledWith(incomplete, 'dummy-body');
// 2. should rename the incomplete to the blob expected name
expect(rename).toHaveBeenCalledWith(incomplete, expectedBlob);
// 3. should create symlink pointing to blob
expect(symlink).toHaveBeenCalledWith(expectedBlob, expectPointer);
});
});
Loading
Loading