Skip to content

Commit 4a801e6

Browse files
mcr229facebook-github-bot
authored andcommitted
Dynamic Shapes (#2442)
Summary: Pull Request resolved: #2442 Only need to look at tester.py file for the tester changes. Change is from `.run_method().compare_outputs() ` to `.run_method_and_compare_outputs()` now if Tester is initialized with dynamic inputs, we will generate random dynamic inputs (according to the specification of the dynamic shapes) to run on the model. This allows us to test that the inputs fed into the model can be dynamic. We ad a num_runs to run_method_and_compare_outputs so that we can choose to run a number of different dynamic inputs with dynamic shapes. Reviewed By: digantdesai, kirklandsign Differential Revision: D54650121
1 parent 554cd27 commit 4a801e6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+283
-317
lines changed

backends/xnnpack/test/models/deeplab_v3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,5 @@ def test_fp32_dl3(self):
3636
.partition()
3737
.to_executorch()
3838
.serialize()
39-
.run_method()
40-
.compare_outputs()
39+
.run_method_and_compare_outputs()
4140
)

backends/xnnpack/test/models/edsr.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def test_fp32_edsr(self):
2525
.partition()
2626
.to_executorch()
2727
.serialize()
28-
.run_method()
29-
.compare_outputs()
28+
.run_method_and_compare_outputs()
3029
)
3130

3231
def test_qs8_edsr(self):
@@ -38,6 +37,5 @@ def test_qs8_edsr(self):
3837
.partition()
3938
.to_executorch()
4039
.serialize()
41-
.run_method()
42-
.compare_outputs()
40+
.run_method_and_compare_outputs()
4341
)

backends/xnnpack/test/models/emformer_rnnt.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def __init__(self):
2121
self.rnnt = decoder.model
2222

2323
class Joiner(EmformerRnnt):
24-
def forward(self, predict_inputs):
25-
return self.rnnt.join(*predict_inputs)
24+
def forward(self, a, b, c, d):
25+
return self.rnnt.join(a, b, c, d)
2626

2727
def get_example_inputs(self):
2828
join_inputs = (
@@ -31,7 +31,7 @@ def get_example_inputs(self):
3131
torch.rand([1, 128, 1024]),
3232
torch.tensor([128]),
3333
)
34-
return (join_inputs,)
34+
return join_inputs
3535

3636
def test_fp32_emformer_joiner(self):
3737
joiner = self.Joiner()
@@ -43,21 +43,19 @@ def test_fp32_emformer_joiner(self):
4343
.check(["torch.ops.higher_order.executorch_call_delegate"])
4444
.to_executorch()
4545
.serialize()
46-
.run_method()
47-
.compare_outputs()
46+
.run_method_and_compare_outputs()
4847
)
4948

5049
class Predictor(EmformerRnnt):
51-
def forward(self, predict_inputs):
52-
return self.rnnt.predict(*predict_inputs)
50+
def forward(self, a, b):
51+
return self.rnnt.predict(a, b, None)
5352

5453
def get_example_inputs(self):
5554
predict_inputs = (
5655
torch.zeros([1, 128], dtype=int),
5756
torch.tensor([128], dtype=int),
58-
None,
5957
)
60-
return (predict_inputs,)
58+
return predict_inputs
6159

6260
@unittest.skip("T183426271")
6361
def test_fp32_emformer_predictor(self):
@@ -70,20 +68,19 @@ def test_fp32_emformer_predictor(self):
7068
.check(["torch.ops.higher_order.executorch_call_delegate"])
7169
.to_executorch()
7270
.serialize()
73-
.run_method()
74-
.compare_outputs()
71+
.run_method_and_compare_outputs()
7572
)
7673

7774
class Transcriber(EmformerRnnt):
78-
def forward(self, predict_inputs):
79-
return self.rnnt.transcribe(*predict_inputs)
75+
def forward(self, a, b):
76+
return self.rnnt.transcribe(a, b)
8077

8178
def get_example_inputs(self):
8279
transcribe_inputs = (
8380
torch.randn(1, 128, 80),
8481
torch.tensor([128]),
8582
)
86-
return (transcribe_inputs,)
83+
return transcribe_inputs
8784

8885
def test_fp32_emformer_transcriber(self):
8986
transcriber = self.Transcriber()
@@ -95,6 +92,5 @@ def test_fp32_emformer_transcriber(self):
9592
.check(["torch.ops.higher_order.executorch_call_delegate"])
9693
.to_executorch()
9794
.serialize()
98-
.run_method()
99-
.compare_outputs()
95+
.run_method_and_compare_outputs()
10096
)

