Skip to content

Commit 3943581

Browse files
[et][dim order] aot support for dim order variant empty op
Pull Request resolved: #7168 This diff added aot support for dim order variant empty op, including operator impl and registration, memory_format_pass update, and end2end tests on both aten and lean mode. ghstack-source-id: 256440140 Differential Revision: [D66738618](https://our.internmc.facebook.com/intern/diff/D66738618/) --------- Co-authored-by: gasoonjia <[email protected]>
1 parent c4054f1 commit 3943581

File tree

5 files changed

+140
-15
lines changed

5 files changed

+140
-15
lines changed

exir/passes/dim_order_ops_registry.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,19 @@
1515
"_to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
1616
)
1717

18-
# Out variant drops TensorOptions
18+
lib.define(
19+
"_empty_dim_order(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, int[]? dim_order=None) -> Tensor"
20+
)
21+
22+
# Out variant of aten::_to_copy and aten::empty drops TensorOptions, so do their dim order variants
1923
lib.define(
2024
"_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
2125
)
2226

27+
lib.define(
28+
"_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
29+
)
30+
2331

2432
def _op_impl(target, *args, **kwargs):
2533
kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None))
@@ -39,18 +47,30 @@ def _to_dim_order_copy_out_impl(*args, **kwargs):
3947
return _op_impl(torch.ops.aten._to_copy.out, *args, **kwargs)
4048

4149

50+
@impl(lib, "_empty_dim_order", "CompositeImplicitAutograd")
51+
def _empty_dim_order_impl(*args, **kwargs):
52+
return _op_impl(torch.ops.aten.empty.memory_format, *args, **kwargs)
53+
54+
55+
@impl(lib, "_empty_dim_order.out", "CompositeImplicitAutograd")
56+
def _empty_dim_order_out_impl(*args, **kwargs):
57+
return _op_impl(torch.ops.aten.empty.out, *args, **kwargs)
58+
59+
4260
"""
4361
Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup
4462
"""
4563
DimOrderOpsMap = {
4664
"aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
65+
"aten.empty.memory_format": exir_ops.edge.dim_order_ops._empty_dim_order.default,
4766
}
4867

4968
"""
5069
Defines a map of aten or edge ops to the corresponding memory format ops for quick lookup
5170
"""
5271
MemoryFormatOpsMap = {
5372
"dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default,
73+
"dim_order_ops._empty_dim_order.default": exir_ops.edge.aten.empty.memory_format,
5474
}
5575

5676
# If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts.

exir/passes/memory_format_ops_pass.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def call_operator(self, op, args, kwargs, meta):
3939
kwargs,
4040
meta,
4141
)
42+
4243
# new kwargs with dim_order, and no memory_format for the new op
4344
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable
4445

@@ -50,17 +51,20 @@ def call_operator(self, op, args, kwargs, meta):
5051
ndim = args[0].to_tensor().dim()
5152
elif isinstance(args[0], torch.Tensor):
5253
ndim = args[0].dim()
54+
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
55+
ndim = len(args[0])
5356
else:
54-
assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}"
57+
assert (
58+
0
59+
), f"Expecting a Tensor, a ProxyValue, or a Sequence, but got {type(args[0])}"
5560

5661
nkwargs["dim_order"] = get_dim_order(mem_format, ndim)
5762
logger.debug(
58-
f"_to_copy = rank: {ndim}, memory_format: {mem_format}."
59-
f" _to_dim_order_copy = dim_order: {nkwargs['dim_order']}"
63+
f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}."
64+
f" {DimOrderOpsMap[op.__name__].__name__} = dim_order: {nkwargs['dim_order']}"
6065
)
6166

62-
t = DimOrderOpsMap.get(op.__name__, None)
63-
assert t is not None, f"{op.__name__} not found in DimOrderOpsMap"
67+
t = DimOrderOpsMap[op.__name__]
6468

