Skip to content

Commit ac1e7db

Browse files
committed
update type enforcement
1 parent 03f0ebb commit ac1e7db

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,6 +1703,11 @@ def aten_ops_logical_xor(
17031703

17041704
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc]
17051705
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc]
1706+
@enforce_tensor_types(
1707+
{
1708+
0: (TRTTensor,),
1709+
}
1710+
) # type: ignore[misc]
17061711
def aten_ops_eq(
17071712
ctx: ConversionContext,
17081713
target: Target,
@@ -1722,6 +1727,11 @@ def aten_ops_eq(
17221727

17231728
@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor) # type: ignore[misc]
17241729
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar) # type: ignore[misc]
1730+
@enforce_tensor_types(
1731+
{
1732+
0: (TRTTensor,),
1733+
}
1734+
) # type: ignore[misc]
17251735
def aten_ops_ne(
17261736
ctx: ConversionContext,
17271737
target: Target,
@@ -1741,6 +1751,11 @@ def aten_ops_ne(
17411751

17421752
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc]
17431753
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc]
1754+
@enforce_tensor_types(
1755+
{
1756+
0: (TRTTensor,),
1757+
}
1758+
) # type: ignore[misc]
17441759
def aten_ops_gt(
17451760
ctx: ConversionContext,
17461761
target: Target,
@@ -1760,6 +1775,11 @@ def aten_ops_gt(
17601775

17611776
@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) # type: ignore[misc]
17621777
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) # type: ignore[misc]
1778+
@enforce_tensor_types(
1779+
{
1780+
0: (TRTTensor,),
1781+
}
1782+
) # type: ignore[misc]
17631783
def aten_ops_ge(
17641784
ctx: ConversionContext,
17651785
target: Target,
@@ -1779,6 +1799,11 @@ def aten_ops_ge(
17791799

17801800
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc]
17811801
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc]
1802+
@enforce_tensor_types(
1803+
{
1804+
0: (TRTTensor,),
1805+
}
1806+
) # type: ignore[misc]
17821807
def aten_ops_lt(
17831808
ctx: ConversionContext,
17841809
target: Target,
@@ -1798,6 +1823,11 @@ def aten_ops_lt(
17981823

17991824
@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor) # type: ignore[misc]
18001825
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar) # type: ignore[misc]
1826+
@enforce_tensor_types(
1827+
{
1828+
0: (TRTTensor,),
1829+
}
1830+
) # type: ignore[misc]
18011831
def aten_ops_le(
18021832
ctx: ConversionContext,
18031833
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Union
1+
from typing import Optional, Sequence, Union
22

33
import numpy as np
44
import tensorrt as trt
@@ -428,8 +428,8 @@ def eq(
428428
target: Target,
429429
source_ir: Optional[SourceIR],
430430
name: str,
431-
lhs_val: Union[TRTTensor, int, float],
432-
rhs_val: Union[TRTTensor, int, float],
431+
lhs_val: TRTTensor,
432+
rhs_val: Union[TRTTensor, int, float, bool, Sequence[Union[int, float, bool]]],
433433
) -> TRTTensor:
434434
return convert_binary_elementwise(
435435
ctx,
@@ -447,8 +447,8 @@ def ne(
447447
target: Target,
448448
source_ir: Optional[SourceIR],
449449
name: str,
450-
lhs_val: Union[TRTTensor, int, float],
451-
rhs_val: Union[TRTTensor, int, float],
450+
lhs_val: TRTTensor,
451+
rhs_val: Union[TRTTensor, int, float, bool, Sequence[Union[int, float, bool]]],
452452
) -> TRTTensor:
453453
return impl.unary.logical_not(
454454
ctx,
@@ -464,8 +464,8 @@ def gt(
464464
target: Target,
465465
source_ir: Optional[SourceIR],
466466
name: str,
467-
lhs_val: Union[TRTTensor, int, float],
468-
rhs_val: Union[TRTTensor, int, float],
467+
lhs_val: TRTTensor,
468+
rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]],
469469
) -> TRTTensor:
470470
return convert_binary_elementwise(
471471
ctx,
@@ -483,8 +483,8 @@ def ge(
483483
target: Target,
484484
source_ir: Optional[SourceIR],
485485
name: str,
486-
lhs_val: Union[TRTTensor, int, float],
487-
rhs_val: Union[TRTTensor, int, float],
486+
lhs_val: TRTTensor,
487+
rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]],
488488
) -> TRTTensor:
489489
return logical_or(
490490
ctx,
@@ -501,8 +501,8 @@ def lt(
501501
target: Target,
502502
source_ir: Optional[SourceIR],
503503
name: str,
504-
lhs_val: Union[TRTTensor, int, float],
505-
rhs_val: Union[TRTTensor, int, float],
504+
lhs_val: TRTTensor,
505+
rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]],
506506
) -> TRTTensor:
507507
return convert_binary_elementwise(
508508
ctx,
@@ -520,8 +520,8 @@ def le(
520520
target: Target,
521521
source_ir: Optional[SourceIR],
522522
name: str,
523-
lhs_val: Union[TRTTensor, int, float],
524-
rhs_val: Union[TRTTensor, int, float],
523+
lhs_val: TRTTensor,
524+
rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]],
525525
) -> TRTTensor:
526526
return logical_or(
527527
ctx,

0 commit comments

Comments
 (0)