Skip to content

Commit c980472

Browse files
mikekgfbshoumikhinmetascroymalfetlucylq
committed
Quantization, fp acceleration, and testing (#572)
* code beautification * code beautification, move functions together * make --device fast the default (#515) * make --device fast the default * Update iOS.md (#517) * Update iOS.md * Update iOS.md * Pip to pip3 (#504) * remove macos-12 test * pip to pip3 * break aoti CI jobs separately (#500) * init * fixes * more fixes * fixes * fix * fix * bug fix * add objcopy update * suppress int8 * undefined variable --------- Co-authored-by: Michael Gschwind <[email protected]> * Support llama3 in chat in run.cpp (#486) * refactor chat runner in preparation for llama3 * add sketch for llama3 prompt template and move to returning tokens * fix tiktoken * fixes to chat * add default llama_ver * Add tests for quantize json, add cuda device specification and precision to cuda.json (#519) * remove code for no KV Cache path (#527) * Update ADVANCED-USERS.md (#529) Update Advanced Users description to reflect changes in the repo since the description was initially created. * runner-aoti on cuda (#531) * runner-aoti on cuda * transfer results back to CPU * transfer results back to CPU * runner-aoti on cuda * Update runner_build.md (#530) Update description of runner and build process in runner_build.md * clean up runner code a little (#532) * clean up runner code a little * update * update * pull out generate loop in chat * updates * edit docs * typo * move int8 linear class and function into qops.py (#534) * add dtype tests for runner-aoti + runner-et (#539) * add dtype tests for runner-aoti + runner-et * typo * Quantized embedding (#536) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * Move Linear int4 to qops (#537) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * move int4 linear to qops * Revert "add dtype tests for runner-aoti + runner-et (#539)" (#548) This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1. * fix generate for llama3 (#538) * fix generate for llama3 * switch more things to C * remove C++ header * add delegation visualization instructions (#551) * Add dtype runner aoti (#552) * add dtype tests for runner-aoti + runner-et * typo * add dtype test runner-aoti * test sdpa with fp16 (#553) * test sdpa with fp16 * kv cache fp32 * typo * update (#560) * Only support newest versions of lm-eval (#556) Summary: remove support for lm-eval 0.3 to reduce the options we have Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * split cpu eval CI by dtype (#554) * split cpu eval CI by dtype * fix * differentiate names with checks * keep one name the same as old * fix * Removing duplicate HF issue message from README (#559) Co-authored-by: Michael Gschwind <[email protected]> * doc updates (#567) * Add VM-safe MPS check --------- Co-authored-by: Anthony Shoumikhin <[email protected]> Co-authored-by: metascroy <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: lucylq <[email protected]> Co-authored-by: Jerry Zhang <[email protected]> Co-authored-by: Jack-Khuu <[email protected]> * add unpacking support (#525) * add unpacking support * fix typos and linter * perform parallel prefill when possible (#568) * perform parallel prefill when possible * typo * disable hack * remove print * remove debug messages which prevent export * fixes * stream results in generate.py (#571) * remove logging interfering with export --------- Co-authored-by: Anthony Shoumikhin <[email protected]> Co-authored-by: metascroy <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: lucylq <[email protected]> Co-authored-by: Jerry Zhang <[email protected]> Co-authored-by: Jack-Khuu <[email protected]>
1 parent e677fe8 commit c980472

File tree

7 files changed

+134
-23
lines changed

7 files changed

+134
-23
lines changed

.github/workflows/more-tests.yml

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
name: Run parallel prefill
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
workflow_dispatch:
9+
10+
jobs:
11+
test-cuda:
12+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
13+
with:
14+
runner: linux.g5.4xlarge.nvidia.gpu
15+
gpu-arch-type: cuda
16+
gpu-arch-version: "12.1"
17+
script: |
18+
echo "::group::Print machine info"
19+
uname -a
20+
echo "::endgroup::"
21+
22+
echo "::group::Install newer objcopy that supports --set-section-alignment"
23+
yum install -y devtoolset-10-binutils
24+
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
25+
echo "::endgroup::"
26+
27+
28+
echo "::group::Download checkpoints"
29+
# Install requirements
30+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
31+
pip3 install -r requirements.txt
32+
pip3 list
33+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
34+
echo "::endgroup::"
35+
36+
echo "::group::Download checkpoints"
37+
mkdir -p checkpoints/stories15M
38+
pushd checkpoints/stories15M
39+
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
40+
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
41+
popd
42+
echo "::endgroup::"
43+
44+
echo "::group::Run inference"
45+
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
46+
export MODEL_NAME=stories15M
47+
export MODEL_DIR=/tmp
48+
49+
for DTYPE in bfloat16 float16 float32; do
50+
###################################################################
51+
# group with different temperatures
52+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0
53+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0.9
54+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 1.0
55+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 100
56+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 200
57+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 500
58+
###################################################################
59+
# group with different temperatures and prefill, and compile
60+
# and prefill compile
61+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0 --compile --compile-prefill
62+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0.9 --compile --compile-prefill
63+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 1.0 --compile --compile-prefill
64+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 100 --compile --compile-prefill
65+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 200 --compile --compile-prefill
66+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 500 --compile --compile-prefill
67+
###################################################################
68+
# group with different temperatures and sequential prefill
69+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0 --sequential-prefill
70+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0.9 --sequential-prefill
71+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 1.0 --sequential-prefill
72+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 100 --sequential-prefill
73+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 200 --sequential-prefill
74+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 500 --sequential-prefill
75+
###################################################################
76+
# group with different temperatures and prefill, and compile
77+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0 --sequential-prefill --compile
78+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0.9 --sequential-prefill --compile
79+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 1.0 --sequential-prefill --compile
80+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 100 --sequential-prefill --compile
81+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 200 --sequential-prefill --compile
82+
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 500 --sequential-prefill --compile
83+
84+
done
85+
86+
echo "tests complete"
87+
echo "******************************************"
88+
echo "::endgroup::"
89+

build/builder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ def from_args(cls, args): # -> BuilderArgs:
117117
if "chat" in path_basename or "instruct" in path_basename:
118118
is_chat_model = True
119119

120+
if args.output_pte_path and args.dtype.startswith("fast"):
121+
if args.dtype == "fast":
122+
dtype = torch.float32
123+
else:
124+
dtype = torch.float16
125+
else:
126+
dtype = name_to_dtype(args.dtype)
127+
120128
return cls(
121129
checkpoint_dir=checkpoint_dir,
122130
checkpoint_path=checkpoint_path,
@@ -127,7 +135,7 @@ def from_args(cls, args): # -> BuilderArgs:
127135
dso_path=args.dso_path,
128136
pte_path=args.pte_path,
129137
device=args.device,
130-
precision=name_to_dtype(args.dtype),
138+
precision=dtype,
131139
setup_caches=(args.output_dso_path or args.output_pte_path),
132140
use_tp=False,
133141
is_chat_model=is_chat_model,

build/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,17 @@ def get_precision():
130130

131131
##########################################################################
132132
### dtype name to torch.dtype mapping ###
133+
134+
133135
def name_to_dtype(name):
136+
if (name == "fast") or (name == "fast16"):
137+
import platform
138+
139+
if platform.processor() == "arm":
140+
return torch.float16
141+
else:
142+
return torch.bfloat16
143+
134144
if name in name_to_dtype_dict:
135145
return name_to_dtype_dict[name]
136146
else:
@@ -150,6 +160,8 @@ def allowable_dtype_names() -> List[str]:
150160
"float32": torch.float,
151161
"float16": torch.float16,
152162
"bfloat16": torch.bfloat16,
163+
"fast": None,
164+
"fast16": None,
153165
}
154166

155167

@@ -208,6 +220,7 @@ def state_dict_device(d, device="cpu") -> Dict:
208220
#########################################################################
209221
### move state dict to specified device ###
210222

223+
211224
def is_mps_available() -> bool:
212225
if not torch.backends.mps.is_available():
213226
return False
@@ -219,7 +232,7 @@ def is_mps_available() -> bool:
219232
except:
220233
return False
221234

222-
# MPS, is that you?
235+
# MPS, is that you?
223236
return True
224237

225238

cli.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,10 @@ def _add_arguments_common(parser):
210210
help="Use the specified ExecuTorch .pte model file",
211211
)
212212
parser.add_argument(
213-
"-d",
214213
"--dtype",
215-
default="float32",
214+
default="fast",
216215
choices=allowable_dtype_names(),
217-
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32",
216+
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32, fast16, fast",
218217
)
219218
parser.add_argument(
220219
"-v",

generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def prefill(
172172
sequential_prefill=True,
173173
**sampling_kwargs,
174174
) -> torch.Tensor:
175-
logging.debug(f"x: {x}, input_pos: {input_pos}")
175+
# logging.debug(f"x: {x}, input_pos: {input_pos}")
176176
width = x.size(1)
177177
assert input_pos.size(0) == width
178178

qops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
305305
@classmethod
306306
def _check_k(cls, *, k, groupsize=1, inner_k_tiles=1):
307307
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
308+
309+
@classmethod
310+
def _prepare_weight_and_scales_and_zeros(
311+
cls, weight_bf16, groupsize, inner_k_tiles
312+
):
313+
from quantize import group_quantize_tensor
314+
315+
weight_int32, scales_and_zeros = group_quantize_tensor(
316+
weight_bf16, n_bit=4, groupsize=groupsize
317+
)
318+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
319+
weight_int32, inner_k_tiles
320+
)
321+
return weight_int4pack, scales_and_zeros
322+
323+
@classmethod
324+
def _calc_padded_size(cls, *, k, groupsize=1, innner_k_tiles=1):
325+
return find_multiple(k, 1024)

quantize.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -595,22 +595,6 @@ def quantized_model(self) -> nn.Module:
595595
##### weight only int4 per channel groupwise quantized code ######
596596

597597

598-
def _int4_prepare_int4_weight_and_scales_and_zeros(
599-
weight_bf16, groupsize, inner_k_tiles
600-
):
601-
weight_int32, scales_and_zeros = group_quantize_tensor(
602-
weight_bf16, n_bit=4, groupsize=groupsize
603-
)
604-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
605-
weight_int32, inner_k_tiles
606-
)
607-
return weight_int4pack, scales_and_zeros
608-
609-
610-
def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1):
611-
return find_multiple(k, 1024)
612-
613-
614598
def replace_linear_int4(
615599
module,
616600
device,
@@ -705,7 +689,7 @@ def create_quantized_state_dict(self):
705689
)
706690
continue
707691
weight_int4pack, scales_and_zeros = (
708-
_int4_prepare_int4_weight_and_scales_and_zeros(
692+
WeightOnlyInt4Linear._prepare_weight_and_scales_and_zeros(
709693
weight.to(torch.float), self.groupsize, self.inner_k_tiles
710694
)
711695
)

0 commit comments

Comments
 (0)