Skip to content

Commit 1863fa6

Browse files
committed
Update on "[ET-VK] Clean up api::vTensor class"
## Context Now that we have forked the `api/` directory from PyTorch Vulkan, we can clean up the `vTensor` class and remove functionality that is not necessary for the ExecuTorch Vulkan delegate. The following changes are made: * Remove unused member variables and member functions from `vTensor` and `vTensorStorage` * Remove all quantization related member variables, member functions, and the `vTensor` constructor for quantized tensors. The Quantization API will be reworked from the ground up. * Rename `view_` (which is an instance of `vTensorStorage`) to `storage_` Finally, the critical change that is introduced is that we now store `storage_` as a direct `vTensorStorage` member variable in `vTensor` instead of storing it as a `std::shared_ptr<vTensorStorage>`. For context, the reason `storage_` was stored as a shared pointer is to be compliant with ATen Tensors, which needs to enable copy construction to enable the following: ``` at::Tensor b = at::rand(...); // Oftentimes this will create a "view" of the tensor. a and b will point the the same underlying storage, but with different metadata. at::Tensor a = b; ``` However, in the ExecuTorch delegate this is no longer necessary. Each Tensor is associated with it's own independent storage and is responsible for managing it's own memory. **By getting rid of `std::shared_ptr`, we can avoid a heap allocation and avoid chasing pointers whenever we need to access the resources of a `vTensor`.** Differential Revision: [D55811279](https://our.internmc.facebook.com/intern/diff/D55811279/) [ghstack-poisoned]
2 parents 4ea3fd2 + e11d7af commit 1863fa6

File tree

8 files changed

+99
-26
lines changed

8 files changed

+99
-26
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
3030
}
3131
QNN_TENSOR_TYPE_MAP = {
32+
torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
3233
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
3334
torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8,
3435
torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16,

backends/qualcomm/partition/common_defs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
exir_ops.edge.aten.clone.default,
1414
exir_ops.edge.aten.index.Tensor,
1515
exir_ops.edge.aten.full.default,
16+
exir_ops.edge.aten.slice_scatter.default,
17+
exir_ops.edge.aten.index_put.default,
1618
]
1719