backends/xnnpack/test/models/inception_v3.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def test_fp32_ic3(self):
4242
.check_not(list(self.all_operators))
4343
.to_executorch()
4444
.serialize()
45-
.run_method()
46-
.compare_outputs()
45+
.run_method_and_compare_outputs()
4746
)
4847

4948
def test_qs8_ic3(self):
@@ -63,6 +62,5 @@ def test_qs8_ic3(self):
6362
.check_not(list(ops_after_quantization))
6463
.to_executorch()
6564
.serialize()
66-
.run_method()
67-
.compare_outputs()
65+
.run_method_and_compare_outputs()
6866
)

backends/xnnpack/test/models/inception_v4.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def test_fp32_ic4(self):
3939
.check_not(list(self.all_operators))
4040
.to_executorch()
4141
.serialize()
42-
.run_method()
43-
.compare_outputs()
42+
.run_method_and_compare_outputs()
4443
)
4544

4645
def test_qs8_ic4(self):
@@ -60,6 +59,5 @@ def test_qs8_ic4(self):
6059
.check_not(list(ops_after_quantization))
6160
.to_executorch()
6261
.serialize()
63-
.run_method()
64-
.compare_outputs()
62+
.run_method_and_compare_outputs()
6563
)

backends/xnnpack/test/models/llama2_et_example.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,5 @@ def _test(self, dtype: torch.dtype = torch.float):
4545
.dump_artifact()
4646
.to_executorch()
4747
.serialize()
48-
.run_method()
49-
.compare_outputs(atol=5e-2)
48+
.run_method_and_compare_outputs(atol=5e-2)
5049
)

backends/xnnpack/test/models/mobilebert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,5 @@ def test_fp32_mobilebert(self):
3838
.check_not(list(self.supported_ops))
3939
.to_executorch()
4040
.serialize()
41-
.run_method()
42-
.compare_outputs()
41+
.run_method_and_compare_outputs()
4342
)

backends/xnnpack/test/models/mobilenet_v2.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ def test_fp32_mv2(self):
4040
.check_not(list(self.all_operators))
4141
.to_executorch()
4242
.serialize()
43-
.run_method()
44-
.compare_outputs()
43+
.run_method_and_compare_outputs()
4544
)
4645

4746
def test_qs8_mv2(self):
@@ -61,6 +60,5 @@ def test_qs8_mv2(self):
6160
.check_not(list(ops_after_quantization))
6261
.to_executorch()
6362
.serialize()
64-
.run_method()
65-
.compare_outputs()
63+
.run_method_and_compare_outputs()
6664
)

backends/xnnpack/test/models/mobilenet_v3.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def test_fp32_mv3(self):
4242
.check_not(list(self.all_operators))
4343
.to_executorch()
4444
.serialize()
45-
.run_method()
46-
.compare_outputs()
45+
.run_method_and_compare_outputs()
4746
)
4847

4948
def test_qs8_mv3(self):
@@ -63,6 +62,5 @@ def test_qs8_mv3(self):
6362
.check_not(list(ops_after_lowering))
6463
.to_executorch()
6564
.serialize()
66-
.run_method()
67-
.compare_outputs()
65+
.run_method_and_compare_outputs()
6866
)

backends/xnnpack/test/models/resnet.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def test_fp32_resnet18(self):
2323
.partition()
2424
.to_executorch()
2525
.serialize()
26-
.run_method()
27-
.compare_outputs()
26+
.run_method_and_compare_outputs()
2827
)
2928

3029
def test_qs8_resnet18(self):
@@ -37,6 +36,5 @@ def test_qs8_resnet18(self):
3736
.partition()
3837
.to_executorch()
3938
.serialize()
40-
.run_method()
41-
.compare_outputs()
39+
.run_method_and_compare_outputs()
4240
)

backends/xnnpack/test/models/torchvision_vit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,5 @@ def test_fp32_vit(self):
5757
.check_not(list(lowerable_xnn_operators))
5858
.to_executorch()
5959
.serialize()
60-
.run_method()
61-
.compare_outputs()
60+
.run_method_and_compare_outputs()
6261
)

backends/xnnpack/test/models/very_big_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,5 @@ def test_very_big_model(self):
3939
.check(["torch.ops.higher_order.executorch_call_delegate"])
4040
.to_executorch()
4141
.serialize()
42-
.run_method()
43-
.compare_outputs()
42+
.run_method_and_compare_outputs()
4443
)

