Skip to content

Commit 22dfacf

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
make arg verifier support optional tensor and tensor list
Summary: update edge dialect argment verifier to support operators with tensorlist and optional tensor type input and ouput Reviewed By: larryliu0820 Differential Revision: D47811606 fbshipit-source-id: 6225885ced59f588c46bf134250610d44f84df91
1 parent e2909b9 commit 22dfacf

File tree

4 files changed

+105
-21
lines changed

4 files changed

+105
-21
lines changed

exir/dialects/edge/TARGETS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ python_library(
1010
"yaml_generator.py",
1111
],
1212
deps = [
13-
"fbsource//third-party/pypi/expecttest:expecttest",
13+
"fbsource//third-party/pypi/expecttest:expecttest", # @manual
1414
"fbsource//third-party/pypi/ruamel-yaml:ruamel-yaml",
1515
":support_dtypes",
1616
":utils",
@@ -26,6 +26,7 @@ python_binary(
2626
],
2727
main_module = "executorch.exir.dialects.edge.yaml_generator",
2828
deps = [
29+
"fbsource//third-party/pypi/expecttest:expecttest", # @manual
2930
"fbsource//third-party/pypi/ruamel-yaml:ruamel-yaml",
3031
":support_dtypes",
3132
":utils",

exir/dialects/edge/_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
+ f"but get {type_constraint_names} and {all_tensor_arg_names}"
105105
)
106106

107-
def validate(self, types: Dict[str, torch.dtype]) -> bool:
107+
def validate(self, types: Dict[str, Optional[torch.dtype]]) -> bool:
108108
"""Check if the given input type combination a legal one of current function.
109109
110110
Args:
@@ -135,7 +135,11 @@ def validate(self, types: Dict[str, torch.dtype]) -> bool:
135135
valid_type = True
136136
# Narrow down the type_alias based on contraint and actual input
137137
for arg_name, arg_type in types.items():
138-
if arg_type in self.type_alias[constraint[arg_name]]:
138+
if arg_type is None:
139+
# None means the user didn't set dtype for this argment
140+
# (i.e. empty tensorlist), skipping the validation.
141+
continue
142+
elif arg_type in self.type_alias[constraint[arg_name]]:
139143
self.type_alias[constraint[arg_name]].reduce_to(arg_type)
140144
else:
141145
valid_type = False

exir/tests/test_verification.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ def __init__(self):
151151
self.register_buffer("a", torch.randn(1, 3, 100, 100))
152152

153153
def forward(self, x):
154-
return self.a + x
154+
b = self.a + x
155+
return torch._to_cpu([b, x])
155156

156157
m = TestModel()
157158
egm = (
@@ -167,14 +168,60 @@ def forward(self, x):
167168
verifier(egm)
168169
self.assertTrue(verifier.is_valid(egm))
169170

171+
def test_edge_happy_with_optional_tensor_input(self) -> None:
172+
class TestModel(torch.nn.Module):
173+
def __init__(self):
174+
super().__init__()
175+
176+
def forward(self, x, weight, bias):
177+
# weight and bias here are optional tensor inputs.
178+
return torch.group_norm(x, 4, weight, bias)
179+
180+
m = TestModel()
181+
egm = (
182+
exir.capture(
183+
m,
184+
(torch.rand(16, 8, 32, 32), torch.rand(8), torch.rand(8)),
185+
exir.CaptureConfig(pt2_mode=True),
186+
)
187+
.to_edge()
188+
.exported_program.graph_module
189+
)
190+
verifier = EXIREdgeDialectVerifier()
191+
verifier(egm)
192+
self.assertTrue(verifier.is_valid(egm))
193+
194+
def test_edge_happy_with_empty_tensorlist_input(self) -> None:
195+
class TestModel(torch.nn.Module):
196+
def __init__(self):
197+
super().__init__()
198+
199+
def forward(self, x):
200+
return torch._to_cpu(x)
201+
202+
m = TestModel()
203+
egm = (
204+
exir.capture(
205+
m,
206+
([],),
207+
exir.CaptureConfig(pt2_mode=True),
208+
)
209+
.to_edge()
210+
.exported_program.graph_module
211+
)
212+
verifier = EXIREdgeDialectVerifier()
213+
verifier(egm)
214+
self.assertTrue(verifier.is_valid(egm))
215+
170216
def test_edge_sad(self) -> None:
171217
class TestModel(torch.nn.Module):
172218
def __init__(self):
173219
super().__init__()
174220
self.register_buffer("a", torch.randn(1, 3, 100, 100))
175221

176222
def forward(self, x):
177-
return self.a + x
223+
b = self.a + x
224+
return torch._to_cpu([b, x])
178225

179226
m = TestModel()
180227
egm = exir.capture(

exir/verification/arg_validator.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from collections import defaultdict
8-
from typing import Any, Dict, Sequence, Tuple
8+
from typing import Any, Dict, Optional, Sequence, Tuple
99

1010
import torch
1111
from executorch.exir.dialects.edge._ops import EdgeDialectFunctionSchema, EdgeOpOverload
@@ -37,9 +37,9 @@ class EdgeOpArgValidator(torch.fx.Interpreter):
3737

3838
def __init__(self, graph_module: torch.fx.GraphModule) -> None:
3939
super().__init__(graph_module)
40-
self.violating_ops: Dict[EdgeOpOverload, Dict[str, torch.dtype]] = defaultdict(
41-
dict
42-
)
40+
self.violating_ops: Dict[
41+
EdgeOpOverload, Dict[str, Optional[torch.dtype]]
42+
] = defaultdict(dict)
4343

4444
def run_node(self, n: torch.fx.Node) -> None:
4545
self.node = n
@@ -52,6 +52,16 @@ def run_node(self, n: torch.fx.Node) -> None:
5252
raise InternalError(str(e)) from e
5353
return ret
5454

55+
def _get_kernel_arg(self, schema_arg, schema_arg_idx, args, kwargs):
56+
if schema_arg.name in kwargs:
57+
kernel_arg = kwargs[schema_arg.name]
58+
elif not schema_arg.kwarg_only and schema_arg_idx < len(args):
59+
kernel_arg = args[schema_arg_idx]
60+
else:
61+
kernel_arg = schema_arg.default_value
62+
63+
return kernel_arg
64+
5565
def call_function(
5666
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
5767
) -> Any:
@@ -64,19 +74,32 @@ def call_function(
6474
if isinstance(target, HigherOrderOperator):
6575
raise RunHigherOrderOperatorError("Can't run delegate")
6676
return super().call_function(target, args, kwargs)
67-
tensor_arg_types: Dict[str, torch.dtype] = {}
77+
78+
# TODO(gasoonjia): Update Optional[torch.dtype] to a concrete class to support mixed dtypes in tensorlist.
79+
tensor_arg_types: Dict[str, Optional[torch.dtype]] = {}
6880
for i, schema_arg in enumerate(target._schema.arguments):
69-
if not isinstance(schema_arg.type, torch.TensorType):
70-
continue
71-
if schema_arg.name in kwargs:
72-
kernel_arg = kwargs[schema_arg.name]
73-
elif not schema_arg.kwarg_only and i < len(args):
74-
kernel_arg = args[i]
75-
else:
76-
kernel_arg = schema_arg.default_value
77-
if not isinstance(kernel_arg, torch.Tensor):
78-
continue
79-
tensor_arg_types[schema_arg.name] = kernel_arg.dtype
81+
if (
82+
isinstance(schema_arg.type, torch.TensorType)
83+
or schema_arg.type == torch.OptionalType.ofTensor()
84+
):
85+
kernel_arg = self._get_kernel_arg(schema_arg, i, args, kwargs)
86+
if not isinstance(kernel_arg, torch.Tensor):
87+
continue
88+
tensor_arg_types[schema_arg.name] = kernel_arg.dtype
89+
elif schema_arg.type == torch.ListType.ofTensors():
90+
kernel_arg = self._get_kernel_arg(schema_arg, i, args, kwargs)
91+
if not isinstance(kernel_arg, list) or not all(
92+
isinstance(kernel_arg[i], torch.Tensor)
93+
for i in range(len(kernel_arg))
94+
):
95+
continue
96+
if len(kernel_arg):
97+
tensor_arg_types[schema_arg.name] = kernel_arg[0].dtype
98+
else:
99+
# If kernel_arg is an empty list, treat its type as None.
100+
# FunctionDtypeConstraint.validate will take None as any legal dtype.
101+
tensor_arg_types[schema_arg.name] = None
102+
80103
ret_index = 0
81104
kernel_rets = self.node.meta["val"]
82105
ret_iter = iter(
@@ -85,11 +108,20 @@ def call_function(
85108
for schema_ret in target._schema.returns:
86109
name = schema_ret.name if schema_ret.name else f"__ret_{ret_index}"
87110
kernel_ret = next(ret_iter)
111+
# Return value should not be in OptionalTensor type, so only check torch.TensorType here.
88112
if isinstance(schema_ret.type, torch.TensorType) and isinstance(
89113
kernel_ret, torch.Tensor
90114
):
91115
tensor_arg_types[name] = kernel_ret.dtype
92116
ret_index += 1
117+
elif schema_ret.type == torch.ListType.ofTensors() and all(
118+
isinstance(kernel_ret[i], torch.Tensor) for i in range(len(kernel_ret))
119+
):
120+
if len(kernel_ret):
121+
tensor_arg_types[name] = kernel_ret[0].dtype
122+
else:
123+
tensor_arg_types[name] = None
124+
ret_index += 1
93125

94126
valid = target._schema.dtype_constraint.validate(tensor_arg_types)
95127
if not valid:

0 commit comments

Comments
 (0)