Skip to content

Commit bcba739

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Remove runtime dependency on ATen/native/vulkan/impl (#2270)
Summary: Pull Request resolved: #2270 bypass-github-export-checks The only missing logic is copied from ``` ATen/native/vulkan/impl/Common.h/cpp ``` to ``` executorch/backends/vulkan/runtime/graph/ops/OpUtils.h/cpp ``` We can create a utils directory and improve their file organization, in a follow up change. Reviewed By: SS-JIA Differential Revision: D54555273 fbshipit-source-id: 3281391ee60623382b9eece2d6c9cf26678e9342
1 parent 0de3a97 commit bcba739

File tree

10 files changed

+128
-9
lines changed

10 files changed

+128
-9
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
10+
11+
namespace at {
12+
namespace native {
13+
namespace vulkan {
14+
15+
api::utils::uvec3 adaptive_work_group_size(
16+
const api::utils::uvec3& global_work_group) {
17+
api::utils::uvec3 local_group_size = {4, 4, 4};
18+
if (global_work_group.data[2u] == 1) {
19+
if (global_work_group.data[1u] < 8) {
20+
local_group_size.data[0u] = 16;
21+
local_group_size.data[1u] = 4;
22+
local_group_size.data[2u] = 1;
23+
} else {
24+
local_group_size.data[0u] = 8;
25+
local_group_size.data[1u] = 8;
26+
local_group_size.data[2u] = 1;
27+
}
28+
}
29+
return local_group_size;
30+
}
31+
32+
} // namespace vulkan
33+
} // namespace native
34+
} // namespace at
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef USE_VULKAN_API
12+
13+
#include <ATen/native/vulkan/api/api.h>
14+
15+
namespace at {
16+
namespace native {
17+
namespace vulkan {
18+
19+
/*
20+
* Maps a semantic dimension name to an integer that corresponds to its
21+
* innermost ordering in a 4D tensor in NCHW format. Width is the innermost
22+
* dimension, so it corresponds to 1, height is the next innermost, so it
23+
* corresponds to 2, and so on.
24+
*/
25+
struct Dim4D {
26+
static constexpr uint32_t Width = 1u;
27+
static constexpr uint32_t Height = 2u;
28+
static constexpr uint32_t Channel = 3u;
29+
static constexpr uint32_t Batch = 4u;
30+
};
31+
32+
/*
33+
* Semantic dimension names for a 1D tensor
34+
*/
35+
struct Dim1D {
36+
static constexpr uint32_t Length = 1u;
37+
};
38+
39+
/*
40+
* Semantic dimension names for a 2D Convolution kernel.
41+
*/
42+
struct DimConv2DKernel {
43+
static constexpr uint32_t Width = 1u;
44+
static constexpr uint32_t Height = 2u;
45+
static constexpr uint32_t InChannels = 3u;
46+
static constexpr uint32_t OutChannels = 4u;
47+
};
48+
49+
/*
50+
* The same as the above, except for a 2D Transposed Convolution kernel.
51+
*/
52+
struct DimTConv2DKernel {
53+
static constexpr uint32_t Width = 1u;
54+
static constexpr uint32_t Height = 2u;
55+
static constexpr uint32_t OutChannels = 3u;
56+
static constexpr uint32_t InChannels = 4u;
57+
};
58+
59+
/*
60+
* The functions below safely return the size of the dimension at the N-th
61+
* innermost index. If the dimensionality of the size array is not sufficient
62+
* then 1 will be returned. The structs above are intended to be used with
63+
* these functions.
64+
*/
65+
template <uint32_t N>
66+
uint32_t dim_at(const std::vector<int64_t>& sizes) {
67+
const uint32_t dims = sizes.size();
68+
return dims < N ? 1 : api::utils::safe_downcast<uint32_t>(sizes[dims - N]);
69+
}
70+
71+
template <uint32_t N>
72+
uint32_t dim_at(const vTensor& v_in) {
73+
return dim_at<N>(v_in.sizes());
74+
}
75+
76+
/*
77+
* For most global work group sizes, returns {4, 4, 4}, but adjusts the size for
78+
* 2D global work group sizes. Always maintains a total of 64 invocations
79+
*/
80+
api::utils::uvec3 adaptive_work_group_size(
81+
const api::utils::uvec3& global_work_group);
82+
83+
} // namespace vulkan
84+
} // namespace native
85+
} // namespace at
86+
87+
#endif /* USE_VULKAN_API */

backends/vulkan/runtime/graph/ops/StagingUtils.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
1213

13-
#include <ATen/native/vulkan/impl/Common.h>
14-
1514
namespace at {
1615
namespace native {
1716
namespace vulkan {

backends/vulkan/runtime/graph/ops/Utils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
12+
1113
namespace at {
1214
namespace native {
1315
namespace vulkan {

backends/vulkan/runtime/graph/ops/Utils.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
#ifdef USE_VULKAN_API
1212

13-
#include <ATen/native/vulkan/impl/Common.h>
14-
1513
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1614

1715
namespace at {

backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
12+
1113
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1214

1315
namespace at {

backends/vulkan/runtime/graph/ops/impl/Arithmetic.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
#ifdef USE_VULKAN_API
1212

13-
#include <ATen/native/vulkan/impl/Arithmetic.h>
14-
1513
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1614

1715
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
1314

14-
#include <ATen/native/vulkan/impl/Common.h>
15-
1615
namespace at {
1716
namespace native {
1817
namespace vulkan {

backends/vulkan/targets.bzl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def define_common_targets():
5353
"@EXECUTORCH_CLIENTS",
5454
],
5555
exported_deps = [
56-
"//caffe2:torch_vulkan_ops",
5756
"//caffe2:torch_vulkan_spv",
5857
],
5958
define_static_target = False,

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <ATen/native/vulkan/api/api.h>
1212

13+
#include <ATen/native/vulkan/impl/Arithmetic.h>
1314
#include <ATen/native/vulkan/impl/Common.h>
1415
#include <ATen/native/vulkan/impl/Packing.h>
1516

0 commit comments

Comments
 (0)