Skip to content

Commit d64fc74

Browse files
Michael Gschwindfacebook-github-bot
authored andcommitted
No hardcoded default checkpoints (#2708)
Summary: Pull Request resolved: #2708 No hardcoded default checkpoints Reviewed By: kimishpatel, digantdesai, cpuhrsch Differential Revision: D55409487 fbshipit-source-id: 1be3eceb62d0d155a40127f5ab2ed0ea101270ad
1 parent 79a4ba0 commit d64fc74

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
import argparse
1010
import copy
1111
import logging
12+
import os
1213
import shlex
1314
from dataclasses import dataclass
1415

1516
from functools import partial
1617
from pathlib import Path
17-
from typing import List, Optional
18+
from typing import List, Optional, Union
1819

1920
import pkg_resources
2021
import torch
@@ -237,8 +238,12 @@ def quantize(
237238
else:
238239
torch_dtype = torch.float16
239240

240-
if checkpoint_path is None:
241-
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
241+
assert checkpoint_path, "Need to specify a checkpoint"
242+
assert os.path.isfile(
243+
canonical_path(checkpoint_path)
244+
), f"{checkpoint_path} does not exist"
245+
# if checkpoint_path is None:
246+
# checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
242247

243248
if calibration_tasks is None:
244249
calibration_tasks = ["wikitext"]
@@ -457,7 +462,9 @@ def build_args_parser() -> argparse.ArgumentParser:
457462
return parser
458463

459464

460-
def canonical_path(path: str, *, dir: bool = False) -> str:
465+
def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
466+
467+
path = str(path)
461468

462469
if verbose_export():
463470
print(f"creating canonical path for {path}")

0 commit comments

Comments
 (0)