Skip to content

Commit ad4dbaf

Browse files
Add torchao mps lowbit ops to llama runner
1 parent daf9aee commit ad4dbaf

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

examples/models/llama/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ cmake_dependent_option(
3838
)
3939

4040
option(EXECUTORCH_BUILD_TORCHAO "Build the torchao kernels" OFF)
41+
option(EXECUTORCH_BUILD_TORCHAO_MPS "Build the torchao mps kernels" OFF)
4142

4243
if(NOT PYTHON_EXECUTABLE)
4344
set(PYTHON_EXECUTABLE python3)
@@ -130,6 +131,13 @@ if(EXECUTORCH_BUILD_TORCHAO)
130131
list(APPEND link_libraries torchao_ops_executorch)
131132
endif()
132133

134+
if(EXECUTORCH_BUILD_TORCHAO_MPS)
135+
set(TORCHAO_BUILD_EXECUTORCH_OPS ON)
136+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental/ops/mps ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental/ops/mps)
137+
target_link_options_shared_lib(torchao_ops_mps_executorch)
138+
list(APPEND link_libraries torchao_ops_mps_executorch)
139+
endif()
140+
133141
set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack)
134142
# Extra compile option and include dir for pthreadpool
135143
if(EXECUTORCH_BUILD_PTHREADPOOL)

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def get_quantizer_and_quant_params(args):
600600

601601
def _qmode_type(value):
602602
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
603-
patterns = [r"torchao:8da(\d+)w"]
603+
patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"]
604604

605605
if value in choices:
606606
return value

examples/models/llama/source_transformation/quantize.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,33 @@ def quantize( # noqa C901
7272
if qmode == "int8":
7373
# Add quantization mode options here: group size, bit width, etc.
7474
return WeightOnlyInt8QuantHandler(model).quantized_model()
75-
elif qmode.startswith("torchao:"):
75+
elif qmode.startswith("torchao:fpa"):
76+
pattern = r"torchao:fpa(\d+)w"
77+
matches = re.findall(pattern, qmode)
78+
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
79+
bitwidth = int(matches[0][0])
80+
_load_torchao_aten_lib(
81+
libname="libtorchao_ops_mps_linear_fp_act_xbit_weight_aten"
82+
)
83+
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
84+
85+
with torch.no_grad():
86+
model = UIntxWeightOnlyLinearQuantizer(
87+
device="mps",
88+
precision=torch.float32,
89+
groupsize=group_size,
90+
bitwidth=bitwidth
91+
).quantize(model).to("cpu")
92+
93+
if verbose:
94+
print("quantized model:", model)
95+
return model
96+
elif qmode.startswith("torchao:8da"):
7697
pattern = r"torchao:8da(\d+)w"
7798
matches = re.findall(pattern, qmode)
7899
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
79100
bitwidth = int(matches[0][0])
80-
_load_torchao_ops_aten()
101+
_load_torchao_aten_lib(libname="libtorchao_ops_aten")
81102
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
82103

83104
with torch.no_grad():
@@ -729,7 +750,7 @@ def get_quant_embedding_transform(args):
729750
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
730751
group_size = int(group_size)
731752
bitwidth = int(bitwidth)
732-
_load_torchao_ops_aten()
753+
_load_torchao_aten_lib(libname="libtorchao_ops_aten")
733754
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
734755

735756
def _torchao_embedding_quantizer(model):
@@ -785,15 +806,15 @@ def get_quant_weight_transform(args, dtype_override, verbose):
785806
)
786807

787808

788-
def _load_torchao_ops_aten():
809+
def _load_torchao_aten_lib(libname):
789810
import glob
790811
import os
791812

792813
libs = glob.glob(
793814
os.path.abspath(
794815
os.path.join(
795816
os.environ.get("CMAKE_INSTALL_PREFIX", ""),
796-
"lib/libtorchao_ops_aten.*",
817+
f"lib/{libname}.*",
797818
)
798819
)
799820
)

0 commit comments

Comments
 (0)