Skip to content

Commit 9579f18

Browse files
authored
Merge branch 'main' into pinbump1111
2 parents dbb090f + 570aebc commit 9579f18

File tree

6 files changed

+104
-21
lines changed

6 files changed

+104
-21
lines changed

.github/workflows/pull.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,3 +1104,41 @@ jobs:
11041104
echo "Generate AOTI"
11051105
python torchchat.py generate stories110M --aoti-package-path ./model.pt2 --prompt "${PRMT}"
11061106
echo "Tests complete."
1107+
1108+
test-torchao-experimental-mps:
1109+
strategy:
1110+
matrix:
1111+
runner: [macos-m1-stable]
1112+
runs-on: ${{matrix.runner}}
1113+
steps:
1114+
- name: Checkout repo
1115+
uses: actions/checkout@v3
1116+
with:
1117+
submodules: true
1118+
- name: Setup Python
1119+
uses: actions/setup-python@v2
1120+
with:
1121+
python-version: 3.10.11
1122+
- name: Print machine info
1123+
run: |
1124+
uname -a
1125+
if [ $(uname -s) == Darwin ]; then
1126+
sysctl machdep.cpu.brand_string
1127+
sysctl machdep.cpu.core_count
1128+
fi
1129+
- name: Install torchchat
1130+
run: |
1131+
echo "Intalling pip3 packages"
1132+
./install/install_requirements.sh
1133+
pip3 list
1134+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
1135+
- name: Install torchao-ops-mps
1136+
id: install-torchao-ops-mps
1137+
run: |
1138+
bash torchchat/utils/scripts/build_torchao_ops.sh mps
1139+
- name: Run inference
1140+
run: |
1141+
python torchchat.py download stories110M
1142+
export PRMT="Once upon a time in a land far away"
1143+
echo "Generate eager"
1144+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device mps --dtype float32 --quantize '{"linear:afpwx": {"bitwidth": 3, "groupsize": 32}}'

