Skip to content

Add support for SchemaKind.mutable to to_out_variant #2541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions exir/operator/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ python_library(
)

python_unittest(
name = "test_util",
name = "test_operator",
srcs = [
"test/test_util.py",
"test/test_operator.py",
],
deps = [
":convert",
":util",
"//caffe2:torch",
],
Expand Down
37 changes: 27 additions & 10 deletions exir/operator/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
# existing case with cache miss.
_func_to_out_variant_map: Dict[OpOverload, Optional[OpOverload]] = {}
_out_variant_to_scratch_map: Dict[OpOverload, Optional[OpOverload]] = {}
_mutable_to_out_variant_map: Dict[OpOverload, Optional[OpOverload]] = {}

# We've found a functional and an out variant with the same name, but their
# schemas mismatch. This map collects all of these cases and provides proper
Expand Down Expand Up @@ -135,9 +136,9 @@ def schema_to_opoverload(schema: FunctionSchema) -> OpOverload:

def set_mapping_for_op(op: OpOverload) -> None:
"""
op can either be a functional op or out variant op.
op can either be a functional op, mutable op, or out variant op.
This method is only called if
1. either op is a functiona op and it's missing in the _func_to_out_variant_map cache.
1. either op is a functional op and it's missing in the _func_to_out_variant_map cache.
2. or op is a out variant op and it's missing in the _out_variant_to_scratch_map cache.

Setup entries in _func_to_out_variant_map and _out_variant_to_scratch_map for all ops sharing the same
Expand All @@ -148,13 +149,17 @@ def set_mapping_for_op(op: OpOverload) -> None:
assert native_schema.kind() in (
SchemaKind.functional,
SchemaKind.out,
SchemaKind.mutable,
)
assert not (
native_schema.kind() == SchemaKind.functional and op in _func_to_out_variant_map
)
assert not (
native_schema.kind() == SchemaKind.out and op in _out_variant_to_scratch_map
)
assert not (
native_schema.kind() == SchemaKind.mutable and op in _mutable_to_out_variant_map
)
qualified_opname = str(op._schema.name)

all_schemas = [
Expand Down Expand Up @@ -192,6 +197,7 @@ def set_mapping_for_op(op: OpOverload) -> None:
for group_by_kind in group_by_signature.values():
func_op_schema = group_by_kind.get(SchemaKind.functional)
out_var_schema = group_by_kind.get(SchemaKind.out)
mutable_op_schema = group_by_kind.get(SchemaKind.mutable)
scratch_schema = group_by_kind.get(SchemaKind.scratch)

# update the map even if out_var_schema is None to cache the negative
Expand All @@ -216,30 +222,41 @@ def set_mapping_for_op(op: OpOverload) -> None:
_out_variant_to_scratch_map[schema_to_opoverload(out_var_schema)] = (
schema_to_opoverload(scratch_schema) if scratch_schema else None
)
if mutable_op_schema:
_mutable_to_out_variant_map[schema_to_opoverload(mutable_op_schema)] = (
schema_to_opoverload(out_var_schema) if out_var_schema else None
)


def to_out_variant(op_overload: OpOverload) -> Tuple[OpOverload, Tuple[str]]:
r"""
Convert the passed in OpOverload to its out variant. Raise an exception if
on return the op_overload is not guaranteed to be an out variant.

If a conversion is found, return the out variant OpOverlaod alongwith the name of out
If a conversion is found, return the out variant OpOverload alongwith the name of out
arguments.
"""
schema = _get_overload_schema(op_overload)
if schema.is_out_fn(): # pyre-ignore
return op_overload, get_out_args_from_schema(schema) # pyre-ignore
return op_overload, get_out_args_from_schema(schema) # pyre-ignore[6]

# should be a functional op here
# should be a functionalish op here
assert (
schema.kind() == SchemaKind.functional # pyre-ignore
), f"Expect an functional op, but get {schema}"

if op_overload not in _func_to_out_variant_map:
schema.kind() == SchemaKind.functional # pyre-ignore[16]
or schema.kind() == SchemaKind.mutable
), f"Expect a functionalish op, but get {schema.kind()} {schema}"

if (
op_overload not in _func_to_out_variant_map
and op_overload not in _mutable_to_out_variant_map
):
# setup out_var
set_mapping_for_op(op_overload)

out_var = _func_to_out_variant_map.get(op_overload)
if op_overload in _mutable_to_out_variant_map:
out_var = _mutable_to_out_variant_map[op_overload]
else:
out_var = _func_to_out_variant_map.get(op_overload)

if not out_var:
msg = f"Missing out variant for functional op: {schema} . Make sure you have loaded your custom operator library for compiler. E.g., custom_ops_generated_lib"
Expand Down
84 changes: 84 additions & 0 deletions exir/operator/test/test_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

import torch
from executorch.exir.operator.convert import _get_overload_schema, to_out_variant
from executorch.exir.operator.util import gen_out_variant_schema
from torch.library import _scoped_library, impl, impl_abstract


class TestOperator(unittest.TestCase):
def setUp(self) -> None:
super().setUp()

def test_gen_out_variant_schema_from_functional(self) -> None:
func_schema = str(torch.ops.aten.mul.Scalar._schema)

out_schema = gen_out_variant_schema(func_schema)
self.assertEqual(out_schema, str(torch.ops.aten.mul.Scalar_out._schema))

def test_gen_out_variant_schema_from_inplace(self) -> None:
func_schema = str(torch.ops.aten.add_.Scalar._schema)

out_schema = gen_out_variant_schema(func_schema)
self.assertEqual(out_schema, str(torch.ops.aten.add.Scalar_out._schema))

def test_gen_out_variant_schema_for_custom_ops(self) -> None:
func_schema = "custom::foo(Tensor a, Tensor b) -> (Tensor c, Tensor d)"

out_schema = gen_out_variant_schema(func_schema)
self.assertEqual(
out_schema,
"custom::foo.out(Tensor a, Tensor b, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))",
)

def test_to_out_variant_mutable(self) -> None:

with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib:

lib.define("custom_mutator(Tensor x, Tensor(a!) y) -> Tensor")
lib.define(
"custom_mutator.out(Tensor x, Tensor(a!) y, *, Tensor(b!) out) -> Tensor(b!)"
)

@impl(lib, "custom_mutator", "Meta")
def custom_mutator_meta(
x: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(x)

@impl(lib, "custom_mutator", "CompositeExplicitAutograd")
def custom_mutator(
x: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
return x + y.add_(1)

@impl_abstract("DO_NOT_USE_TEST_ONLY::custom_mutator.out")
def custom_mutator_out(
x: torch.Tensor,
y: torch.Tensor,
out: torch.Tensor,
) -> torch.Tensor:
out = custom_mutator_meta(
x,
y,
)
return out

out, _ = to_out_variant(
torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator.default
)
schema = _get_overload_schema(out)
self.assertEqual(
schema.__str__(),
"DO_NOT_USE_TEST_ONLY::custom_mutator.out(Tensor x, Tensor(a!) y, *, Tensor(b!) out) -> Tensor(b!)",
)
38 changes: 0 additions & 38 deletions exir/operator/test/test_util.py

This file was deleted.