Skip to content

Commit 7af41a3

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
introduce model-level end2end tests to dim order tests with different delegate
Summary: This diff introduced end2end tests on several models + delegation combinations. Models: llama2, resnet18, mobilenet_v3 Delegate: no delegate, xnnpack Differential Revision: D64174329
1 parent 83c95df commit 7af41a3

File tree

4 files changed

+180
-15
lines changed

4 files changed

+180
-15
lines changed

exir/tests/TARGETS

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,9 @@ python_unittest(
376376
deps = [
377377
":test_memory_format_ops_pass_utils",
378378
"//caffe2:torch",
379+
"//executorch/examples/models/llama2:llama2_model",
379380
"//executorch/extension/pybindings:aten_lib", # @manual
381+
"//pytorch/vision:torchvision", # @manual
380382
],
381383
)
382384

@@ -393,7 +395,9 @@ python_unittest(
393395
"//executorch/exir:pass_base",
394396
"//executorch/exir/dialects:lib",
395397
"//executorch/exir/dialects/edge:lib",
398+
"//executorch/examples/models/llama2:llama2_model",
396399
"//executorch/extension/pybindings:portable_lib", # @manual
400+
"//pytorch/vision:torchvision", # @manual
397401
],
398402
)
399403

@@ -404,6 +408,7 @@ python_library(
404408
],
405409
deps = [
406410
"//caffe2:torch",
411+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
407412
"//executorch/exir:dim_order_utils",
408413
"//executorch/exir:lib",
409414
"//executorch/exir/capture:config",

exir/tests/test_memory_format_ops_pass.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
)
2222
from executorch.exir.pass_base import ExportPass, ProxyValue
2323

24+
from executorch.examples.models.llama2.model import Llama2Model
25+
2426
from executorch.exir.tests.test_memory_format_ops_pass_utils import (
2527
MemoryFormatOpsPassTestUtils,
2628
MemoryFormatTestSet,
@@ -36,6 +38,8 @@
3638
from torch.export import export
3739
from torch.testing import FileCheck
3840

41+
import torchvision
42+
3943

4044
class TestMemoryFormatOpsPass(unittest.TestCase):
4145
def test_op_to_copy_replacement_2d(self) -> None:
@@ -264,3 +268,97 @@ def call_operator(self, op, args, kwargs, meta):
264268

265269
self.assertTrue(is_contiguous_dim_order(actual))
266270
self.assertTrue(is_contiguous_dim_order(expected))
271+
272+
273+
274+
def test_llama2(self) -> None:
275+
llama2 = Llama2Model()
276+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
277+
self,
278+
MemoryFormatTestSet(
279+
module=llama2.get_eager_model().eval(),
280+
sample_input=llama2.get_example_inputs(),
281+
target_memory_format=torch.contiguous_format,
282+
op_level_check=False,
283+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
284+
atol=5e-2,
285+
),
286+
)
287+
288+
def test_llama2_xnnpack(self) -> None:
289+
llama2 = Llama2Model()
290+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
291+
self,
292+
MemoryFormatTestSet(
293+
module=llama2.get_eager_model().eval(),
294+
sample_input=llama2.get_example_inputs(),
295+
target_memory_format=torch.contiguous_format,
296+
op_level_check=False,
297+
use_xnnpack=True,
298+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
299+
atol=5e-2,
300+
),
301+
)
302+
303+
def test_resnet18(self) -> None:
304+
model = torchvision.models.resnet18()
305+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
306+
self,
307+
MemoryFormatTestSet(
308+
module=model.eval(),
309+
sample_input=(torch.randn(1, 3, 224, 224),),
310+
target_memory_format=torch.contiguous_format,
311+
op_level_check=False,
312+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
313+
atol=1e-3,
314+
rtol=1e-3,
315+
),
316+
)
317+
318+
319+
def test_resnet18_xnnpack(self) -> None:
320+
model = torchvision.models.resnet18()
321+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
322+
self,
323+
MemoryFormatTestSet(
324+
module=model.eval(),
325+
sample_input=(torch.randn(1, 3, 224, 224),),
326+
target_memory_format=torch.contiguous_format,
327+
op_level_check=False,
328+
use_xnnpack=True,
329+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
330+
atol=1e-3,
331+
rtol=1e-3,
332+
),
333+
)
334+
335+
def test_mobilenet_v3(self) -> None:
336+
model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True)
337+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
338+
self,
339+
MemoryFormatTestSet(
340+
module=model.eval(),
341+
sample_input=(torch.randn(1, 3, 224, 224),),
342+
target_memory_format=torch.contiguous_format,
343+
op_level_check=False,
344+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
345+
atol=1e-3,
346+
rtol=1e-3,
347+
),
348+
)
349+
350+
def test_mobilenet_v3_xnnpack(self) -> None:
351+
model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True)
352+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
353+
self,
354+
MemoryFormatTestSet(
355+
module=model.eval(),
356+
sample_input=(torch.randn(1, 3, 224, 224),),
357+
target_memory_format=torch.contiguous_format,
358+
op_level_check=False,
359+
use_xnnpack=True,
360+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
361+
atol=1e-3,
362+
rtol=1e-3,
363+
),
364+
)

exir/tests/test_memory_format_ops_pass_aten.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
_load_for_executorch_from_buffer,
2121
)
2222

23+
from executorch.examples.models.llama2.model import Llama2Model
24+
import torchvision
25+
26+
2327