6569
return super().call_operator(
6670
t,
@@ -92,8 +96,10 @@ def call_operator(self, op, args, kwargs, meta):
9296
ndim = args[0].to_tensor().dim()
9397
elif isinstance(args[0], torch.Tensor):
9498
ndim = args[0].dim()
99+
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
100+
ndim = len(args[0])
95101
else:
96-
assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}"
102+
assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}"
97103

98104
# get the "to" memory format for the EdgeOp
99105
default_dim_order = list(range(ndim))
@@ -102,12 +108,11 @@ def call_operator(self, op, args, kwargs, meta):
102108
nkwargs["memory_format"] = get_memory_format(dim_order)
103109

104110
logger.debug(
105-
f" _to_dim_order_copy = dim_order: {dim_order}."
106-
f"_to_copy = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
111+
f" {op.__name__} = dim_order: {dim_order}."
112+
f" {MemoryFormatOpsMap[op.__name__].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
107113
)
108114

109-
t = MemoryFormatOpsMap.get(op.__name__, None)
110-
assert t is not None, f"{op.__name__} not found in MemoryFormatOpsMap"
115+
t = MemoryFormatOpsMap[op.__name__]
111116

112117
return super().call_operator(
113118
t,

exir/tests/test_memory_format_ops_pass.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
MemoryFormatOpsPassTestUtils,
2828
MemoryFormatTestSet,
2929
PropagateToCopyChannalsLastModule,
30+
SimpleEmptyChannelLastModule,
31+
SimpleEmptyContiguoustModule,
3032
SimpleToCopyChannelsLastModule,
3133
SimpleToCopyContiguousModule,
3234
)
@@ -45,6 +47,7 @@ def test_op_to_copy_replacement_2d(self) -> None:
4547
self,
4648
MemoryFormatTestSet(
4749
module=SimpleToCopyContiguousModule().eval(),
50+
op=torch.ops.aten._to_copy.default,
4851
sample_input=(torch.randn([3, 4, 5], dtype=torch.float32),),
4952
target_memory_format=torch.contiguous_format,
5053
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
@@ -56,17 +59,43 @@ def test_op_to_copy_replacement_4d(self) -> None:
5659
self,
5760
MemoryFormatTestSet(
5861
module=SimpleToCopyContiguousModule().eval(),
62+
op=torch.ops.aten._to_copy.default,
5963
sample_input=(torch.randn([3, 4, 5, 6], dtype=torch.float32),),
6064
target_memory_format=torch.contiguous_format,
6165
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
6266
),
6367
)
6468

