Skip to content

Arm backend: Replace asserts with exceptions in quantizer module #11519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ def set_module_name(
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
patterns in the submodule with this module name with the given `quantization_config`
"""
assert (
quantization_config is not None
), " quantization_config == None is not supported yet"
# Validate that quantization_config is provided
if quantization_config is None:
raise ValueError("quantization_config == None is not supported yet")
self.module_name_config[module_name] = quantization_config
return self

Expand Down
40 changes: 26 additions & 14 deletions backends/arm/quantizer/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,40 @@ def get_input_act_qspec(self) -> QuantizationSpec | None:
"""Returns QuantizationSpec 'input_activation' after asserting that input_activation.qscheme is valid."""
if self.input_activation is None:
return None
assert self.input_activation.qscheme in [
# Validate that input_activation uses a supported qscheme
if self.input_activation.qscheme not in [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
], f"Unsupported quantization_spec {self.input_activation} for input_activation."
]:
raise ValueError(
f"Unsupported quantization_spec {self.input_activation} for input_activation."
)
return self.input_activation

def get_output_act_qspec(self) -> QuantizationSpec | None:
"""Returns QuantizationSpec 'output_activation' after asserting that output_activation.qscheme is valid."""
if self.output_activation is None:
return None
assert self.output_activation.qscheme in [
# Validate that output_activation uses a supported qscheme
if self.output_activation.qscheme not in [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
], f"Unsupported quantization_spec {self.output_activation} for output_activation."
]:
raise ValueError(
f"Unsupported quantization_spec {self.output_activation} for output_activation."
)
return self.output_activation

def get_weight_qspec(self) -> QuantizationSpec | None:
"""Returns QuantizationSpec 'weight' after asserting that weight.qscheme is valid."""
if self.weight is None:
return None
assert self.weight.qscheme in [
# Validate that weight uses a supported qscheme
if self.weight.qscheme not in [
torch.per_tensor_symmetric,
torch.per_channel_symmetric,
], f"Unsupported quantization_spec {self.weight} for weight"
]:
raise ValueError(f"Unsupported quantization_spec {self.weight} for weight")
return self.weight

def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None:
Expand All @@ -61,11 +71,11 @@ def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None:
def _derive_qparams_fn(
obs_or_fqs: list[ObserverOrFakeQuantize],
) -> tuple[torch.Tensor, torch.Tensor]:
assert (
len(obs_or_fqs) == 2
), "Expecting two obs/fqs, one for activation and one for weight, got: {}".format(
len(obs_or_fqs)
)
# Validate expected number of observers/fake-quantizes
if len(obs_or_fqs) != 2:
raise ValueError(
f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
)
act_obs_or_fq = obs_or_fqs[0]
weight_obs_or_fq = obs_or_fqs[1]
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
Expand Down Expand Up @@ -94,9 +104,11 @@ def _derive_qparams_fn(

if self.bias is None:
return None
assert (
self.bias.dtype == torch.float
), "Only float dtype for bias is supported for bias right now"
# Validate that bias dtype is floating-point
if self.bias.dtype != torch.float:
raise ValueError(
"Only float dtype for bias is supported for bias right now"
)
return self.bias

def get_fixed_qspec(
Expand Down
Loading