Skip to content

Commit d1dd975

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Add a pass to replace memory_format ops
Summary: The goal is to replace aten ops which takes torch.memory_format as an argument with edge dialect ops with instead takes dim_order. This is towards the larger goal to move away from memory_format to dim_order. Pending items for AOT: * Not running the pass yet, but plan to run it through `to_edge()` * Not implemented all the ops which consumes memory format yet like, * torch.clone * tensor.full_like * Serialization - once this is done, need to remove support for memory_format altogether [BC breaking] Reviewed By: larryliu0820 Differential Revision: D48195055 fbshipit-source-id: 5c623a6914043da80069e27d8094c5f3ce398c06
1 parent 1ac53ee commit d1dd975

File tree

5 files changed

+171
-0
lines changed

5 files changed

+171
-0
lines changed

exir/passes/TARGETS

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ python_library(
1010
deps = [
1111
":const_prop_pass",
1212
":debug_handle_generator_pass",
13+
":memory_format_ops_pass",
1314
":memory_planning_pass",
1415
":normalize_transpose_pass",
1516
":pass_registry",
@@ -262,5 +263,20 @@ python_library(
262263
deps = [
263264
"//caffe2:torch",
264265
"//executorch/exir:dim_order_utils",
266+
"//executorch/exir/dialects:lib",
267+
],
268+
)
269+
270+
python_library(
271+
name = "memory_format_ops_pass",
272+
srcs = [
273+
"memory_format_ops_pass.py",
274+
],
275+
deps = [
276+
":dim_order_ops_registry",
277+
"//caffe2:torch",
278+
"//executorch/exir:dim_order_utils",
279+
"//executorch/exir:pass_base",
280+
"//executorch/exir/dialects/edge:lib",
265281
],
266282
)

exir/passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
3737

3838
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
39+
from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass
3940
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
4041
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
4142
from executorch.exir.passes.pass_registry import PassRegistry
@@ -62,6 +63,7 @@
6263
"QuantFusionPass",
6364
"OpReplacePass",
6465
"EdgeToBackendOpsPass",
66+
"MemoryFormatOpsPass",
6567
"SymShapeEvalPass",
6668
]
6769

exir/passes/memory_format_ops_pass.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import logging
9+
10+
import torch
11+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
12+
from executorch.exir.dim_order_utils import get_dim_order
13+
from executorch.exir.pass_base import ExportPass
14+
from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap
15+
16+
from torch._export.pass_infra.proxy_value import ProxyValue
17+
18+
logger = logging.getLogger(__file__)
19+
logger.setLevel(logging.INFO)
20+
21+
22+
class MemoryFormatOpsPass(ExportPass):
23+
"""
24+
This pass replaces ops which takes torch.memory_format as an argument with
25+
'equivalent' op which takes dim_order. This is towards the larger Executorch
26+
goal to move away from torch.memory_format. There is a 1:1 mapping between
27+
the aten op and the new edge dialect dim_order op.
28+
"""
29+
30+
def call_operator(self, op, args, kwargs, meta):
31+
if not (isinstance(op, EdgeOpOverload) and op.__name__ in DimOrderOpsMap):
32+
return super().call_operator(
33+
op,
34+
args,
35+
kwargs,
36+
meta,
37+
)
38+
# new kwargs with dim_order, and no memory_format for the new op
39+
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable
40+
41+
# get the "to" memory format for the EdgeOp
42+
mem_format = nkwargs.pop("memory_format", torch.contiguous_format)
43+
44+
# can always get the shape, assuming rank is specialized
45+
if isinstance(args[0], ProxyValue) and args[0].is_tensor():
46+
ndim = args[0].to_tensor().dim()
47+
elif isinstance(args[0], torch.Tensor):
48+
ndim = args[0].dim()
49+
else:
50+
assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}"
51+
52+
nkwargs["dim_order"] = get_dim_order(mem_format, ndim)
53+
logger.debug(
54+
f"_to_copy = rank: {ndim}, memory_format: {mem_format}."
55+
f" _to_dim_order_copy = dim_order: {nkwargs['dim_order']}"
56+
)
57+
58+
t = DimOrderOpsMap[op.__name__]
59+
60+
return super().call_operator(
61+
t,
62+
args,
63+
nkwargs,
64+
meta,
65+
)

exir/tests/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,16 @@ python_unittest(
432432
"//executorch/exir:dim_order_utils",
433433
],
434434
)
435+
436+
python_unittest(
437+
name = "memory_format_ops_pass",
438+
srcs = [
439+
"test_memory_format_ops_pass.py",
440+
],
441+
supports_static_listing = True,
442+
deps = [
443+
"//caffe2:torch",
444+
"//executorch/exir:lib",
445+
"//executorch/exir/passes:lib",
446+
],
447+
)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch import exir
11+
from executorch.exir import CaptureConfig, EdgeCompileConfig
12+
from executorch.exir.passes import MemoryFormatOpsPass
13+
from torch.testing import FileCheck
14+
15+
16+
class TestMemoryFormatOpsPass(unittest.TestCase):
17+
def test_op_to_copy_replacement(self) -> None:
18+
class F(torch.nn.Module):
19+
def __init__(self):
20+
super().__init__()
21+
22+
def forward(
23+
self, x: torch.Tensor, mem_format: torch.memory_format
24+
) -> torch.Tensor:
25+
return x.to(dtype=torch.double, memory_format=mem_format)
26+
27+
module = F().eval()
28+
sample_inputs = [
29+
(torch.randn([2, 2], dtype=torch.float32), torch.contiguous_format),
30+
(torch.randn([2, 2, 2], dtype=torch.float32), torch.contiguous_format),
31+
(torch.randn([2, 2, 2, 2], dtype=torch.float32), torch.channels_last),
32+
(
33+
torch.rand_like(
34+
torch.zeros([2, 2, 2, 2]),
35+
dtype=torch.float32,
36+
memory_format=torch.channels_last,
37+
),
38+
torch.contiguous_format,
39+
),
40+
]
41+
42+
aten_op_str = "torch.ops.aten._to_copy.default"
43+
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
44+
45+
for sample_input in sample_inputs:
46+
before = exir.capture(
47+
module,
48+
sample_input,
49+
CaptureConfig(pt2_mode=True, enable_dynamic_shape=True),
50+
)
51+
52+
# check op strings before
53+
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
54+
edge_op_str
55+
).run(before.exported_program.graph_module.code)
56+
57+
ep = before.to_edge(
58+
config=EdgeCompileConfig(_use_edge_ops=True)
59+
) # Only replacing edge_ops
60+
61+
# Run the pass - TODO move this in to_edge passes
62+
after = ep.transform(MemoryFormatOpsPass())
63+
64+
# check op strings
65+
FileCheck().check_not(aten_op_str).check_count(
66+
edge_op_str, 1, exactly=True
67+
).run(after.exported_program.graph_module.code)
68+
69+
# check EdgeOp and the new BackendOp should behave the same
70+
expected = before(*sample_input)
71+
actual = after(*sample_input)
72+
self.assertTrue(torch.allclose(actual, expected))
73+
74+
# TODO - more
75+
after.to_executorch()

0 commit comments

Comments
 (0)