Skip to content

Commit 156263f

Browse files
committed
consolidating the two validators and removing assertion check from evaluator
1 parent 13a6f94 commit 156263f

File tree

1 file changed

+3
-17
lines changed

1 file changed

+3
-17
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def rand_validator(rand_node: Node) -> bool:
6060
if layout is not None:
6161
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
6262
return False
63+
return True
6364

6465

6566
@dynamo_tensorrt_converter(
@@ -76,21 +77,8 @@ def aten_ops_rand(
7677
return np.random.rand(*args)
7778

7879

79-
def randn_validator(randn_node: Node) -> bool:
80-
dtype = randn_node.kwargs.get("dtype", None)
81-
layout = randn_node.kwargs.get("layout", None)
82-
if dtype is not None:
83-
_LOGGER.debug(
84-
f"Currently we don't support specifying output dtype, got {dtype}."
85-
)
86-
return False
87-
if layout is not None:
88-
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
89-
return False
90-
91-
9280
@dynamo_tensorrt_converter(
93-
torch.ops.aten.randn.default, capability_validator=randn_validator
81+
torch.ops.aten.randn.default, capability_validator=rand_validator
9482
)
9583
def aten_ops_randn(
9684
ctx: ConversionContext,
@@ -118,6 +106,7 @@ def randperm_validator(randperm_node: Node) -> bool:
118106
if layout is not None:
119107
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
120108
return False
109+
return True
121110

122111

123112
@dynamo_tensorrt_converter(
@@ -131,7 +120,4 @@ def aten_ops_randperm(
131120
name: str,
132121
) -> Union[TRTTensor, Sequence[TRTTensor]]:
133122
device = kwargs.get("device", None)
134-
input = args[0]
135-
if not isinstance(input, int):
136-
raise RuntimeError(f"The input must be an integer")
137123
return np.random.permutation(*args)

0 commit comments

Comments
 (0)