Skip to content

Commit 52b89ed

Browse files
committed
update type enforcement
1 parent eb480ee commit 52b89ed

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,6 +1649,11 @@ def aten_ops_logical_xor(
16491649

16501650
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc]
16511651
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc]
1652+
@enforce_tensor_types(
1653+
{
1654+
0: (TRTTensor,),
1655+
}
1656+
) # type: ignore[misc]
16521657
def aten_ops_eq(
16531658
ctx: ConversionContext,
16541659
target: Target,
@@ -1668,6 +1673,11 @@ def aten_ops_eq(
16681673

16691674
@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor) # type: ignore[misc]
16701675
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar) # type: ignore[misc]
1676+
@enforce_tensor_types(
1677+
{
1678+
0: (TRTTensor,),
1679+
}
1680+
) # type: ignore[misc]
16711681
def aten_ops_ne(
16721682
ctx: ConversionContext,
16731683
target: Target,
@@ -1687,6 +1697,11 @@ def aten_ops_ne(
16871697

16881698
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc]
16891699
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc]
1700+
@enforce_tensor_types(
1701+
{
1702+
0: (TRTTensor,),
1703+
}
1704+
) # type: ignore[misc]
16901705
def aten_ops_gt(
16911706
ctx: ConversionContext,
16921707
target: Target,
@@ -1706,6 +1721,11 @@ def aten_ops_gt(
17061721

17071722
@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) # type: ignore[misc]
17081723
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) # type: ignore[misc]
1724+
@enforce_tensor_types(
1725+
{
1726+
0: (TRTTensor,),
1727+
}
1728+
) # type: ignore[misc]
17091729
def aten_ops_ge(
17101730
ctx: ConversionContext,
17111731
target: Target,
@@ -1725,6 +1745,11 @@ def aten_ops_ge(
17251745

17261746
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc]
17271747
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc]
1748+
@enforce_tensor_types(
1749+
{
1750+
0: (TRTTensor,),
1751+
}
1752+
) # type: ignore[misc]
17281753
def aten_ops_lt(
17291754
ctx: ConversionContext,
17301755
target: Target,
@@ -1744,6 +1769,11 @@ def aten_ops_lt(
17441769

17451770
@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor) # type: ignore[misc]
17461771
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar) # type: ignore[misc]
1772+
@enforce_tensor_types(
1773+
{
1774+
0: (TRTTensor,),
1775+
}
1776+
) # type: ignore[misc]
17471777
def aten_ops_le(
17481778
ctx: ConversionContext,
17491779
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Union
1+
from typing import Optional, Sequence, Union
22

33
import numpy as np
44
import tensorrt as trt
@@ -428,8 +428,8 @@ def eq(
428428
target: Target,
429429
source_ir: Optional[SourceIR],
430430
name: str,
431-
lhs_val: Union[TRTTensor, int, float],
432-
rhs_val: Union[TRTTensor, int, float],
431+
lhs_val: TRTTensor,
432+
rhs_val: Union[TRTTensor, int, float, bool, Sequence[Union[int, float, bool]]],
433433
) -> TRTTensor:
434434
return convert_binary_elementwise(
435435
ctx,
@@ -447,8 +447,8 @@ def ne(
447447
target: Target,
448448
source_ir: Optional[SourceIR],
449449
name: str,
450-
lhs_val: Union[TRTTensor, int, float],
451-
rhs_val: Union[TRTTensor, int, float],
450+
lhs_val: TRTTensor,
451+
rhs_val: Union[TRTTensor, int, float, bool, Sequence[Union[int, float, bool]]],
452452
) -> TRTTensor:
453453
return impl.unary.logical_not(
454454
ctx,
@@ -464,8 +464,8 @@ def gt(
464464
target: Target,
465465
source_ir: Optional[SourceIR],
466466
name: str,
467-
lhs_val: Union[TRTTensor, int, float],
468-
rhs_val: Union[TRTTensor, int, float],
467+
lhs_val: TRTTensor,
468+
rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]],
469469
) -> TRTTensor:
470470
return convert_binary_elementwise(
471471
ctx,
@@ -483,8 +483,8 @@ def ge(
483483
target: Target,
484484
source_ir: Optional[SourceIR],
485485
name: str,
486-
lhs_val: Union[TRTTensor, int, float],
487-
rhs_val: Union[TRTTensor, int, float],
486+
lhs_val: TRTTensor,
487+
rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]],
488488
) -> TRTTensor:
489489
return logical_or(
490490
ctx,
@@ -501,8 +501,8 @@ def lt(
501501
target: Target,
502502
source_ir: Optional[SourceIR],
503503
name: str,
504-
lhs_val: Union[TRTTensor, int, float],
505-
rhs_val: Union[TRTTensor, int, float],
504+
lhs_val: TRTTensor,
505+
rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]],
506506
) -> TRTTensor:
507507
return convert_binary_elementwise(
508508
ctx,
@@ -520,8 +520,8 @@ def le(
520520
target: Target,
521521
source_ir: Optional[SourceIR],
522522
name: str,
523-
lhs_val: Union[TRTTensor, int, float],
524-
rhs_val: Union[TRTTensor, int, float],
523+
lhs_val: TRTTensor,
524+
rhs_val: Union[TRTTensor, int, float, Sequence[Union[int, float]]],
525525
) -> TRTTensor:
526526
return logical_or(
527527
ctx,

0 commit comments

Comments
 (0)