Skip to content

Commit 8ee280b

Browse files
Arm backend: Add MobileNet v3 testcase (#9223)
* Add missing inplace operators to quantization annotator. * Add atol, rtol & qtol to test_pipeline classes. Co-authored-by: Oscar Andersson <[email protected]> Signed-off-by: Tom Allsop <[email protected]>
1 parent 771588a commit 8ee280b

File tree

3 files changed

+116
-4
lines changed

3 files changed

+116
-4
lines changed

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def _match_pattern(
137137
torch.ops.aten.sum.dim_IntList,
138138
torch.ops.aten.hardsigmoid.default,
139139
torch.ops.aten.hardswish.default,
140+
torch.ops.aten.hardswish_.default,
140141
torch.ops.aten.full_like.default,
141142
]
142143

@@ -196,6 +197,7 @@ def _match_pattern(
196197
torch.ops.aten.full.default,
197198
torch.ops.aten.flatten.using_ints,
198199
torch.ops.aten.dropout.default,
200+
torch.ops.aten.dropout_.default,
199201
torch.ops.aten.clamp.default,
200202
torch.ops.aten.clamp.Tensor,
201203
operator.getitem,
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import common
9+
import pytest
10+
11+
import torch
12+
13+
from executorch.backends.arm.test.tester.test_pipeline import (
14+
EthosU55PipelineBI,
15+
EthosU85PipelineBI,
16+
TosaPipelineBI,
17+
TosaPipelineMI,
18+
)
19+
20+
from torchvision import models, transforms
21+
22+
mv3 = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights)
23+
mv3 = mv3.eval()
24+
25+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26+
27+
input_tensor = torch.rand(1, 3, 232, 232)
28+
29+
model_inputs = (normalize(input_tensor),)
30+
input_t = Tuple[torch.Tensor]
31+
32+
33+
@pytest.mark.slow
34+
def test_mv3_tosa_MI():
35+
pipeline = TosaPipelineMI[input_t](
36+
mv3, model_inputs, aten_op=[], exir_op=[], use_to_edge_transform_and_lower=True
37+
)
38+
pipeline.run()
39+
40+
41+
@pytest.mark.slow
42+
def test_mv3_tosa_BI():
43+
pipeline = TosaPipelineBI[input_t](
44+
mv3,
45+
model_inputs,
46+
aten_op=[],
47+
exir_op=[],
48+
use_to_edge_transform_and_lower=True,
49+
atol=0.3,
50+
qtol=1,
51+
)
52+
pipeline.run()
53+
54+
55+
@pytest.mark.slow
56+
@pytest.mark.corstone_fvp
57+
@common.XfailIfNoCorstone300
58+
def test_mv3_u55_BI():
59+
pipeline = EthosU55PipelineBI[input_t](
60+
mv3,
61+
model_inputs,
62+
aten_ops=[],
63+
exir_ops=[],
64+
run_on_fvp=True,
65+
use_to_edge_transform_and_lower=True,
66+
atol=0.3,
67+
qtol=1,
68+
)
69+
pipeline.run()
70+
71+
72+
@pytest.mark.slow
73+
@pytest.mark.corstone_fvp
74+
@common.XfailIfNoCorstone320
75+
def test_mv3_u85_BI():
76+
pipeline = EthosU85PipelineBI[input_t](
77+
mv3,
78+
model_inputs,
79+
aten_ops=[],
80+
exir_ops=[],
81+
run_on_fvp=True,
82+
use_to_edge_transform_and_lower=True,
83+
atol=0.3,
84+
qtol=1,
85+
)
86+
pipeline.run()

backends/arm/test/tester/test_pipeline.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,9 @@ def __init__(
274274
symmetric_io_quantization: bool = False,
275275
use_to_edge_transform_and_lower: bool = True,
276276
custom_path: str = None,
277+
atol: float = 1e-03,
278+
rtol: float = 1e-03,
279+
qtol: int = 0,
277280
):
278281
compile_spec = common.get_tosa_compile_spec(
279282
tosa_version, custom_path=custom_path
@@ -322,7 +325,11 @@ def __init__(
322325
)
323326

324327
self.add_stage(
325-
self.tester.run_method_and_compare_outputs, inputs=self.test_data
328+
self.tester.run_method_and_compare_outputs,
329+
atol=atol,
330+
rtol=rtol,
331+
qtol=qtol,
332+
inputs=self.test_data,
326333
)
327334

328335

@@ -353,6 +360,9 @@ def __init__(
353360
tosa_version: str = "TOSA-0.80+MI",
354361
use_to_edge_transform_and_lower: bool = True,
355362
custom_path: str = None,
363+
atol: float = 1e-03,
364+
rtol: float = 1e-03,
365+
qtol: int = 0,
356366
):
357367
compile_spec = common.get_tosa_compile_spec(
358368
tosa_version, custom_path=custom_path
@@ -376,7 +386,11 @@ def __init__(
376386
)
377387

378388
self.add_stage(
379-
self.tester.run_method_and_compare_outputs, inputs=self.test_data
389+
self.tester.run_method_and_compare_outputs,
390+
atol=atol,
391+
rtol=rtol,
392+
qtol=qtol,
393+
inputs=self.test_data,
380394
)
381395

382396

@@ -406,6 +420,9 @@ def __init__(
406420
symmetric_io_quantization: bool = False,
407421
use_to_edge_transform_and_lower: bool = False,
408422
custom_path: str = None,
423+
atol: float = 1e-03,
424+
rtol: float = 1e-03,
425+
qtol: int = 1,
409426
):
410427
compile_spec = common.get_u55_compile_spec(custom_path=custom_path)
411428
quant_stage = (
@@ -458,7 +475,9 @@ def __init__(
458475
self.add_stage(self.tester.serialize)
459476
self.add_stage(
460477
self.tester.run_method_and_compare_outputs,
461-
qtol=1,
478+
atol=atol,
479+
rtol=rtol,
480+
qtol=qtol,
462481
inputs=self.test_data,
463482
)
464483

@@ -489,6 +508,9 @@ def __init__(
489508
symmetric_io_quantization: bool = False,
490509
use_to_edge_transform_and_lower: bool = False,
491510
custom_path: str = None,
511+
atol: float = 1e-03,
512+
rtol: float = 1e-03,
513+
qtol: int = 1,
492514
):
493515
compile_spec = common.get_u85_compile_spec(custom_path=custom_path)
494516
quant_stage = (
@@ -541,7 +563,9 @@ def __init__(
541563
self.add_stage(self.tester.serialize)
542564
self.add_stage(
543565
self.tester.run_method_and_compare_outputs,
544-
qtol=1,
566+
atol=atol,
567+
rtol=rtol,
568+
qtol=qtol,
545569
inputs=self.test_data,
546570
)
547571

0 commit comments

Comments
 (0)