Skip to content

Commit a3cda44

Browse files
mikekgfbmalfet
authored andcommitted
MPS CI runs (#162)
* MPS quantization * mps dtypes * updates * fix names * typo * no bfloat16 for older macOS * fix typo * remove failing embedding quantization from MPS runs * bfloat -> current model precision * typo * missed bfloat16 to swotch to defaulkt precision * remove int8 quantization on mps * enable cpu fallback for mps on int4 * hack int4pack_mm for torch.float * typo * disable int4 because fp16 int4pack_mm not working for float16
1 parent 76c330e commit a3cda44

File tree

5 files changed

+104
-17
lines changed

5 files changed

+104
-17
lines changed

.github/workflows/compile-bf16.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ jobs:
4444
export MODEL_NAME=stories15M
4545
export MODEL_DIR=/tmp
4646
for DTYPE in bfloat16 float16 float32; do
47-
if [ $(uname -s) == Darwin ]; then
48-
export DTYPE=float16
49-
fi
47+
# if [ $(uname -s) == Darwin ]; then
48+
# export DTYPE=float16
49+
# fi
5050
python generate.py --dtype ${DTYPE} --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
5151
cat ./output_eager
5252
python generate.py --dtype ${DTYPE} --compile --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled

.github/workflows/test_mps-dtype.yml

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
name: Run eager tests on MPS with dtypes
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
workflow_dispatch:
9+
10+
jobs:
11+
test-mps:
12+
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
13+
with:
14+
runner: macos-m1-stable
15+
script: |
16+
set -eou pipefail
17+
18+
echo "::group::Print machine info"
19+
uname -a
20+
if [ $(uname -s) == Darwin ]; then
21+
sysctl machdep.cpu.brand_string
22+
sysctl machdep.cpu.core_count
23+
fi
24+
echo "::endgroup::"
25+
26+
echo "::group::Install requirements"
27+
# Install requirements
28+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
29+
ls -la
30+
pwd
31+
pip install -r requirements.txt
32+
echo "::endgroup::"
33+
34+
echo "::group::Download checkpoints"
35+
(
36+
mkdir -p checkpoints/stories15M
37+
pushd checkpoints/stories15M
38+
curl -fsSL -O https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
39+
curl -fsSL -O https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
40+
popd
41+
)
42+
echo "::endgroup::"
43+
44+
echo "::group::Run inference"
45+
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
46+
export MODEL_NAME=stories15M
47+
export MODEL_DIR=/tmp
48+
for DTYPE in float16 float32; do
49+
# if [ $(uname -s) == Darwin ]; then
50+
# export DTYPE=float16
51+
# fi
52+
53+
python generate.py --dtype ${DTYPE} --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
54+
cat ./output_eager
55+
# python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
56+
# cat ./output_eager
57+
# python generate.py --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
58+
# cat ./output_eager
59+
# python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
60+
# cat ./output_eager
61+
# python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
62+
# cat ./output_eager
63+
# PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
64+
# cat ./output_eager
65+
done

.github/workflows/test_mps.yml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Run compile tests on MPS
1+
name: Run eager tests on MPS
22

33
on:
44
pull_request:
@@ -45,5 +45,17 @@ jobs:
4545
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
4646
export MODEL_NAME=stories15M
4747
export MODEL_DIR=/tmp
48+
4849
python generate.py --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
4950
cat ./output_eager
51+
# python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
52+
# cat ./output_eager
53+
# python generate.py --device mps --quant '{"embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
54+
# cat ./output_eager
55+
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
56+
# cat ./output_eager
57+
# python generate.py --device mps --quant '{"linear:int8" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
58+
# cat ./output_eager
59+
# PYTORCH_ENABLE_MPS_FALLBACK=1 python generate.py --device mps --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
60+
# cat ./output_eager
61+

quantize.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -465,11 +465,12 @@ def __init__(
465465
self.register_buffer(
466466
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
467467
)
468-
if groupsize is None or (groupsize == 0):
469-
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
468+
dtype=get_precision()
469+
if group_size is None or (group_size == 0):
470+
self.register_buffer("scales", torch.ones(out_features, dtype=dtype))
470471
else:
471-
groups = (in_features + groupsize - 1) // groupsize
472-
self.register_buffer("scales", torch.ones(out_features, groups, dtype=torch.bfloat16))
472+
groups = (in_features + group_size - 1) // group_size
473+
self.register_buffer("scales", torch.ones(out_features, groups, dtype=dtype))
473474

474475
def forward(self, input: torch.Tensor) -> torch.Tensor:
475476
scales = self.scales
@@ -683,12 +684,21 @@ def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1):
683684
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
684685
origin_x_size = x.size()
685686
x = x.reshape(-1, origin_x_size[-1])
686-
c = torch.ops.aten._weight_int4pack_mm(
687-
x.to(dtype=torch.bfloat16),
688-
weight_int4pack,
689-
groupsize,
690-
scales_and_zeros.to(dtype=torch.bfloat16)
691-
).to(dtype=x.dtype)
687+
if x.dtype == torch.float:
688+
# work around missing int4pack_mm for torch.float
689+
c = torch.ops.aten._weight_int4pack_mm(
690+
x.to(torch.float16),
691+
weight_int4pack,
692+
groupsize,
693+
scales_and_zeros.to(torch.float16),
694+
).to(torch.float)
695+
else:
696+
c = torch.ops.aten._weight_int4pack_mm(
697+
x,
698+
weight_int4pack,
699+
groupsize,
700+
scales_and_zeros,
701+
)
692702
new_shape = origin_x_size[:-1] + (out_features,)
693703
c = c.reshape(new_shape)
694704
return c

quantized_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ def linear_int4(
120120
origin_input_size = input.size()
121121
input = input.reshape(-1, origin_input_size[-1])
122122
c = torch.ops.aten._weight_int4pack_mm(
123-
input.to(dtype=torch.bfloat16),
123+
input,
124124
weight_int4pack,
125125
groupsize,
126-
scales_and_zeros.to(dtype=torch.bfloat16)
127-
).to(dtype=input.dtype)
126+
scales_and_zeros,
127+
)
128128
new_shape = origin_input_size[:-1] + (out_features,)
129129
c = c.reshape(new_shape)
130130
return c

0 commit comments

Comments
 (0)