Skip to content

Commit e0a7525

Browse files
authored
fix: Move aten.neg test case (#2310)
1 parent 044d4d6 commit e0a7525

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,11 @@ def neg(
393393
name: str,
394394
input_val: TRTTensor,
395395
) -> TRTTensor:
396+
if (isinstance(input_val, TRTTensor)) and (
397+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
398+
):
399+
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
400+
396401
return convert_unary(
397402
network, target, source_ir, name, trt.UnaryOperation.NEG, input_val
398403
)

tests/py/dynamo/converters/test_neg_aten.py renamed to tests/py/dynamo/conversion/test_neg_aten.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
55
from torch_tensorrt import Input
6-
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
6+
7+
from .harness import DispatchTestCase
78

89

910
class TestNegConverter(DispatchTestCase):
@@ -43,8 +44,8 @@ def forward(self, input):
4344
self.run_test(
4445
neg(),
4546
inputs,
46-
output_dtypes=[torch.int32],
4747
expected_ops={torch.ops.aten.neg.default},
48+
check_dtype=False,
4849
)
4950

5051

0 commit comments

Comments
 (0)