Skip to content

Commit caba051

Browse files
committed
up
1 parent d6149e1 commit caba051

File tree

3 files changed

+43
-19
lines changed

3 files changed

+43
-19
lines changed

examples/models/llama/CMakeLists.txt

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,20 @@ endif()
116116

117117
if(EXECUTORCH_BUILD_TORCHAO)
118118
set(TORCHAO_BUILD_EXECUTORCH_OPS ON)
119-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental)
119+
set(TORCHAO_BUILD_CPU_AARCH64 ON)
120+
add_subdirectory(
121+
${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental
122+
${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental
123+
)
120124
target_link_options_shared_lib(torchao_ops_executorch)
121125
list(APPEND link_libraries torchao_ops_executorch)
122-
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
123-
add_subdirectory(
124-
${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental/ops/mps
125-
${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental/ops/mps)
126-
target_link_options_shared_lib(torchao_ops_mps_executorch)
127-
list(APPEND link_libraries torchao_ops_mps_executorch)
128-
endif()
126+
# if(CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
127+
# add_subdirectory(
128+
# ${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental/ops/mps
129+
# ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental/ops/mps)
130+
# target_link_options_shared_lib(torchao_ops_mps_executorch)
131+
# list(APPEND link_libraries torchao_ops_mps_executorch)
132+
# endif()
129133
endif()
130134

131135
set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack)

examples/models/llama/source_transformation/quantize.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,21 +98,38 @@ def quantize( # noqa C901
9898
matches = re.findall(pattern, qmode)
9999
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
100100
bitwidth = int(matches[0][0])
101-
_load_torchao_aten_lib(libname="libtorchao_ops_aten")
102-
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
101+
# _load_torchao_aten_lib(libname="libtorchao_ops_aten")
102+
# from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
103+
from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight, Int8DynActIntxWeightLinearQuantizer
104+
from torchao.quantization.quant_api import quantize_
105+
from torchao.utils import unwrap_tensor_subclass
106+
from torchao.quantization.granularity import PerRow, PerGroup
103107

104108
with torch.no_grad():
105-
model = Int8DynActIntxWeightLinearQuantizer(
106-
device="cpu",
107-
precision=torch.float32,
108-
groupsize=group_size,
109-
bitwidth=bitwidth,
110-
has_weight_zeros=False,
111-
).quantize(model)
112-
109+
# model = Int8DynActIntxWeightLinearQuantizer(
110+
# device="cpu",
111+
# precision=torch.float32,
112+
# groupsize=group_size,
113+
# bitwidth=bitwidth,
114+
# has_weight_zeros=False,
115+
# ).quantize(model)
116+
117+
quantize_(model,
118+
int8_dynamic_activation_intx_weight(
119+
# group_size=group_size,
120+
# nbit=bitwidth,
121+
# has_weight_zeros=False,
122+
weight_dtype=getattr(torch, f"int{bitwidth}"),
123+
granularity=PerRow() if group_size == 0 else PerGroup(group_size),
124+
has_weight_zeros=False,
125+
),
126+
)
127+
model = unwrap_tensor_subclass(model)
113128
if verbose:
114129
print("quantized model:", model)
115130
return model
131+
132+
return model
116133
elif qmode == "8da4w":
117134
# Check for required args
118135
if group_size is None:
@@ -752,7 +769,7 @@ def get_quant_embedding_transform(args):
752769
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
753770
group_size = int(group_size)
754771
bitwidth = int(bitwidth)
755-
_load_torchao_aten_lib(libname="libtorchao_ops_aten")
772+
# _load_torchao_aten_lib(libname="libtorchao_ops_aten")
756773
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
757774

758775
def _torchao_embedding_quantizer(model):

run.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
for i in {1..5}; do
2+
./cmake-out/examples/models/llama/llama_main --model_path=$MODEL_OUT --tokenizer_path=$TOKENIZER --prompt="Once upon a time,"
3+
done

0 commit comments

Comments
 (0)