Skip to content

Commit 170581a

Browse files
authored
feat(fast cli): Import torch lazily in all places used by the CLI that don't need a model (#1349)
These changes add a little complexity with the lazy and local imports, but they also greatly improve the CLI's response for --help, list, and where. Changes: * Move `import torch` into function bodies that need them * Use `importlib.metadata.version` to check the torch version rather than torch.__version__ * Switch from using torch.inference_mode as a decorator to using it as a context manager. * I also removed it from convert_hf_checkpoint_to_tune since that does not use torch at all * In build_utils, wrap the dtype values in lambdas so they're lazily fetched. #1347 Branch: FasterCli-1347 Signed-off-by: Gabe Goodhart <[email protected]>
1 parent ac02ffb commit 170581a

File tree

3 files changed

+57
-38
lines changed

3 files changed

+57
-38
lines changed

torchchat/cli/cli.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,16 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8+
import importlib.metadata
89
import json
910
import logging
1011
import os
1112
import sys
1213
from pathlib import Path
1314

14-
import torch
15-
16-
from torchchat.cli.download import download_and_convert, is_model_downloaded
17-
1815
from torchchat.utils.build_utils import (
1916
allowable_dtype_names,
2017
allowable_params_table,
21-
get_device_str,
2218
)
2319

2420
logging.basicConfig(level=logging.INFO, format="%(message)s")
@@ -42,6 +38,9 @@
4238

