Skip to content

Commit df0b06c

Browse files
Add torchao mps lowbit ops to llama runner (#7037)
* Add torchao mps lowbit ops to llama runner * Update ao submodule
1 parent 80f1c1b commit df0b06c

File tree

4 files changed

+37
-7
lines changed

4 files changed

+37
-7
lines changed

examples/models/llama/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ if(EXECUTORCH_BUILD_TORCHAO)
128128
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental)
129129
target_link_options_shared_lib(torchao_ops_executorch)
130130
list(APPEND link_libraries torchao_ops_executorch)
131+
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
132+
add_subdirectory(
133+
${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental/ops/mps
134+
${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental/ops/mps)
135+
target_link_options_shared_lib(torchao_ops_mps_executorch)
136+
list(APPEND link_libraries torchao_ops_mps_executorch)
137+
endif()
131138
endif()
132139

133140
set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack)

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def get_quantizer_and_quant_params(args):
600600

601601
def _qmode_type(value):
602602
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
603-
patterns = [r"torchao:8da(\d+)w"]
603+
patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"]
604604

605605
if value in choices:
606606
return value

examples/models/llama/source_transformation/quantize.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,35 @@ def quantize( # noqa C901
7272
if qmode == "int8":
7373
# Add quantization mode options here: group size, bit width, etc.
7474
return WeightOnlyInt8QuantHandler(model).quantized_model()
75-
elif qmode.startswith("torchao:"):
75+
elif qmode.startswith("torchao:fpa"):
76+
pattern = r"torchao:fpa(\d+)w"
77+
matches = re.findall(pattern, qmode)
78+
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
79+
bitwidth = int(matches[0][0])
80+
_load_torchao_aten_lib(libname="libtorchao_ops_mps_aten")
81+
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
82+
83+
with torch.no_grad():
84+
model = (
85+
UIntxWeightOnlyLinearQuantizer(
86+
device="mps",
87+
precision=torch.float32,
88+
groupsize=group_size,
89+
bitwidth=bitwidth,
90+
)
91+
.quantize(model)
92+
.to("cpu")
93+
)
94+
95+
if verbose:
96+
print("quantized model:", model)
97+
return model
98+
elif qmode.startswith("torchao:8da"):
7699
pattern = r"torchao:8da(\d+)w"
77100
matches = re.findall(pattern, qmode)
78101
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
79102
bitwidth = int(matches[0][0])
80-
_load_torchao_ops_aten()
103+
_load_torchao_aten_lib(libname="libtorchao_ops_aten")
81104
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
82105

83106
with torch.no_grad():
@@ -729,7 +752,7 @@ def get_quant_embedding_transform(args):
729752
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
730753
group_size = int(group_size)
731754
bitwidth = int(bitwidth)
732-
_load_torchao_ops_aten()
755+
_load_torchao_aten_lib(libname="libtorchao_ops_aten")
733756
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
734757

735758
def _torchao_embedding_quantizer(model):
@@ -785,15 +808,15 @@ def get_quant_weight_transform(args, dtype_override, verbose):
785808
)
786809

787810

788-
def _load_torchao_ops_aten():
811+
def _load_torchao_aten_lib(libname):
789812
import glob
790813
import os
791814

792815
libs = glob.glob(
793816
os.path.abspath(
794817
os.path.join(
795818
os.environ.get("CMAKE_INSTALL_PREFIX", ""),
796-
"lib/libtorchao_ops_aten.*",
819+
f"lib/{libname}.*",
797820
)
798821
)
799822
)

third-party/ao

Submodule ao updated 354 files

0 commit comments

Comments
 (0)