1
1
from typing import Optional , Sequence , Union
2
2
3
+ import numpy as np
4
+ import tensorrt as trt
3
5
from torch .fx .node import Target
4
6
from torch_tensorrt .dynamo ._SourceIR import SourceIR
5
7
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
6
- from torch_tensorrt .dynamo .conversion .converter_utils import get_positive_dim
7
- from torch_tensorrt .fx .converters .converter_utils import set_layer_name
8
+ from torch_tensorrt .dynamo .conversion .converter_utils import (
9
+ cast_trt_tensor ,
10
+ get_positive_dim ,
11
+ set_layer_name ,
12
+ )
13
+ from torch_tensorrt .dynamo .conversion .impl .elementwise import ne
8
14
from torch_tensorrt .fx .types import TRTTensor
9
- from torch_tensorrt .fx .utils import get_dynamic_dims
10
15
11
16
12
17
def squeeze (
@@ -29,24 +34,90 @@ def squeeze(
29
34
dims .append (dim )
30
35
31
36
new_dims = []
37
+ dim_has_dynamic_shape = False
32
38
for dim in dims :
33
39
dim = get_positive_dim (
34
40
dim ,
35
41
len (input .shape ),
36
42
)
37
43
38
- assert input .shape [dim ] != - 1 , "We don't support squeeze dynamic dim."
39
- assert (
40
- len (get_dynamic_dims (input .shape )) <= 1
41
- ), "Currently more than one dynamic dim for input to squeeze is not supported."
44
+ if input .shape [dim ] == - 1 :
45
+ dim_has_dynamic_shape = True
42
46
new_dims .append (dim )
43
47
44
- output_shape = []
45
- for i , s in enumerate (input .shape ):
46
- if (i in new_dims ) and s == 1 :
47
- continue
48
- output_shape .append (s )
49
48
layer = ctx .net .add_shuffle (input )
50
- layer .reshape_dims = tuple (output_shape )
51
49
set_layer_name (layer , target , name , source_ir )
50
+ if dim_has_dynamic_shape :
51
+ num_shape = len (input .shape )
52
+
53
+ tensor_shape_layer = ctx .net .add_shape (input )
54
+ tensor_shape = tensor_shape_layer .get_output (0 )
55
+ tensor_shape = cast_trt_tensor (
56
+ ctx , tensor_shape , trt .int32 , name + "shape_casted" , "shape"
57
+ )
58
+
59
+ # change it to get_trt_tensor
60
+ one_layer = ctx .net .add_constant (
61
+ (num_shape ,),
62
+ np .ascontiguousarray ([1 ] * num_shape , np .int32 ),
63
+ )
64
+ set_layer_name (one_layer , target , name + "_one" , source_ir )
65
+
66
+ zero_layer = ctx .net .add_constant (
67
+ (num_shape ,),
68
+ np .zeros ((num_shape ,), dtype = np .int32 ),
69
+ )
70
+ set_layer_name (zero_layer , target , name + "_zero" , source_ir )
71
+
72
+ # append last element value
73
+ num_append = num_shape - len (new_dims )
74
+ if num_append > 0 :
75
+ new_dims += [new_dims [- 1 ]] * num_append
76
+
77
+ index_value = np .array (new_dims , dtype = np .int32 )
78
+ index_layer = ctx .net .add_constant (index_value .shape , index_value )
79
+ set_layer_name (index_layer , target , name + "_index" , source_ir )
80
+
81
+ scatter_layer = ctx .net .add_scatter (
82
+ zero_layer .get_output (0 ),
83
+ index_layer .get_output (0 ),
84
+ one_layer .get_output (0 ),
85
+ trt .ScatterMode .ELEMENT ,
86
+ )
87
+ set_layer_name (scatter_layer , target , name + "_scatter" , source_ir )
88
+
89
+ # [1, 2, 1, 3, 1]
90
+ # [0, 0, 1, 1, 1]
91
+ # [t, t, f, t, f]
92
+ ne_tensor = ne (
93
+ ctx ,
94
+ target ,
95
+ source_ir ,
96
+ name + "_ne" ,
97
+ tensor_shape ,
98
+ scatter_layer .get_output (0 ),
99
+ )
100
+
101
+ # [t, t, f, t, f] -> [0, 1, 3]
102
+ non_zero_layer = ctx .net .add_non_zero (ne_tensor )
103
+ set_layer_name (non_zero_layer , target , name + "_non_zero" , source_ir )
104
+
105
+ non_zero_shuffle_layer = ctx .net .add_shuffle (non_zero_layer .get_output (0 ))
106
+ set_layer_name (non_zero_shuffle_layer , target , name + "_shuffle" , source_ir )
107
+ non_zero_shuffle_layer .second_transpose = (1 , 0 )
108
+
109
+ # (1,2,1,3,1) + [0, 1, 3 ,4] -> [1, 2, 3, 1]
110
+ gather_layer = ctx .net .add_gather_v2 (
111
+ tensor_shape , non_zero_shuffle_layer .get_output (0 ), mode = trt .GatherMode .ND
112
+ )
113
+ set_layer_name (gather_layer , target , name + "_gather" , source_ir )
114
+
115
+ layer .set_input (1 , gather_layer .get_output (0 ))
116
+ else :
117
+ output_shape = []
118
+ for i , s in enumerate (input .shape ):
119
+ if (i in new_dims ) and s == 1 :
120
+ continue
121
+ output_shape .append (s )
122
+ layer .reshape_dims = tuple (output_shape )
52
123
return layer .get_output (0 )
0 commit comments