@@ -47,3 +47,83 @@ def aten_ops_arange_start_step(
47
47
name : str ,
48
48
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
49
49
return np .arange (* args )
50
+
51
+
52
+ def rand_validator (rand_node : Node ) -> bool :
53
+ dtype = rand_node .kwargs .get ("dtype" , None )
54
+ layout = rand_node .kwargs .get ("layout" , None )
55
+ if dtype is not None :
56
+ _LOGGER .debug (
57
+ f"Currently we don't support specifying output dtype, got { dtype } ."
58
+ )
59
+ return False
60
+ if layout is not None :
61
+ _LOGGER .debug (
62
+ f"Currently we don't support specifying layout, got { layout } ."
63
+ )
64
+ return False
65
+ @dynamo_tensorrt_converter (torch .ops .aten .rand .default )
66
+ def aten_ops_rand (
67
+ ctx : ConversionContext ,
68
+ target : Target ,
69
+ args : Tuple [Argument , ...],
70
+ kwargs : Dict [str , Argument ],
71
+ name : str ,
72
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
73
+ device = kwargs .get ("device" , None )
74
+ return np .random .rand (* args ).to (device = device )
75
+
76
+
77
+ def randn_validator (randn_node : Node ) -> bool :
78
+ dtype = randn_node .kwargs .get ("dtype" , None )
79
+ layout = randn_node .kwargs .get ("layout" , None )
80
+ if dtype is not None :
81
+ _LOGGER .debug (
82
+ f"Currently we don't support specifying output dtype, got { dtype } ."
83
+ )
84
+ return False
85
+ if layout is not None :
86
+ _LOGGER .debug (
87
+ f"Currently we don't support specifying layout, got { layout } ."
88
+ )
89
+ return False
90
+ @dynamo_tensorrt_converter (torch .ops .aten .randn .default )
91
+ def aten_ops_randn (
92
+ ctx : ConversionContext ,
93
+ target : Target ,
94
+ args : Tuple [Argument , ...],
95
+ kwargs : Dict [str , Argument ],
96
+ name : str ,
97
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
98
+ device = kwargs .get ("device" , None )
99
+ return np .random .randn (* args ).to (device = device )
100
+
101
+
102
+ def randperm_validator (randperm_node : Node ) -> bool :
103
+ dtype = randperm_node .kwargs .get ("dtype" , None )
104
+ layout = randperm_node .kwargs .get ("layout" , None )
105
+ if dtype is not None :
106
+ _LOGGER .debug (
107
+ f"Currently we don't support specifying output dtype, got { dtype } ."
108
+ )
109
+ return False
110
+ if layout is not None :
111
+ _LOGGER .debug (
112
+ f"Currently we don't support specifying layout, got { layout } ."
113
+ )
114
+ return False
115
+ @dynamo_tensorrt_converter (torch .ops .aten .randperm .default )
116
+ def aten_ops_randperm (
117
+ ctx : ConversionContext ,
118
+ target : Target ,
119
+ args : Tuple [Argument , ...],
120
+ kwargs : Dict [str , Argument ],
121
+ name : str ,
122
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
123
+ device = kwargs .get ("device" , None )
124
+ input = args [0 ]
125
+ if not isinstance (input , int ):
126
+ raise RuntimeError (
127
+ f"The input must be an integer"
128
+ )
129
+ return np .random .randperm (* args ).to (device = device )
0 commit comments