Skip to content

Commit c6896d9

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Move graph runtime from PT directory to ET directory
Summary: ## Context Move Vulkan graph runtime from PyTorch directory to ExecuTorch directory to improve development logistics: * ExecuTorch delegate changes will no longer require export to PyTorch directory * Makes it much easier to enable OSS build for Vulkan delegate Reviewed By: shoumikhin Differential Revision: D54133350
1 parent 78ce089 commit c6896d9

21 files changed

+2102
-3
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <ATen/native/vulkan/graph/Graph.h>
10-
#include <ATen/native/vulkan/graph/OperatorRegistry.h>
9+
#include <executorch/backends/vulkan/runtime/graph/Graph.h>
10+
#include <executorch/backends/vulkan/runtime/graph/OperatorRegistry.h>
1111

1212
#include <executorch/backends/vulkan/runtime/VulkanDelegateHeader.h>
1313
#include <executorch/backends/vulkan/schema_generated.h>
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef USE_VULKAN_API
12+
13+
#include <ATen/native/vulkan/api/Context.h>
14+
15+
namespace at {
16+
namespace native {
17+
namespace vulkan {
18+
19+
struct GraphConfig final {
20+
api::ContextConfig contextConfig;
21+
};
22+
23+
} // namespace vulkan
24+
} // namespace native
25+
} // namespace at
26+
27+
#endif /* USE_VULKAN_API */
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/Constant.h>
10+
11+
namespace at {
12+
namespace native {
13+
namespace vulkan {
14+
15+
TensorRef::TensorRef(
16+
const std::vector<int64_t>& t_sizes,
17+
api::ScalarType t_dtype,
18+
const void* const t_data)
19+
: sizes{}, dtype{t_dtype}, data{t_data} {
20+
size_t ndim = t_sizes.size();
21+
sizes.resize(ndim);
22+
for (int i = 0; i < ndim; ++i) {
23+
sizes[i] = t_sizes.at(i);
24+
}
25+
}
26+
27+
} // namespace vulkan
28+
} // namespace native
29+
} // namespace at
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef USE_VULKAN_API
12+
13+
#include <ATen/native/vulkan/api/Context.h>
14+
15+
namespace at {
16+
namespace native {
17+
namespace vulkan {
18+
19+
/*
20+
* Represents a reference to a tensor that has been serialized with the model,
21+
* such as a serialized weight tensor. It contains some metadata as well as a
22+
* raw pointer to the data of the tensor, which is assumed to be contiguous.
23+
*/
24+
struct TensorRef final {
25+
std::vector<int64_t> sizes;
26+
api::ScalarType dtype;
27+
const void* data;
28+
29+
explicit TensorRef(
30+
const std::vector<int64_t>& t_sizes,
31+
api::ScalarType t_dtype,
32+
const void* const t_data);
33+
34+
TensorRef(const TensorRef&) = default;
35+
TensorRef& operator=(const TensorRef&) = default;
36+
37+
TensorRef(TensorRef&&) = default;
38+
TensorRef& operator=(TensorRef&&) = default;
39+
};
40+
41+
} // namespace vulkan
42+
} // namespace native
43+
} // namespace at
44+
45+
#endif /* USE_VULKAN_API */
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/native/vulkan/impl/Arithmetic.h>
10+
#include <ATen/native/vulkan/impl/Common.h>
11+
12+
#include <executorch/backends/vulkan/runtime/graph/Functions.h>
13+
14+
#include <executorch/backends/vulkan/runtime/graph/ops/Arithmetic.h>
15+
16+
namespace at {
17+
namespace native {
18+
namespace vulkan {
19+
20+
#define DEFINE_ARITHMETIC_FN(function, op_type) \
21+
ValueRef function(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
22+
return add_arithmetic_node( \
23+
graph, \
24+
args[0], \
25+
args[1], \
26+
args[2], \
27+
arithmetic::OpType::op_type, \
28+
args[3]); \
29+
}
30+
31+
DEFINE_ARITHMETIC_FN(add, ADD);
32+
DEFINE_ARITHMETIC_FN(sub, SUB);
33+
DEFINE_ARITHMETIC_FN(mul, MUL);
34+
DEFINE_ARITHMETIC_FN(div, DIV);
35+
DEFINE_ARITHMETIC_FN(floor_div, FLOOR_DIV);
36+
DEFINE_ARITHMETIC_FN(pow, POW);
37+
38+
} // namespace vulkan
39+
} // namespace native
40+
} // namespace at
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef USE_VULKAN_API
12+
13+
#include <executorch/backends/vulkan/runtime/graph/Graph.h>
14+
15+
namespace at {
16+
namespace native {
17+
namespace vulkan {
18+
19+
#define DEFINE_OP_FN(name) \
20+
ValueRef name(ComputeGraph& graph, const std::vector<ValueRef>& args);
21+
22+
DEFINE_OP_FN(add);
23+
DEFINE_OP_FN(sub);
24+
DEFINE_OP_FN(mul);
25+
DEFINE_OP_FN(div);
26+
DEFINE_OP_FN(floor_div);
27+
DEFINE_OP_FN(pow);
28+
29+
} // namespace vulkan
30+
} // namespace native
31+
} // namespace at
32+
33+
#endif /* USE_VULKAN_API */

0 commit comments

Comments
 (0)