Skip to content

Commit 81fce9c

Browse files
Use int4mm weight packing MPS kernel (#866)
* Use int4mm weight packing mps kernel * update torch nightly
1 parent d71783c commit 81fce9c

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ $PIP_EXECUTABLE install -r requirements.txt --extra-index-url https://download.p
4646
# NOTE: If a newly-fetched version of the executorch repo changes the value of
4747
# NIGHTLY_VERSION, you should re-run this script to install the necessary
4848
# package versions.
49-
NIGHTLY_VERSION=dev20240613
49+
NIGHTLY_VERSION=dev20240624
5050

5151
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
5252
$PIP_EXECUTABLE uninstall -y triton

quantization/qops.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -395,15 +395,9 @@ def _prepare_weight_and_scales_and_zeros(
395395
weight_int32, scales_and_zeros = group_quantize_tensor(
396396
weight_bf16, n_bit=4, groupsize=groupsize
397397
)
398-
if weight_bf16.device.type == "mps":
399-
# There are still no MPS-accelerated conversion OP
400-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
401-
weight_int32.cpu(), inner_k_tiles
402-
).to("mps")
403-
else:
404-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
405-
weight_int32, inner_k_tiles
406-
)
398+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
399+
weight_int32, inner_k_tiles
400+
)
407401
return weight_int4pack, scales_and_zeros
408402

409403
@classmethod

0 commit comments

Comments
 (0)