Skip to content

Commit b744a43

Browse files
committed
bump torchao pin
1 parent ecc628d commit b744a43

File tree

4 files changed

+61
-29
lines changed

4 files changed

+61
-29
lines changed

.github/workflows/pull.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,14 +1132,14 @@ jobs:
11321132
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
11331133
export PRMT="Once upon a time in a land far away"
11341134
echo "Generate eager"
1135-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1135+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11361136
echo "Generate compile"
1137-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile
1137+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
11381138
echo "Export and run ET (C++ runner)"
1139-
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1139+
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11401140
./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}"
11411141
echo "Export and run AOTI (C++ runner)"
1142-
python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1142+
python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11431143
./cmake-out/aoti_run ./model.so -z ./tokenizer.model -t 0 -i "${PRMT}"
11441144
echo "Generate AOTI"
11451145
python torchchat.py generate stories110M --dso-path ./model.so --prompt "${PRMT}"

docs/quantization.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,29 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n
121121
## Experimental TorchAO lowbit kernels
122122

123123
### Use
124-
The quantization scheme a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
124+
125+
#### linear:a8wxdq
126+
The quantization scheme linear:a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
125127
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
126128
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
127129
Roughly speaking, {bitwidth: 4, groupsize: 32, has_weight_zeros: false} is similar to GGML's Q4_0 quantization scheme.
128130

129-
You should expect high performance on ARM CPU if bitwidth is 1, 2, 3, 4, or 5 and groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
131+
You should expect high performance on ARM CPU if bitwidth is 1, 2, 3, 4, 5, or 6 and groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
132+
133+
#### embedding:wx
134+
The quantization scheme embedding:wx quantizes embeddings in a groupwise manner with the specified bitwidth and groupsize. It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize. Unlike linear:a8wxdq, embedding:wx always quantizes with scales and zeros.
135+
136+
You should expect high performance on ARM CPU if bitwidth is 1, 2, 3, 4, 5, or 6 and groupsize is divisible by 32. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
130137

131138
### Setup
132-
To use a8wxdq, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
139+
To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
133140

134141
From the torchchat root directory, run
135142
```
136143
sh torchchat/utils/scripts/build_torchao_ops.sh
137144
```
138145

139-
This should take about 10 seconds to complete. Once finished, you can use a8wxdq in torchchat.
146+
This should take about 10 seconds to complete.
140147

141148
Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
142149

@@ -156,17 +163,17 @@ Below we show how to use the new kernels. Except for ExecuTorch, you can specif
156163

157164
#### Eager mode
158165
```
159-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
166+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
160167
```
161168

162169
#### torch.compile
163170
```
164-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
171+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
165172
```
166173

167174
#### AOTI
168175
```
169-
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-dso llama3_1.so
176+
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-dso llama3_1.so
170177
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5
171178
```
172179

@@ -178,7 +185,7 @@ OMP_NUM_THREADS=6 ./cmake-out/aoti_run llama3_1.so -z $HOME/.torchchat/model-cac
178185

179186
#### ExecuTorch
180187
```
181-
python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-pte llama3_1.pte
188+
python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-pte llama3_1.pte
182189
```
183190

184191
Note: only the ExecuTorch C++ runner in torchchat when built using the instructions in the setup can run the exported *.pte file. It will not work with the `python torchchat.py generate` command.

torchchat/utils/quantize.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252

5353

5454
# Flag for whether the a8wxdq quantizer is available.
55-
a8wxdq_load_error: Optional[Exception] = None
55+
torchao_experimental_load_error: Optional[Exception] = None
5656

5757
#########################################################################
5858
### torchchat quantization API ###
@@ -79,9 +79,14 @@ def quantize_model(
7979
quantize_options = json.loads(quantize_options)
8080

8181
for quantizer, q_kwargs in quantize_options.items():
82-
# Test if a8wxdq quantizer is available; Surface error if not.
83-
if quantizer == "linear:a8wxdq" and a8wxdq_load_error is not None:
84-
raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}")
82+
# Test if torchao experimental quantizer is available; Surface error if not.
83+
if (
84+
quantizer in ["linear:a8wxdq", "embedding:wx"]
85+
and torchao_experimental_load_error is not None
86+
):
87+
raise Exception(
88+
f"Note: Failed to load torchao experimental {quantizer} quantizer with error: {torchao_experimental_load_error}"
89+
)
8590

