Skip to content

Commit 46639bb

Browse files
mikekgfbmalfet
authored andcommitted
fix device linear:int8 quant (#206)
* fix device int8 quant * fix duplicate device on linear int8 * typo * typo (missed comma made it a tuple) * missing fqn
1 parent 36d0fd4 commit 46639bb

File tree

4 files changed

+52
-35
lines changed

4 files changed

+52
-35
lines changed

.github/workflows/compile_t4.yml

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,35 +63,35 @@ jobs:
6363
echo "******************************************"
6464
echo "******** Emb: group-wise quantized *******"
6565
echo "******************************************"
66-
# python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
67-
# cat ./output_eager
68-
# python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
69-
# cat ./output_compiled
70-
# python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
71-
# python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
72-
# cat ./output_aoti
66+
python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
67+
cat ./output_eager
68+
python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
69+
cat ./output_compiled
70+
python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
71+
python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
72+
cat ./output_aoti
7373
7474
echo "******************************************"
7575
echo "******* INT8 channel-wise quantized ******"
7676
echo "******************************************"
77-
# python generate.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
78-
# cat ./output_eager
79-
# python generate.py --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
80-
# cat ./output_compiled
81-
# python export.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
82-
# python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
83-
# cat ./output_aoti
77+
python generate.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
78+
cat ./output_eager
79+
python generate.py --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
80+
cat ./output_compiled
81+
python export.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
82+
python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
83+
cat ./output_aoti
8484
8585
echo "******************************************"
8686
echo "******** INT8 group-wise quantized *******"
8787
echo "******************************************"
88-
# python generate.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
89-
# cat ./output_eager
90-
# python generate.py --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
91-
# cat ./output_compiled
92-
# python export.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
93-
# python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
94-
# cat ./output_aoti
88+
python generate.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
89+
cat ./output_eager
90+
python generate.py --device cuda --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
91+
cat ./output_compiled
92+
python export.py --device cuda --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
93+
python generate.py --device cuda --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
94+
cat ./output_aoti
9595
9696
echo "tests complete"
9797
echo "******************************************"

.github/workflows/test_mps.yml

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,29 @@ jobs:
4848
4949
python generate.py --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
5050
cat ./output_eager
51+
52+
echo "************************************************************"
53+
echo "*** embedding"
54+
echo "************************************************************"
55+
5156
python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
5257
cat ./output_eager
5358
python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
5459
cat ./output_eager
55-
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
56-
# cat ./output_eager
57-
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
58-
# cat ./output_eager
60+
61+
echo "************************************************************"
62+
echo "*** linear int8"
63+
echo "************************************************************"
64+
65+
python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
66+
cat ./output_eager
67+
python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
68+
cat ./output_eager
69+
70+
echo "************************************************************"
71+
echo "*** linear int4"
72+
echo "************************************************************"
73+
5974
# PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
6075
# cat ./output_eager
6176

build/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def _initialize_model(
250250
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
251251
assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export."
252252
try:
253-
from model_et import PTEModel
253+
from build.model_et import PTEModel
254254
model = PTEModel(model_.config, builder_args.pte_path)
255255
except Exception as e:
256256
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")

quantize.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def quantized_model(self) -> nn.Module:
349349
##### Weight-only int8 per-channel quantized code ######
350350

351351

352-
def replace_linear_weight_only_int8_per_channel(module, node_type, groupsize=None):
352+
def replace_linear_weight_only_int8_per_channel(module, device, node_type, groupsize=None):
353353
if groupsize is not None and groupsize != 0:
354354
pass # groupsize = 2 ** groupsize
355355

@@ -367,10 +367,10 @@ def replace_linear_weight_only_int8_per_channel(module, node_type, groupsize=Non
367367
setattr(
368368
module,
369369
name,
370-
WeightOnlyInt8Linear(child.in_features, child.out_features, groupsize),
370+
WeightOnlyInt8Linear(device, child.in_features, child.out_features, groupsize),
371371
)
372372
else:
373-
replace_linear_weight_only_int8_per_channel(child, node_type, groupsize)
373+
replace_linear_weight_only_int8_per_channel(child, device, node_type, groupsize)
374374

375375

376376
class WeightOnlyInt8QuantHandler(QuantHandler):
@@ -384,7 +384,7 @@ def __init__(
384384
groupsize: Optional[int] = None,
385385
):
386386
self.mod = mod
387-
self.device = device,
387+
self.device = device
388388
self.groupsize = groupsize
389389
self.node_type = node_type
390390
if bitwidth is None:
@@ -434,14 +434,16 @@ def create_quantized_state_dict(self) -> Dict:
434434
scales_dtype=mod.weight.dtype,
435435
)
436436

437+
weight = weight.to(device=self.device)
438+
scales = scales.to(device=self.device)
437439
cur_state_dict[f"{fqn}.weight"] = weight
438440
# squeeze makes groupsize=rowsize unidimensional
439441
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
440442

441443
return cur_state_dict
442444

443445
def convert_for_runtime(self) -> nn.Module:
444-
replace_linear_weight_only_int8_per_channel(self.mod, self.node_type, self.groupsize)
446+
replace_linear_weight_only_int8_per_channel(self.mod, self.device, self.node_type, self.groupsize)
445447
return self.mod
446448

447449
def quantized_model(self) -> nn.Module:
@@ -459,11 +461,11 @@ class WeightOnlyInt8Linear(torch.nn.Module):
459461

460462
def __init__(
461463
self,
464+
device,
462465
in_features: int,
463466
out_features: int,
464467
groupsize: Optional[int] = None,
465468
bias: bool = True,
466-
device=None,
467469
dtype=None,
468470
) -> None:
469471
super().__init__()
@@ -472,14 +474,14 @@ def __init__(
472474
self.in_features = in_features
473475
self.out_features = out_features
474476
self.register_buffer(
475-
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
477+
"weight", torch.empty((out_features, in_features), dtype=torch.int8, device=device)
476478
)
477479
dtype=get_precision()
478480
if groupsize is None or (groupsize == 0):
479-
self.register_buffer("scales", torch.ones(out_features, dtype=dtype))
481+
self.register_buffer("scales", torch.ones(out_features, dtype=dtype, device=device))
480482
else:
481483
groups = (in_features + groupsize - 1) // groupsize
482-
self.register_buffer("scales", torch.ones(out_features, groups, dtype=dtype))
484+
self.register_buffer("scales", torch.ones(out_features, groups, dtype=dtype, device=device))
483485

484486
def forward(self, input: torch.Tensor) -> torch.Tensor:
485487
scales = self.scales

0 commit comments

Comments
 (0)