Skip to content

Commit aafffa2

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
introduce model-level end2end tests to dim order tests with different delegate (#6093)
Summary: This diff introduced end2end tests on several models + delegation combinations. Models: llama2, resnet18, mobilenet_v3 Delegate: no delegate, xnnpack Differential Revision: D64174329
1 parent df5b2ab commit aafffa2

File tree

4 files changed

+130
-15
lines changed

4 files changed

+130
-15
lines changed

exir/tests/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ python_unittest(
377377
":test_memory_format_ops_pass_utils",
378378
"//caffe2:torch",
379379
"//executorch/extension/pybindings:aten_lib", # @manual
380+
"//pytorch/vision:torchvision", # @manual
380381
],
381382
)
382383

@@ -394,6 +395,7 @@ python_unittest(
394395
"//executorch/exir/dialects:lib",
395396
"//executorch/exir/dialects/edge:lib",
396397
"//executorch/extension/pybindings:portable_lib", # @manual
398+
"//pytorch/vision:torchvision", # @manual
397399
],
398400
)
399401

@@ -404,6 +406,7 @@ python_library(
404406
],
405407
deps = [
406408
"//caffe2:torch",
409+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
407410
"//executorch/exir:dim_order_utils",
408411
"//executorch/exir:lib",
409412
"//executorch/exir/capture:config",

exir/tests/test_memory_format_ops_pass.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from torch.export import export
3737
from torch.testing import FileCheck
3838

39+
import torchvision
40+
3941

4042
class TestMemoryFormatOpsPass(unittest.TestCase):
4143
def test_op_to_copy_replacement_2d(self) -> None:
@@ -264,3 +266,67 @@ def call_operator(self, op, args, kwargs, meta):
264266

265267
self.assertTrue(is_contiguous_dim_order(actual))
266268
self.assertTrue(is_contiguous_dim_order(expected))
269+
270+
271+
def test_resnet18(self) -> None:
272+
model = torchvision.models.resnet18()
273+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
274+
self,
275+
MemoryFormatTestSet(
276+
module=model.eval(),
277+
sample_input=(torch.randn(1, 3, 224, 224),),
278+
target_memory_format=torch.contiguous_format,
279+
op_level_check=False,
280+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
281+
atol=1e-3,
282+
rtol=1e-3,
283+
),
284+
)
285+
286+
287+
def test_resnet18_xnnpack(self) -> None:
288+
model = torchvision.models.resnet18()
289+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
290+
self,
291+
MemoryFormatTestSet(
292+
module=model.eval(),
293+
sample_input=(torch.randn(1, 3, 224, 224),),
294+
target_memory_format=torch.contiguous_format,
295+
op_level_check=False,
296+
use_xnnpack=True,
297+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
298+
atol=1e-3,
299+
rtol=1e-3,
300+
),
301+
)
302+
303+
def test_mobilenet_v3(self) -> None:
304+
model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True)
305+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
306+
self,
307+
MemoryFormatTestSet(
308+
module=model.eval(),
309+
sample_input=(torch.randn(1, 3, 224, 224),),
310+
target_memory_format=torch.contiguous_format,
311+
op_level_check=False,
312+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
313+
atol=1e-3,
314+
rtol=1e-3,
315+
),
316+
)
317+
318+
def test_mobilenet_v3_xnnpack(self) -> None:
319+
model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True)
320+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
321+
self,
322+
MemoryFormatTestSet(
323+
module=model.eval(),
324+
sample_input=(torch.randn(1, 3, 224, 224),),
325+
target_memory_format=torch.contiguous_format,
326+
op_level_check=False,
327+
use_xnnpack=True,
328+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
329+
atol=1e-3,
330+
rtol=1e-3,
331+
),
332+
)

exir/tests/test_memory_format_ops_pass_aten.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from executorch.extension.pybindings.aten_lib import ( # @manual
2020
_load_for_executorch_from_buffer,
2121
)
22+
import torchvision
23+
2224

2325

