Skip to content

Commit 3d149ef

Browse files
committed
chore: updates
1 parent 16088e6 commit 3d149ef

File tree

1 file changed

+13
-0
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+13
-0
lines changed

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch_tensorrt.dynamo._SourceIR import SourceIR
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
12+
cast_trt_tensor,
1213
get_positive_dim,
1314
get_trt_tensor,
1415
)
@@ -38,6 +39,12 @@ def shape(
3839
"""
3940
shape_layer = ctx.net.add_shape(input_val)
4041
input_shape = shape_layer.get_output(0)
42+
input_shape = cast_trt_tensor(
43+
ctx,
44+
input_shape,
45+
trt.int32,
46+
name + "_shape_casted",
47+
)
4148
set_layer_name(shape_layer, target, name + "_shape", source_ir)
4249

4350
n_dims = len(input_val.shape)
@@ -82,6 +89,12 @@ def get_shape_with_dynamic_shape(
8289
"""
8390
# Ger real shape info for input_val
8491
input_shape = ctx.net.add_shape(input_val).get_output(0)
92+
input_shape = cast_trt_tensor(
93+
ctx,
94+
input_shape,
95+
trt.int32,
96+
name + "_int32_casted",
97+
)
8598
# input_shape.dtype is int64 in TRT 10.0
8699
input_np_dtype = unified_dtype_converter(input_shape.dtype, Frameworks.NUMPY)
87100
scale_layer = ctx.net.add_constant(

0 commit comments

Comments
 (0)