Skip to content

Use sample input for edge ops #291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions exir/dialects/edge/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ def __init__(
alias: AllowedDtypeSet(set(types)) for alias, types in type_alias.items()
}
self.type_constraint: List[Dict[str, str]] = type_constraint
# type_constraint's non return entries should be same as all tensor args.
# type_constraint's non return entries should include all tensor-like arguments.
for t_constraint in self.type_constraint:
type_constraint_names = set(t_constraint)
all_tensor_arg_names = set(
self.essential_tensor_io_names + self.optional_tensor_io_names
)
if type_constraint_names != all_tensor_arg_names:
if not all_tensor_arg_names.issubset(type_constraint_names):
raise RuntimeError(
"Each input entry of type_constraint must be tensor-like,"
"Input entries of type_constraint must contain all tensor-like arguments, "
+ f"but get {type_constraint_names} and {all_tensor_arg_names}"
)

Expand Down
5 changes: 0 additions & 5 deletions exir/dialects/edge/arg/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,3 @@ def to_out(self, *, name: Optional[str] = None) -> OutArg:
value=self.value,
deps=self.deps,
)


def get_callable(name):
main, suffix = name.split(".")
return getattr(getattr(torch.ops.aten, main), suffix)
2 changes: 2 additions & 0 deletions exir/dialects/edge/dtype/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ python_library(
srcs = [
"runner.py",
"supported.py",
"utils.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir/dialects/edge/arg:lib",
"//executorch/exir/dialects/edge/op:lib",
],
)
37 changes: 22 additions & 15 deletions exir/dialects/edge/dtype/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@

import torch
import torch.testing._internal.common_dtype as common_dtype
from executorch.exir.dialects.edge.arg.model import (
ArgMode,
BaseArg,
BaseKwarg,
GenMode,
get_callable,
)
from executorch.exir.dialects.edge.arg.model import ArgMode, BaseArg, BaseKwarg, GenMode
from executorch.exir.dialects.edge.arg.type import ArgType
from executorch.exir.dialects.edge.dtype.utils import extract_return_dtype
from executorch.exir.dialects.edge.op.api import get_callable


class DtypeRunner:
Expand Down Expand Up @@ -92,12 +88,13 @@ def _get_type_tuples(
types = DtypeRunner._get_types(inputs)

def mapping(t):
type_dtypes = []
if t.is_optional():
return [None]
elif t.is_scalar():
return self.scalar_dtypes
type_dtypes = [None]
if t.is_scalar():
return type_dtypes + self.scalar_dtypes
elif t.is_scalar_type() or t.is_tensor() or t.is_tensor_list():
return self.tensor_dtypes
return type_dtypes + self.tensor_dtypes
else:
raise ValueError("Type {t.name} does not have dtype")

Expand Down Expand Up @@ -142,19 +139,29 @@ def run_dtypes(
args, kwargs = DtypeRunner._get_args_kwargs(inputs, dtypes, argmode)
op = get_callable(name)
try:
op(*args, **kwargs)
return (True, name, dtypes, args, kwargs)
res = op(*args, **kwargs)
ret_dtypes = ()
if "returns" in inputs:
ret_dtypes = tuple(extract_return_dtype(res, inputs["returns"]))
return (True, name, dtypes + ret_dtypes, args, kwargs)
except AssertionError as e:
raise RuntimeError(
f"opname: {name}, inputs: {inputs}, dtypes: {dtypes}, argmode {argmode}"
) from e
except Exception as e:
if argmode == ArgMode.ONES:
return (False, name, dtypes, args, kwargs)
ones_args, ones_kwargs = DtypeRunner._get_args_kwargs(
inputs, dtypes, ArgMode.ONES
)
try:
op(*ones_args, **ones_kwargs)
res = op(*args, **kwargs)
ret_dtypes = ()
if "returns" in inputs:
ret_dtypes = tuple(extract_return_dtype(res, inputs["returns"]))
print(e)
print(name, dtypes, args, kwargs)
return (True, name, dtypes, ones_args, ones_kwargs)
return (True, name, dtypes + ret_dtypes, ones_args, ones_kwargs)
except Exception:
return (False, name, dtypes, ones_args, ones_kwargs)

Expand Down
36 changes: 36 additions & 0 deletions exir/dialects/edge/dtype/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import collections
from typing import Any, List

import torch
from executorch.exir.dialects.edge.arg.model import BaseArg

from executorch.exir.dialects.edge.arg.type import ArgType


def extract_return_dtype(
returns: Any, sample_returns: List[BaseArg]
) -> List[torch.dtype]:
"""Extract the dtype from a return value."""
if not isinstance(returns, collections.abc.Sequence):
returns = [returns]
result = []
for ret, sample in zip(returns, sample_returns):
if sample.type == ArgType.TensorList or sample.type == ArgType.TensorOptList:
# Assuming all tensors in tensor list has the same dtype, and we only add 1 dtype to result.
assert (
ret is not None
), f"Expecting non-None return value for {sample} but got None"
result.append(ret.dtype)
break
elif sample.type == ArgType.Tensor or sample.type == ArgType.TensorOpt:
assert (
ret is not None
), f"Expecting non-None return value for {sample} but got None"
result.append(ret.dtype)
return result
Loading