Skip to content

Commit 9f1032d

Browse files
mikekgfbmalfet
authored andcommitted
implement --device 'fast' (#511)
* implement --device 'fast' * land --device fast, but don't make it default
1 parent 4d56583 commit 9f1032d

File tree

3 files changed

+69
-12
lines changed

3 files changed

+69
-12
lines changed

build/utils.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import logging
1010
import os
1111
from pathlib import Path
12-
from typing import List, Dict
12+
from typing import Dict, List
1313

1414
import torch
1515

@@ -143,10 +143,36 @@ def canonical_path(path):
143143

144144

145145
#########################################################################
146-
### general utility functions ###
146+
### move state dict to specified device ###
147+
147148

148-
def state_dict_device(d, device = "cpu") -> Dict:
149+
def state_dict_device(d, device="cpu") -> Dict:
149150
for key, weight in d.items():
150151
d[key] = weight.to(device=device)
151152

152153
return d
154+
155+
156+
#########################################################################
157+
### move state dict to specified device ###
158+
159+
160+
def get_device_str(device) -> str:
161+
if isinstance(device, str) and device == "fast":
162+
return (
163+
"cuda"
164+
if torch.cuda.is_available()
165+
else "mps" if torch.backends.mps.is_available() else "cpu"
166+
)
167+
else:
168+
return str(device)
169+
170+
171+
def get_device(device) -> str:
172+
if isinstance(device, str) and device == "fast":
173+
device = (
174+
"cuda"
175+
if torch.cuda.is_available()
176+
else "mps" if torch.backends.mps.is_available() else "cpu"
177+
)
178+
return torch.device(device)

cli.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from build.utils import allowable_dtype_names, allowable_params_table
12+
from build.utils import allowable_dtype_names, allowable_params_table, get_device_str
1313
from download import download_and_convert, is_model_downloaded
1414

1515
default_device = "cpu"
@@ -252,7 +252,7 @@ def _add_arguments_common(parser):
252252
"--device",
253253
type=str,
254254
default=default_device,
255-
choices=["cpu", "cuda", "mps"],
255+
choices=["fast", "cpu", "cuda", "mps"],
256256
help="Hardware device to use. Options: cpu, cuda, mps",
257257
)
258258
parser.add_argument(
@@ -300,7 +300,7 @@ def arg_init(args):
300300
"You are using PyTorch {torch.__version__}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
301301
)
302302

303-
if hasattr(args, 'quantize') and Path(args.quantize).is_file():
303+
if hasattr(args, "quantize") and Path(args.quantize).is_file():
304304
with open(args.quantize, "r") as f:
305305
args.quantize = json.loads(f.read())
306306

@@ -309,6 +309,9 @@ def arg_init(args):
309309

310310
# if we specify dtype in quantization recipe, replicate it as args.dtype
311311
args.dtype = args.quantize.get("precision", {}).get("dtype", args.dtype)
312+
args.device = get_device_str(
313+
args.quantize.get("executor", {}).get("accelerator", args.device)
314+
)
312315

313316
if hasattr(args, "seed") and args.seed:
314317
torch.manual_seed(args.seed)

quantize.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
import torch
1616
import torch.nn as nn
1717
import torch.nn.functional as F
18-
from build.utils import find_multiple, get_precision, name_to_dtype, use_et_backend, state_dict_device
18+
from build.utils import (
19+
find_multiple,
20+
get_precision,
21+
name_to_dtype,
22+
state_dict_device,
23+
use_et_backend,
24+
)
1925

2026

2127
#########################################################################
@@ -116,6 +122,28 @@ def quantized_model(self) -> nn.Module:
116122
return self.model_.to(device=self.device, dtype=self.dtype)
117123

118124

125+
#########################################################################
126+
### wrapper for setting device as a QuantHandler ###
127+
128+
129+
class ExecutorHandler(QuantHandler):
130+
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, accelerator):
131+
self.model_ = model
132+
133+
if isinstance(accelerator, str):
134+
device = get_device_str(accelerator)
135+
self.device = device
136+
137+
def create_quantized_state_dict(self) -> Dict: # "StateDict"
138+
pass
139+
140+
def convert_for_runtime(self) -> nn.Module:
141+
pass
142+
143+
def quantized_model(self) -> nn.Module:
144+
return self.model_.to(device=self.device)
145+
146+
119147
#########################################################################
120148
##### Quantization Primitives ######
121149

@@ -407,8 +435,8 @@ def __init__(
407435
@torch.no_grad()
408436
def create_quantized_state_dict(self) -> Dict:
409437
cur_state_dict = state_dict_device(self.model_.state_dict())
410-
dict_device = "cpu" # self.device
411-
438+
dict_device = "cpu" # self.device
439+
412440
if self.bitwidth == 4:
413441
range_min = -8
414442
range_max = 7
@@ -824,12 +852,11 @@ def __init__(
824852
assert groupsize in [32, 64, 128, 256]
825853
assert inner_k_tiles in [2, 4, 8]
826854

827-
828855
# @torch.no_grad()
829856
# def p(self):
830857
# cur_state_dict = state_dict_device(self.model_.state_dict())
831858
# dict_device = "cpu" # self.device
832-
#
859+
#
833860
# for fqn, mod in self.model_.named_modules():
834861
# if hasattr(mod, "weight"):
835862
# print(f"device={str(mod.weight.data.device)}")
@@ -838,7 +865,7 @@ def __init__(
838865
def create_quantized_state_dict(self):
839866
cur_state_dict = state_dict_device(self.model_.state_dict())
840867
dict_device = "cpu" # self.device
841-
868+
842869
for fqn, mod in self.model_.named_modules():
843870
if isinstance(mod, torch.nn.Linear):
844871
assert not mod.bias
@@ -1282,4 +1309,5 @@ def quantized_model(self) -> nn.Module:
12821309
"linear:int4-gptq": WeightOnlyInt4GPTQQuantHandler,
12831310
"linear:hqq": WeightOnlyInt4HqqQuantHandler,
12841311
"precision": PrecisionHandler,
1312+
"executor": ExecutorHandler,
12851313
}

0 commit comments

Comments
 (0)