Skip to content

Commit 5aa0ee2

Browse files
authored
Arm backend: Minor improvements in Arm testing. (#8337)
* Check that reference model was run in TosaReferenceModelDispatch If the if statment looking for executorch_call_delegates fails, the delegate can run the model in eager mode and still produce correct results without the user noticing. We need to make sure we actually run the reference model. * Add markers to flaky tests
1 parent 9746ce7 commit 5aa0ee2

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

backends/arm/test/ops/test_bmm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def test_bmm_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
124124
self._test_bmm_tosa_MI_pipeline(self.BMM(), test_data)
125125

126126
@parameterized.expand(BMMSingleInput.test_data_generators)
127+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
127128
def test_bmm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
128129
test_data = test_data_generator()
129130
self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data)
@@ -144,6 +145,7 @@ def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
144145
self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data)
145146

146147
@parameterized.expand(BMMSingleInput.test_data_generators)
148+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
147149
def test_bmm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
148150
test_data = test_data_generator()
149151
self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data)

backends/arm/test/ops/test_mm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def test_mm_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
115115
self._test_mm_tosa_MI_pipeline(self.MM(), test_data)
116116

117117
@parameterized.expand(MMSingleInput.test_data_generators)
118+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
118119
def test_mm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
119120
test_data = test_data_generator()
120121
self._test_mm_tosa_MI_pipeline(self.MMSingleInput(), test_data)
@@ -126,6 +127,7 @@ def test_mm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
126127
self._test_mm_tosa_BI_pipeline(self.MM(), test_data)
127128

128129
@parameterized.expand(MMSingleInput.test_data_generators)
130+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
129131
def test_mm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
130132
test_data = test_data_generator()
131133
self._test_mm_tosa_BI_pipeline(self.MMSingleInput(), test_data)

backends/arm/test/runner_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ def get_output_quantization_params(
157157
class TosaReferenceModelDispatch(TorchFunctionMode):
158158
"""A context manager for executing call_delegate nodes using the reference model"""
159159

160+
def __init__(self):
161+
self.ran_tosa_dispatch = False
162+
super().__init__()
163+
160164
def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs):
161165
tosa_buffer = lowered_backend_module.processed_bytes
162166
compile_specs = lowered_backend_module.compile_specs
@@ -168,13 +172,21 @@ def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs):
168172

169173
return run_tosa_graph(tosa_buffer, tosa_version, inputs)
170174

175+
def __exit__(self, exc_type, exc_val, exc_tb):
176+
super().__exit__(exc_type, exc_val, exc_tb)
177+
if not self.ran_tosa_dispatch:
178+
raise RuntimeError(
179+
"Ran model with TosaReferenceModelDispatch but never ran ArmBackend delegate."
180+
)
181+
171182
def __torch_function__(self, func, types, args=..., kwargs=None):
172183
if func is torch._higher_order_ops.executorch_call_delegate:
173184
lowered_backend_module = cast(LoweredBackendModule, args[0])
174185
if lowered_backend_module.backend_id == "ArmBackend":
186+
self.ran_tosa_dispatch = True
175187
return self._tosa_dispatch(lowered_backend_module, args[1:])
176188
else:
177-
logger.warning(
189+
raise RuntimeError(
178190
f"Ran model with TosaReferenceModelDispatch but call_delegate with {lowered_backend_module.backend_id=} != 'ArmBackend'."
179191
)
180192

0 commit comments

Comments
 (0)