Skip to content

Qualcomm AI Engine Direct - Enable per channel linear op #2822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions backends/qualcomm/builders/op_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def define_node(
linear_input_tensors.append(input_tensor_wrapper)

weight_node = node.args[1]
if (
quant_attrs := weight_node.meta.get("quant_attrs")
) and "scales" in quant_attrs:
# Dimension of weight is [m, n], per channel quant params is [m]
# Change to [m, 1] to fit the tensor.div(s).add(z)
quant_attrs["scales"] = quant_attrs["scales"].reshape([-1, 1])
quant_attrs["zero_points"] = quant_attrs["zero_points"].reshape([-1, 1])

weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
Expand All @@ -52,6 +60,12 @@ def define_node(

if len(node.args) >= 3:
bias_node = node.args[2]

# TODO remove this when qnn sdk support
if "scales" in bias_node.meta.get("quant_attrs"):
print(
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
)
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
Expand Down
26 changes: 18 additions & 8 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __init__(self):
self.custom_quant_annotations: Sequence[Callable] = []
self.discard_nodes: Set[str] = set()

self.enable_per_channel_conv_quant: bool = True
self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
# the weight quantized for activation 8 bits and 16 bits
self.per_channel_weight_dtype: Dict = {
"8bit_act": torch.int8,
Expand All @@ -290,16 +290,13 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]:
"""
Priority:
1. per channel config when enable_per_channel_conv_quant is True
1. is one of use_per_channel_weight_quant_ops
2. int8 / int16 config
"""
if type(op) == str:
return

if self.enable_per_channel_conv_quant and op in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]:
if op in self.use_per_channel_weight_quant_ops:
if op in self.bit16_quant_ops:
return get_ptq_per_channel_weight_config(
torch.uint16, self.per_channel_weight_dtype["16bit_act"]
Expand All @@ -316,6 +313,12 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig

print(f"No quant config is implemented for op, {op}")

def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
if enable:
self.use_per_channel_weight_quant_ops.update(ops)
else:
self.use_per_channel_weight_quant_ops.difference(ops)

def add_16bit_quant_ops(self, ops: Set[OpOverload]) -> None:
for op in ops:
assert (
Expand Down Expand Up @@ -368,8 +371,15 @@ def set_per_channel_weight_dtype(
if weight_dtype_for_16bit_act:
self.per_channel_weight_dtype["16bit_act"] = weight_dtype_for_16bit_act

def set_per_channel_quant(self, enable: bool) -> None:
self.enable_per_channel_conv_quant = enable
def set_per_channel_conv_quant(self, enable: bool) -> None:
conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default}
self._update_per_channel_weight_quant_ops(conv_ops, enable)

def set_per_channel_linear_quant(self, enable: bool) -> None:
linear_ops = {
torch.ops.aten.linear.default,
}
self._update_per_channel_weight_quant_ops(linear_ops, enable)

def transform_for_annotation(self, model: GraphModule) -> GraphModule:
model = RemoveClone()(model).graph_module
Expand Down
10 changes: 5 additions & 5 deletions backends/qualcomm/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,11 +520,11 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
_annotate_input_qspec_map(
node,
bias_node,
quantization_config.bias,
)
if callable(quantization_config.bias):
bias_config = quantization_config.bias(node)
else:
bias_config = quantization_config.bias
_annotate_input_qspec_map(node, bias_node, bias_config)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,9 @@ def forward(self, x):


class Linear(torch.nn.Module):
def __init__(self):
def __init__(self, use_bias: bool = True):
super().__init__()
self.linear = torch.nn.Linear(4, 5).eval()
self.linear = torch.nn.Linear(4, 5, use_bias).eval()

def forward(self, x):
return self.linear(x)
Expand Down
28 changes: 27 additions & 1 deletion backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,33 @@ def test_qnn_backend_16a4w_linear(self):
module = Linear() # noqa: F405
sample_input = (torch.randn([3, 4]),)
module = self.get_qdq_module(
module, sample_input, quant_dtype=QuantDtype.use_16a4w
module,
sample_input,
quant_dtype=QuantDtype.use_16a4w,
)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_16a4w_per_channel_linear(self):
module = Linear(use_bias=False) # noqa: F405
sample_input = (torch.randn([3, 4]),)
module = self.get_qdq_module(
module,
sample_input,
is_linear_per_channel=True,
quant_dtype=QuantDtype.use_16a4w,
)
self.lower_module_and_test_output(module, sample_input)

# Is not enabled in the current qnn sdk release
@unittest.expectedFailure
def test_qnn_backend_16a4w_per_channel_linear_with_bias(self):
module = Linear() # noqa: F405
sample_input = (torch.randn([3, 4]),)
module = self.get_qdq_module(
module,
sample_input,
is_linear_per_channel=True,
quant_dtype=QuantDtype.use_16a4w,
)
self.lower_module_and_test_output(module, sample_input)

Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,16 @@ def get_qdq_module(
module: torch.nn.Module,
inputs: Tuple[torch.Tensor],
is_conv_per_channel: Optional[bool] = True,
is_linear_per_channel: Optional[bool] = False,
custom_quant_annotations: Tuple[Callable] = (),
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
) -> torch.fx.GraphModule:
m = torch._export.capture_pre_autograd_graph(module, inputs)

quantizer = QnnQuantizer()
quantizer.add_custom_quant_annotations(custom_quant_annotations)
quantizer.set_per_channel_quant(is_conv_per_channel)
quantizer.set_per_channel_conv_quant(is_conv_per_channel)
quantizer.set_per_channel_linear_quant(is_linear_per_channel)

if quant_dtype == QuantDtype.use_8a8w:
pass # default setting
Expand Down