Skip to content

Commit a43b4a6

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
introduce model-level end2end tests to dim order tests with different delegate (#6093)
Summary: Pull Request resolved: #6093 This diff introduced end2end tests on several models + delegation combinations. Models: llama2, resnet18, mobilenet_v3 Delegate: no delegate, xnnpack Reviewed By: digantdesai, larryliu0820 Differential Revision: D64174329 fbshipit-source-id: 0807e0282d136bf1ef6d5be88e0c9f8512580f38
1 parent 69766fb commit a43b4a6

File tree

4 files changed

+138
-15
lines changed

4 files changed

+138
-15
lines changed

exir/tests/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ python_unittest(
377377
":test_memory_format_ops_pass_utils",
378378
"//caffe2:torch",
379379
"//executorch/extension/pybindings:aten_lib", # @manual
380+
"//pytorch/vision:torchvision", # @manual
380381
],
381382
)
382383

@@ -394,6 +395,7 @@ python_unittest(
394395
"//executorch/exir/dialects:lib",
395396
"//executorch/exir/dialects/edge:lib",
396397
"//executorch/extension/pybindings:portable_lib", # @manual
398+
"//pytorch/vision:torchvision", # @manual
397399
],
398400
)
399401

@@ -404,6 +406,7 @@ python_library(
404406
],
405407
deps = [
406408
"//caffe2:torch",
409+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
407410
"//executorch/exir:dim_order_utils",
408411
"//executorch/exir:lib",
409412
"//executorch/exir/capture:config",

exir/tests/test_memory_format_ops_pass.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from typing import Union
1111

1212
import torch
13+
14+
import torchvision
1315
from executorch.exir import EdgeCompileConfig, to_edge
1416
from executorch.exir.dialects._ops import ops as exir_ops
1517
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -264,3 +266,65 @@ def call_operator(self, op, args, kwargs, meta):
264266

265267
self.assertTrue(is_contiguous_dim_order(actual))
266268
self.assertTrue(is_contiguous_dim_order(expected))
269+
270+
def test_resnet18(self) -> None:
271+
model = torchvision.models.resnet18()
272+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
273+
self,
274+
MemoryFormatTestSet(
275+
module=model.eval(),
276+
sample_input=(torch.randn(1, 3, 224, 224),),
277+
target_memory_format=torch.contiguous_format,
278+
op_level_check=False,
279+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
280+
atol=1e-3,
281+
rtol=1e-3,
282+
),
283+
)
284+
285+
def test_resnet18_xnnpack(self) -> None:
286+
model = torchvision.models.resnet18()
287+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
288+
self,
289+
MemoryFormatTestSet(
290+
module=model.eval(),
291+
sample_input=(torch.randn(1, 3, 224, 224),),
292+
target_memory_format=torch.contiguous_format,
293+
op_level_check=False,
294+
use_xnnpack=True,
295+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
296+
atol=1e-3,
297+
rtol=1e-3,
298+
),
299+
)
300+
301+
def test_mobilenet_v3(self) -> None:
302+
model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True)
303+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
304+
self,
305+
MemoryFormatTestSet(
306+
module=model.eval(),
307+
sample_input=(torch.randn(1, 3, 224, 224),),
308+
target_memory_format=torch.contiguous_format,
309+
op_level_check=False,
310+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
311+
atol=1e-3,
312+
rtol=1e-3,
313+
),
314+
)
315+
316+
def test_mobilenet_v3_xnnpack(self) -> None:
317+
model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True)
318+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
319+
self,
320+
MemoryFormatTestSet(
321+
module=model.eval(),
322+
sample_input=(torch.randn(1, 3, 224, 224),),
323+
target_memory_format=torch.contiguous_format,
324+
op_level_check=False,
325+
use_xnnpack=True,
326+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
327+
atol=1e-3,
328+
rtol=1e-3,
329+
),
330+
)

exir/tests/test_memory_format_ops_pass_aten.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import unittest
88

99
import torch
10+
import torchvision
1011

