File tree Expand file tree Collapse file tree 1 file changed +13
-0
lines changed
py/torch_tensorrt/dynamo/conversion/impl Expand file tree Collapse file tree 1 file changed +13
-0
lines changed Original file line number Diff line number Diff line change 9
9
from torch_tensorrt .dynamo ._SourceIR import SourceIR
10
10
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
11
11
from torch_tensorrt .dynamo .conversion .converter_utils import (
12
+ cast_trt_tensor ,
12
13
get_positive_dim ,
13
14
get_trt_tensor ,
14
15
)
@@ -38,6 +39,12 @@ def shape(
38
39
"""
39
40
shape_layer = ctx .net .add_shape (input_val )
40
41
input_shape = shape_layer .get_output (0 )
42
+ input_shape = cast_trt_tensor (
43
+ ctx ,
44
+ input_shape ,
45
+ trt .int32 ,
46
+ name + "_shape_casted" ,
47
+ )
41
48
set_layer_name (shape_layer , target , name + "_shape" , source_ir )
42
49
43
50
n_dims = len (input_val .shape )
@@ -82,6 +89,12 @@ def get_shape_with_dynamic_shape(
82
89
"""
83
90
# Ger real shape info for input_val
84
91
input_shape = ctx .net .add_shape (input_val ).get_output (0 )
92
+ input_shape = cast_trt_tensor (
93
+ ctx ,
94
+ input_shape ,
95
+ trt .int32 ,
96
+ name + "_int32_casted" ,
97
+ )
85
98
# input_shape.dtype is int64 in TRT 10.0
86
99
input_np_dtype = unified_dtype_converter (input_shape .dtype , Frameworks .NUMPY )
87
100
scale_layer = ctx .net .add_constant (
You can’t perform that action at this time.
0 commit comments