Skip to content

Commit cd9a5fa

Browse files
Add torchao mps lowbit ops to llama runner
1 parent daf9aee commit cd9a5fa

File tree

3 files changed

+39
-6
lines changed

3 files changed

+39
-6
lines changed

examples/models/llama/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ cmake_dependent_option(
3838
)
3939

4040
option(EXECUTORCH_BUILD_TORCHAO "Build the torchao kernels" OFF)
41+
option(EXECUTORCH_BUILD_TORCHAO_MPS "Build the torchao mps kernels" OFF)
4142

4243
if(NOT PYTHON_EXECUTABLE)
4344
set(PYTHON_EXECUTABLE python3)
@@ -130,6 +131,13 @@ if(EXECUTORCH_BUILD_TORCHAO)
130131
list(APPEND link_libraries torchao_ops_executorch)
131132
endif()
132133

134+
if(EXECUTORCH_BUILD_TORCHAO_MPS)
135+
set(TORCHAO_BUILD_EXECUTORCH_OPS ON)
136+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental/ops/mps ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental/ops/mps)
137+
target_link_options_shared_lib(torchao_ops_mps_executorch)
138+
list(APPEND link_libraries torchao_ops_mps_executorch)
139+
endif()
140+
133141
set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack)
134142
# Extra compile option and include dir for pthreadpool
135143
if(EXECUTORCH_BUILD_PTHREADPOOL)

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def get_quantizer_and_quant_params(args):
600600

601601
def _qmode_type(value):
602602
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
603-
patterns = [r"torchao:8da(\d+)w"]
603+
patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"]
604604

605605
if value in choices:
606606
return value

examples/models/llama/source_transformation/quantize.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,37 @@ def quantize( # noqa C901
7272
if qmode == "int8":
7373
# Add quantization mode options here: group size, bit width, etc.
7474
return WeightOnlyInt8QuantHandler(model).quantized_model()
75-
elif qmode.startswith("torchao:"):
75+
elif qmode.startswith("torchao:fpa"):
76+
pattern = r"torchao:fpa(\d+)w"
77+
matches = re.findall(pattern, qmode)
78+
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
79+
bitwidth = int(matches[0][0])
80+
_load_torchao_aten_lib(
81+
libname="libtorchao_ops_mps_linear_fp_act_xbit_weight_aten"
82+
)
83+
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
84+
85+
with torch.no_grad():
86+
model = (
87+
UIntxWeightOnlyLinearQuantizer(
88+
device="mps",
89+
precision=torch.float32,
90+
groupsize=group_size,
91+
bitwidth=bitwidth,
92+
)
93+
.quantize(model)
94+
.to("cpu")
95+
)
96+
97+
if verbose:
98+
print("quantized model:", model)
99+
return model
100+
elif qmode.startswith("torchao:8da"):
76101
pattern = r"torchao:8da(\d+)w"
77102
matches = re.findall(pattern, qmode)
78103
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
79104
bitwidth = int(matches[0][0])
80-
_load_torchao_ops_aten()
105+
_load_torchao_aten_lib(libname="libtorchao_ops_aten")
81106
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
82107

83108
with torch.no_grad():
@@ -729,7 +754,7 @@ def get_quant_embedding_transform(args):
729754
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
730755
group_size = int(group_size)
731756
bitwidth = int(bitwidth)
732-
_load_torchao_ops_aten()
757+
_load_torchao_aten_lib(libname="libtorchao_ops_aten")
733758
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
734759

735760
def _torchao_embedding_quantizer(model):
@@ -785,15 +810,15 @@ def get_quant_weight_transform(args, dtype_override, verbose):
785810
)
786811

787812

788-
def _load_torchao_ops_aten():
813+
def _load_torchao_aten_lib(libname):
789814
import glob
790815
import os
791816

792817
libs = glob.glob(
793818
os.path.abspath(
794819
os.path.join(
795820
os.environ.get("CMAKE_INSTALL_PREFIX", ""),
796-
"lib/libtorchao_ops_aten.*",
821+
f"lib/{libname}.*",
797822
)
798823
)
799824
)

0 commit comments

Comments
 (0)