backends/xnnpack/test/models/w2l.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def test_fp32_w2l(self):
3434
.check(["torch.ops.higher_order.executorch_call_delegate"])
3535
.to_executorch()
3636
.serialize()
37-
.run_method()
38-
.compare_outputs()
37+
.run_method_and_compare_outputs()
3938
)
4039

4140
def test_qs8_w2l(self):
@@ -54,6 +53,5 @@ def test_qs8_w2l(self):
5453
.check(["torch.ops.higher_order.executorch_call_delegate"])
5554
.to_executorch()
5655
.serialize()
57-
.run_method()
58-
.compare_outputs()
56+
.run_method_and_compare_outputs()
5957
)

backends/xnnpack/test/ops/abs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ def _test_abs(self, inputs):
3131
.check_not(["executorch_exir_dialects_edge__ops_aten_abs_default"])
3232
.to_executorch()
3333
.serialize()
34-
.run_method()
35-
.compare_outputs()
34+
.run_method_and_compare_outputs()
3635
)
3736

3837
def test_fp16_abs(self):

backends/xnnpack/test/ops/add.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ def _test_add(self, inputs):
5454
.check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"])
5555
.to_executorch()
5656
.serialize()
57-
.run_method()
58-
.compare_outputs()
57+
.run_method_and_compare_outputs()
5958
)
6059

6160
def test_fp16_add(self):
@@ -79,8 +78,7 @@ def test_fp32_add_constant(self):
7978
.check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"])
8079
.to_executorch()
8180
.serialize()
82-
.run_method()
83-
.compare_outputs()
81+
.run_method_and_compare_outputs()
8482
)
8583

8684
def test_qs8_add_constant(self):
@@ -121,8 +119,7 @@ def test_qs8_add(self):
121119
)
122120
.to_executorch()
123121
.serialize()
124-
.run_method()
125-
.compare_outputs()
122+
.run_method_and_compare_outputs()
126123
)
127124

128125
def test_qs8_add2(self):
@@ -145,8 +142,7 @@ def test_qs8_add2(self):
145142
)
146143
.to_executorch()
147144
.serialize()
148-
.run_method()
149-
.compare_outputs()
145+
.run_method_and_compare_outputs()
150146
)
151147

152148
def test_qs8_add3(self):
@@ -169,8 +165,7 @@ def test_qs8_add3(self):
169165
)
170166
.to_executorch()
171167
.serialize()
172-
.run_method()
173-
.compare_outputs()
168+
.run_method_and_compare_outputs()
174169
)
175170

176171
class AddRelu(torch.nn.Module):
@@ -194,8 +189,7 @@ def test_fp32_add_relu(self):
194189
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
195190
.to_executorch()
196191
.serialize()
197-
.run_method()
198-
.compare_outputs()
192+
.run_method_and_compare_outputs()
199193
)
200194

201195
def test_qs8_add_relu(self):
@@ -214,8 +208,7 @@ def test_qs8_add_relu(self):
214208
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
215209
.to_executorch()
216210
.serialize()
217-
.run_method()
218-
.compare_outputs()
211+
.run_method_and_compare_outputs()
219212
)
220213

221214
def test_qs8_add_relu_seq(self):
@@ -261,6 +254,5 @@ def forward(self, x, z):
261254
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
262255
.to_executorch()
263256
.serialize()
264-
.run_method()
265-
.compare_outputs()
257+
.run_method_and_compare_outputs()
266258
)

backends/xnnpack/test/ops/avgpool2d.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def _test_argpool2d(self, inputs):
4242
.check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
4343
.to_executorch()
4444
.serialize()
45-
.run_method()
46-
.compare_outputs()
45+
.run_method_and_compare_outputs()
4746
)
4847

4948
def test_fp16_avgpool2d(self):

backends/xnnpack/test/ops/bilinear2d.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ def test_fp32_static_resize_bilinear2d(self):
8787
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
8888
.to_executorch()
8989
.serialize()
90-
.run_method()
91-
.compare_outputs()
90+
.run_method_and_compare_outputs()
9291
)
9392

9493
def test_fp32_static_resize_bilinear2d_with_align_cornesr(self):
@@ -103,8 +102,7 @@ def test_fp32_static_resize_bilinear2d_with_align_cornesr(self):
103102
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
104103
.to_executorch()
105104
.serialize()
106-
.run_method()
107-
.compare_outputs()
105+
.run_method_and_compare_outputs()
108106
)
109107

110108
def test_fp32_static_resize_bilinear2d_antialiased(self):

0 commit comments

Comments
 (0)