Skip to content

Commit 8cf4d1f

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Add support for SchemaKind.mutable to to_out_variant (#2541)
Summary: Pull Request resolved: #2541 Ideally custom mutator ops would not be in aten OpOverload but would instead be edge ops. The pass that results in mutator ops being in the graph is shared with core for now though and its hard to mix and match EP passes and GM passes. So just add the logic to the otherwise unused aten to_out_variant code. Reviewed By: larryliu0820 Differential Revision: D55156605 fbshipit-source-id: 3003a51b83bb211345c1d6c6e5bb7e97c19dfeb1
1 parent 7bff771 commit 8cf4d1f

File tree

4 files changed

+114
-50
lines changed

4 files changed

+114
-50
lines changed

exir/operator/TARGETS

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ python_library(
3636
)
3737

3838
python_unittest(
39-
name = "test_util",
39+
name = "test_operator",
4040
srcs = [
41-
"test/test_util.py",
41+
"test/test_operator.py",
4242
],
4343
deps = [
44+
":convert",
4445
":util",
4546
"//caffe2:torch",
4647
],

exir/operator/convert.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
# existing case with cache miss.
4242
_func_to_out_variant_map: Dict[OpOverload, Optional[OpOverload]] = {}
4343
_out_variant_to_scratch_map: Dict[OpOverload, Optional[OpOverload]] = {}
44+
_mutable_to_out_variant_map: Dict[OpOverload, Optional[OpOverload]] = {}
4445

4546
# We've found a functional and an out variant with the same name, but their
4647
# schemas mismatch. This map collects all of these cases and provides proper
@@ -135,9 +136,9 @@ def schema_to_opoverload(schema: FunctionSchema) -> OpOverload:
135136

136137
def set_mapping_for_op(op: OpOverload) -> None:
137138
"""
138-
op can either be a functional op or out variant op.
139+
op can either be a functional op, mutable op, or out variant op.
139140
This method is only called if
140-
1. either op is a functiona op and it's missing in the _func_to_out_variant_map cache.
141+
1. either op is a functional op and it's missing in the _func_to_out_variant_map cache.
141142
2. or op is a out variant op and it's missing in the _out_variant_to_scratch_map cache.
142143
143144
Setup entries in _func_to_out_variant_map and _out_variant_to_scratch_map for all ops sharing the same
@@ -148,13 +149,17 @@ def set_mapping_for_op(op: OpOverload) -> None:
148149
assert native_schema.kind() in (
149150
SchemaKind.functional,
150151
SchemaKind.out,
152+
SchemaKind.mutable,
151153
)
152154
assert not (
153155
native_schema.kind() == SchemaKind.functional and op in _func_to_out_variant_map
154156
)
155157
assert not (
156158
native_schema.kind() == SchemaKind.out and op in _out_variant_to_scratch_map
157159
)
160+
assert not (
161+
native_schema.kind() == SchemaKind.mutable and op in _mutable_to_out_variant_map
162+
)
158163
qualified_opname = str(op._schema.name)
159164

160165
all_schemas = [
@@ -192,6 +197,7 @@ def set_mapping_for_op(op: OpOverload) -> None:
192197
for group_by_kind in group_by_signature.values():
193198
func_op_schema = group_by_kind.get(SchemaKind.functional)
194199
out_var_schema = group_by_kind.get(SchemaKind.out)
200+
mutable_op_schema = group_by_kind.get(SchemaKind.mutable)
195201
scratch_schema = group_by_kind.get(SchemaKind.scratch)
196202

197203
# update the map even if out_var_schema is None to cache the negative
@@ -216,30 +222,41 @@ def set_mapping_for_op(op: OpOverload) -> None:
216222
_out_variant_to_scratch_map[schema_to_opoverload(out_var_schema)] = (
217223
schema_to_opoverload(scratch_schema) if scratch_schema else None
218224
)
225+
if mutable_op_schema:
226+
_mutable_to_out_variant_map[schema_to_opoverload(mutable_op_schema)] = (
227+
schema_to_opoverload(out_var_schema) if out_var_schema else None
228+
)
219229

220230

221231
def to_out_variant(op_overload: OpOverload) -> Tuple[OpOverload, Tuple[str]]:
222232
r"""
223233
Convert the passed in OpOverload to its out variant. Raise an exception if
224234
on return the op_overload is not guaranteed to be an out variant.
225235
226-
If a conversion is found, return the out variant OpOverlaod alongwith the name of out
236+
If a conversion is found, return the out variant OpOverload alongwith the name of out
227237
arguments.
228238
"""
229239
schema = _get_overload_schema(op_overload)
230240
if schema.is_out_fn(): # pyre-ignore
231-
return op_overload, get_out_args_from_schema(schema) # pyre-ignore
241+
return op_overload, get_out_args_from_schema(schema) # pyre-ignore[6]
232242

233-
# should be a functional op here
243+
# should be a functionalish op here
234244
assert (
235-
schema.kind() == SchemaKind.functional # pyre-ignore
236-
), f"Expect an functional op, but get {schema}"
237-
238-
if op_overload not in _func_to_out_variant_map:
245+
schema.kind() == SchemaKind.functional # pyre-ignore[16]
246+
or schema.kind() == SchemaKind.mutable
247+
), f"Expect a functionalish op, but get {schema.kind()} {schema}"
248+
249+
if (
250+
op_overload not in _func_to_out_variant_map
251+
and op_overload not in _mutable_to_out_variant_map
252+
):
239253
# setup out_var
240254
set_mapping_for_op(op_overload)
241255

242-
out_var = _func_to_out_variant_map.get(op_overload)
256+
if op_overload in _mutable_to_out_variant_map:
257+
out_var = _mutable_to_out_variant_map[op_overload]
258+
else:
259+
out_var = _func_to_out_variant_map.get(op_overload)
243260

244261
if not out_var:
245262
msg = f"Missing out variant for functional op: {schema} . Make sure you have loaded your custom operator library for compiler. E.g., custom_ops_generated_lib"

exir/operator/test/test_operator.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
import torch
12+
from executorch.exir.operator.convert import _get_overload_schema, to_out_variant
13+
from executorch.exir.operator.util import gen_out_variant_schema
14+
from torch.library import _scoped_library, impl, impl_abstract
15+
16+
17+
class TestOperator(unittest.TestCase):
18+
def setUp(self) -> None:
19+
super().setUp()
20+
21+
def test_gen_out_variant_schema_from_functional(self) -> None:
22+
func_schema = str(torch.ops.aten.mul.Scalar._schema)
23+
24+
out_schema = gen_out_variant_schema(func_schema)
25+
self.assertEqual(out_schema, str(torch.ops.aten.mul.Scalar_out._schema))
26+
27+
def test_gen_out_variant_schema_from_inplace(self) -> None:
28+
func_schema = str(torch.ops.aten.add_.Scalar._schema)
29+
30+
out_schema = gen_out_variant_schema(func_schema)
31+
self.assertEqual(out_schema, str(torch.ops.aten.add.Scalar_out._schema))
32+
33+
def test_gen_out_variant_schema_for_custom_ops(self) -> None:
34+
func_schema = "custom::foo(Tensor a, Tensor b) -> (Tensor c, Tensor d)"
35+
36+
out_schema = gen_out_variant_schema(func_schema)
37+
self.assertEqual(
38+
out_schema,
39+
"custom::foo.out(Tensor a, Tensor b, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))",
40+
)
41+
42+
def test_to_out_variant_mutable(self) -> None:
43+
44+
with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib:
45+
46+
lib.define("custom_mutator(Tensor x, Tensor(a!) y) -> Tensor")
47+
lib.define(
48+
"custom_mutator.out(Tensor x, Tensor(a!) y, *, Tensor(b!) out) -> Tensor(b!)"
49+
)
50+
51+
@impl(lib, "custom_mutator", "Meta")
52+
def custom_mutator_meta(
53+
x: torch.Tensor,
54+
y: torch.Tensor,
55+
) -> torch.Tensor:
56+
return torch.empty_like(x)
57+
58+
@impl(lib, "custom_mutator", "CompositeExplicitAutograd")
59+
def custom_mutator(
60+
x: torch.Tensor,
61+
y: torch.Tensor,
62+
) -> torch.Tensor:
63+
return x + y.add_(1)
64+
65+
@impl_abstract("DO_NOT_USE_TEST_ONLY::custom_mutator.out")
66+
def custom_mutator_out(
67+
x: torch.Tensor,
68+
y: torch.Tensor,
69+
out: torch.Tensor,
70+
) -> torch.Tensor:
71+
out = custom_mutator_meta(
72+
x,
73+
y,
74+
)
75+
return out
76+
77+
out, _ = to_out_variant(
78+
torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator.default
79+
)
80+
schema = _get_overload_schema(out)
81+
self.assertEqual(
82+
schema.__str__(),
83+
"DO_NOT_USE_TEST_ONLY::custom_mutator.out(Tensor x, Tensor(a!) y, *, Tensor(b!) out) -> Tensor(b!)",
84+
)

exir/operator/test/test_util.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)