Skip to content

Commit a809953

Browse files
authored
Add torchao kernels to llama runner
Differential Revision: D64942925 Pull Request resolved: #6195
1 parent 146ca1b commit a809953

File tree

8 files changed

+110
-9
lines changed

8 files changed

+110
-9
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,6 @@
6464
[submodule "third-party/pybind11"]
6565
path = third-party/pybind11
6666
url = https://github.com/pybind/pybind11.git
67+
[submodule "third-party/ao"]
68+
path = third-party/ao
69+
url = https://github.com/pytorch/ao.git

examples/models/llama/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ cmake_dependent_option(
3737
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF
3838
)
3939

40+
option(EXECUTORCH_BUILD_TORCHAO "Build the torchao kernels" OFF)
41+
4042
if(NOT PYTHON_EXECUTABLE)
4143
set(PYTHON_EXECUTABLE python3)
4244
endif()
@@ -121,6 +123,13 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
121123
list(APPEND link_libraries custom_ops)
122124
endif()
123125

126+
if(EXECUTORCH_BUILD_TORCHAO)
127+
set(TORCHAO_BUILD_EXECUTORCH_OPS ON)
128+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental)
129+
target_link_options_shared_lib(torchao_ops_executorch)
130+
list(APPEND link_libraries torchao_ops_executorch)
131+
endif()
132+
124133
set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack)
125134
# Extra compile option and include dir for pthreadpool
126135
if(EXECUTORCH_BUILD_PTHREADPOOL)

examples/models/llama/export_llama_lib.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
import copy
1313
import json
1414
import logging
15+
import re
1516
import shlex
1617
from enum import Enum
1718
from json import JSONDecodeError
1819
from pathlib import Path
1920
from typing import Callable, List, Optional, Union
2021

2122
import pkg_resources
22-
2323
import torch
2424

2525
from executorch.devtools.etrecord import generate_etrecord
@@ -153,12 +153,12 @@ def build_args_parser() -> argparse.ArgumentParser:
153153
],
154154
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
155155
)
156+
156157
parser.add_argument(
157158
"-qmode",
158159
"--quantization_mode",
159-
type=str,
160+
type=_qmode_type,
160161
default=None,
161-
choices=["int8", "8da4w", "8da4w-gptq", "vulkan_4w"],
162162
help="type of quantization",
163163
)
164164

@@ -568,6 +568,23 @@ def get_quantizer_and_quant_params(args):
568568
return pt2e_quant_params, quantizers, quant_dtype
569569

570570

