Skip to content

Remove runtime dependency on ATen/native/vulkan/impl #2270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions backends/vulkan/runtime/graph/ops/OpUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>

namespace at {
namespace native {
namespace vulkan {

api::utils::uvec3 adaptive_work_group_size(
const api::utils::uvec3& global_work_group) {
api::utils::uvec3 local_group_size = {4, 4, 4};
if (global_work_group.data[2u] == 1) {
if (global_work_group.data[1u] < 8) {
local_group_size.data[0u] = 16;
local_group_size.data[1u] = 4;
local_group_size.data[2u] = 1;
} else {
local_group_size.data[0u] = 8;
local_group_size.data[1u] = 8;
local_group_size.data[2u] = 1;
}
}
return local_group_size;
}

} // namespace vulkan
} // namespace native
} // namespace at
87 changes: 87 additions & 0 deletions backends/vulkan/runtime/graph/ops/OpUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/api/api.h>

namespace at {
namespace native {
namespace vulkan {

/*
* Maps a semantic dimension name to an integer that corresponds to its
* innermost ordering in a 4D tensor in NCHW format. Width is the innermost
* dimension, so it corresponds to 1, height is the next innermost, so it
* corresponds to 2, and so on.
*/
struct Dim4D {
static constexpr uint32_t Width = 1u;
static constexpr uint32_t Height = 2u;
static constexpr uint32_t Channel = 3u;
static constexpr uint32_t Batch = 4u;
};

/*
* Semantic dimension names for a 1D tensor
*/
struct Dim1D {
static constexpr uint32_t Length = 1u;
};

/*
* Semantic dimension names for a 2D Convolution kernel.
*/
struct DimConv2DKernel {
static constexpr uint32_t Width = 1u;
static constexpr uint32_t Height = 2u;
static constexpr uint32_t InChannels = 3u;
static constexpr uint32_t OutChannels = 4u;
};

/*
* The same as the above, except for a 2D Transposed Convolution kernel.
*/
struct DimTConv2DKernel {
static constexpr uint32_t Width = 1u;
static constexpr uint32_t Height = 2u;
static constexpr uint32_t OutChannels = 3u;
static constexpr uint32_t InChannels = 4u;
};

/*
* The functions below safely return the size of the dimension at the N-th
* innermost index. If the dimensionality of the size array is not sufficient
* then 1 will be returned. The structs above are intended to be used with
* these functions.
*/
template <uint32_t N>
uint32_t dim_at(const std::vector<int64_t>& sizes) {
const uint32_t dims = sizes.size();
return dims < N ? 1 : api::utils::safe_downcast<uint32_t>(sizes[dims - N]);
}

template <uint32_t N>
uint32_t dim_at(const vTensor& v_in) {
return dim_at<N>(v_in.sizes());
}

/*
* For most global work group sizes, returns {4, 4, 4}, but adjusts the size for
* 2D global work group sizes. Always maintains a total of 64 invocations
*/
api::utils::uvec3 adaptive_work_group_size(
const api::utils::uvec3& global_work_group);

} // namespace vulkan
} // namespace native
} // namespace at

#endif /* USE_VULKAN_API */
3 changes: 1 addition & 2 deletions backends/vulkan/runtime/graph/ops/StagingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>

#include <ATen/native/vulkan/impl/Common.h>

namespace at {
namespace native {
namespace vulkan {
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>

namespace at {
namespace native {
namespace vulkan {
Expand Down
2 changes: 0 additions & 2 deletions backends/vulkan/runtime/graph/ops/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/impl/Common.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

namespace at {
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>

#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

namespace at {
Expand Down
2 changes: 0 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/impl/Arithmetic.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
Expand Down
3 changes: 1 addition & 2 deletions backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>

#include <ATen/native/vulkan/impl/Common.h>

namespace at {
namespace native {
namespace vulkan {
Expand Down
1 change: 0 additions & 1 deletion backends/vulkan/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def define_common_targets():
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
"//caffe2:torch_vulkan_ops",
"//caffe2:torch_vulkan_spv",
],
define_static_target = False,
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <ATen/native/vulkan/api/api.h>

#include <ATen/native/vulkan/impl/Arithmetic.h>
#include <ATen/native/vulkan/impl/Common.h>
#include <ATen/native/vulkan/impl/Packing.h>

Expand Down