1820
allow_list_operator = [

backends/vulkan/runtime/api/Tensor.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,16 +307,16 @@ api::VulkanImage allocate_image(
307307
};
308308

309309
VkImageType image_type = VK_IMAGE_TYPE_3D;
310-
VkImageViewType image_storage_type = VK_IMAGE_VIEW_TYPE_3D;
310+
VkImageViewType image_view_type;
311311

312312
switch (storage_type) {
313313
case api::kTexture3D:
314314
image_type = VK_IMAGE_TYPE_3D;
315-
image_storage_type = VK_IMAGE_VIEW_TYPE_3D;
315+
image_view_type = VK_IMAGE_VIEW_TYPE_3D;
316316
break;
317317
case api::kTexture2D:
318318
image_type = VK_IMAGE_TYPE_2D;
319-
image_storage_type = VK_IMAGE_VIEW_TYPE_2D;
319+
image_view_type = VK_IMAGE_VIEW_TYPE_2D;
320320
break;
321321
default:
322322
// Return an empty VulkanImage by default
@@ -329,7 +329,7 @@ api::VulkanImage allocate_image(
329329
api::create_extent3d(extents),
330330
image_format,
331331
image_type,
332-
image_storage_type,
332+
image_view_type,
333333
sampler_props,
334334
sampler,
335335
/*allow_transfer = */ true,

backends/vulkan/runtime/api/Tensor.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ class vTensorStorage final {
4242
const api::ScalarType dtype,
4343
const bool allocate_memory = true);
4444

45-
~vTensorStorage();
46-
4745
vTensorStorage(const vTensorStorage& other) = delete;
4846
vTensorStorage& operator=(const vTensorStorage& other) = delete;
4947

5048
vTensorStorage(vTensorStorage&& other) = default;
5149
vTensorStorage& operator=(vTensorStorage&& other) = default;
5250

51+
~vTensorStorage();
52+
5353
friend class vTensor;
5454

5555
private:
@@ -130,9 +130,6 @@ class vTensor final {
130130
// image texture that can be passed into a shader.
131131
std::shared_ptr<api::UniformParamsBuffer> extents_uniform_;
132132

133-
// Store the backing storage of the tensor as a shared pointer to allow two
134-
// tensors to share the same underlying resource, but with different metadata.
135-
// This will
136133
vTensorStorage storage_;
137134

138135
public:

examples/models/llama2/README.md

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ Please note that the models are subject to the [acceptable use policy](https://g
1717

1818
# Results
1919

20-
Since 7B Llama2 model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized model.
20+
Since 7B Llama2 model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized model.
2121

22-
For Llama3, we can use the same process. Note that it's only supported in the ExecuTorch main branch.
22+
For Llama3, we can use the same process. Note that it's only supported in the ExecuTorch main branch.
2323

2424
## Quantization:
2525
We employed 4-bit groupwise per token dynamic quantization of all the linear layers of the model. Dynamic quantization refers to quantizating activations dynamically, such that quantization parameters for activations are calculated, from min/max range, at runtime. Here we quantized activations with 8bits (signed integer). Furthermore, weights are statically quantized. In our case weights were per-channel groupwise quantized with 4bit signed integer. For more information refer to this [page](https://github.com/pytorch-labs/ao/).
@@ -57,7 +57,7 @@ Performance was measured on Samsung Galaxy S22, S24, One Plus 12 and iPhone 15 m
5757
- For Llama7b, your device may require at least 32GB RAM. If this is a constraint for you, please try the smaller stories model.
5858

5959
## Step 1: Setup
60-
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch
60+
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. For installation run `./install_requirements.sh --pybind xnnpack`
6161
2. Run `examples/models/llama2/install_requirements.sh` to install a few dependencies.
6262

6363
## Step 2: Prepare model
@@ -103,6 +103,16 @@ If you want to deploy and run a smaller model for educational purposes. From `ex
103103
python -m examples.models.llama2.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
104104
```
105105
106+
### Option C: Download and export Llama3 8B model
107+
108+
You can export and run the original Llama3 8B model.
109+
110+
1. Llama3 pretrained parameters can be downloaded from [Meta's official llama3 repository](https://github.com/meta-llama/llama3/).
111+
112+
2. Export model and generate `.pte` file
113+
```
114+
python -m examples.models.llama2.export_llama --checkpoint <consolidated.00.pth> -p <params.json> -d=fp32 -X -qmode 8da4w -kv --use_sdpa_with_kv_cache --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte" group_size 128 --metadata '{"get_bos_id":128000, "get_eos_id":128001}' --embedding-quantize 4,32
115+
```
106116
107117
## (Optional) Finetuning
108118
@@ -148,6 +158,7 @@ The Uncyclotext results generated above used: `{max_seq_len: 2048, limit: 1000}`
148158
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
149159
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
150160
-DEXECUTORCH_BUILD_XNNPACK=ON \
161+
-DEXECUTORCH_BUILD_QUANTIZED=ON \
151162
-DEXECUTORCH_BUILD_OPTIMIZED=ON \
152163
-DEXECUTORCH_BUILD_CUSTOM=ON \
153164
-Bcmake-out .
@@ -163,17 +174,22 @@ The Uncyclotext results generated above used: `{max_seq_len: 2048, limit: 1000}`
163174
-DEXECUTORCH_BUILD_CUSTOM=ON \
164175
-DEXECUTORCH_BUILD_OPTIMIZED=ON \
165176
-DEXECUTORCH_BUILD_XNNPACK=ON \
177+
-DEXECUTORCH_BUILD_QUANTIZED=ON \
166178
-Bcmake-out/examples/models/llama2 \
167179
examples/models/llama2
168180
169181
cmake --build cmake-out/examples/models/llama2 -j16 --config Release
170182
```
171183
184+
For Llama3, add `-DEXECUTORCH_USE_TIKTOKEN=ON` option when building the llama runner.
185+
172186
3. Run model. Run options available [here](https://github.com/pytorch/executorch/blob/main/examples/models/llama2/main.cpp#L18-L40).
173187
```
174188
cmake-out/examples/models/llama2/llama_main --model_path=<model pte file> --tokenizer_path=<tokenizer.bin> --prompt=<prompt>
175189
```
176190
191+
For Llama3, you can pass the original `tokenizer.model` (without converting to `.bin` file).
192+
177193
## Step 5: Run benchmark on Android phone
178194
179195
**1. Build llama runner binary for Android**
@@ -271,7 +287,7 @@ This example tries to reuse the Python code, with minimal modifications to make
271287
```
272288
git clean -xfd
273289
pip uninstall executorch
274-
./install_requirements.sh <options>
290+
./install_requirements.sh --pybind xnnpack
275291

276292
rm -rf cmake-out
277293
```

examples/models/llama2/export_llama_lib.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,13 @@ def build_args_parser() -> argparse.ArgumentParser:
355355
parser.add_argument(
356356
"--pt2e_quantize",
357357
default=None,
358+
choices=[
359+
"xnnpack_dynamic",
360+
"xnnpack_dynamic_qc4",
361+
"qnn_8a8w",
362+
"qnn_16a16w",
363+
"qnn_16a4w",
364+
],
358365
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
359366
)
360367
parser.add_argument(
@@ -627,6 +634,9 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
627634
if args.use_sdpa_with_kv_cache:
628635
transforms.append(replace_sdpa_with_custom_op)
629636

637+
if args.qnn and args.use_kv_cache:
638+
transforms.append(replace_sdpa_with_simple_sdpa)
639+
transforms.append(replace_causal_mask)
630640
return (
631641
load_llama_model(
632642
modelname=modelname,
@@ -650,13 +660,16 @@ def _export_llama(modelname, args) -> str: # noqa: C901
650660
# export_to_edge
651661
pt2e_quant_params = _get_pt2e_quantization_params(args)
652662
quantizers = get_pt2e_quantizers(pt2e_quant_params, args)
653-
if args.qnn:
654-
assert (
655-
args.quantization_mode is None
656-
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
663+
quant_dtype = None
664+
if args.qnn and args.pt2e_quantize:
657665
try:
658666
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer`
659-
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
667+
from executorch.backends.qualcomm.quantizer.quantizer import (
668+
get_16a4w_qnn_ptq_config,
669+
get_default_16bit_qnn_ptq_config,
670+
QnnQuantizer,
671+
QuantDtype,
672+
)
660673

661674
# reset quantizers and pt2e_quant_params from xnnpack backend
662675
pt2e_quant_params = None
@@ -666,10 +679,41 @@ def _export_llama(modelname, args) -> str: # noqa: C901
666679
"Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html"
667680
)
668681

682+
backend, quant_config = args.pt2e_quantize.split("_")
683+
assert (
684+
backend == "qnn"
685+
), f"The quantization config is for backend {backend} instead of qnn."
669686
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
670687
qnn_quantizer = QnnQuantizer()
671688
# more custom quantization are supported including 16a4w etc. default to 8bit quantized
672689
custom_annotations = ()
690+
if quant_config == "8a8w":
691+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
692+
quant_dtype = QuantDtype.use_8a8w
693+
pass
694+
elif quant_config == "16a16w":
695+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
696+
quant_dtype = QuantDtype.use_16a16w
697+
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
698+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
699+
qnn_quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config())
700+
elif quant_config == "16a4w":
701+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
702+
quant_dtype = QuantDtype.use_16a4w
703+
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
704+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
705+
qnn_quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config())
706+
qnn_quantizer.set_per_channel_weight_dtype(
707+
weight_dtype_for_16bit_act="int4"
708+
)
709+
else:
710+
raise AssertionError(
711+
f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w."
712+
)
713+
714+
assert (
715+
args.quantization_mode is None
716+
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
673717
qnn_quantizer.add_custom_quant_annotations(custom_annotations)
674718
quantizers.append(qnn_quantizer)
675719

@@ -786,25 +830,38 @@ def _export_llama(modelname, args) -> str: # noqa: C901
786830
"Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html"
787831
)
788832

789-
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
790-
backend_options = generate_htp_compiler_spec(use_fp16=False)
833+
use_fp16 = True
834+
skip_node_op_set = {}
835+
if args.pt2e_quantize:
836+
use_fp16 = False
837+
# TODO: fix the lowering error without skipping nodes
838+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
839+
if quant_dtype == QuantDtype.use_8a8w:
840+
raise NotImplementedError("8a8w for llama is still under development")
841+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
842+
elif quant_dtype == QuantDtype.use_16a16w:
843+
raise NotImplementedError("16a16w for llama is still under development")
844+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
845+
elif quant_dtype == QuantDtype.use_16a4w:
846+
raise NotImplementedError("16a4w for llama is still under development")
791847
partitioners.append(
792848
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
793849
QnnPartitioner(
794850
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
795851
generate_qnn_executorch_compiler_spec(
796852
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
797853
soc_model=QcomChipset.SM8650, # default to SM8650
798-
backend_options=backend_options,
854+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
855+
backend_options=generate_htp_compiler_spec(use_fp16=use_fp16),
799856
debug=False,
800857
saver=False,
801858
),
802859
skip_node_id_set={},
803-
skip_node_op_set={},
860+
skip_node_op_set=skip_node_op_set,
804861
)
805862
)
806863
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
807-
_transform(builder_exported_to_edge.export_program())
864+
_transform(builder_exported_to_edge.edge_manager.exported_program())
808865

809866
if args.generate_etrecord:
810867
if not builder_exported_to_edge.edge_manager:

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def embedding_byte_dtype_out_meta(
179179

180180
quantized_decomposed_lib.define(
181181
"embedding_4bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
182-
"int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None) -> Tensor",
182+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
183183
)
184184

185185
quantized_decomposed_lib.define(

extension/module/module.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ Module::Module(
5151
std::unique_ptr<EventTracer> event_tracer)
5252
: data_loader_(std::move(data_loader)),
5353
memory_allocator_(
54-
std::move(memory_allocator)
55-
?: std::make_unique<util::MallocMemoryAllocator>()),
54+
memory_allocator ? std::move(memory_allocator)
55+
: std::make_unique<util::MallocMemoryAllocator>()),
5656
event_tracer_(std::move(event_tracer)) {
5757
runtime_init();
5858
}

0 commit comments

Comments
 (0)