1112
from executorch.exir.tests.test_memory_format_ops_pass_utils import (
1213
MemoryFormatOpsPassTestUtils,
@@ -77,3 +78,33 @@ def test_op_dim_order_propagation_aten(self) -> None:
7778
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
7879
),
7980
)
81+
82+
def test_resnet18(self) -> None:
83+
model = torchvision.models.resnet18()
84+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
85+
self,
86+
MemoryFormatTestSet(
87+
module=model.eval(),
88+
sample_input=(torch.randn(1, 3, 224, 224),),
89+
target_memory_format=torch.contiguous_format,
90+
op_level_check=False,
91+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
92+
atol=1e-3,
93+
rtol=1e-3,
94+
),
95+
)
96+
97+
def test_mobilenet_v3(self) -> None:
98+
model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True)
99+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
100+
self,
101+
MemoryFormatTestSet(
102+
module=model.eval(),
103+
sample_input=(torch.randn(1, 3, 224, 224),),
104+
target_memory_format=torch.contiguous_format,
105+
op_level_check=False,
106+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
107+
atol=1e-3,
108+
rtol=1e-3,
109+
),
110+
)

exir/tests/test_memory_format_ops_pass_utils.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
from typing import Any, Tuple
1212

1313
import torch
14-
from executorch.exir import to_edge
14+
15+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
16+
from executorch.exir import to_edge, to_edge_transform_and_lower
1517
from executorch.exir.capture._config import EdgeCompileConfig
1618

1719
from executorch.exir.dim_order_utils import (
@@ -30,6 +32,10 @@ class MemoryFormatTestSet:
3032
sample_input: Tuple[Any, ...]
3133
target_memory_format: torch.memory_format
3234
_load_for_executorch_from_buffer: Any
35+
op_level_check: bool = True
36+
use_xnnpack: bool = False
37+
rtol: float = 1e-05
38+
atol: float = 1e-08
3339

3440

3541
class SimpleToCopyContiguousModule(torch.nn.Module):
@@ -63,27 +69,42 @@ class MemoryFormatOpsPassTestUtils:
6369
def memory_format_test_runner(
6470
test_class: unittest.TestCase, test_set: MemoryFormatTestSet
6571
):
66-
aten_op_str = "torch.ops.aten._to_copy.default"
67-
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
68-
6972
before = export(test_set.module, test_set.sample_input)
7073

71-
# check op strings before
72-
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
73-
edge_op_str
74-
).run(before.graph_module.code)
74+
if test_set.use_xnnpack:
75+
epm = to_edge_transform_and_lower(
76+
before,
77+
compile_config=EdgeCompileConfig(
78+
_skip_dim_order=False, _check_ir_validity=False
79+
),
80+
partitioner=[XnnpackPartitioner()],
81+
)
82+
else:
83+
epm = to_edge(
84+
before, compile_config=EdgeCompileConfig(_skip_dim_order=False)
85+
)
86+
87+
# check memory format ops, if needed
88+
if test_set.op_level_check:
89+
aten_op_str = "torch.ops.aten._to_copy.default"
90+
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
7591

76-
epm = to_edge(before, compile_config=EdgeCompileConfig(_skip_dim_order=False))
92+
# check op strings before
93+
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
94+
edge_op_str
95+
).run(before.graph_module.code)
7796

78-
# check op strings
79-
FileCheck().check_not(aten_op_str).check_count(
80-
edge_op_str, 1, exactly=True
81-
).run(epm.exported_program().graph_module.code)
97+
# check op strings
98+
FileCheck().check_not(aten_op_str).check_count(
99+
edge_op_str, 1, exactly=True
100+
).run(epm.exported_program().graph_module.code)
82101

83102
# check EdgeOp and the new BackendOp should behave the same
84103
expected = before.module()(*test_set.sample_input)
85104
actual = epm.exported_program().module()(*test_set.sample_input)
86-
test_class.assertTrue(torch.allclose(actual, expected))
105+
test_class.assertTrue(
106+
torch.allclose(actual, expected, atol=test_set.atol, rtol=test_set.rtol)
107+
)
87108
test_class.assertEqual(
88109
is_channel_last_dim_order(actual),
89110
is_channel_last_dim_order(expected),
@@ -105,7 +126,11 @@ def memory_format_test_runner(
105126
runtime_output = executorch_module.run_method(
106127
"forward", tuple(inputs_flattened)
107128
)[0]
108-
test_class.assertTrue(torch.allclose(runtime_output, expected))
129+
test_class.assertTrue(
130+
torch.allclose(
131+
runtime_output, expected, atol=test_set.atol, rtol=test_set.rtol
132+
)
133+
)
109134
test_class.assertEqual(
110135
is_channel_last_dim_order(runtime_output),
111136
is_channel_last_dim_order(expected),

0 commit comments

Comments
 (0)