@@ -34,12 +34,12 @@ void resize_conv2d_node(
34
34
if (ndim == 4 ) {
35
35
new_out_sizes.at (ndim - 4 ) = self.sizes ().at (ndim - 4 );
36
36
}
37
- const auto weight_sizes = graph->get_val (extra_args[0 ]).toTensorRef ().sizes ;
37
+ const auto & weight_sizes = graph->get_val (extra_args[0 ]).toTensorRef ().sizes ;
38
38
new_out_sizes.at (ndim - 3 ) =
39
39
transposed ? weight_sizes.at (ndim - 3 ) : weight_sizes.at (ndim - 4 );
40
40
41
41
// Height, Width
42
- const auto new_out_sizes_hw = calc_out_sizes_hw (
42
+ const auto & new_out_sizes_hw = calc_out_sizes_hw (
43
43
*graph,
44
44
self.sizes (),
45
45
extra_args[0 ],
@@ -87,13 +87,24 @@ enum class Conv2dMethod : uint8_t {
87
87
};
88
88
89
89
api::ShaderInfo get_conv2d_shader (
90
+ ComputeGraph& graph,
90
91
const vTensor& t_out,
91
92
const bool prepack_weights,
92
- const Conv2dMethod method) {
93
+ const Conv2dMethod method,
94
+ const ValueRef weight) {
93
95
std::stringstream kernel_name;
94
96
switch (method) {
95
97
case Conv2dMethod::Depthwise:
96
98
kernel_name << " conv2d_dw" ;
99
+ if (!prepack_weights) {
100
+ const auto & weight_sizes = graph.get_val (weight).toTensorRef ().sizes ;
101
+ if (weight_sizes.at (2 ) == 3 && weight_sizes.at (3 ) == 3 ) {
102
+ kernel_name << " _output_tile_3x3" ;
103
+ }
104
+ if (weight_sizes.at (2 ) == 5 && weight_sizes.at (3 ) == 5 ) {
105
+ kernel_name << " _output_tile_5x5" ;
106
+ }
107
+ }
97
108
break ;
98
109
case Conv2dMethod::SlidingWindow:
99
110
kernel_name << " conv2d" ;
@@ -156,7 +167,7 @@ ValueRef prepack_weights(
156
167
const ValueRef vref,
157
168
const Conv2dMethod method) {
158
169
const auto original_sizes = graph.get_val (vref).toTensorRef ().sizes ;
159
- const auto final_sizes = get_final_sizes (original_sizes, method);
170
+ const auto & final_sizes = get_final_sizes (original_sizes, method);
160
171
161
172
ValueRef v = graph.add_tensor (
162
173
final_sizes,
@@ -169,9 +180,9 @@ ValueRef prepack_weights(
169
180
api::utils::uvec3 local_size = adaptive_work_group_size (global_size);
170
181
171
182
api::ShaderInfo shader =
172
- get_conv2d_shader (t, /* prepack_weights = */ true , method);
183
+ get_conv2d_shader (graph, t, /* prepack_weights = */ true , method, vref );
173
184
174
- const auto padded_sizes = get_padded_sizes (original_sizes, method);
185
+ const auto & padded_sizes = get_padded_sizes (original_sizes, method);
175
186
176
187
graph.prepack_nodes ().emplace_back (new PrepackNode (
177
188
graph,
@@ -210,13 +221,13 @@ Conv2dParams create_conv2d_params(
210
221
const ValueRef weight,
211
222
const KernelParams& p,
212
223
const bool transposed) {
213
- const auto overlay_region = api::utils::make_ivec2 ({
224
+ const auto & overlay_region = api::utils::make_ivec2 ({
214
225
p.kernel_size .data [0 ] +
215
226
(p.kernel_size .data [0 ] - 1 ) * (p.dilation .data [0 ] - 1 ),
216
227
p.kernel_size .data [1 ] +
217
228
(p.kernel_size .data [1 ] - 1 ) * (p.dilation .data [1 ] - 1 ),
218
229
});
219
- const auto weight_sizes = graph.get_val (weight).toTensorRef ().sizes ;
230
+ const auto & weight_sizes = graph.get_val (weight).toTensorRef ().sizes ;
220
231
const int32_t in_group_size =
221
232
api::utils::safe_downcast<int32_t >(api::utils::align_up (
222
233
transposed ? weight_sizes.at (0 ) : weight_sizes.at (1 ), INT64_C (4 )));
@@ -244,7 +255,7 @@ Conv2dMethod get_conv2d_method(
244
255
const ValueRef weight,
245
256
const int64_t groups,
246
257
const bool transposed) {
247
- const auto weight_sizes = graph.get_val (weight).toTensorRef ().sizes ;
258
+ const auto & weight_sizes = graph.get_val (weight).toTensorRef ().sizes ;
248
259
if (!transposed && weight_sizes.at (0 ) == groups && weight_sizes.at (1 ) == 1 ) {
249
260
return Conv2dMethod::Depthwise;
250
261
}
@@ -298,8 +309,8 @@ void add_conv2d_node(
298
309
299
310
check_conv2d_params (kernel_params, transposed_val);
300
311
301
- api::ShaderInfo shader =
302
- get_conv2d_shader ( t_out, /* prepack_weights = */ false , method);
312
+ api::ShaderInfo shader = get_conv2d_shader (
313
+ graph, t_out, /* prepack_weights = */ false , method, weight );
303
314
304
315
graph.execute_nodes ().emplace_back (new ExecuteNode (
305
316
graph,
0 commit comments