69+
def test_op_empty_replacement_channels_last(self) -> None:
70+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
71+
self,
72+
MemoryFormatTestSet(
73+
module=SimpleEmptyChannelLastModule().eval(),
74+
op=torch.ops.aten.empty.memory_format,
75+
sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),),
76+
target_memory_format=torch.channels_last,
77+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
78+
),
79+
)
80+
81+
def test_op_empty_replacement_contiguous(self) -> None:
82+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
83+
self,
84+
MemoryFormatTestSet(
85+
module=SimpleEmptyContiguoustModule().eval(),
86+
op=torch.ops.aten.empty.memory_format,
87+
sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),),
88+
target_memory_format=torch.contiguous_format,
89+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
90+
),
91+
)
92+
6593
def test_op_dim_order_update(self) -> None:
6694
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
6795
self,
6896
MemoryFormatTestSet(
6997
module=SimpleToCopyChannelsLastModule().eval(),
98+
op=torch.ops.aten._to_copy.default,
7099
sample_input=(
71100
torch.rand_like(
72101
torch.zeros([2, 2, 2, 2]),
@@ -84,6 +113,7 @@ def test_op_dim_order_propagation(self) -> None:
84113
self,
85114
MemoryFormatTestSet(
86115
module=PropagateToCopyChannalsLastModule().eval(),
116+
op=torch.ops.aten._to_copy.default,
87117
sample_input=(
88118
torch.rand_like(
89119
torch.zeros([2, 2, 2, 2]),
@@ -273,6 +303,7 @@ def test_resnet18(self) -> None:
273303
self,
274304
MemoryFormatTestSet(
275305
module=model.eval(),
306+
op=torch.ops.aten._to_copy.default,
276307
sample_input=(torch.randn(1, 3, 224, 224),),
277308
target_memory_format=torch.contiguous_format,
278309
op_level_check=False,
@@ -288,6 +319,7 @@ def test_resnet18_xnnpack(self) -> None:
288319
self,
289320
MemoryFormatTestSet(
290321
module=model.eval(),
322+
op=torch.ops.aten._to_copy.default,
291323
sample_input=(torch.randn(1, 3, 224, 224),),
292324
target_memory_format=torch.contiguous_format,
293325
op_level_check=False,
@@ -304,6 +336,7 @@ def test_mobilenet_v3(self) -> None:
304336
self,
305337
MemoryFormatTestSet(
306338
module=model.eval(),
339+
op=torch.ops.aten._to_copy.default,
307340
sample_input=(torch.randn(1, 3, 224, 224),),
308341
target_memory_format=torch.contiguous_format,
309342
op_level_check=False,
@@ -319,6 +352,7 @@ def test_mobilenet_v3_xnnpack(self) -> None:
319352
self,
320353
MemoryFormatTestSet(
321354
module=model.eval(),
355+
op=torch.ops.aten._to_copy.default,
322356
sample_input=(torch.randn(1, 3, 224, 224),),
323357
target_memory_format=torch.contiguous_format,
324358
op_level_check=False,

exir/tests/test_memory_format_ops_pass_aten.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
MemoryFormatOpsPassTestUtils,
1414
MemoryFormatTestSet,
1515
PropagateToCopyChannalsLastModule,
16+
SimpleEmptyChannelLastModule,
17+
SimpleEmptyContiguoustModule,
1618
SimpleToCopyChannelsLastModule,
1719
SimpleToCopyContiguousModule,
1820
)
@@ -28,6 +30,7 @@ def test_op_to_copy_replacement_2d_aten(self) -> None:
2830
self,
2931
MemoryFormatTestSet(
3032
module=SimpleToCopyContiguousModule().eval(),
33+
op=torch.ops.aten._to_copy.default,
3134
sample_input=(torch.randn([3, 4, 5], dtype=torch.float32),),
3235
target_memory_format=torch.contiguous_format,
3336
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
@@ -39,17 +42,43 @@ def test_op_to_copy_replacement_4d_aten(self) -> None:
3942
self,
4043
MemoryFormatTestSet(
4144
module=SimpleToCopyContiguousModule().eval(),
45+
op=torch.ops.aten._to_copy.default,
4246
sample_input=(torch.randn([3, 4, 5, 6], dtype=torch.float32),),
4347
target_memory_format=torch.contiguous_format,
4448
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
4549
),
4650
)
4751

52+
def test_op_empty_replacement_channels_last(self) -> None:
53+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
54+
self,
55+
MemoryFormatTestSet(
56+
module=SimpleEmptyChannelLastModule().eval(),
57+
op=torch.ops.aten.empty.memory_format,
58+
sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),),
59+
target_memory_format=torch.channels_last,
60+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
61+
),
62+
)
63+
64+
def test_op_empty_replacement_contiguous(self) -> None:
65+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
66+
self,
67+
MemoryFormatTestSet(
68+
module=SimpleEmptyContiguoustModule().eval(),
69+
op=torch.ops.aten.empty.memory_format,
70+
sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),),
71+
target_memory_format=torch.contiguous_format,
72+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
73+
),
74+
)
75+
4876
def test_op_dim_order_update_aten(self) -> None:
4977
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
5078
self,
5179
MemoryFormatTestSet(
5280
module=SimpleToCopyChannelsLastModule().eval(),
81+
op=torch.ops.aten._to_copy.default,
5382
sample_input=(
5483
torch.rand_like(
5584
torch.zeros([2, 2, 2, 2]),
@@ -67,6 +96,7 @@ def test_op_dim_order_propagation_aten(self) -> None:
6796
self,
6897
MemoryFormatTestSet(
6998
module=PropagateToCopyChannalsLastModule().eval(),
99+
op=torch.ops.aten._to_copy.default,
70100
sample_input=(
71101
torch.rand_like(
72102
torch.zeros([2, 2, 2, 2]),
@@ -85,6 +115,7 @@ def test_resnet18(self) -> None:
85115
self,
86116
MemoryFormatTestSet(
87117
module=model.eval(),
118+
op=torch.ops.aten._to_copy.default,
88119
sample_input=(torch.randn(1, 3, 224, 224),),
89120
target_memory_format=torch.contiguous_format,
90121
op_level_check=False,
@@ -100,6 +131,7 @@ def test_mobilenet_v3(self) -> None:
100131
self,
101132
MemoryFormatTestSet(
102133
module=model.eval(),
134+
op=torch.ops.aten._to_copy.default,
103135
sample_input=(torch.randn(1, 3, 224, 224),),
104136
target_memory_format=torch.contiguous_format,
105137
op_level_check=False,

exir/tests/test_memory_format_ops_pass_utils.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import unittest
1010
from dataclasses import dataclass
11-
from typing import Any, Tuple
11+
from typing import Any, Dict, List, Tuple
1212

1313
import torch
1414

@@ -26,11 +26,24 @@
2626
from torch.utils._pytree import tree_flatten
2727

2828

29+
MemoryFormatOps2Str: Dict[torch._ops.OpOverload, List[str]] = {
30+
torch.ops.aten._to_copy.default: (
31+
"torch.ops.aten._to_copy.default",
32+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default",
33+
),
34+
torch.ops.aten.empty.memory_format: (
35+
"torch.ops.aten.empty.memory_format",
36+
"executorch_exir_dialects_edge__ops_dim_order_ops__empty_dim_order_default",
37+
),
38+
}
39+
40+
2941
@dataclass
3042
class MemoryFormatTestSet:
3143
module: torch.nn.Module
3244
sample_input: Tuple[Any, ...]
3345
target_memory_format: torch.memory_format
46+
op: torch._ops.OpOverload
3447
_load_for_executorch_from_buffer: Any
3548
op_level_check: bool = True
3649
use_xnnpack: bool = False
@@ -54,6 +67,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5467
return x.to(dtype=torch.double, memory_format=torch.channels_last)
5568

5669

70+
class SimpleEmptyContiguoustModule(torch.nn.Module):
71+
def __init__(self):
72+
super().__init__()
73+
74+
def forward(self, x: torch.Tensor) -> torch.Tensor:
75+
empty_tensor = torch.empty(x.size(), memory_format=torch.contiguous_format)
76+
x = x.to(memory_format=torch.contiguous_format)
77+
empty_tensor.copy_(x)
78+
return empty_tensor
79+
80+
81+
class SimpleEmptyChannelLastModule(torch.nn.Module):
82+
def __init__(self):
83+
super().__init__()
84+
85+
def forward(self, x: torch.Tensor) -> torch.Tensor:
86+
empty_tensor = torch.empty(x.size(), memory_format=torch.channels_last)
87+
x = x.to(memory_format=torch.channels_last)
88+
empty_tensor.copy_(x)
89+
return empty_tensor
90+
91+
5792
class PropagateToCopyChannalsLastModule(torch.nn.Module):
5893
def __init__(self):
5994
super().__init__()
@@ -86,9 +121,7 @@ def memory_format_test_runner(
86121

87122
# check memory format ops, if needed
88123
if test_set.op_level_check:
89-
aten_op_str = "torch.ops.aten._to_copy.default"
90-
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
91-
124+
aten_op_str, edge_op_str = MemoryFormatOps2Str[test_set.op]
92125
# check op strings before
93126
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
94127
edge_op_str
@@ -126,6 +159,7 @@ def memory_format_test_runner(
126159
runtime_output = executorch_module.run_method(
127160
"forward", tuple(inputs_flattened)
128161
)[0]
162+
129163
test_class.assertTrue(
130164
torch.allclose(
131165
runtime_output, expected, atol=test_set.atol, rtol=test_set.rtol

0 commit comments

Comments
 (0)