Skip to content

Commit 0076965

Browse files
[ET-VK] Minor improvements to conv2d pw and dw bounds check.
Pull Request resolved: #7815 This diff contains minor improvements to the conv2d pw and dw bounds check in the Vulkan backend for Executorch. ghstack-source-id: 263238731 @exported-using-ghexport Differential Revision: [D68400689](https://our.internmc.facebook.com/intern/diff/D68400689/) --------- Co-authored-by: Vivek Trivedi <[email protected]>
1 parent ea9058e commit 0076965

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void main() {
4141
div_by_x % out_limits.y,
4242
div_by_x / out_limits.y);
4343

44-
if (any(greaterThanEqual(pos, out_limits))) {
44+
if (pos.z >= out_limits.z) {
4545
return;
4646
}
4747

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void main() {
5959
pos.y *= BATCH_SIZE_Y;
6060

6161
// do not process if top pixel does not fit within the output range
62-
if (any(greaterThanEqual(pos, out_limits))) {
62+
if (pos.z >= out_limits.z) {
6363
return;
6464
}
6565

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_sned_output_tile.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ void main() {
4444
div_by_x % out_limits.y,
4545
div_by_x / out_limits.y);
4646

47-
if (any(greaterThanEqual(pos, out_limits))) {
47+
if (pos.z >= out_limits.z) {
4848
return;
4949
}
5050

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ void main() {
7979

8080
// If the top left position is out of bounds, then this invocation will have
8181
// no work to do.
82-
if (any(greaterThanEqual(ivec3(pos[0], gpos.z), out_limits.xyz))) {
82+
if (gpos.z >= out_limits.z) {
8383
return;
8484
}
8585

0 commit comments

Comments
 (0)