@@ -128,7 +128,7 @@ def to_edge(self, to_edge_stage: Optional[ToEdge] = None):
128
128
to_edge_stage = ToEdge (EdgeCompileConfig (_check_ir_validity = False ))
129
129
return super ().to_edge (to_edge_stage )
130
130
131
- def partition (self , partition_stage : Optional [Partition ] = None ):
131
+ def partition (self , partition_stage : Optional [Partition ] = None ): # pyre-ignore
132
132
if partition_stage is None :
133
133
arm_partitioner = ArmPartitioner (compile_spec = self .compile_spec )
134
134
partition_stage = Partition (arm_partitioner )
@@ -196,6 +196,34 @@ def run_method(
196
196
197
197
return self
198
198
199
+ def compare_outputs (self , atol = 1e-03 , rtol = 1e-03 , qtol = 0 ):
200
+ """
201
+ Compares the original of the original nn module with the output of the generated artifact.
202
+ This requres calling run_method before calling compare_outputs. As that runs the generated
203
+ artifact on the sample inputs and sets the stage output to be compared against the reference.
204
+ """
205
+ assert self .reference_output is not None
206
+ assert self .stage_output is not None
207
+
208
+ # Wrap both outputs as tuple, since executor output is always a tuple even if single tensor
209
+ if isinstance (self .reference_output , torch .Tensor ):
210
+ self .reference_output = (self .reference_output ,)
211
+ if isinstance (self .stage_output , torch .Tensor ):
212
+ self .stage_output = (self .stage_output ,)
213
+
214
+ # If a qtol is provided and we found an dequantization node prior to the output, relax the
215
+ # atol by qtol quant units.
216
+ if self .quantization_scale is not None :
217
+ atol += self .quantization_scale * qtol
218
+
219
+ self ._assert_outputs_equal (
220
+ self .stage_output ,
221
+ self .reference_output ,
222
+ atol = atol ,
223
+ rtol = rtol ,
224
+ )
225
+ return self
226
+
199
227
def _get_input_params (
200
228
self , program : ExportedProgram
201
229
) -> Tuple [str , Union [List [QuantizationParams ], List [None ]]]:
0 commit comments