@@ -66,102 +66,102 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
66
66
}})
67
67
.pattern(
68
68
{" aten::unflatten.int(Tensor self, int dim, int[] sizes) -> (Tensor)" ,
69
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
70
- auto in = args[0 ].ITensorOrFreeze (ctx);
71
- auto dim = args[1 ].unwrapToInt ();
72
- auto in_shape = util::toVec (in->getDimensions ());
73
- std::vector<int64_t > new_shape;
74
- nvinfer1::ITensor* shape_tensor;
75
- if (ctx->input_is_dynamic ) {
76
- /*
77
- * In case the dim is negative
78
- * If the dim in negative range is larger than in_shape,
79
- * then it should run into index out of bound error as expected
80
- */
81
- if (dim < 0 ) {
82
- dim = in_shape.size () + dim;
83
- }
84
- std::cout << " Dynamic shape case" << std::endl;
85
- LOG_DEBUG (" Using dynamic version of reshape layer" );
86
- if (args[2 ].isITensorList ()) {
87
- std::cout << " isTensorList case" << std::endl;
88
- LOG_DEBUG (" Shape tensor is an ITensorList" );
89
- auto expand_shape = args[2 ].unwrapToITensorList ();
90
- auto shape_layer = ctx->net ->addShape (*in);
91
- TORCHTRT_CHECK (shape_layer, " Unable to create shape layer from node: " << *n);
92
- auto shape_1d_tensor = shape_layer->getOutput (0 );
93
-
94
- std::vector<int > before_dim_indices_vector (dim);
95
- std::iota (before_dim_indices_vector.begin (), before_dim_indices_vector.end (), 0 );
96
-
97
- nvinfer1::ITensor* before_dim_gather_out = nullptr ;
98
- if (before_dim_indices_vector.size ()){
99
- at::Tensor before_dim_indices = torch::tensor (before_dim_indices_vector).to (torch::kI32 );
100
- auto before_dim_indices_out = converters::tensor_to_const (ctx, before_dim_indices);
101
- auto before_dim_gather_layer = ctx->net ->addGather (*shape_1d_tensor, *before_dim_indices_out, 0 );
102
- TORCHTRT_CHECK (before_dim_gather_layer, " Unable to create gather layer from node: " << *n);
103
- before_dim_gather_out = before_dim_gather_layer->getOutput (0 );
104
- }
105
-
106
- std::vector<int > after_dim_indices_vector (in_shape.size () - (dim + 1 ));
107
- std::iota (after_dim_indices_vector.begin (), after_dim_indices_vector.end (), dim + 1 );
108
-
109
- nvinfer1::ITensor* after_dim_gather_out = nullptr ;
110
- if (after_dim_indices_vector.size ()){
111
- at::Tensor after_dim_indices = torch::tensor (after_dim_indices_vector).to (torch::kI32 );
112
- auto after_dim_indices_out = converters::tensor_to_const (ctx, after_dim_indices);
113
- auto after_dim_gather_layer = ctx->net ->addGather (*shape_1d_tensor, *after_dim_indices_out, 0 );
114
- TORCHTRT_CHECK (after_dim_gather_layer, " Unable to create gather layer from node: " << *n);
115
- after_dim_gather_out = after_dim_gather_layer->getOutput (0 );
116
- }
117
-
118
- std::vector<nvinfer1::ITensor*> shape_tensors;
119
- if (before_dim_gather_out){
120
- shape_tensors.push_back (before_dim_gather_out);
121
- }
122
- for (auto new_shape_tensor : expand_shape){
123
- shape_tensors.push_back (new_shape_tensor);
124
- }
125
- if (after_dim_gather_out){
126
- shape_tensors.push_back (after_dim_gather_out);
127
- }
128
-
129
- auto shape_cat_layer = ctx->net ->addConcatenation (shape_tensors.data (), shape_tensors.size ());
130
- TORCHTRT_CHECK (shape_cat_layer, " Unable to create cat layer from node: " << *n);
131
- shape_tensor = shape_cat_layer->getOutput (0 );
132
- LOG_DEBUG (" Shape tensor shape: " << shape_tensor->getDimensions ());
133
- } else if (args[2 ].isIntList ()) {
134
- auto shape_vec = args[2 ].unwrapToIntList ().vec ();
135
- // New shape
136
- new_shape.insert (new_shape.end (), in_shape.begin (), in_shape.begin () + dim);
137
- new_shape.insert (new_shape.end (), shape_vec.begin (), shape_vec.end ());
138
- new_shape.insert (new_shape.end (), in_shape.begin () + dim + 1 , in_shape.end ());
139
-
140
- shape_tensor = tensor_to_const (ctx, torch::tensor (new_shape).to (torch::kI32 ));
141
- } else {
142
- LOG_ERROR (
143
- " Invalid IValue type of " << args[2 ].ivalue_type ()
144
- << " detected for shape tensor from node: " << *n);
145
- }
146
- }
147
- else {
148
- new_shape = torch::unflatten (torch::rand (in_shape), dim, args[2 ].unwrapToIntList ().vec ()).sizes ().vec ();
149
- }
150
- auto shuffle = ctx->net ->addShuffle (*in);
151
- shuffle->setName (util::node_info (n).c_str ());
152
- TORCHTRT_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
153
-
154
- if (ctx->input_is_dynamic ) {
155
- shuffle->setInput (1 , *shape_tensor);
156
- } else {
157
- shuffle->setReshapeDimensions (util::toDims (new_shape));
158
- }
159
-
160
- auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle->getOutput (0 ));
161
- LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
162
-
163
- return true ;
164
- }})
69
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
70
+ auto in = args[0 ].ITensorOrFreeze (ctx);
71
+ auto dim = args[1 ].unwrapToInt ();
72
+ auto in_shape = util::toVec (in->getDimensions ());
73
+ std::vector<int64_t > new_shape;
74
+ nvinfer1::ITensor* shape_tensor;
75
+ if (ctx->input_is_dynamic ) {
76
+ /*
77
+ * In case the dim is negative
78
+ * If the dim in negative range is larger than in_shape,
79
+ * then it should run into index out of bound error as expected
80
+ */
81
+ if (dim < 0 ) {
82
+ dim = in_shape.size () + dim;
83
+ }
84
+ std::cout << " Dynamic shape case" << std::endl;
85
+ LOG_DEBUG (" Using dynamic version of reshape layer" );
86
+ if (args[2 ].isITensorList ()) {
87
+ std::cout << " isTensorList case" << std::endl;
88
+ LOG_DEBUG (" Shape tensor is an ITensorList" );
89
+ auto expand_shape = args[2 ].unwrapToITensorList ();
90
+ auto shape_layer = ctx->net ->addShape (*in);
91
+ TORCHTRT_CHECK (shape_layer, " Unable to create shape layer from node: " << *n);
92
+ auto shape_1d_tensor = shape_layer->getOutput (0 );
93
+
94
+ std::vector<int > before_dim_indices_vector (dim);
95
+ std::iota (before_dim_indices_vector.begin (), before_dim_indices_vector.end (), 0 );
96
+
97
+ nvinfer1::ITensor* before_dim_gather_out = nullptr ;
98
+ if (before_dim_indices_vector.size ()) {
99
+ at::Tensor before_dim_indices = torch::tensor (before_dim_indices_vector).to (torch::kI32 );
100
+ auto before_dim_indices_out = converters::tensor_to_const (ctx, before_dim_indices);
101
+ auto before_dim_gather_layer = ctx->net ->addGather (*shape_1d_tensor, *before_dim_indices_out, 0 );
102
+ TORCHTRT_CHECK (before_dim_gather_layer, " Unable to create gather layer from node: " << *n);
103
+ before_dim_gather_out = before_dim_gather_layer->getOutput (0 );
104
+ }
105
+
106
+ std::vector<int > after_dim_indices_vector (in_shape.size () - (dim + 1 ));
107
+ std::iota (after_dim_indices_vector.begin (), after_dim_indices_vector.end (), dim + 1 );
108
+
109
+ nvinfer1::ITensor* after_dim_gather_out = nullptr ;
110
+ if (after_dim_indices_vector.size ()) {
111
+ at::Tensor after_dim_indices = torch::tensor (after_dim_indices_vector).to (torch::kI32 );
112
+ auto after_dim_indices_out = converters::tensor_to_const (ctx, after_dim_indices);
113
+ auto after_dim_gather_layer = ctx->net ->addGather (*shape_1d_tensor, *after_dim_indices_out, 0 );
114
+ TORCHTRT_CHECK (after_dim_gather_layer, " Unable to create gather layer from node: " << *n);
115
+ after_dim_gather_out = after_dim_gather_layer->getOutput (0 );
116
+ }
117
+
118
+ std::vector<nvinfer1::ITensor*> shape_tensors;
119
+ if (before_dim_gather_out) {
120
+ shape_tensors.push_back (before_dim_gather_out);
121
+ }
122
+ for (auto new_shape_tensor : expand_shape) {
123
+ shape_tensors.push_back (new_shape_tensor);
124
+ }
125
+ if (after_dim_gather_out) {
126
+ shape_tensors.push_back (after_dim_gather_out);
127
+ }
128
+
129
+ auto shape_cat_layer = ctx->net ->addConcatenation (shape_tensors.data (), shape_tensors.size ());
130
+ TORCHTRT_CHECK (shape_cat_layer, " Unable to create cat layer from node: " << *n);
131
+ shape_tensor = shape_cat_layer->getOutput (0 );
132
+ LOG_DEBUG (" Shape tensor shape: " << shape_tensor->getDimensions ());
133
+ } else if (args[2 ].isIntList ()) {
134
+ auto shape_vec = args[2 ].unwrapToIntList ().vec ();
135
+ // New shape
136
+ new_shape.insert (new_shape.end (), in_shape.begin (), in_shape.begin () + dim);
137
+ new_shape.insert (new_shape.end (), shape_vec.begin (), shape_vec.end ());
138
+ new_shape.insert (new_shape.end (), in_shape.begin () + dim + 1 , in_shape.end ());
139
+
140
+ shape_tensor = tensor_to_const (ctx, torch::tensor (new_shape).to (torch::kI32 ));
141
+ } else {
142
+ LOG_ERROR (
143
+ " Invalid IValue type of " << args[2 ].ivalue_type ()
144
+ << " detected for shape tensor from node: " << *n);
145
+ }
146
+ } else {
147
+ new_shape =
148
+ torch::unflatten (torch::rand (in_shape), dim, args[2 ].unwrapToIntList ().vec ()).sizes ().vec ();
149
+ }
150
+ auto shuffle = ctx->net ->addShuffle (*in);
151
+ shuffle->setName (util::node_info (n).c_str ());
152
+ TORCHTRT_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
153
+
154
+ if (ctx->input_is_dynamic ) {
155
+ shuffle->setInput (1 , *shape_tensor);
156
+ } else {
157
+ shuffle->setReshapeDimensions (util::toDims (new_shape));
158
+ }
159
+
160
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle->getOutput (0 ));
161
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
162
+
163
+ return true ;
164
+ }})
165
165
.pattern(
166
166
{" aten::reshape(Tensor self, int[] shape) -> (Tensor)" ,
167
167
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
0 commit comments