4339
# Handle CLI arguments that are common to a majority of subcommands.
4440
def check_args(args, verb: str) -> None:
41+
# Local import to avoid unnecessary expensive imports
42+
from torchchat.cli.download import download_and_convert, is_model_downloaded
43+
4544
# Handle model download. Skip this for download, since it has slightly
4645
# different semantics.
4746
if (
@@ -498,9 +497,10 @@ def _add_speculative_execution_args(parser) -> None:
498497

499498

500499
def arg_init(args):
501-
if not (torch.__version__ > "2.3"):
500+
torch_version = importlib.metadata.version("torch")
501+
if not torch_version or (torch_version <= "2.3"):
502502
raise RuntimeError(
503-
f"You are using PyTorch {torch.__version__}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
503+
f"You are using PyTorch {torch_version}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
504504
)
505505

506506
if sys.version_info.major != 3 or sys.version_info.minor < 10:
@@ -521,6 +521,9 @@ def arg_init(args):
521521
raise RuntimeError("Device not supported by ExecuTorch")
522522
args.device = "cpu"
523523
else:
524+
# Localized import to minimize expensive imports
525+
from torchchat.utils.build_utils import get_device_str
526+
524527
args.device = get_device_str(
525528
args.quantize.get("executor", {}).get("accelerator", args.device)
526529
)
@@ -534,5 +537,8 @@ def arg_init(args):
534537
vars(args)["compile_prefill"] = False
535538

536539
if hasattr(args, "seed") and args.seed:
540+
# Localized import to minimize expensive imports
541+
import torch
542+
537543
torch.manual_seed(args.seed)
538544
return args

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,23 @@
1111
from pathlib import Path
1212
from typing import Optional
1313

14-
import torch
15-
16-
from torchchat.model import TransformerArgs
17-
1814
# support running without installing as a package
1915
wd = Path(__file__).parent.parent
2016
sys.path.append(str(wd.resolve()))
2117
sys.path.append(str((wd / "build").resolve()))
2218

23-
from torchchat.model import ModelArgs
24-
2519

26-
@torch.inference_mode()
2720
def convert_hf_checkpoint(
2821
*,
2922
model_dir: Optional[Path] = None,
3023
model_name: Optional[str] = None,
3124
remove_bin_files: bool = False,
3225
) -> None:
26+
27+
# Local imports to avoid expensive imports
28+
from torchchat.model import ModelArgs, TransformerArgs
29+
import torch
30+
3331
if model_dir is None:
3432
model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
3533
if model_name is None:
@@ -58,10 +56,11 @@ def convert_hf_checkpoint(
5856
tokenizer_pth = model_dir / "original" / "tokenizer.model"
5957
if consolidated_pth.is_file() and tokenizer_pth.is_file():
6058
# Confirm we can load it
61-
loaded_result = torch.load(
62-
str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True
63-
)
64-
del loaded_result # No longer needed
59+
with torch.inference_mode():
60+
loaded_result = torch.load(
61+
str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True
62+
)
63+
del loaded_result # No longer needed
6564
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
6665
os.rename(consolidated_pth, model_dir / "model.pth")
6766
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
@@ -130,7 +129,8 @@ def load_safetensors():
130129
state_dict = None
131130
for loader in loaders:
132131
try:
133-
state_dict = loader()
132+
with torch.inference_mode():
133+
state_dict = loader()
134134
break
135135
except Exception:
136136
continue
@@ -173,7 +173,6 @@ def load_safetensors():
173173
os.remove(file)
174174

175175

176-
@torch.inference_mode()
177176
def convert_hf_checkpoint_to_tune(
178177
*,
179178
model_dir: Optional[Path] = None,

torchchat/utils/build_utils.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,31 @@
1313
from pathlib import Path
1414
from typing import Any, Callable, Dict, List, Optional, Tuple
1515

16-
import torch
1716

1817
##########################################################################
1918
### unpack packed weights ###
2019

2120

21+
class _LazyImportTorch:
22+
"""This is a wrapper around the import of torch that only performs the
23+
import when an actual attribute is needed off of torch.
24+
"""
25+
@staticmethod
26+
def __getattribute__(name: str) -> Any:
27+
import torch
28+
return getattr(torch, name)
29+
30+
31+
# Alias torch to the lazy import
32+
torch = _LazyImportTorch()
33+
34+
2235
def unpack_packed_weights(
2336
packed_weights: Dict[str, Any],
2437
packed_linear: Callable,
25-
input_dtype: torch.dtype,
38+
input_dtype: "torch.dtype",
2639
unpacked_dims: Tuple,
27-
) -> torch.Tensor:
40+
) -> "torch.Tensor":
2841
"""Given a packed weight matrix `packed_weights`, a Callable
2942
implementing a packed linear function for the packed format, and the
3043
unpacked dimensions of the weights, recreate the unpacked weight
@@ -169,26 +182,27 @@ def name_to_dtype(name, device):
169182
return torch.bfloat16
170183

171184
try:
172-
return name_to_dtype_dict[name]
185+
return _name_to_dtype_dict[name]()
173186
except KeyError:
174187
raise RuntimeError(f"unsupported dtype name {name} specified")
175188

176189

177190
def allowable_dtype_names() -> List[str]:
178-
return name_to_dtype_dict.keys()
179-
180-
181-
name_to_dtype_dict = {
182-
"fp32": torch.float,
183-
"fp16": torch.float16,
184-
"bf16": torch.bfloat16,
185-
"float": torch.float,
186-
"half": torch.float16,
187-
"float32": torch.float,
188-
"float16": torch.float16,
189-
"bfloat16": torch.bfloat16,
190-
"fast": None,
191-
"fast16": None,
191+
return _name_to_dtype_dict.keys()
192+
193+
194+
# NOTE: values are wrapped in lambdas to avoid proactive imports for torch
195+
_name_to_dtype_dict = {
196+
"fp32": lambda: torch.float,
197+
"fp16": lambda: torch.float16,
198+
"bf16": lambda: torch.bfloat16,
199+
"float": lambda: torch.float,
200+
"half": lambda: torch.float16,
201+
"float32": lambda: torch.float,
202+
"float16": lambda: torch.float16,
203+
"bfloat16": lambda: torch.bfloat16,
204+
"fast": lambda: None,
205+
"fast16": lambda: None,
192206
}
193207

194208

0 commit comments

Comments
 (0)