5
5
from torch_tensorrt .fx .converters .converter_utils import (
6
6
get_positive_dim ,
7
7
get_trt_tensor ,
8
+ prepend_ones ,
8
9
set_layer_name ,
9
10
)
10
11
from torch_tensorrt .fx .types import Shape , TRTNetwork , TRTTensor
@@ -27,7 +28,7 @@ def unsqueeze(
27
28
)
28
29
29
30
dim = cast (int , dim )
30
- input_shape = input_val . shape
31
+
31
32
input_shape_size = (
32
33
len (input_val .shape ) + 1
33
34
if network .has_implicit_batch_dimension
@@ -46,5 +47,55 @@ def unsqueeze(
46
47
layer .reshape_dims = (
47
48
tuple (input_val .shape )[:dim ] + (1 ,) + tuple (input_val .shape )[dim :]
48
49
)
49
- set_layer_name (layer , target , name )
50
+ set_layer_name (layer , target , name , source_ir )
51
+ return layer .get_output (0 )
52
+
53
+
54
+ def expand (
55
+ network : TRTNetwork ,
56
+ target : Target ,
57
+ source_ir : Optional [SourceIR ],
58
+ name : str ,
59
+ input_t : TRTTensor ,
60
+ shape : Shape ,
61
+ ) -> TRTTensor :
62
+ if not isinstance (input_t , TRTTensor ):
63
+ raise RuntimeError (
64
+ f"expand received input { input_t } that is not a TensorRT ITensor"
65
+ )
66
+
67
+ shape_rank = len (shape )
68
+ initial_tensor_rank = len (input_t .shape )
69
+
70
+ # If the rank of the input tensor is less than the shape's rank, pad with ones
71
+ if initial_tensor_rank < shape_rank :
72
+ input_t = prepend_ones (
73
+ network ,
74
+ input_t ,
75
+ name + "_expand_ones_padding" ,
76
+ shape_rank - initial_tensor_rank ,
77
+ )
78
+ # If the rank of the input tensor is more than the shape's rank, raise error
79
+ elif initial_tensor_rank > shape_rank :
80
+ raise RuntimeError (
81
+ f"expand called with { shape_rank } -dimensional shape on Tensor with { len (shape )} dimensions. "
82
+ "Cannot expand to shape with rank smaller than original tensor."
83
+ )
84
+
85
+ # After the above padding, the shape and tensor rank must be equal
86
+ assert len (input_t .shape ) == shape_rank
87
+
88
+ # -1 denotes taking the shape from the original input tensor
89
+ shape = tuple (
90
+ [input_t .shape [i ] if shape [i ] == - 1 else shape [i ] for i in range (shape_rank )]
91
+ )
92
+
93
+ # Establish the desired output shape, strides, and starting indices
94
+ input_tensor_shape = tuple (input_t .shape )
95
+ start = tuple ([0 ] * shape_rank )
96
+ stride = tuple (
97
+ [int (i == o ) for i , o in zip (input_tensor_shape , shape )]
98
+ ) # stride == 1 if dimensions match, 0 otherwise
99
+ layer = network .add_slice (input_t , start = start , shape = shape , stride = stride )
100
+ set_layer_name (layer , target , name , source_ir )
50
101
return layer .get_output (0 )
0 commit comments