Skip to content

Commit 38d04b6

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Update compare_output medhod for tester (#3016)
Summary: Method name update Reviewed By: mcr229 Differential Revision: D56072265
1 parent 057e432 commit 38d04b6

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

backends/arm/test/tester/arm_tester.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def to_edge(self, to_edge_stage: Optional[ToEdge] = None):
128128
to_edge_stage = ToEdge(EdgeCompileConfig(_check_ir_validity=False))
129129
return super().to_edge(to_edge_stage)
130130

131-
def partition(self, partition_stage: Optional[Partition] = None):
131+
def partition(self, partition_stage: Optional[Partition] = None): # pyre-ignore
132132
if partition_stage is None:
133133
arm_partitioner = ArmPartitioner(compile_spec=self.compile_spec)
134134
partition_stage = Partition(arm_partitioner)
@@ -196,6 +196,34 @@ def run_method(
196196

197197
return self
198198

199+
def compare_outputs(self, atol=1e-03, rtol=1e-03, qtol=0):
200+
"""
201+
Compares the original of the original nn module with the output of the generated artifact.
202+
This requres calling run_method before calling compare_outputs. As that runs the generated
203+
artifact on the sample inputs and sets the stage output to be compared against the reference.
204+
"""
205+
assert self.reference_output is not None
206+
assert self.stage_output is not None
207+
208+
# Wrap both outputs as tuple, since executor output is always a tuple even if single tensor
209+
if isinstance(self.reference_output, torch.Tensor):
210+
self.reference_output = (self.reference_output,)
211+
if isinstance(self.stage_output, torch.Tensor):
212+
self.stage_output = (self.stage_output,)
213+
214+
# If a qtol is provided and we found an dequantization node prior to the output, relax the
215+
# atol by qtol quant units.
216+
if self.quantization_scale is not None:
217+
atol += self.quantization_scale * qtol
218+
219+
self._assert_outputs_equal(
220+
self.stage_output,
221+
self.reference_output,
222+
atol=atol,
223+
rtol=rtol,
224+
)
225+
return self
226+
199227
def _get_input_params(
200228
self, program: ExportedProgram
201229
) -> Tuple[str, Union[List[QuantizationParams], List[None]]]:

backends/xnnpack/test/tester/tester.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,22 @@ def check_node_count(self, input: Dict[Any, int]):
534534

535535
return self
536536

537+
def run_method(
538+
self, stage: Optional[str] = None, inputs: Optional[Tuple[torch.Tensor]] = None
539+
):
540+
# This is to avoid accidental ommition of compare_outputs resulting in
541+
# false positive of the test passing.
542+
raise NotImplementedError(
543+
"run_method is deprecated, please use run_method_and_compare_outputs"
544+
)
545+
546+
def compare_outputs(self, atol=1e-03, rtol=1e-03, qtol=0):
547+
# This is to avoid accidental ommition of compare_outputs resulting in
548+
# false positive of the test passing.
549+
raise NotImplementedError(
550+
"compare_outputs is deprecated, please use run_method_and_compare_outputs"
551+
)
552+
537553
def run_method_and_compare_outputs(
538554
self,
539555
stage: Optional[str] = None,

0 commit comments

Comments
 (0)