Skip to content

Commit 0bf93c6

Browse files
committed
fix: Formatting + TRTTensor casting
1 parent ffe53e0 commit 0bf93c6

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def aten_ops_batch_norm(
7070
)
7171

7272

73-
@dynamo_tensorrt_converter(torch.ops.aten.cat.default)
73+
@dynamo_tensorrt_converter(torch.ops.aten.cat.default) # type: ignore[misc]
7474
def aten_ops_cat(
7575
ctx: ConversionContext,
7676
target: Target,
@@ -1724,6 +1724,7 @@ def aten_ops_reshape(
17241724
)
17251725

17261726

1727+
@enforce_tensor_types({0: (TRTTensor,)}) # type: ignore[misc]
17271728
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default) # type: ignore[misc]
17281729
def aten_ops_argmax(
17291730
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/argmax.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Union
1+
from typing import Optional
22

33
import tensorrt as trt
44
from torch.fx.node import Target
@@ -24,14 +24,9 @@ def argmax(
2424
source_ir: Optional[SourceIR],
2525
name: str,
2626
input: TRTTensor,
27-
dim: Union[int, None],
27+
dim: Optional[int],
2828
keep_dim: bool = False,
2929
) -> TRTTensor:
30-
if not isinstance(input, TRTTensor):
31-
raise RuntimeError(
32-
f"argmax received input {input} that is not part " "of the TensorRT region!"
33-
)
34-
3530
if input.dtype == trt.int32:
3631
input = cast_trt_tensor(ctx, input, trt.float32, name)
3732

@@ -51,8 +46,6 @@ def argmax(
5146
shuffle_layer.reshape_dims = (*input.shape, 1)
5247
set_layer_name(shuffle_layer, target, name + "_broadcast")
5348
out = shuffle_layer.get_output(0)
54-
elif dim < 0:
55-
dim = len(tuple(input.shape)) + dim
5649

5750
reduce_mask = get_axes_for_reduce_op(0)
5851
if dim is not None:

0 commit comments

Comments
 (0)