Skip to content

Commit 8bb7de1

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Remove shape testing code from dialects edge (#607)
Summary: The purpose of this diff stack is to clean up the OpInput code that lives inside exir/dialects/edge, leaving only what is needed to search allowed dtypes per op. Code that is only used for test generation is removed from exir/dialects/edge Reviewed By: larryliu0820 Differential Revision: D49891113
1 parent 17fee78 commit 8bb7de1

File tree

7 files changed

+107
-357
lines changed

7 files changed

+107
-357
lines changed

exir/dialects/edge/arg/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ oncall("ai_infra_mobile_platform")
55
python_library(
66
name = "lib",
77
srcs = [
8-
"constraints.py",
98
"model.py",
109
"type.py",
1110
],

exir/dialects/edge/arg/constraints.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

exir/dialects/edge/arg/model.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ def __init__(
6161
nonzero=False,
6262
nonneg=False,
6363
bounded=False,
64-
deps=None,
65-
constraints=None,
6664
):
6765
self.type: ArgType = argtype
6866

@@ -79,8 +77,6 @@ def __init__(
7977
self.nonzero = nonzero
8078
self.nonneg = nonneg
8179
self.bounded = bounded
82-
self.deps = deps
83-
self.constraints = constraints
8480

8581
self._mode: ArgMode = ArgMode.DEFAULT
8682
self._kw: bool = False
@@ -244,60 +240,12 @@ def get_val_with_dtype(self, dtype):
244240
else:
245241
raise ValueError(f"Unsupported Type: {self.type}")
246242

247-
def get_val_with_shape(self, shape):
248-
if shape is None:
249-
return None
250-
251-
def helper(s):
252-
return torch.full(tuple(s), self.fill, dtype=self.dtype)
253-
254-
if self.type.is_tensor():
255-
return helper(shape)
256-
elif self.type.is_tensor_list():
257-
return [helper(s) for s in shape]
258-
else:
259-
raise ValueError(f"Unsupported value with shape for type: {self.type}")
260-
261243
def get_val(self):
262244
if self.type.has_dtype():
263245
return self.get_val_with_dtype(self.dtype)
264246
else:
265247
return self.value
266248

267-
def get_shape(self):
268-
if self.type.is_tensor():
269-
return self.size
270-
elif self.type.is_tensor_list():
271-
if not self.value_given:
272-
return []
273-
return [s.size for s in self.value]
274-
else:
275-
raise ValueError(f"Unsupported get shape for type: {self.type}")
276-
277-
def get_constraints(self):
278-
if self.type.is_dim():
279-
constraints = {
280-
"val_min": lambda deps: -deps[0].dim() if deps[0].dim() > 0 else -1,
281-
"val_max": lambda deps: deps[0].dim() - 1 if deps[0].dim() > 0 else 0,
282-
}
283-
if self.type.is_dim_list():
284-
constraints = {
285-
"len_max": lambda deps: deps[0].dim(),
286-
"val_min": lambda deps: -deps[0].dim() if deps[0].dim() > 0 else -1,
287-
"val_max": lambda deps: deps[0].dim() - 1 if deps[0].dim() > 0 else 0,
288-
"no_dups": True,
289-
}
290-
if self.type.is_index():
291-
constraints = {
292-
"val_min": lambda deps: -deps[0].size(deps[1]),
293-
"val_max": lambda deps: deps[0].size(deps[1]) - 1,
294-
}
295-
if self.type.is_memory_format():
296-
constraints = {"values": [None]}
297-
if self.constraints is not None:
298-
constraints.update(self.constraints)
299-
return constraints
300-
301249

302250
class BaseKwarg(BaseArg):
303251
def __init__(self, argtype, argname, **kwargs):
@@ -360,5 +308,4 @@ def to_out(self, *, name: Optional[str] = None) -> OutArg:
360308
size=self.size,
361309
dtype=self.dtype,
362310
value=self.value,
363-
deps=self.deps,
364311
)

exir/dialects/edge/arg/type.py

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,13 @@
1010
class ArgType(str, Enum):
1111
Tensor = "Tensor"
1212
TensorOpt = "Tensor?"
13-
1413
TensorList = "Tensor[]"
1514
TensorOptList = "Tensor?[]"
16-
1715
Scalar = "Scalar"
1816
ScalarOpt = "Scalar?"
19-
2017
ScalarType = "ScalarType"
2118
ScalarTypeOpt = "ScalarType?"
22-
23-
Dim = "Dim"
24-
DimOpt = "Dim?"
25-
DimList = "Dim[]"
26-
DimListOpt = "Dim[]?"
27-
28-
Shape = "Shape"
29-
Stride = "Stride"
30-
Index = "Index"
31-
IndexOpt = "Index?"
32-
Length = "Length"
33-
LengthList = "Length[]"
34-
3519
Param = "Param"
36-
Float = "Float"
37-
FloatOpt = "Float?"
38-
MemoryFormat = "MemoryFormat"
39-
40-
Bool = "Bool"
41-
Keepdim = "Keepdim"
4220

4321
def is_tensor(self):
4422
return self in [ArgType.Tensor, ArgType.TensorOpt]
@@ -52,62 +30,11 @@ def is_scalar(self):
5230
def is_scalar_type(self):
5331
return self in [ArgType.ScalarType, ArgType.ScalarTypeOpt]
5432

55-
def is_dim(self):
56-
return self in [ArgType.Dim, ArgType.DimOpt]
57-
58-
def is_dim_list(self):
59-
return self in [ArgType.DimList, ArgType.DimListOpt]
60-
61-
def is_shape(self):
62-
return self in [ArgType.Shape]
63-
64-
def is_stride(self):
65-
return self in [ArgType.Stride]
66-
67-
def is_index(self):
68-
return self in [ArgType.Index, ArgType.IndexOpt]
69-
70-
def is_length(self):
71-
return self in [ArgType.Length]
72-
73-
def is_length_list(self):
74-
return self in [ArgType.LengthList]
75-
76-
def is_keepdim(self):
77-
return self in [ArgType.Keepdim]
78-
79-
def is_param(self):
80-
return self in [
81-
ArgType.Param,
82-
ArgType.Float,
83-
ArgType.FloatOpt,
84-
ArgType.MemoryFormat,
85-
]
86-
87-
def is_bool(self):
88-
return self in [ArgType.Bool, ArgType.Keepdim]
89-
90-
def is_float(self):
91-
return self in [ArgType.Float, ArgType.FloatOpt]
92-
9333
def is_optional(self):
9434
return self in [
9535
ArgType.TensorOpt,
9636
ArgType.ScalarOpt,
9737
ArgType.ScalarTypeOpt,
98-
ArgType.DimOpt,
99-
ArgType.DimListOpt,
100-
ArgType.FloatOpt,
101-
ArgType.IndexOpt,
102-
]
103-
104-
def is_list(self):
105-
return self in [
106-
ArgType.TensorList,
107-
ArgType.TensorOptList,
108-
ArgType.DimList,
109-
ArgType.DimListOpt,
110-
ArgType.LengthList,
11138
]
11239

11340
def has_dtype(self):
@@ -117,21 +44,3 @@ def has_dtype(self):
11744
or self.is_scalar()
11845
or self.is_scalar_type()
11946
)
120-
121-
def has_shape(self):
122-
return self.is_tensor() or self.is_tensor_list()
123-
124-
def has_length(self):
125-
return self.is_list()
126-
127-
def is_shape_relevant(self):
128-
return (
129-
self.is_dim()
130-
or self.is_dim_list()
131-
or self.is_shape()
132-
or self.is_stride()
133-
or self.is_index()
134-
or self.is_length()
135-
or self.is_length_list()
136-
or self.is_keepdim()
137-
)

exir/dialects/edge/dtype/runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import itertools
88
import random
9-
from typing import Dict, Iterator, List, Optional, Tuple, Union
9+
from typing import Any, Dict, Iterator, List, Optional, Tuple
1010

1111
import torch
1212
import torch.testing._internal.common_dtype as common_dtype
@@ -28,7 +28,7 @@ def _get_types(inputs: Dict[str, List[BaseArg]]) -> List[ArgType]:
2828

2929
@staticmethod
3030
def _get_args_kwargs(
31-
inputs: Dict[str, List[Union[BaseArg]]],
31+
inputs: Dict[str, List[BaseArg]],
3232
dtypes: Tuple[Optional[torch.dtype]],
3333
mode: ArgMode,
3434
) -> Tuple[List[BaseArg], Dict[str, BaseKwarg]]:
@@ -174,7 +174,7 @@ def run_dtypes(
174174
def run(
175175
self,
176176
name: str,
177-
inputs: Dict[str, List[BaseArg]],
177+
inputs: Dict[str, Any],
178178
argmode: ArgMode = ArgMode.ONES,
179179
) -> List[
180180
Tuple[

0 commit comments

Comments
 (0)