2426
class TestMemoryFormatOpsPass(unittest.TestCase):
@@ -77,3 +79,34 @@ def test_op_dim_order_propagation_aten(self) -> None:
7779
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
7880
),
7981
)
82+
83+
def test_resnet18(self) -> None:
84+
model = torchvision.models.resnet18()
85+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
86+
self,
87+
MemoryFormatTestSet(
88+
module=model.eval(),
89+
sample_input=(torch.randn(1, 3, 224, 224),),
90+
target_memory_format=torch.contiguous_format,
91+
op_level_check=False,
92+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
93+
atol=1e-3,
94+
rtol=1e-3,
95+
),
96+
)
97+
98+
99+
def test_mobilenet_v3(self) -> None:
100+
model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True)
101+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
102+
self,
103+
MemoryFormatTestSet(
104+
module=model.eval(),
105+
sample_input=(torch.randn(1, 3, 224, 224),),
106+
target_memory_format=torch.contiguous_format,
107+
op_level_check=False,
108+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
109+
atol=1e-3,
110+
rtol=1e-3,
111+
),
112+
)

exir/tests/test_memory_format_ops_pass_utils.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
from typing import Any, Tuple
1212

1313
import torch
14-
from executorch.exir import to_edge
14+
from executorch.exir import to_edge, to_edge_transform_and_lower
1515
from executorch.exir.capture._config import EdgeCompileConfig
1616

1717
from executorch.exir.dim_order_utils import (
1818
is_channel_last_dim_order,
1919
is_contiguous_dim_order,
2020
)
2121

22+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
23+
2224
from torch.export import export
2325
from torch.testing import FileCheck
2426
from torch.utils._pytree import tree_flatten
@@ -30,6 +32,10 @@ class MemoryFormatTestSet:
3032
sample_input: Tuple[Any, ...]
3133
target_memory_format: torch.memory_format
3234
_load_for_executorch_from_buffer: Any
35+
op_level_check: bool = True
36+
use_xnnpack: bool = False
37+
rtol: float = 1e-05
38+
atol: float = 1e-08
3339

3440

3541
class SimpleToCopyContiguousModule(torch.nn.Module):
@@ -63,27 +69,34 @@ class MemoryFormatOpsPassTestUtils:
6369
def memory_format_test_runner(
6470
test_class: unittest.TestCase, test_set: MemoryFormatTestSet
6571
):
66-
aten_op_str = "torch.ops.aten._to_copy.default"
67-
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
68-
6972
before = export(test_set.module, test_set.sample_input)
7073

71-
# check op strings before
72-
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
73-
edge_op_str
74-
).run(before.graph_module.code)
74+
if test_set.use_xnnpack:
75+
epm = to_edge_transform_and_lower(
76+
before, compile_config=EdgeCompileConfig(_skip_dim_order=False, _check_ir_validity=False), partitioner=[XnnpackPartitioner()])
77+
else:
78+
epm = to_edge(before, compile_config=EdgeCompileConfig(_skip_dim_order=False))
79+
80+
# check memory format ops, if needed
81+
if test_set.op_level_check:
82+
aten_op_str = "torch.ops.aten._to_copy.default"
83+
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
84+
85+
# check op strings before
86+
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
87+
edge_op_str
88+
).run(before.graph_module.code)
7589

76-
epm = to_edge(before, compile_config=EdgeCompileConfig(_skip_dim_order=False))
90+
# check op strings
91+
FileCheck().check_not(aten_op_str).check_count(
92+
edge_op_str, 1, exactly=True
93+
).run(epm.exported_program().graph_module.code)
7794

78-
# check op strings
79-
FileCheck().check_not(aten_op_str).check_count(
80-
edge_op_str, 1, exactly=True
81-
).run(epm.exported_program().graph_module.code)
8295

8396
# check EdgeOp and the new BackendOp should behave the same
8497
expected = before.module()(*test_set.sample_input)
8598
actual = epm.exported_program().module()(*test_set.sample_input)
86-
test_class.assertTrue(torch.allclose(actual, expected))
99+
test_class.assertTrue(torch.allclose(actual, expected, atol=test_set.atol, rtol=test_set.rtol))
87100
test_class.assertEqual(
88101
is_channel_last_dim_order(actual),
89102
is_channel_last_dim_order(expected),
@@ -105,7 +118,7 @@ def memory_format_test_runner(
105118
runtime_output = executorch_module.run_method(
106119
"forward", tuple(inputs_flattened)
107120
)[0]
108-
test_class.assertTrue(torch.allclose(runtime_output, expected))
121+
test_class.assertTrue(torch.allclose(runtime_output, expected,atol=test_set.atol, rtol=test_set.rtol))
109122
test_class.assertEqual(
110123
is_channel_last_dim_order(runtime_output),
111124
is_channel_last_dim_order(expected),

0 commit comments

Comments
 (0)