Skip to content

Commit 103dd7e

Browse files
committed
[ArmBackend] Minor improvements to model unit tests
* Tighten numerical tolerance for MobileNetV2 test and ensure randomness by using test fwk for generating test vectors * Make sure to calibrate and test model with different data Change-Id: I5e345a2ba1fee8272abb498eceda4b829e2b9e72 Signed-off-by: Fredrik Knutsson <[email protected]>
1 parent 8ad15f3 commit 103dd7e

File tree

4 files changed

+70
-57
lines changed

4 files changed

+70
-57
lines changed

backends/arm/test/models/test_conformer.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
logger.setLevel(logging.INFO)
1919

2020

21+
def get_test_inputs(dim, lengths, num_examples):
22+
return (torch.rand(num_examples, int(lengths.max()), dim), lengths)
23+
24+
2125
class TestConformer(unittest.TestCase):
2226
"""Tests Torchaudio Conformer"""
2327

@@ -41,8 +45,9 @@ class TestConformer(unittest.TestCase):
4145
}
4246

4347
dim = 16
44-
lengths = torch.randint(1, 100, (10,), dtype=torch.int32)
45-
input_data = torch.rand(10, int(lengths.max()), dim)
48+
num_examples = 10
49+
lengths = torch.randint(1, 100, (num_examples,), dtype=torch.int32)
50+
model_example_inputs = get_test_inputs(dim, lengths, num_examples)
4651
conformer = Conformer(
4752
input_dim=dim,
4853
num_heads=4,
@@ -56,7 +61,7 @@ def test_conformer_tosa_MI(self):
5661
(
5762
ArmTester(
5863
self.conformer,
59-
example_inputs=(self.input_data, self.lengths),
64+
example_inputs=self.model_example_inputs,
6065
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-0.80+MI"),
6166
)
6267
.export()
@@ -66,7 +71,9 @@ def test_conformer_tosa_MI(self):
6671
.to_executorch()
6772
# TODO(MLETORCH-632): Fix numerical errors
6873
.run_method_and_compare_outputs(
69-
inputs=(self.input_data, self.lengths), rtol=1, atol=5
74+
rtol=1.0,
75+
atol=5.0,
76+
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
7077
)
7178
)
7279

@@ -75,15 +82,18 @@ def test_conformer_tosa_BI(self):
7582
(
7683
ArmTester(
7784
self.conformer,
78-
example_inputs=(self.input_data, self.lengths),
85+
example_inputs=self.model_example_inputs,
7986
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-0.80+BI"),
8087
)
8188
.quantize()
8289
.export()
8390
.to_edge_transform_and_lower()
8491
.to_executorch()
8592
.run_method_and_compare_outputs(
86-
qtol=1, rtol=1, atol=5, inputs=(self.input_data, self.lengths)
93+
qtol=1.0,
94+
rtol=1.0,
95+
atol=5.0,
96+
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
8797
)
8898
)
8999

@@ -92,7 +102,7 @@ def test_conformer_u55_BI(self):
92102
tester = (
93103
ArmTester(
94104
self.conformer,
95-
example_inputs=(self.input_data, self.lengths),
105+
example_inputs=self.model_example_inputs,
96106
compile_spec=common.get_u55_compile_spec(),
97107
)
98108
.quantize()
@@ -103,15 +113,18 @@ def test_conformer_u55_BI(self):
103113
)
104114
if conftest.is_option_enabled("corstone_fvp"):
105115
tester.run_method_and_compare_outputs(
106-
atol=1.0, qtol=1, inputs=(self.input_data, self.lengths)
116+
qtol=1.0,
117+
rtol=1.0,
118+
atol=5.0,
119+
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
107120
)
108121

109122
@unittest.expectedFailure # TODO(MLETORCH-635)
110123
def test_conformer_u85_BI(self):
111124
tester = (
112125
ArmTester(
113126
self.conformer,
114-
example_inputs=(self.input_data, self.lengths),
127+
example_inputs=self.model_example_inputs,
115128
compile_spec=common.get_u85_compile_spec(),
116129
)
117130
.quantize()
@@ -122,5 +135,8 @@ def test_conformer_u85_BI(self):
122135
)
123136
if conftest.is_option_enabled("corstone_fvp"):
124137
tester.run_method_and_compare_outputs(
125-
atol=1.0, qtol=1, inputs=(self.input_data, self.lengths)
138+
qtol=1.0,
139+
rtol=1.0,
140+
atol=5.0,
141+
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
126142
)

backends/arm/test/models/test_dl3_arm.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,38 @@ class TestDl3(unittest.TestCase):
1717
"""Tests DeepLabv3."""
1818

1919
dl3 = deeplab_v3.DeepLabV3ResNet50Model()
20-
model_inputs = dl3.get_example_inputs()
20+
model_example_inputs = dl3.get_example_inputs()
2121
dl3 = dl3.get_eager_model()
2222

2323
@unittest.expectedFailure
2424
def test_dl3_tosa_MI(self):
2525
(
2626
ArmTester(
2727
self.dl3,
28-
example_inputs=self.model_inputs,
28+
example_inputs=self.model_example_inputs,
2929
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
3030
)
3131
.export()
3232
.to_edge_transform_and_lower()
3333
.to_executorch()
34-
.run_method_and_compare_outputs(self.model_inputs)
34+
.run_method_and_compare_outputs(inputs=self.dl3.get_example_inputs())
3535
)
3636

3737
@unittest.expectedFailure
3838
def test_dl3_tosa_BI(self):
3939
(
4040
ArmTester(
4141
self.dl3,
42-
example_inputs=self.model_inputs,
42+
example_inputs=self.model_example_inputs,
4343
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
4444
)
4545
.quantize()
4646
.export()
4747
.to_edge_transform_and_lower()
4848
.to_executorch()
49-
.run_method_and_compare_outputs(atol=1.0, qtol=1, inputs=self.model_inputs)
49+
.run_method_and_compare_outputs(
50+
atol=1.0, qtol=1, inputs=self.dl3.get_example_inputs()
51+
)
5052
)
5153

5254
@pytest.mark.slow
@@ -56,7 +58,7 @@ def test_dl3_u55_BI(self):
5658
tester = (
5759
ArmTester(
5860
self.dl3,
59-
example_inputs=self.model_inputs,
61+
example_inputs=self.model_example_inputs,
6062
compile_spec=common.get_u55_compile_spec(),
6163
)
6264
.quantize()
@@ -67,7 +69,7 @@ def test_dl3_u55_BI(self):
6769
)
6870
if conftest.is_option_enabled("corstone_fvp"):
6971
tester.run_method_and_compare_outputs(
70-
atol=1.0, qtol=1, inputs=self.model_inputs
72+
atol=1.0, qtol=1, inputs=self.dl3.get_example_inputs()
7173
)
7274

7375
@pytest.mark.slow
@@ -77,7 +79,7 @@ def test_dl3_u85_BI(self):
7779
tester = (
7880
ArmTester(
7981
self.dl3,
80-
example_inputs=self.model_inputs,
82+
example_inputs=self.model_example_inputs,
8183
compile_spec=common.get_u85_compile_spec(),
8284
)
8385
.quantize()
@@ -88,5 +90,5 @@ def test_dl3_u85_BI(self):
8890
)
8991
if conftest.is_option_enabled("corstone_fvp"):
9092
tester.run_method_and_compare_outputs(
91-
atol=1.0, qtol=1, inputs=self.model_inputs
93+
atol=1.0, qtol=1, inputs=self.dl3.get_example_inputs()
9294
)

backends/arm/test/models/test_lstm_arm.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616
from torch.nn.quantizable.modules import rnn
1717

1818

19+
def get_test_inputs():
20+
return (
21+
torch.randn(5, 3, 10), # input
22+
(torch.randn(2, 3, 20), torch.randn(2, 3, 20)), # (h0, c0)
23+
)
24+
25+
1926
class TestLSTM(unittest.TestCase):
2027
"""Tests quantizable LSTM module."""
2128

@@ -27,46 +34,43 @@ class TestLSTM(unittest.TestCase):
2734
lstm = rnn.LSTM(10, 20, 2)
2835
lstm = lstm.eval()
2936

30-
input_tensor = torch.randn(5, 3, 10)
31-
h0 = torch.randn(2, 3, 20)
32-
c0 = torch.randn(2, 3, 20)
33-
34-
model_inputs = (input_tensor, (h0, c0))
37+
# Used e.g. for quantization calibration and shape extraction in the tester
38+
model_example_inputs = get_test_inputs()
3539

3640
def test_lstm_tosa_MI(self):
3741
(
3842
ArmTester(
3943
self.lstm,
40-
example_inputs=self.model_inputs,
44+
example_inputs=self.model_example_inputs,
4145
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
4246
)
4347
.export()
4448
.to_edge_transform_and_lower()
4549
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
4650
.to_executorch()
47-
.run_method_and_compare_outputs(inputs=self.model_inputs)
51+
.run_method_and_compare_outputs(inputs=get_test_inputs())
4852
)
4953

5054
def test_lstm_tosa_BI(self):
5155
(
5256
ArmTester(
5357
self.lstm,
54-
example_inputs=self.model_inputs,
58+
example_inputs=self.model_example_inputs,
5559
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
5660
)
5761
.quantize()
5862
.export()
5963
.to_edge_transform_and_lower()
6064
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
6165
.to_executorch()
62-
.run_method_and_compare_outputs(atol=3e-1, qtol=1, inputs=self.model_inputs)
66+
.run_method_and_compare_outputs(atol=3e-1, qtol=1, inputs=get_test_inputs())
6367
)
6468

6569
def test_lstm_u55_BI(self):
6670
tester = (
6771
ArmTester(
6872
self.lstm,
69-
example_inputs=self.model_inputs,
73+
example_inputs=self.model_example_inputs,
7074
compile_spec=common.get_u55_compile_spec(),
7175
)
7276
.quantize()
@@ -78,14 +82,14 @@ def test_lstm_u55_BI(self):
7882
)
7983
if conftest.is_option_enabled("corstone_fvp"):
8084
tester.run_method_and_compare_outputs(
81-
atol=3e-1, qtol=1, inputs=self.model_inputs
85+
atol=3e-1, qtol=1, inputs=get_test_inputs()
8286
)
8387

8488
def test_lstm_u85_BI(self):
8589
tester = (
8690
ArmTester(
8791
self.lstm,
88-
example_inputs=self.model_inputs,
92+
example_inputs=self.model_example_inputs,
8993
compile_spec=common.get_u85_compile_spec(),
9094
)
9195
.quantize()
@@ -97,5 +101,5 @@ def test_lstm_u85_BI(self):
97101
)
98102
if conftest.is_option_enabled("corstone_fvp"):
99103
tester.run_method_and_compare_outputs(
100-
atol=3e-1, qtol=1, inputs=self.model_inputs
104+
atol=3e-1, qtol=1, inputs=get_test_inputs()
101105
)

backends/arm/test/models/test_mobilenet_v2_arm.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,50 +32,37 @@ class TestMobileNetV2(unittest.TestCase):
3232
normalize = transforms.Normalize(
3333
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
3434
)
35-
model_inputs = (normalize(torch.randn((1, 3, 224, 224))),)
3635

37-
all_operators = {
38-
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
39-
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
40-
"executorch_exir_dialects_edge__ops_aten_permute_copy_default",
41-
"executorch_exir_dialects_edge__ops_aten_addmm_default",
42-
"executorch_exir_dialects_edge__ops_aten_mean_dim",
43-
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
44-
"executorch_exir_dialects_edge__ops_aten_convolution_default",
45-
}
46-
47-
operators_after_quantization = all_operators - {
48-
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
49-
}
36+
# Used e.g. for quantization calibration and shape extraction in the tester
37+
model_example_inputs = (normalize(torch.randn((1, 3, 224, 224))),)
5038

5139
def test_mv2_tosa_MI(self):
5240
(
5341
ArmTester(
5442
self.mv2,
55-
example_inputs=self.model_inputs,
43+
example_inputs=self.model_example_inputs,
5644
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
5745
)
5846
.export()
5947
.to_edge_transform_and_lower()
48+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
6049
.to_executorch()
61-
.run_method_and_compare_outputs(inputs=self.model_inputs)
50+
.run_method_and_compare_outputs()
6251
)
6352

6453
def test_mv2_tosa_BI(self):
6554
(
6655
ArmTester(
6756
self.mv2,
68-
example_inputs=self.model_inputs,
57+
example_inputs=self.model_example_inputs,
6958
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
7059
)
7160
.quantize()
7261
.export()
7362
.to_edge_transform_and_lower()
63+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
7464
.to_executorch()
75-
# atol=1.0 is a defensive upper limit
76-
# TODO MLETROCH-72
77-
# TODO MLETROCH-149
78-
.run_method_and_compare_outputs(atol=1.0, qtol=1, inputs=self.model_inputs)
65+
.run_method_and_compare_outputs(rtol=0.001, atol=0.2, qtol=1)
7966
)
8067

8168
@pytest.mark.slow
@@ -84,7 +71,7 @@ def test_mv2_u55_BI(self):
8471
tester = (
8572
ArmTester(
8673
self.mv2,
87-
example_inputs=self.model_inputs,
74+
example_inputs=self.model_example_inputs,
8875
compile_spec=common.get_u55_compile_spec(),
8976
)
9077
.quantize()
@@ -95,7 +82,9 @@ def test_mv2_u55_BI(self):
9582
)
9683
if conftest.is_option_enabled("corstone_fvp"):
9784
tester.run_method_and_compare_outputs(
98-
atol=1.0, qtol=1, inputs=self.model_inputs
85+
rtol=0.001,
86+
atol=0.2,
87+
qtol=1,
9988
)
10089

10190
@pytest.mark.slow
@@ -104,7 +93,7 @@ def test_mv2_u85_BI(self):
10493
tester = (
10594
ArmTester(
10695
self.mv2,
107-
example_inputs=self.model_inputs,
96+
example_inputs=self.model_example_inputs,
10897
compile_spec=common.get_u85_compile_spec(),
10998
)
11099
.quantize()
@@ -115,5 +104,7 @@ def test_mv2_u85_BI(self):
115104
)
116105
if conftest.is_option_enabled("corstone_fvp"):
117106
tester.run_method_and_compare_outputs(
118-
atol=1.0, qtol=1, inputs=self.model_inputs
107+
rtol=0.001,
108+
atol=0.2,
109+
qtol=1,
119110
)

0 commit comments

Comments
 (0)