@@ -60,6 +60,7 @@ def rand_validator(rand_node: Node) -> bool:
60
60
if layout is not None :
61
61
_LOGGER .debug (f"Currently we don't support specifying layout, got { layout } ." )
62
62
return False
63
+ return True
63
64
64
65
65
66
@dynamo_tensorrt_converter (
@@ -76,21 +77,8 @@ def aten_ops_rand(
76
77
return np .random .rand (* args )
77
78
78
79
79
- def randn_validator (randn_node : Node ) -> bool :
80
- dtype = randn_node .kwargs .get ("dtype" , None )
81
- layout = randn_node .kwargs .get ("layout" , None )
82
- if dtype is not None :
83
- _LOGGER .debug (
84
- f"Currently we don't support specifying output dtype, got { dtype } ."
85
- )
86
- return False
87
- if layout is not None :
88
- _LOGGER .debug (f"Currently we don't support specifying layout, got { layout } ." )
89
- return False
90
-
91
-
92
80
@dynamo_tensorrt_converter (
93
- torch .ops .aten .randn .default , capability_validator = randn_validator
81
+ torch .ops .aten .randn .default , capability_validator = rand_validator
94
82
)
95
83
def aten_ops_randn (
96
84
ctx : ConversionContext ,
@@ -118,6 +106,7 @@ def randperm_validator(randperm_node: Node) -> bool:
118
106
if layout is not None :
119
107
_LOGGER .debug (f"Currently we don't support specifying layout, got { layout } ." )
120
108
return False
109
+ return True
121
110
122
111
123
112
@dynamo_tensorrt_converter (
@@ -131,7 +120,4 @@ def aten_ops_randperm(
131
120
name : str ,
132
121
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
133
122
device = kwargs .get ("device" , None )
134
- input = args [0 ]
135
- if not isinstance (input , int ):
136
- raise RuntimeError (f"The input must be an integer" )
137
123
return np .random .permutation (* args )
0 commit comments