2428
class TestMemoryFormatOpsPass(unittest.TestCase):
2529
def test_op_to_copy_replacement_2d_aten(self) -> None:
@@ -77,3 +81,48 @@ def test_op_dim_order_propagation_aten(self) -> None:
7781
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
7882
),
7983
)
84+
85+
def test_llama2(self) -> None:
86+
llama2 = Llama2Model()
87+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
88+
self,
89+
MemoryFormatTestSet(
90+
module=llama2.get_eager_model().eval(),
91+
sample_input=llama2.get_example_inputs(),
92+
target_memory_format=torch.contiguous_format,
93+
op_level_check=False,
94+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
95+
atol=5e-2,
96+
),
97+
)
98+
99+
def test_resnet18(self) -> None:
100+
model = torchvision.models.resnet18()
101+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
102+
self,
103+
MemoryFormatTestSet(
104+
module=model.eval(),
105+
sample_input=(torch.randn(1, 3, 224, 224),),
106+
target_memory_format=torch.contiguous_format,
107+
op_level_check=False,
108+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
109+
atol=1e-3,
110+
rtol=1e-3,
111+
),
112+
)
113+
114+
115+
def test_mobilenet_v3(self) -> None:
116+
model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True)
117+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
118+
self,
119+
MemoryFormatTestSet(
120+
module=model.eval(),
121+
sample_input=(torch.randn(1, 3, 224, 224),),
122+
target_memory_format=torch.contiguous_format,
123+
op_level_check=False,
124+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
125+
atol=1e-3,
126+
rtol=1e-3,
127+
),
128+
)

exir/tests/test_memory_format_ops_pass_utils.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
from typing import Any, Tuple
1212

1313
import torch
14-
from executorch.exir import to_edge
14+
from executorch.exir import to_edge, to_edge_transform_and_lower
1515
from executorch.exir.capture._config import EdgeCompileConfig
1616

1717
from executorch.exir.dim_order_utils import (
1818
is_channel_last_dim_order,
1919
is_contiguous_dim_order,
2020
)
2121

22+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
23+
2224
from torch.export import export
2325
from torch.testing import FileCheck
2426
from torch.utils._pytree import tree_flatten
@@ -30,6 +32,10 @@ class MemoryFormatTestSet:
3032
sample_input: Tuple[Any, ...]
3133
target_memory_format: torch.memory_format
3234
_load_for_executorch_from_buffer: Any
35+
op_level_check: bool = True
36+
use_xnnpack: bool = False
37+
rtol: float = 1e-05
38+
atol: float = 1e-08
3339

3440

3541
class SimpleToCopyContiguousModule(torch.nn.Module):
@@ -63,27 +69,34 @@ class MemoryFormatOpsPassTestUtils:
6369
def memory_format_test_runner(
6470
test_class: unittest.TestCase, test_set: MemoryFormatTestSet
6571
):
66-
aten_op_str = "torch.ops.aten._to_copy.default"
67-
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
68-
6972
before = export(test_set.module, test_set.sample_input)
7073

71-
# check op strings before
72-
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
73-
edge_op_str
74-
).run(before.graph_module.code)
74+
if test_set.use_xnnpack:
75+
epm = to_edge_transform_and_lower(
76+
before, compile_config=EdgeCompileConfig(_skip_dim_order=False, _check_ir_validity=False), partitioner=[XnnpackPartitioner()])
77+
else:
78+
epm = to_edge(before, compile_config=EdgeCompileConfig(_skip_dim_order=False))
79+
80+
# check memory format ops, if needed
81+
if test_set.op_level_check:
82+
aten_op_str = "torch.ops.aten._to_copy.default"
83+
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
84+
85+
# check op strings before
86+
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
87+
edge_op_str
88+
).run(before.graph_module.code)
7589

76-
epm = to_edge(before, compile_config=EdgeCompileConfig(_skip_dim_order=False))
90+
# check op strings
91+
FileCheck().check_not(aten_op_str).check_count(
92+
edge_op_str, 1, exactly=True
93+
).run(epm.exported_program().graph_module.code)
7794

78-
# check op strings
79-
FileCheck().check_not(aten_op_str).check_count(
80-
edge_op_str, 1, exactly=True
81-
).run(epm.exported_program().graph_module.code)
8295

8396
# check EdgeOp and the new BackendOp should behave the same
8497
expected = before.module()(*test_set.sample_input)
8598
actual = epm.exported_program().module()(*test_set.sample_input)
86-
test_class.assertTrue(torch.allclose(actual, expected))
99+
test_class.assertTrue(torch.allclose(actual, expected, atol=test_set.atol, rtol=test_set.rtol))
87100
test_class.assertEqual(
88101
is_channel_last_dim_order(actual),
89102
is_channel_last_dim_order(expected),
@@ -105,7 +118,7 @@ def memory_format_test_runner(
105118
runtime_output = executorch_module.run_method(
106119
"forward", tuple(inputs_flattened)
107120
)[0]
108-
test_class.assertTrue(torch.allclose(runtime_output, expected))
121+
test_class.assertTrue(torch.allclose(runtime_output, expected,atol=test_set.atol, rtol=test_set.rtol))
109122
test_class.assertEqual(
110123
is_channel_last_dim_order(runtime_output),
111124
is_channel_last_dim_order(expected),

0 commit comments

Comments
 (0)