571+
def _qmode_type(value):
572+
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
573+
patterns = [r"torchao:8da(\d+)w"]
574+
575+
if value in choices:
576+
return value
577+
578+
for pattern in patterns:
579+
matches = re.findall(pattern, value)
580+
if len(matches) == 1:
581+
return value
582+
583+
raise argparse.ArgumentTypeError(
584+
f"Got qmode {value}, but expected one of {choices}, or one of the regex patterns {patterns}."
585+
)
586+
587+
571588
def _validate_args(args):
572589
"""
573590
TODO: Combine all the backends under --backend args
@@ -581,6 +598,19 @@ def _validate_args(args):
581598
if args.num_sharding > 0 and not args.qnn:
582599
raise ValueError("Model shard is only supported with qnn backend now.")
583600

601+
if (
602+
args.quantization_mode is not None
603+
and args.quantization_mode.startswith("torchao:")
604+
) or (
605+
args.embedding_quantize is not None
606+
and args.embedding_quantize.startswith("torchao:")
607+
):
608+
if args.enable_dynamic_shape:
609+
raise ValueError(
610+
"Dynamic shape is not currently supported with torchao ops. Please use --disable_dynamic_shape."
611+
"If you need this feature, please file an issue."
612+
)
613+
584614

585615
def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
586616
_validate_args(args)

examples/models/llama/install_requirements.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
pip install snakeviz sentencepiece
1111

1212
# Install torchao.
13-
TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt)
14-
pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}"
13+
pip install "$(dirname "$0")/../../../third-party/ao"
1514

1615
# Install lm-eval for Model Evaluation with lm-evalution-harness
1716
# Install tiktoken for tokenizer

examples/models/llama/source_transformation/quantize.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
8+
import re
79
from functools import partial
810
from pathlib import Path
911
from typing import Any, Dict, Optional
@@ -70,6 +72,26 @@ def quantize( # noqa C901
7072
if qmode == "int8":
7173
# Add quantization mode options here: group size, bit width, etc.
7274
return WeightOnlyInt8QuantHandler(model).quantized_model()
75+
elif qmode.startswith("torchao:"):
76+
pattern = r"torchao:8da(\d+)w"
77+
matches = re.findall(pattern, qmode)
78+
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
79+
bitwidth = int(matches[0][0])
80+
_load_torchao_ops_aten()
81+
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
82+
83+
with torch.no_grad():
84+
model = Int8DynActIntxWeightLinearQuantizer(
85+
device="cpu",
86+
precision=torch.float32,
87+
groupsize=group_size,
88+
bitwidth=bitwidth,
89+
has_weight_zeros=False,
90+
).quantize(model)
91+
92+
if verbose:
93+
print("quantized model:", model)
94+
return model
7395
elif qmode == "8da4w":
7496
# Check for required args
7597
if group_size is None:
@@ -79,6 +101,7 @@ def quantize( # noqa C901
79101
model = Int8DynActInt4WeightQuantizer(
80102
precision=torch_dtype, groupsize=group_size
81103
).quantize(model)
104+
82105
if verbose:
83106
print("quantized model:", model)
84107
return model
@@ -692,6 +715,25 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
692715

693716

694717
def get_quant_embedding_transform(args):
718+
if args.embedding_quantize.startswith("torchao:"):
719+
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
720+
group_size = int(group_size)
721+
bitwidth = int(bitwidth)
722+
_load_torchao_ops_aten()
723+
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
724+
725+
def _torchao_embedding_quantizer(model):
726+
with torch.no_grad():
727+
model = IntxWeightEmbeddingQuantizer(
728+
device="cpu",
729+
precision=torch.float32,
730+
bitwidth=bitwidth,
731+
groupsize=group_size,
732+
).quantize(model)
733+
return model
734+
735+
return _torchao_embedding_quantizer
736+
695737
bitwidth, group_size = args.embedding_quantize.split(",")
696738
if group_size == "none" or group_size == "None" or group_size == "0":
697739
group_size = None
@@ -733,4 +775,23 @@ def get_quant_weight_transform(args, dtype_override, verbose):
733775
)
734776

735777

778+
def _load_torchao_ops_aten():
779+
import glob
780+
import os
781+
782+
libs = glob.glob(
783+
os.path.abspath(
784+
os.path.join(
785+
os.environ.get("CMAKE_INSTALL_PREFIX", ""),
786+
"lib/libtorchao_ops_aten.*",
787+
)
788+
)
789+
)
790+
assert (
791+
len(libs) == 1
792+
), f"Expected 1 library but got {len(libs)}. If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly."
793+
logging.info(f"Loading custom ops library: {libs[0]}")
794+
torch.ops.load_library(libs[0])
795+
796+
736797
############################ Source Transform End #######################

examples/models/llama3_2_vision/install_requirements.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,4 @@
99
pip install --pre torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu --no-cache-dir
1010

1111
# Install torchao.
12-
TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt)
13-
pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}"
12+
pip install "$(dirname "$0")/../../../third-party/ao"

examples/models/phi-3-mini-lora/install_requirements.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,4 @@ pip install torchtune
1010
pip install tiktoken
1111

1212
# Install torchao.
13-
TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt)
14-
pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}"
13+
pip install "$(dirname "$0")/../../../third-party/ao"

third-party/ao

Submodule ao added at 75d0693

0 commit comments

Comments
 (0)