Skip to content

Commit ad98ae2

Browse files
mikekgfbmalfet
authored andcommitted
replicate dtype from quantization dictionary to args.dtype (#494)
1 parent e027cac commit ad98ae2

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

build/builder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,9 @@ def _initialize_model(
357357
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")
358358

359359
if builder_args.dso_path:
360-
assert (
361-
quantize is None or quantize == "{ }"
362-
), "quantize not valid for exported DSO model. Specify quantization during export."
360+
# assert (
361+
# quantize is None or quantize == "{ }"
362+
# ), "quantize not valid for exported DSO model. Specify quantization during export."
363363

364364
t0 = time.time()
365365
model = _load_model(builder_args, only_config=True)
@@ -379,9 +379,9 @@ def _initialize_model(
379379
except:
380380
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
381381
elif builder_args.pte_path:
382-
assert (
383-
quantize is None or quantize == "{ }"
384-
), "quantize not valid for exported PTE model. Specify quantization during export."
382+
# assert (
383+
# quantize is None or quantize == "{ }"
384+
# ), "quantize not valid for exported PTE model. Specify quantization during export."
385385

386386
t0 = time.time()
387387
model = _load_model(builder_args, only_config=True)

cli.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,10 +295,16 @@ def _add_arguments_common(parser):
295295

296296

297297
def arg_init(args):
298-
if hasattr(args, 'quantize') and Path(args.quantize).is_file():
298+
if hasattr(args, "quantize") and Path(args.quantize).is_file():
299299
with open(args.quantize, "r") as f:
300300
args.quantize = json.loads(f.read())
301301

302-
if hasattr(args, 'seed') and args.seed:
302+
if isinstance(args.quantize, str):
303+
args.quantize = json.loads(args.quantize)
304+
305+
# if we specify dtype in quantization recipe, replicate it as args.dtype
306+
args.dtype = args.quantize.get("precision", {}).get("dtype", args.dtype)
307+
308+
if hasattr(args, "seed") and args.seed:
303309
torch.manual_seed(args.seed)
304310
return args

eval.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _model_call(self, inps):
158158
x = seq.index_select(0, input_pos).view(1, -1)
159159
start = time.time()
160160
logits = model_forward(self._model, x, input_pos)
161-
self.times.append(time.time()-start)
161+
self.times.append(time.time() - start)
162162
return logits
163163

164164
def _model_generate(self, context, max_length, eos_token_id):
@@ -266,9 +266,13 @@ def main(args) -> None:
266266
device=builder_args.device,
267267
)
268268
print(f"Time to run eval: {time.time() - t1:.02f}s.")
269-
times=torch.tensor(result["times"])
270-
print(f"Time in model.forward: {times.sum():.02f}s, over {times.numel()} model evaluations")
271-
print(f"forward run time stats - Median: {times.median():.02f}s Min: {times.min():.02f}s Max: {times.max():.02f}s")
269+
times = torch.tensor(result["times"])
270+
print(
271+
f"Time in model.forward: {times.sum():.02f}s, over {times.numel()} model evaluations"
272+
)
273+
print(
274+
f"forward run time stats - Median: {times.median():.02f}s Min: {times.min():.02f}s Max: {times.max():.02f}s"
275+
)
272276
if builder_args.dso_path:
273277
print(f"For model {builder_args.dso_path}")
274278
elif builder_args.pte_path:

0 commit comments

Comments
 (0)