Skip to content

Commit 8a06828

Browse files
committed
[ET-VK][EZ] Split Graph.h classes into multiple files
I always get lost finding the right class within `Graph.h/cpp`, so we split them into - `ComputeGraph.h/cpp` - `ExecuteNode.h` - `PrepackNode.h` - `SharedObject.h/cpp` and move `ValueRef`/`IOValueRef` into `Value.h`. Differential Revision: [D54272392](https://our.internmc.facebook.com/intern/diff/D54272392/) ghstack-source-id: 216691329 Pull Request resolved: #2150
1 parent 4e77258 commit 8a06828

File tree

13 files changed

+244
-142
lines changed

13 files changed

+244
-142
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/backends/vulkan/runtime/graph/Graph.h>
9+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1010
#include <executorch/backends/vulkan/runtime/graph/OperatorRegistry.h>
1111

1212
#include <executorch/backends/vulkan/runtime/VulkanDelegateHeader.h>

backends/vulkan/runtime/graph/Graph.cpp renamed to backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,76 +6,14 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/backends/vulkan/runtime/graph/Graph.h>
9+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ops/Staging.h>
1212

1313
namespace at {
1414
namespace native {
1515
namespace vulkan {
1616

17-
//
18-
// SharedObject
19-
//
20-
21-
void SharedObject::add_user(ComputeGraph* const graph, const ValueRef idx) {
22-
vTensor& t = graph->get_val(idx).toTensor();
23-
24-
//
25-
// Aggregate Memory Requirements
26-
//
27-
28-
const VkMemoryRequirements mem_reqs = t.get_memory_requirements();
29-
aggregate_memory_requirements.size =
30-
std::max(mem_reqs.size, aggregate_memory_requirements.size);
31-
aggregate_memory_requirements.alignment =
32-
std::max(mem_reqs.alignment, aggregate_memory_requirements.alignment);
33-
aggregate_memory_requirements.memoryTypeBits |= mem_reqs.memoryTypeBits;
34-
35-
//
36-
// Aggregate Allocation Create Info
37-
//
38-
39-
const VmaAllocationCreateInfo create_info = t.get_allocation_create_info();
40-
// Clear out CREATE_STRATEGY bit flags in case of conflict
41-
VmaAllocationCreateFlags clear_mask = ~VMA_ALLOCATION_CREATE_STRATEGY_MASK;
42-
VmaAllocationCreateFlags create_flags = create_info.flags & clear_mask;
43-
// Use the default allocation strategy
44-
aggregate_create_info.flags = create_flags | api::DEFAULT_ALLOCATION_STRATEGY;
45-
46-
// Set the usage flag if it is currently not set
47-
if (aggregate_create_info.usage == VMA_MEMORY_USAGE_UNKNOWN) {
48-
aggregate_create_info.usage = create_info.usage;
49-
}
50-
// Otherwise check that there is no conflict regarding usage
51-
VK_CHECK_COND(aggregate_create_info.usage == create_info.usage);
52-
aggregate_create_info.requiredFlags |= create_info.requiredFlags;
53-
aggregate_create_info.preferredFlags |= create_info.preferredFlags;
54-
55-
users.emplace_back(idx);
56-
}
57-
58-
void SharedObject::allocate(ComputeGraph* const graph) {
59-
if (aggregate_memory_requirements.size == 0) {
60-
return;
61-
}
62-
allocation = graph->context()->adapter_ptr()->vma().create_allocation(
63-
aggregate_memory_requirements, aggregate_create_info);
64-
}
65-
66-
void SharedObject::bind_users(ComputeGraph* const graph) {
67-
if (users.empty()) {
68-
return;
69-
}
70-
for (const ValueRef idx : users) {
71-
graph->get_val(idx).toTensor().bind_allocation(allocation);
72-
}
73-
}
74-
75-
//
76-
// ComputeGraph
77-
//
78-
7917
ComputeGraph::ComputeGraph(GraphConfig config)
8018
: config_{config},
8119
context_{new api::Context(

backends/vulkan/runtime/graph/Graph.h renamed to backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 3 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -17,85 +17,15 @@
1717
#include <ATen/native/vulkan/api/Types.h>
1818

1919
#include <executorch/backends/vulkan/runtime/graph/Config.h>
20+
#include <executorch/backends/vulkan/runtime/graph/ExecuteNode.h>
21+
#include <executorch/backends/vulkan/runtime/graph/PrepackNode.h>
22+
#include <executorch/backends/vulkan/runtime/graph/SharedObject.h>
2023
#include <executorch/backends/vulkan/runtime/graph/Value.h>
2124

2225
namespace at {
2326
namespace native {
2427
namespace vulkan {
2528

26-
using ValueRef = int32_t;
27-
28-
struct IOValueRef {
29-
ValueRef value;
30-
ValueRef staging;
31-
};
32-
33-
class ComputeGraph;
34-
35-
/*
36-
* Represents a single prepacking op in a ML model. In graph mode, ops will be
37-
* implemented in a derived class that implements encode, which will implement
38-
* encoding of shaders transferring necessary data (such as weights and biases)
39-
* to the GPU.
40-
*/
41-
class PrepackNode {
42-
friend class ComputeGraph;
43-
44-
public:
45-
PrepackNode(ValueRef tref, ValueRef packed) : tref_{tref}, packed_{packed} {}
46-
47-
virtual ~PrepackNode() = default;
48-
49-
protected:
50-
ValueRef tref_;
51-
ValueRef packed_;
52-
53-
public:
54-
virtual void encode(ComputeGraph* graph) const = 0;
55-
};
56-
57-
/*
58-
* Represents a single execution op in a ML model. In graph mode, ops will be
59-
* implemented in a derived class that implements encode, which will implement
60-
* encoding of the shader corresponding to the op into the command buffer of a
61-
* ComputeGraph.
62-
*/
63-
class ExecuteNode {
64-
friend class ComputeGraph;
65-
66-
public:
67-
ExecuteNode(ValueRef input, ValueRef output)
68-
: inputs_{input}, outputs_{output} {}
69-
ExecuteNode(
70-
const std::vector<ValueRef>& inputs,
71-
const std::vector<ValueRef>& outputs)
72-
: inputs_(inputs), outputs_(outputs) {}
73-
74-
virtual ~ExecuteNode() = default;
75-
76-
protected:
77-
std::vector<ValueRef> inputs_;
78-
std::vector<ValueRef> outputs_;
79-
80-
public:
81-
virtual void encode(ComputeGraph* graph) const = 0;
82-
};
83-
84-
struct SharedObject {
85-
friend class ComputeGraph;
86-
87-
explicit SharedObject() = default;
88-
89-
VkMemoryRequirements aggregate_memory_requirements;
90-
VmaAllocationCreateInfo aggregate_create_info;
91-
std::vector<ValueRef> users;
92-
api::MemoryAllocation allocation;
93-
94-
void add_user(ComputeGraph* const graph, const ValueRef idx);
95-
void allocate(ComputeGraph* const graph);
96-
void bind_users(ComputeGraph* const graph);
97-
};
98-
9929
/*
10030
* This is the core data structure used to execute Vulkan models in graph mode.
10131
* As opposed to ATen/eager mode where a command buffer is encoded every
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef USE_VULKAN_API
12+
13+
#include <ATen/native/vulkan/api/Context.h>
14+
#include <ATen/native/vulkan/api/Tensor.h>
15+
#include <ATen/native/vulkan/api/Types.h>
16+
17+
#include <executorch/backends/vulkan/runtime/graph/Value.h>
18+
19+
namespace at {
20+
namespace native {
21+
namespace vulkan {
22+
23+
class ComputeGraph;
24+
25+
/*
26+
* Represents a single execution op in a ML model. In graph mode, ops will be
27+
* implemented in a derived class that implements encode, which will implement
28+
* encoding of the shader corresponding to the op into the command buffer of a
29+
* ComputeGraph.
30+
*/
31+
class ExecuteNode {
32+
friend class ComputeGraph;
33+
34+
public:
35+
ExecuteNode(ValueRef input, ValueRef output)
36+
: inputs_{input}, outputs_{output} {}
37+
ExecuteNode(
38+
const std::vector<ValueRef>& inputs,
39+
const std::vector<ValueRef>& outputs)
40+
: inputs_(inputs), outputs_(outputs) {}
41+
42+
virtual ~ExecuteNode() = default;
43+
44+
protected:
45+
std::vector<ValueRef> inputs_;
46+
std::vector<ValueRef> outputs_;
47+
48+
public:
49+
virtual void encode(ComputeGraph* graph) const = 0;
50+
};
51+
52+
} // namespace vulkan
53+
} // namespace native
54+
} // namespace at
55+
56+
#endif /* USE_VULKAN_API */

backends/vulkan/runtime/graph/Functions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#ifdef USE_VULKAN_API
1212

13-
#include <executorch/backends/vulkan/runtime/graph/Graph.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1414

1515
namespace at {
1616
namespace native {

backends/vulkan/runtime/graph/OperatorRegistry.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#ifdef USE_VULKAN_API
1212

13-
#include <executorch/backends/vulkan/runtime/graph/Graph.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1414

1515
#include <functional>
1616
#include <unordered_map>
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef USE_VULKAN_API
12+
13+
#include <ATen/native/vulkan/api/Context.h>
14+
#include <ATen/native/vulkan/api/Tensor.h>
15+
#include <ATen/native/vulkan/api/Types.h>
16+
17+
#include <executorch/backends/vulkan/runtime/graph/Value.h>
18+
19+
namespace at {
20+
namespace native {
21+
namespace vulkan {
22+
23+
class ComputeGraph;
24+
25+
/*
26+
* Represents a single prepacking op in a ML model. In graph mode, ops will be
27+
* implemented in a derived class that implements encode, which will implement
28+
* encoding of shaders transferring necessary data (such as weights and biases)
29+
* to the GPU.
30+
*/
31+
class PrepackNode {
32+
friend class ComputeGraph;
33+
34+
public:
35+
PrepackNode(ValueRef tref, ValueRef packed) : tref_{tref}, packed_{packed} {}
36+
37+
virtual ~PrepackNode() = default;
38+
39+
protected:
40+
ValueRef tref_;
41+
ValueRef packed_;
42+
43+
public:
44+
virtual void encode(ComputeGraph* graph) const = 0;
45+
};
46+
47+
} // namespace vulkan
48+
} // namespace native
49+
} // namespace at
50+
51+
#endif /* USE_VULKAN_API */
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/SharedObject.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
13+
namespace at {
14+
namespace native {
15+
namespace vulkan {
16+
17+
void SharedObject::add_user(ComputeGraph* const graph, const ValueRef idx) {
18+
vTensor& t = graph->get_val(idx).toTensor();
19+
20+
//
21+
// Aggregate Memory Requirements
22+
//
23+
24+
const VkMemoryRequirements mem_reqs = t.get_memory_requirements();
25+
aggregate_memory_requirements.size =
26+
std::max(mem_reqs.size, aggregate_memory_requirements.size);
27+
aggregate_memory_requirements.alignment =
28+
std::max(mem_reqs.alignment, aggregate_memory_requirements.alignment);
29+
aggregate_memory_requirements.memoryTypeBits |= mem_reqs.memoryTypeBits;
30+
31+
//
32+
// Aggregate Allocation Create Info
33+
//
34+
35+
const VmaAllocationCreateInfo create_info = t.get_allocation_create_info();
36+
// Clear out CREATE_STRATEGY bit flags in case of conflict
37+
VmaAllocationCreateFlags clear_mask = ~VMA_ALLOCATION_CREATE_STRATEGY_MASK;
38+
VmaAllocationCreateFlags create_flags = create_info.flags & clear_mask;
39+
// Use the default allocation strategy
40+
aggregate_create_info.flags = create_flags | api::DEFAULT_ALLOCATION_STRATEGY;
41+
42+
// Set the usage flag if it is currently not set
43+
if (aggregate_create_info.usage == VMA_MEMORY_USAGE_UNKNOWN) {
44+
aggregate_create_info.usage = create_info.usage;
45+
}
46+
// Otherwise check that there is no conflict regarding usage
47+
VK_CHECK_COND(aggregate_create_info.usage == create_info.usage);
48+
aggregate_create_info.requiredFlags |= create_info.requiredFlags;
49+
aggregate_create_info.preferredFlags |= create_info.preferredFlags;
50+
51+
users.emplace_back(idx);
52+
}
53+
54+
void SharedObject::allocate(ComputeGraph* const graph) {
55+
if (aggregate_memory_requirements.size == 0) {
56+
return;
57+
}
58+
allocation = graph->context()->adapter_ptr()->vma().create_allocation(
59+
aggregate_memory_requirements, aggregate_create_info);
60+
}
61+
62+
void SharedObject::bind_users(ComputeGraph* const graph) {
63+
if (users.empty()) {
64+
return;
65+
}
66+
for (const ValueRef idx : users) {
67+
graph->get_val(idx).toTensor().bind_allocation(allocation);
68+
}
69+
}
70+
71+
} // namespace vulkan
72+
} // namespace native
73+
} // namespace at

0 commit comments

Comments
 (0)