Skip to content

Remove shape testing code from dialects edge #607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion exir/dialects/edge/arg/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ oncall("ai_infra_mobile_platform")
python_library(
name = "lib",
srcs = [
"constraints.py",
"model.py",
"type.py",
],
Expand Down
12 changes: 0 additions & 12 deletions exir/dialects/edge/arg/constraints.py

This file was deleted.

53 changes: 0 additions & 53 deletions exir/dialects/edge/arg/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def __init__(
nonzero=False,
nonneg=False,
bounded=False,
deps=None,
constraints=None,
):
self.type: ArgType = argtype

Expand All @@ -79,8 +77,6 @@ def __init__(
self.nonzero = nonzero
self.nonneg = nonneg
self.bounded = bounded
self.deps = deps
self.constraints = constraints

self._mode: ArgMode = ArgMode.DEFAULT
self._kw: bool = False
Expand Down Expand Up @@ -244,60 +240,12 @@ def get_val_with_dtype(self, dtype):
else:
raise ValueError(f"Unsupported Type: {self.type}")

def get_val_with_shape(self, shape):
if shape is None:
return None

def helper(s):
return torch.full(tuple(s), self.fill, dtype=self.dtype)

if self.type.is_tensor():
return helper(shape)
elif self.type.is_tensor_list():
return [helper(s) for s in shape]
else:
raise ValueError(f"Unsupported value with shape for type: {self.type}")

def get_val(self):
if self.type.has_dtype():
return self.get_val_with_dtype(self.dtype)
else:
return self.value

def get_shape(self):
if self.type.is_tensor():
return self.size
elif self.type.is_tensor_list():
if not self.value_given:
return []
return [s.size for s in self.value]
else:
raise ValueError(f"Unsupported get shape for type: {self.type}")

def get_constraints(self):
if self.type.is_dim():
constraints = {
"val_min": lambda deps: -deps[0].dim() if deps[0].dim() > 0 else -1,
"val_max": lambda deps: deps[0].dim() - 1 if deps[0].dim() > 0 else 0,
}
if self.type.is_dim_list():
constraints = {
"len_max": lambda deps: deps[0].dim(),
"val_min": lambda deps: -deps[0].dim() if deps[0].dim() > 0 else -1,
"val_max": lambda deps: deps[0].dim() - 1 if deps[0].dim() > 0 else 0,
"no_dups": True,
}
if self.type.is_index():
constraints = {
"val_min": lambda deps: -deps[0].size(deps[1]),
"val_max": lambda deps: deps[0].size(deps[1]) - 1,
}
if self.type.is_memory_format():
constraints = {"values": [None]}
if self.constraints is not None:
constraints.update(self.constraints)
return constraints


class BaseKwarg(BaseArg):
def __init__(self, argtype, argname, **kwargs):
Expand Down Expand Up @@ -360,5 +308,4 @@ def to_out(self, *, name: Optional[str] = None) -> OutArg:
size=self.size,
dtype=self.dtype,
value=self.value,
deps=self.deps,
)
91 changes: 0 additions & 91 deletions exir/dialects/edge/arg/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,13 @@
class ArgType(str, Enum):
Tensor = "Tensor"
TensorOpt = "Tensor?"

TensorList = "Tensor[]"
TensorOptList = "Tensor?[]"

Scalar = "Scalar"
ScalarOpt = "Scalar?"

ScalarType = "ScalarType"
ScalarTypeOpt = "ScalarType?"

Dim = "Dim"
DimOpt = "Dim?"
DimList = "Dim[]"
DimListOpt = "Dim[]?"

Shape = "Shape"
Stride = "Stride"
Index = "Index"
IndexOpt = "Index?"
Length = "Length"
LengthList = "Length[]"

Param = "Param"
Float = "Float"
FloatOpt = "Float?"
MemoryFormat = "MemoryFormat"

Bool = "Bool"
Keepdim = "Keepdim"

def is_tensor(self):
return self in [ArgType.Tensor, ArgType.TensorOpt]
Expand All @@ -52,62 +30,11 @@ def is_scalar(self):
def is_scalar_type(self):
return self in [ArgType.ScalarType, ArgType.ScalarTypeOpt]

def is_dim(self):
return self in [ArgType.Dim, ArgType.DimOpt]

def is_dim_list(self):
return self in [ArgType.DimList, ArgType.DimListOpt]

def is_shape(self):
return self in [ArgType.Shape]

def is_stride(self):
return self in [ArgType.Stride]

def is_index(self):
return self in [ArgType.Index, ArgType.IndexOpt]

def is_length(self):
return self in [ArgType.Length]

def is_length_list(self):
return self in [ArgType.LengthList]

def is_keepdim(self):
return self in [ArgType.Keepdim]

def is_param(self):
return self in [
ArgType.Param,
ArgType.Float,
ArgType.FloatOpt,
ArgType.MemoryFormat,
]

def is_bool(self):
return self in [ArgType.Bool, ArgType.Keepdim]

def is_float(self):
return self in [ArgType.Float, ArgType.FloatOpt]

def is_optional(self):
return self in [
ArgType.TensorOpt,
ArgType.ScalarOpt,
ArgType.ScalarTypeOpt,
ArgType.DimOpt,
ArgType.DimListOpt,
ArgType.FloatOpt,
ArgType.IndexOpt,
]

def is_list(self):
return self in [
ArgType.TensorList,
ArgType.TensorOptList,
ArgType.DimList,
ArgType.DimListOpt,
ArgType.LengthList,
]

def has_dtype(self):
Expand All @@ -117,21 +44,3 @@ def has_dtype(self):
or self.is_scalar()
or self.is_scalar_type()
)

def has_shape(self):
return self.is_tensor() or self.is_tensor_list()

def has_length(self):
return self.is_list()

def is_shape_relevant(self):
return (
self.is_dim()
or self.is_dim_list()
or self.is_shape()
or self.is_stride()
or self.is_index()
or self.is_length()
or self.is_length_list()
or self.is_keepdim()
)
6 changes: 3 additions & 3 deletions exir/dialects/edge/dtype/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import itertools
import random
from typing import Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple

import torch
import torch.testing._internal.common_dtype as common_dtype
Expand All @@ -28,7 +28,7 @@ def _get_types(inputs: Dict[str, List[BaseArg]]) -> List[ArgType]:

@staticmethod
def _get_args_kwargs(
inputs: Dict[str, List[Union[BaseArg]]],
inputs: Dict[str, List[BaseArg]],
dtypes: Tuple[Optional[torch.dtype]],
mode: ArgMode,
) -> Tuple[List[BaseArg], Dict[str, BaseKwarg]]:
Expand Down Expand Up @@ -174,7 +174,7 @@ def run_dtypes(
def run(
self,
name: str,
inputs: Dict[str, List[BaseArg]],
inputs: Dict[str, Any],
argmode: ArgMode = ArgMode.ONES,
) -> List[
Tuple[
Expand Down
Loading