Skip to content

Commit 5dcba60

Browse files
authored
Merge branch 'main' into Arm-backend-Bump-cortex-m-size-test-threshold
2 parents a2868fb + 954f0e9 commit 5dcba60

File tree

9 files changed

+330
-115
lines changed

9 files changed

+330
-115
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@
6262
UnsqueezeScalarPlaceholdersPass,
6363
)
6464

65-
from executorch.backends.arm.tosa_specification import TosaSpecification
65+
from executorch.backends.arm.tosa_specification import (
66+
TosaLoweringContext,
67+
TosaSpecification,
68+
)
6669
from executorch.backends.transforms.decompose_sdpa import (
6770
DecomposeScaledDotProductAttention,
6871
)
@@ -80,7 +83,8 @@ def __init__(self, tosa_spec: TosaSpecification) -> None:
8083
super().__init__()
8184

8285
def _transform(self, graph_module: GraphModule):
83-
return self(graph_module).graph_module
86+
with TosaLoweringContext(self.tosa_spec):
87+
return self(graph_module).graph_module
8488

8589
def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
8690
self.add_pass(FuseQuantizedActivationPass())

backends/arm/quantizer/arm_quantizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,9 @@ def set_module_name(
247247
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
248248
patterns in the submodule with this module name with the given `quantization_config`
249249
"""
250-
assert (
251-
quantization_config is not None
252-
), " quantization_config == None is not supported yet"
250+
# Validate that quantization_config is provided
251+
if quantization_config is None:
252+
raise ValueError("quantization_config == None is not supported yet")
253253
self.module_name_config[module_name] = quantization_config
254254
return self
255255

backends/arm/quantizer/quantization_config.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,40 @@ def get_input_act_qspec(self) -> QuantizationSpec | None:
2929
"""Returns QuantizationSpec 'input_activation' after asserting that input_activation.qscheme is valid."""
3030
if self.input_activation is None:
3131
return None
32-
assert self.input_activation.qscheme in [
32+
# Validate that input_activation uses a supported qscheme
33+
if self.input_activation.qscheme not in [
3334
torch.per_tensor_affine,
3435
torch.per_tensor_symmetric,
35-
], f"Unsupported quantization_spec {self.input_activation} for input_activation."
36+
]:
37+
raise ValueError(
38+
f"Unsupported quantization_spec {self.input_activation} for input_activation."
39+
)
3640
return self.input_activation
3741

3842
def get_output_act_qspec(self) -> QuantizationSpec | None:
3943
"""Returns QuantizationSpec 'output_activation' after asserting that output_activation.qscheme is valid."""
4044
if self.output_activation is None:
4145
return None
42-
assert self.output_activation.qscheme in [
46+
# Validate that output_activation uses a supported qscheme
47+
if self.output_activation.qscheme not in [
4348
torch.per_tensor_affine,
4449
torch.per_tensor_symmetric,
45-
], f"Unsupported quantization_spec {self.output_activation} for output_activation."
50+
]:
51+
raise ValueError(
52+
f"Unsupported quantization_spec {self.output_activation} for output_activation."
53+
)
4654
return self.output_activation
4755

4856
def get_weight_qspec(self) -> QuantizationSpec | None:
4957
"""Returns QuantizationSpec 'weight' after asserting that weight.qscheme is valid."""
5058
if self.weight is None:
5159
return None
52-
assert self.weight.qscheme in [
60+
# Validate that weight uses a supported qscheme
61+
if self.weight.qscheme not in [
5362
torch.per_tensor_symmetric,
5463
torch.per_channel_symmetric,
55-
], f"Unsupported quantization_spec {self.weight} for weight"
64+
]:
65+
raise ValueError(f"Unsupported quantization_spec {self.weight} for weight")
5666
return self.weight
5767

5868
def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None:
@@ -61,11 +71,11 @@ def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None:
6171
def _derive_qparams_fn(
6272
obs_or_fqs: list[ObserverOrFakeQuantize],
6373
) -> tuple[torch.Tensor, torch.Tensor]:
64-
assert (
65-
len(obs_or_fqs) == 2
66-
), "Expecting two obs/fqs, one for activation and one for weight, got: {}".format(
67-
len(obs_or_fqs)
68-
)
74+
# Validate expected number of observers/fake-quantizes
75+
if len(obs_or_fqs) != 2:
76+
raise ValueError(
77+
f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
78+
)
6979
act_obs_or_fq = obs_or_fqs[0]
7080
weight_obs_or_fq = obs_or_fqs[1]
7181
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
@@ -94,9 +104,11 @@ def _derive_qparams_fn(
94104

95105
if self.bias is None:
96106
return None
97-
assert (
98-
self.bias.dtype == torch.float
99-
), "Only float dtype for bias is supported for bias right now"
107+
# Validate that bias dtype is floating-point
108+
if self.bias.dtype != torch.float:
109+
raise ValueError(
110+
"Only float dtype for bias is supported for bias right now"
111+
)
100112
return self.bias
101113

102114
def get_fixed_qspec(

backends/arm/runtime/EthosUBackend.cpp

Lines changed: 10 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -261,24 +261,12 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
261261
event_tracer,
262262
"+EthosUBackend::execute()handles.input.permute_CHW_to_HWC()");
263263
// permuted byte copy CHW to HWC
264-
int c, h, w;
265-
if (tensor_in.dim() == 4) {
266-
c = tensor_in.size(1);
267-
h = tensor_in.size(2);
268-
w = tensor_in.size(3);
269-
} else if (tensor_in.dim() == 5) {
270-
c = tensor_in.size(2);
271-
h = tensor_in.size(3);
272-
w = tensor_in.size(4);
273-
} else {
274-
ET_LOG(
275-
Error,
276-
"Unsupported input tensor dimension %d, expected 4 or 5",
277-
tensor_in.dim());
278-
return Error::InvalidProgram;
279-
}
280264
permute_CHW_to_HWC(
281-
tensor_in.mutable_data_ptr<char>(), scratch_addr, c, h, w);
265+
tensor_in.mutable_data_ptr<char>(),
266+
scratch_addr,
267+
tensor_in.size(1),
268+
tensor_in.size(2),
269+
tensor_in.size(3));
282270
} else if (both_char or both_int or both_short) {
283271
EXECUTORCH_PROF_SCOPE(
284272
event_tracer, "+EthosUBackend::execute()handles.input.memcpy()");
@@ -376,24 +364,12 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
376364
"+EthosUBackend::execute()handles.output.permute_HWC_to_CHW()");
377365

378366
char* output_address = (char*)output_addr;
379-
int c, h, w;
380-
if (tensor_out.dim() == 4) {
381-
c = tensor_out.size(1);
382-
h = tensor_out.size(2);
383-
w = tensor_out.size(3);
384-
} else if (tensor_out.dim() == 5) {
385-
c = tensor_out.size(2);
386-
h = tensor_out.size(3);
387-
w = tensor_out.size(4);
388-
} else {
389-
ET_LOG(
390-
Error,
391-
"Unsupported output tensor dimension %d, expected 4 or 5",
392-
tensor_out.dim());
393-
return Error::InvalidProgram;
394-
}
395367
permute_HWC_to_CHW(
396-
output_address, tensor_out.mutable_data_ptr<char>(), c, h, w);
368+
output_address,
369+
tensor_out.mutable_data_ptr<char>(),
370+
tensor_out.size(1),
371+
tensor_out.size(2),
372+
tensor_out.size(3));
397373
} else {
398374
EXECUTORCH_PROF_SCOPE(
399375
event_tracer, "+EthosUBackend::execute()handles.output.move()");
@@ -454,14 +430,6 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
454430
if (permuted_shape) {
455431
ET_LOG(Debug, "Tensor input/output %d will be permuted", index);
456432
}
457-
} else if (tensor.dim() == 5) {
458-
// Same as above, but for 5D tensors.
459-
permuted_shape = tensor.size(0) == io->shape[0] &&
460-
tensor.size(1) == io->shape[1] && tensor.size(2) == io->shape[4] &&
461-
tensor.size(3) == io->shape[2] && tensor.size(4) == io->shape[3];
462-
if (permuted_shape) {
463-
ET_LOG(Debug, "Tensor input/output %d will be permuted", index);
464-
}
465433
}
466434
*is_permuted = permuted_shape;
467435
return Error::Ok;

backends/arm/tosa_specification.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# JIT compiler flows.
1212
#
1313

14+
import contextvars
1415
import re
1516
from typing import List
1617

@@ -214,3 +215,34 @@ def support_integer(self):
214215

215216
def support_float(self):
216217
return "FP" in self.profiles
218+
219+
220+
class TosaLoweringContext:
221+
"""
222+
A context manager to handle the TOSA specific aspects of the lowering process.
223+
For now it only handles the TOSA specification context, but it can be extended
224+
to include other policies or configurations.
225+
"""
226+
227+
# Define a context variable for the spec
228+
tosa_spec_var: contextvars.ContextVar = contextvars.ContextVar("tosa_spec")
229+
230+
def __init__(self, spec: TosaSpecification):
231+
self.spec = spec
232+
233+
def __enter__(self):
234+
# Set the spec in the context variable and store the token for later reset
235+
self.token = TosaLoweringContext.tosa_spec_var.set(self.spec)
236+
return self
237+
238+
def __exit__(self, exc_type, exc_value, traceback):
239+
# Reset the context variable to its previous state
240+
TosaLoweringContext.tosa_spec_var.reset(self.token)
241+
242+
243+
# A helper function to retrieve the current spec anywhere in your code
244+
def get_context_spec() -> TosaSpecification:
245+
try:
246+
return TosaLoweringContext.tosa_spec_var.get()
247+
except LookupError:
248+
raise RuntimeError("Function must be executed within a TosaLoweringContext")

backends/qualcomm/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,5 @@ def define_common_targets():
9292
exported_deps = [
9393
":schema",
9494
],
95+
platforms = [ANDROID],
9596
)

examples/apple/mps/scripts/mps_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def get_model_config(args):
145145
return model_config
146146

147147

148-
if __name__ == "__main__":
148+
if __name__ == "__main__": # noqa: C901
149149
args = parse_args()
150150

151151
if args.model_name not in MODEL_NAME_TO_MODEL:

0 commit comments

Comments
 (0)