Skip to content

Commit 0cd9c57

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Use sample input for edge ops (#291)
Summary: Pull Request resolved: #291 Use sample inputs to generate dtype constraints for edge ops. This is depending on diff stack D49088856. The purpose of the whole stack is to increase edge ops dtype constraints coverage. This diff uses the sample input generator DtypeRunner to generate all possible dtype arguments and get back valid ones for each operator. Reviewed By: manuelcandales Differential Revision: D49182359 fbshipit-source-id: 7274b56037b53488ad70b28b66a430dcd48754ed
1 parent 632a6e1 commit 0cd9c57

File tree

12 files changed

+6510
-253
lines changed

12 files changed

+6510
-253
lines changed

exir/dialects/edge/_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,15 @@ def __init__(
8787
alias: AllowedDtypeSet(set(types)) for alias, types in type_alias.items()
8888
}
8989
self.type_constraint: List[Dict[str, str]] = type_constraint
90-
# type_constraint's non return entries should be same as all tensor args.
90+
# type_constraint's non return entries should include all tensor-like arguments.
9191
for t_constraint in self.type_constraint:
9292
type_constraint_names = set(t_constraint)
9393
all_tensor_arg_names = set(
9494
self.essential_tensor_io_names + self.optional_tensor_io_names
9595
)
96-
if type_constraint_names != all_tensor_arg_names:
96+
if not all_tensor_arg_names.issubset(type_constraint_names):
9797
raise RuntimeError(
98-
"Each input entry of type_constraint must be tensor-like,"
98+
"Input entries of type_constraint must contain all tensor-like arguments, "
9999
+ f"but get {type_constraint_names} and {all_tensor_arg_names}"
100100
)
101101

exir/dialects/edge/arg/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,3 @@ def to_out(self, *, name: Optional[str] = None) -> OutArg:
362362
value=self.value,
363363
deps=self.deps,
364364
)
365-
366-
367-
def get_callable(name):
368-
main, suffix = name.split(".")
369-
return getattr(getattr(torch.ops.aten, main), suffix)

exir/dialects/edge/dtype/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ python_library(
77
srcs = [
88
"runner.py",
99
"supported.py",
10+
"utils.py",
1011
],
1112
deps = [
1213
"//caffe2:torch",
1314
"//executorch/exir/dialects/edge/arg:lib",
15+
"//executorch/exir/dialects/edge/op:lib",
1416
],
1517
)

exir/dialects/edge/dtype/runner.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,10 @@
1010

1111
import torch
1212
import torch.testing._internal.common_dtype as common_dtype
13-
from executorch.exir.dialects.edge.arg.model import (
14-
ArgMode,
15-
BaseArg,
16-
BaseKwarg,
17-
GenMode,
18-
get_callable,
19-
)
13+
from executorch.exir.dialects.edge.arg.model import ArgMode, BaseArg, BaseKwarg, GenMode
2014
from executorch.exir.dialects.edge.arg.type import ArgType
15+
from executorch.exir.dialects.edge.dtype.utils import extract_return_dtype
16+
from executorch.exir.dialects.edge.op.api import get_callable
2117

2218

2319
class DtypeRunner:
@@ -92,12 +88,13 @@ def _get_type_tuples(
9288
types = DtypeRunner._get_types(inputs)
9389

9490
def mapping(t):
91+
type_dtypes = []
9592
if t.is_optional():
96-
return [None]
97-
elif t.is_scalar():
98-
return self.scalar_dtypes
93+
type_dtypes = [None]
94+
if t.is_scalar():
95+
return type_dtypes + self.scalar_dtypes
9996
elif t.is_scalar_type() or t.is_tensor() or t.is_tensor_list():
100-
return self.tensor_dtypes
97+
return type_dtypes + self.tensor_dtypes
10198
else:
10299
raise ValueError("Type {t.name} does not have dtype")
103100

@@ -142,19 +139,29 @@ def run_dtypes(
142139
args, kwargs = DtypeRunner._get_args_kwargs(inputs, dtypes, argmode)
143140
op = get_callable(name)
144141
try:
145-
op(*args, **kwargs)
146-
return (True, name, dtypes, args, kwargs)
142+
res = op(*args, **kwargs)
143+
ret_dtypes = ()
144+
if "returns" in inputs:
145+
ret_dtypes = tuple(extract_return_dtype(res, inputs["returns"]))
146+
return (True, name, dtypes + ret_dtypes, args, kwargs)
147+
except AssertionError as e:
148+
raise RuntimeError(
149+
f"opname: {name}, inputs: {inputs}, dtypes: {dtypes}, argmode {argmode}"
150+
) from e
147151
except Exception as e:
148152
if argmode == ArgMode.ONES:
149153
return (False, name, dtypes, args, kwargs)
150154
ones_args, ones_kwargs = DtypeRunner._get_args_kwargs(
151155
inputs, dtypes, ArgMode.ONES
152156
)
153157
try:
154-
op(*ones_args, **ones_kwargs)
158+
res = op(*args, **kwargs)
159+
ret_dtypes = ()
160+
if "returns" in inputs:
161+
ret_dtypes = tuple(extract_return_dtype(res, inputs["returns"]))
155162
print(e)
156163
print(name, dtypes, args, kwargs)
157-
return (True, name, dtypes, ones_args, ones_kwargs)
164+
return (True, name, dtypes + ret_dtypes, ones_args, ones_kwargs)
158165
except Exception:
159166
return (False, name, dtypes, ones_args, ones_kwargs)
160167

exir/dialects/edge/dtype/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import collections
8+
from typing import Any, List
9+
10+
import torch
11+
from executorch.exir.dialects.edge.arg.model import BaseArg
12+
13+
from executorch.exir.dialects.edge.arg.type import ArgType
14+
15+
16+
def extract_return_dtype(
17+
returns: Any, sample_returns: List[BaseArg]
18+
) -> List[torch.dtype]:
19+
"""Extract the dtype from a return value."""
20+
if not isinstance(returns, collections.abc.Sequence):
21+
returns = [returns]
22+
result = []
23+
for ret, sample in zip(returns, sample_returns):
24+
if sample.type == ArgType.TensorList or sample.type == ArgType.TensorOptList:
25+
# Assuming all tensors in tensor list has the same dtype, and we only add 1 dtype to result.
26+
assert (
27+
ret is not None
28+
), f"Expecting non-None return value for {sample} but got None"
29+
result.append(ret.dtype)
30+
break
31+
elif sample.type == ArgType.Tensor or sample.type == ArgType.TensorOpt:
32+
assert (
33+
ret is not None
34+
), f"Expecting non-None return value for {sample} but got None"
35+
result.append(ret.dtype)
36+
return result

0 commit comments

Comments
 (0)