Skip to content

Commit 1e975b9

Browse files
committed
[ET-VK] Making stride equals dilation the default mode for conv2d dw.
Pull Request resolved: #7596 This diff makes changes make stride equals dilation the default mode for conv2d dw output op. Adds a different source file to handle stride not equal dilation case. Differential Revision: [D67979760](https://our.internmc.facebook.com/intern/diff/D67979760/) ghstack-source-id: 260951738
1 parent 8749884 commit 1e975b9

File tree

5 files changed

+101
-57
lines changed

5 files changed

+101
-57
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
#define TILE_SIZE ${TILE_SIZE}
1616

17-
#define STRIDE_EQ_DILATION ${STRIDE_EQ_DILATION}
18-
1917
#define BATCH_SIZE_X ${BATCH_SIZE_X}
2018

2119
#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
@@ -45,7 +43,6 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4543
* output at a single output location.
4644
*/
4745

48-
#if STRIDE_EQ_DILATION
4946
void main() {
5047
// x and y are divided by batch size to determine 3d position
5148
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
@@ -127,42 +124,3 @@ void main() {
127124
}
128125
}
129126
}
130-
131-
#else
132-
void main() {
133-
const uint div_by_x = gl_GlobalInvocationID.x / out_limits.x;
134-
const ivec3 pos = ivec3(
135-
gl_GlobalInvocationID.x % out_limits.x,
136-
div_by_x % out_limits.y,
137-
div_by_x / out_limits.y);
138-
139-
if (any(greaterThanEqual(pos, out_limits))) {
140-
return;
141-
}
142-
143-
// Compute the index of the top-left element of the overlay region. Negative
144-
// indices indicate that the top-left element is in a region added by padding.
145-
const ivec2 ipos = pos.xy * stride - padding;
146-
147-
// Compute the start and end of the input indices to load. Padding is assumed
148-
// to be constant 0 padding, so any reads from the padding region is skipped.
149-
const ivec2 start = ipos;
150-
const ivec2 end = ipos + overlay_region.xy;
151-
152-
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
153-
int kx = 0;
154-
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
155-
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
156-
// The weight kernel was rearranged such that every NxN filter is
157-
// flattened to fit in one row. Each filter was then stacked on top of
158-
// each other vertically.
159-
const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0);
160-
sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum);
161-
kx++;
162-
}
163-
}
164-
165-
imageStore(t_out, pos, op(sum, out_min, out_max));
166-
}
167-
168-
#endif

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ conv2d_dw_output_tile:
1212
TILE_SIZE: 3
1313
BATCH_SIZE_X: 4
1414
BATCH_SIZE_Y: 2
15-
STRIDE_EQ_DILATION: 0
1615
generate_variant_forall:
1716
DTYPE:
1817
- VALUE: half
@@ -26,15 +25,3 @@ conv2d_dw_output_tile:
2625
- NAME: conv2d_dw_output_tile_5x5_clamp
2726
OPERATOR: clamp(X, A, B)
2827
TILE_SIZE: 5
29-
- NAME: conv2d_dw_sed_output_tile_3x3
30-
STRIDE_EQ_DILATION: 1
31-
- NAME: conv2d_dw_sed_output_tile_3x3_clamp
32-
OPERATOR: clamp(X, A, B)
33-
STRIDE_EQ_DILATION: 1
34-
- NAME: conv2d_dw_sed_output_tile_5x5
35-
TILE_SIZE: 5
36-
STRIDE_EQ_DILATION: 1
37-
- NAME: conv2d_dw_sed_output_tile_5x5_clamp
38-
OPERATOR: clamp(X, A, B)
39-
TILE_SIZE: 5
40-
STRIDE_EQ_DILATION: 1
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
#define TILE_SIZE ${TILE_SIZE}
16+
17+
#define op(X, A, B) ${OPERATOR}
18+
19+
#include "indexing_utils.h"
20+
21+
layout(std430) buffer;
22+
23+
${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
24+
${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")}
25+
${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")}
26+
${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")}
27+
${layout_declare_ubo(4, "ivec3", "out_limits")}
28+
${layout_declare_ubo(5, "ivec4", "in_sizes")}
29+
${layout_declare_ubo(6, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
30+
${layout_declare_ubo(7, "ivec2", "overlay_region", "int", "in_group_size")}
31+
${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
32+
33+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
34+
35+
/*
36+
* Computes a depthwise convolution. Each shader invocation calculates the
37+
* output at a single output location.
38+
*/
39+
40+
void main() {
41+
const uint div_by_x = gl_GlobalInvocationID.x / out_limits.x;
42+
const ivec3 pos = ivec3(
43+
gl_GlobalInvocationID.x % out_limits.x,
44+
div_by_x % out_limits.y,
45+
div_by_x / out_limits.y);
46+
47+
if (any(greaterThanEqual(pos, out_limits))) {
48+
return;
49+
}
50+
51+
// Compute the index of the top-left element of the overlay region. Negative
52+
// indices indicate that the top-left element is in a region added by padding.
53+
const ivec2 ipos = pos.xy * stride - padding;
54+
55+
// Compute the start and end of the input indices to load. Padding is assumed
56+
// to be constant 0 padding, so any reads from the padding region is skipped.
57+
const ivec2 start = ipos;
58+
const ivec2 end = ipos + overlay_region.xy;
59+
60+
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
61+
int kx = 0;
62+
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
63+
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
64+
// The weight kernel was rearranged such that every NxN filter is
65+
// flattened to fit in one row. Each filter was then stacked on top of
66+
// each other vertically.
67+
const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0);
68+
sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum);
69+
kx++;
70+
}
71+
}
72+
73+
imageStore(t_out, pos, op(sum, out_min, out_max));
74+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv2d_dw_sned_output_tile:
8+
parameter_names_with_default_values:
9+
OPERATOR: X
10+
NDIM: 3
11+
DTYPE: float
12+
TILE_SIZE: 3
13+
generate_variant_forall:
14+
DTYPE:
15+
- VALUE: half
16+
- VALUE: float
17+
shader_variants:
18+
- NAME: conv2d_dw_sned_output_tile_3x3
19+
- NAME: conv2d_dw_sned_output_tile_3x3_clamp
20+
OPERATOR: clamp(X, A, B)
21+
- NAME: conv2d_dw_sned_output_tile_5x5
22+
TILE_SIZE: 5
23+
- NAME: conv2d_dw_sned_output_tile_5x5_clamp
24+
OPERATOR: clamp(X, A, B)
25+
TILE_SIZE: 5

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ vkapi::ShaderInfo get_conv2d_shader(
134134
case Conv2dMethod::Depthwise:
135135
kernel_name = "conv2d_dw";
136136
if (!prepack_weights) {
137-
if (stride_equals_dilation) {
138-
kernel_name += "_sed";
137+
if (!stride_equals_dilation) {
138+
kernel_name += "_sned";
139139
}
140140
const auto& weight_sizes = graph.get_tref(weight)->sizes;
141141
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {

0 commit comments

Comments
 (0)