@@ -62,7 +62,9 @@ def rand_validator(rand_node: Node) -> bool:
62
62
return False
63
63
64
64
65
- @dynamo_tensorrt_converter (torch .ops .aten .rand .default )
65
+ @dynamo_tensorrt_converter (
66
+ torch .ops .aten .rand .default , capability_validator = rand_validator
67
+ )
66
68
def aten_ops_rand (
67
69
ctx : ConversionContext ,
68
70
target : Target ,
@@ -71,7 +73,7 @@ def aten_ops_rand(
71
73
name : str ,
72
74
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
73
75
device = kwargs .get ("device" , None )
74
- return np .random .rand (* args ). to ( device = device )
76
+ return np .random .rand (* args )
75
77
76
78
77
79
def randn_validator (randn_node : Node ) -> bool :
@@ -87,7 +89,9 @@ def randn_validator(randn_node: Node) -> bool:
87
89
return False
88
90
89
91
90
- @dynamo_tensorrt_converter (torch .ops .aten .randn .default )
92
+ @dynamo_tensorrt_converter (
93
+ torch .ops .aten .randn .default , capability_validator = randn_validator
94
+ )
91
95
def aten_ops_randn (
92
96
ctx : ConversionContext ,
93
97
target : Target ,
@@ -96,7 +100,7 @@ def aten_ops_randn(
96
100
name : str ,
97
101
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
98
102
device = kwargs .get ("device" , None )
99
- return np .random .randn (* args ). to ( device = device )
103
+ return np .random .randn (* args )
100
104
101
105
102
106
def randperm_validator (randperm_node : Node ) -> bool :
@@ -112,7 +116,9 @@ def randperm_validator(randperm_node: Node) -> bool:
112
116
return False
113
117
114
118
115
- @dynamo_tensorrt_converter (torch .ops .aten .randperm .default )
119
+ @dynamo_tensorrt_converter (
120
+ torch .ops .aten .randperm .default , capability_validator = randperm_validator
121
+ )
116
122
def aten_ops_randperm (
117
123
ctx : ConversionContext ,
118
124
target : Target ,
@@ -124,4 +130,4 @@ def aten_ops_randperm(
124
130
input = args [0 ]
125
131
if not isinstance (input , int ):
126
132
raise RuntimeError (f"The input must be an integer" )
127
- return np .random .randperm (* args ). to ( device = device )
133
+ return np .random .permutation (* args )
0 commit comments