Skip to content

Commit 70c5be3

Browse files
freddan80facebook-github-bot
authored andcommitted
Cut dependencies and clean up Arm backend unit tester (#2231)
Summary: bypass-github-pytorch-ci-checks bypass-github-export-checks Pull Request resolved: #2231 Reviewed By: mergennachin Differential Revision: D54640970 Pulled By: digantdesai fbshipit-source-id: 5bab38b60cff1ceb74d1a0b06694e240af1ba9d1
1 parent 47d2737 commit 70c5be3

File tree

3 files changed

+155
-142
lines changed

3 files changed

+155
-142
lines changed

backends/arm/test/ops/test_add.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _test_add_tosa_BI_pipeline(
8888
.to_executorch()
8989
)
9090
if TOSA_REF_MODEL_INSTALLED:
91-
tester.run_method().compare_outputs()
91+
tester.run_method().compare_outputs(qtol=1)
9292
else:
9393
logger.warning(
9494
"TOSA ref model tool not installed, skip numerical correctness tests"
@@ -118,8 +118,6 @@ def test_add_tosa_MI(self):
118118
test_data = (torch.randn(4, 4, 4),)
119119
self._test_add_tosa_MI_pipeline(self.Add(), test_data)
120120

121-
# TODO: Will this type of parametrization be supported? pytest seem
122-
# have issue with it.
123121
@parameterized.expand(
124122
[
125123
(torch.ones(5),), # test_data

backends/arm/test/tester/arm_tester.py

Lines changed: 111 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from enum import Enum
7-
from typing import Optional, Tuple
7+
from typing import List, Optional, Tuple, Union
88

99
import torch
1010
from executorch.backends.arm.arm_backend import (
@@ -15,6 +15,7 @@
1515
from executorch.backends.arm.arm_partitioner import ArmPartitioner
1616

1717
from executorch.backends.arm.test.tosautil.tosa_test_utils import (
18+
QuantizationParams,
1819
TosaProfile,
1920
TosaTestUtils,
2021
)
@@ -32,6 +33,7 @@
3233
get_symmetric_quantization_config,
3334
XNNPACKQuantizer,
3435
)
36+
from torch.export import ExportedProgram
3537

3638

3739
class ArmBackendSelector(Enum):
@@ -61,6 +63,7 @@ def __init__(
6163
TosaProfile.BI or TosaProfile.MI
6264
"""
6365
self.tosa_test_util = None
66+
self.is_quantized = profile == TosaProfile.BI
6467
if backend == ArmBackendSelector.TOSA:
6568
self.tosa_test_util = TosaTestUtils(profile=profile)
6669
# The spec below tiggers arm_backend.py to output two files:
@@ -119,54 +122,121 @@ def run_method(
119122
), "self.tosa_test_util is not initialized, cannot use run_method()"
120123
inputs_to_run = inputs or self.inputs
121124

122-
# TODO: we can't possible need to use all these stages??
123-
export_stage = self.stages[
124-
self.stage_name(Export)
125-
] # this is what XNNpack use to get quant params
126-
toedge_stage = self.stages[
127-
self.stage_name(ToEdge)
128-
] # this is what get_input_quantization_params use to get quant params
129-
partition_stage = self.stages[
130-
self.stage_name(Partition)
131-
] # this is what tosa_ref_dump_inputs use....
132-
133-
# TODO: I'd prefer to use this TOSA buffer instead of output.tosa,
134-
# generated by arm_backend.py. The issue is that we're still depending
135-
# on desc.json, which is created from TosaSerializer class, not from
136-
# the serialized TOSA buffer. Leave this here for review purposes.
137-
# ts_serialized = self._get_serialized_tosa_buffer( # unused
138-
# partition_stage.artifact
139-
# )
140-
141-
# This is where the torch reference output is calculated and set
142-
# TODO: This sets self.quantization_scale, which is duplicates
143-
# self.tosa_test_util.quantization.output.scales (?). Fixme.
144-
(
145-
self.reference_output,
146-
self.quantization_scale,
147-
) = self._calculate_reference_output(export_stage.artifact, inputs_to_run)
148-
149-
# Convert the torch inputs to something TOSA ref model can use
150-
tensor_names_and_inputs_np = self.tosa_test_util.convert_inputs_to_tosa(
151-
partition_stage.artifact, toedge_stage.artifact, inputs_to_run
125+
export_stage = self.stages[self.stage_name(Export)]
126+
127+
(input_names, qp_input) = self._get_input_params(export_stage.artifact)
128+
(output_name, qp_output) = self._get_output_param(export_stage.artifact)
129+
130+
# Calculate the reference output using the original module or the quant
131+
# module. self.quantization_scale is used by compare_outputs() to
132+
# calculate the tolerance
133+
self.quantization_scale = None if qp_output is None else qp_output.scale
134+
if self.is_quantized:
135+
module_for_ref = self.stages[self.stage_name(Quantize)].artifact
136+
else:
137+
module_for_ref = self.original_module
138+
self.reference_output = self._calculate_reference_output(
139+
module_for_ref, inputs_to_run
152140
)
153141

154142
# Run the TOSA ref model to get the output tensor, which will be
155143
# compared to the torch output in compare_outputs()
156144
self.stage_output = self.tosa_test_util.run_tosa_ref_model(
157-
tensor_names_and_inputs_np
145+
params_input=(input_names, qp_input),
146+
param_output=(output_name, qp_output),
147+
inputs=inputs_to_run,
158148
)
159149

160150
return self
161151

162-
def _get_serialized_tosa_buffer(self, partition_stage: Partition) -> bytes:
152+
def _get_input_params(
153+
self, program: ExportedProgram
154+
) -> Tuple[str, Union[List[QuantizationParams], List[None]]]:
163155
"""
164-
This is just a prototype...
165-
Todo:
166-
* The "_0" indicates that there are many lowered modules. Loop it!
167-
* There's probably a better way to get this buffer. An API? Yes,
168-
it seems the serialize stage does this for you...
156+
Get name and optionally quantization parameters for the inputs to this
157+
model.
158+
159+
Args:
160+
program (ExportedProgram): The program to get input parameters from
161+
Returns:
162+
Tuple[str, Optional[QuantizationParams]]: A tuple containing the
163+
input node names and their quantization parameters.
164+
"""
165+
input_names = []
166+
# E.g. bias and weights are 'placeholders' as well. This is used to
167+
# get only the use inputs.
168+
usr_inputs = program.graph_signature.user_inputs
169+
for node in program.graph.nodes:
170+
if node.op == "placeholder" and node.name in usr_inputs:
171+
input_names.append(node.name)
172+
continue
173+
174+
if self.is_quantized:
175+
quant_params = []
176+
for node in program.graph.nodes:
177+
if (
178+
node.target
179+
== torch.ops.quantized_decomposed.quantize_per_tensor.default
180+
and node.args[0].name in input_names
181+
):
182+
qp = QuantizationParams(
183+
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
184+
)
185+
quant_params.append(qp)
186+
if len(quant_params) == len(
187+
input_names
188+
): # break early if we have all the inputs quantized parameters
189+
break
190+
assert len(quant_params) != 0, "Quantization paramerters not found"
191+
return (input_names, quant_params)
192+
else:
193+
return (input_names, len(input_names) * [None]) # return a list of None's
194+
195+
def _get_output_param(
196+
self, program: ExportedProgram
197+
) -> Tuple[str, Union[QuantizationParams, None]]:
169198
"""
170-
return partition_stage._edge_programs[
171-
"forward"
172-
]._graph_module.lowered_module_0.processed_bytes
199+
Get name and optionally quantization parameters for the inputs to this
200+
model.
201+
202+
Args:
203+
program (ExportedProgram): The program to get output parameters from.
204+
Returns:
205+
Tuple[str, Optional[QuantizationParams]]: A tuple containing the
206+
output node name and its quantization parameters.
207+
"""
208+
output_node = None
209+
for node in program.graph.nodes:
210+
if node.op == "output":
211+
output_node = node
212+
break
213+
214+
if self.is_quantized:
215+
quant_params = None
216+
for node in program.graph.nodes:
217+
if (
218+
node.target
219+
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
220+
and node == output_node.args[0][0]
221+
):
222+
quant_params = QuantizationParams(
223+
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
224+
)
225+
break # break early, there's only one output node
226+
assert quant_params is not None, "Quantization paramerters not found"
227+
return (output_node.name, quant_params)
228+
else:
229+
return (output_node.name, None)
230+
231+
@staticmethod
232+
def _calculate_reference_output(
233+
module: Union[torch.fx.GraphModule, torch.nn.Module], inputs
234+
) -> torch.Tensor:
235+
"""
236+
Note: I'd prefer to use the base class method here, but since it use the
237+
exported program, I can't. The partitioner stage clears the state_dict
238+
of the exported program, which causes an issue when evaluating the
239+
module.
240+
"""
241+
242+
return module.forward(*inputs)

0 commit comments

Comments
 (0)