docs/quantization.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,32 @@ Note: only the ExecuTorch C++ runner in torchchat when built using the instructi
196196
./cmake-out/et_run llama3_1.pte -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l 3 -i "Once upon a time,"
197197
```
198198

199+
## Experimental TorchAO MPS lowbit kernels
200+
201+
WARNING: These kernels only work on devices with Apple Silicon.
202+
203+
### Use
204+
205+
#### linear:afpwx
206+
The quantization scheme linear:afpwx quantizes only the weights in a groupwise manner with a specified bitwidth and groupsize.
207+
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize (32, 64, 128, 256).
208+
209+
### Setup
210+
To use linear:afpwx, you must set up the torchao mps experimental kernels. These will only work on device with Apple Silicon.
211+
Currently, torchchat can only run them on Eager mode.
212+
213+
From the torchchat root directory, run
214+
```
215+
sh torchchat/utils/scripts/build_torchao_ops.sh mps
216+
```
217+
218+
### Examples
219+
220+
#### Eager mode
221+
```
222+
python3 torchchat.py generate stories110M --device mps --dtype float32 --quantize '{"linear:afpwx": {"bitwidth": 4, "groupsize": 256}}' --prompt "Once upon a time," --num-samples 5
223+
```
224+
199225
## Quantization Profiles
200226

201227
Four [sample profiles](https://github.com/pytorch/torchchat/tree/main/torchchat/quant_config/) are included with the torchchat distribution: `cuda.json`, `desktop.json`, `mobile.json`, `pi5.json`

install/.pins/torchao-pin.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2f97b0955953fa1a46594a27f0df2bc48d93e79d
1+
7d7c14e898eca3fe66138d2a9445755a9270b800

torchchat/utils/quantize.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@
6363
def get_named_parameters(func: Callable) -> List[str]:
6464
# Get the signature of the function
6565
signature = inspect.signature(func)
66-
66+
6767
# Extract the parameters from the signature
6868
parameters = signature.parameters
69-
69+
7070
# Filter and return named parameters
7171
named_params = [
7272
name for name, param in parameters.items()
@@ -80,8 +80,8 @@ def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer:
8080
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
8181
del q_kwargs[key]
8282
return q_kwargs
83-
84-
83+
84+
8585
#########################################################################
8686
### torchchat quantization API ###
8787

@@ -116,15 +116,18 @@ def quantize_model(
116116
if not support_tensor_subclass:
117117
unwrap_tensor_subclass(model)
118118
continue
119-
119+
120120
if quantizer in ["linear:a8wxdq", "embedding:wx"]:
121121
# These quantizers require float32 input weights. Note that after quantization,
122122
# the weights will no longer be float32, but lowbit integers
123123
if get_precision() != torch.float32:
124124
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
125125
set_precision(torch.float32)
126-
127-
# We set global precision from quantize options if it is specified at cli.py:485
126+
127+
if quantizer == "linear:afpwx" and device != "mps":
128+
raise RuntimeError("linear:afpwx quantization can only run on mps device!")
129+
130+
# We set global precision from quantize options if it is specified at cli.py:485
128131
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
129132
precision = get_precision()
130133

@@ -915,10 +918,12 @@ def quantized_model(self) -> nn.Module:
915918
from torchao_experimental_quant_api import (
916919
Int8DynActIntxWeightLinearQuantizer,
917920
IntxWeightEmbeddingQuantizer,
921+
UIntxWeightOnlyLinearQuantizer,
918922
)
919923

920924
quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
921925
quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
926+
quantizer_class_dict["linear:afpwx"] = UIntxWeightOnlyLinearQuantizer
922927

923928
# Try loading custom op
924929
try:
@@ -928,15 +933,14 @@ def quantized_model(self) -> nn.Module:
928933
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
929934
torch.ops.load_library(libs[0])
930935
except Exception as e:
931-
print("Failed to torchao ops library with error: ", e)
932-
print("Slow fallback kernels will be used.")
936+
print("Unabled to load torchao cpu ops library. Slow fallback kernels will be used.")
937+
938+
try:
939+
libname = "libtorchao_ops_mps_aten.dylib"
940+
libpath = f"{torchao_build_path}/cmake-out/lib/{libname}"
941+
torch.ops.load_library(libpath)
942+
except Exception as e:
943+
print("Unabled to load torchao mps ops library.")
933944

934945
except Exception as e:
935-
class ErrorHandler(QuantHandler):
936-
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
937-
global torchao_experimental_load_error
938-
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}")
939-
940-
torchao_experimental_load_error = e
941-
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler
942-
quantizer_class_dict["embedding:wx"] = ErrorHandler
946+
print("Unabled to import torchao experimental quant_api with error: ", e)

torchchat/utils/scripts/build_torchao_ops.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
device=${1:-cpu}
89

10+
if [[ "$device" != "cpu" && "$device" != "mps" ]]; then
11+
echo "Invalid argument: $device. Valid values are 'cpu' or 'mps'." >&2
12+
exit 1
13+
fi
914

1015
source "$(dirname "${BASH_SOURCE[0]}")/install_utils.sh"
1116

1217
pushd ${TORCHCHAT_ROOT}
1318
find_cmake_prefix_path
1419
clone_torchao
15-
install_torchao_aten_ops
20+
install_torchao_aten_ops "$device"
1621
popd

torchchat/utils/scripts/install_utils.sh

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,18 @@ clone_torchao() {
184184
}
185185

186186
install_torchao_aten_ops() {
187-
echo "Building torchao custom ops for ATen"
188-
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental
187+
local device=${1:-cpu}
188+
189+
if [[ "$device" == "cpu" ]]; then
190+
echo "Building torchao custom ops for ATen"
191+
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental
192+
elif [[ "$device" == "mps" ]]; then
193+
echo "Building torchao mps custom ops for ATen"
194+
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental/ops/mps
195+
else
196+
echo "Invalid argument: $device. Valid values are 'cpu' or 'mps'." >&2
197+
return 1
198+
fi
189199

190200
CMAKE_OUT_DIR=${TORCHCHAT_ROOT}/torchao-build/cmake-out
191201
cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \

0 commit comments

Comments
 (0)