Skip to content

Add filter function to XNNPack Quantizer #10626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions backends/xnnpack/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ def __init__(self) -> None:
] = {}
self.module_type_config: dict[Callable, Optional[QuantizationConfig]] = {}
self.module_name_config: dict[str, Optional[QuantizationConfig]] = {}
# If specified, only quantize nodes that return true for the filter
# function.
self.filter_fn: Optional[Callable[[Node], bool]] = None

@classmethod
def get_supported_quantization_configs(cls) -> list[QuantizationConfig]:
Expand Down Expand Up @@ -355,6 +358,14 @@ def set_module_name(
self.module_name_config[module_name] = quantization_config
return self

def set_filter_function(self, filter_fn: Callable[[Node], bool]):
"""
Set the filter function. We only quantize nodes that return True for
the filter function.
"""
self.filter_fn = filter_fn
return self

def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
Expand All @@ -378,17 +389,29 @@ def _annotate_all_patterns(
if quantization_config is None:
return model

# Create a combined filter function, which returns True only when
# both filter_fn and self.filter_fn return True.
def combined_filter_fn(n: Node) -> bool:
combined_filter = [self.filter_fn, filter_fn]
return all(f(n) for f in combined_filter if f is not None)

for pattern in self.SUPPORTED_PATTERNS:
if operator_target and operator_target not in pattern.op_overloads:
# if operator_target is specified, skip patterns that aren't
# associated with that target
continue
if quantization_config.input_activation.is_dynamic and pattern.is_dynamic:
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
OP_TO_ANNOTATOR[pattern.name](
model, quantization_config, combined_filter_fn
)
elif quantization_config.is_qat and pattern.is_qat:
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
OP_TO_ANNOTATOR[pattern.name](
model, quantization_config, combined_filter_fn
)
elif not quantization_config.input_activation.is_dynamic:
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
OP_TO_ANNOTATOR[pattern.name](
model, quantization_config, combined_filter_fn
)

return model

Expand Down
30 changes: 30 additions & 0 deletions backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,36 @@ def test_obs_sharing_ops(self):
]
self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list)

def test_set_filter_fn(self):
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
m_eager = TestHelperModules.TwoLinearModule().eval()

# Set the filter function so that the second linear is not quantized
def filter_fn(n):
return n.name != "linear_1"

quantizer.set_filter_function(filter_fn)

# Test with 2d inputs
example_inputs_2d = (torch.randn(9, 8),)
node_occurrence = {
# input and output of the first linear op will be (de)quantized
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
# weight for the first linear will be dequantized
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
}
self._test_quantizer(
m_eager,
example_inputs_2d,
quantizer,
node_occurrence,
)

def test_set_module_name(self):
class Sub(torch.nn.Module):
def __init__(self) -> None:
Expand Down
Loading