File tree Expand file tree Collapse file tree 3 files changed +58
-0
lines changed
py/torch_tensorrt/dynamo/conversion
tests/py/dynamo/converters Expand file tree Collapse file tree 3 files changed +58
-0
lines changed Original file line number Diff line number Diff line change @@ -251,6 +251,24 @@ def aten_ops_rsqrt(
251
251
)
252
252
253
253
254
+ @dynamo_tensorrt_converter (torch .ops .aten .neg .default )
255
+ def aten_ops_neg (
256
+ network : TRTNetwork ,
257
+ target : Target ,
258
+ args : Tuple [Argument , ...],
259
+ kwargs : Dict [str , Argument ],
260
+ name : str ,
261
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
262
+
263
+ return impl .unary .neg (
264
+ network ,
265
+ target ,
266
+ SourceIR .ATEN ,
267
+ name ,
268
+ args [0 ],
269
+ )
270
+
271
+
254
272
@dynamo_tensorrt_converter (torch .ops .aten .squeeze .dim )
255
273
@dynamo_tensorrt_converter (torch .ops .aten .squeeze .dims )
256
274
def aten_ops_squeeze (
Original file line number Diff line number Diff line change @@ -96,3 +96,19 @@ def sign(
96
96
double_floor_div_output ,
97
97
1 ,
98
98
)
99
+
100
+ def neg (
101
+ network : TRTNetwork ,
102
+ target : Target ,
103
+ source_ir : Optional [SourceIR ],
104
+ name : str ,
105
+ input_val : TRTTensor ,
106
+ ) -> TRTTensor :
107
+ return convert_unary (
108
+ network ,
109
+ target ,
110
+ source_ir ,
111
+ name ,
112
+ trt .UnaryOperation .EXP ,
113
+ input_val
114
+ )
Original file line number Diff line number Diff line change
1
+ import torch
2
+ import torch .nn as nn
3
+ from parameterized import parameterized
4
+ from torch .testing ._internal .common_utils import run_tests
5
+ from torch_tensorrt .dynamo .test_utils import DispatchTestCase
6
+ from torch_tensorrt import Input
7
+
8
+
9
+ class TestNegConverter (DispatchTestCase ):
10
+ def test_neg (self ):
11
+ class neg (nn .Module ):
12
+ def forward (self , input ):
13
+ return torch .neg (input )
14
+
15
+ inputs = [torch .randn (1 , 10 )]
16
+ self .run_test (
17
+ neg (),
18
+ inputs ,
19
+ expected_ops = {torch .ops .aten .neg .default },
20
+ )
21
+
22
+
23
+ if __name__ == "__main__" :
24
+ run_tests ()
You can’t perform that action at this time.
0 commit comments