Skip to content

Qualcomm AI Engine Direct - LPBQ enablement #9313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ PRs are always welcome to help improve the codebase in a comprehensive manner. B
- [shewu-quic](https://github.com/shewu-quic)
- [chunit-quic](https://github.com/chunit-quic)
- [winskuo-quic](https://github.com/winskuo-quic)
- [DannyYuyang-quic](https://github.com/DannyYuyang-quic)
- [haowhsu-quic](https://github.com/haowhsu-quic)

Thanks again for your contribution!
2 changes: 0 additions & 2 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .annotate_quant_attrs import AnnotateQuantAttrs
from .constant_i64_to_i32 import ConstantI64toI32
from .convert_bmm_to_matmul import ConvertBmmToMatmul
from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D
from .convert_to_linear import ConvertToLinear
from .decompose_any import DecomposeAny
from .decompose_einsum import DecomposeEinsum
Expand Down Expand Up @@ -30,7 +29,6 @@
AnnotateQuantAttrs,
ConstantI64toI32,
ConvertBmmToMatmul,
ConvertInterpolateWithUpsample2D,
RecomposePReLU,
ConvertToLinear,
DecomposeAny,
Expand Down
17 changes: 15 additions & 2 deletions backends/qualcomm/_passes/annotate_quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
from executorch.backends.qualcomm.utils.constants import (
QCOM_AXIS,
QCOM_BLOCK_SIZE,
QCOM_DTYPE,
QCOM_ENCODING,
QCOM_QUANT_ATTRS,
Expand Down Expand Up @@ -122,13 +123,25 @@ def _dequant_fold_params(self, n, quant_attrs, param):
scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis)
offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis)
param = param.sub(offsets).mul(scales).to(torch.float32).contiguous()
set_parameter(param, n.args[0], self.edge_program)
elif quant_attrs[QCOM_ENCODING] in [
exir_ops.edge.pt2e_quant.dequantize_affine.default
]:
param = torch.ops.pt2e_quant.dequantize_affine(
param,
block_size=quant_attrs[QCOM_BLOCK_SIZE],
scale=quant_attrs[QCOM_SCALE],
zero_point=quant_attrs[QCOM_ZERO_POINT],
input_dtype=quant_attrs[QCOM_DTYPE],
quant_min=quant_attrs[QCOM_QUANT_MIN],
quant_max=quant_attrs[QCOM_QUANT_MAX],
output_dtype=torch.float32,
)
else:
scale = quant_attrs[QCOM_SCALE]
offset = quant_attrs[QCOM_ZERO_POINT]
param = param.sub(offset).mul(scale).to(torch.float32).contiguous()
set_parameter(param, n.args[0], self.edge_program)

set_parameter(param, n.args[0], self.edge_program)
n.args[0].meta["val"] = param

def _annotate_quant_attrs(
Expand Down

This file was deleted.

1 change: 1 addition & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.pixel_shuffle.default,
exir_ops.edge.aten.pixel_unshuffle.default,
exir_ops.edge.aten.upsample_bilinear2d.default,
exir_ops.edge.aten.upsample_bilinear2d.vec,
exir_ops.edge.aten.upsample_nearest2d.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
}
Expand Down
11 changes: 6 additions & 5 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from executorch.backends.qualcomm.builders.utils import get_parameter
from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING
from executorch.backends.qualcomm.utils.constants import QCOM_DTYPE, QCOM_ENCODING
from executorch.exir.dialects._ops import ops as exir_ops
from torch._subclasses import FakeTensor

Expand Down Expand Up @@ -42,6 +42,10 @@ def get_quant_attrs(
value = get_parameter(attr_n, edge_program)
quant_attrs[quant_attr_keys[i - 1]] = value

# remap key for compatibility - block quantization only
if dtype := quant_attrs.get("input_dtype", None):
quant_attrs[QCOM_DTYPE] = dtype

quant_attrs[QCOM_ENCODING] = quant_node.target
return quant_attrs

Expand All @@ -62,7 +66,6 @@ def get_passes_dependency_for_capture_program():
AnnotateQuantAttrs,
ConstantI64toI32,
ConvertBmmToMatmul,
ConvertInterpolateWithUpsample2D,
ConvertToLinear,
DecomposeAny,
DecomposeLinalgVectorNorm,
Expand All @@ -85,11 +88,9 @@ def get_passes_dependency_for_capture_program():
ConvertToLinear,
RecomposePReLU,
ConvertBmmToMatmul,
ConvertInterpolateWithUpsample2D,
],
ConstantI64toI32: [ConvertInterpolateWithUpsample2D],
ConstantI64toI32: [RemoveRedundancy],
ConvertBmmToMatmul: [ConvertToLinear],
ConvertInterpolateWithUpsample2D: [RemoveRedundancy],
ConvertToLinear: [RecomposePixelUnshuffle],
DecomposeAny: [RemoveRedundancy],
DecomposeLinalgVectorNorm: [RemoveRedundancy],
Expand Down
11 changes: 11 additions & 0 deletions backends/qualcomm/aot/ir/qcir.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ enum QuantizeType : byte {
AXIS_SCALE_OFFSET,
BW_SCALE_OFFSET,
BW_AXIS_SCALE_OFFSET,
BLOCKWISE_EXPANSION,
UNDEFINED,
}

enum BlockScaleStorageType: byte {
BITWIDTH_SCALE_STORAGE_8 = 0,
BITWIDTH_SCALE_STORAGE_16,
UNDEFINED,
}

Expand All @@ -72,6 +79,10 @@ table QuantizeParam {
offsets: [int];
// used by general quantization
data: [ScaleOffset];
// used by block quantization
num_blocks_per_axis: uint;
block_scale_storage_type: BlockScaleStorageType;
block_scale: [ubyte];
}

table Tensor {
Expand Down
63 changes: 61 additions & 2 deletions backends/qualcomm/aot/ir/qcir_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,22 @@ flatbuffers::Offset<qcir::QuantizeParam> ToQuantizeParam(
qcir::QuantizeType::BW_SCALE_OFFSET},
{QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
qcir::QuantizeType::BW_AXIS_SCALE_OFFSET},
{QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION,
qcir::QuantizeType::BLOCKWISE_EXPANSION},
{QNN_QUANTIZATION_ENCODING_UNDEFINED,
qcir::QuantizeType::UNDEFINED},
};

int32_t axis = 0;
uint32_t bitwidth = 0;
uint32_t bitwidth = 0, num_blocks_per_axis = 0;
auto param = QNN_TENSOR_VER_PTR(tensor)->quantizeParams;
auto quant_type = type_map.at(param.quantizationEncoding);
std::vector<qcir::ScaleOffset> data;
std::vector<uint8_t> block_scale;
std::vector<float> scales;
std::vector<int32_t> offsets;
qcir::BlockScaleStorageType block_scale_storage_type =
qcir::BlockScaleStorageType::BITWIDTH_SCALE_STORAGE_8;
switch (quant_type) {
case qcir::QuantizeType::SCALE_OFFSET: {
data.emplace_back(qcir::ScaleOffset(
Expand Down Expand Up @@ -160,6 +165,28 @@ flatbuffers::Offset<qcir::QuantizeParam> ToQuantizeParam(
offsets.push_back(param.bwAxisScaleOffsetEncoding.offsets[i]);
}
} break;
case qcir::QuantizeType::BLOCKWISE_EXPANSION: {
bitwidth = param.blockwiseExpansion->blockScaleBitwidth;
axis = param.blockwiseExpansion->axis;
uint num_channels = QNN_TENSOR_VER_PTR(tensor)->dimensions[axis];
for (uint i = 0; i < num_channels; ++i) {
data.emplace_back(qcir::ScaleOffset(
param.blockwiseExpansion->scaleOffsets[i].scale,
param.blockwiseExpansion->scaleOffsets[i].offset));
}
num_blocks_per_axis = param.blockwiseExpansion->numBlocksPerAxis;
uint multiplier = 1;
if (param.blockwiseExpansion->blockScaleStorageType ==
QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_16) {
multiplier = 2;
block_scale_storage_type =
qcir::BlockScaleStorageType::BITWIDTH_SCALE_STORAGE_16;
}
uint total_bytes = num_channels * num_blocks_per_axis * multiplier;
block_scale = std::vector<uint8_t>(
param.blockwiseExpansion->blocksScale8,
param.blockwiseExpansion->blocksScale8 + total_bytes);
} break;
default:
// encodings are not required if lowering with floating point precision
break;
Expand All @@ -172,7 +199,10 @@ flatbuffers::Offset<qcir::QuantizeParam> ToQuantizeParam(
axis,
&scales,
&offsets,
&data);
&data,
num_blocks_per_axis,
block_scale_storage_type,
&block_scale);
}

Qnn_QuantizeParams_t ToQuantizeParam(const tensor_type& tensor) {
Expand All @@ -192,9 +222,14 @@ Qnn_QuantizeParams_t ToQuantizeParam(const tensor_type& tensor) {
QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET},
{qcir::QuantizeType::BW_AXIS_SCALE_OFFSET,
QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET},
{qcir::QuantizeType::BLOCKWISE_EXPANSION,
QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION},
{qcir::QuantizeType::UNDEFINED,
QNN_QUANTIZATION_ENCODING_UNDEFINED},
};
// Qnn_BlockwiseExpansion_t is a pointer type in Qnn_QuantizeParams_t
// need a bookkeeper for guarding life cycle
static std::vector<std::unique_ptr<Qnn_BlockwiseExpansion_t>> block_param;

Qnn_QuantizeParams_t p = QNN_QUANTIZE_PARAMS_INIT;
auto param = tensor->qparam();
Expand Down Expand Up @@ -226,6 +261,30 @@ Qnn_QuantizeParams_t ToQuantizeParam(const tensor_type& tensor) {
p.bwAxisScaleOffsetEncoding.offsets =
const_cast<int32_t*>(param->offsets()->data());
} break;
case QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION: {
block_param.emplace_back(std::make_unique<Qnn_BlockwiseExpansion_t>());
p.blockwiseExpansion = block_param.back().get();
p.blockwiseExpansion->axis = param->axis();
p.blockwiseExpansion->scaleOffsets = reinterpret_cast<Qnn_ScaleOffset_t*>(
const_cast<uint8_t*>(param->data()->Data()));
p.blockwiseExpansion->numBlocksPerAxis = param->num_blocks_per_axis();
switch (param->block_scale_storage_type()) {
case qcir::BlockScaleStorageType::BITWIDTH_SCALE_STORAGE_8:
p.blockwiseExpansion->blockScaleStorageType =
QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8;
break;
case qcir::BlockScaleStorageType::BITWIDTH_SCALE_STORAGE_16:
p.blockwiseExpansion->blockScaleStorageType =
QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_16;
break;
default:
p.blockwiseExpansion->blockScaleStorageType =
QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_UNDEFINED;
break;
}
p.blockwiseExpansion->blocksScale8 =
const_cast<uint8_t*>(param->block_scale()->Data());
} break;
default:
// encodings are not required if lowering with floating point precision
break;
Expand Down
49 changes: 45 additions & 4 deletions backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,28 @@ std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
int32_t offset = quant_info["offset"].cast<int32_t>();
quantize_param_wrapper =
std::make_unique<ScaleOffsetQuantizeParamsWrapper>(scale, offset);
} else if (encoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION) {
int32_t axis = quant_info["axis"].cast<int32_t>();
std::vector<Qnn_ScaleOffset_t> scale_offset =
quant_info["block_scale_offset"].cast<std::vector<Qnn_ScaleOffset_t>>();
uint32_t num_blocks_per_axis =
quant_info["num_blocks_per_axis"].cast<uint32_t>();
uint32_t block_scale_bitwidth =
quant_info["block_scale_bitwidth"].cast<uint32_t>();
Qnn_BlockwiseExpansionBlockScaleStorageType_t block_storage_type =
quant_info["block_storage_type"]
.cast<Qnn_BlockwiseExpansionBlockScaleStorageType_t>();
std::vector<uint8_t> buf =
quant_info["block_scales"].cast<std::vector<uint8_t>>();
quantize_param_wrapper =
std::make_unique<BlockwiseExpansionQuantizeParamsWrapper>(
axis,
scale_offset,
num_blocks_per_axis,
block_scale_bitwidth,
block_storage_type,
buf.data(),
buf.size());
} else {
QNN_EXECUTORCH_LOG_ERROR(
"Unknown the encoding of quantization: %d", encoding);
Expand Down Expand Up @@ -179,9 +201,6 @@ PYBIND11_MODULE(PyQnnWrapperAdaptor, m) {
.export_values();

py::enum_<Qnn_QuantizationEncoding_t>(m, "Qnn_QuantizationEncoding_t")
.value(
"QNN_QUANTIZATION_ENCODING_UNDEFINED",
Qnn_QuantizationEncoding_t::QNN_QUANTIZATION_ENCODING_UNDEFINED)
.value(
"QNN_QUANTIZATION_ENCODING_SCALE_OFFSET",
Qnn_QuantizationEncoding_t::QNN_QUANTIZATION_ENCODING_SCALE_OFFSET)
Expand All @@ -196,6 +215,29 @@ PYBIND11_MODULE(PyQnnWrapperAdaptor, m) {
"QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET",
Qnn_QuantizationEncoding_t::
QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET)
.value(
"QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION",
Qnn_QuantizationEncoding_t::
QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION)
.value(
"QNN_QUANTIZATION_ENCODING_UNDEFINED",
Qnn_QuantizationEncoding_t::QNN_QUANTIZATION_ENCODING_UNDEFINED)
.export_values();

py::enum_<Qnn_BlockwiseExpansionBlockScaleStorageType_t>(
m, "Qnn_BlockwiseExpansionBlockScaleStorageType_t")
.value(
"QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8",
Qnn_BlockwiseExpansionBlockScaleStorageType_t::
QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8)
.value(
"QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_16",
Qnn_BlockwiseExpansionBlockScaleStorageType_t::
QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_16)
.value(
"QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_UNDEFINED",
Qnn_BlockwiseExpansionBlockScaleStorageType_t::
QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_UNDEFINED)
.export_values();

py::class_<OpWrapper, std::shared_ptr<OpWrapper>>(m, "OpWrapper")
Expand Down Expand Up @@ -476,7 +518,6 @@ PYBIND11_MODULE(PyQnnWrapperAdaptor, m) {
return std::vector<Qnn_ScaleOffset_t>(
aso.scaleOffset, aso.scaleOffset + aso.numScaleOffsets);
});
// op_wrapper.GetParams() get std::vector<ParamWrapper*>
}
} // namespace qnn
} // namespace backends
Expand Down
Loading
Loading