Skip to content

Commit 68ebf6d

Browse files
committed
add initial support for torch.ops.aten.neg.default converter
1 parent aa1c843 commit 68ebf6d

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,24 @@ def aten_ops_rsqrt(
251251
)
252252

253253

254+
@dynamo_tensorrt_converter(torch.ops.aten.neg.default)
255+
def aten_ops_neg(
256+
network: TRTNetwork,
257+
target: Target,
258+
args: Tuple[Argument, ...],
259+
kwargs: Dict[str, Argument],
260+
name: str,
261+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
262+
263+
return impl.unary.neg(
264+
network,
265+
target,
266+
SourceIR.ATEN,
267+
name,
268+
args[0],
269+
)
270+
271+
254272
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim)
255273
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims)
256274
def aten_ops_squeeze(

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,19 @@ def sign(
9696
double_floor_div_output,
9797
1,
9898
)
99+
100+
def neg(
101+
network: TRTNetwork,
102+
target: Target,
103+
source_ir: Optional[SourceIR],
104+
name: str,
105+
input_val: TRTTensor,
106+
) -> TRTTensor:
107+
return convert_unary(
108+
network,
109+
target,
110+
source_ir,
111+
name,
112+
trt.UnaryOperation.EXP,
113+
input_val
114+
)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
6+
from torch_tensorrt import Input
7+
8+
9+
class TestNegConverter(DispatchTestCase):
10+
def test_neg(self):
11+
class neg(nn.Module):
12+
def forward(self, input):
13+
return torch.neg(input)
14+
15+
inputs = [torch.randn(1, 10)]
16+
self.run_test(
17+
neg(),
18+
inputs,
19+
expected_ops={torch.ops.aten.neg.default},
20+
)
21+
22+
23+
if __name__ == "__main__":
24+
run_tests()

0 commit comments

Comments
 (0)