8691
if (
8792
quantizer not in quantizer_class_dict
@@ -95,7 +100,7 @@ def quantize_model(
95100
if not support_tensor_subclass:
96101
unwrap_tensor_subclass(model)
97102
continue
98-
# We set global precision from quantize options if it is specified at cli.py:485
103+
# We set global precision from quantize options if it is specified at cli.py:485
99104
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
100105
precision = get_precision()
101106

@@ -108,10 +113,19 @@ def quantize_model(
108113
groupsize=q_kwargs.get("groupsize", 128),
109114
has_weight_zeros=q_kwargs.get("has_weight_zeros", False),
110115
)
116+
elif quantizer == "embedding:wx":
117+
quant_handler = ao_quantizer_class_dict[quantizer](
118+
device=device,
119+
precision=precision,
120+
bitwidth=q_kwargs.get("bitwidth", 4),
121+
groupsize=q_kwargs.get("groupsize", 32),
122+
)
111123
else:
112124
# Easier to ask forgiveness than permission
113125
quant_handler = ao_quantizer_class_dict[quantizer](
114-
groupsize=q_kwargs["groupsize"], device=device, precision=precision
126+
groupsize=q_kwargs["groupsize"],
127+
device=device,
128+
precision=precision,
115129
)
116130
except TypeError as e:
117131
if "unexpected keyword argument 'device'" in str(e):
@@ -877,24 +891,35 @@ def quantized_model(self) -> nn.Module:
877891

878892
try:
879893
import importlib.util
880-
import sys
881894
import os
895+
import sys
896+
882897
torchao_build_path = f"{os.getcwd()}/torchao-build"
883898

884899
# Try loading quantizer
885900
torchao_experimental_quant_api_spec = importlib.util.spec_from_file_location(
886901
"torchao_experimental_quant_api",
887902
f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py",
888903
)
889-
torchao_experimental_quant_api = importlib.util.module_from_spec(torchao_experimental_quant_api_spec)
904+
torchao_experimental_quant_api = importlib.util.module_from_spec(
905+
torchao_experimental_quant_api_spec
906+
)
890907
sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api
891-
torchao_experimental_quant_api_spec.loader.exec_module(torchao_experimental_quant_api)
892-
from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer
893-
ao_quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer
908+
torchao_experimental_quant_api_spec.loader.exec_module(
909+
torchao_experimental_quant_api
910+
)
911+
from torchao_experimental_quant_api import (
912+
Int8DynActIntxWeightLinearQuantizer,
913+
IntxWeightEmbeddingQuantizer,
914+
)
915+
916+
ao_quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
917+
ao_quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
894918

895919
# Try loading custom op
896920
try:
897921
import glob
922+
898923
libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/libtorchao_ops_aten.*")
899924
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
900925
torch.ops.load_library(libs[0])
@@ -903,4 +928,4 @@ def quantized_model(self) -> nn.Module:
903928
print("Slow fallback kernels will be used.")
904929

905930
except Exception as e:
906-
a8wxdq_load_error = e
931+
torchao_experimental_load_error = e

torchchat/utils/scripts/install_utils.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,10 @@ clone_torchao() {
176176
pushd ${TORCHCHAT_ROOT}/torchao-build/src
177177
echo $pwd
178178

179-
git clone https://github.com/pytorch/ao.git
180-
cd ao
181-
git checkout $(cat ${TORCHCHAT_ROOT}/install/.pins/torchao-pin.txt)
179+
# git clone https://github.com/pytorch/ao.git
180+
# cd ao
181+
# git checkout $(cat ${TORCHCHAT_ROOT}/install/.pins/torchao-pin.txt)
182+
cp -R $HOME/fbsource/fbcode/pytorch/ao .
182183

183184
popd
184185
}
@@ -191,7 +192,6 @@ install_torchao_aten_ops() {
191192
cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \
192193
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \
193194
-DCMAKE_BUILD_TYPE="Release" \
194-
-DTORCHAO_OP_TARGET="aten" \
195195
-S . \
196196
-B ${CMAKE_OUT_DIR} -G Ninja
197197
cmake --build ${CMAKE_OUT_DIR} --target install --config Release
@@ -207,7 +207,7 @@ install_torchao_executorch_ops() {
207207
cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \
208208
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \
209209
-DCMAKE_BUILD_TYPE="Release" \
210-
-DTORCHAO_OP_TARGET="executorch" \
210+
-DTORCHAO_BUILD_EXECUTORCH_OPS=ON \
211211
-DEXECUTORCH_INCLUDE_DIRS="${EXECUTORCH_INCLUDE_DIRS}" \
212212
-DEXECUTORCH_LIBRARIES="${EXECUTORCH_LIBRARIES}" \
213213
-S . \

0 commit comments

Comments
 (0)